diff --git a/README.md b/README.md index 2cbac30..61f9cc0 100644 --- a/README.md +++ b/README.md @@ -15,10 +15,23 @@ Many-in-one AI client. - [X] generate - [X] edit - [X] variation - - [ ] embeddings + - [X] embeddings - [ ] audio + - [ ] transcription + - [ ] translation - [ ] files + - [ ] list + - [ ] upload + - [ ] delete + - [ ] retrieve metadata + - [ ] retrieve content - [ ] fine-tunes + - [ ] create + - [ ] list + - [ ] retrieve + - [ ] cancel + - [ ] list events + - [ ] delete model - [ ] moderations * ElevenLabs - [X] TTS (stream) diff --git a/src/platforms/openai/common.ts b/src/platforms/openai/common.ts index da4c01a..db277d1 100644 --- a/src/platforms/openai/common.ts +++ b/src/platforms/openai/common.ts @@ -30,4 +30,8 @@ export interface Configuration { baseUrl?: string; } +export enum ResponseObjectType { + LIST = 'list', +} + export const DEFAULT_BASE_URL = 'https://api.openai.com' as const; diff --git a/src/platforms/openai/events.ts b/src/platforms/openai/events.ts index 4804451..56efc03 100644 --- a/src/platforms/openai/events.ts +++ b/src/platforms/openai/events.ts @@ -1,6 +1,7 @@ import { PassThrough } from 'stream'; import { EventEmitter } from 'events'; import fetchPonyfill from 'fetch-ponyfill'; +import { DoFetchBody, processRequest } from '../../packages/request'; import * as AllPlatformsCommon from '../../common'; import { Configuration, DEFAULT_BASE_URL } from './common'; import { createTextCompletion, CreateTextCompletionParams } from './features/text-completion'; @@ -14,33 +15,39 @@ import { createImageVariation, } from './features/image'; import { CreateEditParams, createEdit } from './features/edit'; -import { listModels } from './features/model'; -import { DoFetchBody, processRequest } from '../../packages/request'; +import { listModels, ModelId, retrieveModel } from './features/model'; +import { createEmbedding, CreateEmbeddingParams } from './features/embedding'; export interface PlatformEventEmitter extends AllPlatformsCommon.PlatformEventEmitter { + listModels(): void; + retrieveModel(modelId: ModelId): void; + createCompletion(params: CreateTextCompletionParams): void; createChatCompletion(params: CreateChatCompletionParams): void; + createEdit(params: CreateEditParams): void; createImage(params: CreateImageParams): void; createImageEdit(params: CreateImageEditParams): void; createImageVariation(params: CreateImageVariationParams): void; - createCompletion(params: CreateTextCompletionParams): void; - createEdit(params: CreateEditParams): void; - listModels(): void; + createEmbedding(params: CreateEmbeddingParams): void; } export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEventEmitter { + readonly listModels: PlatformEventEmitter['listModels']; + + readonly retrieveModel: PlatformEventEmitter['retrieveModel']; + readonly createCompletion: PlatformEventEmitter['createCompletion']; + readonly createChatCompletion: PlatformEventEmitter['createChatCompletion']; + + readonly createEdit: PlatformEventEmitter['createEdit']; + readonly createImage: PlatformEventEmitter['createImage']; readonly createImageEdit: PlatformEventEmitter['createImageEdit']; readonly createImageVariation: PlatformEventEmitter['createImageVariation']; - readonly createChatCompletion: PlatformEventEmitter['createChatCompletion']; - - readonly createEdit: PlatformEventEmitter['createEdit']; - - readonly listModels: PlatformEventEmitter['listModels']; + readonly createEmbedding: PlatformEventEmitter['createEmbedding']; constructor(configParams: Configuration) { super(); @@ -110,11 +117,13 @@ export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEv } }; this.listModels = listModels.bind(this, doFetch); - this.createImage = createImage.bind(this, doFetch); - this.createImageVariation = createImageVariation.bind(this, doFetch); - this.createImageEdit = createImageEdit.bind(this, doFetch); + this.retrieveModel = retrieveModel.bind(this, doFetch); this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream); this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream); this.createEdit = createEdit.bind(this, doFetch); + this.createImage = createImage.bind(this, doFetch); + this.createImageEdit = createImageEdit.bind(this, doFetch); + this.createImageVariation = createImageVariation.bind(this, doFetch); + this.createEmbedding = createEmbedding.bind(this, doFetch); } } diff --git a/src/platforms/openai/features/chat-completion.ts b/src/platforms/openai/features/chat-completion.ts index ea5bb42..c61fb0d 100644 --- a/src/platforms/openai/features/chat-completion.ts +++ b/src/platforms/openai/features/chat-completion.ts @@ -4,6 +4,7 @@ import { CreatedResource, } from '../common'; import { + CompletionUsage, UsageMetadata, } from '../usage'; import { ChatCompletionModel } from '../models'; @@ -48,7 +49,8 @@ export interface CreateChatCompletionDataEvent< } export interface ChatCompletion - extends CreateChatCompletionDataEvent>, UsageMetadata {} + extends CreateChatCompletionDataEvent>, + UsageMetadata {} export type ChatCompletionChunkDataEvent = CreateChatCompletionDataEvent; diff --git a/src/platforms/openai/features/edit.ts b/src/platforms/openai/features/edit.ts index f4ff64d..b9e5a13 100644 --- a/src/platforms/openai/features/edit.ts +++ b/src/platforms/openai/features/edit.ts @@ -3,6 +3,7 @@ import { CreatedResource, } from '../common'; import { + CompletionUsage, UsageMetadata, } from '../usage'; import { EditModel } from '../models'; @@ -26,7 +27,7 @@ export interface EditChoice extends ChoiceBase { text: string; } -export interface CreateEditDataEvent extends CreatedResource, UsageMetadata { +export interface CreateEditDataEvent extends CreatedResource, UsageMetadata { object: DataEventObjectType; choices: EditChoice[]; } @@ -55,7 +56,7 @@ export function createEdit( return; } - const responseData = await response.json() as Record; + const responseData = await response.json() as CreateEditDataEvent; this.emit('data', responseData); this.emit('end'); }) diff --git a/src/platforms/openai/features/embedding.ts b/src/platforms/openai/features/embedding.ts new file mode 100644 index 0000000..90b6db5 --- /dev/null +++ b/src/platforms/openai/features/embedding.ts @@ -0,0 +1,60 @@ +import { DoFetch } from '../../../packages/request'; +import { EmbeddingModel } from '../models'; +import { PlatformApiError } from '../../../common'; +import { ResponseObjectType } from '../common'; +import { PromptUsage, UsageMetadata } from '../usage'; + +export enum DataEventObjectType { + EMBEDDING = 'embedding', +} + +export interface CreateEmbeddingParams { + model: EmbeddingModel; + input: string | number[] | string[] | number[][]; + user?: string; +} + +export interface Embedding { + object: DataEventObjectType; + embedding: number[]; + index: number; +} + +export interface CreateEmbeddingResponse extends UsageMetadata { + object: ResponseObjectType; + data: Embedding[]; + model: EmbeddingModel; +} + +export function createEmbedding( + this: NodeJS.EventEmitter, + doFetch: DoFetch, + params: CreateEmbeddingParams, +) { + doFetch('POST', '/embeddings', { + model: params.model, + input: params.input, + user: params.user, + }) + .then(async (response) => { + if (!response.ok) { + this.emit('error', new PlatformApiError( + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions + `Create chat completion returned with status: ${response.status}`, + response, + )); + this.emit('end'); + return; + } + + const responseData = await response.json() as CreateEmbeddingResponse; + this.emit('data', responseData); + this.emit('end'); + }) + .catch((err) => { + this.emit('error', err as Error); + this.emit('end'); + }); + + return this; +} diff --git a/src/platforms/openai/features/image.ts b/src/platforms/openai/features/image.ts index 4c8e924..7d84b9c 100644 --- a/src/platforms/openai/features/image.ts +++ b/src/platforms/openai/features/image.ts @@ -105,10 +105,11 @@ export function createImageEdit( const responseData = await response.json() as Record; const data = responseData.data as ImageData[]; - this.emit('data', { + const emittedData = { ...responseData, data: data.map((item) => Buffer.from(item.b64_json, 'base64')), - }); + } as ImageDataEvent; + this.emit('data', emittedData); this.emit('end'); }) .catch((err) => { diff --git a/src/platforms/openai/features/model.ts b/src/platforms/openai/features/model.ts index e0374ec..aa58cc9 100644 --- a/src/platforms/openai/features/model.ts +++ b/src/platforms/openai/features/model.ts @@ -1,17 +1,25 @@ import { DoFetch } from '../../../packages/request'; import { PlatformApiError } from '../../../common'; +import { ResponseObjectType } from '../common'; export enum DataEventObjectType { MODEL = 'model', } +export type ModelId = string; + export interface ModelData { - id: string; + id: ModelId; object: DataEventObjectType, owned_by: string; permission: string[]; } +export interface ListModelsResponse { + object: ResponseObjectType; + data: ModelData[]; +} + export function listModels( this: NodeJS.EventEmitter, doFetch: DoFetch, @@ -28,8 +36,37 @@ export function listModels( return; } - const responseData = await response.json() as Record; - this.emit('data', responseData.data as ModelData[]); + const responseData = await response.json() as ListModelsResponse; + this.emit('data', responseData); + this.emit('end'); + }) + .catch((err) => { + this.emit('error', err as Error); + this.emit('end'); + }); + + return this; +} + +export function retrieveModel( + this: NodeJS.EventEmitter, + doFetch: DoFetch, + modelId: ModelId, +) { + doFetch('GET', `/models/${modelId}`) + .then(async (response) => { + if (!response.ok) { + this.emit('error', new PlatformApiError( + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions + `Request from platform returned with status: ${response.status}`, + response, + )); + this.emit('end'); + return; + } + + const responseData = await response.json() as ModelData; + this.emit('data', responseData); this.emit('end'); }) .catch((err) => { diff --git a/src/platforms/openai/features/text-completion.ts b/src/platforms/openai/features/text-completion.ts index 31bd859..0f6f8cf 100644 --- a/src/platforms/openai/features/text-completion.ts +++ b/src/platforms/openai/features/text-completion.ts @@ -5,6 +5,7 @@ import { CreatedResource, } from '../common'; import { + CompletionUsage, UsageMetadata, } from '../usage'; import { ConsumeStream, DoFetch } from '../../../packages/request'; @@ -46,7 +47,8 @@ export interface CreateTextCompletionDataEvent< } export interface TextCompletion - extends CreateTextCompletionDataEvent>, UsageMetadata {} + extends CreateTextCompletionDataEvent>, + UsageMetadata {} export type TextCompletionChunkDataEvent = CreateTextCompletionDataEvent; diff --git a/src/platforms/openai/usage.ts b/src/platforms/openai/usage.ts index a2d6f8b..9240583 100644 --- a/src/platforms/openai/usage.ts +++ b/src/platforms/openai/usage.ts @@ -53,12 +53,15 @@ export const getPromptTokens = (normalizedMessageArray: MessageObject[], model: return getTokens(chatTokens, model); }; -export interface Usage { +export interface PromptUsage { prompt_tokens: number; - completion_tokens: number; total_tokens: number; } -export interface UsageMetadata { - usage: Usage; +export interface CompletionUsage extends PromptUsage { + completion_tokens: number; +} + +export interface UsageMetadata { + usage: U; } diff --git a/test/platforms/openai/usage.test.ts b/test/platforms/openai/usage.test.ts index 4c6e158..07622be 100644 --- a/test/platforms/openai/usage.test.ts +++ b/test/platforms/openai/usage.test.ts @@ -3,7 +3,7 @@ import { ChatCompletionModel, getPromptTokens, getTokens, MessageRole, - Usage, + CompletionUsage, } from '../../../src/platforms/openai'; describe('OpenAI', () => { @@ -83,7 +83,7 @@ describe('OpenAI', () => { request.model, ) .length; - const usage: Usage = { + const usage: CompletionUsage = { prompt_tokens: promptTokensLength, completion_tokens: completionTokensLength, total_tokens: promptTokensLength + completionTokensLength, @@ -132,7 +132,7 @@ describe('OpenAI', () => { request.model, ) .length; - const usage: Usage = { + const usage: CompletionUsage = { prompt_tokens: promptTokensLength, completion_tokens: completionTokensLength, total_tokens: promptTokensLength + completionTokensLength,