Quellcode durchsuchen

Refactor structure, add chat utils

Reduce coupling by putting types and definitions in their appropriate
files.

Chat utils for creating prompts and messages have been added.
master
TheoryOfNekomata vor 1 Jahr
Ursprung
Commit
9959b0bfd8
10 geänderte Dateien mit 430 neuen und 133 gelöschten Zeilen
  1. +2
    -1
      package.json
  2. +3
    -3
      src/index.ts
  3. +103
    -0
      src/platforms/openai/chat.ts
  4. +9
    -25
      src/platforms/openai/common.ts
  5. +78
    -4
      src/platforms/openai/events.ts
  6. +1
    -1
      src/platforms/openai/features/chat-completion.ts
  7. +5
    -91
      src/platforms/openai/index.ts
  8. +7
    -7
      test/platforms/openai/api.test.ts
  9. +194
    -0
      test/platforms/openai/chat.test.ts
  10. +28
    -1
      yarn.lock

+ 2
- 1
package.json Datei anzeigen

@@ -48,6 +48,7 @@
"access": "public"
},
"dependencies": {
"fetch-ponyfill": "^7.1.0"
"fetch-ponyfill": "^7.1.0",
"handlebars": "^4.7.7"
}
}

+ 3
- 3
src/index.ts Datei anzeigen

@@ -1,10 +1,10 @@
import * as OpenAiImpl from './platforms/openai';

export const SUPPORTED_PLATFORMS = { OpenAi: OpenAiImpl } as const;
export type PlatformConfig = OpenAiImpl.PlatformConfig;
export type PlatformEventEmitter = OpenAiImpl.PlatformEventEmitter;
const SUPPORTED_PLATFORMS = { OpenAi: OpenAiImpl } as const;

export * as OpenAi from './platforms/openai';
export type PlatformConfig = OpenAiImpl.PlatformConfig;
export type PlatformEventEmitter = OpenAiImpl.PlatformEventEmitter;

