浏览代码

Refactor structure, add chat utils

Reduce coupling by putting types and definitions in their appropriate
files.

Chat utils for creating prompts and messages have been added.
master
父节点
当前提交
9959b0bfd8
共有 10 个文件被更改,包括 430 次插入133 次删除
  1. +2
    -1
      package.json
  2. +3
    -3
      src/index.ts
  3. +103
    -0
      src/platforms/openai/chat.ts
  4. +9
    -25
      src/platforms/openai/common.ts
  5. +78
    -4
      src/platforms/openai/events.ts
  6. +1
    -1
      src/platforms/openai/features/chat-completion.ts
  7. +5
    -91
      src/platforms/openai/index.ts
  8. +7
    -7
      test/platforms/openai/api.test.ts
  9. +194
    -0
      test/platforms/openai/chat.test.ts
  10. +28
    -1
      yarn.lock

+ 2
- 1
package.json 查看文件

@@ -48,6 +48,7 @@
"access": "public"
},
"dependencies": {
"fetch-ponyfill": "^7.1.0"
"fetch-ponyfill": "^7.1.0",
"handlebars": "^4.7.7"
}
}

+ 3
- 3
src/index.ts 查看文件

@@ -1,10 +1,10 @@
import * as OpenAiImpl from './platforms/openai';

export const SUPPORTED_PLATFORMS = { OpenAi: OpenAiImpl } as const;
export type PlatformConfig = OpenAiImpl.PlatformConfig;
export type PlatformEventEmitter = OpenAiImpl.PlatformEventEmitter;
const SUPPORTED_PLATFORMS = { OpenAi: OpenAiImpl } as const;

export * as OpenAi from './platforms/openai';
export type PlatformConfig = OpenAiImpl.PlatformConfig;
export type PlatformEventEmitter = OpenAiImpl.PlatformEventEmitter;

