Browse Source

Add embedding functions

Consume API endpoints for embeddings.
master
TheoryOfNekomata 1 year ago
parent
commit
d6a03b147c
11 changed files with 162 additions and 30 deletions
  1. +14
    -1
      README.md
  2. +4
    -0
      src/platforms/openai/common.ts
  3. +22
    -13
      src/platforms/openai/events.ts
  4. +3
    -1
      src/platforms/openai/features/chat-completion.ts
  5. +3
    -2
      src/platforms/openai/features/edit.ts
  6. +60
    -0
      src/platforms/openai/features/embedding.ts
  7. +3
    -2
      src/platforms/openai/features/image.ts
  8. +40
    -3
      src/platforms/openai/features/model.ts
  9. +3
    -1
      src/platforms/openai/features/text-completion.ts
  10. +7
    -4
      src/platforms/openai/usage.ts
  11. +3
    -3
      test/platforms/openai/usage.test.ts

+ 14
- 1
README.md View File

@@ -15,10 +15,23 @@ Many-in-one AI client.
- [X] generate - [X] generate
- [X] edit - [X] edit
- [X] variation - [X] variation
- [ ] embeddings
- [X] embeddings
- [ ] audio - [ ] audio
- [ ] transcription
- [ ] translation
- [ ] files - [ ] files
- [ ] list
- [ ] upload
- [ ] delete
- [ ] retrieve metadata
- [ ] retrieve content
- [ ] fine-tunes - [ ] fine-tunes
- [ ] create
- [ ] list
- [ ] retrieve
- [ ] cancel
- [ ] list events
- [ ] delete model
- [ ] moderations - [ ] moderations
* ElevenLabs * ElevenLabs
- [X] TTS (stream) - [X] TTS (stream)


+ 4
- 0
src/platforms/openai/common.ts View File

@@ -30,4 +30,8 @@ export interface Configuration {
baseUrl?: string; baseUrl?: string;
} }


export enum ResponseObjectType {
LIST = 'list',
}

export const DEFAULT_BASE_URL = 'https://api.openai.com' as const; export const DEFAULT_BASE_URL = 'https://api.openai.com' as const;

+ 22
- 13
src/platforms/openai/events.ts View File

@@ -1,6 +1,7 @@
import { PassThrough } from 'stream'; import { PassThrough } from 'stream';
import { EventEmitter } from 'events'; import { EventEmitter } from 'events';
import fetchPonyfill from 'fetch-ponyfill'; import fetchPonyfill from 'fetch-ponyfill';
import { DoFetchBody, processRequest } from '../../packages/request';
import * as AllPlatformsCommon from '../../common'; import * as AllPlatformsCommon from '../../common';
import { Configuration, DEFAULT_BASE_URL } from './common'; import { Configuration, DEFAULT_BASE_URL } from './common';
import { createTextCompletion, CreateTextCompletionParams } from './features/text-completion'; import { createTextCompletion, CreateTextCompletionParams } from './features/text-completion';
@@ -14,33 +15,39 @@ import {
createImageVariation, createImageVariation,
} from './features/image'; } from './features/image';
import { CreateEditParams, createEdit } from './features/edit'; import { CreateEditParams, createEdit } from './features/edit';
import { listModels } from './features/model';
import { DoFetchBody, processRequest } from '../../packages/request';
import { listModels, ModelId, retrieveModel } from './features/model';
import { createEmbedding, CreateEmbeddingParams } from './features/embedding';


export interface PlatformEventEmitter extends AllPlatformsCommon.PlatformEventEmitter { export interface PlatformEventEmitter extends AllPlatformsCommon.PlatformEventEmitter {
listModels(): void;
retrieveModel(modelId: ModelId): void;
createCompletion(params: CreateTextCompletionParams): void;
createChatCompletion(params: CreateChatCompletionParams): void; createChatCompletion(params: CreateChatCompletionParams): void;
createEdit(params: CreateEditParams): void;
createImage(params: CreateImageParams): void; createImage(params: CreateImageParams): void;
createImageEdit(params: CreateImageEditParams): void; createImageEdit(params: CreateImageEditParams): void;
createImageVariation(params: CreateImageVariationParams): void; createImageVariation(params: CreateImageVariationParams): void;
createCompletion(params: CreateTextCompletionParams): void;
createEdit(params: CreateEditParams): void;
listModels(): void;
createEmbedding(params: CreateEmbeddingParams): void;
} }


