Browse Source

Add embedding functions

Consume API endpoints for embeddings.
master
TheoryOfNekomata 1 year ago
parent
commit
d6a03b147c
11 changed files with 162 additions and 30 deletions
  1. +14
    -1
      README.md
  2. +4
    -0
      src/platforms/openai/common.ts
  3. +22
    -13
      src/platforms/openai/events.ts
  4. +3
    -1
      src/platforms/openai/features/chat-completion.ts
  5. +3
    -2
      src/platforms/openai/features/edit.ts
  6. +60
    -0
      src/platforms/openai/features/embedding.ts
  7. +3
    -2
      src/platforms/openai/features/image.ts
  8. +40
    -3
      src/platforms/openai/features/model.ts
  9. +3
    -1
      src/platforms/openai/features/text-completion.ts
  10. +7
    -4
      src/platforms/openai/usage.ts
  11. +3
    -3
      test/platforms/openai/usage.test.ts

+ 14
- 1
README.md View File

@@ -15,10 +15,23 @@ Many-in-one AI client.
- [X] generate
- [X] edit
- [X] variation
- [ ] embeddings
- [X] embeddings
- [ ] audio
- [ ] transcription
- [ ] translation
- [ ] files
- [ ] list
- [ ] upload
- [ ] delete
- [ ] retrieve metadata
- [ ] retrieve content
- [ ] fine-tunes
- [ ] create
- [ ] list
- [ ] retrieve
- [ ] cancel
- [ ] list events
- [ ] delete model
- [ ] moderations
* ElevenLabs
- [X] TTS (stream)


+ 4
- 0
src/platforms/openai/common.ts View File

@@ -30,4 +30,8 @@ export interface Configuration {
baseUrl?: string;
}

export enum ResponseObjectType {
LIST = 'list',
}

export const DEFAULT_BASE_URL = 'https://api.openai.com' as const;

+ 22
- 13
src/platforms/openai/events.ts View File

@@ -1,6 +1,7 @@
import { PassThrough } from 'stream';
import { EventEmitter } from 'events';
import fetchPonyfill from 'fetch-ponyfill';
import { DoFetchBody, processRequest } from '../../packages/request';
import * as AllPlatformsCommon from '../../common';
import { Configuration, DEFAULT_BASE_URL } from './common';
import { createTextCompletion, CreateTextCompletionParams } from './features/text-completion';
@@ -14,33 +15,39 @@ import {
createImageVariation,
} from './features/image';
import { CreateEditParams, createEdit } from './features/edit';
import { listModels } from './features/model';
import { DoFetchBody, processRequest } from '../../packages/request';
import { listModels, ModelId, retrieveModel } from './features/model';
import { createEmbedding, CreateEmbeddingParams } from './features/embedding';

export interface PlatformEventEmitter extends AllPlatformsCommon.PlatformEventEmitter {
listModels(): void;
retrieveModel(modelId: ModelId): void;
createCompletion(params: CreateTextCompletionParams): void;
createChatCompletion(params: CreateChatCompletionParams): void;
createEdit(params: CreateEditParams): void;
createImage(params: CreateImageParams): void;
createImageEdit(params: CreateImageEditParams): void;
createImageVariation(params: CreateImageVariationParams): void;
createCompletion(params: CreateTextCompletionParams): void;
createEdit(params: CreateEditParams): void;
listModels(): void;
createEmbedding(params: CreateEmbeddingParams): void;
}

