Ver código fonte

Refactor structure, add chat utils

Reduce coupling by putting types and definitions in their appropriate
files.

Chat utils for creating prompts and messages have been added.
master
TheoryOfNekomata 1 ano atrás
pai
commit
9959b0bfd8
10 arquivos alterados com 430 adições e 133 exclusões
  1. +2
    -1
      package.json
  2. +3
    -3
      src/index.ts
  3. +103
    -0
      src/platforms/openai/chat.ts
  4. +9
    -25
      src/platforms/openai/common.ts
  5. +78
    -4
      src/platforms/openai/events.ts
  6. +1
    -1
      src/platforms/openai/features/chat-completion.ts
  7. +5
    -91
      src/platforms/openai/index.ts
  8. +7
    -7
      test/platforms/openai/api.test.ts
  9. +194
    -0
      test/platforms/openai/chat.test.ts
  10. +28
    -1
      yarn.lock

+ 2
- 1
package.json Ver arquivo

@@ -48,6 +48,7 @@
"access": "public"
},
"dependencies": {
"fetch-ponyfill": "^7.1.0"
"fetch-ponyfill": "^7.1.0",
"handlebars": "^4.7.7"
}
}

+ 3
- 3
src/index.ts Ver arquivo

@@ -1,10 +1,10 @@
import * as OpenAiImpl from './platforms/openai';

export const SUPPORTED_PLATFORMS = { OpenAi: OpenAiImpl } as const;
export type PlatformConfig = OpenAiImpl.PlatformConfig;
export type PlatformEventEmitter = OpenAiImpl.PlatformEventEmitter;
const SUPPORTED_PLATFORMS = { OpenAi: OpenAiImpl } as const;

export * as OpenAi from './platforms/openai';
export type PlatformConfig = OpenAiImpl.PlatformConfig;
export type PlatformEventEmitter = OpenAiImpl.PlatformEventEmitter;