export const createAiClient = (configParams: PlatformConfig): PlatformEventEmitter => {
const {


+ 103
- 0
src/platforms/openai/chat.ts 查看文件

@@ -0,0 +1,103 @@
import Handlebars from 'handlebars';
import { Message, MessageRole } from './message';

const isValidMessageObject = (maybeMessage: unknown): maybeMessage is Message => {
if (typeof maybeMessage !== 'object') {
return false;
}

if (maybeMessage === null) {
return false;
}

const messageObject = maybeMessage as Record<string, unknown>;

return (
Object.values(MessageRole).includes(messageObject.role as MessageRole)
&& typeof messageObject.content === 'string'
);
};

export const normalizeChatMessage = (messageRaw: Message | Message[]) => {
if (typeof messageRaw === 'string') {
return [
{
role: MessageRole.USER,
content: messageRaw,
},
];
}

if (Array.isArray(messageRaw)) {
return messageRaw.map((message) => {
if (typeof message === 'string') {
return {
role: MessageRole.USER,
content: message,
};
}

if (isValidMessageObject(message)) {
return message;
}

throw new Error('Invalid message format');
});
}

if (isValidMessageObject(messageRaw)) {
return [messageRaw];
}

throw new Error('Invalid message format');
};

export const buildChatFromTranscript = (transcript: string) => {
const parameterized = Handlebars.create().compile(transcript, {
noEscape: true,
ignoreStandalone: true,
strict: true,
preventIndent: true,
});

return (params: Record<string, unknown>) => {
const compiled = parameterized(params);
const prompts = compiled.split('\n---\n');
return prompts.map((prompt) => {
const lines = prompt.trim().split('\n\n');
let lastRole = MessageRole.USER;
return lines.filter((s) => s.trim().length > 0).map((lineRaw) => {
const line = lineRaw.replace(/\n/g, ' ');
const lineCheckRole = line.toLowerCase();
if (lineCheckRole.startsWith('system:')) {
lastRole = MessageRole.SYSTEM;
return {
role: MessageRole.SYSTEM,
content: line.substring('system:'.length).trim(),
};
}

if (lineCheckRole.startsWith('user:')) {
lastRole = MessageRole.USER;
return {
role: MessageRole.USER,
content: line.substring('user:'.length).trim(),
};
}

if (lineCheckRole.startsWith('assistant:')) {
lastRole = MessageRole.ASSISTANT;
return {
role: MessageRole.ASSISTANT,
content: line.substring('assistant:'.length).trim(),
};
}

return {
role: lastRole,
content: line.trim(),
};
});
});
};
};

+ 9
- 25
src/platforms/openai/common.ts 查看文件

@@ -1,5 +1,3 @@
import { Message, MessageRole } from './message';

export enum FinishReason {
STOP = 'stop',
LENGTH = 'length',
@@ -48,27 +46,13 @@ export class PlatformError extends Error {
}
}

export const normalizeChatMessage = (messageRaw: Message | Message[]) => {
if (typeof messageRaw === 'string') {
return [
{
role: MessageRole.USER,
content: messageRaw,
},
];
}

if (Array.isArray(messageRaw)) {
return messageRaw.map((message) => {
if (typeof message === 'string') {
return {
role: MessageRole.USER,
content: message,
};
}
return message;
});
}
export enum ApiVersion {
V1 = 'v1',
}

return messageRaw;
};
export interface Configuration {
organizationId?: string;
apiVersion: ApiVersion;
apiKey: string;
baseUrl?: string;
}

+ 78
- 4
src/platforms/openai/events.ts 查看文件

@@ -1,7 +1,11 @@
import { CreateChatCompletionParams } from './features/chat-completion';
import { CreateImageParams } from './features/image';
import { CreateTextCompletionParams } from './features/text-completion';
import { CreateEditParams } from './features/edit';
import { PassThrough } from 'stream';
import { EventEmitter } from 'events';
import fetchPonyfill from 'fetch-ponyfill';
import { Configuration } from './common';
import { createTextCompletion, CreateTextCompletionParams } from './features/text-completion';
import { CreateChatCompletionParams, createChatCompletion } from './features/chat-completion';
import { CreateImageParams, createImage } from './features/image';
import { CreateEditParams, createEdit } from './features/edit';

export type DataEventCallback<D> = (data: D) => void;

@@ -16,3 +20,73 @@ export interface PlatformEventEmitter extends NodeJS.EventEmitter {
on(event: 'end', callback: () => void): this;
on(event: 'error', callback: ErrorEventCallback): this;
}

export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEventEmitter {
readonly createCompletion: PlatformEventEmitter['createCompletion'];

readonly createImage: PlatformEventEmitter['createImage'];

readonly createChatCompletion: PlatformEventEmitter['createChatCompletion'];

readonly createEdit: PlatformEventEmitter['createEdit'];

constructor(configParams: Configuration) {
super();
const headers: Record<string, string> = {
Authorization: `Bearer ${configParams.apiKey}`,
};

if (configParams.organizationId) {
headers['OpenAI-Organization'] = configParams.organizationId;
}

const { fetch: fetchInstance } = fetchPonyfill();
const doFetch = (method: string, path: string, body: Record<string, unknown>) => {
const theFetchParams = {
method,
headers: {
...headers,
'Content-Type': 'application/json',
},
body: JSON.stringify(body),
};

const url = new URL(
`/${configParams.apiVersion}${path}`,
configParams.baseUrl ?? 'https://api.openai.com',
).toString();

this.emit('start', {
...theFetchParams,
url,
});

return fetchInstance(url, theFetchParams);
};

const consumeStream = async (response: Response) => {
// eslint-disable-next-line no-restricted-syntax
for await (const chunk of response.body as unknown as PassThrough) {
const chunkStringMaybeMultiple = chunk.toString();
const chunkStrings = chunkStringMaybeMultiple
.split('\n')
.filter((chunkString: string) => chunkString.length > 0);
chunkStrings.forEach((chunkString: string) => {
const dataRaw = chunkString.split('data: ').at(1);
if (!dataRaw) {
return;
}
if (dataRaw === '[DONE]') {
return;
}
const data = JSON.parse(dataRaw);
this.emit('data', data);
});
}
};
this.createImage = createImage.bind(this, doFetch);
this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream);
this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream);
this.createEdit = createEdit.bind(this, doFetch);
}
}

+ 1
- 1
src/platforms/openai/features/chat-completion.ts 查看文件

@@ -3,13 +3,13 @@ import {
ConsumeStream,
DataEventId,
DoFetch,
normalizeChatMessage,
PlatformError,
PlatformResponse,
UsageMetadata,
} from '../common';
import { Message, MessageObject } from '../message';
import { ChatCompletionModel } from '../models';
import { normalizeChatMessage } from '../chat';

