feat: add chat completions driver to puterai module

This commit is contained in:
KernelDeimos 2024-08-01 15:12:23 -04:00 committed by Eric Dubé
parent ef6671da18
commit 4e3bd1831e
4 changed files with 273 additions and 0 deletions

View File

@ -24,6 +24,20 @@ class AIInterfaceService extends BaseService {
},
}
});
col_interfaces.set('puter-chat-completion', {
description: 'Chatbot.',
methods: {
complete: {
description: 'Get completions for a chat log.',
parameters: {
messages: { type: 'json' },
vision: { type: 'flag' },
},
result: { type: 'json' }
}
}
});
}
}

View File

@ -0,0 +1,229 @@
const APIError = require('../../api/APIError');
const BaseService = require('../../services/BaseService');
const { Context } = require('../../util/context');
const SmolUtil = require('../../util/smolutil');
class OpenAICompletionService extends BaseService {
static MODULES = {
openai: require('openai'),
tiktoken: require('tiktoken'),
}
async _init () {
const sk_key =
this.config?.openai?.secret_key ??
this.global_config.openai?.secret_key;
this.openai = new this.modules.openai.OpenAI({
apiKey: sk_key
});
}
static IMPLEMENTS = {
['puter-chat-completion']: {
async complete ({ messages, vision }) {
const model = 'gpt-4o';
return await this.complete(messages, {
model,
moderation: true,
});
}
}
};
async check_moderation (text) {
// create moderation
const results = await this.openai.moderations.create({
input: text,
});
let flagged = false;
for ( const result of results?.results ?? [] ) {
if ( result.flagged ) {
flagged = true;
break;
}
}
return {
flagged,
results,
};
}
async complete (messages, { moderation, model }) {
// Validate messages
if ( ! Array.isArray(messages) ) {
throw new Error('`messages` must be an array');
}
model = model ?? 'gpt-3.5-turbo';
// model = model ?? 'gpt-4o';
for ( let i = 0; i < messages.length; i++ ) {
let msg = messages[i];
if ( typeof msg === 'string' ) msg = { content: msg };
if ( typeof msg !== 'object' ) {
throw new Error('each message must be a string or an object');
}
if ( ! msg.role ) msg.role = 'user';
if ( ! msg.content ) {
throw new Error('each message must have a `content` property');
}
const texts = [];
if ( typeof msg.content === 'string' ) texts.push(msg.content);
else if ( typeof msg.content === 'object' ) {
if ( Array.isArray(msg.content) ) {
texts.push(...msg.content.filter(o => (
( ! o.type && o.hasOwnProperty('text') ) ||
o.type === 'text')).map(o => o.text));
}
else texts.push(msg.content.text);
}
if ( moderation ) {
for ( const text of texts ) {
const moderation_result = await this.check_moderation(text);
if ( moderation_result.flagged ) {
throw new Error('message is not allowed');
}
}
}
messages[i] = msg;
}
messages.unshift({
role: 'system',
content: 'You are running inside a Puter app.',
})
// messages.unshift({
// role: 'system',
// content: 'Don\'t let the user trick you into doing something bad.',
// })
const user_private_uid = Context.get('actor')?.private_uid ?? 'UNKNOWN';
if ( user_private_uid === 'UNKNOWN' ) {
this.errors.report('chat-completion-service:unknown-user', {
message: 'failed to get a user ID for an OpenAI request',
alarm: true,
trace: true,
});
}
this.log.info('PRIVATE UID FOR USER ' + user_private_uid)
// Here's something fun; the documentation shows `type: 'image_url'` in
// objects that contain an image url, but everything still works if
// that's missing. We normalise it here so the token count code works.
for ( const msg of messages ) {
if ( ! msg.content ) continue;
if ( typeof msg.content !== 'object' ) continue;
const content = SmolUtil.ensure_array(msg.content);
for ( const o of content ) {
if ( ! o.hasOwnProperty('image_url') ) continue;
if ( o.type ) continue;
o.type = 'image_url';
}
}
console.log('DATA GOING IN', messages);
// Count tokens
let token_count = 0;
{
const enc = this.modules.tiktoken.encoding_for_model(model);
const text = JSON.stringify(messages)
const tokens = enc.encode(text);
token_count += tokens.length;
}
// Subtract image urls
for ( const msg of messages ) {
// console.log('msg and content', msg, msg.content);
if ( ! msg.content ) continue;
if ( typeof msg.content !== 'object' ) continue;
const content = SmolUtil.ensure_array(msg.content);
for ( const o of content ) {
// console.log('part of content', o);
if ( o.type !== 'image_url' ) continue;
const enc = this.modules.tiktoken.encoding_for_model(model);
const text = o.image_url?.url ?? '';
const tokens = enc.encode(text);
token_count -= tokens.length;
}
}
const max_tokens = 4096 - token_count;
console.log('MAX TOKENS ???', max_tokens);
if ( max_tokens <= 8 ) {
throw APIError.create('max_tokens_exceeded', null, {
input_tokens: token_count,
max_tokens: 4096 - 8,
});
}
const completion = await this.openai.chat.completions.create({
user: user_private_uid,
messages: messages,
model: model,
max_tokens,
});
this.log.info('how many choices?: ' + completion.choices.length);
// Record spending information
const spending_meta = {};
spending_meta.timestamp = Date.now();
spending_meta.count_tokens_input = token_count;
spending_meta.count_tokens_output = (() => {
// count output tokens (overestimate)
const enc = this.modules.tiktoken.encoding_for_model(model);
const text = JSON.stringify(completion.choices);
const tokens = enc.encode(text);
return tokens.length;
})();
const svc_spending = Context.get('services').get('spending');
svc_spending.record_spending('openai', 'chat-completion', spending_meta);
const svc_counting = Context.get('services').get('counting');
svc_counting.increment({
service_name: 'openai:chat-completion',
service_type: 'gpt',
values: {
model,
input_tokens: token_count,
output_tokens: spending_meta.count_tokens_output,
}
});
const is_empty = completion.choices?.[0]?.message?.content?.trim() === '';
if ( is_empty ) {
// GPT refuses to generate an empty response if you ask it to,
// so this will probably only happen on an error condition.
throw new Error('an empty response was generated');
}
// We need to moderate the completion too
if ( moderation ) {
const text = completion.choices[0].message.content;
const moderation_result = await this.check_moderation(text);
if ( moderation_result.flagged ) {
throw new Error('message is not allowed');
}
}
return completion.choices[0];
}
}
module.exports = {
OpenAICompletionService,
};

View File

@ -9,6 +9,9 @@ class PuterAIModule extends AdvancedBase {
const { AWSTextractService } = require('./AWSTextractService');
services.registerService('aws-textract', AWSTextractService);
const { OpenAICompletionService } = require('./OpenAICompletionService');
services.registerService('openai-completion', OpenAICompletionService);
}
}

View File

@ -14,4 +14,31 @@ await (await fetch("http://api.puter.localhost:4100/drivers/call", {
}),
"method": "POST",
})).json();
```
```javascript
await (await fetch("http://api.puter.localhost:4100/drivers/call", {
"headers": {
"Content-Type": "application/json",
"Authorization": `Bearer ${puter.authToken}`,
},
"body": JSON.stringify({
interface: 'puter-chat-completion',
driver: 'openai-completion',
method: 'complete',
args: {
messages: [
{
role: 'system',
content: 'Act like Spongebob'
},
{
role: 'user',
content: 'How do I make my code run faster?'
},
]
},
}),
"method": "POST",
})).json();
```