export const createAiClient = (configParams: PlatformConfig): PlatformEventEmitter => {
const {


+ 103
- 0
src/platforms/openai/chat.ts Ver arquivo

@@ -0,0 +1,103 @@
import Handlebars from 'handlebars';
import { Message, MessageRole } from './message';

const isValidMessageObject = (maybeMessage: unknown): maybeMessage is Message => {
if (typeof maybeMessage !== 'object') {
return false;
}

if (maybeMessage === null) {
return false;
}

const messageObject = maybeMessage as Record<string, unknown>;

return (
Object.values(MessageRole).includes(messageObject.role as MessageRole)
&& typeof messageObject.content === 'string'
);
};

export const normalizeChatMessage = (messageRaw: Message | Message[]) => {
if (typeof messageRaw === 'string') {
return [
{
role: MessageRole.USER,
content: messageRaw,
},
];
}

if (Array.isArray(messageRaw)) {
return messageRaw.map((message) => {
if (typeof message === 'string') {
return {
role: MessageRole.USER,
content: message,
};
}

if (isValidMessageObject(message)) {
return message;
}

throw new Error('Invalid message format');
});
}

if (isValidMessageObject(messageRaw)) {
return [messageRaw];
}

throw new Error('Invalid message format');
};

export const buildChatFromTranscript = (transcript: string) => {
const parameterized = Handlebars.create().compile(transcript, {
noEscape: true,
ignoreStandalone: true,
strict: true,
preventIndent: true,
});

return (params: Record<string, unknown>) => {
const compiled = parameterized(params);
const prompts = compiled.split('\n---\n');
return prompts.map((prompt) => {
const lines = prompt.trim().split('\n\n');
let lastRole = MessageRole.USER;
return lines.filter((s) => s.trim().length > 0).map((lineRaw) => {
const line = lineRaw.replace(/\n/g, ' ');
const lineCheckRole = line.toLowerCase();
if (lineCheckRole.startsWith('system:')) {
lastRole = MessageRole.SYSTEM;
return {
role: MessageRole.SYSTEM,
content: line.substring('system:'.length).trim(),
};
}

if (lineCheckRole.startsWith('user:')) {
lastRole = MessageRole.USER;
return {
role: MessageRole.USER,
content: line.substring('user:'.length).trim(),
};
}

if (lineCheckRole.startsWith('assistant:')) {
lastRole = MessageRole.ASSISTANT;
return {
role: MessageRole.ASSISTANT,
content: line.substring('assistant:'.length).trim(),
};
}

return {
role: lastRole,
content: line.trim(),
};
});
});
};
};

+ 9
- 25
src/platforms/openai/common.ts Ver arquivo

@@ -1,5 +1,3 @@
import { Message, MessageRole } from './message';

export enum FinishReason {
STOP = 'stop',
LENGTH = 'length',
@@ -48,27 +46,13 @@ export class PlatformError extends Error {
}
}

export const normalizeChatMessage = (messageRaw: Message | Message[]) => {
if (typeof messageRaw === 'string') {
return [
{
role: MessageRole.USER,
content: messageRaw,
},
];
}

if (Array.isArray(messageRaw)) {
return messageRaw.map((message) => {
if (typeof message === 'string') {
return {
role: MessageRole.USER,
content: message,
};
}
return message;
});
}
export enum ApiVersion {
V1 = 'v1',
}

return messageRaw;
};
export interface Configuration {
organizationId?: string;
apiVersion: ApiVersion;
apiKey: string;
baseUrl?: string;
}

+ 78
- 4
src/platforms/openai/events.ts Ver arquivo

@@ -1,7 +1,11 @@
import { CreateChatCompletionParams } from './features/chat-completion';
import { CreateImageParams } from './features/image';
import { CreateTextCompletionParams } from './features/text-completion';
import { CreateEditParams } from './features/edit';
import { PassThrough } from 'stream';
import { EventEmitter } from 'events';
import fetchPonyfill from 'fetch-ponyfill';
import { Configuration } from './common';
import { createTextCompletion, CreateTextCompletionParams } from './features/text-completion';
import { CreateChatCompletionParams, createChatCompletion } from './features/chat-completion';
import { CreateImageParams, createImage } from './features/image';
import { CreateEditParams, createEdit } from './features/edit';

export type DataEventCallback<D> = (data: D) => void;

@@ -16,3 +20,73 @@ export interface PlatformEventEmitter extends NodeJS.EventEmitter {
on(event: 'end', callback: () => void): this;
on(event: 'error', callback: ErrorEventCallback): this;
}

export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEventEmitter {
readonly createCompletion: PlatformEventEmitter['createCompletion'];

readonly createImage: PlatformEventEmitter['createImage'];

readonly createChatCompletion: PlatformEventEmitter['createChatCompletion'];

readonly createEdit: PlatformEventEmitter['createEdit'];

constructor(configParams: Configuration) {
super();
const headers: Record<string, string> = {
Authorization: `Bearer ${configParams.apiKey}`,
};

if (configParams.organizationId) {
headers['OpenAI-Organization'] = configParams.organizationId;
}

const { fetch: fetchInstance } = fetchPonyfill();
const doFetch = (method: string, path: string, body: Record<string, unknown>) => {
const theFetchParams = {
method,
headers: {
...headers,
'Content-Type': 'application/json',
},
body: JSON.stringify(body),
};

const url = new URL(
`/${configParams.apiVersion}${path}`,
configParams.baseUrl ?? 'https://api.openai.com',
).toString();

this.emit('start', {
...theFetchParams,
url,
});

return fetchInstance(url, theFetchParams);
};

const consumeStream = async (response: Response) => {
// eslint-disable-next-line no-restricted-syntax
for await (const chunk of response.body as unknown as PassThrough) {
const chunkStringMaybeMultiple = chunk.toString();
const chunkStrings = chunkStringMaybeMultiple
.split('\n')
.filter((chunkString: string) => chunkString.length > 0);
chunkStrings.forEach((chunkString: string) => {
const dataRaw = chunkString.split('data: ').at(1);
if (!dataRaw) {
return;
}
if (dataRaw === '[DONE]') {
return;
}
const data = JSON.parse(dataRaw);
this.emit('data', data);
});
}
};
this.createImage = createImage.bind(this, doFetch);
this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream);
this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream);
this.createEdit = createEdit.bind(this, doFetch);
}
}

+ 1
- 1
src/platforms/openai/features/chat-completion.ts Ver arquivo

@@ -3,13 +3,13 @@ import {
ConsumeStream,
DataEventId,
DoFetch,
normalizeChatMessage,
PlatformError,
PlatformResponse,
UsageMetadata,
} from '../common';
import { Message, MessageObject } from '../message';
import { ChatCompletionModel } from '../models';
import { normalizeChatMessage } from '../chat';

