Parcourir la source

Refactor structure, add chat utils

Reduce coupling by putting types and definitions in their appropriate
files.

Chat utils for creating prompts and messages have been added.
master
TheoryOfNekomata il y a 1 an
Parent
révision
9959b0bfd8
10 fichiers modifiés avec 430 ajouts et 133 suppressions
  1. +2
    -1
      package.json
  2. +3
    -3
      src/index.ts
  3. +103
    -0
      src/platforms/openai/chat.ts
  4. +9
    -25
      src/platforms/openai/common.ts
  5. +78
    -4
      src/platforms/openai/events.ts
  6. +1
    -1
      src/platforms/openai/features/chat-completion.ts
  7. +5
    -91
      src/platforms/openai/index.ts
  8. +7
    -7
      test/platforms/openai/api.test.ts
  9. +194
    -0
      test/platforms/openai/chat.test.ts
  10. +28
    -1
      yarn.lock

+ 2
- 1
package.json Voir le fichier

@@ -48,6 +48,7 @@
"access": "public"
},
"dependencies": {
"fetch-ponyfill": "^7.1.0"
"fetch-ponyfill": "^7.1.0",
"handlebars": "^4.7.7"
}
}

+ 3
- 3
src/index.ts Voir le fichier

@@ -1,10 +1,10 @@
import * as OpenAiImpl from './platforms/openai';

export const SUPPORTED_PLATFORMS = { OpenAi: OpenAiImpl } as const;
export type PlatformConfig = OpenAiImpl.PlatformConfig;
export type PlatformEventEmitter = OpenAiImpl.PlatformEventEmitter;
const SUPPORTED_PLATFORMS = { OpenAi: OpenAiImpl } as const;

export * as OpenAi from './platforms/openai';
export type PlatformConfig = OpenAiImpl.PlatformConfig;
export type PlatformEventEmitter = OpenAiImpl.PlatformEventEmitter;