export interface CreateChatCompletionParams {
messages: Message | Message[];


+ 5
- 91
src/platforms/openai/index.ts 查看文件

@@ -1,20 +1,16 @@
import fetchPonyfill from 'fetch-ponyfill';
import { EventEmitter } from 'events';
import { PassThrough } from 'stream';
import { PlatformEventEmitter } from './events';
import { createTextCompletion, TextCompletion } from './features/text-completion';
import { createImage } from './features/image';
import { createChatCompletion, ChatCompletion } from './features/chat-completion';
import { createEdit } from './features/edit';
import { Configuration } from './common';

export * from './message';
export * from './models';
export { PlatformEventEmitter, ChatCompletion, TextCompletion };
export * from './common';
export { PlatformEventEmitter, PlatformEventEmitterImpl } from './events';
export {
ChatCompletion,
ChatCompletionChunkDataEvent,
DataEventObjectType as ChatCompletionDataEventObjectType,
} from './features/chat-completion';
export {
TextCompletion,
TextCompletionChunkDataEvent,
DataEventObjectType as TextCompletionDataEventObjectType,
} from './features/text-completion';
@@ -23,11 +19,6 @@ export {
DataEventObjectType as EditDataEventObjectType,
} from './features/edit';
export { CreateImageDataEvent, CreateImageSize } from './features/image';
export * from './common';

export enum ApiVersion {
V1 = 'v1',
}

export const PLATFORM_ID = 'openai' as const;

@@ -35,80 +26,3 @@ export interface PlatformConfig {
platform: typeof PLATFORM_ID;
platformConfiguration: Configuration;
}

export interface Configuration {
organizationId?: string;
apiVersion: ApiVersion;
apiKey: string;
baseUrl?: string;
}

export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEventEmitter {
readonly createCompletion: PlatformEventEmitter['createCompletion'];

readonly createImage: PlatformEventEmitter['createImage'];

readonly createChatCompletion: PlatformEventEmitter['createChatCompletion'];

readonly createEdit: PlatformEventEmitter['createEdit'];

constructor(configParams: Configuration) {
super();
const headers: Record<string, string> = {
Authorization: `Bearer ${configParams.apiKey}`,
};

if (configParams.organizationId) {
headers['OpenAI-Organization'] = configParams.organizationId;
}

const { fetch: fetchInstance } = fetchPonyfill();
const doFetch = (method: string, path: string, body: Record<string, unknown>) => {
const theFetchParams = {
method,
headers: {
...headers,
'Content-Type': 'application/json',
},
body: JSON.stringify(body),
};

const url = new URL(
`/${configParams.apiVersion}${path}`,
configParams.baseUrl ?? 'https://api.openai.com',
).toString();

this.emit('start', {
...theFetchParams,
url,
});

return fetchInstance(url, theFetchParams);
};

const consumeStream = async (response: Response) => {
// eslint-disable-next-line no-restricted-syntax
for await (const chunk of response.body as unknown as PassThrough) {
const chunkStringMaybeMultiple = chunk.toString();
const chunkStrings = chunkStringMaybeMultiple
.split('\n')
.filter((chunkString: string) => chunkString.length > 0);
chunkStrings.forEach((chunkString: string) => {
const dataRaw = chunkString.split('data: ').at(1);
if (!dataRaw) {
return;
}
if (dataRaw === '[DONE]') {
return;
}
const data = JSON.parse(dataRaw);
this.emit('data', data);
});
}
};
this.createImage = createImage.bind(this, doFetch);
this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream);
this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream);
this.createEdit = createEdit.bind(this, doFetch);
}
}

test/index.test.ts → test/platforms/openai/api.test.ts 查看文件

@@ -10,14 +10,14 @@ import {
createAiClient,
PlatformEventEmitter,
OpenAi,
} from '../src';
} from '../../../src';

