Browse Source

Refactor structure, add chat utils

Reduce coupling by putting types and definitions in their appropriate
files.

Chat utils for creating prompts and messages have been added.
master
TheoryOfNekomata 1 year ago
parent
commit
9959b0bfd8
10 changed files with 430 additions and 133 deletions
  1. +2
    -1
      package.json
  2. +3
    -3
      src/index.ts
  3. +103
    -0
      src/platforms/openai/chat.ts
  4. +9
    -25
      src/platforms/openai/common.ts
  5. +78
    -4
      src/platforms/openai/events.ts
  6. +1
    -1
      src/platforms/openai/features/chat-completion.ts
  7. +5
    -91
      src/platforms/openai/index.ts
  8. +7
    -7
      test/platforms/openai/api.test.ts
  9. +194
    -0
      test/platforms/openai/chat.test.ts
  10. +28
    -1
      yarn.lock

+ 2
- 1
package.json View File

@@ -48,6 +48,7 @@
"access": "public"
},
"dependencies": {
"fetch-ponyfill": "^7.1.0"
"fetch-ponyfill": "^7.1.0",
"handlebars": "^4.7.7"
}
}

+ 3
- 3
src/index.ts View File

@@ -1,10 +1,10 @@
import * as OpenAiImpl from './platforms/openai';

export const SUPPORTED_PLATFORMS = { OpenAi: OpenAiImpl } as const;
export type PlatformConfig = OpenAiImpl.PlatformConfig;
export type PlatformEventEmitter = OpenAiImpl.PlatformEventEmitter;
const SUPPORTED_PLATFORMS = { OpenAi: OpenAiImpl } as const;

export * as OpenAi from './platforms/openai';
export type PlatformConfig = OpenAiImpl.PlatformConfig;
export type PlatformEventEmitter = OpenAiImpl.PlatformEventEmitter;