export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEventEmitter { export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEventEmitter {
readonly listModels: PlatformEventEmitter['listModels'];

readonly retrieveModel: PlatformEventEmitter['retrieveModel'];

readonly createCompletion: PlatformEventEmitter['createCompletion']; readonly createCompletion: PlatformEventEmitter['createCompletion'];


readonly createChatCompletion: PlatformEventEmitter['createChatCompletion'];

readonly createEdit: PlatformEventEmitter['createEdit'];

readonly createImage: PlatformEventEmitter['createImage']; readonly createImage: PlatformEventEmitter['createImage'];


readonly createImageEdit: PlatformEventEmitter['createImageEdit']; readonly createImageEdit: PlatformEventEmitter['createImageEdit'];


readonly createImageVariation: PlatformEventEmitter['createImageVariation']; readonly createImageVariation: PlatformEventEmitter['createImageVariation'];


readonly createChatCompletion: PlatformEventEmitter['createChatCompletion'];

readonly createEdit: PlatformEventEmitter['createEdit'];

readonly listModels: PlatformEventEmitter['listModels'];
readonly createEmbedding: PlatformEventEmitter['createEmbedding'];


constructor(configParams: Configuration) { constructor(configParams: Configuration) {
super(); super();
@@ -110,11 +117,13 @@ export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEv
} }
}; };
this.listModels = listModels.bind(this, doFetch); this.listModels = listModels.bind(this, doFetch);
this.createImage = createImage.bind(this, doFetch);
this.createImageVariation = createImageVariation.bind(this, doFetch);
this.createImageEdit = createImageEdit.bind(this, doFetch);
this.retrieveModel = retrieveModel.bind(this, doFetch);
this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream); this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream);
this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream); this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream);
this.createEdit = createEdit.bind(this, doFetch); this.createEdit = createEdit.bind(this, doFetch);
this.createImage = createImage.bind(this, doFetch);
this.createImageEdit = createImageEdit.bind(this, doFetch);
this.createImageVariation = createImageVariation.bind(this, doFetch);
this.createEmbedding = createEmbedding.bind(this, doFetch);
} }
} }

+ 3
- 1
src/platforms/openai/features/chat-completion.ts View File

@@ -4,6 +4,7 @@ import {
CreatedResource, CreatedResource,
} from '../common'; } from '../common';
import { import {
CompletionUsage,
UsageMetadata, UsageMetadata,
} from '../usage'; } from '../usage';
import { ChatCompletionModel } from '../models'; import { ChatCompletionModel } from '../models';
@@ -48,7 +49,8 @@ export interface CreateChatCompletionDataEvent<
} }


export interface ChatCompletion export interface ChatCompletion
extends CreateChatCompletionDataEvent<Partial<ChatCompletionChoice>>, UsageMetadata {}
extends CreateChatCompletionDataEvent<Partial<ChatCompletionChoice>>,
UsageMetadata<CompletionUsage> {}


export type ChatCompletionChunkDataEvent = CreateChatCompletionDataEvent<ChatCompletionChunkChoice>; export type ChatCompletionChunkDataEvent = CreateChatCompletionDataEvent<ChatCompletionChunkChoice>;




+ 3
- 2
src/platforms/openai/features/edit.ts View File

@@ -3,6 +3,7 @@ import {
CreatedResource, CreatedResource,
} from '../common'; } from '../common';
import { import {
CompletionUsage,
UsageMetadata, UsageMetadata,
} from '../usage'; } from '../usage';
import { EditModel } from '../models'; import { EditModel } from '../models';
@@ -26,7 +27,7 @@ export interface EditChoice extends ChoiceBase {
text: string; text: string;
} }