describe('ai-utils', () => {
describe('OpenAI', () => {
beforeAll(() => {
config();
});

describe('OpenAI', () => {
describe.skip('API', () => {
let aiClient: PlatformEventEmitter;

beforeEach(() => {
@@ -31,7 +31,7 @@ describe('ai-utils', () => {
});
});

describe.skip('createChatCompletion', () => {
describe('createChatCompletion', () => {
let result: Partial<OpenAi.ChatCompletion> | undefined;

beforeEach(() => {
@@ -100,7 +100,7 @@ describe('ai-utils', () => {
}), { timeout: 10000 });
});

describe.skip('createImage', () => {
describe('createImage', () => {
it('works', () => new Promise<void>((resolve, reject) => {
aiClient.on<OpenAi.CreateImageDataEvent>('data', (r) => {
expect(r).toHaveProperty('created', expect.any(Number));
@@ -123,7 +123,7 @@ describe('ai-utils', () => {
}), { timeout: 10000 });
});

describe.skip('createCompletion', () => {
describe('createCompletion', () => {
let result: Partial<OpenAi.TextCompletion> | undefined;

beforeEach(() => {
@@ -187,7 +187,7 @@ describe('ai-utils', () => {
}), { timeout: 10000 });
});

describe.skip('createEdit', () => {
describe('createEdit', () => {
it('works', () => new Promise<void>((resolve, reject) => {
aiClient.on<OpenAi.CreateEditDataEvent>('data', (r) => {
expect(r).toHaveProperty('object', OpenAi.EditDataEventObjectType.EDIT);

+ 194
- 0
test/platforms/openai/chat.test.ts 查看文件

@@ -0,0 +1,194 @@
import { describe, it, expect } from 'vitest';
import * as Chat from '../../../src/platforms/openai/chat';
import { MessageRole } from '../../../src/platforms/openai';

describe('OpenAI', () => {
describe('chat', () => {
describe('normalizeChatMessage', () => {
it('normalizes a basic string', () => {
const message = Chat.normalizeChatMessage('This is a user message.');

expect(message).toHaveLength(1);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('normalizes a string array', () => {
const message = Chat.normalizeChatMessage([
'This is a user message.',
'This is another user message.',
]);

expect(message).toHaveLength(2);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});

expect(message).toContainEqual({
role: 'user',
content: 'This is another user message.',
});
});

it('normalizes a message object', () => {
const message = Chat.normalizeChatMessage({
role: MessageRole.USER,
content: 'This is a user message.',
});

expect(message).toHaveLength(1);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('normalizes a message object array', () => {
const message = Chat.normalizeChatMessage([
{
role: MessageRole.USER,
content: 'This is a user message.',
},
{
role: MessageRole.USER,
content: 'This is another user message.',
},
]);

expect(message).toHaveLength(2);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});

expect(message).toContainEqual({
role: 'user',
content: 'This is another user message.',
});
});
});

describe('buildChatFromTranscript', () => {
it('processes line breaks correctly', () => {
const message = `
SYSTEM: This is a system message. This is a chat from the
user: This is a user message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts[0]).toHaveLength(1);

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message. This is a chat from the user: This is a user message.',
});
});

it('makes distinctions between different dialogues', () => {
const message = `
SYSTEM: This is a system message. This is a chat from the

user: This is a user message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts[0]).toHaveLength(2);

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message. This is a chat from the',
});

expect(prompts[0]).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('builds an array of chat messages from a single string.', () => {
const message = `
SYSTEM: This is a system message.

USER: This is a user message.

SYSTEM: This is another system message.

USER: This is another user message.

ASSISTANT: This is an assistant message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts[0]).toHaveLength(5);

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message.',
});

expect(prompts[0]).toContainEqual({
role: 'user',
content: 'This is a user message.',
});

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is another system message.',
});

expect(prompts[0]).toContainEqual({
role: 'user',
content: 'This is another user message.',
});

expect(prompts[0]).toContainEqual({
role: 'assistant',
content: 'This is an assistant message.',
});
});

it('builds multiple chat messages with a divider.', () => {
const message = `
SYSTEM: This is a system message.
---
USER: This is a user message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts).toHaveLength(2);
expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message.',
});
expect(prompts[1]).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('injects parameters into the chat messages.', () => {
const message = `
Say {{name}}. <name> <age> <foo> {{htmlChar}} \\{{escaped}}
`;

const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({ name: 'Hello', htmlChar: '<html>' });
expect(prompts[0][0]).toEqual({
role: 'user',
content: 'Say Hello. <name> <age> <foo> <html> {{escaped}}',
});
});
});
});
});

+ 28
- 1
yarn.lock 查看文件

@@ -1923,6 +1923,18 @@ grapheme-splitter@^1.0.4:
resolved "https://registry.yarnpkg.com/grapheme-splitter/-/grapheme-splitter-1.0.4.tgz#9cf3a665c6247479896834af35cf1dbb4400767e"
integrity sha512-bzh50DW9kTPM00T8y4o8vQg89Di9oLJVLW/KaOGIXJWP/iqCN6WKYkbNOF04vFLJhwcpYUh9ydh/+5vpOqV4YQ==
handlebars@^4.7.7:
version "4.7.7"
resolved "https://registry.yarnpkg.com/handlebars/-/handlebars-4.7.7.tgz#9ce33416aad02dbd6c8fafa8240d5d98004945a1"
integrity sha512-aAcXm5OAfE/8IXkcZvCepKU3VzW1/39Fb5ZuqMtgI/hT8X2YgoMvBY5dLhq/cpOvw7Lk1nK/UF71aLG/ZnVYRA==
dependencies:
minimist "^1.2.5"
neo-async "^2.6.0"
source-map "^0.6.1"
wordwrap "^1.0.0"
optionalDependencies:
uglify-js "^3.1.4"
has-bigints@^1.0.1, has-bigints@^1.0.2:
version "1.0.2"
resolved "https://registry.yarnpkg.com/has-bigints/-/has-bigints-1.0.2.tgz#0871bd3e3d51626f6ca0966668ba35d5602d6eaa"
@@ -2440,7 +2452,7 @@ minimatch@^3.0.4, minimatch@^3.0.5, minimatch@^3.1.1, minimatch@^3.1.2:
dependencies:
brace-expansion "^1.1.7"
minimist@^1.2.0, minimist@^1.2.6:
minimist@^1.2.0, minimist@^1.2.5, minimist@^1.2.6:
version "1.2.8"
resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.8.tgz#c1a464e7693302e082a075cee0c057741ac4772c"
integrity sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==
@@ -2480,6 +2492,11 @@ natural-compare@^1.4.0:
resolved "https://registry.yarnpkg.com/natural-compare/-/natural-compare-1.4.0.tgz#4abebfeed7541f2c27acfb29bdbbd15c8d5ba4f7"
integrity sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==
neo-async@^2.6.0:
version "2.6.2"
resolved "https://registry.yarnpkg.com/neo-async/-/neo-async-2.6.2.tgz#b4aafb93e3aeb2d8174ca53cf163ab7d7308305f"
integrity sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==
node-fetch@~2.6.1:
version "2.6.9"
resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.6.9.tgz#7c7f744b5cc6eb5fd404e0c7a9fec630a55657e6"
@@ -3301,6 +3318,11 @@ ufo@^1.1.1:
resolved "https://registry.yarnpkg.com/ufo/-/ufo-1.1.1.tgz#e70265e7152f3aba425bd013d150b2cdf4056d7c"
integrity sha512-MvlCc4GHrmZdAllBc0iUDowff36Q9Ndw/UzqmEKyrfSzokTd9ZCy1i+IIk5hrYKkjoYVQyNbrw7/F8XJ2rEwTg==
uglify-js@^3.1.4:
version "3.17.4"
resolved "https://registry.yarnpkg.com/uglify-js/-/uglify-js-3.17.4.tgz#61678cf5fa3f5b7eb789bb345df29afb8257c22c"
integrity sha512-T9q82TJI9e/C1TAxYvfb16xO120tMVFZrGA3f9/P4424DNu6ypK103y0GPFVa17yotwSyZW5iYXgjYHkGrJW/g==
unbox-primitive@^1.0.2:
version "1.0.2"
resolved "https://registry.yarnpkg.com/unbox-primitive/-/unbox-primitive-1.0.2.tgz#29032021057d5e6cdbd08c5129c226dff8ed6f9e"
@@ -3485,6 +3507,11 @@ word-wrap@^1.2.3:
resolved "https://registry.yarnpkg.com/word-wrap/-/word-wrap-1.2.3.tgz#610636f6b1f703891bd34771ccb17fb93b47079c"
integrity sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ==
wordwrap@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-1.0.0.tgz#27584810891456a4171c8d0226441ade90cbcaeb"
integrity sha512-gvVzJFlPycKc5dZN4yPkP8w7Dc37BtP1yczEneOb4uq34pXZcvrtRTmWV8W+Ume+XCxKgbjM+nevkyFPMybd4Q==
wrap-ansi@^6.2.0:
version "6.2.0"
resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-6.2.0.tgz#e9393ba07102e6c91a3b221478f0257cd2856e53"


正在加载...
取消
保存