export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEventEmitter {
readonly listModels: PlatformEventEmitter['listModels'];

readonly retrieveModel: PlatformEventEmitter['retrieveModel'];

readonly createCompletion: PlatformEventEmitter['createCompletion'];

readonly createChatCompletion: PlatformEventEmitter['createChatCompletion'];

readonly createEdit: PlatformEventEmitter['createEdit'];

readonly createImage: PlatformEventEmitter['createImage'];

readonly createImageEdit: PlatformEventEmitter['createImageEdit'];

readonly createImageVariation: PlatformEventEmitter['createImageVariation'];

readonly createChatCompletion: PlatformEventEmitter['createChatCompletion'];

readonly createEdit: PlatformEventEmitter['createEdit'];

readonly listModels: PlatformEventEmitter['listModels'];
readonly createEmbedding: PlatformEventEmitter['createEmbedding'];

constructor(configParams: Configuration) {
super();
@@ -110,11 +117,13 @@ export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEv
}
};
this.listModels = listModels.bind(this, doFetch);
this.createImage = createImage.bind(this, doFetch);
this.createImageVariation = createImageVariation.bind(this, doFetch);
this.createImageEdit = createImageEdit.bind(this, doFetch);
this.retrieveModel = retrieveModel.bind(this, doFetch);
this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream);
this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream);
this.createEdit = createEdit.bind(this, doFetch);
this.createImage = createImage.bind(this, doFetch);
this.createImageEdit = createImageEdit.bind(this, doFetch);
this.createImageVariation = createImageVariation.bind(this, doFetch);
this.createEmbedding = createEmbedding.bind(this, doFetch);
}
}

+ 3
- 1
src/platforms/openai/features/chat-completion.ts View File

@@ -4,6 +4,7 @@ import {
CreatedResource,
} from '../common';
import {
CompletionUsage,
UsageMetadata,
} from '../usage';
import { ChatCompletionModel } from '../models';
@@ -48,7 +49,8 @@ export interface CreateChatCompletionDataEvent<
}

export interface ChatCompletion
extends CreateChatCompletionDataEvent<Partial<ChatCompletionChoice>>, UsageMetadata {}
extends CreateChatCompletionDataEvent<Partial<ChatCompletionChoice>>,
UsageMetadata<CompletionUsage> {}

export type ChatCompletionChunkDataEvent = CreateChatCompletionDataEvent<ChatCompletionChunkChoice>;



+ 3
- 2
src/platforms/openai/features/edit.ts View File

@@ -3,6 +3,7 @@ import {
CreatedResource,
} from '../common';
import {
CompletionUsage,
UsageMetadata,
} from '../usage';
import { EditModel } from '../models';
@@ -26,7 +27,7 @@ export interface EditChoice extends ChoiceBase {
text: string;
}

export interface CreateEditDataEvent extends CreatedResource, UsageMetadata {
export interface CreateEditDataEvent extends CreatedResource, UsageMetadata<CompletionUsage> {
object: DataEventObjectType;
choices: EditChoice[];
}
@@ -55,7 +56,7 @@ export function createEdit(
return;
}

const responseData = await response.json() as Record<string, unknown>;
const responseData = await response.json() as CreateEditDataEvent;
this.emit('data', responseData);
this.emit('end');
})


+ 60
- 0
src/platforms/openai/features/embedding.ts View File

@@ -0,0 +1,60 @@
import { DoFetch } from '../../../packages/request';
import { EmbeddingModel } from '../models';
import { PlatformApiError } from '../../../common';
import { ResponseObjectType } from '../common';
import { PromptUsage, UsageMetadata } from '../usage';

export enum DataEventObjectType {
EMBEDDING = 'embedding',
}

export interface CreateEmbeddingParams {
model: EmbeddingModel;
input: string | number[] | string[] | number[][];
user?: string;
}

export interface Embedding {
object: DataEventObjectType;
embedding: number[];
index: number;
}

export interface CreateEmbeddingResponse extends UsageMetadata<PromptUsage> {
object: ResponseObjectType;
data: Embedding[];
model: EmbeddingModel;
}

export function createEmbedding(
this: NodeJS.EventEmitter,
doFetch: DoFetch,
params: CreateEmbeddingParams,
) {
doFetch('POST', '/embeddings', {
model: params.model,
input: params.input,
user: params.user,
})
.then(async (response) => {
if (!response.ok) {
this.emit('error', new PlatformApiError(
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`Create chat completion returned with status: ${response.status}`,
response,
));
this.emit('end');
return;
}

const responseData = await response.json() as CreateEmbeddingResponse;
this.emit('data', responseData);
this.emit('end');
})
.catch((err) => {
this.emit('error', err as Error);
this.emit('end');
});

return this;
}

+ 3
- 2
src/platforms/openai/features/image.ts View File