export const createAiClient = (configParams: PlatformConfig): PlatformEventEmitter => {
const {


+ 103
- 0
src/platforms/openai/chat.ts Voir le fichier

@@ -0,0 +1,103 @@
import Handlebars from 'handlebars';
import { Message, MessageRole } from './message';

const isValidMessageObject = (maybeMessage: unknown): maybeMessage is Message => {
if (typeof maybeMessage !== 'object') {
return false;
}

if (maybeMessage === null) {
return false;
}

const messageObject = maybeMessage as Record<string, unknown>;

return (
Object.values(MessageRole).includes(messageObject.role as MessageRole)
&& typeof messageObject.content === 'string'
);
};

export const normalizeChatMessage = (messageRaw: Message | Message[]) => {
if (typeof messageRaw === 'string') {
return [
{
role: MessageRole.USER,
content: messageRaw,
},
];
}

if (Array.isArray(messageRaw)) {
return messageRaw.map((message) => {
if (typeof message === 'string') {
return {
role: MessageRole.USER,
content: message,
};
}

if (isValidMessageObject(message)) {
return message;
}

throw new Error('Invalid message format');
});
}

if (isValidMessageObject(messageRaw)) {
return [messageRaw];
}

throw new Error('Invalid message format');
};

export const buildChatFromTranscript = (transcript: string) => {
const parameterized = Handlebars.create().compile(transcript, {
noEscape: true,
ignoreStandalone: true,
strict: true,
preventIndent: true,
});

return (params: Record<string, unknown>) => {
const compiled = parameterized(params);
const prompts = compiled.split('\n---\n');
return prompts.map((prompt) => {
const lines = prompt.trim().split('\n\n');
let lastRole = MessageRole.USER;
return lines.filter((s) => s.trim().length > 0).map((lineRaw) => {
const line = lineRaw.replace(/\n/g, ' ');
const lineCheckRole = line.toLowerCase();
if (lineCheckRole.startsWith('system:')) {
lastRole = MessageRole.SYSTEM;
return {
role: MessageRole.SYSTEM,
content: line.substring('system:'.length).trim(),
};
}

if (lineCheckRole.startsWith('user:')) {
lastRole = MessageRole.USER;
return {
role: MessageRole.USER,
content: line.substring('user:'.length).trim(),
};
}

if (lineCheckRole.startsWith('assistant:')) {
lastRole = MessageRole.ASSISTANT;
return {
role: MessageRole.ASSISTANT,
content: line.substring('assistant:'.length).trim(),
};
}

return {
role: lastRole,
content: line.trim(),
};
});
});
};
};

+ 9
- 25
src/platforms/openai/common.ts Voir le fichier

@@ -1,5 +1,3 @@
import { Message, MessageRole } from './message';

export enum FinishReason {
STOP = 'stop',
LENGTH = 'length',
@@ -48,27 +46,13 @@ export class PlatformError extends Error {
}
}

export const normalizeChatMessage = (messageRaw: Message | Message[]) => {
if (typeof messageRaw === 'string') {
return [
{
role: MessageRole.USER,
content: messageRaw,
},
];
}

if (Array.isArray(messageRaw)) {
return messageRaw.map((message) => {
if (typeof message === 'string') {
return {
role: MessageRole.USER,
content: message,
};
}
return message;
});
}
export enum ApiVersion {
V1 = 'v1',
}

return messageRaw;
};
export interface Configuration {
organizationId?: string;
apiVersion: ApiVersion;
apiKey: string;
baseUrl?: string;
}

+ 78
- 4
src/platforms/openai/events.ts Voir le fichier

@@ -1,7 +1,11 @@
import { CreateChatCompletionParams } from './features/chat-completion';
import { CreateImageParams } from './features/image';
import { CreateTextCompletionParams } from './features/text-completion';
import { CreateEditParams } from './features/edit';
import { PassThrough } from 'stream';
import { EventEmitter } from 'events';
import fetchPonyfill from 'fetch-ponyfill';
import { Configuration } from './common';
import { createTextCompletion, CreateTextCompletionParams } from './features/text-completion';
import { CreateChatCompletionParams, createChatCompletion } from './features/chat-completion';
import { CreateImageParams, createImage } from './features/image';
import { CreateEditParams, createEdit } from './features/edit';

export type DataEventCallback<D> = (data: D) => void;

@@ -16,3 +20,73 @@ export interface PlatformEventEmitter extends NodeJS.EventEmitter {
on(event: 'end', callback: () => void): this;
on(event: 'error', callback: ErrorEventCallback): this;
}

export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEventEmitter {
readonly createCompletion: PlatformEventEmitter['createCompletion'];

readonly createImage: PlatformEventEmitter['createImage'];

readonly createChatCompletion: PlatformEventEmitter['createChatCompletion'];

readonly createEdit: PlatformEventEmitter['createEdit'];

constructor(configParams: Configuration) {
super();
const headers: Record<string, string> = {
Authorization: `Bearer ${configParams.apiKey}`,
};

if (configParams.organizationId) {
headers['OpenAI-Organization'] = configParams.organizationId;
}

const { fetch: fetchInstance } = fetchPonyfill();
const doFetch = (method: string, path: string, body: Record<string, unknown>) => {
const theFetchParams = {
method,
headers: {
...headers,
'Content-Type': 'application/json',
},
body: JSON.stringify(body),
};

const url = new URL(
`/${configParams.apiVersion}${path}`,
configParams.baseUrl ?? 'https://api.openai.com',
).toString();

this.emit('start', {
...theFetchParams,
url,
});

return fetchInstance(url, theFetchParams);
};

const consumeStream = async (response: Response) => {
// eslint-disable-next-line no-restricted-syntax
for await (const chunk of response.body as unknown as PassThrough) {
const chunkStringMaybeMultiple = chunk.toString();
const chunkStrings = chunkStringMaybeMultiple
.split('\n')
.filter((chunkString: string) => chunkString.length > 0);
chunkStrings.forEach((chunkString: string) => {
const dataRaw = chunkString.split('data: ').at(1);
if (!dataRaw) {
return;
}
if (dataRaw === '[DONE]') {
return;
}
const data = JSON.parse(dataRaw);
this.emit('data', data);
});
}
};
this.createImage = createImage.bind(this, doFetch);
this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream);
this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream);
this.createEdit = createEdit.bind(this, doFetch);
}
}

+ 1
- 1
src/platforms/openai/features/chat-completion.ts Voir le fichier

@@ -3,13 +3,13 @@ import {
ConsumeStream,
DataEventId,
DoFetch,
normalizeChatMessage,
PlatformError,
PlatformResponse,
UsageMetadata,
} from '../common';
import { Message, MessageObject } from '../message';
import { ChatCompletionModel } from '../models';
import { normalizeChatMessage } from '../chat';