export interface CreateChatCompletionParams {
messages: Message | Message[];


+ 5
- 91
src/platforms/openai/index.ts Ver arquivo

@@ -1,20 +1,16 @@
import fetchPonyfill from 'fetch-ponyfill';
import { EventEmitter } from 'events';
import { PassThrough } from 'stream';
import { PlatformEventEmitter } from './events';
import { createTextCompletion, TextCompletion } from './features/text-completion';
import { createImage } from './features/image';
import { createChatCompletion, ChatCompletion } from './features/chat-completion';
import { createEdit } from './features/edit';
import { Configuration } from './common';

export * from './message';
export * from './models';
export { PlatformEventEmitter, ChatCompletion, TextCompletion };
export * from './common';
export { PlatformEventEmitter, PlatformEventEmitterImpl } from './events';
export {
ChatCompletion,
ChatCompletionChunkDataEvent,
DataEventObjectType as ChatCompletionDataEventObjectType,
} from './features/chat-completion';
export {
TextCompletion,
TextCompletionChunkDataEvent,
DataEventObjectType as TextCompletionDataEventObjectType,
} from './features/text-completion';
@@ -23,11 +19,6 @@ export {
DataEventObjectType as EditDataEventObjectType,
} from './features/edit';
export { CreateImageDataEvent, CreateImageSize } from './features/image';
export * from './common';

export enum ApiVersion {
V1 = 'v1',
}

export const PLATFORM_ID = 'openai' as const;

@@ -35,80 +26,3 @@ export interface PlatformConfig {
platform: typeof PLATFORM_ID;
platformConfiguration: Configuration;
}

export interface Configuration {
organizationId?: string;
apiVersion: ApiVersion;
apiKey: string;
baseUrl?: string;
}

export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEventEmitter {
readonly createCompletion: PlatformEventEmitter['createCompletion'];

readonly createImage: PlatformEventEmitter['createImage'];

readonly createChatCompletion: PlatformEventEmitter['createChatCompletion'];

readonly createEdit: PlatformEventEmitter['createEdit'];

constructor(configParams: Configuration) {
super();
const headers: Record<string, string> = {
Authorization: `Bearer ${configParams.apiKey}`,
};

if (configParams.organizationId) {
headers['OpenAI-Organization'] = configParams.organizationId;
}

const { fetch: fetchInstance } = fetchPonyfill();
const doFetch = (method: string, path: string, body: Record<string, unknown>) => {
const theFetchParams = {
method,
headers: {
...headers,
'Content-Type': 'application/json',
},
body: JSON.stringify(body),
};

const url = new URL(
`/${configParams.apiVersion}${path}`,
configParams.baseUrl ?? 'https://api.openai.com',
).toString();

this.emit('start', {
...theFetchParams,
url,
});

return fetchInstance(url, theFetchParams);
};

const consumeStream = async (response: Response) => {
// eslint-disable-next-line no-restricted-syntax
for await (const chunk of response.body as unknown as PassThrough) {
const chunkStringMaybeMultiple = chunk.toString();
const chunkStrings = chunkStringMaybeMultiple
.split('\n')
.filter((chunkString: string) => chunkString.length > 0);
chunkStrings.forEach((chunkString: string) => {
const dataRaw = chunkString.split('data: ').at(1);
if (!dataRaw) {
return;
}
if (dataRaw === '[DONE]') {
return;
}
const data = JSON.parse(dataRaw);
this.emit('data', data);
});
}
};
this.createImage = createImage.bind(this, doFetch);
this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream);
this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream);
this.createEdit = createEdit.bind(this, doFetch);
}
}

test/index.test.ts → test/platforms/openai/api.test.ts Ver arquivo

@@ -10,14 +10,14 @@ import {
createAiClient,
PlatformEventEmitter,
OpenAi,
} from '../src';
} from '../../../src';