export const createAiClient = (configParams: PlatformConfig): PlatformEventEmitter => {
const {


+ 103
- 0
src/platforms/openai/chat.ts View File

@@ -0,0 +1,103 @@
import Handlebars from 'handlebars';
import { Message, MessageRole } from './message';

const isValidMessageObject = (maybeMessage: unknown): maybeMessage is Message => {
if (typeof maybeMessage !== 'object') {
return false;
}

if (maybeMessage === null) {
return false;
}

const messageObject = maybeMessage as Record<string, unknown>;

return (
Object.values(MessageRole).includes(messageObject.role as MessageRole)
&& typeof messageObject.content === 'string'
);
};

export const normalizeChatMessage = (messageRaw: Message | Message[]) => {
if (typeof messageRaw === 'string') {
return [
{
role: MessageRole.USER,
content: messageRaw,
},
];
}

if (Array.isArray(messageRaw)) {
return messageRaw.map((message) => {
if (typeof message === 'string') {
return {
role: MessageRole.USER,
content: message,
};
}

if (isValidMessageObject(message)) {
return message;
}

throw new Error('Invalid message format');
});
}

if (isValidMessageObject(messageRaw)) {
return [messageRaw];
}

throw new Error('Invalid message format');
};

export const buildChatFromTranscript = (transcript: string) => {
const parameterized = Handlebars.create().compile(transcript, {
noEscape: true,
ignoreStandalone: true,
strict: true,
preventIndent: true,
});

return (params: Record<string, unknown>) => {
const compiled = parameterized(params);
const prompts = compiled.split('\n---\n');
return prompts.map((prompt) => {
const lines = prompt.trim().split('\n\n');
let lastRole = MessageRole.USER;
return lines.filter((s) => s.trim().length > 0).map((lineRaw) => {
const line = lineRaw.replace(/\n/g, ' ');
const lineCheckRole = line.toLowerCase();
if (lineCheckRole.startsWith('system:')) {
lastRole = MessageRole.SYSTEM;
return {
role: MessageRole.SYSTEM,
content: line.substring('system:'.length).trim(),
};
}

if (lineCheckRole.startsWith('user:')) {
lastRole = MessageRole.USER;
return {
role: MessageRole.USER,
content: line.substring('user:'.length).trim(),
};
}

if (lineCheckRole.startsWith('assistant:')) {
lastRole = MessageRole.ASSISTANT;
return {
role: MessageRole.ASSISTANT,
content: line.substring('assistant:'.length).trim(),
};
}

return {
role: lastRole,
content: line.trim(),
};
});
});
};
};

+ 9
- 25
src/platforms/openai/common.ts View File

@@ -1,5 +1,3 @@
import { Message, MessageRole } from './message';

export enum FinishReason {
STOP = 'stop',
LENGTH = 'length',
@@ -48,27 +46,13 @@ export class PlatformError extends Error {
}
}

export const normalizeChatMessage = (messageRaw: Message | Message[]) => {
if (typeof messageRaw === 'string') {
return [
{
role: MessageRole.USER,
content: messageRaw,
},
];
}

if (Array.isArray(messageRaw)) {
return messageRaw.map((message) => {
if (typeof message === 'string') {
return {
role: MessageRole.USER,
content: message,
};
}
return message;
});
}
export enum ApiVersion {
V1 = 'v1',
}

return messageRaw;
};
export interface Configuration {
organizationId?: string;
apiVersion: ApiVersion;
apiKey: string;
baseUrl?: string;
}

+ 78
- 4
src/platforms/openai/events.ts View File

@@ -1,7 +1,11 @@
import { CreateChatCompletionParams } from './features/chat-completion';
import { CreateImageParams } from './features/image';
import { CreateTextCompletionParams } from './features/text-completion';
import { CreateEditParams } from './features/edit';
import { PassThrough } from 'stream';
import { EventEmitter } from 'events';
import fetchPonyfill from 'fetch-ponyfill';
import { Configuration } from './common';
import { createTextCompletion, CreateTextCompletionParams } from './features/text-completion';
import { CreateChatCompletionParams, createChatCompletion } from './features/chat-completion';
import { CreateImageParams, createImage } from './features/image';
import { CreateEditParams, createEdit } from './features/edit';

export type DataEventCallback<D> = (data: D) => void;

@@ -16,3 +20,73 @@ export interface PlatformEventEmitter extends NodeJS.EventEmitter {
on(event: 'end', callback: () => void): this;
on(event: 'error', callback: ErrorEventCallback): this;
}

export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEventEmitter {
readonly createCompletion: PlatformEventEmitter['createCompletion'];

readonly createImage: PlatformEventEmitter['createImage'];

readonly createChatCompletion: PlatformEventEmitter['createChatCompletion'];

readonly createEdit: PlatformEventEmitter['createEdit'];

constructor(configParams: Configuration) {
super();
const headers: Record<string, string> = {
Authorization: `Bearer ${configParams.apiKey}`,
};

if (configParams.organizationId) {
headers['OpenAI-Organization'] = configParams.organizationId;
}

const { fetch: fetchInstance } = fetchPonyfill();
const doFetch = (method: string, path: string, body: Record<string, unknown>) => {
const theFetchParams = {
method,
headers: {
...headers,
'Content-Type': 'application/json',
},
body: JSON.stringify(body),
};

const url = new URL(
`/${configParams.apiVersion}${path}`,
configParams.baseUrl ?? 'https://api.openai.com',
).toString();

this.emit('start', {
...theFetchParams,
url,
});

return fetchInstance(url, theFetchParams);
};

const consumeStream = async (response: Response) => {
// eslint-disable-next-line no-restricted-syntax
for await (const chunk of response.body as unknown as PassThrough) {
const chunkStringMaybeMultiple = chunk.toString();
const chunkStrings = chunkStringMaybeMultiple
.split('\n')
.filter((chunkString: string) => chunkString.length > 0);
chunkStrings.forEach((chunkString: string) => {
const dataRaw = chunkString.split('data: ').at(1);
if (!dataRaw) {
return;
}
if (dataRaw === '[DONE]') {
return;
}
const data = JSON.parse(dataRaw);
this.emit('data', data);
});
}
};
this.createImage = createImage.bind(this, doFetch);
this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream);
this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream);
this.createEdit = createEdit.bind(this, doFetch);
}
}

+ 1
- 1
src/platforms/openai/features/chat-completion.ts View File

@@ -3,13 +3,13 @@ import {
ConsumeStream,
DataEventId,
DoFetch,
normalizeChatMessage,
PlatformError,
PlatformResponse,
UsageMetadata,
} from '../common';
import { Message, MessageObject } from '../message';
import { ChatCompletionModel } from '../models';
import { normalizeChatMessage } from '../chat';