export interface CreateChatCompletionParams {
messages: Message | Message[];


+ 5
- 91
src/platforms/openai/index.ts Voir le fichier

@@ -1,20 +1,16 @@
import fetchPonyfill from 'fetch-ponyfill';
import { EventEmitter } from 'events';
import { PassThrough } from 'stream';
import { PlatformEventEmitter } from './events';
import { createTextCompletion, TextCompletion } from './features/text-completion';
import { createImage } from './features/image';
import { createChatCompletion, ChatCompletion } from './features/chat-completion';
import { createEdit } from './features/edit';
import { Configuration } from './common';

export * from './message';
export * from './models';
export { PlatformEventEmitter, ChatCompletion, TextCompletion };
export * from './common';
export { PlatformEventEmitter, PlatformEventEmitterImpl } from './events';
export {
ChatCompletion,
ChatCompletionChunkDataEvent,
DataEventObjectType as ChatCompletionDataEventObjectType,
} from './features/chat-completion';
export {
TextCompletion,
TextCompletionChunkDataEvent,
DataEventObjectType as TextCompletionDataEventObjectType,
} from './features/text-completion';
@@ -23,11 +19,6 @@ export {
DataEventObjectType as EditDataEventObjectType,
} from './features/edit';
export { CreateImageDataEvent, CreateImageSize } from './features/image';
export * from './common';

export enum ApiVersion {
V1 = 'v1',
}

export const PLATFORM_ID = 'openai' as const;

@@ -35,80 +26,3 @@ export interface PlatformConfig {
platform: typeof PLATFORM_ID;
platformConfiguration: Configuration;
}

export interface Configuration {
organizationId?: string;
apiVersion: ApiVersion;
apiKey: string;
baseUrl?: string;
}

export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEventEmitter {
readonly createCompletion: PlatformEventEmitter['createCompletion'];

readonly createImage: PlatformEventEmitter['createImage'];

readonly createChatCompletion: PlatformEventEmitter['createChatCompletion'];

readonly createEdit: PlatformEventEmitter['createEdit'];

constructor(configParams: Configuration) {
super();
const headers: Record<string, string> = {
Authorization: `Bearer ${configParams.apiKey}`,
};

if (configParams.organizationId) {
headers['OpenAI-Organization'] = configParams.organizationId;
}

const { fetch: fetchInstance } = fetchPonyfill();
const doFetch = (method: string, path: string, body: Record<string, unknown>) => {
const theFetchParams = {
method,
headers: {
...headers,
'Content-Type': 'application/json',
},
body: JSON.stringify(body),
};

const url = new URL(
`/${configParams.apiVersion}${path}`,
configParams.baseUrl ?? 'https://api.openai.com',
).toString();

this.emit('start', {
...theFetchParams,
url,
});

return fetchInstance(url, theFetchParams);
};

const consumeStream = async (response: Response) => {
// eslint-disable-next-line no-restricted-syntax
for await (const chunk of response.body as unknown as PassThrough) {
const chunkStringMaybeMultiple = chunk.toString();
const chunkStrings = chunkStringMaybeMultiple
.split('\n')
.filter((chunkString: string) => chunkString.length > 0);
chunkStrings.forEach((chunkString: string) => {
const dataRaw = chunkString.split('data: ').at(1);
if (!dataRaw) {
return;
}
if (dataRaw === '[DONE]') {
return;
}
const data = JSON.parse(dataRaw);
this.emit('data', data);
});
}
};
this.createImage = createImage.bind(this, doFetch);
this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream);
this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream);
this.createEdit = createEdit.bind(this, doFetch);
}
}

test/index.test.ts → test/platforms/openai/api.test.ts Voir le fichier

@@ -10,14 +10,14 @@ import {
createAiClient,
PlatformEventEmitter,
OpenAi,
} from '../src';
} from '../../../src';

