Browse Source

Implement model listing, other image functions

Add image edits and variations endpoints.
master
TheoryOfNekomata 1 year ago
parent
commit
5fe705e6dc
10 changed files with 258 additions and 35 deletions
  1. +53
    -0
      src/packages/form-data.ts
  2. +3
    -3
      src/platforms/openai/chat.ts
  3. +4
    -2
      src/platforms/openai/common.ts
  4. +43
    -10
      src/platforms/openai/events.ts
  5. +3
    -3
      src/platforms/openai/features/chat-completion.ts
  6. +3
    -3
      src/platforms/openai/features/edit.ts
  7. +105
    -10
      src/platforms/openai/features/image.ts
  8. +40
    -0
      src/platforms/openai/features/model.ts
  9. +3
    -3
      src/platforms/openai/features/text-completion.ts
  10. +1
    -1
      src/platforms/openai/index.ts

+ 53
- 0
src/packages/form-data.ts View File

@@ -0,0 +1,53 @@
const appendValue = (formData: FormData, key: string, value: unknown, arrayDepth = 0) => {
if (value instanceof Buffer) {
formData.append(key, new Blob([value]));
return;
}
if (value instanceof Blob) {
formData.append(key, value);
return;
}
if (value instanceof File) {
formData.append(key, value, value.name);
return;
}
if (typeof value === 'string') {
formData.append(key, value);
return;
}
if (typeof value === 'number' && !Number.isNaN(value)) {
formData.append(key, value.toString(10));
return;
}
if (typeof value === 'boolean') {
formData.append(key, value.toString());
return;
}
if (Array.isArray(value) && arrayDepth === 0) {
appendValue(formData, key, value, arrayDepth + 1);
return;
}
if (typeof value === 'object' && value !== null) {
formData.append(key, JSON.stringify(value));
return;
}
throw new Error(`Invalid value for key: ${key}`);
};

export const fromJson = (json: Record<string, unknown>) => {
const formData = new FormData();
Object
.entries(json)
.forEach(([key, value]) => {
appendValue(formData, key, value);
});
return formData;
};

export const toJson = (formData: FormData) => {
const json = {} as Record<string, unknown>;
formData.forEach((value, key) => {
json[key] = value;
});
return json;
};

+ 3
- 3
src/platforms/openai/chat.ts View File

@@ -85,7 +85,7 @@ export const buildChatFromTranscript = (transcript: string) => {
lastRole = MessageRole.SYSTEM;
return {
role: MessageRole.SYSTEM,
content: line.substring('system:'.length).trim(),
content: line.slice('system:'.length).trim(),
};
}

@@ -93,7 +93,7 @@ export const buildChatFromTranscript = (transcript: string) => {
lastRole = MessageRole.USER;
return {
role: MessageRole.USER,
content: line.substring('user:'.length).trim(),
content: line.slice('user:'.length).trim(),
};
}

@@ -101,7 +101,7 @@ export const buildChatFromTranscript = (transcript: string) => {
lastRole = MessageRole.ASSISTANT;
return {
role: MessageRole.ASSISTANT,
content: line.substring('assistant:'.length).trim(),
content: line.slice('assistant:'.length).trim(),
};
}



+ 4
- 2
src/platforms/openai/common.ts View File

@@ -15,14 +15,16 @@ export type DataEventId = string;

export type Timestamp = number;