export interface CreateChatCompletionParams {
messages: Message | Message[];


+ 5
- 91
src/platforms/openai/index.ts View File

@@ -1,20 +1,16 @@
import fetchPonyfill from 'fetch-ponyfill';
import { EventEmitter } from 'events';
import { PassThrough } from 'stream';
import { PlatformEventEmitter } from './events';
import { createTextCompletion, TextCompletion } from './features/text-completion';
import { createImage } from './features/image';
import { createChatCompletion, ChatCompletion } from './features/chat-completion';
import { createEdit } from './features/edit';
import { Configuration } from './common';

export * from './message';
export * from './models';
export { PlatformEventEmitter, ChatCompletion, TextCompletion };
export * from './common';
export { PlatformEventEmitter, PlatformEventEmitterImpl } from './events';
export {
ChatCompletion,
ChatCompletionChunkDataEvent,
DataEventObjectType as ChatCompletionDataEventObjectType,
} from './features/chat-completion';
export {
TextCompletion,
TextCompletionChunkDataEvent,
DataEventObjectType as TextCompletionDataEventObjectType,
} from './features/text-completion';
@@ -23,11 +19,6 @@ export {
DataEventObjectType as EditDataEventObjectType,
} from './features/edit';
export { CreateImageDataEvent, CreateImageSize } from './features/image';
export * from './common';

export enum ApiVersion {
V1 = 'v1',
}

export const PLATFORM_ID = 'openai' as const;

@@ -35,80 +26,3 @@ export interface PlatformConfig {
platform: typeof PLATFORM_ID;
platformConfiguration: Configuration;
}

export interface Configuration {
organizationId?: string;
apiVersion: ApiVersion;
apiKey: string;
baseUrl?: string;
}

export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEventEmitter {
readonly createCompletion: PlatformEventEmitter['createCompletion'];

readonly createImage: PlatformEventEmitter['createImage'];

readonly createChatCompletion: PlatformEventEmitter['createChatCompletion'];

readonly createEdit: PlatformEventEmitter['createEdit'];

constructor(configParams: Configuration) {
super();
const headers: Record<string, string> = {
Authorization: `Bearer ${configParams.apiKey}`,
};

if (configParams.organizationId) {
headers['OpenAI-Organization'] = configParams.organizationId;
}

const { fetch: fetchInstance } = fetchPonyfill();
const doFetch = (method: string, path: string, body: Record<string, unknown>) => {
const theFetchParams = {
method,
headers: {
...headers,
'Content-Type': 'application/json',
},
body: JSON.stringify(body),
};

const url = new URL(
`/${configParams.apiVersion}${path}`,
configParams.baseUrl ?? 'https://api.openai.com',
).toString();

this.emit('start', {
...theFetchParams,
url,
});

return fetchInstance(url, theFetchParams);
};

const consumeStream = async (response: Response) => {
// eslint-disable-next-line no-restricted-syntax
for await (const chunk of response.body as unknown as PassThrough) {
const chunkStringMaybeMultiple = chunk.toString();
const chunkStrings = chunkStringMaybeMultiple
.split('\n')
.filter((chunkString: string) => chunkString.length > 0);
chunkStrings.forEach((chunkString: string) => {
const dataRaw = chunkString.split('data: ').at(1);
if (!dataRaw) {
return;
}
if (dataRaw === '[DONE]') {
return;
}
const data = JSON.parse(dataRaw);
this.emit('data', data);
});
}
};
this.createImage = createImage.bind(this, doFetch);
this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream);
this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream);
this.createEdit = createEdit.bind(this, doFetch);
}
}

test/index.test.ts → test/platforms/openai/api.test.ts View File

@@ -10,14 +10,14 @@ import {
createAiClient,
PlatformEventEmitter,
OpenAi,
} from '../src';
} from '../../../src';