export interface CreateEditDataEvent extends CreatedResource, UsageMetadata {
export interface CreateEditDataEvent extends CreatedResource, UsageMetadata<CompletionUsage> {
object: DataEventObjectType; object: DataEventObjectType;
choices: EditChoice[]; choices: EditChoice[];
} }
@@ -55,7 +56,7 @@ export function createEdit(
return; return;
} }


const responseData = await response.json() as Record<string, unknown>;
const responseData = await response.json() as CreateEditDataEvent;
this.emit('data', responseData); this.emit('data', responseData);
this.emit('end'); this.emit('end');
}) })


+ 60
- 0
src/platforms/openai/features/embedding.ts View File

@@ -0,0 +1,60 @@
import { DoFetch } from '../../../packages/request';
import { EmbeddingModel } from '../models';
import { PlatformApiError } from '../../../common';
import { ResponseObjectType } from '../common';
import { PromptUsage, UsageMetadata } from '../usage';

export enum DataEventObjectType {
EMBEDDING = 'embedding',
}

export interface CreateEmbeddingParams {
model: EmbeddingModel;
input: string | number[] | string[] | number[][];
user?: string;
}

export interface Embedding {
object: DataEventObjectType;
embedding: number[];
index: number;
}

export interface CreateEmbeddingResponse extends UsageMetadata<PromptUsage> {
object: ResponseObjectType;
data: Embedding[];
model: EmbeddingModel;
}

export function createEmbedding(
this: NodeJS.EventEmitter,
doFetch: DoFetch,
params: CreateEmbeddingParams,
) {
doFetch('POST', '/embeddings', {
model: params.model,
input: params.input,
user: params.user,
})
.then(async (response) => {
if (!response.ok) {
this.emit('error', new PlatformApiError(
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`Create chat completion returned with status: ${response.status}`,
response,
));
this.emit('end');
return;
}

const responseData = await response.json() as CreateEmbeddingResponse;
this.emit('data', responseData);
this.emit('end');
})
.catch((err) => {
this.emit('error', err as Error);
this.emit('end');
});

return this;
}

+ 3
- 2
src/platforms/openai/features/image.ts View File