export const createAiClient = (configParams: PlatformConfig): PlatformEventEmitter => {
const {


+ 103
- 0
src/platforms/openai/chat.ts Datei anzeigen

@@ -0,0 +1,103 @@
import Handlebars from 'handlebars';
import { Message, MessageRole } from './message';

const isValidMessageObject = (maybeMessage: unknown): maybeMessage is Message => {
if (typeof maybeMessage !== 'object') {
return false;
}

if (maybeMessage === null) {
return false;
}

const messageObject = maybeMessage as Record<string, unknown>;

return (
Object.values(MessageRole).includes(messageObject.role as MessageRole)
&& typeof messageObject.content === 'string'
);
};

export const normalizeChatMessage = (messageRaw: Message | Message[]) => {
if (typeof messageRaw === 'string') {
return [
{
role: MessageRole.USER,
content: messageRaw,
},
];
}

if (Array.isArray(messageRaw)) {
return messageRaw.map((message) => {
if (typeof message === 'string') {
return {
role: MessageRole.USER,
content: message,
};
}

if (isValidMessageObject(message)) {
return message;
}

throw new Error('Invalid message format');
});
}

if (isValidMessageObject(messageRaw)) {
return [messageRaw];
}

throw new Error('Invalid message format');
};

export const buildChatFromTranscript = (transcript: string) => {
const parameterized = Handlebars.create().compile(transcript, {
noEscape: true,
ignoreStandalone: true,
strict: true,
preventIndent: true,
});

return (params: Record<string, unknown>) => {
const compiled = parameterized(params);
const prompts = compiled.split('\n---\n');
return prompts.map((prompt) => {
const lines = prompt.trim().split('\n\n');
let lastRole = MessageRole.USER;
return lines.filter((s) => s.trim().length > 0).map((lineRaw) => {
const line = lineRaw.replace(/\n/g, ' ');
const lineCheckRole = line.toLowerCase();
if (lineCheckRole.startsWith('system:')) {
lastRole = MessageRole.SYSTEM;
return {
role: MessageRole.SYSTEM,
content: line.substring('system:'.length).trim(),
};
}

if (lineCheckRole.startsWith('user:')) {
lastRole = MessageRole.USER;
return {
role: MessageRole.USER,
content: line.substring('user:'.length).trim(),
};
}

if (lineCheckRole.startsWith('assistant:')) {
lastRole = MessageRole.ASSISTANT;
return {
role: MessageRole.ASSISTANT,
content: line.substring('assistant:'.length).trim(),
};
}

return {
role: lastRole,
content: line.trim(),
};
});
});
};
};

+ 9
- 25
src/platforms/openai/common.ts Datei anzeigen

@@ -1,5 +1,3 @@
import { Message, MessageRole } from './message';

export enum FinishReason {
STOP = 'stop',
LENGTH = 'length',
@@ -48,27 +46,13 @@ export class PlatformError extends Error {
}
}

export const normalizeChatMessage = (messageRaw: Message | Message[]) => {
if (typeof messageRaw === 'string') {
return [
{
role: MessageRole.USER,
content: messageRaw,
},
];
}

if (Array.isArray(messageRaw)) {
return messageRaw.map((message) => {
if (typeof message === 'string') {
return {
role: MessageRole.USER,
content: message,
};
}
return message;
});
}
export enum ApiVersion {
V1 = 'v1',
}

return messageRaw;
};
export interface Configuration {
organizationId?: string;
apiVersion: ApiVersion;
apiKey: string;
baseUrl?: string;
}

+ 78
- 4
src/platforms/openai/events.ts Datei anzeigen

@@ -1,7 +1,11 @@
import { CreateChatCompletionParams } from './features/chat-completion';
import { CreateImageParams } from './features/image';
import { CreateTextCompletionParams } from './features/text-completion';
import { CreateEditParams } from './features/edit';
import { PassThrough } from 'stream';
import { EventEmitter } from 'events';
import fetchPonyfill from 'fetch-ponyfill';
import { Configuration } from './common';
import { createTextCompletion, CreateTextCompletionParams } from './features/text-completion';
import { CreateChatCompletionParams, createChatCompletion } from './features/chat-completion';
import { CreateImageParams, createImage } from './features/image';
import { CreateEditParams, createEdit } from './features/edit';

export type DataEventCallback<D> = (data: D) => void;

@@ -16,3 +20,73 @@ export interface PlatformEventEmitter extends NodeJS.EventEmitter {
on(event: 'end', callback: () => void): this;
on(event: 'error', callback: ErrorEventCallback): this;
}

export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEventEmitter {
readonly createCompletion: PlatformEventEmitter['createCompletion'];

readonly createImage: PlatformEventEmitter['createImage'];

readonly createChatCompletion: PlatformEventEmitter['createChatCompletion'];

readonly createEdit: PlatformEventEmitter['createEdit'];

constructor(configParams: Configuration) {
super();
const headers: Record<string, string> = {
Authorization: `Bearer ${configParams.apiKey}`,
};

if (configParams.organizationId) {
headers['OpenAI-Organization'] = configParams.organizationId;
}

const { fetch: fetchInstance } = fetchPonyfill();
const doFetch = (method: string, path: string, body: Record<string, unknown>) => {
const theFetchParams = {
method,
headers: {
...headers,
'Content-Type': 'application/json',
},
body: JSON.stringify(body),
};

const url = new URL(
`/${configParams.apiVersion}${path}`,
configParams.baseUrl ?? 'https://api.openai.com',
).toString();

this.emit('start', {
...theFetchParams,
url,
});

return fetchInstance(url, theFetchParams);
};

const consumeStream = async (response: Response) => {
// eslint-disable-next-line no-restricted-syntax
for await (const chunk of response.body as unknown as PassThrough) {
const chunkStringMaybeMultiple = chunk.toString();
const chunkStrings = chunkStringMaybeMultiple
.split('\n')
.filter((chunkString: string) => chunkString.length > 0);
chunkStrings.forEach((chunkString: string) => {
const dataRaw = chunkString.split('data: ').at(1);
if (!dataRaw) {
return;
}
if (dataRaw === '[DONE]') {
return;
}
const data = JSON.parse(dataRaw);
this.emit('data', data);
});
}
};
this.createImage = createImage.bind(this, doFetch);
this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream);
this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream);
this.createEdit = createEdit.bind(this, doFetch);
}
}

+ 1
- 1
src/platforms/openai/features/chat-completion.ts Datei anzeigen

@@ -3,13 +3,13 @@ import {
ConsumeStream,
DataEventId,
DoFetch,
normalizeChatMessage,
PlatformError,
PlatformResponse,
UsageMetadata,
} from '../common';
import { Message, MessageObject } from '../message';
import { ChatCompletionModel } from '../models';
import { normalizeChatMessage } from '../chat';

