diff --git a/README.md b/README.md index 542f9ea..7be2b55 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ npm install opengradient-sdk ## Requirements - Node.js 18+ (for global `fetch`) -- A funded EVM wallet on Base (settlement happens in OPG on the Base network via [x402](https://x402.org)) +- A funded EVM wallet on Base for direct `chat()` / `completion()` calls. Relay-backed `chatOhttp()` calls do not require a wallet private key. ## Quick Start @@ -50,6 +50,31 @@ for await (const chunk of stream) { } ``` +### OHTTP-encrypted chat + +Use `chatOhttp` when you want the request body encrypted for an OHTTP-enabled +TEE gateway through the same chat API relay used by the frontend. The default +relay paths are `/api/v1/chat/ohttp/config` and `/api/v1/chat/ohttp`. + +```typescript +import { Client, TEE_LLM } from "opengradient-sdk"; + +const client = new Client({ + ohttpRelayUrl: process.env.OG_OHTTP_RELAY_URL, + ohttpHeaders: { + Authorization: `Bearer ${process.env.OG_OHTTP_AUTH_TOKEN}`, + }, +}); + +const result = await client.llm.chatOhttp({ + model: TEE_LLM.GPT_5, + messages: [{ role: "user", content: "Hello over OHTTP" }], +}); + +console.log(result.chatOutput?.content); +console.log("TEE:", result.teeId); +``` + ### Tool / function calling ```typescript diff --git a/examples/llm_chat_ohttp.ts b/examples/llm_chat_ohttp.ts new file mode 100644 index 0000000..a715f5e --- /dev/null +++ b/examples/llm_chat_ohttp.ts @@ -0,0 +1,30 @@ +// Run an OHTTP-encrypted chat completion through an OHTTP relay/gateway. +// +// Run with: +// OG_OHTTP_RELAY_URL=https://chat-api.example OG_OHTTP_AUTH_TOKEN=... npx ts-node examples/llm_chat_ohttp.ts + +import { Client, TEE_LLM } from "../src"; + +async function main() { + const client = new Client({ + ohttpRelayUrl: process.env.OG_OHTTP_RELAY_URL, + ohttpConfigPath: process.env.OG_OHTTP_CONFIG_PATH, + ohttpRequestPath: process.env.OG_OHTTP_REQUEST_PATH, + ohttpHeaders: process.env.OG_OHTTP_AUTH_TOKEN + ? { Authorization: `Bearer ${process.env.OG_OHTTP_AUTH_TOKEN}` } + : undefined, + }); + + const result = await client.llm.chatOhttp({ + model: TEE_LLM.GPT_5, + messages: [{ role: "user", content: "Explain OHTTP in one paragraph." }], + }); + + console.log(`Response: ${result.chatOutput?.content}`); + console.log(`TEE: ${result.teeId ?? "(unknown)"}`); +} + +main().catch((err) => { + console.error(err); + process.exit(1); +}); diff --git a/package-lock.json b/package-lock.json index 714f2e5..fffa670 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,14 +1,17 @@ { "name": "opengradient-sdk", - "version": "2.1.0", + "version": "2.1.1", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "opengradient-sdk", - "version": "2.1.0", + "version": "2.1.1", "license": "MIT", "dependencies": { + "@noble/ciphers": "^1.3.0", + "@noble/curves": "^1.8.0", + "@noble/hashes": "^1.4.0", "@x402/core": "^2.11.0", "@x402/evm": "^2.11.0", "@x402/fetch": "^2.11.0", diff --git a/package.json b/package.json index 7807050..0dcbda8 100644 --- a/package.json +++ b/package.json @@ -52,6 +52,9 @@ "testEnvironment": "node" }, "dependencies": { + "@noble/ciphers": "^1.3.0", + "@noble/curves": "^1.8.0", + "@noble/hashes": "^1.4.0", "@x402/core": "^2.11.0", "@x402/evm": "^2.11.0", "@x402/fetch": "^2.11.0", diff --git a/src/client.ts b/src/client.ts index cd58e43..d0ef6f6 100644 --- a/src/client.ts +++ b/src/client.ts @@ -1,4 +1,5 @@ import { LLM } from "./llm"; +import { OHTTPClient } from "./ohttp"; import { ClientConfig } from "./types"; import { RegistryTEEConnection, @@ -31,34 +32,50 @@ import { */ export class Client { readonly llm: LLM; + readonly ohttp: OHTTPClient; constructor(config: ClientConfig) { - const privateKey = ( - config.privateKey.startsWith("0x") - ? config.privateKey - : `0x${config.privateKey}` - ) as `0x${string}`; + const privateKey = config.privateKey + ? ((config.privateKey.startsWith("0x") + ? config.privateKey + : `0x${config.privateKey}`) as `0x${string}`) + : undefined; - let connection: TEEConnection; - if (config.llmServerUrl) { - connection = new StaticTEEConnection(config.llmServerUrl); - } else { - const registry = new TEERegistry( - config.rpcUrl ?? DEFAULT_OG_RPC_URL, - config.teeRegistryAddress ?? DEFAULT_TEE_REGISTRY_ADDRESS, - ); - connection = new RegistryTEEConnection(registry); + let connection: TEEConnection | undefined; + let registry: TEERegistry | undefined; + if (privateKey || !hasExplicitOHTTPRelay(config)) { + if (config.llmServerUrl) { + connection = new StaticTEEConnection(config.llmServerUrl); + } else { + registry = new TEERegistry( + config.rpcUrl ?? DEFAULT_OG_RPC_URL, + config.teeRegistryAddress ?? DEFAULT_TEE_REGISTRY_ADDRESS, + ); + connection = new RegistryTEEConnection(registry); + } } + this.ohttp = new OHTTPClient({ + relayUrl: config.ohttpRelayUrl ?? config.llmServerUrl, + requestPath: config.ohttpRequestPath, + configPath: config.ohttpConfigPath, + headers: config.ohttpHeaders, + }); + this.llm = new LLM({ privateKey, maxPaymentValue: config.maxPaymentValue, connection, + ohttp: this.ohttp, }); } /** Tear down dispatchers and any background refresh timers. */ async close(): Promise { - await this.llm.close(); + await Promise.all([this.llm.close(), this.ohttp.close()]); } } + +function hasExplicitOHTTPRelay(config: ClientConfig): boolean { + return Boolean(config.ohttpRelayUrl); +} diff --git a/src/index.ts b/src/index.ts index c049518..69e998f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,5 +1,22 @@ export { Client } from "./client"; export { LLM } from "./llm"; +export { + OHTTPClient, + OHTTP_CHUNKED_RESPONSE_MEDIA_TYPE, + OHTTP_REQUEST_MEDIA_TYPE, + OHTTP_RESPONSE_MEDIA_TYPE, +} from "./ohttp"; +export type { + OHTTPByteStreamResult, + OHTTPClientConfig, + OHTTPGatewayConfig, + OHTTPGatewayMetadata, + OHTTPJsonResult, + OHTTPKeyConfig, + OHTTPRequestOptions, + OHTTPRouteMetadata, + OHTTPSigningKey, +} from "./ohttp"; export { TEE_LLM, X402SettlementMode, OpenGradientError } from "./types"; @@ -24,6 +41,7 @@ export { TEE_TYPE_VALIDATOR, } from "./teeRegistry"; export type { TEEEndpoint } from "./teeRegistry"; +export type { TEEOHTTPConfig, TEEOHTTPEndpoint } from "./teeRegistry"; export { RegistryTEEConnection, diff --git a/src/llm.ts b/src/llm.ts index dbdeb7d..313ca01 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -4,6 +4,7 @@ import { UptoEvmScheme } from "@x402/evm"; import { registerExactEvmScheme } from "@x402/evm/exact/client"; import { privateKeyToAccount } from "viem/accounts"; import type { Agent } from "undici"; +import type { OHTTPClient, OHTTPRouteMetadata } from "./ohttp"; import { ChatParams, ChatMessage, @@ -28,10 +29,12 @@ const CHAT_ENDPOINT = "/v1/chat/completions"; const COMPLETION_ENDPOINT = "/v1/completions"; export interface LLMConfig { - privateKey: `0x${string}`; + privateKey?: `0x${string}`; maxPaymentValue?: bigint; /** Resolves the active TEE endpoint and TLS dispatcher. */ - connection: TEEConnection; + connection?: TEEConnection; + /** Optional OHTTP transport for encrypted relay requests. */ + ohttp?: OHTTPClient; } /** @@ -51,7 +54,7 @@ export class LLM { /** Tear down dispatchers and any background refresh timers. */ async close(): Promise { - await this.config.connection.close(); + await this.config.connection?.close(); } /** @@ -132,6 +135,22 @@ export class LLM { return this.chatNonStreaming(params); } + /** + * Perform an OHTTP-encrypted chat completion through an OHTTP relay/gateway. + */ + chatOhttp(params: ChatParams & { stream?: false }): Promise; + /** + * Perform an OHTTP-encrypted streaming chat completion. + */ + chatOhttp(params: ChatParams & { stream: true }): AsyncIterable; + chatOhttp( + params: ChatParams & { stream?: boolean }, + ): Promise | AsyncIterable { + this.validateChatParams(params); + if (params.stream) return this.chatOhttpStream(params); + return this.chatOhttpNonStreaming(params); + } + private async chatNonStreaming( params: ChatParams, ): Promise { @@ -187,6 +206,42 @@ export class LLM { }; } + private async chatOhttpNonStreaming( + params: ChatParams, + ): Promise { + const payload = this.buildChatPayload(params, false); + const { body, route } = await this.requireOHTTPClient().requestJson< + Record + >(payload); + + const choices = body.choices as + | Array<{ + message?: ChatMessage; + finish_reason?: string; + }> + | undefined; + if (!choices || choices.length === 0) { + throw new OpenGradientError( + `Invalid OHTTP response: 'choices' missing or empty in ${JSON.stringify(body)}`, + ); + } + + const message = choices[0].message ?? { role: "assistant" }; + normalizeChatMessage(message); + + return { + finishReason: choices[0].finish_reason, + chatOutput: message, + usage: body.usage as TokenUsage | undefined, + dataSettlementTransactionHash: body.data_settlement_transaction_hash, + dataSettlementBlobId: body.data_settlement_blob_id, + teeSignature: body.tee_signature, + teeTimestamp: body.tee_timestamp, + teeId: route.teeId, + teeEndpoint: route.teeEndpoint, + }; + } + private async *chatStream(params: ChatParams): AsyncIterable { const payload = this.buildChatPayload(params, true); const settlementMode = @@ -207,7 +262,7 @@ export class LLM { } // Connection-level failure during stream setup: re-resolve and retry once. try { - await this.config.connection.reconnect(); + await this.requireTEEConnection().reconnect(); } catch (reconnectErr) { throw new OpenGradientError( `TEE LLM stream failed and registry refresh failed: ${String(reconnectErr)}`, @@ -276,6 +331,67 @@ export class LLM { } } + private async *chatOhttpStream( + params: ChatParams, + ): AsyncIterable { + const payload = this.buildChatPayload(params, true); + const nonStreamPayload = this.buildChatPayload(params, false); + const stream = await this.requireOHTTPClient().streamBytes(payload); + const decoder = new TextDecoder(); + let buffer = ""; + let pendingFinal: StreamChunk | null = null; + let finalData: Record | null = null; + let fullContent = ""; + + for await (const bytes of stream.chunks) { + buffer += decoder.decode(bytes, { stream: true }); + + let newlineIdx; + while ((newlineIdx = buffer.indexOf("\n")) !== -1) { + const line = buffer.slice(0, newlineIdx).trim(); + buffer = buffer.slice(newlineIdx + 1); + if (!line || !line.startsWith("data: ")) continue; + + const dataStr = line.slice(6).trim(); + if (dataStr === "[DONE]") continue; + + let data: any; + try { + data = JSON.parse(dataStr); + } catch { + continue; + } + if (typeof data.error === "string") { + throw new OpenGradientError(data.error); + } + + const chunk = withOHTTPRoute(parseStreamChunk(data), stream.route); + for (const choice of chunk.choices) { + if (choice.delta.content) fullContent += choice.delta.content; + } + if ( + typeof data.tee_signature === "string" || + typeof data.tee_output_hash === "string" + ) { + finalData = data; + } + if (chunk.isFinal) { + pendingFinal = chunk; + continue; + } + yield chunk; + } + } + + if (finalData) { + await stream.verifyResponse(finalData, { + responseContent: fullContent, + requestHashCandidates: [nonStreamPayload, payload], + }); + } + if (pendingFinal) yield pendingFinal; + } + private async *chatToolsAsStream( params: ChatParams, ): AsyncIterable { @@ -340,7 +456,8 @@ export class LLM { private getX402Client(): x402Client { if (!this.x402ClientInstance) { - const account = privateKeyToAccount(this.config.privateKey); + const privateKey = this.requirePrivateKey(); + const account = privateKeyToAccount(privateKey); const client = new x402Client(); registerExactEvmScheme(client, { signer: account }); // The TEE may quote the "upto" scheme — register it on EVM networks too. @@ -350,6 +467,45 @@ export class LLM { return this.x402ClientInstance; } + private requireOHTTPClient(): OHTTPClient { + if (!this.config.ohttp) { + throw new OpenGradientError( + "OHTTP is not configured. Pass ohttpRelayUrl, ohttpConfigUrl, or an OHTTPClient to use chatOhttp().", + ); + } + return this.config.ohttp; + } + + private requirePrivateKey(): `0x${string}` { + if (!this.config.privateKey) { + throw new OpenGradientError( + "An EVM privateKey is required for direct x402-paid chat() and completion() calls. Use chatOhttp() with ohttpRelayUrl for relay-paid OHTTP calls.", + ); + } + return this.config.privateKey; + } + + private requireTEEConnection(): TEEConnection { + if (!this.config.connection) { + throw new OpenGradientError( + "A TEE connection is not configured. Provide privateKey for direct x402 calls, or use chatOhttp() with an OHTTP relay.", + ); + } + return this.config.connection; + } + + private validateChatParams(params: ChatParams): void { + if (params.responseFormat?.type === "json_object") { + const provider = params.model.split("/")[0]; + if (provider === "anthropic") { + throw new OpenGradientError( + "Anthropic models do not support response_format type 'json_object'. " + + "Use { type: 'json_schema', jsonSchema: {...} } with an explicit schema instead.", + ); + } + } + } + /** * Build a paid fetch that injects the TEE's pinned TLS dispatcher into every * request (including x402 payment retries). @@ -371,7 +527,8 @@ export class LLM { body: Record, settlementMode: X402SettlementMode, ): Promise<{ response: Response; tee: ActiveTEE }> { - this.config.connection.ensureRefreshLoop(); + const connection = this.requireTEEConnection(); + connection.ensureRefreshLoop(); try { return await this.sendOnce(path, body, settlementMode); } catch (e) { @@ -380,7 +537,7 @@ export class LLM { throw e; } try { - await this.config.connection.reconnect(); + await connection.reconnect(); } catch (reconnectErr) { throw new OpenGradientError( `TEE LLM request failed and registry refresh failed: ${String(reconnectErr)}`, @@ -395,7 +552,7 @@ export class LLM { body: Record, settlementMode: X402SettlementMode, ): Promise<{ response: Response; tee: ActiveTEE }> { - const tee = await this.config.connection.ensureConnected(); + const tee = await this.requireTEEConnection().ensureConnected(); const url = `${trimSlash(tee.endpoint)}${path}`; const paidFetch = this.buildPaidFetch(tee.dispatcher); @@ -453,6 +610,30 @@ function serializeResponseFormat(format: ResponseFormat): Record { return out; } +function normalizeChatMessage(message: ChatMessage): void { + // Some providers (Anthropic via the proxy) return content as an array of + // typed blocks; flatten to a plain string for parity with Python. + if (Array.isArray((message as any).content)) { + message.content = ((message as any).content as any[]) + .filter((b) => b && typeof b === "object" && b.type === "text") + .map((b) => b.text ?? "") + .join(" ") + .trim(); + } +} + +function withOHTTPRoute( + chunk: StreamChunk, + route: OHTTPRouteMetadata, +): StreamChunk { + if (!chunk.isFinal) return chunk; + return { + ...chunk, + teeId: route.teeId, + teeEndpoint: route.teeEndpoint, + }; +} + function parseStreamChunk(data: any): StreamChunk { const choices: StreamChoice[] = (data.choices ?? []).map((c: any) => { // The TEE proxy sometimes sends SSE events using the non-streaming diff --git a/src/ohttp.ts b/src/ohttp.ts new file mode 100644 index 0000000..d53f0c8 --- /dev/null +++ b/src/ohttp.ts @@ -0,0 +1,1002 @@ +import { chacha20poly1305 } from "@noble/ciphers/chacha"; +import { x25519 } from "@noble/curves/ed25519"; +import { expand, extract } from "@noble/hashes/hkdf"; +import { hmac } from "@noble/hashes/hmac"; +import { keccak_256 } from "@noble/hashes/sha3"; +import { sha256 } from "@noble/hashes/sha256"; +import { + bytesToHex, + concatBytes, + hexToBytes, + utf8ToBytes, +} from "@noble/hashes/utils"; +import { webcrypto } from "node:crypto"; +import { OpenGradientError } from "./types"; + +export const OHTTP_REQUEST_MEDIA_TYPE = "message/ohttp-req"; +export const OHTTP_RESPONSE_MEDIA_TYPE = "message/ohttp-res"; +export const OHTTP_CHUNKED_RESPONSE_MEDIA_TYPE = "message/ohttp-chunked-res"; + +const DEFAULT_OHTTP_CONFIG_PATH = "/api/v1/chat/ohttp/config"; +const DEFAULT_OHTTP_REQUEST_PATH = "/api/v1/chat/ohttp"; + +const KEY_CONFIG_ID = 0x01; +const KEM_ID_X25519 = 0x0020; +const KDF_ID_HKDF_SHA256 = 0x0001; +const AEAD_ID_CHACHA20_POLY1305 = 0x0003; +const NK = 32; +const NN = 12; +const NH = 32; + +const LABEL_REQUEST = utf8ToBytes("message/bhttp request"); +const LABEL_RESPONSE = utf8ToBytes("message/bhttp response"); +const LABEL_CHUNKED_RESPONSE = utf8ToBytes("message/bhttp chunked response"); +const LABEL_FINAL = utf8ToBytes("final"); +const HPKE_VERSION = utf8ToBytes("HPKE-v1"); + +const subtleCrypto = (globalThis.crypto ?? webcrypto).subtle; +const randomCrypto = globalThis.crypto ?? webcrypto; + +export interface OHTTPKeyConfig { + keyId: number; + kemId: number; + kdfId: number; + aeadId: number; + /** Hex-encoded X25519 public key, with or without 0x prefix. */ + publicKey: string; + /** Base64-encoded serialized key config. */ + keyConfig?: string; +} + +export interface OHTTPSigningKey { + /** PEM-encoded RSA public key used to verify TEE response signatures. */ + publicKey: string; + hpke?: OHTTPKeyConfig | null; +} + +export interface OHTTPGatewayMetadata { + endpoint: string; + host: string; +} + +export interface OHTTPGatewayConfig { + teeId?: string; + ohttp: OHTTPKeyConfig; + signingKey: OHTTPSigningKey; + teeGateway: OHTTPGatewayMetadata; +} + +export interface OHTTPRouteMetadata { + relayHost: string; + teeHost: string; + teeId?: string; + teeEndpoint: string; +} + +export interface OHTTPClientConfig { + /** Base URL for the chat API OHTTP relay. */ + relayUrl?: string; + requestPath?: string; + configPath?: string; + headers?: HeadersInit | (() => HeadersInit | Promise); + fetch?: typeof fetch; + config?: OHTTPGatewayConfig; +} + +export interface OHTTPRequestOptions { + signal?: AbortSignal; + headers?: HeadersInit; + verify?: boolean; +} + +export interface OHTTPJsonResult> { + status: number; + body: TBody; + route: OHTTPRouteMetadata; +} + +export interface OHTTPByteStreamResult { + chunks: AsyncIterable; + route: OHTTPRouteMetadata; + verifyResponse: ( + responseBody: Record, + options?: { + responseContent?: string; + requestHashCandidates?: unknown[]; + }, + ) => Promise; +} + +interface EncapsulatedRequest { + wire: Uint8Array; + enc: Uint8Array; + responseSecret: Uint8Array; + chunkedResponseSecret: Uint8Array; +} + +interface OHTTPInnerResult { + status: number; + body: Record; +} + +/** + * Minimal OHTTP client for encrypting JSON requests to an OHTTP-capable TEE + * gateway and decrypting the gateway response locally. + */ +export class OHTTPClient { + private configPromise: Promise | null = null; + private readonly fetchImpl: typeof fetch; + + constructor(private readonly config: OHTTPClientConfig = {}) { + this.fetchImpl = config.fetch ?? fetch; + } + + clearConfigCache(): void { + this.configPromise = null; + } + + async close(): Promise { + /* no-op: relay-backed OHTTP uses caller-provided fetch lifecycle */ + } + + async getConfig(): Promise { + this.configPromise ??= this.resolveConfig(); + return this.configPromise; + } + + async getRouteMetadata(): Promise { + const gatewayConfig = await this.getConfig(); + return routeMetadata(gatewayConfig, this.resolveRequestUrl()); + } + + async requestJson>( + body: unknown, + options: OHTTPRequestOptions = {}, + ): Promise> { + const gatewayConfig = await this.getConfig(); + validateOHTTPConfig(gatewayConfig); + + const encodedBody = utf8ToBytes(JSON.stringify(body)); + const encapsulated = encapsulateRequest( + publicKeyBytes(gatewayConfig.ohttp.publicKey), + encodedBody, + ); + const response = await this.sendEncapsulated( + gatewayConfig, + encapsulated.wire, + OHTTP_RESPONSE_MEDIA_TYPE, + options, + ); + + const sealed = new Uint8Array(await response.arrayBuffer()); + const inner = decryptResponse( + encapsulated.responseSecret, + encapsulated.enc, + sealed, + ); + if (inner.status >= 400) { + throw new OpenGradientError( + String(inner.body.error ?? "TEE OHTTP request failed"), + inner.status, + ); + } + + if (options.verify !== false) { + await verifyTeeResponse({ + innerRequest: body, + responseBody: inner.body, + signingKeyPem: gatewayConfig.signingKey.publicKey, + }); + } + + return { + status: inner.status, + body: inner.body as TBody, + route: routeMetadata(gatewayConfig, this.resolveRequestUrl()), + }; + } + + async streamBytes( + body: unknown, + options: OHTTPRequestOptions = {}, + ): Promise { + const gatewayConfig = await this.getConfig(); + validateOHTTPConfig(gatewayConfig); + + const encodedBody = utf8ToBytes(JSON.stringify(body)); + const encapsulated = encapsulateRequest( + publicKeyBytes(gatewayConfig.ohttp.publicKey), + encodedBody, + ); + const response = await this.sendEncapsulated( + gatewayConfig, + encapsulated.wire, + OHTTP_CHUNKED_RESPONSE_MEDIA_TYPE, + options, + true, + ); + if (!response.body) { + throw new OpenGradientError("TEE OHTTP stream returned empty body"); + } + + const decrypter = new ChunkedOHTTPResponseDecrypter( + encapsulated.chunkedResponseSecret, + encapsulated.enc, + ); + return { + chunks: decryptChunkedStream(response.body, decrypter), + route: routeMetadata(gatewayConfig, this.resolveRequestUrl()), + verifyResponse: async (responseBody, verifyOptions) => { + if (options.verify === false) return; + await verifyTeeResponse({ + innerRequest: body, + responseBody, + signingKeyPem: gatewayConfig.signingKey.publicKey, + responseContent: verifyOptions?.responseContent, + requestHashCandidates: verifyOptions?.requestHashCandidates, + }); + }, + }; + } + + private async resolveConfig(): Promise { + if (this.config.config) return this.config.config; + + const configUrl = this.resolveConfigUrl(); + const response = await this.fetchImpl(configUrl, { + method: "GET", + headers: await this.baseHeaders(), + }); + if (!response.ok) { + throw new OpenGradientError( + `TEE OHTTP config request failed: HTTP ${response.status} - ${await safeResponseText(response)}`, + response.status, + ); + } + return normalizeGatewayConfig(await response.json()); + } + + private resolveConfigUrl(): string { + if (!this.config.relayUrl) { + throw new OpenGradientError("ohttpRelayUrl is required for OHTTP relay calls"); + } + return `${trimSlash(this.config.relayUrl)}${this.config.configPath ?? DEFAULT_OHTTP_CONFIG_PATH}`; + } + + private resolveRequestUrl(): string { + if (!this.config.relayUrl) { + throw new OpenGradientError("ohttpRelayUrl is required for OHTTP relay calls"); + } + return `${trimSlash(this.config.relayUrl)}${this.config.requestPath ?? DEFAULT_OHTTP_REQUEST_PATH}`; + } + + private async sendEncapsulated( + gatewayConfig: OHTTPGatewayConfig, + wire: Uint8Array, + accept: string, + options: OHTTPRequestOptions, + stream = false, + ): Promise { + const requestUrl = this.resolveRequestUrl(); + const response = await this.fetchImpl(requestUrl, { + method: "POST", + headers: { + ...(await this.baseHeaders()), + ...headersToRecord(options.headers), + "Content-Type": OHTTP_REQUEST_MEDIA_TYPE, + Accept: accept, + ...(stream ? { "X-OHTTP-Stream": "true" } : {}), + ...(gatewayConfig.teeId ? { "X-TEE-ID": gatewayConfig.teeId } : {}), + }, + body: arrayBufferBody(wire), + signal: options.signal, + }); + + if (!response.ok) { + throw new OpenGradientError( + `TEE OHTTP request failed: HTTP ${response.status} - ${await safeResponseText(response)}`, + response.status, + ); + } + return response; + } + + private async baseHeaders(): Promise> { + const headers = + typeof this.config.headers === "function" + ? await this.config.headers() + : this.config.headers; + return headersToRecord(headers); + } +} + +function normalizeGatewayConfig( + value: unknown, +): OHTTPGatewayConfig { + if (!isRecord(value)) { + throw new OpenGradientError("Malformed OHTTP config response"); + } + + const ohttpRaw = expectRecord(value.ohttp, "ohttp"); + const signingKeyRaw = expectRecord(value.signing_key ?? value.signingKey, "signing_key"); + const teeGatewayRaw = expectRecord( + value.tee_gateway ?? value.teeGateway, + "tee_gateway", + ); + + const ohttp = normalizeOHTTPKeyConfig(ohttpRaw); + + const hpkeRaw = signingKeyRaw.hpke; + return { + teeId: optionalStringField(value, "tee_id", "teeId"), + ohttp, + signingKey: { + publicKey: stringField(signingKeyRaw, "public_key", "publicKey"), + hpke: isRecord(hpkeRaw) ? normalizeOHTTPKeyConfig(hpkeRaw) : null, + }, + teeGateway: { + endpoint: stringField(teeGatewayRaw, "endpoint"), + host: + optionalStringField(teeGatewayRaw, "host") ?? + hostFromEndpoint(stringField(teeGatewayRaw, "endpoint")), + }, + }; +} + +function normalizeOHTTPKeyConfig(record: Record): OHTTPKeyConfig { + return { + keyId: numberField(record, "key_id", "keyId"), + kemId: numberField(record, "kem_id", "kemId"), + kdfId: numberField(record, "kdf_id", "kdfId"), + aeadId: numberField(record, "aead_id", "aeadId"), + publicKey: stringField(record, "public_key", "publicKey"), + keyConfig: optionalStringField(record, "key_config", "keyConfig"), + }; +} + +function validateOHTTPConfig(config: OHTTPGatewayConfig): void { + const { ohttp, signingKey } = config; + if ( + ohttp.keyId !== KEY_CONFIG_ID || + ohttp.kemId !== KEM_ID_X25519 || + ohttp.kdfId !== KDF_ID_HKDF_SHA256 || + ohttp.aeadId !== AEAD_ID_CHACHA20_POLY1305 + ) { + throw new OpenGradientError("Unsupported TEE OHTTP key configuration"); + } + + const hpkePublicKey = signingKey.hpke?.publicKey; + if (hpkePublicKey && cleanHex(hpkePublicKey) !== cleanHex(ohttp.publicKey)) { + throw new OpenGradientError( + "TEE signing key metadata does not match OHTTP key config", + ); + } +} + +function encapsulateRequest( + recipientPublicKey: Uint8Array, + plaintext: Uint8Array, +): EncapsulatedRequest { + const skE = randomCrypto.getRandomValues(new Uint8Array(32)); + const enc = x25519.getPublicKey(skE); + const dh = x25519.getSharedSecret(skE, recipientPublicKey); + const kemContext = concatBytes(enc, recipientPublicKey); + const sharedSecret = extractAndExpand(dh, kemContext); + + const header = headerBytes(); + const info = concatBytes(LABEL_REQUEST, new Uint8Array([0]), header); + const context = keySchedule(sharedSecret, info); + const ciphertext = chacha20poly1305( + context.key, + context.baseNonce, + new Uint8Array(), + ).encrypt(plaintext); + const wire = concatBytes(header, enc, ciphertext); + const responseSecret = hpkeLabeledExpand( + suiteId(), + context.exporterSecret, + utf8ToBytes("sec"), + LABEL_RESPONSE, + NK, + ); + const chunkedResponseSecret = hpkeLabeledExpand( + suiteId(), + context.exporterSecret, + utf8ToBytes("sec"), + LABEL_CHUNKED_RESPONSE, + NK, + ); + + return { wire, enc, responseSecret, chunkedResponseSecret }; +} + +function decryptResponse( + responseSecret: Uint8Array, + enc: Uint8Array, + sealed: Uint8Array, +): OHTTPInnerResult { + if (sealed.length <= NK) { + throw new OpenGradientError("Malformed OHTTP response"); + } + + const responseNonce = sealed.slice(0, NK); + const ciphertext = sealed.slice(NK); + const { key, nonce } = deriveResponseKeys(responseSecret, enc, responseNonce); + const plaintext = chacha20poly1305(key, nonce, new Uint8Array()).decrypt( + ciphertext, + ); + + return normalizeInnerResponse(JSON.parse(new TextDecoder().decode(plaintext))); +} + +class ChunkedOHTTPResponseDecrypter { + private buffer = new Uint8Array(); + private key: Uint8Array | null = null; + private nonce: Uint8Array | null = null; + private counter = 0; + private sawFinal = false; + + constructor( + private readonly responseSecret: Uint8Array, + private readonly enc: Uint8Array, + ) {} + + push(chunk: Uint8Array | undefined, done: boolean): Uint8Array[] { + if (chunk?.length) { + this.buffer = concatBytes(this.buffer, chunk); + } + + if (!this.key || !this.nonce) { + if (this.buffer.length < NK) { + if (done) { + throw new OpenGradientError("Malformed chunked OHTTP response"); + } + return []; + } + const responseNonce = this.buffer.slice(0, NK); + const context = deriveResponseKeys( + this.responseSecret, + this.enc, + responseNonce, + ); + this.key = context.key; + this.nonce = context.nonce; + this.buffer = this.buffer.slice(NK); + } + + const out: Uint8Array[] = []; + while (this.buffer.length > 0) { + const frame = decodeVarint(this.buffer, 0); + if (!frame) { + if (done) { + throw new OpenGradientError("Malformed chunked OHTTP response"); + } + break; + } + + const { value: sealedLength, offset } = frame; + if (sealedLength === 0) { + if (!done) break; + const ciphertext = this.buffer.slice(offset); + out.push(this.decryptChunk(ciphertext, true)); + this.buffer = new Uint8Array(); + this.sawFinal = true; + break; + } + + if (this.buffer.length < offset + sealedLength) { + if (done) { + throw new OpenGradientError("Truncated chunked OHTTP response"); + } + break; + } + + const ciphertext = this.buffer.slice(offset, offset + sealedLength); + out.push(this.decryptChunk(ciphertext, false)); + this.buffer = this.buffer.slice(offset + sealedLength); + } + + if (done && !this.sawFinal) { + throw new OpenGradientError("Chunked OHTTP response missing final marker"); + } + + return out; + } + + private decryptChunk(ciphertext: Uint8Array, isFinal: boolean): Uint8Array { + if (!this.key || !this.nonce) { + throw new OpenGradientError( + "Chunked OHTTP response keys are not initialized", + ); + } + + const chunkNonce = xorBytes(this.nonce, i2osp(this.counter, NN)); + const aad = isFinal ? LABEL_FINAL : new Uint8Array(); + const plaintext = chacha20poly1305(this.key, chunkNonce, aad).decrypt( + ciphertext, + ); + this.counter += 1; + return plaintext; + } +} + +async function* decryptChunkedStream( + body: ReadableStream, + decrypter: ChunkedOHTTPResponseDecrypter, +): AsyncIterable { + const reader = body.getReader(); + try { + while (true) { + const { value, done } = await reader.read(); + const chunks = decrypter.push(value, done); + for (const chunk of chunks) yield chunk; + if (done) break; + } + } finally { + reader.releaseLock(); + } +} + +function deriveResponseKeys( + responseSecret: Uint8Array, + enc: Uint8Array, + responseNonce: Uint8Array, +) { + const salt = concatBytes(enc, responseNonce); + const prk = hmac(sha256, salt, responseSecret); + return { + key: expand(sha256, prk, utf8ToBytes("key"), NK), + nonce: expand(sha256, prk, utf8ToBytes("nonce"), NN), + }; +} + +function normalizeInnerResponse(decoded: unknown): OHTTPInnerResult { + if (!isRecord(decoded)) { + throw new OpenGradientError("Malformed OHTTP response"); + } + + if (typeof decoded.status === "number" && isRecord(decoded.body)) { + return { + status: decoded.status, + body: decoded.body, + }; + } + + return { + status: 200, + body: decoded, + }; +} + +function extractAndExpand(dh: Uint8Array, kemContext: Uint8Array) { + const kemSuiteId = concatBytes(utf8ToBytes("KEM"), i2osp(KEM_ID_X25519, 2)); + const eaePrk = hpkeLabeledExtract( + kemSuiteId, + new Uint8Array(), + utf8ToBytes("eae_prk"), + dh, + ); + return hpkeLabeledExpand( + kemSuiteId, + eaePrk, + utf8ToBytes("shared_secret"), + kemContext, + NH, + ); +} + +function keySchedule(sharedSecret: Uint8Array, info: Uint8Array) { + const suite = suiteId(); + const pskIdHash = hpkeLabeledExtract( + suite, + new Uint8Array(), + utf8ToBytes("psk_id_hash"), + new Uint8Array(), + ); + const infoHash = hpkeLabeledExtract( + suite, + new Uint8Array(), + utf8ToBytes("info_hash"), + info, + ); + const keyScheduleContext = concatBytes( + new Uint8Array([0]), + pskIdHash, + infoHash, + ); + const secret = hpkeLabeledExtract( + suite, + sharedSecret, + utf8ToBytes("secret"), + new Uint8Array(), + ); + + return { + key: hpkeLabeledExpand( + suite, + secret, + utf8ToBytes("key"), + keyScheduleContext, + NK, + ), + baseNonce: hpkeLabeledExpand( + suite, + secret, + utf8ToBytes("base_nonce"), + keyScheduleContext, + NN, + ), + exporterSecret: hpkeLabeledExpand( + suite, + secret, + utf8ToBytes("exp"), + keyScheduleContext, + NH, + ), + }; +} + +function hpkeLabeledExtract( + suite: Uint8Array, + salt: Uint8Array, + label: Uint8Array, + ikm: Uint8Array, +) { + return extract(sha256, concatBytes(HPKE_VERSION, suite, label, ikm), salt); +} + +function hpkeLabeledExpand( + suite: Uint8Array, + prk: Uint8Array, + label: Uint8Array, + info: Uint8Array, + length: number, +) { + return expand( + sha256, + prk, + concatBytes(i2osp(length, 2), HPKE_VERSION, suite, label, info), + length, + ); +} + +function suiteId() { + return concatBytes( + utf8ToBytes("HPKE"), + i2osp(KEM_ID_X25519, 2), + i2osp(KDF_ID_HKDF_SHA256, 2), + i2osp(AEAD_ID_CHACHA20_POLY1305, 2), + ); +} + +function headerBytes() { + return concatBytes( + new Uint8Array([KEY_CONFIG_ID]), + i2osp(KEM_ID_X25519, 2), + i2osp(KDF_ID_HKDF_SHA256, 2), + i2osp(AEAD_ID_CHACHA20_POLY1305, 2), + ); +} + +async function verifyTeeResponse({ + innerRequest, + responseBody, + signingKeyPem, + responseContent, + requestHashCandidates, +}: { + innerRequest: unknown; + responseBody: Record; + signingKeyPem: string; + responseContent?: string; + requestHashCandidates?: unknown[]; +}) { + const teeRequestHash = + typeof responseBody.tee_request_hash === "string" + ? responseBody.tee_request_hash + : undefined; + const requestHashes = (requestHashCandidates ?? [innerRequest]).map( + (candidate) => keccak_256(utf8ToBytes(pythonJsonDumps(candidate))), + ); + const requestHash = + requestHashes.find( + (candidate) => bytesToHex(candidate) === teeRequestHash, + ) ?? requestHashes[0]; + if (teeRequestHash && teeRequestHash !== bytesToHex(requestHash)) { + throw new OpenGradientError("TEE request hash verification failed"); + } + + const outputContent = responseContent ?? responseContentForHash(responseBody); + const outputHash = keccak_256(utf8ToBytes(outputContent)); + const teeOutputHash = + typeof responseBody.tee_output_hash === "string" + ? responseBody.tee_output_hash + : undefined; + if (teeOutputHash && teeOutputHash !== bytesToHex(outputHash)) { + throw new OpenGradientError("TEE output hash verification failed"); + } + + const teeSignature = + typeof responseBody.tee_signature === "string" + ? responseBody.tee_signature + : undefined; + const teeTimestamp = + typeof responseBody.tee_timestamp === "string" || + typeof responseBody.tee_timestamp === "number" || + typeof responseBody.tee_timestamp === "bigint" + ? responseBody.tee_timestamp + : undefined; + + if (!teeSignature || teeTimestamp === undefined) { + return; + } + + const timestamp = uint256Bytes(BigInt(teeTimestamp)); + const msgHash = keccak_256(concatBytes(requestHash, outputHash, timestamp)); + const key = await importRsaPublicKey(signingKeyPem); + const ok = await subtleCrypto.verify( + { name: "RSA-PSS", saltLength: 32 }, + key, + arrayBufferBody(base64ToBytes(teeSignature)), + arrayBufferBody(msgHash), + ); + if (!ok) { + throw new OpenGradientError("TEE signature verification failed"); + } +} + +function responseContentForHash(responseBody: Record) { + const choice = getFirstChoice(responseBody); + const message = isRecord(choice?.message) ? choice.message : undefined; + if ( + choice?.finish_reason === "tool_calls" && + Array.isArray(message?.tool_calls) + ) { + return pythonJsonDumps(message.tool_calls); + } + return extractAssistantContent(responseBody); +} + +function extractAssistantContent(responseBody: Record) { + const choice = getFirstChoice(responseBody); + const message = isRecord(choice?.message) ? choice.message : undefined; + const content = message?.content; + if (Array.isArray(content)) { + return content + .map((part) => + isRecord(part) && typeof part.text === "string" ? part.text : "", + ) + .join(""); + } + return typeof content === "string" ? content : ""; +} + +function getFirstChoice( + responseBody: Record, +): Record | undefined { + const choices = responseBody.choices; + if (!Array.isArray(choices)) { + return undefined; + } + const choice = choices[0]; + return isRecord(choice) ? choice : undefined; +} + +function pythonJsonDumps(value: unknown, keyHint?: string): string { + if (value === null) return "null"; + if (Array.isArray(value)) { + return `[${value.map((item) => pythonJsonDumps(item)).join(", ")}]`; + } + if (isRecord(value)) { + return `{${Object.keys(value) + .sort() + .map( + (key) => `${jsonStringAscii(key)}: ${pythonJsonDumps(value[key], key)}`, + ) + .join(", ")}}`; + } + if (typeof value === "string") return jsonStringAscii(value); + if (typeof value === "number") { + if (keyHint === "temperature" && Number.isInteger(value)) { + return `${value}.0`; + } + return JSON.stringify(value); + } + if (typeof value === "boolean") { + return JSON.stringify(value); + } + return "null"; +} + +function jsonStringAscii(value: string) { + return JSON.stringify(value).replace( + /[\u007f-\uffff]/g, + (char) => `\\u${char.charCodeAt(0).toString(16).padStart(4, "0")}`, + ); +} + +async function importRsaPublicKey(pem: string) { + const der = base64ToBytes( + pem + .replace("-----BEGIN PUBLIC KEY-----", "") + .replace("-----END PUBLIC KEY-----", "") + .replace(/\s/g, ""), + ); + return subtleCrypto.importKey( + "spki", + der, + { name: "RSA-PSS", hash: "SHA-256" }, + false, + ["verify"], + ); +} + +function routeMetadata( + config: OHTTPGatewayConfig, + requestUrl: string, +): OHTTPRouteMetadata { + return { + relayHost: hostFromEndpoint(requestUrl), + teeHost: config.teeGateway.host, + teeId: config.teeId, + teeEndpoint: config.teeGateway.endpoint, + }; +} + +function publicKeyBytes(value: string): Uint8Array { + return hexToBytes(cleanHex(value)); +} + +function cleanHex(value: string): string { + return value.startsWith("0x") ? value.slice(2) : value; +} + +function i2osp(value: number, length: number) { + const out = new Uint8Array(length); + for (let i = length - 1, n = value; i >= 0; i -= 1, n >>= 8) { + out[i] = n & 0xff; + } + return out; +} + +function decodeVarint( + bytes: Uint8Array, + offset: number, +): { value: number; offset: number } | null { + if (offset >= bytes.length) { + return null; + } + + const first = bytes[offset]; + const length = 1 << (first >> 6); + if (offset + length > bytes.length) { + return null; + } + + let value = BigInt(first & 0x3f); + for (let i = 1; i < length; i += 1) { + value = (value << BigInt(8)) | BigInt(bytes[offset + i]); + } + + if (value > BigInt(Number.MAX_SAFE_INTEGER)) { + throw new OpenGradientError("Chunked OHTTP frame is too large"); + } + return { value: Number(value), offset: offset + length }; +} + +function xorBytes(left: Uint8Array, right: Uint8Array): Uint8Array { + const out = new Uint8Array(left.length); + for (let i = 0; i < left.length; i += 1) { + out[i] = left[i] ^ right[i]; + } + return out; +} + +function uint256Bytes(value: bigint) { + const out = new Uint8Array(32); + const byteShift = BigInt(8); + const byteMask = BigInt(0xff); + for (let i = 31, n = value; i >= 0; i -= 1, n >>= byteShift) { + out[i] = Number(n & byteMask); + } + return out; +} + +function base64ToBytes(value: string) { + return Uint8Array.from(Buffer.from(value, "base64")); +} + +function hostFromEndpoint(endpoint: string) { + try { + return new URL(endpoint).host; + } catch { + return endpoint; + } +} + +function arrayBufferBody(bytes: Uint8Array): ArrayBuffer { + return bytes.buffer.slice( + bytes.byteOffset, + bytes.byteOffset + bytes.byteLength, + ) as ArrayBuffer; +} + +function headersToRecord(headers: HeadersInit | undefined): Record { + const out: Record = {}; + if (!headers) return out; + if (headers instanceof Headers) { + headers.forEach((value, key) => { + out[key] = value; + }); + return out; + } + if (Array.isArray(headers)) { + for (const [key, value] of headers) out[key] = value; + return out; + } + for (const [key, value] of Object.entries(headers)) { + if (value !== undefined && value !== null) { + out[key] = String(value); + } + } + return out; +} + +async function safeResponseText(response: Response): Promise { + const text = await response.text().catch(() => ""); + if (!text) return "empty response"; + try { + const body = JSON.parse(text); + return body?.detail ? String(body.detail) : text; + } catch { + return text; + } +} + +function trimSlash(url: string): string { + return url.endsWith("/") ? url.slice(0, -1) : url; +} + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null && !Array.isArray(value); +} + +function expectRecord(value: unknown, name: string): Record { + if (!isRecord(value)) { + throw new OpenGradientError(`Malformed OHTTP config response: ${name}`); + } + return value; +} + +function numberField( + record: Record, + snake: string, + camel?: string, +): number { + const value = record[snake] ?? (camel ? record[camel] : undefined); + if (typeof value !== "number") { + throw new OpenGradientError(`Malformed OHTTP config response: ${snake}`); + } + return value; +} + +function stringField( + record: Record, + snake: string, + camel?: string, +): string { + const value = record[snake] ?? (camel ? record[camel] : undefined); + if (typeof value !== "string") { + throw new OpenGradientError(`Malformed OHTTP config response: ${snake}`); + } + return value; +} + +function optionalStringField( + record: Record, + snake: string, + camel?: string, +): string | undefined { + const value = record[snake] ?? (camel ? record[camel] : undefined); + return typeof value === "string" ? value : undefined; +} diff --git a/src/teeRegistry.ts b/src/teeRegistry.ts index d8b5971..c0f5482 100644 --- a/src/teeRegistry.ts +++ b/src/teeRegistry.ts @@ -18,12 +18,28 @@ export interface TEEEndpoint { teeId: string; /** HTTPS endpoint URL of the TEE. */ endpoint: string; + /** DER-encoded public key used by the TEE for response attestations. */ + publicKeyDer: Uint8Array; /** DER-encoded X.509 certificate bytes as stored in the registry. */ tlsCertDer: Uint8Array; /** Wallet address that receives x402 payments for this TEE. */ paymentAddress: string; } +export interface TEEOHTTPConfig { + keyId: number; + kemId: number; + kdfId: number; + aeadId: number; + publicKey: Uint8Array; + keyConfig: Uint8Array; +} + +/** A verified TEE endpoint that also has an OHTTP HPKE key config. */ +export interface TEEOHTTPEndpoint extends TEEEndpoint { + ohttpConfig: TEEOHTTPConfig; +} + interface RawTEEInfo { owner: Address; paymentAddress: Address; @@ -37,6 +53,23 @@ interface RawTEEInfo { lastHeartbeatAt: bigint; } +interface RawOHTTPConfig { + keyId: number; + kemId: number; + kdfId: number; + aeadId: number; + publicKey: Hex; + keyConfig: Hex; + registeredAt: bigint; + updatedAt: bigint; +} + +interface RawTEERecordWithOHTTPConfig { + teeId: Hex; + tee: RawTEEInfo; + ohttpConfig: RawOHTTPConfig; +} + function hexToBytes(hex: Hex): Uint8Array { const cleaned = hex.startsWith("0x") ? hex.slice(2) : hex; if (cleaned.length === 0) return new Uint8Array(0); @@ -105,8 +138,66 @@ export class TEERegistry { out.push({ teeId: keccak256(tee.publicKey), endpoint: tee.endpoint, + publicKeyDer: hexToBytes(tee.publicKey), + tlsCertDer: hexToBytes(tee.tlsCertificate), + paymentAddress: tee.paymentAddress, + }); + } + return out; + } + + /** + * Return active TEEs of the given type that have a registered OHTTP HPKE + * configuration. + */ + async getActiveOHTTPTEEsByType(teeType: number): Promise { + let records: readonly RawTEERecordWithOHTTPConfig[]; + try { + records = (await this.client.readContract({ + address: this.address, + abi: TEE_REGISTRY_ABI, + functionName: "getActiveTEERecordsWithOHTTPConfig", + args: [teeType], + })) as readonly RawTEERecordWithOHTTPConfig[]; + } catch (e) { + // eslint-disable-next-line no-console + console.warn( + `Failed to fetch active OHTTP TEEs from registry (type=${teeType}): ${String(e)}`, + ); + return []; + } + + const out: TEEOHTTPEndpoint[] = []; + for (const record of [...records].sort((left, right) => + Number(right.tee.lastHeartbeatAt - left.tee.lastHeartbeatAt), + )) { + const { tee, ohttpConfig } = record; + if ( + !tee.enabled || + !tee.endpoint || + !tee.tlsCertificate || + tee.tlsCertificate === "0x" || + !tee.publicKey || + tee.publicKey === "0x" || + !ohttpConfig.publicKey || + ohttpConfig.publicKey === "0x" + ) { + continue; + } + out.push({ + teeId: record.teeId, + endpoint: tee.endpoint, + publicKeyDer: hexToBytes(tee.publicKey), tlsCertDer: hexToBytes(tee.tlsCertificate), paymentAddress: tee.paymentAddress, + ohttpConfig: { + keyId: ohttpConfig.keyId, + kemId: ohttpConfig.kemId, + kdfId: ohttpConfig.kdfId, + aeadId: ohttpConfig.aeadId, + publicKey: hexToBytes(ohttpConfig.publicKey), + keyConfig: hexToBytes(ohttpConfig.keyConfig), + }, }); } return out; @@ -123,4 +214,15 @@ export class TEERegistry { } return tees[Math.floor(Math.random() * tees.length)]; } + + /** + * Return a random active OHTTP-enabled LLM proxy TEE from the registry. + */ + async getLLMOHTTPTEE(): Promise { + const tees = await this.getActiveOHTTPTEEsByType(TEE_TYPE_LLM_PROXY); + if (tees.length === 0) { + return null; + } + return tees[Math.floor(Math.random() * tees.length)]; + } } diff --git a/src/types.ts b/src/types.ts index 7b026b5..4db509e 100644 --- a/src/types.ts +++ b/src/types.ts @@ -244,8 +244,12 @@ export interface StreamChunk { } export interface ClientConfig { - /** EVM private key (hex string, with or without 0x prefix). */ - privateKey: string; + /** + * EVM private key (hex string, with or without 0x prefix). Required for + * direct x402-paid `chat()` and `completion()` calls. Not required for + * relay-backed `chatOhttp()` calls because the relay pays the TEE. + */ + privateKey?: string; /** * Override with a hardcoded TEE LLM server URL (dev / self-hosted). When * set, the on-chain TEE registry is bypassed and TLS verification is @@ -262,6 +266,17 @@ export interface ClientConfig { rpcUrl?: string; /** Override the deployed TEERegistry contract address. */ teeRegistryAddress?: string; + /** + * Base URL for an OHTTP chat API relay. The default paths match the frontend: + * /api/v1/chat/ohttp/config and /api/v1/chat/ohttp. + */ + ohttpRelayUrl?: string; + /** Override the OHTTP config path. Defaults to /api/v1/chat/ohttp/config. */ + ohttpConfigPath?: string; + /** Override the OHTTP request path. Defaults to /api/v1/chat/ohttp. */ + ohttpRequestPath?: string; + /** Headers to attach to OHTTP config and request calls. */ + ohttpHeaders?: HeadersInit | (() => HeadersInit | Promise); } export class OpenGradientError extends Error {