diff --git a/package.json b/package.json index db3580d..cbbf297 100644 --- a/package.json +++ b/package.json @@ -48,6 +48,7 @@ "access": "public" }, "dependencies": { - "fetch-ponyfill": "^7.1.0" + "fetch-ponyfill": "^7.1.0", + "handlebars": "^4.7.7" } } diff --git a/src/index.ts b/src/index.ts index ac06b64..4b12a7e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,10 +1,10 @@ import * as OpenAiImpl from './platforms/openai'; -export const SUPPORTED_PLATFORMS = { OpenAi: OpenAiImpl } as const; -export type PlatformConfig = OpenAiImpl.PlatformConfig; -export type PlatformEventEmitter = OpenAiImpl.PlatformEventEmitter; +const SUPPORTED_PLATFORMS = { OpenAi: OpenAiImpl } as const; export * as OpenAi from './platforms/openai'; +export type PlatformConfig = OpenAiImpl.PlatformConfig; +export type PlatformEventEmitter = OpenAiImpl.PlatformEventEmitter; export const createAiClient = (configParams: PlatformConfig): PlatformEventEmitter => { const { diff --git a/src/platforms/openai/chat.ts b/src/platforms/openai/chat.ts new file mode 100644 index 0000000..50fd960 --- /dev/null +++ b/src/platforms/openai/chat.ts @@ -0,0 +1,103 @@ +import Handlebars from 'handlebars'; +import { Message, MessageRole } from './message'; + +const isValidMessageObject = (maybeMessage: unknown): maybeMessage is Message => { + if (typeof maybeMessage !== 'object') { + return false; + } + + if (maybeMessage === null) { + return false; + } + + const messageObject = maybeMessage as Record; + + return ( + Object.values(MessageRole).includes(messageObject.role as MessageRole) + && typeof messageObject.content === 'string' + ); +}; + +export const normalizeChatMessage = (messageRaw: Message | Message[]) => { + if (typeof messageRaw === 'string') { + return [ + { + role: MessageRole.USER, + content: messageRaw, + }, + ]; + } + + if (Array.isArray(messageRaw)) { + return messageRaw.map((message) => { + if (typeof message === 'string') { + return { + role: MessageRole.USER, + content: message, + }; + } + + if (isValidMessageObject(message)) { + return message; + } + + throw new Error('Invalid message format'); + }); + } + + if (isValidMessageObject(messageRaw)) { + return [messageRaw]; + } + + throw new Error('Invalid message format'); +}; + +export const buildChatFromTranscript = (transcript: string) => { + const parameterized = Handlebars.create().compile(transcript, { + noEscape: true, + ignoreStandalone: true, + strict: true, + preventIndent: true, + }); + + return (params: Record) => { + const compiled = parameterized(params); + const prompts = compiled.split('\n---\n'); + return prompts.map((prompt) => { + const lines = prompt.trim().split('\n\n'); + let lastRole = MessageRole.USER; + return lines.filter((s) => s.trim().length > 0).map((lineRaw) => { + const line = lineRaw.replace(/\n/g, ' '); + const lineCheckRole = line.toLowerCase(); + if (lineCheckRole.startsWith('system:')) { + lastRole = MessageRole.SYSTEM; + return { + role: MessageRole.SYSTEM, + content: line.substring('system:'.length).trim(), + }; + } + + if (lineCheckRole.startsWith('user:')) { + lastRole = MessageRole.USER; + return { + role: MessageRole.USER, + content: line.substring('user:'.length).trim(), + }; + } + + if (lineCheckRole.startsWith('assistant:')) { + lastRole = MessageRole.ASSISTANT; + return { + role: MessageRole.ASSISTANT, + content: line.substring('assistant:'.length).trim(), + }; + } + + return { + role: lastRole, + content: line.trim(), + }; + }); + }); + }; +}; diff --git a/src/platforms/openai/common.ts b/src/platforms/openai/common.ts index ed6439f..cb85b0d 100644 --- a/src/platforms/openai/common.ts +++ b/src/platforms/openai/common.ts @@ -1,5 +1,3 @@ -import { Message, MessageRole } from './message'; - export enum FinishReason { STOP = 'stop', LENGTH = 'length', @@ -48,27 +46,13 @@ export class PlatformError extends Error { } } -export const normalizeChatMessage = (messageRaw: Message | Message[]) => { - if (typeof messageRaw === 'string') { - return [ - { - role: MessageRole.USER, - content: messageRaw, - }, - ]; - } - - if (Array.isArray(messageRaw)) { - return messageRaw.map((message) => { - if (typeof message === 'string') { - return { - role: MessageRole.USER, - content: message, - }; - } - return message; - }); - } +export enum ApiVersion { + V1 = 'v1', +} - return messageRaw; -}; +export interface Configuration { + organizationId?: string; + apiVersion: ApiVersion; + apiKey: string; + baseUrl?: string; +} diff --git a/src/platforms/openai/events.ts b/src/platforms/openai/events.ts index 7c41c05..b3ef233 100644 --- a/src/platforms/openai/events.ts +++ b/src/platforms/openai/events.ts @@ -1,7 +1,11 @@ -import { CreateChatCompletionParams } from './features/chat-completion'; -import { CreateImageParams } from './features/image'; -import { CreateTextCompletionParams } from './features/text-completion'; -import { CreateEditParams } from './features/edit'; +import { PassThrough } from 'stream'; +import { EventEmitter } from 'events'; +import fetchPonyfill from 'fetch-ponyfill'; +import { Configuration } from './common'; +import { createTextCompletion, CreateTextCompletionParams } from './features/text-completion'; +import { CreateChatCompletionParams, createChatCompletion } from './features/chat-completion'; +import { CreateImageParams, createImage } from './features/image'; +import { CreateEditParams, createEdit } from './features/edit'; export type DataEventCallback = (data: D) => void; @@ -16,3 +20,73 @@ export interface PlatformEventEmitter extends NodeJS.EventEmitter { on(event: 'end', callback: () => void): this; on(event: 'error', callback: ErrorEventCallback): this; } + +export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEventEmitter { + readonly createCompletion: PlatformEventEmitter['createCompletion']; + + readonly createImage: PlatformEventEmitter['createImage']; + + readonly createChatCompletion: PlatformEventEmitter['createChatCompletion']; + + readonly createEdit: PlatformEventEmitter['createEdit']; + + constructor(configParams: Configuration) { + super(); + const headers: Record = { + Authorization: `Bearer ${configParams.apiKey}`, + }; + + if (configParams.organizationId) { + headers['OpenAI-Organization'] = configParams.organizationId; + } + + const { fetch: fetchInstance } = fetchPonyfill(); + const doFetch = (method: string, path: string, body: Record) => { + const theFetchParams = { + method, + headers: { + ...headers, + 'Content-Type': 'application/json', + }, + body: JSON.stringify(body), + }; + + const url = new URL( + `/${configParams.apiVersion}${path}`, + configParams.baseUrl ?? 'https://api.openai.com', + ).toString(); + + this.emit('start', { + ...theFetchParams, + url, + }); + + return fetchInstance(url, theFetchParams); + }; + + const consumeStream = async (response: Response) => { + // eslint-disable-next-line no-restricted-syntax + for await (const chunk of response.body as unknown as PassThrough) { + const chunkStringMaybeMultiple = chunk.toString(); + const chunkStrings = chunkStringMaybeMultiple + .split('\n') + .filter((chunkString: string) => chunkString.length > 0); + chunkStrings.forEach((chunkString: string) => { + const dataRaw = chunkString.split('data: ').at(1); + if (!dataRaw) { + return; + } + if (dataRaw === '[DONE]') { + return; + } + const data = JSON.parse(dataRaw); + this.emit('data', data); + }); + } + }; + this.createImage = createImage.bind(this, doFetch); + this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream); + this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream); + this.createEdit = createEdit.bind(this, doFetch); + } +} diff --git a/src/platforms/openai/features/chat-completion.ts b/src/platforms/openai/features/chat-completion.ts index da9b997..66e92a3 100644 --- a/src/platforms/openai/features/chat-completion.ts +++ b/src/platforms/openai/features/chat-completion.ts @@ -3,13 +3,13 @@ import { ConsumeStream, DataEventId, DoFetch, - normalizeChatMessage, PlatformError, PlatformResponse, UsageMetadata, } from '../common'; import { Message, MessageObject } from '../message'; import { ChatCompletionModel } from '../models'; +import { normalizeChatMessage } from '../chat'; export interface CreateChatCompletionParams { messages: Message | Message[]; diff --git a/src/platforms/openai/index.ts b/src/platforms/openai/index.ts index 2673242..5b87ccb 100644 --- a/src/platforms/openai/index.ts +++ b/src/platforms/openai/index.ts @@ -1,20 +1,16 @@ -import fetchPonyfill from 'fetch-ponyfill'; -import { EventEmitter } from 'events'; -import { PassThrough } from 'stream'; -import { PlatformEventEmitter } from './events'; -import { createTextCompletion, TextCompletion } from './features/text-completion'; -import { createImage } from './features/image'; -import { createChatCompletion, ChatCompletion } from './features/chat-completion'; -import { createEdit } from './features/edit'; +import { Configuration } from './common'; export * from './message'; export * from './models'; -export { PlatformEventEmitter, ChatCompletion, TextCompletion }; +export * from './common'; +export { PlatformEventEmitter, PlatformEventEmitterImpl } from './events'; export { + ChatCompletion, ChatCompletionChunkDataEvent, DataEventObjectType as ChatCompletionDataEventObjectType, } from './features/chat-completion'; export { + TextCompletion, TextCompletionChunkDataEvent, DataEventObjectType as TextCompletionDataEventObjectType, } from './features/text-completion'; @@ -23,11 +19,6 @@ export { DataEventObjectType as EditDataEventObjectType, } from './features/edit'; export { CreateImageDataEvent, CreateImageSize } from './features/image'; -export * from './common'; - -export enum ApiVersion { - V1 = 'v1', -} export const PLATFORM_ID = 'openai' as const; @@ -35,80 +26,3 @@ export interface PlatformConfig { platform: typeof PLATFORM_ID; platformConfiguration: Configuration; } - -export interface Configuration { - organizationId?: string; - apiVersion: ApiVersion; - apiKey: string; - baseUrl?: string; -} - -export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEventEmitter { - readonly createCompletion: PlatformEventEmitter['createCompletion']; - - readonly createImage: PlatformEventEmitter['createImage']; - - readonly createChatCompletion: PlatformEventEmitter['createChatCompletion']; - - readonly createEdit: PlatformEventEmitter['createEdit']; - - constructor(configParams: Configuration) { - super(); - const headers: Record = { - Authorization: `Bearer ${configParams.apiKey}`, - }; - - if (configParams.organizationId) { - headers['OpenAI-Organization'] = configParams.organizationId; - } - - const { fetch: fetchInstance } = fetchPonyfill(); - const doFetch = (method: string, path: string, body: Record) => { - const theFetchParams = { - method, - headers: { - ...headers, - 'Content-Type': 'application/json', - }, - body: JSON.stringify(body), - }; - - const url = new URL( - `/${configParams.apiVersion}${path}`, - configParams.baseUrl ?? 'https://api.openai.com', - ).toString(); - - this.emit('start', { - ...theFetchParams, - url, - }); - - return fetchInstance(url, theFetchParams); - }; - - const consumeStream = async (response: Response) => { - // eslint-disable-next-line no-restricted-syntax - for await (const chunk of response.body as unknown as PassThrough) { - const chunkStringMaybeMultiple = chunk.toString(); - const chunkStrings = chunkStringMaybeMultiple - .split('\n') - .filter((chunkString: string) => chunkString.length > 0); - chunkStrings.forEach((chunkString: string) => { - const dataRaw = chunkString.split('data: ').at(1); - if (!dataRaw) { - return; - } - if (dataRaw === '[DONE]') { - return; - } - const data = JSON.parse(dataRaw); - this.emit('data', data); - }); - } - }; - this.createImage = createImage.bind(this, doFetch); - this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream); - this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream); - this.createEdit = createEdit.bind(this, doFetch); - } -} diff --git a/test/index.test.ts b/test/platforms/openai/api.test.ts similarity index 96% rename from test/index.test.ts rename to test/platforms/openai/api.test.ts index 1636328..83329c1 100644 --- a/test/index.test.ts +++ b/test/platforms/openai/api.test.ts @@ -10,14 +10,14 @@ import { createAiClient, PlatformEventEmitter, OpenAi, -} from '../src'; +} from '../../../src'; -describe('ai-utils', () => { +describe('OpenAI', () => { beforeAll(() => { config(); }); - describe('OpenAI', () => { + describe.skip('API', () => { let aiClient: PlatformEventEmitter; beforeEach(() => { @@ -31,7 +31,7 @@ describe('ai-utils', () => { }); }); - describe.skip('createChatCompletion', () => { + describe('createChatCompletion', () => { let result: Partial | undefined; beforeEach(() => { @@ -100,7 +100,7 @@ describe('ai-utils', () => { }), { timeout: 10000 }); }); - describe.skip('createImage', () => { + describe('createImage', () => { it('works', () => new Promise((resolve, reject) => { aiClient.on('data', (r) => { expect(r).toHaveProperty('created', expect.any(Number)); @@ -123,7 +123,7 @@ describe('ai-utils', () => { }), { timeout: 10000 }); }); - describe.skip('createCompletion', () => { + describe('createCompletion', () => { let result: Partial | undefined; beforeEach(() => { @@ -187,7 +187,7 @@ describe('ai-utils', () => { }), { timeout: 10000 }); }); - describe.skip('createEdit', () => { + describe('createEdit', () => { it('works', () => new Promise((resolve, reject) => { aiClient.on('data', (r) => { expect(r).toHaveProperty('object', OpenAi.EditDataEventObjectType.EDIT); diff --git a/test/platforms/openai/chat.test.ts b/test/platforms/openai/chat.test.ts new file mode 100644 index 0000000..70366ed --- /dev/null +++ b/test/platforms/openai/chat.test.ts @@ -0,0 +1,194 @@ +import { describe, it, expect } from 'vitest'; +import * as Chat from '../../../src/platforms/openai/chat'; +import { MessageRole } from '../../../src/platforms/openai'; + +describe('OpenAI', () => { + describe('chat', () => { + describe('normalizeChatMessage', () => { + it('normalizes a basic string', () => { + const message = Chat.normalizeChatMessage('This is a user message.'); + + expect(message).toHaveLength(1); + + expect(message).toContainEqual({ + role: 'user', + content: 'This is a user message.', + }); + }); + + it('normalizes a string array', () => { + const message = Chat.normalizeChatMessage([ + 'This is a user message.', + 'This is another user message.', + ]); + + expect(message).toHaveLength(2); + + expect(message).toContainEqual({ + role: 'user', + content: 'This is a user message.', + }); + + expect(message).toContainEqual({ + role: 'user', + content: 'This is another user message.', + }); + }); + + it('normalizes a message object', () => { + const message = Chat.normalizeChatMessage({ + role: MessageRole.USER, + content: 'This is a user message.', + }); + + expect(message).toHaveLength(1); + + expect(message).toContainEqual({ + role: 'user', + content: 'This is a user message.', + }); + }); + + it('normalizes a message object array', () => { + const message = Chat.normalizeChatMessage([ + { + role: MessageRole.USER, + content: 'This is a user message.', + }, + { + role: MessageRole.USER, + content: 'This is another user message.', + }, + ]); + + expect(message).toHaveLength(2); + + expect(message).toContainEqual({ + role: 'user', + content: 'This is a user message.', + }); + + expect(message).toContainEqual({ + role: 'user', + content: 'This is another user message.', + }); + }); + }); + + describe('buildChatFromTranscript', () => { + it('processes line breaks correctly', () => { + const message = ` +SYSTEM: This is a system message. This is a chat from the +user: This is a user message. +`; + const parameterized = Chat.buildChatFromTranscript(message); + const prompts = parameterized({}); + + expect(prompts[0]).toHaveLength(1); + + expect(prompts[0]).toContainEqual({ + role: 'system', + content: 'This is a system message. This is a chat from the user: This is a user message.', + }); + }); + + it('makes distinctions between different dialogues', () => { + const message = ` +SYSTEM: This is a system message. This is a chat from the + +user: This is a user message. +`; + const parameterized = Chat.buildChatFromTranscript(message); + const prompts = parameterized({}); + + expect(prompts[0]).toHaveLength(2); + + expect(prompts[0]).toContainEqual({ + role: 'system', + content: 'This is a system message. This is a chat from the', + }); + + expect(prompts[0]).toContainEqual({ + role: 'user', + content: 'This is a user message.', + }); + }); + + it('builds an array of chat messages from a single string.', () => { + const message = ` +SYSTEM: This is a system message. + +USER: This is a user message. + +SYSTEM: This is another system message. + +USER: This is another user message. + +ASSISTANT: This is an assistant message. +`; + const parameterized = Chat.buildChatFromTranscript(message); + const prompts = parameterized({}); + + expect(prompts[0]).toHaveLength(5); + + expect(prompts[0]).toContainEqual({ + role: 'system', + content: 'This is a system message.', + }); + + expect(prompts[0]).toContainEqual({ + role: 'user', + content: 'This is a user message.', + }); + + expect(prompts[0]).toContainEqual({ + role: 'system', + content: 'This is another system message.', + }); + + expect(prompts[0]).toContainEqual({ + role: 'user', + content: 'This is another user message.', + }); + + expect(prompts[0]).toContainEqual({ + role: 'assistant', + content: 'This is an assistant message.', + }); + }); + + it('builds multiple chat messages with a divider.', () => { + const message = ` +SYSTEM: This is a system message. +--- +USER: This is a user message. +`; + const parameterized = Chat.buildChatFromTranscript(message); + const prompts = parameterized({}); + + expect(prompts).toHaveLength(2); + expect(prompts[0]).toContainEqual({ + role: 'system', + content: 'This is a system message.', + }); + expect(prompts[1]).toContainEqual({ + role: 'user', + content: 'This is a user message.', + }); + }); + + it('injects parameters into the chat messages.', () => { + const message = ` + Say {{name}}. {{htmlChar}} \\{{escaped}} + `; + + const parameterized = Chat.buildChatFromTranscript(message); + const prompts = parameterized({ name: 'Hello', htmlChar: '' }); + expect(prompts[0][0]).toEqual({ + role: 'user', + content: 'Say Hello. {{escaped}}', + }); + }); + }); + }); +}); diff --git a/yarn.lock b/yarn.lock index e0e4ef1..e0bdb14 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1923,6 +1923,18 @@ grapheme-splitter@^1.0.4: resolved "https://registry.yarnpkg.com/grapheme-splitter/-/grapheme-splitter-1.0.4.tgz#9cf3a665c6247479896834af35cf1dbb4400767e" integrity sha512-bzh50DW9kTPM00T8y4o8vQg89Di9oLJVLW/KaOGIXJWP/iqCN6WKYkbNOF04vFLJhwcpYUh9ydh/+5vpOqV4YQ== +handlebars@^4.7.7: + version "4.7.7" + resolved "https://registry.yarnpkg.com/handlebars/-/handlebars-4.7.7.tgz#9ce33416aad02dbd6c8fafa8240d5d98004945a1" + integrity sha512-aAcXm5OAfE/8IXkcZvCepKU3VzW1/39Fb5ZuqMtgI/hT8X2YgoMvBY5dLhq/cpOvw7Lk1nK/UF71aLG/ZnVYRA== + dependencies: + minimist "^1.2.5" + neo-async "^2.6.0" + source-map "^0.6.1" + wordwrap "^1.0.0" + optionalDependencies: + uglify-js "^3.1.4" + has-bigints@^1.0.1, has-bigints@^1.0.2: version "1.0.2" resolved "https://registry.yarnpkg.com/has-bigints/-/has-bigints-1.0.2.tgz#0871bd3e3d51626f6ca0966668ba35d5602d6eaa" @@ -2440,7 +2452,7 @@ minimatch@^3.0.4, minimatch@^3.0.5, minimatch@^3.1.1, minimatch@^3.1.2: dependencies: brace-expansion "^1.1.7" -minimist@^1.2.0, minimist@^1.2.6: +minimist@^1.2.0, minimist@^1.2.5, minimist@^1.2.6: version "1.2.8" resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.8.tgz#c1a464e7693302e082a075cee0c057741ac4772c" integrity sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA== @@ -2480,6 +2492,11 @@ natural-compare@^1.4.0: resolved "https://registry.yarnpkg.com/natural-compare/-/natural-compare-1.4.0.tgz#4abebfeed7541f2c27acfb29bdbbd15c8d5ba4f7" integrity sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw== +neo-async@^2.6.0: + version "2.6.2" + resolved "https://registry.yarnpkg.com/neo-async/-/neo-async-2.6.2.tgz#b4aafb93e3aeb2d8174ca53cf163ab7d7308305f" + integrity sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw== + node-fetch@~2.6.1: version "2.6.9" resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.6.9.tgz#7c7f744b5cc6eb5fd404e0c7a9fec630a55657e6" @@ -3301,6 +3318,11 @@ ufo@^1.1.1: resolved "https://registry.yarnpkg.com/ufo/-/ufo-1.1.1.tgz#e70265e7152f3aba425bd013d150b2cdf4056d7c" integrity sha512-MvlCc4GHrmZdAllBc0iUDowff36Q9Ndw/UzqmEKyrfSzokTd9ZCy1i+IIk5hrYKkjoYVQyNbrw7/F8XJ2rEwTg== +uglify-js@^3.1.4: + version "3.17.4" + resolved "https://registry.yarnpkg.com/uglify-js/-/uglify-js-3.17.4.tgz#61678cf5fa3f5b7eb789bb345df29afb8257c22c" + integrity sha512-T9q82TJI9e/C1TAxYvfb16xO120tMVFZrGA3f9/P4424DNu6ypK103y0GPFVa17yotwSyZW5iYXgjYHkGrJW/g== + unbox-primitive@^1.0.2: version "1.0.2" resolved "https://registry.yarnpkg.com/unbox-primitive/-/unbox-primitive-1.0.2.tgz#29032021057d5e6cdbd08c5129c226dff8ed6f9e" @@ -3485,6 +3507,11 @@ word-wrap@^1.2.3: resolved "https://registry.yarnpkg.com/word-wrap/-/word-wrap-1.2.3.tgz#610636f6b1f703891bd34771ccb17fb93b47079c" integrity sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ== +wordwrap@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-1.0.0.tgz#27584810891456a4171c8d0226441ade90cbcaeb" + integrity sha512-gvVzJFlPycKc5dZN4yPkP8w7Dc37BtP1yczEneOb4uq34pXZcvrtRTmWV8W+Ume+XCxKgbjM+nevkyFPMybd4Q== + wrap-ansi@^6.2.0: version "6.2.0" resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-6.2.0.tgz#e9393ba07102e6c91a3b221478f0257cd2856e53"