describe('ai-utils', () => {
describe('OpenAI', () => {
beforeAll(() => {
config();
});

describe('OpenAI', () => {
describe.skip('API', () => {
let aiClient: PlatformEventEmitter;

beforeEach(() => {
@@ -31,7 +31,7 @@ describe('ai-utils', () => {
});
});

describe.skip('createChatCompletion', () => {
describe('createChatCompletion', () => {
let result: Partial<OpenAi.ChatCompletion> | undefined;

beforeEach(() => {
@@ -100,7 +100,7 @@ describe('ai-utils', () => {
}), { timeout: 10000 });
});

describe.skip('createImage', () => {
describe('createImage', () => {
it('works', () => new Promise<void>((resolve, reject) => {
aiClient.on<OpenAi.CreateImageDataEvent>('data', (r) => {
expect(r).toHaveProperty('created', expect.any(Number));
@@ -123,7 +123,7 @@ describe('ai-utils', () => {
}), { timeout: 10000 });
});

describe.skip('createCompletion', () => {
describe('createCompletion', () => {
let result: Partial<OpenAi.TextCompletion> | undefined;

beforeEach(() => {
@@ -187,7 +187,7 @@ describe('ai-utils', () => {
}), { timeout: 10000 });
});

describe.skip('createEdit', () => {
describe('createEdit', () => {
it('works', () => new Promise<void>((resolve, reject) => {
aiClient.on<OpenAi.CreateEditDataEvent>('data', (r) => {
expect(r).toHaveProperty('object', OpenAi.EditDataEventObjectType.EDIT);

+ 194
- 0
test/platforms/openai/chat.test.ts View File

@@ -0,0 +1,194 @@
import { describe, it, expect } from 'vitest';
import * as Chat from '../../../src/platforms/openai/chat';
import { MessageRole } from '../../../src/platforms/openai';

describe('OpenAI', () => {
describe('chat', () => {
describe('normalizeChatMessage', () => {
it('normalizes a basic string', () => {
const message = Chat.normalizeChatMessage('This is a user message.');

expect(message).toHaveLength(1);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('normalizes a string array', () => {
const message = Chat.normalizeChatMessage([
'This is a user message.',
'This is another user message.',
]);

expect(message).toHaveLength(2);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});

expect(message).toContainEqual({
role: 'user',
content: 'This is another user message.',
});
});

it('normalizes a message object', () => {
const message = Chat.normalizeChatMessage({
role: MessageRole.USER,
content: 'This is a user message.',
});

expect(message).toHaveLength(1);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('normalizes a message object array', () => {
const message = Chat.normalizeChatMessage([
{
role: MessageRole.USER,
content: 'This is a user message.',
},
{
role: MessageRole.USER,
content: 'This is another user message.',
},
]);

expect(message).toHaveLength(2);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});

expect(message).toContainEqual({
role: 'user',
content: 'This is another user message.',
});
});
});

describe('buildChatFromTranscript', () => {
it('processes line breaks correctly', () => {
const message = `
SYSTEM: This is a system message. This is a chat from the
user: This is a user message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts[0]).toHaveLength(1);

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message. This is a chat from the user: This is a user message.',
});
});

it('makes distinctions between different dialogues', () => {
const message = `
SYSTEM: This is a system message. This is a chat from the

user: This is a user message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts[0]).toHaveLength(2);

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message. This is a chat from the',
});

expect(prompts[0]).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('builds an array of chat messages from a single string.', () => {
const message = `
SYSTEM: This is a system message.

USER: This is a user message.

SYSTEM: This is another system message.

USER: This is another user message.

ASSISTANT: This is an assistant message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts[0]).toHaveLength(5);

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message.',
});

expect(prompts[0]).toContainEqual({
role: 'user',
content: 'This is a user message.',
});

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is another system message.',
});

expect(prompts[0]).toContainEqual({
role: 'user',
content: 'This is another user message.',
});

expect(prompts[0]).toContainEqual({
role: 'assistant',
content: 'This is an assistant message.',
});
});

it('builds multiple chat messages with a divider.', () => {
const message = `
SYSTEM: This is a system message.
---
USER: This is a user message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts).toHaveLength(2);
expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message.',
});
expect(prompts[1]).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('injects parameters into the chat messages.', () => {
const message = `
Say {{name}}. <name> <age> <foo> {{htmlChar}} \\{{escaped}}
`;

const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({ name: 'Hello', htmlChar: '<html>' });
expect(prompts[0][0]).toEqual({
role: 'user',
content: 'Say Hello. <name> <age> <foo> <html> {{escaped}}',
});
});
});
});
});

+ 28
- 1
yarn.lock View File

@@ -1923,6 +1923,18 @@ grapheme-splitter@^1.0.4:
resolved "https://registry.yarnpkg.com/grapheme-splitter/-/grapheme-splitter-1.0.4.tgz#9cf3a665c6247479896834af35cf1dbb4400767e"
integrity sha512-bzh50DW9kTPM00T8y4o8vQg89Di9oLJVLW/KaOGIXJWP/iqCN6WKYkbNOF04vFLJhwcpYUh9ydh/+5vpOqV4YQ==
handlebars@^4.7.7:
version "4.7.7"
resolved "https://registry.yarnpkg.com/handlebars/-/handlebars-4.7.7.tgz#9ce33416aad02dbd6c8fafa8240d5d98004945a1"
integrity sha512-aAcXm5OAfE/8IXkcZvCepKU3VzW1/39Fb5ZuqMtgI/hT8X2YgoMvBY5dLhq/cpOvw7Lk1nK/UF71aLG/ZnVYRA==
dependencies:
minimist "^1.2.5"
neo-async "^2.6.0"
source-map "^0.6.1"
wordwrap "^1.0.0"
optionalDependencies:
uglify-js "^3.1.4"
has-bigints@^1.0.1, has-bigints@^1.0.2:
version "1.0.2"
resolved "https://registry.yarnpkg.com/has-bigints/-/has-bigints-1.0.2.tgz#0871bd3e3d51626f6ca0966668ba35d5602d6eaa"
@@ -2440,7 +2452,7 @@ minimatch@^3.0.4, minimatch@^3.0.5, minimatch@^3.1.1, minimatch@^3.1.2:
dependencies:
brace-expansion "^1.1.7"
minimist@^1.2.0, minimist@^1.2.6:
minimist@^1.2.0, minimist@^1.2.5, minimist@^1.2.6:
version "1.2.8"
resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.8.tgz#c1a464e7693302e082a075cee0c057741ac4772c"
integrity sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==
@@ -2480,6 +2492,11 @@ natural-compare@^1.4.0:
resolved "https://registry.yarnpkg.com/natural-compare/-/natural-compare-1.4.0.tgz#4abebfeed7541f2c27acfb29bdbbd15c8d5ba4f7"
integrity sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==
neo-async@^2.6.0:
version "2.6.2"
resolved "https://registry.yarnpkg.com/neo-async/-/neo-async-2.6.2.tgz#b4aafb93e3aeb2d8174ca53cf163ab7d7308305f"
integrity sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==
node-fetch@~2.6.1:
version "2.6.9"
resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.6.9.tgz#7c7f744b5cc6eb5fd404e0c7a9fec630a55657e6"
@@ -3301,6 +3318,11 @@ ufo@^1.1.1:
resolved "https://registry.yarnpkg.com/ufo/-/ufo-1.1.1.tgz#e70265e7152f3aba425bd013d150b2cdf4056d7c"
integrity sha512-MvlCc4GHrmZdAllBc0iUDowff36Q9Ndw/UzqmEKyrfSzokTd9ZCy1i+IIk5hrYKkjoYVQyNbrw7/F8XJ2rEwTg==
uglify-js@^3.1.4:
version "3.17.4"
resolved "https://registry.yarnpkg.com/uglify-js/-/uglify-js-3.17.4.tgz#61678cf5fa3f5b7eb789bb345df29afb8257c22c"
integrity sha512-T9q82TJI9e/C1TAxYvfb16xO120tMVFZrGA3f9/P4424DNu6ypK103y0GPFVa17yotwSyZW5iYXgjYHkGrJW/g==
unbox-primitive@^1.0.2:
version "1.0.2"
resolved "https://registry.yarnpkg.com/unbox-primitive/-/unbox-primitive-1.0.2.tgz#29032021057d5e6cdbd08c5129c226dff8ed6f9e"
@@ -3485,6 +3507,11 @@ word-wrap@^1.2.3:
resolved "https://registry.yarnpkg.com/word-wrap/-/word-wrap-1.2.3.tgz#610636f6b1f703891bd34771ccb17fb93b47079c"
integrity sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ==
wordwrap@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-1.0.0.tgz#27584810891456a4171c8d0226441ade90cbcaeb"
integrity sha512-gvVzJFlPycKc5dZN4yPkP8w7Dc37BtP1yczEneOb4uq34pXZcvrtRTmWV8W+Ume+XCxKgbjM+nevkyFPMybd4Q==
wrap-ansi@^6.2.0:
version "6.2.0"
resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-6.2.0.tgz#e9393ba07102e6c91a3b221478f0257cd2856e53"


Loading…
Cancel
Save