export interface CreateChatCompletionParams {
messages: Message | Message[];


+ 5
- 91
src/platforms/openai/index.ts Datei anzeigen

@@ -1,20 +1,16 @@
import fetchPonyfill from 'fetch-ponyfill';
import { EventEmitter } from 'events';
import { PassThrough } from 'stream';
import { PlatformEventEmitter } from './events';
import { createTextCompletion, TextCompletion } from './features/text-completion';
import { createImage } from './features/image';
import { createChatCompletion, ChatCompletion } from './features/chat-completion';
import { createEdit } from './features/edit';
import { Configuration } from './common';

export * from './message';
export * from './models';
export { PlatformEventEmitter, ChatCompletion, TextCompletion };
export * from './common';
export { PlatformEventEmitter, PlatformEventEmitterImpl } from './events';
export {
ChatCompletion,
ChatCompletionChunkDataEvent,
DataEventObjectType as ChatCompletionDataEventObjectType,
} from './features/chat-completion';
export {
TextCompletion,
TextCompletionChunkDataEvent,
DataEventObjectType as TextCompletionDataEventObjectType,
} from './features/text-completion';
@@ -23,11 +19,6 @@ export {
DataEventObjectType as EditDataEventObjectType,
} from './features/edit';
export { CreateImageDataEvent, CreateImageSize } from './features/image';
export * from './common';

export enum ApiVersion {
V1 = 'v1',
}

export const PLATFORM_ID = 'openai' as const;

@@ -35,80 +26,3 @@ export interface PlatformConfig {
platform: typeof PLATFORM_ID;
platformConfiguration: Configuration;
}

export interface Configuration {
organizationId?: string;
apiVersion: ApiVersion;
apiKey: string;
baseUrl?: string;
}

export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEventEmitter {
readonly createCompletion: PlatformEventEmitter['createCompletion'];

readonly createImage: PlatformEventEmitter['createImage'];

readonly createChatCompletion: PlatformEventEmitter['createChatCompletion'];

readonly createEdit: PlatformEventEmitter['createEdit'];

constructor(configParams: Configuration) {
super();
const headers: Record<string, string> = {
Authorization: `Bearer ${configParams.apiKey}`,
};

if (configParams.organizationId) {
headers['OpenAI-Organization'] = configParams.organizationId;
}

const { fetch: fetchInstance } = fetchPonyfill();
const doFetch = (method: string, path: string, body: Record<string, unknown>) => {
const theFetchParams = {
method,
headers: {
...headers,
'Content-Type': 'application/json',
},
body: JSON.stringify(body),
};

const url = new URL(
`/${configParams.apiVersion}${path}`,
configParams.baseUrl ?? 'https://api.openai.com',
).toString();

this.emit('start', {
...theFetchParams,
url,
});

return fetchInstance(url, theFetchParams);
};

const consumeStream = async (response: Response) => {
// eslint-disable-next-line no-restricted-syntax
for await (const chunk of response.body as unknown as PassThrough) {
const chunkStringMaybeMultiple = chunk.toString();
const chunkStrings = chunkStringMaybeMultiple
.split('\n')
.filter((chunkString: string) => chunkString.length > 0);
chunkStrings.forEach((chunkString: string) => {
const dataRaw = chunkString.split('data: ').at(1);
if (!dataRaw) {
return;
}
if (dataRaw === '[DONE]') {
return;
}
const data = JSON.parse(dataRaw);
this.emit('data', data);
});
}
};
this.createImage = createImage.bind(this, doFetch);
this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream);
this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream);
this.createEdit = createEdit.bind(this, doFetch);
}
}

test/index.test.ts → test/platforms/openai/api.test.ts Datei anzeigen

@@ -10,14 +10,14 @@ import {
createAiClient,
PlatformEventEmitter,
OpenAi,
} from '../src';
} from '../../../src';