export interface PlatformResponse {
export interface CreatedResource {
created: Timestamp;
}

export type DoFetchBody = BodyInit | Record<string, unknown>

export type DoFetch = (
method: string,
path: string,
body: Record<string, unknown>
body?: DoFetchBody
) => Promise<Response>;

export type ConsumeStream = (


+ 43
- 10
src/platforms/openai/events.ts View File

@@ -1,11 +1,19 @@
import { PassThrough } from 'stream';
import { EventEmitter } from 'events';
import fetchPonyfill from 'fetch-ponyfill';
import { Configuration } from './common';
import { Configuration, DoFetchBody } from './common';
import { createTextCompletion, CreateTextCompletionParams } from './features/text-completion';
import { CreateChatCompletionParams, createChatCompletion } from './features/chat-completion';
import { CreateImageParams, createImage } from './features/image';
import {
CreateImageParams,
createImage,
CreateImageEditParams,
createImageEdit,
CreateImageVariationParams,
createImageVariation,
} from './features/image';
import { CreateEditParams, createEdit } from './features/edit';
import { listModels } from './features/model';

export type DataEventCallback<D> = (data: D) => void;

@@ -14,8 +22,11 @@ export type ErrorEventCallback = (event: Error) => void;
export interface PlatformEventEmitter extends NodeJS.EventEmitter {
createChatCompletion(params: CreateChatCompletionParams): void;
createImage(params: CreateImageParams): void;
createImageEdit(params: CreateImageEditParams): void;
createImageVariation(params: CreateImageVariationParams): void;
createCompletion(params: CreateTextCompletionParams): void;
createEdit(params: CreateEditParams): void;
listModels(): void;
on<D>(event: 'data', callback: DataEventCallback<D>): this;
on(event: 'end', callback: () => void): this;
on(event: 'error', callback: ErrorEventCallback): this;
@@ -26,29 +37,48 @@ export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEv

readonly createImage: PlatformEventEmitter['createImage'];

readonly createImageEdit: PlatformEventEmitter['createImageEdit'];

readonly createImageVariation: PlatformEventEmitter['createImageVariation'];

readonly createChatCompletion: PlatformEventEmitter['createChatCompletion'];

readonly createEdit: PlatformEventEmitter['createEdit'];

readonly listModels: PlatformEventEmitter['listModels'];

constructor(configParams: Configuration) {
super();
const headers: Record<string, string> = {
const platformHeaders: Record<string, string> = {
Authorization: `Bearer ${configParams.apiKey}`,
};

if (configParams.organizationId) {
headers['OpenAI-Organization'] = configParams.organizationId;
platformHeaders['OpenAI-Organization'] = configParams.organizationId;
}

const { fetch: fetchInstance } = fetchPonyfill();
const doFetch = (method: string, path: string, body: Record<string, unknown>) => {
const doFetch = (method: string, path: string, body?: DoFetchBody) => {
const requestHeaders = {
...platformHeaders,
};

let theBody: BodyInit;

if (
body instanceof FormData
|| body instanceof URLSearchParams
) {
theBody = body;
} else {
theBody = JSON.stringify(body);
requestHeaders['Content-Type'] = 'application/json';
}

const theFetchParams = {
method,
headers: {
...headers,
'Content-Type': 'application/json',
},
body: JSON.stringify(body),
headers: requestHeaders,
body: theBody,
};

const url = new URL(
@@ -85,7 +115,10 @@ export class PlatformEventEmitterImpl extends EventEmitter implements PlatformEv
});
}
};
this.listModels = listModels.bind(this, doFetch);
this.createImage = createImage.bind(this, doFetch);
this.createImageVariation = createImageVariation.bind(this, doFetch);
this.createImageEdit = createImageEdit.bind(this, doFetch);
this.createCompletion = createTextCompletion.bind(this, doFetch, consumeStream);
this.createChatCompletion = createChatCompletion.bind(this, doFetch, consumeStream);
this.createEdit = createEdit.bind(this, doFetch);


+ 3
- 3
src/platforms/openai/features/chat-completion.ts View File

@@ -4,7 +4,7 @@ import {
DataEventId,
DoFetch,
PlatformError,
PlatformResponse,
CreatedResource,
} from '../common';
import {
UsageMetadata,
@@ -41,7 +41,7 @@ export enum DataEventObjectType {

export interface CreateChatCompletionDataEvent<
C extends Partial<FinishableChoiceBase>
> extends PlatformResponse {
> extends CreatedResource {
id: DataEventId;
object: DataEventObjectType;
model: ChatCompletionModel;
@@ -72,7 +72,7 @@ export function createChatCompletion(
frequency_penalty: params.frequencyPenalty ?? 0,
logit_bias: params.logitBias ?? {},
user: params.user,
})
} as Record<string, unknown>)
.then(async (response) => {
if (!response.ok) {
this.emit('error', new PlatformError(


+ 3
- 3
src/platforms/openai/features/edit.ts View File

@@ -2,7 +2,7 @@ import {
ChoiceBase,
DoFetch,
PlatformError,
PlatformResponse,
CreatedResource,
} from '../common';
import {
UsageMetadata,
@@ -26,7 +26,7 @@ export interface EditChoice extends ChoiceBase {
text: string;
}

export interface CreateEditDataEvent extends PlatformResponse, UsageMetadata {
export interface CreateEditDataEvent extends CreatedResource, UsageMetadata {
object: DataEventObjectType;
choices: EditChoice[];
}
@@ -43,7 +43,7 @@ export function createEdit(
n: params.n ?? 1,
temperature: params.temperature ?? 1,
top_p: params.topP ?? 1,
})
} as Record<string, unknown>)
.then(async (response) => {
if (!response.ok) {
this.emit('error', new PlatformError(


+ 105
- 10
src/platforms/openai/features/image.ts View File

@@ -1,16 +1,17 @@
import * as FormDataUtils from '../../../packages/form-data';
import {
DoFetch,
PlatformError,
PlatformResponse,
CreatedResource,
} from '../common';

export enum CreateImageSize {
export enum ImageSize {
SQUARE_256 = '256x256',
SQUARE_512 = '512x512',
SQUARE_1024 = '1024x1024',
}

export enum CreateImageResponseFormat {
export enum ImageResponseFormat {
URL = 'url',
BASE64_JSON = 'b64_json',
}
@@ -18,15 +19,15 @@ export enum CreateImageResponseFormat {
export interface CreateImageParams {
prompt: string;
n? : number;
size?: CreateImageSize;
size?: ImageSize;
user?: string;
}

export interface CreateImageData {
export interface ImageData {
b64_json: string;
}

export interface CreateImageDataEvent extends PlatformResponse {
export interface ImageDataEvent extends CreatedResource {
data: Buffer[];
}

@@ -38,10 +39,10 @@ export function createImage(
doFetch('POST', '/images/generations', {
prompt: params.prompt,
n: params.n ?? 1,
size: params.size ?? CreateImageSize.SQUARE_1024,
size: params.size ?? ImageSize.SQUARE_1024,
user: params.user,
response_format: CreateImageResponseFormat.BASE64_JSON,
})
response_format: ImageResponseFormat.BASE64_JSON,
} as Record<string, unknown>)
.then(async (response) => {
if (!response.ok) {
this.emit('error', new PlatformError(
@@ -54,7 +55,101 @@ export function createImage(
}

const responseData = await response.json() as Record<string, unknown>;
const data = responseData.data as CreateImageData[];
const data = responseData.data as ImageData[];
this.emit('data', {
...responseData,
data: data.map((item) => Buffer.from(item.b64_json, 'base64')),
});
this.emit('end');
})
.catch((err) => {
this.emit('error', err as Error);
this.emit('end');
});

return this;
}

export interface CreateImageEditParams {
image: Buffer;
mask?: Buffer;
prompt: string;
n?: number;
size?: ImageSize;
user?: string;
}

export function createImageEdit(
this: NodeJS.EventEmitter,
doFetch: DoFetch,
params: CreateImageEditParams,
) {
doFetch('POST', '/images/edits', FormDataUtils.fromJson({
prompt: params.prompt,
image: params.image,
mask: params.mask,
n: params.n ?? 1,
size: params.size ?? ImageSize.SQUARE_1024,
response_format: ImageResponseFormat.BASE64_JSON,
}))
.then(async (response) => {
if (!response.ok) {
this.emit('error', new PlatformError(
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`Request from platform returned with status: ${response.status}`,
response,
));
this.emit('end');
return;
}

const responseData = await response.json() as Record<string, unknown>;
const data = responseData.data as ImageData[];
this.emit('data', {
...responseData,
data: data.map((item) => Buffer.from(item.b64_json, 'base64')),
});
this.emit('end');
})
.catch((err) => {
this.emit('error', err as Error);
this.emit('end');
});

return this;
}

export interface CreateImageVariationParams {
image: Buffer;
n?: number;
size?: ImageSize;
user?: string;
}

export function createImageVariation(
this: NodeJS.EventEmitter,
doFetch: DoFetch,
params: CreateImageVariationParams,
) {
doFetch('POST', '/images/variations', FormDataUtils.fromJson({
image: params.image,
n: params.n ?? 1,
size: params.size ?? ImageSize.SQUARE_1024,
response_format: ImageResponseFormat.BASE64_JSON,
}))
.then(async (response) => {
if (!response.ok) {
this.emit('error', new PlatformError(
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`Request from platform returned with status: ${response.status}`,
response,
));
this.emit('end');
return;
}

const responseData = await response.json() as Record<string, unknown>;
const data = responseData.data as ImageData[];
this.emit('data', {
...responseData,
data: data.map((item) => Buffer.from(item.b64_json, 'base64')),


+ 40
- 0
src/platforms/openai/features/model.ts View File

@@ -0,0 +1,40 @@
import { DoFetch, PlatformError } from '../common';

export enum DataEventObjectType {
MODEL = 'model',
}

export interface ModelData {
id: string;
object: DataEventObjectType,
owned_by: string;
permission: string[];
}

export function listModels(
this: NodeJS.EventEmitter,
doFetch: DoFetch,
) {
doFetch('GET', '/models')
.then(async (response) => {
if (!response.ok) {
this.emit('error', new PlatformError(
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`Request from platform returned with status: ${response.status}`,
response,
));
this.emit('end');
return;
}

const responseData = await response.json() as Record<string, unknown>;
this.emit('data', responseData.data as ModelData[]);
this.emit('end');
})
.catch((err) => {
this.emit('error', err as Error);
this.emit('end');
});

return this;
}

+ 3
- 3
src/platforms/openai/features/text-completion.ts View File

@@ -5,7 +5,7 @@ import {
DoFetch,
FinishableChoiceBase,
PlatformError,
PlatformResponse,
CreatedResource,
} from '../common';
import {
UsageMetadata,
@@ -39,7 +39,7 @@ export interface TextCompletionChoice extends FinishableChoiceBase {

export interface CreateTextCompletionDataEvent<
C extends Partial<FinishableChoiceBase>
> extends PlatformResponse {
> extends CreatedResource {
id: DataEventId;
object: DataEventObjectType;
model: TextCompletionModel;
@@ -73,7 +73,7 @@ export function createTextCompletion(
user: params.user,
presence_penalty: params.presencePenalty,
frequency_penalty: params.frequencyPenalty,
})
} as Record<string, unknown>)
.then(async (response) => {
if (!response.ok) {
this.emit('error', new PlatformError(


+ 1
- 1
src/platforms/openai/index.ts View File

@@ -23,7 +23,7 @@ export {
CreateEditDataEvent,
DataEventObjectType as EditDataEventObjectType,
} from './features/edit';
export { CreateImageDataEvent, CreateImageSize } from './features/image';
export { ImageDataEvent, ImageSize } from './features/image';

export const PLATFORM_ID = 'openai' as const;



Loading…
Cancel
Save