diff --git a/.env.sample b/.env.sample index 9847a1d..d4f9581 100644 --- a/.env.sample +++ b/.env.sample @@ -1 +1,5 @@ -OPENAI_API_KEY= \ No newline at end of file +OPENAI_API_KEY= +OPENAI_ENDPOINT= + +# For running the tests +IDENTITY_SERVER_URL= \ No newline at end of file diff --git a/bun.lockb b/bun.lockb index 1ee564d..36941d6 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/package.json b/package.json index 61480d8..824d576 100644 --- a/package.json +++ b/package.json @@ -7,7 +7,8 @@ "typescript": "^5.7.3" }, "scripts": { - "docs": "docsify serve docs" + "docs": "docsify serve docs", + "test": "bun test tests/e2e.test.ts" }, "type": "module", "workspaces": [ @@ -15,4 +16,4 @@ "packages/mcp", "chat" ] -} \ No newline at end of file +} diff --git a/packages/daemon/src/daemon.ts b/packages/daemon/src/daemon.ts index c1d92d4..72adff9 100644 --- a/packages/daemon/src/daemon.ts +++ b/packages/daemon/src/daemon.ts @@ -6,12 +6,18 @@ import { type IMessageLifecycle, type IHook, type IHookLog, + type MultiMessageSchema, } from "./types"; import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { SSEClientTransport } from "./SSEClientTransport.js"; import type { TextContent } from "@modelcontextprotocol/sdk/types.js"; import type { Keypair } from "@solana/web3.js"; -import { createPrompt, generateText } from "./llm.js"; +import { + createMultiplePrompts, + createPrompt, + generateText, + generateTextWithMessages, +} from "./llm.js"; import nacl from "tweetnacl"; import { nanoid } from "nanoid"; import { Buffer } from "buffer"; @@ -66,7 +72,7 @@ export class Daemon implements IDaemon { ) { this.modelApiKeys = { generationKey: opts.modelApiKeys.generationKey, - } + }; this.keypair = opts.privateKey; @@ -270,6 +276,43 @@ export class Daemon implements IDaemon { toolArgs?: { [key: string]: any; // key = `serverUrl-toolName` }; + /** + * Use a custom system prompt instead of the default one + */ + customSystemPrompt?: string; + /** + * Opt to use a custom message template instead of the default one. + * + * This involves passing a string with the following placeholders: + * - {{name}} + * - {{identityPrompt}} + * - {{message}} + * - {{context}} + * - {{tools}} + * + * If any of these placeholders are missing, then that section will be omitted. + * + * @example + * ```typescript + * const userTemplate = ` + * # Name + * {{name}} + * + * # Identity + * {{identity}} + * + * # Message + * {{message}} + * + * # Context + * {{context}} + * + * # Tools + * {{tools}} + * `; + * ``` + */ + customMessageTemplate?: string; } ): Promise { if (!this.keypair) { @@ -339,11 +382,214 @@ export class Daemon implements IDaemon { } // Generate Text - lifecycle.generatedPrompt = createPrompt(lifecycle); + lifecycle.generatedPrompt = createPrompt( + lifecycle, + opts?.customMessageTemplate + ); lifecycle.output = await generateText( this.character.modelSettings.generation, this.modelApiKeys.generationKey, - lifecycle.generatedPrompt + lifecycle.generatedPrompt, + opts?.customSystemPrompt + ); + + if (actions) { + let actionPromises: Promise[] = []; + for (const tool of this.tools.action) { + const toolArgs = + opts?.toolArgs?.[`${tool.serverUrl}-${tool.tool.name}`]; + actionPromises.push( + this.callTool(tool.tool.name, tool.serverUrl, { + lifecycle, + args: toolArgs, + }) + ); + } + + const actionResults = await Promise.all(actionPromises); + lifecycle.actionsLog = actionResults + .map((lfcyl) => { + return lfcyl.actionsLog; + }) + .flat(); + lifecycle.hooks = actionResults + .map((lfcyl) => { + return lfcyl.hooks; + }) + .flat(); + + let hookPromises: Promise[] = []; + for (const hook of lifecycle.hooks) { + hookPromises.push(this.hook(hook)); + } + + const hookResults = await Promise.all(hookPromises); + lifecycle.hooksLog = hookResults + .map((hookResult) => { + return hookResult; + }) + .flat(); + } + + if (postProcess) { + let postProcessPromises: Promise[] = []; + for (const tool of this.tools.postProcess) { + const toolArgs = + opts?.toolArgs?.[`${tool.serverUrl}-${tool.tool.name}`]; + postProcessPromises.push( + this.callTool(tool.tool.name, tool.serverUrl, { + lifecycle, + args: toolArgs, + }) + ); + } + + const postProcessResults = await Promise.all(postProcessPromises); + lifecycle.postProcessLog = postProcessResults + .map((lfcyl) => { + return lfcyl.postProcessLog; + }) + .flat(); + } + + return lifecycle; + } + + async multipleMessages( + messages: MultiMessageSchema[], + opts?: { + channelId?: string; + context?: boolean; + actions?: boolean; + postProcess?: boolean; + toolArgs?: { + [key: string]: any; // key = `serverUrl-toolName` + }; + /** + * Use a custom system prompt instead of the default one + */ + customSystemPrompt?: string; + /** + * Opt to use a custom message template instead of the default one. + * + * This involves passing a string with the following placeholders: + * - {{name}} + * - {{identityPrompt}} + * - {{message}} + * - {{context}} + * - {{tools}} + * + * If any of these placeholders are missing, then that section will be omitted. + * + * @example + * ```typescript + * const userTemplate = ` + * # Name + * {{name}} + * + * # Identity + * {{identity}} + * + * # Message + * {{message}} + * + * # Context + * {{context}} + * + * # Tools + * {{tools}} + * `; + * ``` + */ + customMessageTemplate?: string; + } + ): Promise { + if (!this.keypair) { + throw new Error("Keypair not found"); + } + + if (!this.character) { + throw new Error("Character not found"); + } + + if (!this.modelApiKeys.generationKey) { + throw new Error("Model API keys not found"); + } + + const context = opts?.context ?? true; + const actions = opts?.actions ?? true; + const postProcess = opts?.postProcess ?? true; + + const formattedMessages = messages.map( + (m) => ` + # ${m.role} + ${m.content} + ` + ); + + // Lifecycle: message -> fetchContext -> generateText -> takeActions -> hooks -> callHooks -> postProcess + let lifecycle: IMessageLifecycle = { + daemonPubkey: this.keypair.publicKey.toBase58(), + daemonName: this.character?.name ?? "", + messageId: nanoid(), + message: formattedMessages, + createdAt: new Date().toISOString(), + approval: "", + channelId: opts?.channelId ?? null, + identityPrompt: + this.character?.identityPrompt ?? DEFAULT_IDENTITY_PROMPT(this), + context: [], + tools: [], + generatedPrompt: "", + output: "", + hooks: [], + hooksLog: [], + actionsLog: [], + postProcessLog: [], + }; + + // Generate Approval + lifecycle = this.generateApproval(lifecycle); + + if (context) { + let contextPromises: Promise[] = []; + for (const tool of this.tools.context) { + const toolArgs = + opts?.toolArgs?.[`${tool.serverUrl}-${tool.tool.name}`]; + contextPromises.push( + this.callTool(tool.tool.name, tool.serverUrl, { + lifecycle, + args: toolArgs, + }) + ); + } + + const contextResults = await Promise.all(contextPromises); + lifecycle.context = contextResults + .map((lfcyl) => { + return lfcyl.context; + }) + .flat(); + lifecycle.tools = contextResults + .map((lfcyl) => { + return lfcyl.tools; + }) + .flat(); + } + + // Construct messages with custom prompt if provided + const prompts = createMultiplePrompts( + lifecycle, + messages, + opts?.customMessageTemplate + ); + lifecycle.generatedPrompt = prompts.map((p) => p.content); + // Generate Text given multiple messages + lifecycle.output = await generateTextWithMessages( + this.character.modelSettings.generation, + this.modelApiKeys.generationKey, + prompts, + opts?.customSystemPrompt ); if (actions) { diff --git a/packages/daemon/src/llm.ts b/packages/daemon/src/llm.ts index 74a9ef5..aaa0546 100644 --- a/packages/daemon/src/llm.ts +++ b/packages/daemon/src/llm.ts @@ -1,6 +1,13 @@ import OpenAI from "openai"; -import type { IMessageLifecycle, ModelSettings } from "./types"; +import type { + IMessageLifecycle, + ModelSettings, + MultiMessageSchema, +} from "./types"; import Anthropic from "@anthropic-ai/sdk"; +import type { ChatCompletionMessageParam } from "openai/resources"; +import type { MessageParam } from "@anthropic-ai/sdk/resources"; +import { parseTemplate } from "./templateParser"; export const SYSTEM_PROMPT = ` You are an AI agent operating within a framework that provides you with: @@ -74,7 +81,8 @@ export async function generateEmbeddings( export async function generateText( generationModelSettings: ModelSettings, generationModelKey: string, - userMessage: string + userMessage: string, + customSystemPrompt?: string ): Promise { switch (generationModelSettings?.provider) { case "openai": @@ -89,7 +97,7 @@ export async function generateText( messages: [ { role: "system", - content: SYSTEM_PROMPT, + content: customSystemPrompt ?? SYSTEM_PROMPT, }, { role: "user", @@ -110,7 +118,7 @@ export async function generateText( const anthropicResponse = await anthropic.messages.create({ model: generationModelSettings.name, - system: SYSTEM_PROMPT, + system: customSystemPrompt ?? SYSTEM_PROMPT, messages: [ { role: "user", @@ -126,8 +134,96 @@ export async function generateText( } } -export function createPrompt(lifecycle: IMessageLifecycle): string { - return ` +/** + * Generates text using an OpenAI or Anthropic compatible model, using multiple messages. + * + * @param generationModelSettings The settings for the generation model. + * @param generationModelKey The API key for the generation model. + * @param messages An array of messages to generate text from. + * @param customSystemPrompt An optional custom system prompt to use for the generation. + * @returns The generated text as a string. + */ +export async function generateTextWithMessages( + generationModelSettings: ModelSettings, + generationModelKey: string, + messages: MultiMessageSchema[], + customSystemPrompt?: string +): Promise { + switch (generationModelSettings?.provider) { + case "openai": + const openai = new OpenAI({ + apiKey: generationModelKey, + baseURL: generationModelSettings.endpoint, + dangerouslyAllowBrowser: true, + }); + + const formattedMessages = messages.map( + (m) => + ({ + role: m.role, + content: m.content, + }) as ChatCompletionMessageParam + ); + + const openaiResponse = await openai.chat.completions.create({ + model: generationModelSettings.name, + messages: [ + { + role: "system", + content: customSystemPrompt ?? SYSTEM_PROMPT, + }, + ...formattedMessages, + ], + temperature: generationModelSettings.temperature, + max_completion_tokens: generationModelSettings.maxTokens, + }); + + return openaiResponse.choices[0].message.content ?? ""; + break; + case "anthropic": + const anthropic = new Anthropic({ + apiKey: generationModelKey, + baseURL: generationModelSettings.endpoint, + }); + + const anthropicResponse = await anthropic.messages.create({ + model: generationModelSettings.name, + system: customSystemPrompt ?? SYSTEM_PROMPT, + messages: [ + ...messages.map( + (m) => + ({ + role: m.role, + content: m.content, + }) as MessageParam + ), + ], + max_tokens: generationModelSettings.maxTokens ?? 1000, + temperature: generationModelSettings.temperature ?? 0.2, + }); + + return anthropicResponse.content.join("\n"); + break; + } +} + +export function createPrompt( + lifecycle: IMessageLifecycle, + overridePromptTemplate?: string +): string { + if (overridePromptTemplate) { + return parseTemplate( + { + lifecycle, + message: { + role: "user", + content: lifecycle.message as string, + }, + }, + overridePromptTemplate + ).content; + } else { + return ` # Name ${lifecycle.daemonName} @@ -143,4 +239,52 @@ export function createPrompt(lifecycle: IMessageLifecycle): string { # Tools ${lifecycle.tools?.join("\n")} `; + } +} + +export function createMultiplePrompts( + lifecycle: IMessageLifecycle, + messages: MultiMessageSchema[], + overridePromptTemplate?: string +): MultiMessageSchema[] { + if (overridePromptTemplate) { + return messages.map((m) => { + return parseTemplate( + { + lifecycle, + message: { role: m.role, content: m.content }, + }, + overridePromptTemplate + ); + }); + } else { + return messages.map((m) => { + if (m.role === "user") { + return { + role: "user", + content: ` + # Name + ${lifecycle.daemonName} + + # Identity Prompt + ${lifecycle.identityPrompt} + + # User Message + ${m.content} + + # Context + ${lifecycle.context?.join("\n")} + + # Tools + ${lifecycle.tools?.join("\n")} + `, + }; + } else { + return { + role: "assistant", + content: m.content, + }; + } + }); + } } diff --git a/packages/daemon/src/templateParser.ts b/packages/daemon/src/templateParser.ts new file mode 100644 index 0000000..f389785 --- /dev/null +++ b/packages/daemon/src/templateParser.ts @@ -0,0 +1,61 @@ +import type { IMessageLifecycle } from "./types"; + +const MESSAGE_ALLOWED_VARIABLES = [ + "name", + "identity", + "message", + "context", + "tools", +] as const; + +type TemplateData = { + lifecycle: IMessageLifecycle; + message: { role: "user" | "assistant"; content: string }; +}; + +type MessageAllowedVariable = (typeof MESSAGE_ALLOWED_VARIABLES)[number]; + +function validateTemplate(template: string): string[] { + const variablePattern = /\{\{(\w+)\}\}/g; + const matches = Array.from(template.matchAll(variablePattern)); + const usedVars = matches.map((match) => match[1]); + + const invalidVars = usedVars.filter( + (v) => !MESSAGE_ALLOWED_VARIABLES.includes(v as MessageAllowedVariable) + ); + + return invalidVars; +} + +export function parseTemplate( + data: TemplateData, + template: string +): { role: "user" | "assistant"; content: string } { + const invalidVars = validateTemplate(template); + if (invalidVars.length > 0) { + throw new Error( + `Template contains invalid variables: ${invalidVars.join(", ")}` + ); + } + + const variableMap = { + name: data.lifecycle.daemonName, + identity: data.lifecycle.identityPrompt, + message: data.message.content, + context: data.lifecycle.context?.join("\n") ?? "", + tools: data.lifecycle.tools?.join("\n") ?? "", + }; + + return { + role: data.message.role, + content: template.replace( + /\{\{(\w+)\}\}/g, + (_, variable: MessageAllowedVariable) => { + if (MESSAGE_ALLOWED_VARIABLES.includes(variable)) { + return variableMap[variable] as string; + } + return `{{${variable}}}`; + } + ), + }; +} diff --git a/packages/daemon/src/types.ts b/packages/daemon/src/types.ts index ebba9c7..02b0ec5 100644 --- a/packages/daemon/src/types.ts +++ b/packages/daemon/src/types.ts @@ -20,7 +20,7 @@ export type IHookLog = any; export const ZMessageLifecycle = z.object({ daemonPubkey: z.string(), - message: z.string(), + message: z.string().or(z.array(z.string())), messageId: z.string(), createdAt: z.string(), approval: z.string(), @@ -29,7 +29,7 @@ export const ZMessageLifecycle = z.object({ identityPrompt: z.string().nullable(), context: z.array(z.string()).default([]), tools: z.array(z.string()).default([]), - generatedPrompt: z.string().default(""), + generatedPrompt: z.string().or(z.array(z.string())).default(""), output: z.string().default(""), hooks: z.array(ZHook).default([]), hooksLog: z.array(z.string()).default([]), @@ -127,6 +127,11 @@ export interface ToolRegistration { tool: ITool; } +export interface MultiMessageSchema { + role: "user" | "assistant"; + content: string; +} + export interface IDaemon { // Properties character: Character | undefined; diff --git a/tests/e2e.test.ts b/tests/e2e.test.ts new file mode 100644 index 0000000..a659c02 --- /dev/null +++ b/tests/e2e.test.ts @@ -0,0 +1,104 @@ +import { + Daemon, + type Character, + type MultiMessageSchema, +} from "@spacemangaming/daemon"; +import { Keypair } from "@solana/web3.js"; +import { expect, test } from "bun:test"; + +test( + "Should use message templates & handle multiple messages", + async () => { + const daemon = new Daemon(); + + const identityKp = Keypair.generate(); + const prompt = + "You are Bob, a helpful assistant who has a knack for building things."; + + await daemon.init(process.env.IDENTITY_SERVER_URL!, { + character: { + name: "Bob", + pubkey: identityKp.publicKey.toBase58(), + identityPrompt: prompt, + identityServerUrl: process.env.IDENTITY_SERVER_URL!, + modelSettings: { + embedding: { + provider: "openai", + name: "text-embedding-3-small", + endpoint: process.env.OPENAI_ENDPOINT!, + apiKey: process.env.OPENAI_API_KEY!, + }, + generation: { + provider: "openai", + name: "gpt-4o", + endpoint: process.env.OPENAI_ENDPOINT!, + apiKey: process.env.OPENAI_API_KEY!, + }, + }, + bootstrap: [], + } as Character, + privateKey: identityKp, + modelApiKeys: { + generationKey: process.env.OPENAI_API_KEY!, + embeddingKey: process.env.OPENAI_API_KEY!, + }, + }); + + const lifecycle = await daemon.message("Hello!", { + context: true, + actions: true, + postProcess: false, + customSystemPrompt: + "As Bob The Builder, guide the user through a conversation.", + customMessageTemplate: ` + {{message}} + `, + }); + + expect(lifecycle).toBeDefined(); + + console.log("Message:", lifecycle.message); + console.log("Output:", lifecycle.output); + console.log("Generated Prompts:", lifecycle.generatedPrompt); + + const historyWithNewMessage: MultiMessageSchema[] = [ + { + role: "user", + content: lifecycle.message as string, + }, + { + role: "assistant", + content: lifecycle.output, + }, + { + role: "user", + content: "Awesome!", + }, + ]; + + console.log("History with new message:", historyWithNewMessage); + + const lifecycleWithMultiple = await daemon.multipleMessages( + historyWithNewMessage, + { + context: true, + actions: true, + postProcess: false, + customSystemPrompt: + "As Bob The Builder, guide the user through a conversation.", + customMessageTemplate: ` + {{message}} + `, + } + ); + + expect(lifecycleWithMultiple).toBeDefined(); + + const generatedPrompts = lifecycleWithMultiple.generatedPrompt as string[]; + + console.log("Messages:", lifecycleWithMultiple.message); + console.log("Output:", lifecycleWithMultiple.output); + console.log("Generated Prompts:", generatedPrompts); + }, + { timeout: 60000 } +);