describe('ai-utils', () => {
describe('OpenAI', () => {
beforeAll(() => {
config();
});

describe('OpenAI', () => {
describe.skip('API', () => {
let aiClient: PlatformEventEmitter;

beforeEach(() => {
@@ -31,7 +31,7 @@ describe('ai-utils', () => {
});
});

describe.skip('createChatCompletion', () => {
describe('createChatCompletion', () => {
let result: Partial<OpenAi.ChatCompletion> | undefined;

beforeEach(() => {
@@ -100,7 +100,7 @@ describe('ai-utils', () => {
}), { timeout: 10000 });
});

describe.skip('createImage', () => {
describe('createImage', () => {
it('works', () => new Promise<void>((resolve, reject) => {
aiClient.on<OpenAi.CreateImageDataEvent>('data', (r) => {
expect(r).toHaveProperty('created', expect.any(Number));
@@ -123,7 +123,7 @@ describe('ai-utils', () => {
}), { timeout: 10000 });
});

describe.skip('createCompletion', () => {
describe('createCompletion', () => {
let result: Partial<OpenAi.TextCompletion> | undefined;

beforeEach(() => {
@@ -187,7 +187,7 @@ describe('ai-utils', () => {
}), { timeout: 10000 });
});

describe.skip('createEdit', () => {
describe('createEdit', () => {
it('works', () => new Promise<void>((resolve, reject) => {
aiClient.on<OpenAi.CreateEditDataEvent>('data', (r) => {
expect(r).toHaveProperty('object', OpenAi.EditDataEventObjectType.EDIT);

+ 194
- 0
test/platforms/openai/chat.test.ts Ver arquivo

@@ -0,0 +1,194 @@
import { describe, it, expect } from 'vitest';
import * as Chat from '../../../src/platforms/openai/chat';
import { MessageRole } from '../../../src/platforms/openai';

describe('OpenAI', () => {
describe('chat', () => {
describe('normalizeChatMessage', () => {
it('normalizes a basic string', () => {
const message = Chat.normalizeChatMessage('This is a user message.');

expect(message).toHaveLength(1);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('normalizes a string array', () => {
const message = Chat.normalizeChatMessage([
'This is a user message.',
'This is another user message.',
]);

expect(message).toHaveLength(2);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});

expect(message).toContainEqual({
role: 'user',
content: 'This is another user message.',
});
});

it('normalizes a message object', () => {
const message = Chat.normalizeChatMessage({
role: MessageRole.USER,
content: 'This is a user message.',
});

expect(message).toHaveLength(1);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('normalizes a message object array', () => {
const message = Chat.normalizeChatMessage([
{
role: MessageRole.USER,
content: 'This is a user message.',
},
{
role: MessageRole.USER,
content: 'This is another user message.',
},
]);

expect(message).toHaveLength(2);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});

expect(message).toContainEqual({
role: 'user',
content: 'This is another user message.',
});
});
});

describe('buildChatFromTranscript', () => {
it('processes line breaks correctly', () => {
const message = `
SYSTEM: This is a system message. This is a chat from the
user: This is a user message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts[0]).toHaveLength(1);

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message. This is a chat from the user: This is a user message.',
});
});

it('makes distinctions between different dialogues', () => {
const message = `
SYSTEM: This is a system message. This is a chat from the

user: This is a user message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts[0]).toHaveLength(2);

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message. This is a chat from the',
});

expect(prompts[0]).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('builds an array of chat messages from a single string.', () => {
const message = `
SYSTEM: This is a system message.

USER: This is a user message.

SYSTEM: This is another system message.

USER: This is another user message.

ASSISTANT: This is an assistant message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts[0]).toHaveLength(5);

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message.',
});

expect(prompts[0]).toContainEqual({
role: 'user',
content: 'This is a user message.',
});

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is another system message.',
});

expect(prompts[0]).toContainEqual({
role: 'user',
content: 'This is another user message.',
});

expect(prompts[0]).toContainEqual({
role: 'assistant',
content: 'This is an assistant message.',
});
});

it('builds multiple chat messages with a divider.', () => {
const message = `
SYSTEM: This is a system message.
---
USER: This is a user message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts).toHaveLength(2);
expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message.',
});
expect(prompts[1]).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('injects parameters into the chat messages.', () => {
const message = `
Say {{name}}. <name> <age> <foo> {{htmlChar}} \\{{escaped}}
`;

const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({ name: 'Hello', htmlChar: '<html>' });
expect(prompts[0][0]).toEqual({
role: 'user',
content: 'Say Hello. <name> <age> <foo> <html> {{escaped}}',
});
});
});
});
});

+ 28
- 1
yarn.lock Ver arquivo

@@ -1923,6 +1923,18 @@ grapheme-splitter@^1.0.4:
resolved "https://registry.yarnpkg.com/grapheme-splitter/-/grapheme-splitter-1.0.4.tgz#9cf3a665c6247479896834af35cf1dbb4400767e"
integrity sha512-bzh50DW9kTPM00T8y4o8vQg89Di9oLJVLW/KaOGIXJWP/iqCN6WKYkbNOF04vFLJhwcpYUh9ydh/+5vpOqV4YQ==
handlebars@^4.7.7:
version "4.7.7"
resolved "https://registry.yarnpkg.com/handlebars/-/handlebars-4.7.7.tgz#9ce33416aad02dbd6c8fafa8240d5d98004945a1"
integrity sha512-aAcXm5OAfE/8IXkcZvCepKU3VzW1/39Fb5ZuqMtgI/hT8X2YgoMvBY5dLhq/cpOvw7Lk1nK/UF71aLG/ZnVYRA==
dependencies:
minimist "^1.2.5"
neo-async "^2.6.0"
source-map "^0.6.1"
wordwrap "^1.0.0"
optionalDependencies:
uglify-js "^3.1.4"
has-bigints@^1.0.1, has-bigints@^1.0.2:
version "1.0.2"
resolved "https://registry.yarnpkg.com/has-bigints/-/has-bigints-1.0.2.tgz#0871bd3e3d51626f6ca0966668ba35d5602d6eaa"
@@ -2440,7 +2452,7 @@ minimatch@^3.0.4, minimatch@^3.0.5, minimatch@^3.1.1, minimatch@^3.1.2:
dependencies:
brace-expansion "^1.1.7"
minimist@^1.2.0, minimist@^1.2.6:
minimist@^1.2.0, minimist@^1.2.5, minimist@^1.2.6:
version "1.2.8"
resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.8.tgz#c1a464e7693302e082a075cee0c057741ac4772c"
integrity sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==
@@ -2480,6 +2492,11 @@ natural-compare@^1.4.0:
resolved "https://registry.yarnpkg.com/natural-compare/-/natural-compare-1.4.0.tgz#4abebfeed7541f2c27acfb29bdbbd15c8d5ba4f7"
integrity sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==
neo-async@^2.6.0:
version "2.6.2"
resolved "https://registry.yarnpkg.com/neo-async/-/neo-async-2.6.2.tgz#b4aafb93e3aeb2d8174ca53cf163ab7d7308305f"
integrity sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==
node-fetch@~2.6.1:
version "2.6.9"
resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.6.9.tgz#7c7f744b5cc6eb5fd404e0c7a9fec630a55657e6"
@@ -3301,6 +3318,11 @@ ufo@^1.1.1:
resolved "https://registry.yarnpkg.com/ufo/-/ufo-1.1.1.tgz#e70265e7152f3aba425bd013d150b2cdf4056d7c"
integrity sha512-MvlCc4GHrmZdAllBc0iUDowff36Q9Ndw/UzqmEKyrfSzokTd9ZCy1i+IIk5hrYKkjoYVQyNbrw7/F8XJ2rEwTg==
uglify-js@^3.1.4:
version "3.17.4"
resolved "https://registry.yarnpkg.com/uglify-js/-/uglify-js-3.17.4.tgz#61678cf5fa3f5b7eb789bb345df29afb8257c22c"
integrity sha512-T9q82TJI9e/C1TAxYvfb16xO120tMVFZrGA3f9/P4424DNu6ypK103y0GPFVa17yotwSyZW5iYXgjYHkGrJW/g==
unbox-primitive@^1.0.2:
version "1.0.2"
resolved "https://registry.yarnpkg.com/unbox-primitive/-/unbox-primitive-1.0.2.tgz#29032021057d5e6cdbd08c5129c226dff8ed6f9e"
@@ -3485,6 +3507,11 @@ word-wrap@^1.2.3:
resolved "https://registry.yarnpkg.com/word-wrap/-/word-wrap-1.2.3.tgz#610636f6b1f703891bd34771ccb17fb93b47079c"
integrity sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ==
wordwrap@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-1.0.0.tgz#27584810891456a4171c8d0226441ade90cbcaeb"
integrity sha512-gvVzJFlPycKc5dZN4yPkP8w7Dc37BtP1yczEneOb4uq34pXZcvrtRTmWV8W+Ume+XCxKgbjM+nevkyFPMybd4Q==
wrap-ansi@^6.2.0:
version "6.2.0"
resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-6.2.0.tgz#e9393ba07102e6c91a3b221478f0257cd2856e53"


Carregando…
Cancelar
Salvar