@@ -105,10 +105,11 @@ export function createImageEdit(

const responseData = await response.json() as Record<string, unknown>;
const data = responseData.data as ImageData[];
this.emit('data', {
const emittedData = {
...responseData,
data: data.map((item) => Buffer.from(item.b64_json, 'base64')),
});
} as ImageDataEvent;
this.emit('data', emittedData);
this.emit('end');
})
.catch((err) => {


+ 40
- 3
src/platforms/openai/features/model.ts View File

@@ -1,17 +1,25 @@
import { DoFetch } from '../../../packages/request';
import { PlatformApiError } from '../../../common';
import { ResponseObjectType } from '../common';

export enum DataEventObjectType {
MODEL = 'model',
}

export type ModelId = string;

export interface ModelData {
id: string;
id: ModelId;
object: DataEventObjectType,
owned_by: string;
permission: string[];
}

export interface ListModelsResponse {
object: ResponseObjectType;
data: ModelData[];
}

export function listModels(
this: NodeJS.EventEmitter,
doFetch: DoFetch,
@@ -28,8 +36,37 @@ export function listModels(
return;
}

const responseData = await response.json() as Record<string, unknown>;
this.emit('data', responseData.data as ModelData[]);
const responseData = await response.json() as ListModelsResponse;
this.emit('data', responseData);
this.emit('end');
})
.catch((err) => {
this.emit('error', err as Error);
this.emit('end');
});

return this;
}

export function retrieveModel(
this: NodeJS.EventEmitter,
doFetch: DoFetch,
modelId: ModelId,
) {
doFetch('GET', `/models/${modelId}`)
.then(async (response) => {
if (!response.ok) {
this.emit('error', new PlatformApiError(
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`Request from platform returned with status: ${response.status}`,
response,
));
this.emit('end');
return;
}

const responseData = await response.json() as ModelData;
this.emit('data', responseData);
this.emit('end');
})
.catch((err) => {


+ 3
- 1
src/platforms/openai/features/text-completion.ts View File

@@ -5,6 +5,7 @@ import {
CreatedResource,
} from '../common';
import {
CompletionUsage,
UsageMetadata,
} from '../usage';
import { ConsumeStream, DoFetch } from '../../../packages/request';
@@ -46,7 +47,8 @@ export interface CreateTextCompletionDataEvent<
}

export interface TextCompletion
extends CreateTextCompletionDataEvent<Partial<TextCompletionChoice>>, UsageMetadata {}
extends CreateTextCompletionDataEvent<Partial<TextCompletionChoice>>,
UsageMetadata<CompletionUsage> {}

export type TextCompletionChunkDataEvent = CreateTextCompletionDataEvent<TextCompletionChoice>;



+ 7
- 4
src/platforms/openai/usage.ts View File

@@ -53,12 +53,15 @@ export const getPromptTokens = (normalizedMessageArray: MessageObject[], model:
return getTokens(chatTokens, model);
};

export interface Usage {
export interface PromptUsage {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
}

export interface UsageMetadata {
usage: Usage;
export interface CompletionUsage extends PromptUsage {
completion_tokens: number;
}

export interface UsageMetadata<U extends PromptUsage> {
usage: U;
}

+ 3
- 3
test/platforms/openai/usage.test.ts View File

@@ -3,7 +3,7 @@ import {
ChatCompletionModel,
getPromptTokens, getTokens,
MessageRole,
Usage,
CompletionUsage,
} from '../../../src/platforms/openai';

describe('OpenAI', () => {
@@ -83,7 +83,7 @@ describe('OpenAI', () => {
request.model,
)
.length;
const usage: Usage = {
const usage: CompletionUsage = {
prompt_tokens: promptTokensLength,
completion_tokens: completionTokensLength,
total_tokens: promptTokensLength + completionTokensLength,
@@ -132,7 +132,7 @@ describe('OpenAI', () => {
request.model,
)
.length;
const usage: Usage = {
const usage: CompletionUsage = {
prompt_tokens: promptTokensLength,
completion_tokens: completionTokensLength,
total_tokens: promptTokensLength + completionTokensLength,


Loading…
Cancel
Save