describe('ai-utils', () => {
describe('OpenAI', () => {
beforeAll(() => {
config();
});

describe('OpenAI', () => {
describe.skip('API', () => {
let aiClient: PlatformEventEmitter;

beforeEach(() => {
@@ -31,7 +31,7 @@ describe('ai-utils', () => {
});
});

describe.skip('createChatCompletion', () => {
describe('createChatCompletion', () => {
let result: Partial<OpenAi.ChatCompletion> | undefined;

beforeEach(() => {
@@ -100,7 +100,7 @@ describe('ai-utils', () => {
}), { timeout: 10000 });
});

describe.skip('createImage', () => {
describe('createImage', () => {
it('works', () => new Promise<void>((resolve, reject) => {
aiClient.on<OpenAi.CreateImageDataEvent>('data', (r) => {
expect(r).toHaveProperty('created', expect.any(Number));
@@ -123,7 +123,7 @@ describe('ai-utils', () => {
}), { timeout: 10000 });
});

describe.skip('createCompletion', () => {
describe('createCompletion', () => {
let result: Partial<OpenAi.TextCompletion> | undefined;

beforeEach(() => {
@@ -187,7 +187,7 @@ describe('ai-utils', () => {
}), { timeout: 10000 });
});

describe.skip('createEdit', () => {
describe('createEdit', () => {
it('works', () => new Promise<void>((resolve, reject) => {
aiClient.on<OpenAi.CreateEditDataEvent>('data', (r) => {
expect(r).toHaveProperty('object', OpenAi.EditDataEventObjectType.EDIT);

+ 194
- 0
test/platforms/openai/chat.test.ts Datei anzeigen

@@ -0,0 +1,194 @@
import { describe, it, expect } from 'vitest';
import * as Chat from '../../../src/platforms/openai/chat';
import { MessageRole } from '../../../src/platforms/openai';

describe('OpenAI', () => {
describe('chat', () => {
describe('normalizeChatMessage', () => {
it('normalizes a basic string', () => {
const message = Chat.normalizeChatMessage('This is a user message.');

expect(message).toHaveLength(1);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('normalizes a string array', () => {
const message = Chat.normalizeChatMessage([
'This is a user message.',
'This is another user message.',
]);

expect(message).toHaveLength(2);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});

expect(message).toContainEqual({
role: 'user',
content: 'This is another user message.',
});
});

it('normalizes a message object', () => {
const message = Chat.normalizeChatMessage({
role: MessageRole.USER,
content: 'This is a user message.',
});

expect(message).toHaveLength(1);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('normalizes a message object array', () => {
const message = Chat.normalizeChatMessage([
{
role: MessageRole.USER,
content: 'This is a user message.',
},
{
role: MessageRole.USER,
content: 'This is another user message.',
},
]);

expect(message).toHaveLength(2);

expect(message).toContainEqual({
role: 'user',
content: 'This is a user message.',
});

expect(message).toContainEqual({
role: 'user',
content: 'This is another user message.',
});
});
});

describe('buildChatFromTranscript', () => {
it('processes line breaks correctly', () => {
const message = `
SYSTEM: This is a system message. This is a chat from the
user: This is a user message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts[0]).toHaveLength(1);

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message. This is a chat from the user: This is a user message.',
});
});

it('makes distinctions between different dialogues', () => {
const message = `
SYSTEM: This is a system message. This is a chat from the

user: This is a user message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts[0]).toHaveLength(2);

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message. This is a chat from the',
});

expect(prompts[0]).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('builds an array of chat messages from a single string.', () => {
const message = `
SYSTEM: This is a system message.

USER: This is a user message.

SYSTEM: This is another system message.

USER: This is another user message.

ASSISTANT: This is an assistant message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts[0]).toHaveLength(5);

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message.',
});

expect(prompts[0]).toContainEqual({
role: 'user',
content: 'This is a user message.',
});

expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is another system message.',
});

expect(prompts[0]).toContainEqual({
role: 'user',
content: 'This is another user message.',
});

expect(prompts[0]).toContainEqual({
role: 'assistant',
content: 'This is an assistant message.',
});
});

it('builds multiple chat messages with a divider.', () => {
const message = `
SYSTEM: This is a system message.
---
USER: This is a user message.
`;
const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({});

expect(prompts).toHaveLength(2);
expect(prompts[0]).toContainEqual({
role: 'system',
content: 'This is a system message.',
});
expect(prompts[1]).toContainEqual({
role: 'user',
content: 'This is a user message.',
});
});

it('injects parameters into the chat messages.', () => {
const message = `
Say {{name}}. <name> <age> <foo> {{htmlChar}} \\{{escaped}}
`;

const parameterized = Chat.buildChatFromTranscript(message);
const prompts = parameterized({ name: 'Hello', htmlChar: '<html>' });
expect(prompts[0][0]).toEqual({
role: 'user',
content: 'Say Hello. <name> <age> <foo> <html> {{escaped}}',
});
});
});
});
});

+ 28
- 1
yarn.lock Datei anzeigen

@@ -1923,6 +1923,18 @@ grapheme-splitter@^1.0.4:
resolved "https://registry.yarnpkg.com/grapheme-splitter/-/grapheme-splitter-1.0.4.tgz#9cf3a665c6247479896834af35cf1dbb4400767e"
integrity sha512-bzh50DW9kTPM00T8y4o8vQg89Di9oLJVLW/KaOGIXJWP/iqCN6WKYkbNOF04vFLJhwcpYUh9ydh/+5vpOqV4YQ==
handlebars@^4.7.7:
version "4.7.7"
resolved "https://registry.yarnpkg.com/handlebars/-/handlebars-4.7.7.tgz#9ce33416aad02dbd6c8fafa8240d5d98004945a1"
integrity sha512-aAcXm5OAfE/8IXkcZvCepKU3VzW1/39Fb5ZuqMtgI/hT8X2YgoMvBY5dLhq/cpOvw7Lk1nK/UF71aLG/ZnVYRA==
dependencies:
minimist "^1.2.5"
neo-async "^2.6.0"
source-map "^0.6.1"
wordwrap "^1.0.0"
optionalDependencies:
uglify-js "^3.1.4"
has-bigints@^1.0.1, has-bigints@^1.0.2:
version "1.0.2"
resolved "https://registry.yarnpkg.com/has-bigints/-/has-bigints-1.0.2.tgz#0871bd3e3d51626f6ca0966668ba35d5602d6eaa"
@@ -2440,7 +2452,7 @@ minimatch@^3.0.4, minimatch@^3.0.5, minimatch@^3.1.1, minimatch@^3.1.2:
dependencies:
brace-expansion "^1.1.7"
minimist@^1.2.0, minimist@^1.2.6:
minimist@^1.2.0, minimist@^1.2.5, minimist@^1.2.6:
version "1.2.8"
resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.8.tgz#c1a464e7693302e082a075cee0c057741ac4772c"
integrity sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==
@@ -2480,6 +2492,11 @@ natural-compare@^1.4.0:
resolved "https://registry.yarnpkg.com/natural-compare/-/natural-compare-1.4.0.tgz#4abebfeed7541f2c27acfb29bdbbd15c8d5ba4f7"
integrity sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==
neo-async@^2.6.0:
version "2.6.2"
resolved "https://registry.yarnpkg.com/neo-async/-/neo-async-2.6.2.tgz#b4aafb93e3aeb2d8174ca53cf163ab7d7308305f"
integrity sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==
node-fetch@~2.6.1:
version "2.6.9"
resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.6.9.tgz#7c7f744b5cc6eb5fd404e0c7a9fec630a55657e6"
@@ -3301,6 +3318,11 @@ ufo@^1.1.1:
resolved "https://registry.yarnpkg.com/ufo/-/ufo-1.1.1.tgz#e70265e7152f3aba425bd013d150b2cdf4056d7c"
integrity sha512-MvlCc4GHrmZdAllBc0iUDowff36Q9Ndw/UzqmEKyrfSzokTd9ZCy1i+IIk5hrYKkjoYVQyNbrw7/F8XJ2rEwTg==
uglify-js@^3.1.4:
version "3.17.4"
resolved "https://registry.yarnpkg.com/uglify-js/-/uglify-js-3.17.4.tgz#61678cf5fa3f5b7eb789bb345df29afb8257c22c"
integrity sha512-T9q82TJI9e/C1TAxYvfb16xO120tMVFZrGA3f9/P4424DNu6ypK103y0GPFVa17yotwSyZW5iYXgjYHkGrJW/g==
unbox-primitive@^1.0.2:
version "1.0.2"
resolved "https://registry.yarnpkg.com/unbox-primitive/-/unbox-primitive-1.0.2.tgz#29032021057d5e6cdbd08c5129c226dff8ed6f9e"
@@ -3485,6 +3507,11 @@ word-wrap@^1.2.3:
resolved "https://registry.yarnpkg.com/word-wrap/-/word-wrap-1.2.3.tgz#610636f6b1f703891bd34771ccb17fb93b47079c"
integrity sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ==
wordwrap@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/wordwrap/-/wordwrap-1.0.0.tgz#27584810891456a4171c8d0226441ade90cbcaeb"
integrity sha512-gvVzJFlPycKc5dZN4yPkP8w7Dc37BtP1yczEneOb4uq34pXZcvrtRTmWV8W+Ume+XCxKgbjM+nevkyFPMybd4Q==
wrap-ansi@^6.2.0:
version "6.2.0"
resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-6.2.0.tgz#e9393ba07102e6c91a3b221478f0257cd2856e53"


Laden…
Abbrechen
Speichern