@@ -105,10 +105,11 @@ export function createImageEdit(


const responseData = await response.json() as Record<string, unknown>; const responseData = await response.json() as Record<string, unknown>;
const data = responseData.data as ImageData[]; const data = responseData.data as ImageData[];
this.emit('data', {
const emittedData = {
...responseData, ...responseData,
data: data.map((item) => Buffer.from(item.b64_json, 'base64')), data: data.map((item) => Buffer.from(item.b64_json, 'base64')),
});
} as ImageDataEvent;
this.emit('data', emittedData);
this.emit('end'); this.emit('end');
}) })
.catch((err) => { .catch((err) => {


+ 40
- 3
src/platforms/openai/features/model.ts View File

@@ -1,17 +1,25 @@
import { DoFetch } from '../../../packages/request'; import { DoFetch } from '../../../packages/request';
import { PlatformApiError } from '../../../common'; import { PlatformApiError } from '../../../common';
import { ResponseObjectType } from '../common';


export enum DataEventObjectType { export enum DataEventObjectType {
MODEL = 'model', MODEL = 'model',
} }


export type ModelId = string;

export interface ModelData { export interface ModelData {
id: string;
id: ModelId;
object: DataEventObjectType, object: DataEventObjectType,
owned_by: string; owned_by: string;
permission: string[]; permission: string[];
} }


export interface ListModelsResponse {
object: ResponseObjectType;
data: ModelData[];
}

export function listModels( export function listModels(
this: NodeJS.EventEmitter, this: NodeJS.EventEmitter,
doFetch: DoFetch, doFetch: DoFetch,
@@ -28,8 +36,37 @@ export function listModels(
return; return;
} }


const responseData = await response.json() as Record<string, unknown>;
this.emit('data', responseData.data as ModelData[]);
const responseData = await response.json() as ListModelsResponse;
this.emit('data', responseData);
this.emit('end');
})
.catch((err) => {
this.emit('error', err as Error);
this.emit('end');
});

return this;
}

export function retrieveModel(
this: NodeJS.EventEmitter,
doFetch: DoFetch,
modelId: ModelId,
) {
doFetch('GET', `/models/${modelId}`)
.then(async (response) => {
if (!response.ok) {
this.emit('error', new PlatformApiError(
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`Request from platform returned with status: ${response.status}`,
response,
));
this.emit('end');
return;
}

const responseData = await response.json() as ModelData;
this.emit('data', responseData);
this.emit('end'); this.emit('end');
}) })
.catch((err) => { .catch((err) => {


+ 3
- 1
src/platforms/openai/features/text-completion.ts View File

@@ -5,6 +5,7 @@ import {
CreatedResource, CreatedResource,
} from '../common'; } from '../common';
import { import {
CompletionUsage,
UsageMetadata, UsageMetadata,
} from '../usage'; } from '../usage';
import { ConsumeStream, DoFetch } from '../../../packages/request'; import { ConsumeStream, DoFetch } from '../../../packages/request';
@@ -46,7 +47,8 @@ export interface CreateTextCompletionDataEvent<
} }


export interface TextCompletion export interface TextCompletion
extends CreateTextCompletionDataEvent<Partial<TextCompletionChoice>>, UsageMetadata {}
extends CreateTextCompletionDataEvent<Partial<TextCompletionChoice>>,
UsageMetadata<CompletionUsage> {}


export type TextCompletionChunkDataEvent = CreateTextCompletionDataEvent<TextCompletionChoice>; export type TextCompletionChunkDataEvent = CreateTextCompletionDataEvent<TextCompletionChoice>;




+ 7
- 4
src/platforms/openai/usage.ts View File

@@ -53,12 +53,15 @@ export const getPromptTokens = (normalizedMessageArray: MessageObject[], model:
return getTokens(chatTokens, model); return getTokens(chatTokens, model);
}; };


export interface Usage {
export interface PromptUsage {
prompt_tokens: number; prompt_tokens: number;
completion_tokens: number;
total_tokens: number; total_tokens: number;
} }


export interface UsageMetadata {
usage: Usage;
export interface CompletionUsage extends PromptUsage {
completion_tokens: number;
}

export interface UsageMetadata<U extends PromptUsage> {
usage: U;
} }

+ 3
- 3
test/platforms/openai/usage.test.ts View File

@@ -3,7 +3,7 @@ import {
ChatCompletionModel, ChatCompletionModel,
getPromptTokens, getTokens, getPromptTokens, getTokens,
MessageRole, MessageRole,
Usage,
CompletionUsage,
} from '../../../src/platforms/openai'; } from '../../../src/platforms/openai';


describe('OpenAI', () => { describe('OpenAI', () => {
@@ -83,7 +83,7 @@ describe('OpenAI', () => {
request.model, request.model,
) )
.length; .length;
const usage: Usage = {
const usage: CompletionUsage = {
prompt_tokens: promptTokensLength, prompt_tokens: promptTokensLength,
completion_tokens: completionTokensLength, completion_tokens: completionTokensLength,
total_tokens: promptTokensLength + completionTokensLength, total_tokens: promptTokensLength + completionTokensLength,
@@ -132,7 +132,7 @@ describe('OpenAI', () => {
request.model, request.model,
) )
.length; .length;
const usage: Usage = {
const usage: CompletionUsage = {
prompt_tokens: promptTokensLength, prompt_tokens: promptTokensLength,
completion_tokens: completionTokensLength, completion_tokens: completionTokensLength,
total_tokens: promptTokensLength + completionTokensLength, total_tokens: promptTokensLength + completionTokensLength,


Loading…
Cancel
Save