From 5fe705e6dc0924892ccda10f559b4cdb679d7d67 Mon Sep 17 00:00:00 2001 From: TheoryOfNekomata Date: Sat, 22 Apr 2023 18:50:46 +0800 Subject: [PATCH] Implement model listing, other image functions Add image edits and variations endpoints. --- src/packages/form-data.ts | 53 ++++++++ src/platforms/openai/chat.ts | 6 +- src/platforms/openai/common.ts | 6 +- src/platforms/openai/events.ts | 53 ++++++-- .../openai/features/chat-completion.ts | 6 +- src/platforms/openai/features/edit.ts | 6 +- src/platforms/openai/features/image.ts | 115 ++++++++++++++++-- src/platforms/openai/features/model.ts | 40 ++++++ .../openai/features/text-completion.ts | 6 +- src/platforms/openai/index.ts | 2 +- 10 files changed, 258 insertions(+), 35 deletions(-) create mode 100644 src/packages/form-data.ts create mode 100644 src/platforms/openai/features/model.ts diff --git a/src/packages/form-data.ts b/src/packages/form-data.ts new file mode 100644 index 0000000..99b2e10 --- /dev/null +++ b/src/packages/form-data.ts @@ -0,0 +1,53 @@ +const appendValue = (formData: FormData, key: string, value: unknown, arrayDepth = 0) => { + if (value instanceof Buffer) { + formData.append(key, new Blob([value])); + return; + } + if (value instanceof Blob) { + formData.append(key, value); + return; + } + if (value instanceof File) { + formData.append(key, value, value.name); + return; + } + if (typeof value === 'string') { + formData.append(key, value); + return; + } + if (typeof value === 'number' && !Number.isNaN(value)) { + formData.append(key, value.toString(10)); + return; + } + if (typeof value === 'boolean') { + formData.append(key, value.toString()); + return; + } + if (Array.isArray(value) && arrayDepth === 0) { + appendValue(formData, key, value, arrayDepth + 1); + return; + } + if (typeof value === 'object' && value !== null) { + formData.append(key, JSON.stringify(value)); + return; + } + throw new Error(`Invalid value for key: ${key}`); +}; + +export const fromJson = (json: Record) => { + const formData = new FormData(); + Object + .entries(json) + .forEach(([key, value]) => { + appendValue(formData, key, value); + }); + return formData; +}; + +export const toJson = (formData: FormData) => { + const json = {} as Record; + formData.forEach((value, key) => { + json[key] = value; + }); + return json; +}; diff --git a/src/platforms/openai/chat.ts b/src/platforms/openai/chat.ts index 9b806a3..2f2238a 100644 --- a/src/platforms/openai/chat.ts +++ b/src/platforms/openai/chat.ts @@ -85,7 +85,7 @@ export const buildChatFromTranscript = (transcript: string) => { lastRole = MessageRole.SYSTEM; return { role: MessageRole.SYSTEM, - content: line.substring('system:'.length).trim(), + content: line.slice('system:'.length).trim(), }; } @@ -93,7 +93,7 @@ export const buildChatFromTranscript = (transcript: string) => { lastRole = MessageRole.USER; return { role: MessageRole.USER, - content: line.substring('user:'.length).trim(), + content: line.slice('user:'.length).trim(), }; } @@ -101,7 +101,7 @@ export const buildChatFromTranscript = (transcript: string) => { lastRole = MessageRole.ASSISTANT; return { role: MessageRole.ASSISTANT, - content: line.substring('assistant:'.length).trim(), + content: line.slice('assistant:'.length).trim(), }; } diff --git a/src/platforms/openai/common.ts b/src/platforms/openai/common.ts index afad96a..5bcc7cc 100644 --- a/src/platforms/openai/common.ts +++ b/src/platforms/openai/common.ts @@ -15,14 +15,16 @@ export type DataEventId = string; export type Timestamp = number; -export interface PlatformResponse { +export interface CreatedResource { created: Timestamp; } +export type DoFetchBody = BodyInit | Record + export type DoFetch = ( method: string, path: string, - body: Record + body?: DoFetchBody ) => Promise; export type ConsumeStream = ( diff --git a/src/platforms/openai/events.ts b/src/platforms/openai/events.ts index 1648849..958561e 100644 --- a/src/platforms/openai/events.ts +++ b/src/platforms/openai/events.ts @@ -1,11 +1,19 @@ import { PassThrough } from 'stream'; import { EventEmitter } from 'events'; import fetchPonyfill from 'fetch-ponyfill'; -import { Configuration } from './common'; +import { Configuration, DoFetchBody } from './common'; import { createTextCompletion, CreateTextCompletionParams } from './features/text-completion'; import { CreateChatCompletionParams, createChatCompletion } from './features/chat-completion'; -import { CreateImageParams, createImage } from './features/image'; +import { + CreateImageParams, + createImage, + CreateImageEditParams, + createImageEdit, + CreateImageVariationParams, + createImageVariation, +} from './features/image'; import { CreateEditParams, createEdit } from './features/edit'; +import { listModels } from './features/model'; export type DataEventCallback = (data: D) => void; @@ -14,8 +22,11 @@ export type ErrorEventCallback = (event: Error) => void; export interface PlatformEventEmitter extends NodeJS.EventEmitter { createChatCompletion(params: CreateChatCompletionParams): void; createImage(params: CreateImageParams): void; + createImageEdit(params: CreateImageEditParams): void; + createImageVariation(params: CreateImageVariationParams): void; createCompletion(params: CreateTextCompletionParams): void; createEdit(params: CreateEditParams): void; + listModels(): void; on(event: 'data', callback: DataEventCallback): this; on(event: 'end', callback: () => void): this; on(event: 'error', callback: ErrorEventCallback): this; @@ -26,29 +37,48 @@ export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEv readonly createImage: PlatformEventEmitter['createImage']; + readonly createImageEdit: PlatformEventEmitter['createImageEdit']; + + readonly createImageVariation: PlatformEventEmitter['createImageVariation']; + readonly createChatCompletion: PlatformEventEmitter['createChatCompletion']; readonly createEdit: PlatformEventEmitter['createEdit']; + readonly listModels: PlatformEventEmitter['listModels']; + constructor(configParams: Configuration) { super(); - const headers: Record = { + const platformHeaders: Record = { Authorization: `Bearer ${configParams.apiKey}`, }; if (configParams.organizationId) { - headers['OpenAI-Organization'] = configParams.organizationId; + platformHeaders['OpenAI-Organization'] = configParams.organizationId; } const { fetch: fetchInstance } = fetchPonyfill(); - const doFetch = (method: string, path: string, body: Record) => { + const doFetch = (method: string, path: string, body?: DoFetchBody) => { + const requestHeaders = { + ...platformHeaders, + }; + + let theBody: BodyInit; + + if ( + body instanceof FormData + || body instanceof URLSearchParams + ) { + theBody = body; + } else { + theBody = JSON.stringify(body); + requestHeaders['Content-Type'] = 'application/json'; + } + const theFetchParams = { method, - headers: { - ...headers, - 'Content-Type': 'application/json', - }, - body: JSON.stringify(body), + headers: requestHeaders, + body: theBody, }; const url = new URL( @@ -85,7 +115,10 @@ export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEv }); } }; + this.listModels = listModels.bind(this, doFetch); this.createImage = createImage.bind(this, doFetch); + this.createImageVariation = createImageVariation.bind(this, doFetch); + this.createImageEdit = createImageEdit.bind(this, doFetch); this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream); this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream); this.createEdit = createEdit.bind(this, doFetch); diff --git a/src/platforms/openai/features/chat-completion.ts b/src/platforms/openai/features/chat-completion.ts index 6ce7059..b65455c 100644 --- a/src/platforms/openai/features/chat-completion.ts +++ b/src/platforms/openai/features/chat-completion.ts @@ -4,7 +4,7 @@ import { DataEventId, DoFetch, PlatformError, - PlatformResponse, + CreatedResource, } from '../common'; import { UsageMetadata, @@ -41,7 +41,7 @@ export enum DataEventObjectType { export interface CreateChatCompletionDataEvent< C extends Partial -> extends PlatformResponse { +> extends CreatedResource { id: DataEventId; object: DataEventObjectType; model: ChatCompletionModel; @@ -72,7 +72,7 @@ export function createChatCompletion( frequency_penalty: params.frequencyPenalty ?? 0, logit_bias: params.logitBias ?? {}, user: params.user, - }) + } as Record) .then(async (response) => { if (!response.ok) { this.emit('error', new PlatformError( diff --git a/src/platforms/openai/features/edit.ts b/src/platforms/openai/features/edit.ts index 0104bc6..a9f6a96 100644 --- a/src/platforms/openai/features/edit.ts +++ b/src/platforms/openai/features/edit.ts @@ -2,7 +2,7 @@ import { ChoiceBase, DoFetch, PlatformError, - PlatformResponse, + CreatedResource, } from '../common'; import { UsageMetadata, @@ -26,7 +26,7 @@ export interface EditChoice extends ChoiceBase { text: string; } -export interface CreateEditDataEvent extends PlatformResponse, UsageMetadata { +export interface CreateEditDataEvent extends CreatedResource, UsageMetadata { object: DataEventObjectType; choices: EditChoice[]; } @@ -43,7 +43,7 @@ export function createEdit( n: params.n ?? 1, temperature: params.temperature ?? 1, top_p: params.topP ?? 1, - }) + } as Record) .then(async (response) => { if (!response.ok) { this.emit('error', new PlatformError( diff --git a/src/platforms/openai/features/image.ts b/src/platforms/openai/features/image.ts index 25d5953..fa440f4 100644 --- a/src/platforms/openai/features/image.ts +++ b/src/platforms/openai/features/image.ts @@ -1,16 +1,17 @@ +import * as FormDataUtils from '../../../packages/form-data'; import { DoFetch, PlatformError, - PlatformResponse, + CreatedResource, } from '../common'; -export enum CreateImageSize { +export enum ImageSize { SQUARE_256 = '256x256', SQUARE_512 = '512x512', SQUARE_1024 = '1024x1024', } -export enum CreateImageResponseFormat { +export enum ImageResponseFormat { URL = 'url', BASE64_JSON = 'b64_json', } @@ -18,15 +19,15 @@ export enum CreateImageResponseFormat { export interface CreateImageParams { prompt: string; n? : number; - size?: CreateImageSize; + size?: ImageSize; user?: string; } -export interface CreateImageData { +export interface ImageData { b64_json: string; } -export interface CreateImageDataEvent extends PlatformResponse { +export interface ImageDataEvent extends CreatedResource { data: Buffer[]; } @@ -38,10 +39,10 @@ export function createImage( doFetch('POST', '/images/generations', { prompt: params.prompt, n: params.n ?? 1, - size: params.size ?? CreateImageSize.SQUARE_1024, + size: params.size ?? ImageSize.SQUARE_1024, user: params.user, - response_format: CreateImageResponseFormat.BASE64_JSON, - }) + response_format: ImageResponseFormat.BASE64_JSON, + } as Record) .then(async (response) => { if (!response.ok) { this.emit('error', new PlatformError( @@ -54,7 +55,101 @@ export function createImage( } const responseData = await response.json() as Record; - const data = responseData.data as CreateImageData[]; + const data = responseData.data as ImageData[]; + this.emit('data', { + ...responseData, + data: data.map((item) => Buffer.from(item.b64_json, 'base64')), + }); + this.emit('end'); + }) + .catch((err) => { + this.emit('error', err as Error); + this.emit('end'); + }); + + return this; +} + +export interface CreateImageEditParams { + image: Buffer; + mask?: Buffer; + prompt: string; + n?: number; + size?: ImageSize; + user?: string; +} + +export function createImageEdit( + this: NodeJS.EventEmitter, + doFetch: DoFetch, + params: CreateImageEditParams, +) { + doFetch('POST', '/images/edits', FormDataUtils.fromJson({ + prompt: params.prompt, + image: params.image, + mask: params.mask, + n: params.n ?? 1, + size: params.size ?? ImageSize.SQUARE_1024, + response_format: ImageResponseFormat.BASE64_JSON, + })) + .then(async (response) => { + if (!response.ok) { + this.emit('error', new PlatformError( + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions + `Request from platform returned with status: ${response.status}`, + response, + )); + this.emit('end'); + return; + } + + const responseData = await response.json() as Record; + const data = responseData.data as ImageData[]; + this.emit('data', { + ...responseData, + data: data.map((item) => Buffer.from(item.b64_json, 'base64')), + }); + this.emit('end'); + }) + .catch((err) => { + this.emit('error', err as Error); + this.emit('end'); + }); + + return this; +} + +export interface CreateImageVariationParams { + image: Buffer; + n?: number; + size?: ImageSize; + user?: string; +} + +export function createImageVariation( + this: NodeJS.EventEmitter, + doFetch: DoFetch, + params: CreateImageVariationParams, +) { + doFetch('POST', '/images/variations', FormDataUtils.fromJson({ + image: params.image, + n: params.n ?? 1, + size: params.size ?? ImageSize.SQUARE_1024, + response_format: ImageResponseFormat.BASE64_JSON, + })) + .then(async (response) => { + if (!response.ok) { + this.emit('error', new PlatformError( + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions + `Request from platform returned with status: ${response.status}`, + response, + )); + this.emit('end'); + return; + } + + const responseData = await response.json() as Record; + const data = responseData.data as ImageData[]; this.emit('data', { ...responseData, data: data.map((item) => Buffer.from(item.b64_json, 'base64')), diff --git a/src/platforms/openai/features/model.ts b/src/platforms/openai/features/model.ts new file mode 100644 index 0000000..c22f4fc --- /dev/null +++ b/src/platforms/openai/features/model.ts @@ -0,0 +1,40 @@ +import { DoFetch, PlatformError } from '../common'; + +export enum DataEventObjectType { + MODEL = 'model', +} + +export interface ModelData { + id: string; + object: DataEventObjectType, + owned_by: string; + permission: string[]; +} + +export function listModels( + this: NodeJS.EventEmitter, + doFetch: DoFetch, +) { + doFetch('GET', '/models') + .then(async (response) => { + if (!response.ok) { + this.emit('error', new PlatformError( + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions + `Request from platform returned with status: ${response.status}`, + response, + )); + this.emit('end'); + return; + } + + const responseData = await response.json() as Record; + this.emit('data', responseData.data as ModelData[]); + this.emit('end'); + }) + .catch((err) => { + this.emit('error', err as Error); + this.emit('end'); + }); + + return this; +} diff --git a/src/platforms/openai/features/text-completion.ts b/src/platforms/openai/features/text-completion.ts index a838c32..c169667 100644 --- a/src/platforms/openai/features/text-completion.ts +++ b/src/platforms/openai/features/text-completion.ts @@ -5,7 +5,7 @@ import { DoFetch, FinishableChoiceBase, PlatformError, - PlatformResponse, + CreatedResource, } from '../common'; import { UsageMetadata, @@ -39,7 +39,7 @@ export interface TextCompletionChoice extends FinishableChoiceBase { export interface CreateTextCompletionDataEvent< C extends Partial -> extends PlatformResponse { +> extends CreatedResource { id: DataEventId; object: DataEventObjectType; model: TextCompletionModel; @@ -73,7 +73,7 @@ export function createTextCompletion( user: params.user, presence_penalty: params.presencePenalty, frequency_penalty: params.frequencyPenalty, - }) + } as Record) .then(async (response) => { if (!response.ok) { this.emit('error', new PlatformError( diff --git a/src/platforms/openai/index.ts b/src/platforms/openai/index.ts index 63dc4a5..51f7c99 100644 --- a/src/platforms/openai/index.ts +++ b/src/platforms/openai/index.ts @@ -23,7 +23,7 @@ export { CreateEditDataEvent, DataEventObjectType as EditDataEventObjectType, } from './features/edit'; -export { CreateImageDataEvent, CreateImageSize } from './features/image'; +export { ImageDataEvent, ImageSize } from './features/image'; export const PLATFORM_ID = 'openai' as const;