describe('ai-utils', () => {
describe('OpenAI', () => {
beforeAll(() => {
config();
});

describe('OpenAI', () => {
describe.skip('API', () => {
let aiClient: PlatformEventEmitter;

beforeEach(() => {
@@ -31,7 +31,7 @@ describe('ai-utils', () => {
});
});

describe.skip('createChatCompletion', () => {
describe('createChatCompletion', () => {
let result: Partial<OpenAi.ChatCompletion> | undefined;

beforeEach(() => {
@@ -100,7 +100,7 @@ describe('ai-utils', () => {
}), { timeout: 10000 });
});

describe.skip('createImage', () => {
describe('createImage', () => {
it('works', () => new Promise<void>((resolve, reject) => {
aiClient.on<OpenAi.CreateImageDataEvent>('data', (r) => {
expect(r).toHaveProperty('created', expect.any(Number));
@@ -123,7 +123,7 @@ describe('ai-utils', () => {
}), { timeout: 10000 });
});

describe.skip('createCompletion', () => {
describe('createCompletion', () => {
let result: Partial<OpenAi.TextCompletion> | undefined;

beforeEach(() => {
@@ -187,7 +187,7 @@ describe('ai-utils', () => {
}), { timeout: 10000 });
});

describe.skip('createEdit', () => {
describe('createEdit', () => {
it('works', () => new Promise<void>((resolve, reject) => {
aiClient.on<OpenAi.CreateEditDataEvent>('data', (r) => {
expect(r).toHaveProperty('object', OpenAi.EditDataEventObjectType.EDIT);

+ 194
- 0
test/platforms/openai/chat.test.ts Voir le fichier

@@ -0,0 +1,194 @@
import { describe, it, expect } from 'vitest';
import * as Chat from '../../../src/platforms/openai/chat';
import { MessageRole } from '../../../src/platforms/openai';

describe('OpenAI', () => {
describe('chat', () => {
describe('normalizeChatMessage', () => {
it('normalizes a basic string', () => {
const message = Chat.normalizeChatMessage('This is a user message.');

expect(message).toHaveLength(1);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('normalizes a string array', () => {
const message = Chat.normalizeChatMessage([
'This is a user message.',
'This is another user message.',
]);

expect(message).toHaveLength(2);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});

expect(message).toContainEqual({
role: 'user',
content: 'This is another user message.',
});
});

it('normalizes a message object', () => {
const message = Chat.normalizeChatMessage({
role: MessageRole.USER,
content: 'This is a user message.',
});

expect(message).toHaveLength(1);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('normalizes a message object array', () => {
const message = Chat.normalizeChatMessage([
{
role: MessageRole.USER,
content: 'This is a user message.',
},
{
role: MessageRole.USER,
content: 'This is another user message.',
},
]);

expect(message).toHaveLength(2);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});

expect(message).toContainEqual({
role: 'user',
content: 'This is another user message.',
});
});
});

describe('buildChatFromTranscript', () => {
it('processes line breaks correctly', () => {
const message = `
SYSTEM: This is a system message. This is a chat from the
user: This is a user message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts[0]).toHaveLength(1);

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message. This is a chat from the user: This is a user message.',
});
});

it('makes distinctions between different dialogues', () => {
const message = `
SYSTEM: This is a system message. This is a chat from the

user: This is a user message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts[0]).toHaveLength(2);

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message. This is a chat from the',
});

expect(prompts[0]).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('builds an array of chat messages from a single string.', () => {
const message = `
SYSTEM: This is a system message.

USER: This is a user message.

SYSTEM: This is another system message.

USER: This is another user message.

ASSISTANT: This is an assistant message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts[0]).toHaveLength(5);

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message.',
});

expect(prompts[0]).toContainEqual({
role: 'user',
content: 'This is a user message.',
});

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is another system message.',
});

expect(prompts[0]).toContainEqual({
role: 'user',
content: 'This is another user message.',
});

expect(prompts[0]).toContainEqual({
role: 'assistant',
content: 'This is an assistant message.',
});
});

it('builds multiple chat messages with a divider.', () => {
const message = `
SYSTEM: This is a system message.
---
USER: This is a user message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts).toHaveLength(2);
expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message.',
});
expect(prompts[1]).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('injects parameters into the chat messages.', () => {
const message = `
Say {{name}}. <name> <age> <foo> {{htmlChar}} \\{{escaped}}
`;

const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({ name: 'Hello', htmlChar: '<html>' });
expect(prompts[0][0]).toEqual({
role: 'user',
content: 'Say Hello. <name> <age> <foo> <html> {{escaped}}',
});
});
});
});
});

+ 28
- 1
yarn.lock Voir le fichier

@@ -1923,6 +1923,18 @@ grapheme-splitter@^1.0.4:
resolved "https://registry.yarnpkg.com/grapheme-splitter/-/grapheme-splitter-1.0.4.tgz#9cf3a665c6247479896834af35cf1dbb4400767e"
integrity sha512-bzh50DW9kTPM00T8y4o8vQg89Di9oLJVLW/KaOGIXJWP/iqCN6WKYkbNOF04vFLJhwcpYUh9ydh/+5vpOqV4YQ==
handlebars@^4.7.7:
version "4.7.7"
resolved "https://registry.yarnpkg.com/handlebars/-/handlebars-4.7.7.tgz#9ce33416aad02dbd6c8fafa8240d5d98004945a1"
integrity sha512-aAcXm5OAfE/8IXkcZvCepKU3VzW1/39Fb5ZuqMtgI/hT8X2YgoMvBY5dLhq/cpOvw7Lk1nK/UF71aLG/ZnVYRA==
dependencies:
minimist "^1.2.5"
neo-async "^2.6.0"
source-map "^0.6.1"
wordwrap "^1.0.0"
optionalDependencies:
uglify-js "^3.1.4"
has-bigints@^1.0.1, has-bigints@^1.0.2:
version "1.0.2"
resolved "https://registry.yarnpkg.com/has-bigints/-/has-bigints-1.0.2.tgz#0871bd3e3d51626f6ca0966668ba35d5602d6eaa"
@@ -2440,7 +2452,7 @@ minimatch@^3.0.4, minimatch@^3.0.5, minimatch@^3.1.1, minimatch@^3.1.2:
dependencies:
brace-expansion "^1.1.7"
minimist@^1.2.0, minimist@^1.2.6:
minimist@^1.2.0, minimist@^1.2.5, minimist@^1.2.6:
version "1.2.8"
resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.8.tgz#c1a464e7693302e082a075cee0c057741ac4772c"
integrity sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==
@@ -2480,6 +2492,11 @@ natural-compare@^1.4.0:
resolved "https://registry.yarnpkg.com/natural-compare/-/natural-compare-1.4.0.tgz#4abebfeed7541f2c27acfb29bdbbd15c8d5ba4f7"
integrity sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==
neo-async@^2.6.0:
version "2.6.2"
resolved "https://registry.yarnpkg.com/neo-async/-/neo-async-2.6.2.tgz#b4aafb93e3aeb2d8174ca53cf163ab7d7308305f"
integrity sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==
node-fetch@~2.6.1:
version "2.6.9"
resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.6.9.tgz#7c7f744b5cc6eb5fd404e0c7a9fec630a55657e6"
@@ -3301,6 +3318,11 @@ ufo@^1.1.1:
resolved "https://registry.yarnpkg.com/ufo/-/ufo-1.1.1.tgz#e70265e7152f3aba425bd013d150b2cdf4056d7c"
integrity sha512-MvlCc4GHrmZdAllBc0iUDowff36Q9Ndw/UzqmEKyrfSzokTd9ZCy1i+IIk5hrYKkjoYVQyNbrw7/F8XJ2rEwTg==
uglify-js@^3.1.4:
version "3.17.4"
resolved "https://registry.yarnpkg.com/uglify-js/-/uglify-js-3.17.4.tgz#61678cf5fa3f5b7eb789bb345df29afb8257c22c"
integrity sha512-T9q82TJI9e/C1TAxYvfb16xO120tMVFZrGA3f9/P4424DNu6ypK103y0GPFVa17yotwSyZW5iYXgjYHkGrJW/g==
unbox-primitive@^1.0.2:
version "1.0.2"
resolved "https://registry.yarnpkg.com/unbox-primitive/-/unbox-primitive-1.0.2.tgz#29032021057d5e6cdbd08c5129c226dff8ed6f9e"
@@ -3485,6 +3507,11 @@ word-wrap@^1.2.3:
resolved "https://registry.yarnpkg.com/word-wrap/-/word-wrap-1.2.3.tgz#610636f6b1f703891bd34771ccb17fb93b47079c"
integrity sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ==
wordwrap@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-1.0.0.tgz#27584810891456a4171c8d0226441ade90cbcaeb"
integrity sha512-gvVzJFlPycKc5dZN4yPkP8w7Dc37BtP1yczEneOb4uq34pXZcvrtRTmWV8W+Ume+XCxKgbjM+nevkyFPMybd4Q==
wrap-ansi@^6.2.0:
version "6.2.0"
resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-6.2.0.tgz#e9393ba07102e6c91a3b221478f0257cd2856e53"


Chargement…
Annuler
Enregistrer