diff --git a/packages/ai/src/ai/tools/annotations.ts b/packages/ai/src/ai/tools/annotations.ts index 096f40461..b83985b0f 100644 --- a/packages/ai/src/ai/tools/annotations.ts +++ b/packages/ai/src/ai/tools/annotations.ts @@ -1,7 +1,12 @@ import { tool } from "ai"; import dayjs from "dayjs"; import { z } from "zod"; -import { callRPCProcedure, createToolLogger, getAppContext } from "./utils"; +import { + callRPCProcedure, + createToolLogger, + getAppContext, + resolveToolWebsite, +} from "./utils"; const logger = createToolLogger("Annotations Tools"); @@ -95,10 +100,11 @@ export function createAnnotationTools() { execute: async ({ websiteId, chartType, chartContext }, options) => { const context = getAppContext(options); try { + const resolved = resolveToolWebsite(context, websiteId); const result = await callRPCProcedure( "annotations", "list", - { websiteId, chartType, chartContext }, + { websiteId: resolved.websiteId, chartType, chartContext }, context ); return { @@ -107,7 +113,7 @@ export function createAnnotationTools() { }; } catch (error) { logger.error("Failed to list annotations", { - websiteId, + websiteId: websiteId, chartType, error, }); @@ -140,6 +146,7 @@ export function createAnnotationTools() { options ) => { const context = getAppContext(options); + const resolved = resolveToolWebsite(context, websiteId); try { if (!confirmed) { const dateRangePreview = `${chartContext.dateRange.start_date} to ${chartContext.dateRange.end_date} (${chartContext.dateRange.granularity})`; @@ -149,7 +156,7 @@ export function createAnnotationTools() { message: "Please review the annotation details below and confirm if you want to create it:", annotation: { - websiteId, + websiteId: resolved.websiteId, chartType, dateRange: dateRangePreview, annotationType, @@ -170,7 +177,7 @@ export function createAnnotationTools() { "annotations", "create", { - websiteId, + websiteId: resolved.websiteId, chartType, chartContext, annotationType, diff --git a/packages/ai/src/ai/tools/utils/context.test.ts b/packages/ai/src/ai/tools/utils/context.test.ts index 92096cdea..b0258acfe 100644 --- a/packages/ai/src/ai/tools/utils/context.test.ts +++ b/packages/ai/src/ai/tools/utils/context.test.ts @@ -86,4 +86,46 @@ describe("resolveToolWebsite", () => { expect(() => resolveToolWebsite(ctx)).toThrow(/multiple websites/); }); + + it("resolves a domain name matching ctx.websiteDomain to the context websiteId", () => { + // Insights agent passes domain names (e.g. "databuddy.cc") because LLM confuses + // domain with websiteId after seeing domain-keyed SQL queries. + const ctx = makeCtx({ + websiteId: "internal-id-xyz", + websiteDomain: "databuddy.cc", + }); + + expect(resolveToolWebsite(ctx, "databuddy.cc")).toEqual({ + websiteId: "internal-id-xyz", + domain: "databuddy.cc", + }); + }); + + it("resolves a domain name matching an accessible website's domain", () => { + const ctx = makeCtx({ + accessibleWebsites: [ + { id: "web_a", domain: "a.com", name: null, isPublic: null, createdAt: null }, + { id: "web_b", domain: "b.com", name: null, isPublic: null, createdAt: null }, + ], + }); + + expect(resolveToolWebsite(ctx, "b.com")).toEqual({ + websiteId: "web_b", + domain: "b.com", + }); + }); + + it("still rejects a domain that is not ctx.websiteDomain and not in accessible list", () => { + const ctx = makeCtx({ + websiteId: "internal-id-xyz", + websiteDomain: "databuddy.cc", + accessibleWebsites: [ + { id: "web_a", domain: "a.com", name: null, isPublic: null, createdAt: null }, + ], + }); + + expect(() => resolveToolWebsite(ctx, "unknown.com")).toThrow( + /not in this workspace/ + ); + }); }); diff --git a/packages/ai/src/ai/tools/utils/context.ts b/packages/ai/src/ai/tools/utils/context.ts index 48b155ecb..767e1f271 100644 --- a/packages/ai/src/ai/tools/utils/context.ts +++ b/packages/ai/src/ai/tools/utils/context.ts @@ -27,15 +27,29 @@ export function resolveToolWebsite( (id === ctx.websiteId ? ctx.websiteDomain : undefined); if (inputWebsiteId) { - const isAccessible = + // Try exact ID match + const isAccessibleById = accessible.some((w) => w.id === inputWebsiteId) || inputWebsiteId === ctx.websiteId; - if (!isAccessible) { - throw new Error( - `Website "${inputWebsiteId}" is not in this workspace. Call list_websites to see available websites.` - ); + if (isAccessibleById) { + return { websiteId: inputWebsiteId, domain: domainFor(inputWebsiteId) }; } - return { websiteId: inputWebsiteId, domain: domainFor(inputWebsiteId) }; + + // Try domain name match (model may pass a domain instead of an ID) + const byDomain = accessible.find((w) => w.domain === inputWebsiteId); + if (byDomain) { + return { websiteId: byDomain.id, domain: byDomain.domain ?? undefined }; + } + if (ctx.websiteDomain === inputWebsiteId) { + const fallbackId = ctx.defaultWebsiteId ?? ctx.websiteId; + if (fallbackId) { + return { websiteId: fallbackId, domain: ctx.websiteDomain }; + } + } + + throw new Error( + `Website "${inputWebsiteId}" is not in this workspace. Call list_websites to see available websites.` + ); } const fallbackId = ctx.defaultWebsiteId ?? ctx.websiteId;