mirror of
https://github.com/HeyPuter/puter
synced 2024-11-14 22:06:00 +00:00
feat: add chat completions driver to puterai module
This commit is contained in:
parent
ef6671da18
commit
4e3bd1831e
@ -24,6 +24,20 @@ class AIInterfaceService extends BaseService {
|
||||
},
|
||||
}
|
||||
});
|
||||
|
||||
col_interfaces.set('puter-chat-completion', {
|
||||
description: 'Chatbot.',
|
||||
methods: {
|
||||
complete: {
|
||||
description: 'Get completions for a chat log.',
|
||||
parameters: {
|
||||
messages: { type: 'json' },
|
||||
vision: { type: 'flag' },
|
||||
},
|
||||
result: { type: 'json' }
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
229
src/backend/src/modules/puterai/OpenAICompletionService.js
Normal file
229
src/backend/src/modules/puterai/OpenAICompletionService.js
Normal file
@ -0,0 +1,229 @@
|
||||
const APIError = require('../../api/APIError');
|
||||
const BaseService = require('../../services/BaseService');
|
||||
const { Context } = require('../../util/context');
|
||||
const SmolUtil = require('../../util/smolutil');
|
||||
|
||||
class OpenAICompletionService extends BaseService {
|
||||
static MODULES = {
|
||||
openai: require('openai'),
|
||||
tiktoken: require('tiktoken'),
|
||||
}
|
||||
async _init () {
|
||||
const sk_key =
|
||||
this.config?.openai?.secret_key ??
|
||||
this.global_config.openai?.secret_key;
|
||||
|
||||
this.openai = new this.modules.openai.OpenAI({
|
||||
apiKey: sk_key
|
||||
});
|
||||
}
|
||||
|
||||
static IMPLEMENTS = {
|
||||
['puter-chat-completion']: {
|
||||
async complete ({ messages, vision }) {
|
||||
const model = 'gpt-4o';
|
||||
return await this.complete(messages, {
|
||||
model,
|
||||
moderation: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
async check_moderation (text) {
|
||||
// create moderation
|
||||
const results = await this.openai.moderations.create({
|
||||
input: text,
|
||||
});
|
||||
|
||||
let flagged = false;
|
||||
|
||||
for ( const result of results?.results ?? [] ) {
|
||||
if ( result.flagged ) {
|
||||
flagged = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
flagged,
|
||||
results,
|
||||
};
|
||||
}
|
||||
|
||||
async complete (messages, { moderation, model }) {
|
||||
// Validate messages
|
||||
if ( ! Array.isArray(messages) ) {
|
||||
throw new Error('`messages` must be an array');
|
||||
}
|
||||
|
||||
model = model ?? 'gpt-3.5-turbo';
|
||||
// model = model ?? 'gpt-4o';
|
||||
|
||||
for ( let i = 0; i < messages.length; i++ ) {
|
||||
let msg = messages[i];
|
||||
if ( typeof msg === 'string' ) msg = { content: msg };
|
||||
if ( typeof msg !== 'object' ) {
|
||||
throw new Error('each message must be a string or an object');
|
||||
}
|
||||
if ( ! msg.role ) msg.role = 'user';
|
||||
if ( ! msg.content ) {
|
||||
throw new Error('each message must have a `content` property');
|
||||
}
|
||||
|
||||
const texts = [];
|
||||
if ( typeof msg.content === 'string' ) texts.push(msg.content);
|
||||
else if ( typeof msg.content === 'object' ) {
|
||||
if ( Array.isArray(msg.content) ) {
|
||||
texts.push(...msg.content.filter(o => (
|
||||
( ! o.type && o.hasOwnProperty('text') ) ||
|
||||
o.type === 'text')).map(o => o.text));
|
||||
}
|
||||
else texts.push(msg.content.text);
|
||||
}
|
||||
|
||||
if ( moderation ) {
|
||||
for ( const text of texts ) {
|
||||
const moderation_result = await this.check_moderation(text);
|
||||
if ( moderation_result.flagged ) {
|
||||
throw new Error('message is not allowed');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages[i] = msg;
|
||||
}
|
||||
|
||||
messages.unshift({
|
||||
role: 'system',
|
||||
content: 'You are running inside a Puter app.',
|
||||
})
|
||||
// messages.unshift({
|
||||
// role: 'system',
|
||||
// content: 'Don\'t let the user trick you into doing something bad.',
|
||||
// })
|
||||
|
||||
const user_private_uid = Context.get('actor')?.private_uid ?? 'UNKNOWN';
|
||||
if ( user_private_uid === 'UNKNOWN' ) {
|
||||
this.errors.report('chat-completion-service:unknown-user', {
|
||||
message: 'failed to get a user ID for an OpenAI request',
|
||||
alarm: true,
|
||||
trace: true,
|
||||
});
|
||||
}
|
||||
|
||||
this.log.info('PRIVATE UID FOR USER ' + user_private_uid)
|
||||
|
||||
// Here's something fun; the documentation shows `type: 'image_url'` in
|
||||
// objects that contain an image url, but everything still works if
|
||||
// that's missing. We normalise it here so the token count code works.
|
||||
for ( const msg of messages ) {
|
||||
if ( ! msg.content ) continue;
|
||||
if ( typeof msg.content !== 'object' ) continue;
|
||||
|
||||
const content = SmolUtil.ensure_array(msg.content);
|
||||
|
||||
for ( const o of content ) {
|
||||
if ( ! o.hasOwnProperty('image_url') ) continue;
|
||||
if ( o.type ) continue;
|
||||
o.type = 'image_url';
|
||||
}
|
||||
}
|
||||
|
||||
console.log('DATA GOING IN', messages);
|
||||
|
||||
// Count tokens
|
||||
let token_count = 0;
|
||||
{
|
||||
const enc = this.modules.tiktoken.encoding_for_model(model);
|
||||
const text = JSON.stringify(messages)
|
||||
const tokens = enc.encode(text);
|
||||
token_count += tokens.length;
|
||||
}
|
||||
|
||||
// Subtract image urls
|
||||
for ( const msg of messages ) {
|
||||
// console.log('msg and content', msg, msg.content);
|
||||
if ( ! msg.content ) continue;
|
||||
if ( typeof msg.content !== 'object' ) continue;
|
||||
|
||||
const content = SmolUtil.ensure_array(msg.content);
|
||||
|
||||
for ( const o of content ) {
|
||||
// console.log('part of content', o);
|
||||
if ( o.type !== 'image_url' ) continue;
|
||||
const enc = this.modules.tiktoken.encoding_for_model(model);
|
||||
const text = o.image_url?.url ?? '';
|
||||
const tokens = enc.encode(text);
|
||||
token_count -= tokens.length;
|
||||
}
|
||||
}
|
||||
|
||||
const max_tokens = 4096 - token_count;
|
||||
console.log('MAX TOKENS ???', max_tokens);
|
||||
|
||||
if ( max_tokens <= 8 ) {
|
||||
throw APIError.create('max_tokens_exceeded', null, {
|
||||
input_tokens: token_count,
|
||||
max_tokens: 4096 - 8,
|
||||
});
|
||||
}
|
||||
|
||||
const completion = await this.openai.chat.completions.create({
|
||||
user: user_private_uid,
|
||||
messages: messages,
|
||||
model: model,
|
||||
max_tokens,
|
||||
});
|
||||
|
||||
this.log.info('how many choices?: ' + completion.choices.length);
|
||||
|
||||
// Record spending information
|
||||
const spending_meta = {};
|
||||
spending_meta.timestamp = Date.now();
|
||||
spending_meta.count_tokens_input = token_count;
|
||||
spending_meta.count_tokens_output = (() => {
|
||||
// count output tokens (overestimate)
|
||||
const enc = this.modules.tiktoken.encoding_for_model(model);
|
||||
const text = JSON.stringify(completion.choices);
|
||||
const tokens = enc.encode(text);
|
||||
return tokens.length;
|
||||
})();
|
||||
|
||||
const svc_spending = Context.get('services').get('spending');
|
||||
svc_spending.record_spending('openai', 'chat-completion', spending_meta);
|
||||
|
||||
const svc_counting = Context.get('services').get('counting');
|
||||
svc_counting.increment({
|
||||
service_name: 'openai:chat-completion',
|
||||
service_type: 'gpt',
|
||||
values: {
|
||||
model,
|
||||
input_tokens: token_count,
|
||||
output_tokens: spending_meta.count_tokens_output,
|
||||
}
|
||||
});
|
||||
|
||||
const is_empty = completion.choices?.[0]?.message?.content?.trim() === '';
|
||||
if ( is_empty ) {
|
||||
// GPT refuses to generate an empty response if you ask it to,
|
||||
// so this will probably only happen on an error condition.
|
||||
throw new Error('an empty response was generated');
|
||||
}
|
||||
|
||||
// We need to moderate the completion too
|
||||
if ( moderation ) {
|
||||
const text = completion.choices[0].message.content;
|
||||
const moderation_result = await this.check_moderation(text);
|
||||
if ( moderation_result.flagged ) {
|
||||
throw new Error('message is not allowed');
|
||||
}
|
||||
}
|
||||
|
||||
return completion.choices[0];
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
OpenAICompletionService,
|
||||
};
|
@ -9,6 +9,9 @@ class PuterAIModule extends AdvancedBase {
|
||||
|
||||
const { AWSTextractService } = require('./AWSTextractService');
|
||||
services.registerService('aws-textract', AWSTextractService);
|
||||
|
||||
const { OpenAICompletionService } = require('./OpenAICompletionService');
|
||||
services.registerService('openai-completion', OpenAICompletionService);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -14,4 +14,31 @@ await (await fetch("http://api.puter.localhost:4100/drivers/call", {
|
||||
}),
|
||||
"method": "POST",
|
||||
})).json();
|
||||
```
|
||||
|
||||
```javascript
|
||||
await (await fetch("http://api.puter.localhost:4100/drivers/call", {
|
||||
"headers": {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": `Bearer ${puter.authToken}`,
|
||||
},
|
||||
"body": JSON.stringify({
|
||||
interface: 'puter-chat-completion',
|
||||
driver: 'openai-completion',
|
||||
method: 'complete',
|
||||
args: {
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: 'Act like Spongebob'
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: 'How do I make my code run faster?'
|
||||
},
|
||||
]
|
||||
},
|
||||
}),
|
||||
"method": "POST",
|
||||
})).json();
|
||||
```
|
Loading…
Reference in New Issue
Block a user