diff --git a/.agents/skills/billing/SKILL.md b/.agents/skills/billing/SKILL.md new file mode 100644 index 0000000000..56113a0a47 --- /dev/null +++ b/.agents/skills/billing/SKILL.md @@ -0,0 +1,59 @@ +--- +name: billing +description: Use when changing Comp AI billing, Stripe products/prices, subscription checkout, org payment methods, entitlements, usage ledgers, invoices, or billing webhooks. +metadata: + short-description: Comp AI billing architecture +--- + +# Billing + +Comp AI billing is SKU-first and subscription-ready. Stripe is the payment provider; Comp AI owns catalog definitions, entitlement state, usage gating, and audit history. + +## Core Rules + +- Use `@trycompai/billing` for SKU keys, amounts, Stripe product IDs, Stripe price IDs, cadence, and included usage. +- Do not add product-specific nullable fields to `OrganizationBilling`. +- Keep org-level billing generic: `stripeCustomerId`, `stripePaymentMethodId`, and `paymentMethodUpdatedAt`. +- Store per-product state in generic per-SKU tables keyed by `skuKey`. +- Gate paid actions from local entitlement/usage state, not directly from Stripe object reads. +- Treat Stripe webhooks as eventually consistent, retryable, and possibly out of order. +- Make webhook and entitlement handling idempotent with Stripe event IDs, invoice IDs, subscription item IDs, and period start/end. +- Archive accidental live/test Stripe objects instead of reusing them blindly. + +## Current SKU Shape + +- Active background-check subscriptions: `background_checks_monthly_3` ($79/mo, 3 checks), `background_checks_monthly_10` ($199/mo, 10 checks), and `background_checks_monthly_20` ($399/mo, 20 checks). +- Active penetration-test subscriptions: `pentest_monthly_1` ($299/mo, 1 scan), `pentest_monthly_3` ($499/mo, 3 scans), and `pentest_monthly_5_current` ($899/mo, 5 scans). +- Deprecated / legacy SKUs remain in the catalog for historical Stripe records: `background_check_one_time`, `background_checks_monthly_25`, `pentest_monthly_4`, `pentest_monthly_5`, and `pentest_monthly_10`. + +Live catalog entries should only be added after deliberate live Stripe object creation. + +## Implementation Pattern + +1. Add or update SKU definitions and Stripe IDs in `packages/billing/src/index.ts` and `packages/billing/src/sku-definitions.ts`. +2. Store Stripe IDs in the catalog by environment rather than adding new env vars. +3. Create subscriptions through Stripe Checkout `mode: subscription`. +4. Use the shared Stripe customer default payment method unless the product explicitly needs overrides. +5. For allowance products, sync local subscription state from Stripe subscription items and period timestamps. +6. Record usage in the billing usage ledger with a stable idempotency key. +7. Record support-relevant mutations in billing audit events. + +## Webhooks + +Handle these events before launching subscription access: + +- `checkout.session.completed` +- `invoice.paid` +- `invoice.payment_failed` +- `invoice.payment_action_required` +- `customer.subscription.updated` +- `customer.subscription.deleted` + +Provision or renew allowance only for the matching subscription item and SKU. Do not use generic invoice-level periods for multi-item subscriptions. + +## Validation + +- Run `bun run db:generate` after Prisma schema changes. +- Run `bun run check:prisma-schemas` to catch stale copied schema fragments. +- Run catalog tests after SKU changes. +- Add webhook tests for duplicate events, out-of-order delivery, payment failure, action required, cancellation, and renewal. diff --git a/apps/api/Dockerfile.multistage b/apps/api/Dockerfile.multistage index 1830d7c426..85d4b07cd7 100644 --- a/apps/api/Dockerfile.multistage +++ b/apps/api/Dockerfile.multistage @@ -25,6 +25,7 @@ COPY packages/integration-platform/package.json ./packages/integration-platform/ COPY packages/tsconfig/package.json ./packages/tsconfig/ COPY packages/email/package.json ./packages/email/ COPY packages/company/package.json ./packages/company/ +COPY packages/billing/package.json ./packages/billing/ # Copy API package.json COPY apps/api/package.json ./apps/api/ @@ -55,6 +56,7 @@ COPY packages/integration-platform ./packages/integration-platform COPY packages/tsconfig ./packages/tsconfig COPY packages/email ./packages/email COPY packages/company ./packages/company +COPY packages/billing ./packages/billing # Copy API source COPY apps/api ./apps/api @@ -66,7 +68,8 @@ RUN cd packages/db && bun run build RUN cd packages/auth && bun run build \ && cd ../integration-platform && bun run build \ && cd ../email && bun run build \ - && cd ../company && bun run build + && cd ../company && bun run build \ + && cd ../billing && bun run build # Copy model files to api schema dir, then build NestJS app # Note: @prisma/client is already generated by packages/db build (generate-prisma-client-js.js) @@ -111,6 +114,7 @@ COPY --from=builder --chown=nestjs:nestjs /app/packages/integration-platform ./p COPY --from=builder --chown=nestjs:nestjs /app/packages/tsconfig ./packages/tsconfig COPY --from=builder --chown=nestjs:nestjs /app/packages/email ./packages/email COPY --from=builder --chown=nestjs:nestjs /app/packages/company ./packages/company +COPY --from=builder --chown=nestjs:nestjs /app/packages/billing ./packages/billing # Copy production node_modules (includes Prisma client already generated for linux/amd64) COPY --from=builder --chown=nestjs:nestjs /app/node_modules ./node_modules diff --git a/apps/api/buildspec.yml b/apps/api/buildspec.yml index 26b9091ed0..2ebf7b13c1 100644 --- a/apps/api/buildspec.yml +++ b/apps/api/buildspec.yml @@ -41,6 +41,7 @@ phases: - cd packages/db && bun run build && cd ../.. - cd packages/integration-platform && bun run build && cd ../.. - cd packages/company && bun run build && cd ../.. + - cd packages/billing && bun run build && cd ../.. - echo "Building NestJS application..." - echo "APP_NAME is set to $APP_NAME" @@ -81,11 +82,13 @@ phases: - rm -rf ../docker-build/node_modules/@trycompai/integration-platform - rm -rf ../docker-build/node_modules/@trycompai/auth - rm -rf ../docker-build/node_modules/@trycompai/company + - rm -rf ../docker-build/node_modules/@trycompai/billing - mkdir -p ../docker-build/node_modules/@trycompai/utils - mkdir -p ../docker-build/node_modules/@trycompai/db - mkdir -p ../docker-build/node_modules/@trycompai/integration-platform - mkdir -p ../docker-build/node_modules/@trycompai/auth - mkdir -p ../docker-build/node_modules/@trycompai/company + - mkdir -p ../docker-build/node_modules/@trycompai/billing - cp -r ../../packages/utils/src ../docker-build/node_modules/@trycompai/utils/ - cp ../../packages/utils/package.json ../docker-build/node_modules/@trycompai/utils/ - cp -r ../../packages/db/dist ../docker-build/node_modules/@trycompai/db/ @@ -96,10 +99,12 @@ phases: - cp ../../packages/auth/package.json ../docker-build/node_modules/@trycompai/auth/ - cp -r ../../packages/company/dist ../docker-build/node_modules/@trycompai/company/ - cp ../../packages/company/package.json ../docker-build/node_modules/@trycompai/company/ + - cp -r ../../packages/billing/dist ../docker-build/node_modules/@trycompai/billing/ + - cp ../../packages/billing/package.json ../docker-build/node_modules/@trycompai/billing/ - cp Dockerfile ../docker-build/ # Remove workspace dependencies from package.json (they're copied manually above) - - cat package.json | jq 'del(.dependencies["@trycompai/integration-platform"]) | del(.dependencies["@trycompai/auth"]) | del(.dependencies["@trycompai/company"])' > ../docker-build/package.json + - cat package.json | jq 'del(.dependencies["@trycompai/integration-platform"]) | del(.dependencies["@trycompai/auth"]) | del(.dependencies["@trycompai/company"]) | del(.dependencies["@trycompai/billing"])' > ../docker-build/package.json - cp ../../bun.lock ../docker-build/ || true - echo "Building Docker image..." diff --git a/apps/api/package.json b/apps/api/package.json index b36942bebc..8e9a3cb9ea 100644 --- a/apps/api/package.json +++ b/apps/api/package.json @@ -78,6 +78,7 @@ "@trigger.dev/build": "4.4.3", "@trigger.dev/sdk": "4.4.3", "@trycompai/auth": "workspace:*", + "@trycompai/billing": "workspace:*", "@trycompai/company": "workspace:*", "@trycompai/db": "workspace:*", "@trycompai/email": "workspace:*", @@ -169,7 +170,9 @@ "moduleNameMapper": { "^@db$": "/../prisma/index", "^@/(.*)$": "/$1", + "^\\./sku-definitions\\.js$": "/../../../packages/billing/src/sku-definitions.ts", "^@trycompai/auth$": "/../../../packages/auth/src/index.ts", + "^@trycompai/billing$": "/../../../packages/billing/src/index.ts", "^@trycompai/company$": "/../../../packages/company/src/index.ts", "^@trycompai/db$": "@prisma/client", "^@trycompai/email$": "/../../../packages/email/index.ts", @@ -183,7 +186,7 @@ "build": "nest build", "build:docker": "bunx prisma generate --schema=prisma/schema && nest build", "db:generate": "bun run db:getschema && bunx prisma generate --schema=prisma/schema", - "db:getschema": "find ../../packages/db/prisma/schema -name '*.prisma' ! -name 'schema.prisma' -exec cp {} prisma/schema/ \\;", + "db:getschema": "find prisma/schema -name '*.prisma' ! -name 'schema.prisma' -delete && find ../../packages/db/prisma/schema -name '*.prisma' ! -name 'schema.prisma' -exec cp {} prisma/schema/ \\;", "db:migrate": "cd ../../packages/db && bunx prisma migrate dev && cd ../../apps/api", "deploy:trigger-prod": "npx trigger.dev@4.4.3 deploy", "dev": "bunx concurrently --kill-others --names \"nest,trigger\" --prefix-colors \"green,blue\" \"nest start --watch\" \"trigger dev\"", diff --git a/apps/api/src/admin-organizations/admin-audit-log-context.spec.ts b/apps/api/src/admin-organizations/admin-audit-log-context.spec.ts new file mode 100644 index 0000000000..2fbd6bcc13 --- /dev/null +++ b/apps/api/src/admin-organizations/admin-audit-log-context.spec.ts @@ -0,0 +1,81 @@ +import { of } from 'rxjs'; +import { AdminAuditLogInterceptor } from './admin-audit-log.interceptor'; + +const mockCreate = jest.fn().mockResolvedValue({}); +const mockContextFind = jest.fn(); + +jest.mock('@db', () => ({ + AuditLogEntityType: { + organization: 'organization', + finding: 'finding', + policy: 'policy', + task: 'task', + vendor: 'vendor', + }, + Prisma: {}, + db: { + auditLog: { + get create() { + return mockCreate; + }, + }, + context: { + get findFirst() { + return mockContextFind; + }, + }, + }, +})); + +jest.mock('../audit/audit-log.constants', () => ({ + MUTATION_METHODS: new Set(['POST', 'PATCH', 'PUT', 'DELETE']), + SENSITIVE_KEYS: new Set(), +})); + +describe('AdminAuditLogInterceptor context parsing', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockContextFind.mockResolvedValue({ + question: 'Which subprocessors are used for production data?', + }); + }); + + it('keeps context entity ids instead of treating them as org-level actions', (done) => { + const request = { + method: 'PATCH', + url: '/v1/admin/organizations/org_1/context/ctx_1', + params: { orgId: 'org_1' }, + body: { answer: 'Updated answer' }, + userId: 'usr_admin', + }; + const context = { + getHandler: () => context, + switchToHttp: () => ({ getRequest: () => request }), + } as unknown as Parameters[0]; + const interceptor = new AdminAuditLogInterceptor({ + get: jest.fn().mockReturnValue(false), + } as never); + + interceptor + .intercept(context, { handle: () => of({ ok: true }) }) + .subscribe({ + complete: () => { + setTimeout(() => { + expect(mockContextFind).toHaveBeenCalledWith({ + where: { id: 'ctx_1', organizationId: 'org_1' }, + select: { question: true }, + }); + expect(mockCreate).toHaveBeenCalledWith({ + data: expect.objectContaining({ + entityType: 'organization', + entityId: 'ctx_1', + description: + "Updated context 'Which subprocessors are used for production data?'", + }), + }); + done(); + }, 50); + }, + }); + }); +}); diff --git a/apps/api/src/admin-organizations/admin-audit-log.interceptor.spec.ts b/apps/api/src/admin-organizations/admin-audit-log.interceptor.spec.ts index a8b310b2e7..ccdca4374c 100644 --- a/apps/api/src/admin-organizations/admin-audit-log.interceptor.spec.ts +++ b/apps/api/src/admin-organizations/admin-audit-log.interceptor.spec.ts @@ -28,7 +28,7 @@ jest.mock('@db', () => ({ return mockPolicyFind; }, }, - taskItem: { + task: { get findFirst() { return mockTaskFind; }, @@ -72,6 +72,7 @@ function buildContext(overrides: { }; return { + getHandler: () => buildContext, switchToHttp: () => ({ getRequest: () => request }), } as unknown as Parameters[0]; } @@ -310,4 +311,29 @@ describe('AdminAuditLogInterceptor', () => { }, }); }); + + it('should audit billing mutations against the organization id', (done) => { + const ctx = buildContext({ + method: 'PUT', + url: '/v1/admin/organizations/org_1/billing/preferences', + params: { orgId: 'org_1' }, + body: { billingEmail: 'accounts@example.com' }, + }); + + interceptor.intercept(ctx, nextHandler).subscribe({ + complete: () => { + setTimeout(() => { + expect(mockCreate).toHaveBeenCalledWith({ + data: expect.objectContaining({ + organizationId: 'org_1', + entityType: 'organization', + entityId: 'org_1', + description: 'Updated billing', + }), + }); + done(); + }, 50); + }, + }); + }); }); diff --git a/apps/api/src/admin-organizations/admin-audit-log.interceptor.ts b/apps/api/src/admin-organizations/admin-audit-log.interceptor.ts index 121f1caf76..9db723b935 100644 --- a/apps/api/src/admin-organizations/admin-audit-log.interceptor.ts +++ b/apps/api/src/admin-organizations/admin-audit-log.interceptor.ts @@ -24,6 +24,10 @@ const SEGMENT_TO_RESOURCE: Record< entity: AuditLogEntityType.pentest, singular: 'pentest credits', }, + billing: { + entity: AuditLogEntityType.organization, + singular: 'billing', + }, }; const SPECIAL_ACTION_DESCRIPTIONS: Record = { @@ -153,6 +157,14 @@ export class AdminAuditLogInterceptor implements NestInterceptor { } const mapped = SEGMENT_TO_RESOURCE[resourceSegment]; + if (resourceSegment === 'billing' && mapped) { + return { + resource: mapped.singular, + entityType: mapped.entity, + entityId: orgId, + actionSegment: possibleEntityId ?? null, + }; + } return { resource: mapped?.singular ?? resourceSegment, diff --git a/apps/api/src/admin-organizations/admin-billing-actions.service.spec.ts b/apps/api/src/admin-organizations/admin-billing-actions.service.spec.ts new file mode 100644 index 0000000000..af5f54ce02 --- /dev/null +++ b/apps/api/src/admin-organizations/admin-billing-actions.service.spec.ts @@ -0,0 +1,91 @@ +import { HttpStatus, NotFoundException } from '@nestjs/common'; +import { db } from '@db'; +import { AdminBillingActionsService } from './admin-billing-actions.service'; + +jest.mock('@db', () => ({ + db: { + organization: { findUnique: jest.fn() }, + organizationBilling: { findUnique: jest.fn() }, + organizationBillingSubscription: { + findMany: jest.fn(), + findFirst: jest.fn(), + updateMany: jest.fn(), + }, + billingAuditEvent: { create: jest.fn() }, + }, +})); + +const mockedDb = db as unknown as { + organization: { findUnique: jest.Mock }; + organizationBilling: { findUnique: jest.Mock }; + organizationBillingSubscription: { findMany: jest.Mock }; +}; + +describe('AdminBillingActionsService', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockedDb.organization.findUnique.mockResolvedValue({ + id: 'org_1', + name: 'Customer', + }); + mockedDb.organizationBilling.findUnique.mockResolvedValue({ + organizationId: 'org_1', + stripeCustomerId: 'cus_org_1', + }); + mockedDb.organizationBillingSubscription.findMany.mockResolvedValue([]); + }); + + it('rejects invoice recovery links for invoices owned by another customer', async () => { + const service = new AdminBillingActionsService( + { + getClient: () => ({ + invoices: { + retrieve: jest.fn().mockResolvedValue({ + id: 'in_other', + customer: 'cus_other', + status: 'open', + }), + }, + }), + isConfigured: () => true, + } as never, + {} as never, + {} as never, + {} as never, + ); + + await expect( + service.getInvoiceRetryLink({ + organizationId: 'org_1', + adminUserId: 'usr_admin', + invoiceId: 'in_other', + note: 'customer asked', + }), + ).rejects.toBeInstanceOf(NotFoundException); + }); + + it('rejects invoice recovery links when Stripe billing is not configured', async () => { + const getClient = jest.fn(); + const service = new AdminBillingActionsService( + { + getClient, + isConfigured: () => false, + } as never, + {} as never, + {} as never, + {} as never, + ); + + await expect( + service.getInvoiceRetryLink({ + organizationId: 'org_1', + adminUserId: 'usr_admin', + invoiceId: 'in_open', + note: 'retry payment', + }), + ).rejects.toMatchObject({ + status: HttpStatus.PAYMENT_REQUIRED, + }); + expect(getClient).not.toHaveBeenCalled(); + }); +}); diff --git a/apps/api/src/admin-organizations/admin-billing-actions.service.ts b/apps/api/src/admin-organizations/admin-billing-actions.service.ts new file mode 100644 index 0000000000..2b5af7979a --- /dev/null +++ b/apps/api/src/admin-organizations/admin-billing-actions.service.ts @@ -0,0 +1,193 @@ +import { + BadRequestException, + Injectable, + NotFoundException, +} from '@nestjs/common'; +import { db } from '@db'; +import { BillingCreditsService } from '../billing/billing-credits.service'; +import { BillingService } from '../billing/billing.service'; +import { validateBillingRedirectUrl } from '../billing/billing-redirect-urls'; +import { assertStripeBillingConfigured } from '../billing/billing-stripe-config'; +import { StripeService } from '../stripe/stripe.service'; +import { getInvoiceCustomerId } from './admin-billing.helpers'; +import { getOrgBillingContext, writeBillingAudit } from './admin-billing.data'; +import { AdminBillingService } from './admin-billing.service'; +import type { AdminBillingStatus } from './admin-billing.types'; + +@Injectable() +export class AdminBillingActionsService { + constructor( + private readonly stripeService: StripeService, + private readonly billingService: BillingService, + private readonly credits: BillingCreditsService, + private readonly adminBilling: AdminBillingService, + ) {} + + async cancelSubscription(params: { + organizationId: string; + adminUserId: string; + subscriptionId: string; + mode: 'period_end' | 'immediate'; + note: string; + confirm?: string; + }): Promise { + assertStripeBillingConfigured(this.stripeService); + const subscription = await this.getScopedSubscription(params); + if (params.mode === 'immediate' && params.confirm !== 'cancel now') { + throw new BadRequestException('Type "cancel now" to confirm.'); + } + const stripe = this.stripeService.getClient(); + const updated = + params.mode === 'immediate' + ? await stripe.subscriptions.cancel(subscription.stripeSubscriptionId, { + cancellation_details: { comment: params.note }, + }) + : await stripe.subscriptions.update(subscription.stripeSubscriptionId, { + cancel_at_period_end: true, + cancellation_details: { comment: params.note }, + }); + await db.organizationBillingSubscription.updateMany({ + where: { id: subscription.id, organizationId: params.organizationId }, + data: { + stripeStatus: updated.status, + cancelAtPeriodEnd: updated.cancel_at_period_end, + canceledAt: updated.canceled_at + ? new Date(updated.canceled_at * 1000) + : null, + }, + }); + await writeBillingAudit({ + organizationId: params.organizationId, + eventType: 'admin_subscription_canceled', + skuKey: subscription.skuKey, + metadata: { + adminUserId: params.adminUserId, + subscriptionId: subscription.id, + mode: params.mode, + note: params.note, + }, + }); + return this.adminBilling.getStatus(params.organizationId); + } + + async resumeSubscription(params: { + organizationId: string; + adminUserId: string; + subscriptionId: string; + note: string; + }): Promise { + assertStripeBillingConfigured(this.stripeService); + const subscription = await this.getScopedSubscription(params); + const updated = await this.stripeService + .getClient() + .subscriptions.update(subscription.stripeSubscriptionId, { + cancel_at_period_end: false, + }); + await db.organizationBillingSubscription.updateMany({ + where: { id: subscription.id, organizationId: params.organizationId }, + data: { + stripeStatus: updated.status, + cancelAtPeriodEnd: updated.cancel_at_period_end, + canceledAt: null, + }, + }); + await writeBillingAudit({ + organizationId: params.organizationId, + eventType: 'admin_subscription_resumed', + skuKey: subscription.skuKey, + metadata: { + adminUserId: params.adminUserId, + subscriptionId: subscription.id, + note: params.note, + }, + }); + return this.adminBilling.getStatus(params.organizationId); + } + + async createPaymentLink(params: { + organizationId: string; + successUrl: string; + cancelUrl: string; + }) { + validateBillingRedirectUrl(params.successUrl); + validateBillingRedirectUrl(params.cancelUrl); + return this.billingService.createSetupSession(params); + } + + async grantCredits(params: { + organizationId: string; + adminUserId: string; + productKey: 'pentest' | 'background_check'; + quantity: number; + note: string; + confirm?: string; + }): Promise { + if (params.quantity >= 25 && params.confirm !== 'grant credits') { + throw new BadRequestException('Type "grant credits" to confirm.'); + } + await this.credits.grant({ + organizationId: params.organizationId, + productKey: params.productKey, + quantity: params.quantity, + source: 'manual', + note: params.note, + adminUserId: params.adminUserId, + }); + await writeBillingAudit({ + organizationId: params.organizationId, + eventType: 'admin_credits_granted', + metadata: { + adminUserId: params.adminUserId, + productKey: params.productKey, + quantity: params.quantity, + note: params.note, + }, + }); + return this.adminBilling.getStatus(params.organizationId); + } + + async getInvoiceRetryLink(params: { + organizationId: string; + adminUserId: string; + invoiceId: string; + note: string; + }) { + assertStripeBillingConfigured(this.stripeService); + const { billing } = await getOrgBillingContext(params.organizationId); + if (!billing) throw new NotFoundException('Billing customer not found.'); + const invoice = await this.stripeService + .getClient() + .invoices.retrieve(params.invoiceId); + if (getInvoiceCustomerId(invoice) !== billing.stripeCustomerId) { + throw new NotFoundException('Invoice not found.'); + } + await writeBillingAudit({ + organizationId: params.organizationId, + eventType: 'admin_invoice_retry_link_created', + metadata: { + adminUserId: params.adminUserId, + invoiceId: params.invoiceId, + note: params.note, + }, + }); + return { + hostedInvoiceUrl: invoice.hosted_invoice_url ?? null, + invoicePdfUrl: invoice.invoice_pdf ?? null, + status: invoice.status ?? 'unknown', + }; + } + + private async getScopedSubscription(params: { + organizationId: string; + subscriptionId: string; + }) { + const subscription = await db.organizationBillingSubscription.findFirst({ + where: { + id: params.subscriptionId, + organizationId: params.organizationId, + }, + }); + if (!subscription) throw new NotFoundException('Subscription not found.'); + return subscription; + } +} diff --git a/apps/api/src/admin-organizations/admin-billing.controller.ts b/apps/api/src/admin-organizations/admin-billing.controller.ts new file mode 100644 index 0000000000..8af9805632 --- /dev/null +++ b/apps/api/src/admin-organizations/admin-billing.controller.ts @@ -0,0 +1,208 @@ +import { + Body, + Controller, + Get, + Param, + Post, + Put, + Req, + UseGuards, + UseInterceptors, + UsePipes, + ValidationPipe, +} from '@nestjs/common'; +import { ApiExcludeController, ApiOperation, ApiTags } from '@nestjs/swagger'; +import { Throttle } from '@nestjs/throttler'; +import { + subscriptionBillingSkuKeys, + type BillingSkuKey, +} from '@trycompai/billing'; +import { PlatformAdminGuard } from '../auth/platform-admin.guard'; +import { AdminAuditLogInterceptor } from './admin-audit-log.interceptor'; +import { AdminBillingActionsService } from './admin-billing-actions.service'; +import { AdminBillingService } from './admin-billing.service'; +import { + AdminBillingCancelSubscriptionDto, + AdminBillingGrantCreditsDto, + AdminBillingInvoiceActionDto, + AdminBillingNoteDto, + AdminBillingPaymentLinkDto, + AdminBillingPreferencesDto, + AdminBillingSubscriptionPreviewDto, + AdminBillingSubscriptionDto, +} from './dto/admin-billing.dto'; + +interface AdminRequest { + userId: string; +} + +@ApiExcludeController() +@ApiTags('Admin - Billing') +@Controller({ path: 'admin/organizations/:orgId/billing', version: '1' }) +@UseGuards(PlatformAdminGuard) +@UseInterceptors(AdminAuditLogInterceptor) +@UsePipes( + new ValidationPipe({ + whitelist: true, + forbidNonWhitelisted: true, + transform: true, + }), +) +@Throttle({ default: { ttl: 60_000, limit: 30 } }) +export class AdminBillingController { + constructor( + private readonly billing: AdminBillingService, + private readonly actions: AdminBillingActionsService, + ) {} + + @Get() + @ApiOperation({ summary: 'Get admin billing status for an organization' }) + async getStatus(@Param('orgId') orgId: string) { + return this.billing.getStatus(orgId); + } + + @Put('preferences') + @ApiOperation({ summary: 'Update billing preferences for an organization' }) + async updatePreferences( + @Param('orgId') orgId: string, + @Req() req: AdminRequest, + @Body() body: AdminBillingPreferencesDto, + ) { + return this.billing.updatePreferences({ + organizationId: orgId, + adminUserId: req.userId, + note: body.note, + confirmBillingEmailChange: body.confirmBillingEmailChange, + preferences: { + companyName: body.companyName, + billingEmail: body.billingEmail, + purchaseOrder: body.purchaseOrder ?? null, + address: { + line1: body.addressLine1 ?? null, + line2: body.addressLine2 ?? null, + city: body.addressCity ?? null, + state: body.addressState ?? null, + postalCode: body.addressPostalCode ?? null, + country: body.addressCountry ?? null, + }, + taxId: { + type: body.taxIdType ?? null, + value: body.taxIdValue ?? null, + }, + }, + }); + } + + @Post('subscriptions/preview') + @ApiOperation({ summary: 'Preview a subscription plan change' }) + async previewSubscription( + @Param('orgId') orgId: string, + @Body() body: AdminBillingSubscriptionPreviewDto, + ) { + return this.billing.previewSubscription({ + organizationId: orgId, + skuKey: assertSubscriptionSku(body.skuKey), + }); + } + + @Post('subscriptions') + @ApiOperation({ summary: 'Create or change a subscription plan' }) + async setSubscription( + @Param('orgId') orgId: string, + @Req() req: AdminRequest, + @Body() body: AdminBillingSubscriptionDto, + ) { + return this.billing.setSubscription({ + organizationId: orgId, + adminUserId: req.userId, + skuKey: assertSubscriptionSku(body.skuKey), + returnUrl: body.returnUrl, + note: body.note, + confirmDowngrade: body.confirmDowngrade, + }); + } + + @Post('subscriptions/:subscriptionId/cancel') + @Throttle({ default: { ttl: 60_000, limit: 5 } }) + async cancelSubscription( + @Param('orgId') orgId: string, + @Param('subscriptionId') subscriptionId: string, + @Req() req: AdminRequest, + @Body() body: AdminBillingCancelSubscriptionDto, + ) { + return this.actions.cancelSubscription({ + organizationId: orgId, + adminUserId: req.userId, + subscriptionId, + mode: body.mode, + note: body.note, + confirm: body.confirm, + }); + } + + @Post('subscriptions/:subscriptionId/resume') + async resumeSubscription( + @Param('orgId') orgId: string, + @Param('subscriptionId') subscriptionId: string, + @Req() req: AdminRequest, + @Body() body: AdminBillingNoteDto, + ) { + return this.actions.resumeSubscription({ + organizationId: orgId, + adminUserId: req.userId, + subscriptionId, + note: body.note, + }); + } + + @Post('payment-link') + async createPaymentLink( + @Param('orgId') orgId: string, + @Body() body: AdminBillingPaymentLinkDto, + ) { + return this.actions.createPaymentLink({ + organizationId: orgId, + successUrl: body.successUrl, + cancelUrl: body.cancelUrl, + }); + } + + @Post('credits') + @Throttle({ default: { ttl: 60_000, limit: 5 } }) + async grantCredits( + @Param('orgId') orgId: string, + @Req() req: AdminRequest, + @Body() body: AdminBillingGrantCreditsDto, + ) { + return this.actions.grantCredits({ + organizationId: orgId, + adminUserId: req.userId, + productKey: body.productKey, + quantity: body.quantity, + note: body.note, + confirm: body.confirm, + }); + } + + @Post('invoices/:invoiceId/retry-link') + async getInvoiceRetryLink( + @Param('orgId') orgId: string, + @Param('invoiceId') invoiceId: string, + @Req() req: AdminRequest, + @Body() body: AdminBillingInvoiceActionDto, + ) { + return this.actions.getInvoiceRetryLink({ + organizationId: orgId, + adminUserId: req.userId, + invoiceId, + note: body.note, + }); + } +} + +function assertSubscriptionSku(value: string): BillingSkuKey { + if (subscriptionBillingSkuKeys.some((skuKey) => skuKey === value)) { + return value as BillingSkuKey; + } + throw new Error('Invalid subscription SKU.'); +} diff --git a/apps/api/src/admin-organizations/admin-billing.data.spec.ts b/apps/api/src/admin-organizations/admin-billing.data.spec.ts new file mode 100644 index 0000000000..95a623f315 --- /dev/null +++ b/apps/api/src/admin-organizations/admin-billing.data.spec.ts @@ -0,0 +1,66 @@ +import { createAdminSubscription } from './admin-billing.data'; + +jest.mock('@db', () => ({ + db: {}, + Prisma: {}, +})); + +describe('createAdminSubscription', () => { + it('uses a stable Stripe idempotency key for subscription create retries', async () => { + const subscription = { + id: 'sub_1', + status: 'active', + cancel_at_period_end: false, + canceled_at: null, + items: { + data: [ + { + id: 'si_1', + current_period_start: 1777564800, + current_period_end: 1780243200, + }, + ], + }, + }; + const subscriptionsCreate = jest.fn().mockResolvedValue(subscription); + const entitlements = { syncSubscriptionItem: jest.fn() }; + + await createAdminSubscription({ + organizationId: 'org_1', + stripeCustomerId: 'cus_1', + skuKey: 'pentest_monthly_1', + stripePriceId: 'price_1', + includedQuantity: 1, + idempotencyKey: + 'admin-subscription-create:org_1:pentest_monthly_1:cus_1:none', + stripeService: { + getClient: () => ({ + subscriptions: { create: subscriptionsCreate }, + }), + } as never, + entitlements: entitlements as never, + }); + await createAdminSubscription({ + organizationId: 'org_1', + stripeCustomerId: 'cus_1', + skuKey: 'pentest_monthly_1', + stripePriceId: 'price_1', + includedQuantity: 1, + idempotencyKey: + 'admin-subscription-create:org_1:pentest_monthly_1:cus_1:none', + stripeService: { + getClient: () => ({ + subscriptions: { create: subscriptionsCreate }, + }), + } as never, + entitlements: entitlements as never, + }); + + const firstKey = subscriptionsCreate.mock.calls[0][1].idempotencyKey; + const secondKey = subscriptionsCreate.mock.calls[1][1].idempotencyKey; + expect(firstKey).toBe( + 'admin-subscription-create:org_1:pentest_monthly_1:cus_1:none', + ); + expect(secondKey).toBe(firstKey); + }); +}); diff --git a/apps/api/src/admin-organizations/admin-billing.data.ts b/apps/api/src/admin-organizations/admin-billing.data.ts new file mode 100644 index 0000000000..ac0c98d7fd --- /dev/null +++ b/apps/api/src/admin-organizations/admin-billing.data.ts @@ -0,0 +1,112 @@ +import { NotFoundException } from '@nestjs/common'; +import { db, Prisma } from '@db'; +import type { BillingSkuKey } from '@trycompai/billing'; +import type Stripe from 'stripe'; +import { BillingEntitlementsService } from '../billing/billing-entitlements.service'; +import { StripeService } from '../stripe/stripe.service'; +import { dateFromSeconds, readNumber } from './admin-billing.helpers'; + +export async function getOrgBillingContext(organizationId: string) { + const [organization, billing, subscriptions] = await Promise.all([ + db.organization.findUnique({ + where: { id: organizationId }, + select: { id: true, name: true }, + }), + db.organizationBilling.findUnique({ where: { organizationId } }), + db.organizationBillingSubscription.findMany({ + where: { organizationId }, + orderBy: [{ createdAt: 'desc' }], + }), + ]); + if (!organization) { + throw new NotFoundException(`Organization ${organizationId} not found`); + } + return { organization, billing, subscriptions }; +} + +export async function writeBillingAudit(params: { + organizationId: string; + eventType: string; + skuKey?: string | null; + stripeEventId?: string | null; + metadata?: Prisma.InputJsonValue; +}) { + await db.billingAuditEvent.create({ + data: { + organizationId: params.organizationId, + eventType: params.eventType, + skuKey: params.skuKey, + stripeEventId: params.stripeEventId, + metadata: params.metadata, + }, + }); +} + +export async function createAdminSubscription(params: { + organizationId: string; + stripeCustomerId: string; + skuKey: BillingSkuKey; + stripePriceId: string; + includedQuantity: number; + idempotencyKey: string; + stripeService: StripeService; + entitlements: BillingEntitlementsService; +}) { + const subscription = await params.stripeService + .getClient() + .subscriptions.create( + { + customer: params.stripeCustomerId, + items: [{ price: params.stripePriceId, quantity: 1 }], + metadata: { + organizationId: params.organizationId, + skuKey: params.skuKey, + source: 'admin-billing', + }, + expand: ['items.data.price'], + }, + { + idempotencyKey: params.idempotencyKey, + }, + ); + const item = subscription.items.data[0]; + if (!item) throw new NotFoundException('Stripe subscription item not found.'); + await syncStripeSubscriptionItem({ + organizationId: params.organizationId, + skuKey: params.skuKey, + stripePriceId: params.stripePriceId, + includedQuantity: params.includedQuantity, + subscription, + item, + entitlements: params.entitlements, + }); +} + +export async function syncStripeSubscriptionItem(params: { + organizationId: string; + skuKey: BillingSkuKey; + stripePriceId: string; + includedQuantity: number; + subscription: Stripe.Subscription; + item: Stripe.SubscriptionItem; + entitlements: BillingEntitlementsService; +}) { + await params.entitlements.syncSubscriptionItem({ + organizationId: params.organizationId, + skuKey: params.skuKey, + stripeSubscriptionId: params.subscription.id, + stripeSubscriptionItemId: params.item.id, + stripePriceId: params.stripePriceId, + stripeStatus: params.subscription.status, + currentPeriodStart: dateFromSeconds( + readNumber(params.item, 'current_period_start'), + ), + currentPeriodEnd: dateFromSeconds( + readNumber(params.item, 'current_period_end'), + ), + includedQuantity: params.includedQuantity, + cancelAtPeriodEnd: params.subscription.cancel_at_period_end, + canceledAt: dateFromSeconds(params.subscription.canceled_at), + stripeEventId: undefined, + }); +} diff --git a/apps/api/src/admin-organizations/admin-billing.helpers.ts b/apps/api/src/admin-organizations/admin-billing.helpers.ts new file mode 100644 index 0000000000..678066a80d --- /dev/null +++ b/apps/api/src/admin-organizations/admin-billing.helpers.ts @@ -0,0 +1,93 @@ +import { + getBillingCatalog, + getBillingSku, + getBillingSkuProductKey, + type BillingProductKey, + type BillingSkuKey, +} from '@trycompai/billing'; +import type Stripe from 'stripe'; +import type { + AdminBillingPlan, + AdminBillingSubscription, +} from './admin-billing.types'; + +export function listAdminBillingPlans(): AdminBillingPlan[] { + return Object.values(getBillingCatalog().skus) + .filter((sku) => sku.cadence === 'month' && !sku.deprecated) + .map((sku) => ({ + skuKey: sku.key, + productKey: sku.productKey, + name: sku.name, + unitAmount: sku.unitAmount, + currency: sku.currency, + includedQuantity: sku.includedUsage?.quantity ?? 0, + })); +} + +export function mapAdminSubscription(subscription: { + id: string; + skuKey: string; + stripeSubscriptionId: string; + stripeSubscriptionItemId: string; + stripeStatus: string; + includedQuantity: number; + usedQuantity: number; + currentPeriodStart: Date | null; + currentPeriodEnd: Date | null; + cancelAtPeriodEnd: boolean; + canceledAt: Date | null; +}): AdminBillingSubscription { + return { + id: subscription.id, + skuKey: subscription.skuKey, + productKey: getBillingSkuProductKey(subscription.skuKey), + stripeSubscriptionId: subscription.stripeSubscriptionId, + stripeSubscriptionItemId: subscription.stripeSubscriptionItemId, + stripeStatus: subscription.stripeStatus, + includedQuantity: subscription.includedQuantity, + usedQuantity: subscription.usedQuantity, + remainingQuantity: Math.max( + subscription.includedQuantity - subscription.usedQuantity, + 0, + ), + currentPeriodStart: subscription.currentPeriodStart?.toISOString() ?? null, + currentPeriodEnd: subscription.currentPeriodEnd?.toISOString() ?? null, + cancelAtPeriodEnd: subscription.cancelAtPeriodEnd, + canceledAt: subscription.canceledAt?.toISOString() ?? null, + }; +} + +export function getProductFromSku(skuKey: string): BillingProductKey | null { + return getBillingSkuProductKey(skuKey); +} + +export function isDowngrade(params: { + currentIncludedQuantity: number; + nextSkuKey: BillingSkuKey; +}) { + const nextSku = getBillingSku({ skuKey: params.nextSkuKey }); + return ( + (nextSku.includedUsage?.quantity ?? 0) < params.currentIncludedQuantity + ); +} + +export function dateFromSeconds(value: number | null | undefined): Date | null { + return typeof value === 'number' ? new Date(value * 1000) : null; +} + +export function readNumber(value: unknown, key: string): number | null { + if (typeof value !== 'object' || value === null) return null; + const raw = (value as Record)[key]; + return typeof raw === 'number' ? raw : null; +} + +export function extractStripeId(value: unknown): string | null { + if (typeof value === 'string') return value; + if (typeof value !== 'object' || value === null) return null; + const raw = (value as Record).id; + return typeof raw === 'string' ? raw : null; +} + +export function getInvoiceCustomerId(invoice: Stripe.Invoice): string | null { + return extractStripeId(invoice.customer); +} diff --git a/apps/api/src/admin-organizations/admin-billing.service.spec.ts b/apps/api/src/admin-organizations/admin-billing.service.spec.ts new file mode 100644 index 0000000000..fb27da8481 --- /dev/null +++ b/apps/api/src/admin-organizations/admin-billing.service.spec.ts @@ -0,0 +1,140 @@ +import { db } from '@db'; +import { AdminBillingService } from './admin-billing.service'; + +jest.mock('@db', () => ({ + db: { + organization: { findUnique: jest.fn() }, + organizationBilling: { findUnique: jest.fn() }, + organizationBillingSubscription: { findMany: jest.fn() }, + billingAuditEvent: { create: jest.fn() }, + }, +})); + +const mockedDb = db as unknown as { + organization: { findUnique: jest.Mock }; + organizationBilling: { findUnique: jest.Mock }; + organizationBillingSubscription: { findMany: jest.Mock }; + billingAuditEvent: { create: jest.Mock }; +}; + +describe('AdminBillingService', () => { + const originalStripeSecretKey = process.env.STRIPE_SECRET_KEY; + const restoreStripeSecretKey = () => { + if (typeof originalStripeSecretKey === 'string') { + process.env.STRIPE_SECRET_KEY = originalStripeSecretKey; + return; + } + delete process.env.STRIPE_SECRET_KEY; + }; + + beforeEach(() => { + jest.clearAllMocks(); + restoreStripeSecretKey(); + mockedDb.organization.findUnique.mockResolvedValue({ + id: 'org_1', + name: 'Customer', + }); + mockedDb.organizationBilling.findUnique.mockResolvedValue({ + organizationId: 'org_1', + stripeCustomerId: 'cus_org_1', + stripePaymentMethodId: 'pm_1', + }); + mockedDb.organizationBillingSubscription.findMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'pentest_monthly_1', + stripeSubscriptionId: 'sub_1', + stripeSubscriptionItemId: 'si_1', + stripeStatus: 'active', + }, + ]); + mockedDb.billingAuditEvent.create.mockResolvedValue({}); + }); + + afterAll(() => { + restoreStripeSecretKey(); + }); + + it('previews subscription changes with the configured billing catalog environment', async () => { + process.env.STRIPE_SECRET_KEY = 'sk_live_preview_test'; + const createPreview = jest.fn().mockResolvedValue({ + amount_due: 49900, + currency: 'usd', + }); + const service = new AdminBillingService( + { + getClient: () => ({ + invoices: { createPreview }, + }), + isConfigured: () => true, + } as never, + {} as never, + {} as never, + {} as never, + ); + + await expect( + service.previewSubscription({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_3', + }), + ).resolves.toMatchObject({ + amountDue: 49900, + currency: 'usd', + subscriptionId: 'obs_1', + }); + + expect(createPreview).toHaveBeenCalledWith( + expect.objectContaining({ + subscription_details: expect.objectContaining({ + items: [ + expect.objectContaining({ + id: 'si_1', + price: 'price_1TS3zMCxqPDT5y0WC2OyJNAv', + quantity: 1, + }), + ], + }), + }), + ); + }); + + it('writes an audit event when admin checkout immediately changes an existing subscription', async () => { + const service = new AdminBillingService( + { isConfigured: () => true } as never, + { + createSubscriptionCheckoutSession: jest + .fn() + .mockResolvedValue({ changed: true }), + } as never, + {} as never, + {} as never, + ); + jest.spyOn(service, 'getStatus').mockResolvedValue({} as never); + mockedDb.organizationBilling.findUnique.mockResolvedValue({ + organizationId: 'org_1', + stripeCustomerId: 'cus_org_1', + stripePaymentMethodId: null, + }); + + await service.setSubscription({ + organizationId: 'org_1', + adminUserId: 'usr_admin', + skuKey: 'pentest_monthly_3', + returnUrl: 'http://localhost:3000/org_1/settings/billing', + note: 'Upgrade for customer', + }); + + expect(mockedDb.billingAuditEvent.create).toHaveBeenCalledWith({ + data: expect.objectContaining({ + organizationId: 'org_1', + eventType: 'admin_subscription_set', + skuKey: 'pentest_monthly_3', + metadata: { + adminUserId: 'usr_admin', + note: 'Upgrade for customer', + }, + }), + }); + }); +}); diff --git a/apps/api/src/admin-organizations/admin-billing.service.ts b/apps/api/src/admin-organizations/admin-billing.service.ts new file mode 100644 index 0000000000..ba680b41f3 --- /dev/null +++ b/apps/api/src/admin-organizations/admin-billing.service.ts @@ -0,0 +1,278 @@ +import { + BadRequestException, + Injectable, + NotFoundException, +} from '@nestjs/common'; +import { db } from '@db'; +import { + getBillingSku, + resolveBillingCatalogEnvironment, + type BillingSkuKey, +} from '@trycompai/billing'; +import { BillingCreditsService } from '../billing/billing-credits.service'; +import { BillingEntitlementsService } from '../billing/billing-entitlements.service'; +import { listBillingInvoices } from '../billing/billing-invoices'; +import { + getBillingPreferences, + type BillingPreferencesInput, + updateBillingPreferences, +} from '../billing/billing-preferences'; +import { validateBillingRedirectUrl } from '../billing/billing-redirect-urls'; +import { assertStripeBillingConfigured } from '../billing/billing-stripe-config'; +import { changeSubscriptionPlan } from '../billing/billing-subscription-plans'; +import { BillingService } from '../billing/billing.service'; +import { StripeService } from '../stripe/stripe.service'; +import { + getProductFromSku, + isDowngrade, + listAdminBillingPlans, + mapAdminSubscription, +} from './admin-billing.helpers'; +import { + createAdminSubscription, + getOrgBillingContext, + writeBillingAudit, +} from './admin-billing.data'; +import type { + AdminBillingPreview, + AdminBillingStatus, +} from './admin-billing.types'; + +@Injectable() +export class AdminBillingService { + constructor( + private readonly stripeService: StripeService, + private readonly billingService: BillingService, + private readonly credits: BillingCreditsService, + private readonly entitlements: BillingEntitlementsService, + ) {} + + async getStatus(organizationId: string): Promise { + const { organization, billing, subscriptions } = + await getOrgBillingContext(organizationId); + const [ + preferences, + invoices, + usageRows, + creditBalances, + creditEvents, + auditEvents, + ] = await Promise.all([ + getBillingPreferences({ + stripeService: this.stripeService, + stripeCustomerId: billing?.stripeCustomerId ?? null, + fallbackCompanyName: organization.name, + }), + listBillingInvoices({ + stripeService: this.stripeService, + stripeCustomerId: billing?.stripeCustomerId ?? null, + }), + this.billingService + .getStatus(organizationId) + .then((status) => status.usageRows), + this.credits.listBalances(organizationId), + this.credits.listEvents({ organizationId, take: 50 }), + db.billingAuditEvent.findMany({ + where: { organizationId }, + orderBy: { createdAt: 'desc' }, + take: 50, + }), + ]); + + return { + organization, + stripeCustomerId: billing?.stripeCustomerId ?? null, + hasPaymentMethod: !!billing?.stripePaymentMethodId, + paymentMethodUpdatedAt: + billing?.paymentMethodUpdatedAt?.toISOString() ?? null, + preferences, + availablePlans: listAdminBillingPlans(), + subscriptions: subscriptions.map(mapAdminSubscription), + creditBalances, + creditEvents, + usageRows, + invoices, + failedInvoices: invoices.filter((invoice) => + ['open', 'past_due', 'uncollectible'].includes(invoice.status), + ), + auditEvents: auditEvents.map((event) => ({ + id: event.id, + eventType: event.eventType, + skuKey: event.skuKey, + stripeEventId: event.stripeEventId, + metadata: event.metadata, + createdAt: event.createdAt.toISOString(), + })), + }; + } + + async updatePreferences(params: { + organizationId: string; + adminUserId: string; + preferences: BillingPreferencesInput; + note: string; + confirmBillingEmailChange?: boolean; + }): Promise { + assertStripeBillingConfigured(this.stripeService); + const status = await this.getStatus(params.organizationId); + const currentEmail = status.preferences.billingEmail; + if ( + currentEmail && + currentEmail !== params.preferences.billingEmail && + !params.confirmBillingEmailChange + ) { + throw new BadRequestException('Confirm billing email change.'); + } + + const result = await updateBillingPreferences({ + stripeService: this.stripeService, + organizationId: params.organizationId, + preferences: params.preferences, + }); + await writeBillingAudit({ + organizationId: params.organizationId, + eventType: 'admin_billing_preferences_updated', + metadata: { + adminUserId: params.adminUserId, + stripeCustomerId: result.stripeCustomerId, + billingEmail: result.preferences.billingEmail, + note: params.note, + }, + }); + return this.getStatus(params.organizationId); + } + + async previewSubscription(params: { + organizationId: string; + skuKey: BillingSkuKey; + }): Promise { + const { billing, subscriptions } = await getOrgBillingContext( + params.organizationId, + ); + if (!billing) throw new NotFoundException('Billing customer not found.'); + assertStripeBillingConfigured(this.stripeService); + const sku = getBillingSku({ + environment: resolveBillingCatalogEnvironment(), + skuKey: params.skuKey, + }); + const current = subscriptions.find( + (item) => + item.stripeStatus !== 'canceled' && + getProductFromSku(item.skuKey) === sku.productKey, + ); + const prorationDate = Math.floor(Date.now() / 1000); + const invoice = await this.stripeService + .getClient() + .invoices.createPreview({ + customer: billing.stripeCustomerId, + ...(current ? { subscription: current.stripeSubscriptionId } : {}), + subscription_details: { + proration_date: current ? prorationDate : undefined, + items: [ + current + ? { + id: current.stripeSubscriptionItemId, + price: sku.stripePriceId, + quantity: 1, + } + : { price: sku.stripePriceId, quantity: 1 }, + ], + }, + }); + return { + amountDue: invoice.amount_due, + currency: invoice.currency, + subscriptionId: current?.id ?? null, + prorationDate, + }; + } + + async setSubscription(params: { + organizationId: string; + adminUserId: string; + skuKey: BillingSkuKey; + returnUrl: string; + note: string; + confirmDowngrade?: boolean; + }): Promise { + validateBillingRedirectUrl(params.returnUrl); + assertStripeBillingConfigured(this.stripeService); + const { billing, subscriptions } = await getOrgBillingContext( + params.organizationId, + ); + const sku = getBillingSku({ + environment: resolveBillingCatalogEnvironment(), + skuKey: params.skuKey, + }); + const current = subscriptions.find( + (item) => + item.stripeStatus !== 'canceled' && + getProductFromSku(item.skuKey) === sku.productKey, + ); + const latestProductSubscription = subscriptions.find( + (item) => getProductFromSku(item.skuKey) === sku.productKey, + ); + if ( + current && + isDowngrade({ + currentIncludedQuantity: current.includedQuantity, + nextSkuKey: params.skuKey, + }) && + !params.confirmDowngrade + ) { + throw new BadRequestException('Confirm plan downgrade.'); + } + if (!billing?.stripePaymentMethodId) { + const result = + await this.billingService.createSubscriptionCheckoutSession({ + organizationId: params.organizationId, + skuKey: params.skuKey, + successUrl: params.returnUrl, + cancelUrl: params.returnUrl, + }); + if (!('changed' in result)) return result; + await writeBillingAudit({ + organizationId: params.organizationId, + eventType: 'admin_subscription_set', + skuKey: sku.key, + metadata: { adminUserId: params.adminUserId, note: params.note }, + }); + return this.getStatus(params.organizationId); + } + if (current) { + await changeSubscriptionPlan({ + organizationId: params.organizationId, + subscription: current, + skuKey: sku.key, + stripePriceId: sku.stripePriceId, + includedQuantity: sku.includedUsage?.quantity ?? 0, + stripeService: this.stripeService, + entitlements: this.entitlements, + }); + } else { + await createAdminSubscription({ + organizationId: params.organizationId, + stripeCustomerId: billing.stripeCustomerId, + skuKey: sku.key, + stripePriceId: sku.stripePriceId, + includedQuantity: sku.includedUsage?.quantity ?? 0, + idempotencyKey: [ + 'admin-subscription-create', + params.organizationId, + sku.key, + billing.stripeCustomerId, + latestProductSubscription?.stripeSubscriptionId ?? 'none', + ].join(':'), + stripeService: this.stripeService, + entitlements: this.entitlements, + }); + } + await writeBillingAudit({ + organizationId: params.organizationId, + eventType: 'admin_subscription_set', + skuKey: sku.key, + metadata: { adminUserId: params.adminUserId, note: params.note }, + }); + return this.getStatus(params.organizationId); + } +} diff --git a/apps/api/src/admin-organizations/admin-billing.types.ts b/apps/api/src/admin-organizations/admin-billing.types.ts new file mode 100644 index 0000000000..b7cf8a97d6 --- /dev/null +++ b/apps/api/src/admin-organizations/admin-billing.types.ts @@ -0,0 +1,68 @@ +import type { BillingProductKey } from '@trycompai/billing'; +import type { + BillingCreditBalanceSummary, + BillingCreditEventSummary, +} from '../billing/billing-credits.types'; +import type { BillingInvoice } from '../billing/billing-invoices'; +import type { BillingPreferences } from '../billing/billing-preferences'; +import type { BillingUsageRow } from '../billing/billing.types'; + +export interface AdminBillingPlan { + skuKey: string; + productKey: BillingProductKey; + name: string; + unitAmount: number; + currency: string; + includedQuantity: number; +} + +export interface AdminBillingSubscription { + id: string; + skuKey: string; + productKey: BillingProductKey | null; + stripeSubscriptionId: string; + stripeSubscriptionItemId: string; + stripeStatus: string; + includedQuantity: number; + usedQuantity: number; + remainingQuantity: number; + currentPeriodStart: string | null; + currentPeriodEnd: string | null; + cancelAtPeriodEnd: boolean; + canceledAt: string | null; +} + +export interface AdminBillingAuditEvent { + id: string; + eventType: string; + skuKey: string | null; + stripeEventId: string | null; + metadata: unknown; + createdAt: string; +} + +export interface AdminBillingStatus { + organization: { + id: string; + name: string; + }; + stripeCustomerId: string | null; + hasPaymentMethod: boolean; + paymentMethodUpdatedAt: string | null; + preferences: BillingPreferences; + availablePlans: AdminBillingPlan[]; + subscriptions: AdminBillingSubscription[]; + creditBalances: BillingCreditBalanceSummary[]; + creditEvents: BillingCreditEventSummary[]; + usageRows: BillingUsageRow[]; + invoices: BillingInvoice[]; + failedInvoices: BillingInvoice[]; + auditEvents: AdminBillingAuditEvent[]; +} + +export interface AdminBillingPreview { + amountDue: number; + currency: string; + subscriptionId: string | null; + prorationDate: number; +} diff --git a/apps/api/src/admin-organizations/admin-organizations.module.ts b/apps/api/src/admin-organizations/admin-organizations.module.ts index 938dcc6668..4afc166c4c 100644 --- a/apps/api/src/admin-organizations/admin-organizations.module.ts +++ b/apps/api/src/admin-organizations/admin-organizations.module.ts @@ -7,7 +7,11 @@ import { EvidenceFormsModule } from '../evidence-forms/evidence-forms.module'; import { PoliciesModule } from '../policies/policies.module'; import { CommentsModule } from '../comments/comments.module'; import { AttachmentsModule } from '../attachments/attachments.module'; +import { BillingModule } from '../billing/billing.module'; import { SecurityPenetrationTestsModule } from '../security-penetration-tests/security-penetration-tests.module'; +import { AdminBillingActionsService } from './admin-billing-actions.service'; +import { AdminBillingController } from './admin-billing.controller'; +import { AdminBillingService } from './admin-billing.service'; import { AdminOrganizationsController } from './admin-organizations.controller'; import { AdminOrganizationsService } from './admin-organizations.service'; import { PurgeOrganizationService } from './purge-organization.service'; @@ -31,6 +35,7 @@ import { AdminPentestCreditsController } from './admin-pentest-credits.controlle PoliciesModule, CommentsModule, AttachmentsModule, + BillingModule, SecurityPenetrationTestsModule, ], controllers: [ @@ -42,9 +47,12 @@ import { AdminPentestCreditsController } from './admin-pentest-credits.controlle AdminContextController, AdminEvidenceController, AdminPentestCreditsController, + AdminBillingController, ], providers: [ AdminOrganizationsService, + AdminBillingService, + AdminBillingActionsService, PurgeOrganizationService, PurgeOrganizationSnapshotService, PurgeOrganizationExternalService, diff --git a/apps/api/src/admin-organizations/dto/admin-billing.dto.ts b/apps/api/src/admin-organizations/dto/admin-billing.dto.ts new file mode 100644 index 0000000000..32c798ee29 --- /dev/null +++ b/apps/api/src/admin-organizations/dto/admin-billing.dto.ts @@ -0,0 +1,154 @@ +import { subscriptionBillingSkuKeys } from '@trycompai/billing'; +import { + IsBoolean, + IsEmail, + IsIn, + IsInt, + IsNotEmpty, + IsOptional, + IsString, + IsUrl, + Length, + Max, + Min, +} from 'class-validator'; +import { billingTaxIdTypes } from '../../billing/billing-preferences'; + +export class AdminBillingPreferencesDto { + @IsString() + @Length(1, 150) + companyName: string; + + @IsEmail() + billingEmail: string; + + @IsOptional() + @IsString() + @Length(0, 140) + purchaseOrder: string | null; + + @IsOptional() + @IsString() + @Length(0, 200) + addressLine1: string | null; + + @IsOptional() + @IsString() + @Length(0, 200) + addressLine2: string | null; + + @IsOptional() + @IsString() + @Length(0, 100) + addressCity: string | null; + + @IsOptional() + @IsString() + @Length(0, 100) + addressState: string | null; + + @IsOptional() + @IsString() + @Length(0, 32) + addressPostalCode: string | null; + + @IsOptional() + @IsString() + @Length(0, 2) + addressCountry: string | null; + + @IsOptional() + @IsString() + @IsIn([...billingTaxIdTypes, '']) + taxIdType: string | null; + + @IsOptional() + @IsString() + @Length(0, 64) + taxIdValue: string | null; + + @IsOptional() + @IsBoolean() + confirmBillingEmailChange?: boolean; + + @IsString() + @Length(3, 500) + note: string; +} + +export class AdminBillingSubscriptionDto { + @IsString() + @IsIn(subscriptionBillingSkuKeys) + skuKey: string; + + @IsString() + @IsUrl({ require_tld: false }) + returnUrl: string; + + @IsString() + @Length(3, 500) + note: string; + + @IsOptional() + @IsBoolean() + confirmDowngrade?: boolean; +} + +export class AdminBillingSubscriptionPreviewDto { + @IsString() + @IsIn(subscriptionBillingSkuKeys) + skuKey: string; +} + +export class AdminBillingCancelSubscriptionDto { + @IsIn(['period_end', 'immediate']) + mode: 'period_end' | 'immediate'; + + @IsString() + @Length(3, 500) + note: string; + + @IsOptional() + @IsString() + confirm?: string; +} + +export class AdminBillingNoteDto { + @IsString() + @Length(3, 500) + note: string; +} + +export class AdminBillingPaymentLinkDto { + @IsString() + @IsUrl({ require_tld: false }) + successUrl: string; + + @IsString() + @IsUrl({ require_tld: false }) + cancelUrl: string; +} + +export class AdminBillingGrantCreditsDto { + @IsIn(['pentest', 'background_check']) + productKey: 'pentest' | 'background_check'; + + @IsInt() + @Min(1) + @Max(1000) + quantity: number; + + @IsString() + @Length(3, 500) + note: string; + + @IsOptional() + @IsString() + confirm?: string; +} + +export class AdminBillingInvoiceActionDto { + @IsString() + @IsNotEmpty() + note: string; +} diff --git a/apps/api/src/app.module.ts b/apps/api/src/app.module.ts index ce20e1782d..318ba6afde 100644 --- a/apps/api/src/app.module.ts +++ b/apps/api/src/app.module.ts @@ -53,6 +53,7 @@ import { AdminOrganizationsModule } from './admin-organizations/admin-organizati import { AdminFeatureFlagsModule } from './admin-feature-flags/admin-feature-flags.module'; import { TimelinesModule } from './timelines/timelines.module'; import { BackgroundChecksModule } from './background-checks/background-checks.module'; +import { BillingModule } from './billing/billing.module'; @Module({ imports: [ @@ -114,6 +115,7 @@ import { BackgroundChecksModule } from './background-checks/background-checks.mo SecretsModule, SecurityPenetrationTestsModule, StripeModule, + BillingModule, BackgroundChecksModule, AdminOrganizationsModule, AdminFeatureFlagsModule, diff --git a/apps/api/src/background-checks/background-check-billing-customer.spec.ts b/apps/api/src/background-checks/background-check-billing-customer.spec.ts new file mode 100644 index 0000000000..e47abaf77f --- /dev/null +++ b/apps/api/src/background-checks/background-check-billing-customer.spec.ts @@ -0,0 +1,118 @@ +import { db, Prisma } from '@db'; +import type { StripeService } from '../stripe/stripe.service'; +import { findOrCreateBackgroundCheckBillingCustomer } from './background-check-billing-customer'; + +jest.mock('@db', () => { + class PrismaClientKnownRequestError extends Error { + code: string; + + constructor(message: string, options: { code: string }) { + super(message); + this.code = options.code; + } + } + + return { + Prisma: { + PrismaClientKnownRequestError, + }, + db: { + organizationBilling: { + findUnique: jest.fn(), + create: jest.fn(), + }, + organization: { + findUnique: jest.fn(), + }, + }, + }; +}); + +const mockedDb = db as jest.Mocked; + +describe('findOrCreateBackgroundCheckBillingCustomer', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('recovers when a concurrent request creates the billing row first', async () => { + type OrganizationBilling = typeof db.organizationBilling; + type Organization = typeof db.organization; + const dbMocks = mockedDb as unknown as { + organizationBilling: { + findUnique: jest.MockedFunction; + create: jest.MockedFunction; + }; + organization: { + findUnique: jest.MockedFunction; + }; + }; + + dbMocks.organizationBilling.findUnique + .mockResolvedValueOnce(null) + .mockResolvedValueOnce({ + stripeCustomerId: 'cus_winner', + } as Awaited>); + dbMocks.organization.findUnique.mockResolvedValueOnce({ + name: 'Acme', + } as Awaited>); + dbMocks.organizationBilling.create.mockRejectedValueOnce( + new Prisma.PrismaClientKnownRequestError('Unique constraint failed', { + code: 'P2002', + clientVersion: 'test', + }), + ); + + const customersCreate = jest.fn().mockResolvedValue({ id: 'cus_loser' }); + const customersUpdate = jest.fn().mockResolvedValue({ id: 'cus_winner' }); + const stripeService = { + getClient: () => ({ + customers: { + create: customersCreate, + update: customersUpdate, + }, + }), + } as unknown as StripeService; + + await expect( + findOrCreateBackgroundCheckBillingCustomer({ + stripeService, + organizationId: 'org_1', + customerEmail: 'billing@trycomp.ai', + }), + ).resolves.toBe('cus_winner'); + + expect(customersCreate).toHaveBeenCalledWith( + { + name: 'Acme', + metadata: { organizationId: 'org_1' }, + }, + { idempotencyKey: 'background-check-customer:org_1' }, + ); + expect(customersUpdate).toHaveBeenCalledWith('cus_winner', { + email: 'billing@trycomp.ai', + }); + }); + + it('does not create a Stripe client when existing billing needs no update', async () => { + type OrganizationBilling = typeof db.organizationBilling; + const dbMocks = mockedDb as unknown as { + organizationBilling: { + findUnique: jest.MockedFunction; + }; + }; + dbMocks.organizationBilling.findUnique.mockResolvedValueOnce({ + stripeCustomerId: 'cus_existing', + } as Awaited>); + const getClient = jest.fn(); + + await expect( + findOrCreateBackgroundCheckBillingCustomer({ + stripeService: { getClient } as unknown as StripeService, + organizationId: 'org_1', + }), + ).resolves.toBe('cus_existing'); + + expect(getClient).not.toHaveBeenCalled(); + }); +}); diff --git a/apps/api/src/background-checks/background-check-billing-customer.ts b/apps/api/src/background-checks/background-check-billing-customer.ts new file mode 100644 index 0000000000..c57dc72b6f --- /dev/null +++ b/apps/api/src/background-checks/background-check-billing-customer.ts @@ -0,0 +1,109 @@ +import { NotFoundException } from '@nestjs/common'; +import { db, Prisma } from '@db'; +import { StripeService } from '../stripe/stripe.service'; + +export async function findOrCreateBackgroundCheckBillingCustomer({ + stripeService, + organizationId, + customerEmail, +}: { + stripeService: StripeService; + organizationId: string; + customerEmail?: string; +}): Promise { + const existingBilling = await db.organizationBilling.findUnique({ + where: { organizationId }, + select: { stripeCustomerId: true }, + }); + + if (existingBilling) { + await updateStripeCustomerEmail({ + stripeService, + stripeCustomerId: existingBilling.stripeCustomerId, + customerEmail, + }); + return existingBilling.stripeCustomerId; + } + + const organization = await db.organization.findUnique({ + where: { id: organizationId }, + select: { name: true }, + }); + + if (!organization) { + throw new NotFoundException('Organization not found.'); + } + + const stripe = stripeService.getClient(); + const customer = await stripe.customers.create( + { + name: organization.name, + metadata: { organizationId }, + }, + { + idempotencyKey: `background-check-customer:${organizationId}`, + }, + ); + + try { + await db.organizationBilling.create({ + data: { + organizationId, + stripeCustomerId: customer.id, + }, + }); + } catch (error) { + if (!isUniqueConstraintError(error)) { + throw error; + } + + const billing = await db.organizationBilling.findUnique({ + where: { organizationId }, + select: { stripeCustomerId: true }, + }); + + if (!billing) { + throw error; + } + + await updateStripeCustomerEmail({ + stripeService, + stripeCustomerId: billing.stripeCustomerId, + customerEmail, + }); + + return billing.stripeCustomerId; + } + + await updateStripeCustomerEmail({ + stripeService, + stripeCustomerId: customer.id, + customerEmail, + }); + + return customer.id; +} + +async function updateStripeCustomerEmail({ + stripeService, + stripeCustomerId, + customerEmail, +}: { + stripeService: StripeService; + stripeCustomerId: string; + customerEmail?: string; +}): Promise { + if (!customerEmail) return; + + const stripe = stripeService.getClient(); + await stripe.customers.update(stripeCustomerId, { + email: customerEmail, + }); +} + +function isUniqueConstraintError(error: unknown): boolean { + return ( + error instanceof Prisma.PrismaClientKnownRequestError && + error.code === 'P2002' + ); +} diff --git a/apps/api/src/background-checks/background-check-billing-invoices.ts b/apps/api/src/background-checks/background-check-billing-invoices.ts new file mode 100644 index 0000000000..376a6a560e --- /dev/null +++ b/apps/api/src/background-checks/background-check-billing-invoices.ts @@ -0,0 +1,59 @@ +import Stripe from 'stripe'; +import { StripeService } from '../stripe/stripe.service'; + +export interface BackgroundCheckBillingInvoice { + id: string; + number: string; + createdAt: string; + dueDate: string | null; + amountPaid: number; + amountDue: number; + currency: string; + status: string; + type: 'Subscription' | 'One Time'; + hostedInvoiceUrl: string | null; + invoicePdfUrl: string | null; +} + +export async function listBackgroundCheckBillingInvoices({ + stripeService, + stripeCustomerId, +}: { + stripeService: StripeService; + stripeCustomerId: string | null; +}): Promise { + if (!stripeCustomerId || !stripeService.isConfigured()) { + return []; + } + + const stripe = stripeService.getClient(); + const invoices = await stripe.invoices.list({ + customer: stripeCustomerId, + limit: 10, + }); + + return invoices.data.map(toBillingInvoice); +} + +function toBillingInvoice( + invoice: Stripe.Invoice, +): BackgroundCheckBillingInvoice { + return { + id: invoice.id, + number: invoice.number ?? invoice.id, + createdAt: new Date(invoice.created * 1000).toISOString(), + dueDate: invoice.due_date + ? new Date(invoice.due_date * 1000).toISOString() + : null, + amountPaid: invoice.amount_paid, + amountDue: invoice.amount_due, + currency: invoice.currency, + status: invoice.status ?? 'unknown', + type: + invoice.parent?.type === 'subscription_details' + ? 'Subscription' + : 'One Time', + hostedInvoiceUrl: invoice.hosted_invoice_url ?? null, + invoicePdfUrl: invoice.invoice_pdf ?? null, + }; +} diff --git a/apps/api/src/background-checks/background-check-billing-urls.spec.ts b/apps/api/src/background-checks/background-check-billing-urls.spec.ts new file mode 100644 index 0000000000..3fba0f271f --- /dev/null +++ b/apps/api/src/background-checks/background-check-billing-urls.spec.ts @@ -0,0 +1,42 @@ +import { BadRequestException } from '@nestjs/common'; +import { validateBackgroundCheckBillingRedirectUrl } from './background-check-billing-urls'; + +describe('validateBackgroundCheckBillingRedirectUrl', () => { + const originalEnv = { ...process.env }; + + afterEach(() => { + process.env = { ...originalEnv }; + }); + + it('throws a controlled error when the configured app URL is malformed', () => { + process.env.NEXT_PUBLIC_APP_URL = 'not a url'; + process.env.APP_URL = ''; + process.env.BETTER_AUTH_URL = ''; + + expect(() => + validateBackgroundCheckBillingRedirectUrl( + 'https://app.trycomp.ai/return', + ), + ).toThrow(BadRequestException); + }); + + it('rejects opaque configured app URLs', () => { + process.env.NEXT_PUBLIC_APP_URL = 'file:///tmp/app.html'; + process.env.APP_URL = ''; + process.env.BETTER_AUTH_URL = ''; + + expect(() => + validateBackgroundCheckBillingRedirectUrl('file:///tmp/return.html'), + ).toThrow(BadRequestException); + }); + + it('rejects opaque redirect URLs', () => { + process.env.NEXT_PUBLIC_APP_URL = 'https://app.trycomp.ai'; + process.env.APP_URL = ''; + process.env.BETTER_AUTH_URL = ''; + + expect(() => + validateBackgroundCheckBillingRedirectUrl('data:text/html,ok'), + ).toThrow(BadRequestException); + }); +}); diff --git a/apps/api/src/background-checks/background-check-billing-urls.ts b/apps/api/src/background-checks/background-check-billing-urls.ts new file mode 100644 index 0000000000..d2171035c2 --- /dev/null +++ b/apps/api/src/background-checks/background-check-billing-urls.ts @@ -0,0 +1,48 @@ +import { BadRequestException } from '@nestjs/common'; + +export function validateBackgroundCheckBillingRedirectUrl(url: string): void { + const appUrl = + process.env.NEXT_PUBLIC_APP_URL || + process.env.APP_URL || + process.env.BETTER_AUTH_URL; + if (!appUrl) { + throw new BadRequestException('App URL is not configured on the server.'); + } + + let appUrlParsed: URL; + try { + appUrlParsed = new URL(appUrl); + } catch { + throw new BadRequestException( + 'App URL is not configured correctly on the server.', + ); + } + if (!isWebUrl(appUrlParsed)) { + throw new BadRequestException( + 'App URL is not configured correctly on the server.', + ); + } + + let parsed: URL; + try { + parsed = new URL(url); + } catch { + throw new BadRequestException('Invalid redirect URL.'); + } + if (!isWebUrl(parsed)) { + throw new BadRequestException('Invalid redirect URL.'); + } + + if (parsed.origin !== appUrlParsed.origin) { + throw new BadRequestException( + 'Redirect URL must belong to the application origin.', + ); + } +} + +function isWebUrl(url: URL): boolean { + return ( + url.origin !== 'null' && + (url.protocol === 'https:' || url.protocol === 'http:') + ); +} diff --git a/apps/api/src/background-checks/background-check-billing.controller.ts b/apps/api/src/background-checks/background-check-billing.controller.ts index 708d0dd2c1..e872ccb06a 100644 --- a/apps/api/src/background-checks/background-check-billing.controller.ts +++ b/apps/api/src/background-checks/background-check-billing.controller.ts @@ -1,9 +1,17 @@ -import { Body, Controller, Get, HttpCode, Post, UseGuards } from '@nestjs/common'; +import { + Body, + Controller, + Get, + HttpCode, + Post, + UseGuards, +} from '@nestjs/common'; import { ApiOperation, ApiSecurity, ApiTags } from '@nestjs/swagger'; -import { OrganizationId } from '../auth/auth-context.decorator'; +import { AuthContext, OrganizationId } from '../auth/auth-context.decorator'; import { HybridAuthGuard } from '../auth/hybrid-auth.guard'; import { PermissionGuard } from '../auth/permission.guard'; import { RequirePermission } from '../auth/require-permission.decorator'; +import type { AuthContext as AuthContextType } from '../auth/types'; import { BackgroundCheckBillingService } from './background-check-billing.service'; import { BackgroundCheckBillingPortalDto, @@ -28,15 +36,19 @@ export class BackgroundCheckBillingController { @Post('setup-session') @RequirePermission('organization', 'update') @HttpCode(200) - @ApiOperation({ summary: 'Create a Stripe setup session for background checks' }) + @ApiOperation({ + summary: 'Create a Stripe setup session for background checks', + }) async setupSession( @OrganizationId() organizationId: string, + @AuthContext() authContext: AuthContextType, @Body() body: BackgroundCheckSetupSessionDto, ) { return this.billingService.createSetupSession({ organizationId, successUrl: body.successUrl, cancelUrl: body.cancelUrl, + customerEmail: authContext.userEmail, }); } diff --git a/apps/api/src/background-checks/background-check-billing.service.ts b/apps/api/src/background-checks/background-check-billing.service.ts index 21a1b8cd63..8b9b7b3975 100644 --- a/apps/api/src/background-checks/background-check-billing.service.ts +++ b/apps/api/src/background-checks/background-check-billing.service.ts @@ -1,210 +1,43 @@ -import { BadRequestException, Injectable, NotFoundException } from '@nestjs/common'; -import { db } from '@db'; -import { StripeService } from '../stripe/stripe.service'; +import { Injectable } from '@nestjs/common'; +import { + getBillingSku, + resolveBillingCatalogEnvironment, +} from '@trycompai/billing'; +import { BillingService } from '../billing/billing.service'; +import { validateBackgroundCheckBillingRedirectUrl } from './background-check-billing-urls'; @Injectable() export class BackgroundCheckBillingService { - constructor(private readonly stripeService: StripeService) {} + constructor(private readonly billingService: BillingService) {} - async getStatus(organizationId: string): Promise<{ - hasBilling: boolean; - hasPaymentMethod: boolean; - setupAt: Date | null; - usage: { - backgroundChecks: number; - penetrationTests: number; - }; - }> { - const [billing, backgroundChecks, penetrationTests] = await Promise.all([ - db.organizationBilling.findUnique({ - where: { organizationId }, - select: { - stripeCustomerId: true, - stripeBackgroundCheckPaymentMethodId: true, - backgroundCheckPaymentMethodSetupAt: true, - }, - }), - db.backgroundCheckRequest.count({ where: { organizationId } }), - db.securityPenetrationTestRun.count({ where: { organizationId } }), - ]); - - return { - hasBilling: !!billing, - hasPaymentMethod: !!billing?.stripeBackgroundCheckPaymentMethodId, - setupAt: billing?.backgroundCheckPaymentMethodSetupAt ?? null, - usage: { - backgroundChecks, - penetrationTests, - }, - }; + async getStatus(organizationId: string) { + return this.billingService.getStatus(organizationId); } - async createSetupSession({ - organizationId, - successUrl, - cancelUrl, - }: { + async createSetupSession(params: { organizationId: string; successUrl: string; cancelUrl: string; + customerEmail?: string; }): Promise<{ url: string }> { - this.validateRedirectUrl(successUrl); - this.validateRedirectUrl(cancelUrl); - - const stripe = this.stripeService.getClient(); - const customerId = await this.findOrCreateCustomer(organizationId); - const price = await this.getBackgroundCheckPrice(); - - const session = await stripe.checkout.sessions.create({ - mode: 'setup', - customer: customerId, - currency: price.currency, - success_url: successUrl, - cancel_url: cancelUrl, - metadata: { - organizationId, - source: 'comp-background-check', - }, - }); - - if (!session.url) { - throw new BadRequestException('Failed to create Stripe Checkout session.'); - } - - return { url: session.url }; + validateBackgroundCheckBillingRedirectUrl(params.successUrl); + validateBackgroundCheckBillingRedirectUrl(params.cancelUrl); + return this.billingService.createSetupSession(params); } - async handleSetupSuccess({ - organizationId, - sessionId, - }: { + async handleSetupSuccess(params: { organizationId: string; sessionId: string; }): Promise<{ success: true }> { - const stripe = this.stripeService.getClient(); - const session = await stripe.checkout.sessions.retrieve(sessionId, { - expand: ['setup_intent'], - }); - - if (session.status !== 'complete') { - throw new BadRequestException('Checkout session is not complete.'); - } - - if ( - session.metadata?.organizationId && - session.metadata.organizationId !== organizationId - ) { - throw new BadRequestException( - 'Checkout session does not belong to this organization.', - ); - } - - const stripeCustomerId = this.extractStripeId(session.customer); - if (!stripeCustomerId) { - throw new BadRequestException('Checkout session is missing a customer.'); - } - - await this.assertCustomerBelongsToOrganization({ - organizationId, - stripeCustomerId, - }); - - const setupIntent = session.setup_intent; - if (!setupIntent || typeof setupIntent === 'string') { - throw new BadRequestException( - 'Checkout session is missing a setup intent.', - ); - } - - const paymentMethodId = this.extractStripeId(setupIntent.payment_method); - if (!paymentMethodId) { - throw new BadRequestException('Setup intent is missing a payment method.'); - } - - await stripe.customers.update(stripeCustomerId, { - invoice_settings: { - default_payment_method: paymentMethodId, - }, - }); - - await db.organizationBilling.upsert({ - where: { organizationId }, - create: { - organizationId, - stripeCustomerId, - stripeBackgroundCheckPaymentMethodId: paymentMethodId, - backgroundCheckPaymentMethodSetupAt: new Date(), - }, - update: { - stripeCustomerId, - stripeBackgroundCheckPaymentMethodId: paymentMethodId, - backgroundCheckPaymentMethodSetupAt: new Date(), - }, - }); - - return { success: true }; + return this.billingService.handleSetupSuccess(params); } - async createBillingPortalSession({ - organizationId, - returnUrl, - }: { + async createBillingPortalSession(params: { organizationId: string; returnUrl: string; }): Promise<{ url: string }> { - this.validateRedirectUrl(returnUrl); - - const stripe = this.stripeService.getClient(); - const billing = await db.organizationBilling.findUnique({ - where: { organizationId }, - select: { stripeCustomerId: true }, - }); - - if (!billing) { - throw new NotFoundException('No billing record found for this organization.'); - } - - const portalSession = await stripe.billingPortal.sessions.create({ - customer: billing.stripeCustomerId, - return_url: returnUrl, - }); - - return { url: portalSession.url }; - } - - async findOrCreateCustomer(organizationId: string): Promise { - const existingBilling = await db.organizationBilling.findUnique({ - where: { organizationId }, - select: { stripeCustomerId: true }, - }); - - if (existingBilling) { - return existingBilling.stripeCustomerId; - } - - const organization = await db.organization.findUnique({ - where: { id: organizationId }, - select: { name: true }, - }); - - if (!organization) { - throw new NotFoundException('Organization not found.'); - } - - const stripe = this.stripeService.getClient(); - const customer = await stripe.customers.create({ - name: organization.name, - metadata: { organizationId }, - }); - - await db.organizationBilling.create({ - data: { - organizationId, - stripeCustomerId: customer.id, - }, - }); - - return customer.id; + validateBackgroundCheckBillingRedirectUrl(params.returnUrl); + return this.billingService.createBillingPortalSession(params); } async getBackgroundCheckPrice(): Promise<{ @@ -212,80 +45,17 @@ export class BackgroundCheckBillingService { unitAmount: number; currency: string; }> { - const priceId = process.env.STRIPE_BACKGROUND_CHECK_PRICE_ID; - if (!priceId) { - throw new BadRequestException( - 'Background check pricing is not configured. Contact support.', - ); - } - - const stripe = this.stripeService.getClient(); - const price = await stripe.prices.retrieve(priceId); - if (price.unit_amount === null || price.unit_amount === undefined) { - throw new BadRequestException( - 'Background check pricing is not configured. Contact support.', - ); - } - + const sku = getBillingSku({ + environment: resolveBillingCatalogEnvironment({ + stripeSecretKey: process.env.STRIPE_SECRET_KEY, + nodeEnv: process.env.NODE_ENV, + }), + skuKey: 'background_check_one_time', + }); return { - id: price.id, - unitAmount: price.unit_amount, - currency: price.currency, + id: sku.stripePriceId, + unitAmount: sku.unitAmount, + currency: sku.currency, }; } - - private validateRedirectUrl(url: string): void { - const appUrl = - process.env.NEXT_PUBLIC_APP_URL || - process.env.APP_URL || - process.env.BETTER_AUTH_URL; - if (!appUrl) { - throw new BadRequestException('App URL is not configured on the server.'); - } - - let parsed: URL; - try { - parsed = new URL(url); - } catch { - throw new BadRequestException('Invalid redirect URL.'); - } - - if (parsed.origin !== new URL(appUrl).origin) { - throw new BadRequestException('Redirect URL must belong to the application origin.'); - } - } - - private extractStripeId(value: string | { id?: string } | null): string | null { - if (!value) return null; - if (typeof value === 'string') return value; - return value.id ?? null; - } - - private async assertCustomerBelongsToOrganization({ - organizationId, - stripeCustomerId, - }: { - organizationId: string; - stripeCustomerId: string; - }): Promise { - const billing = await db.organizationBilling.findUnique({ - where: { organizationId }, - select: { stripeCustomerId: true }, - }); - - if (billing?.stripeCustomerId === stripeCustomerId) { - return; - } - - const stripe = this.stripeService.getClient(); - const customer = await stripe.customers.retrieve(stripeCustomerId); - if ( - customer.deleted || - customer.metadata?.organizationId !== organizationId - ) { - throw new BadRequestException( - 'Checkout session does not belong to this organization.', - ); - } - } } diff --git a/apps/api/src/background-checks/background-check-payment.service.spec.ts b/apps/api/src/background-checks/background-check-payment.service.spec.ts index b794294310..62443a91ca 100644 --- a/apps/api/src/background-checks/background-check-payment.service.spec.ts +++ b/apps/api/src/background-checks/background-check-payment.service.spec.ts @@ -1,21 +1,21 @@ -import { HttpException, HttpStatus } from '@nestjs/common'; -import { db } from '@db'; +jest.mock('@db', () => ({ db: {} })); + +import { HttpException, HttpStatus, Logger } from '@nestjs/common'; +import { BillingEntitlementsService } from '../billing/billing-entitlements.service'; import { StripeService } from '../stripe/stripe.service'; import { BackgroundCheckBillingService } from './background-check-billing.service'; import { BackgroundCheckPaymentService } from './background-check-payment.service'; -jest.mock('@db', () => ({ - db: { - organizationBilling: { - findUnique: jest.fn(), - }, - }, -})); - -const mockedDb = db as jest.Mocked; - -function mockAsync(fn: unknown): jest.MockedFunction<() => Promise> { - return fn as jest.MockedFunction<() => Promise>; +function mockEntitlements( + overrides: Partial = {}, +): BillingEntitlementsService { + return { + tryConsumeIncludedUsageForProduct: jest + .fn() + .mockResolvedValue({ status: 'not_configured' }), + refundIncludedUsageForProduct: jest.fn().mockResolvedValue(undefined), + ...overrides, + } as unknown as BillingEntitlementsService; } describe('BackgroundCheckPaymentService', () => { @@ -23,58 +23,142 @@ describe('BackgroundCheckPaymentService', () => { jest.clearAllMocks(); }); - it('throws payment required when no background check payment method exists', async () => { - mockAsync>>( - mockedDb.organizationBilling.findUnique, - ).mockResolvedValueOnce(null); + it('consumes background check subscription allowance', async () => { + const tryConsumeIncludedUsageForProduct = jest + .fn() + .mockResolvedValue({ status: 'consumed', subscriptionId: 'obs_1' }); + const entitlements = mockEntitlements({ + tryConsumeIncludedUsageForProduct, + }); const service = new BackgroundCheckPaymentService( { getClient: jest.fn() } as unknown as StripeService, - { getBackgroundCheckPrice: jest.fn() } as unknown as BackgroundCheckBillingService, + {} as BackgroundCheckBillingService, + entitlements, ); await expect( service.charge({ organizationId: 'org_1', memberId: 'mem_1' }), - ).rejects.toThrow( - expect.objectContaining({ - status: HttpStatus.PAYMENT_REQUIRED, + ).resolves.toEqual({ + paymentIntentId: null, + invoiceId: null, + status: 'subscription_included', + amount: 0, + currency: 'usd', + }); + + expect(tryConsumeIncludedUsageForProduct).toHaveBeenCalledWith({ + organizationId: 'org_1', + productKey: 'background_check', + sourceResourceId: 'mem_1', + }); + }); + + it('blocks when no background check subscription is configured', async () => { + const service = new BackgroundCheckPaymentService( + { getClient: jest.fn() } as unknown as StripeService, + {} as BackgroundCheckBillingService, + mockEntitlements(), + ); + + try { + await service.charge({ organizationId: 'org_1', memberId: 'mem_1' }); + throw new Error('Expected charge to require payment'); + } catch (error) { + expect(error).toBeInstanceOf(HttpException); + if (!(error instanceof HttpException)) throw error; + expect(error.getStatus()).toBe(HttpStatus.PAYMENT_REQUIRED); + expect(error.getResponse()).toEqual( + expect.objectContaining({ + code: 'background_check_subscription_required', + }), + ); + } + }); + + it('blocks when background check subscription allowance is exhausted', async () => { + const service = new BackgroundCheckPaymentService( + { getClient: jest.fn() } as unknown as StripeService, + {} as BackgroundCheckBillingService, + mockEntitlements({ + tryConsumeIncludedUsageForProduct: jest + .fn() + .mockResolvedValue({ status: 'exhausted', subscriptionId: 'obs_1' }), }), ); + + try { + await service.charge({ organizationId: 'org_1', memberId: 'mem_1' }); + throw new Error('Expected charge to require payment'); + } catch (error) { + expect(error).toBeInstanceOf(HttpException); + if (!(error instanceof HttpException)) throw error; + expect(error.getStatus()).toBe(HttpStatus.PAYMENT_REQUIRED); + expect(error.getResponse()).toEqual( + expect.objectContaining({ + code: 'background_check_subscription_exhausted', + }), + ); + } }); - it('charges Stripe with payment-method scoped idempotency key', async () => { - mockAsync>>( - mockedDb.organizationBilling.findUnique, - ).mockResolvedValueOnce({ - stripeCustomerId: 'cus_1', - stripeBackgroundCheckPaymentMethodId: 'pm_1', - } as Awaited>); - const paymentIntentsCreate = jest.fn().mockResolvedValue({ - id: 'pi_1', - status: 'succeeded', + it('refunds the consumed background check allowance by product family', async () => { + const refundIncludedUsageForProduct = jest + .fn() + .mockResolvedValue(undefined); + const entitlements = mockEntitlements({ refundIncludedUsageForProduct }); + const service = new BackgroundCheckPaymentService( + { getClient: jest.fn() } as unknown as StripeService, + {} as BackgroundCheckBillingService, + entitlements, + ); + + await expect( + service.refund({ + organizationId: 'org_1', + memberId: 'mem_1', + paymentIntentId: null, + }), + ).resolves.toBeNull(); + + expect(refundIncludedUsageForProduct).toHaveBeenCalledWith({ + organizationId: 'org_1', + productKey: 'background_check', + sourceResourceId: 'mem_1', + reason: 'background_check_failed', + }); + }); + + it('does not throw when refunding consumed allowance fails', async () => { + const loggerSpy = jest + .spyOn(Logger.prototype, 'error') + .mockImplementation(); + const entitlements = mockEntitlements({ + refundIncludedUsageForProduct: jest + .fn() + .mockRejectedValue(new Error('refund failed')), }); const service = new BackgroundCheckPaymentService( - { - getClient: () => ({ paymentIntents: { create: paymentIntentsCreate } }), - } as unknown as StripeService, - { - getBackgroundCheckPrice: jest.fn().mockResolvedValue({ - id: 'price_bg', - unitAmount: 1250, - currency: 'usd', - }), - } as unknown as BackgroundCheckBillingService, + { getClient: jest.fn() } as unknown as StripeService, + {} as BackgroundCheckBillingService, + entitlements, ); - await service.charge({ organizationId: 'org_1', memberId: 'mem_1' }); + await expect( + service.refund({ + organizationId: 'org_1', + memberId: 'mem_1', + paymentIntentId: null, + }), + ).resolves.toBeNull(); - expect(paymentIntentsCreate).toHaveBeenCalledWith( + expect(loggerSpy).toHaveBeenCalledWith( + 'Failed to refund background check included usage - manual credit review required.', expect.objectContaining({ - amount: 1250, - customer: 'cus_1', - description: 'Comp AI - Background Check x1', - payment_method: 'pm_1', + organizationId: 'org_1', + memberId: 'mem_1', + error: 'refund failed', }), - { idempotencyKey: 'background-check:org_1:mem_1:price_bg:pm_1' }, ); + loggerSpy.mockRestore(); }); }); diff --git a/apps/api/src/background-checks/background-check-payment.service.ts b/apps/api/src/background-checks/background-check-payment.service.ts index 17e0dc4501..326c9593d7 100644 --- a/apps/api/src/background-checks/background-check-payment.service.ts +++ b/apps/api/src/background-checks/background-check-payment.service.ts @@ -1,96 +1,83 @@ import { HttpException, HttpStatus, Injectable, Logger } from '@nestjs/common'; -import { db } from '@db'; +import { BillingEntitlementsService } from '../billing/billing-entitlements.service'; import { StripeService } from '../stripe/stripe.service'; import { BackgroundCheckBillingService } from './background-check-billing.service'; @Injectable() export class BackgroundCheckPaymentService { - private static readonly receiptDescription = - 'Comp AI - Background Check x1'; - private readonly logger = new Logger(BackgroundCheckPaymentService.name); constructor( private readonly stripeService: StripeService, private readonly billingService: BackgroundCheckBillingService, + private readonly entitlements: BillingEntitlementsService, ) {} - async charge(params: { - organizationId: string; - memberId: string; - }): Promise<{ - paymentIntentId: string; + async charge(params: { organizationId: string; memberId: string }): Promise<{ + paymentIntentId: string | null; + invoiceId: string | null; status: string; amount: number; currency: string; }> { - const billing = await db.organizationBilling.findUnique({ - where: { organizationId: params.organizationId }, - select: { - stripeCustomerId: true, - stripeBackgroundCheckPaymentMethodId: true, - }, - }); + const includedUsage = + await this.entitlements.tryConsumeIncludedUsageForProduct({ + organizationId: params.organizationId, + productKey: 'background_check', + sourceResourceId: params.memberId, + }); - if (!billing?.stripeBackgroundCheckPaymentMethodId) { - throw new HttpException( - 'No background check payment method on file. Update billing first.', - HttpStatus.PAYMENT_REQUIRED, - ); + if (includedUsage.status === 'consumed') { + return { + paymentIntentId: null, + invoiceId: null, + status: 'subscription_included', + amount: 0, + currency: 'usd', + }; } - const price = await this.billingService.getBackgroundCheckPrice(); - const stripe = this.stripeService.getClient(); - const paymentIntent = await stripe.paymentIntents.create( + throw new HttpException( { - customer: billing.stripeCustomerId, - amount: price.unitAmount, - currency: price.currency, - description: BackgroundCheckPaymentService.receiptDescription, - payment_method: billing.stripeBackgroundCheckPaymentMethodId, - off_session: true, - confirm: true, - automatic_payment_methods: { - enabled: true, - allow_redirects: 'never', - }, - metadata: { - source: 'comp-background-check', - compOrganizationId: params.organizationId, - compMemberId: params.memberId, - }, - }, - { - idempotencyKey: [ - 'background-check', - params.organizationId, - params.memberId, - price.id, - billing.stripeBackgroundCheckPaymentMethodId, - ].join(':'), + error: + includedUsage.status === 'exhausted' + ? 'No background checks remaining in your subscription. Upgrade or wait for your monthly allowance to reset.' + : 'Choose a background check plan before requesting a background check.', + code: + includedUsage.status === 'exhausted' + ? 'background_check_subscription_exhausted' + : 'background_check_subscription_required', }, + HttpStatus.PAYMENT_REQUIRED, ); - - if (paymentIntent.status !== 'succeeded') { - throw new HttpException( - 'Background check payment failed. Update billing and try again.', - HttpStatus.PAYMENT_REQUIRED, - ); - } - - return { - paymentIntentId: paymentIntent.id, - status: paymentIntent.status, - amount: price.unitAmount, - currency: price.currency, - }; } async refund(params: { organizationId: string; memberId: string; - paymentIntentId: string; + paymentIntentId: string | null; }): Promise { + if (!params.paymentIntentId) { + try { + await this.entitlements.refundIncludedUsageForProduct({ + organizationId: params.organizationId, + productKey: 'background_check', + sourceResourceId: params.memberId, + reason: 'background_check_failed', + }); + } catch (error) { + this.logger.error( + 'Failed to refund background check included usage - manual credit review required.', + { + organizationId: params.organizationId, + memberId: params.memberId, + error: error instanceof Error ? error.message : 'Unknown error', + }, + ); + } + return null; + } + try { const stripe = this.stripeService.getClient(); const refund = await stripe.refunds.create( @@ -107,7 +94,7 @@ export class BackgroundCheckPaymentService { return refund.id; } catch (error) { this.logger.error( - 'Failed to refund background check payment — manual refund required.', + 'Failed to refund background check payment - manual refund required.', { organizationId: params.organizationId, memberId: params.memberId, @@ -118,4 +105,8 @@ export class BackgroundCheckPaymentService { return null; } } + + async getBackgroundCheckPrice() { + return this.billingService.getBackgroundCheckPrice(); + } } diff --git a/apps/api/src/background-checks/background-checks.module.ts b/apps/api/src/background-checks/background-checks.module.ts index b1e16339b6..b643bd468f 100644 --- a/apps/api/src/background-checks/background-checks.module.ts +++ b/apps/api/src/background-checks/background-checks.module.ts @@ -1,6 +1,7 @@ import { Module } from '@nestjs/common'; import { AttachmentsModule } from '../attachments/attachments.module'; import { AuthModule } from '../auth/auth.module'; +import { BillingModule } from '../billing/billing.module'; import { BackgroundCheckBillingController } from './background-check-billing.controller'; import { BackgroundCheckBillingService } from './background-check-billing.service'; import { BackgroundCheckCustomService } from './background-check-custom.service'; @@ -13,7 +14,7 @@ import { import { BackgroundChecksService } from './background-checks.service'; @Module({ - imports: [AuthModule, AttachmentsModule], + imports: [AuthModule, AttachmentsModule, BillingModule], controllers: [ BackgroundChecksController, PeopleBackgroundChecksController, diff --git a/apps/api/src/background-checks/background-checks.service.spec.ts b/apps/api/src/background-checks/background-checks.service.spec.ts index 17dfcaa0e4..2d23c240c9 100644 --- a/apps/api/src/background-checks/background-checks.service.spec.ts +++ b/apps/api/src/background-checks/background-checks.service.spec.ts @@ -1,9 +1,9 @@ import { BackgroundCheckIdentityClient } from './background-check-identity.client'; +import { BillingService } from '../billing/billing.service'; import { BackgroundCheckBillingService } from './background-check-billing.service'; import { BackgroundCheckPaymentService } from './background-check-payment.service'; import { BackgroundChecksService } from './background-checks.service'; import { db } from '@db'; -import type { StripeService } from '../stripe/stripe.service'; jest.mock('@db', () => { class PrismaClientKnownRequestError extends Error { @@ -61,7 +61,8 @@ function mockAsync(fn: unknown): jest.MockedFunction<() => Promise> { function invocationOrder(fn: unknown, index = 0): number { return ( - (fn as { mock: { invocationCallOrder: number[] } }).mock.invocationCallOrder[index] ?? 0 + (fn as { mock: { invocationCallOrder: number[] } }).mock + .invocationCallOrder[index] ?? 0 ); } @@ -132,7 +133,9 @@ describe('background checks', () => { compOrganizationId: 'org_1', compMemberId: 'mem_1', }); - expect(body.callbackUrl).toBe('https://api.trycomp.ai/v1/background-checks/webhook'); + expect(body.callbackUrl).toBe( + 'https://api.trycomp.ai/v1/background-checks/webhook', + ); expect(body.requesterNotes).toBeUndefined(); }); @@ -192,7 +195,9 @@ describe('background checks', () => { mockAsync>>( mockedDb.backgroundCheckRequest.findUnique, ).mockResolvedValueOnce( - existing as Awaited>, + existing as Awaited< + ReturnType + >, ); const identityClient = { createBackgroundCheck: jest.fn() }; const paymentService = { charge: jest.fn(), refund: jest.fn() }; @@ -243,7 +248,9 @@ describe('background checks', () => { } as Awaited>); const identityClient = { - createBackgroundCheck: jest.fn().mockRejectedValue(new Error('identity down')), + createBackgroundCheck: jest + .fn() + .mockRejectedValue(new Error('identity down')), }; const paymentService = { charge: jest.fn().mockResolvedValue({ @@ -369,9 +376,9 @@ describe('background checks', () => { }), ); // Record is created before Identity API is called - expect(invocationOrder(mockedDb.backgroundCheckRequest.create)).toBeLessThan( - invocationOrder(identityClient.createBackgroundCheck), - ); + expect( + invocationOrder(mockedDb.backgroundCheckRequest.create), + ).toBeLessThan(invocationOrder(identityClient.createBackgroundCheck)); expect(identityClient.createBackgroundCheck).toHaveBeenCalledWith( expect.not.objectContaining({ requesterNotes: expect.any(String), @@ -425,50 +432,15 @@ describe('background checks', () => { }); it('uses BETTER_AUTH_URL as the local app URL fallback for setup redirects', async () => { - process.env = { - ...process.env, - NEXT_PUBLIC_APP_URL: '', - APP_URL: '', - BETTER_AUTH_URL: 'http://localhost:3000', - }; - mockAsync>>( - mockedDb.organizationBilling.findUnique, - ).mockResolvedValueOnce(null); - mockAsync>>( - mockedDb.organization.findUnique, - ).mockResolvedValueOnce({ - name: 'Acme', - } as Awaited>); - mockAsync>>( - mockedDb.organizationBilling.create, - ).mockResolvedValueOnce({ - organizationId: 'org_1', - stripeCustomerId: 'cus_1', - } as Awaited>); - - const stripe = { - checkout: { - sessions: { - create: jest.fn().mockResolvedValue({ - url: 'https://checkout.stripe.com/c/session_1', - }), - }, - }, - customers: { - create: jest.fn().mockResolvedValue({ id: 'cus_1' }), - }, - prices: { - retrieve: jest.fn().mockResolvedValue({ - id: 'price_bg', - unit_amount: 4900, - currency: 'usd', - }), - }, - }; - const stripeService = { - getClient: () => stripe, - } as unknown as StripeService; - const service = new BackgroundCheckBillingService(stripeService); + process.env.NEXT_PUBLIC_APP_URL = ''; + process.env.APP_URL = ''; + process.env.BETTER_AUTH_URL = 'http://localhost:3000'; + const billingService = { + createSetupSession: jest.fn().mockResolvedValue({ + url: 'https://checkout.stripe.com/c/session_1', + }), + } as unknown as BillingService; + const service = new BackgroundCheckBillingService(billingService); await expect( service.createSetupSession({ @@ -476,26 +448,45 @@ describe('background checks', () => { successUrl: 'http://localhost:3000/org_1/people/mem_1?background_check_billing=success', cancelUrl: 'http://localhost:3000/org_1/people/mem_1', + customerEmail: 'billing@trycomp.ai', }), ).resolves.toEqual({ url: 'https://checkout.stripe.com/c/session_1' }); + + expect(billingService.createSetupSession).toHaveBeenCalledWith({ + organizationId: 'org_1', + successUrl: + 'http://localhost:3000/org_1/people/mem_1?background_check_billing=success', + cancelUrl: 'http://localhost:3000/org_1/people/mem_1', + customerEmail: 'billing@trycomp.ai', + }); }); it('includes background check and penetration test usage in billing status', async () => { - mockAsync>>( - mockedDb.organizationBilling.findUnique, - ).mockResolvedValueOnce({ - stripeCustomerId: 'cus_1', - stripeBackgroundCheckPaymentMethodId: 'pm_1', - backgroundCheckPaymentMethodSetupAt: new Date('2026-04-29T12:00:00.000Z'), - } as Awaited>); - mockAsync(mockedDb.backgroundCheckRequest.count).mockResolvedValueOnce(4); - mockAsync( - mockedDb.securityPenetrationTestRun.count, - ).mockResolvedValueOnce(2); - - const service = new BackgroundCheckBillingService({ - getClient: jest.fn(), - } as unknown as StripeService); + const billingService = { + getStatus: jest.fn().mockResolvedValue({ + hasBilling: true, + hasPaymentMethod: true, + setupAt: new Date('2026-04-29T12:00:00.000Z'), + usage: { backgroundChecks: 4, penetrationTests: 2 }, + subscriptions: [], + invoices: [ + { + id: 'in_1', + number: 'INV-001', + createdAt: '2026-04-30T00:00:00.000Z', + dueDate: null, + amountPaid: 4900, + amountDue: 4900, + currency: 'usd', + status: 'paid', + type: 'One Time', + hostedInvoiceUrl: 'https://invoice.stripe.com/i/in_1', + invoicePdfUrl: 'https://invoice.stripe.com/i/in_1.pdf', + }, + ], + }), + } as unknown as BillingService; + const service = new BackgroundCheckBillingService(billingService); await expect(service.getStatus('org_1')).resolves.toMatchObject({ hasBilling: true, @@ -504,12 +495,16 @@ describe('background checks', () => { backgroundChecks: 4, penetrationTests: 2, }, + invoices: [ + { + id: 'in_1', + number: 'INV-001', + amountPaid: 4900, + status: 'paid', + type: 'One Time', + }, + ], }); - expect(mockedDb.backgroundCheckRequest.count).toHaveBeenCalledWith({ - where: { organizationId: 'org_1' }, - }); - expect(mockedDb.securityPenetrationTestRun.count).toHaveBeenCalledWith({ - where: { organizationId: 'org_1' }, - }); + expect(billingService.getStatus).toHaveBeenCalledWith('org_1'); }); }); diff --git a/apps/api/src/background-checks/dto/background-check-billing.dto.ts b/apps/api/src/background-checks/dto/background-check-billing.dto.ts index 7d1e9c249e..b796939c57 100644 --- a/apps/api/src/background-checks/dto/background-check-billing.dto.ts +++ b/apps/api/src/background-checks/dto/background-check-billing.dto.ts @@ -1,4 +1,4 @@ -import { IsString, IsUrl } from 'class-validator'; +import { IsNotEmpty, IsString, IsUrl } from 'class-validator'; export class BackgroundCheckSetupSessionDto { @IsString() @@ -12,6 +12,7 @@ export class BackgroundCheckSetupSessionDto { export class BackgroundCheckSetupSuccessDto { @IsString() + @IsNotEmpty() sessionId: string; } diff --git a/apps/api/src/billing/billing-credits.service.spec.ts b/apps/api/src/billing/billing-credits.service.spec.ts new file mode 100644 index 0000000000..f3d42c7ec0 --- /dev/null +++ b/apps/api/src/billing/billing-credits.service.spec.ts @@ -0,0 +1,278 @@ +import { db } from '@db'; +import { BillingCreditsService } from './billing-credits.service'; + +jest.mock('@db', () => ({ + db: { + billingCreditBalance: { + findMany: jest.fn(), + findFirst: jest.fn(), + findFirstOrThrow: jest.fn(), + findUniqueOrThrow: jest.fn(), + create: jest.fn(), + }, + billingCreditEvent: { + findMany: jest.fn(), + findFirst: jest.fn(), + }, + $transaction: jest.fn(), + }, +})); + +type MockTx = { + billingCreditEvent: { create: jest.Mock }; + billingCreditBalance: { update: jest.Mock; updateMany: jest.Mock }; +}; + +const mockedDb = db as unknown as { + billingCreditBalance: { + findMany: jest.Mock; + findFirst: jest.Mock; + findFirstOrThrow: jest.Mock; + findUniqueOrThrow: jest.Mock; + create: jest.Mock; + }; + billingCreditEvent: { findMany: jest.Mock; findFirst: jest.Mock }; + $transaction: jest.Mock; +}; + +describe('BillingCreditsService', () => { + let service: BillingCreditsService; + let tx: MockTx; + + beforeEach(() => { + jest.clearAllMocks(); + tx = { + billingCreditEvent: { create: jest.fn() }, + billingCreditBalance: { + update: jest.fn(), + updateMany: jest.fn().mockResolvedValue({ count: 1 }), + }, + }; + mockedDb.$transaction.mockImplementation( + (callback: (tx: MockTx) => Promise) => callback(tx), + ); + mockedDb.billingCreditEvent.findFirst.mockResolvedValue(null); + service = new BillingCreditsService(); + }); + + it('grants credits to an org-scoped product balance', async () => { + mockedDb.billingCreditBalance.findFirst.mockResolvedValue({ + id: 'bcb_1', + organizationId: 'org_1', + productKey: 'pentest', + skuKey: null, + }); + mockedDb.billingCreditBalance.findUniqueOrThrow.mockResolvedValue({ + id: 'bcb_1', + productKey: 'pentest', + skuKey: null, + balance: 3, + totalGranted: 3, + totalConsumed: 0, + totalRefunded: 0, + lastSource: 'manual', + updatedAt: new Date('2026-05-01T00:00:00.000Z'), + }); + + await expect( + service.grant({ + organizationId: 'org_1', + productKey: 'pentest', + quantity: 3, + source: 'manual', + note: 'CS goodwill', + adminUserId: 'usr_1', + }), + ).resolves.toMatchObject({ balance: 3, productKey: 'pentest' }); + + expect(tx.billingCreditEvent.create).toHaveBeenCalledWith( + expect.objectContaining({ + data: expect.objectContaining({ + organizationId: 'org_1', + productKey: 'pentest', + quantity: 3, + adminUserId: 'usr_1', + }), + }), + ); + expect(tx.billingCreditBalance.update).toHaveBeenCalledWith( + expect.objectContaining({ + where: { id: 'bcb_1' }, + data: expect.objectContaining({ balance: { increment: 3 } }), + }), + ); + }); + + it('consumes one available manual credit atomically', async () => { + mockedDb.billingCreditBalance.findMany.mockResolvedValue([ + { + id: 'bcb_1', + organizationId: 'org_1', + productKey: 'background_check', + skuKey: null, + balance: 1, + }, + ]); + + await expect( + service.tryConsumeForProduct({ + organizationId: 'org_1', + productKey: 'background_check', + sourceResourceId: 'mem_1', + }), + ).resolves.toEqual({ status: 'consumed' }); + + expect(tx.billingCreditBalance.updateMany).toHaveBeenCalledWith({ + where: { id: 'bcb_1', balance: { gt: 0 } }, + data: { + balance: { decrement: 1 }, + totalConsumed: { increment: 1 }, + }, + }); + }); + + it('returns exhausted when a concurrent consume drains the selected balance', async () => { + mockedDb.billingCreditBalance.findMany.mockResolvedValue([ + { + id: 'bcb_1', + organizationId: 'org_1', + productKey: 'background_check', + skuKey: null, + balance: 1, + }, + ]); + tx.billingCreditBalance.updateMany.mockResolvedValue({ count: 0 }); + + await expect( + service.tryConsumeForProduct({ + organizationId: 'org_1', + productKey: 'background_check', + sourceResourceId: 'mem_1', + }), + ).resolves.toEqual({ status: 'exhausted' }); + }); + + it('tries another available balance when the first selected balance is drained concurrently', async () => { + mockedDb.billingCreditBalance.findMany.mockResolvedValue([ + { + id: 'bcb_1', + organizationId: 'org_1', + productKey: 'background_check', + skuKey: null, + balance: 1, + }, + { + id: 'bcb_2', + organizationId: 'org_1', + productKey: 'background_check', + skuKey: 'background_checks_monthly_3', + balance: 2, + }, + ]); + tx.billingCreditBalance.updateMany + .mockResolvedValueOnce({ count: 0 }) + .mockResolvedValueOnce({ count: 1 }); + + await expect( + service.tryConsumeForProduct({ + organizationId: 'org_1', + productKey: 'background_check', + sourceResourceId: 'mem_1', + }), + ).resolves.toEqual({ status: 'consumed' }); + + expect(tx.billingCreditBalance.updateMany).toHaveBeenNthCalledWith(1, { + where: { id: 'bcb_1', balance: { gt: 0 } }, + data: { + balance: { decrement: 1 }, + totalConsumed: { increment: 1 }, + }, + }); + expect(tx.billingCreditBalance.updateMany).toHaveBeenNthCalledWith(2, { + where: { id: 'bcb_2', balance: { gt: 0 } }, + data: { + balance: { decrement: 1 }, + totalConsumed: { increment: 1 }, + }, + }); + expect(tx.billingCreditEvent.create).toHaveBeenCalledWith( + expect.objectContaining({ + data: expect.objectContaining({ + balanceId: 'bcb_2', + skuKey: 'background_checks_monthly_3', + }), + }), + ); + }); + + it('treats consume retries as consumed before checking balance', async () => { + mockedDb.billingCreditEvent.findFirst.mockResolvedValue({ id: 'bce_1' }); + + await expect( + service.tryConsumeForProduct({ + organizationId: 'org_1', + productKey: 'background_check', + sourceResourceId: 'mem_1', + }), + ).resolves.toEqual({ status: 'consumed' }); + + expect(mockedDb.billingCreditBalance.findMany).not.toHaveBeenCalled(); + }); + + it('uses a stable refund idempotency key across reasons', async () => { + mockedDb.billingCreditEvent.findFirst.mockResolvedValue({ + id: 'bce_consume_1', + balanceId: 'bcb_1', + productKey: 'pentest', + skuKey: null, + quantity: 1, + }); + + await expect( + service.refundForProduct({ + organizationId: 'org_1', + productKey: 'pentest', + sourceResourceId: 'run_1', + reason: 'first reason', + }), + ).resolves.toBe(true); + + expect(tx.billingCreditEvent.create).toHaveBeenCalledWith( + expect.objectContaining({ + data: expect.objectContaining({ + idempotencyKey: 'credit-refund:org_1:pentest:run_1', + note: 'first reason', + }), + }), + ); + }); + + it('uses the provided transaction client for credit refunds', async () => { + const transactionClient = { + billingCreditEvent: { + findFirst: jest.fn().mockResolvedValue({ + id: 'bce_consume_1', + balanceId: 'bcb_1', + productKey: 'pentest', + skuKey: null, + quantity: 1, + }), + create: jest.fn(), + }, + billingCreditBalance: { update: jest.fn() }, + }; + + await expect( + service.refundForProduct({ + organizationId: 'org_1', + productKey: 'pentest', + sourceResourceId: 'run_1', + reason: 'canceled', + tx: transactionClient as never, + }), + ).resolves.toBe(true); + + expect(transactionClient.billingCreditEvent.create).toHaveBeenCalled(); + expect(mockedDb.$transaction).not.toHaveBeenCalled(); + }); +}); diff --git a/apps/api/src/billing/billing-credits.service.ts b/apps/api/src/billing/billing-credits.service.ts new file mode 100644 index 0000000000..b848772a14 --- /dev/null +++ b/apps/api/src/billing/billing-credits.service.ts @@ -0,0 +1,300 @@ +import { Injectable } from '@nestjs/common'; +import { Prisma, db } from '@db'; +import type { BillingProductKey, BillingSkuKey } from '@trycompai/billing'; +import { isUniqueConstraintError } from './billing-entitlements.types'; +import { + assertCreditEventType, + assertProductKey, + type BillingCreditBalanceSummary, + validateCreditInput, +} from './billing-credits.types'; + +@Injectable() +export class BillingCreditsService { + async listBalances(organizationId: string) { + const balances = await db.billingCreditBalance.findMany({ + where: { organizationId }, + orderBy: [{ productKey: 'asc' }, { createdAt: 'asc' }], + }); + return balances.map((balance) => ({ + id: balance.id, + productKey: assertProductKey(balance.productKey), + skuKey: balance.skuKey, + balance: balance.balance, + totalGranted: balance.totalGranted, + totalConsumed: balance.totalConsumed, + totalRefunded: balance.totalRefunded, + lastSource: balance.lastSource, + updatedAt: balance.updatedAt.toISOString(), + })); + } + + async listEvents(params: { organizationId: string; take?: number }) { + return ( + await db.billingCreditEvent.findMany({ + where: { organizationId: params.organizationId }, + orderBy: { createdAt: 'desc' }, + take: params.take ?? 50, + }) + ).map((event) => ({ + id: event.id, + productKey: assertProductKey(event.productKey), + skuKey: event.skuKey, + eventType: assertCreditEventType(event.eventType), + quantity: event.quantity, + source: event.source, + note: event.note, + adminUserId: event.adminUserId, + sourceResourceId: event.sourceResourceId, + createdAt: event.createdAt.toISOString(), + })); + } + + async grant(params: { + organizationId: string; + productKey: BillingProductKey; + skuKey?: BillingSkuKey | null; + quantity: number; + source: string; + note: string; + adminUserId?: string | null; + idempotencyKey?: string; + }): Promise { + validateCreditInput(params); + const balance = await this.findOrCreateBalance({ + organizationId: params.organizationId, + productKey: params.productKey, + skuKey: params.skuKey ?? null, + }); + const idempotencyKey = + params.idempotencyKey ?? + [ + 'grant', + params.organizationId, + params.productKey, + params.skuKey ?? 'product', + params.quantity, + params.note.trim().toLowerCase(), + ].join(':'); + + try { + await db.$transaction(async (tx) => { + await tx.billingCreditEvent.create({ + data: { + organizationId: params.organizationId, + balanceId: balance.id, + productKey: params.productKey, + skuKey: params.skuKey ?? null, + eventType: 'grant', + quantity: params.quantity, + source: params.source, + note: params.note.trim(), + adminUserId: params.adminUserId ?? null, + idempotencyKey, + }, + }); + await tx.billingCreditBalance.update({ + where: { id: balance.id }, + data: { + balance: { increment: params.quantity }, + totalGranted: { increment: params.quantity }, + lastSource: params.source, + }, + }); + }); + } catch (error) { + if (!isUniqueConstraintError(error)) throw error; + } + + return this.getBalance(balance.id); + } + + async tryConsumeForProduct(params: { + organizationId: string; + productKey: BillingProductKey; + sourceResourceId: string; + }): Promise<{ status: 'consumed' | 'not_configured' | 'exhausted' }> { + const idempotencyKey = [ + 'credit-consume', + params.organizationId, + params.productKey, + params.sourceResourceId, + ].join(':'); + const existingConsumption = await db.billingCreditEvent.findFirst({ + where: { + organizationId: params.organizationId, + productKey: params.productKey, + eventType: 'consume', + idempotencyKey, + }, + select: { id: true }, + }); + if (existingConsumption) return { status: 'consumed' }; + + const existing = await db.billingCreditBalance.findMany({ + where: { + organizationId: params.organizationId, + productKey: params.productKey, + }, + orderBy: { createdAt: 'asc' }, + }); + if (existing.length === 0) return { status: 'not_configured' }; + + const availableBalances = existing.filter((item) => item.balance > 0); + if (availableBalances.length === 0) return { status: 'exhausted' }; + + for (const balance of availableBalances) { + try { + const consumed = await db.$transaction(async (tx) => { + const updated = await tx.billingCreditBalance.updateMany({ + where: { id: balance.id, balance: { gt: 0 } }, + data: { + balance: { decrement: 1 }, + totalConsumed: { increment: 1 }, + }, + }); + if (updated.count === 0) return false; + await tx.billingCreditEvent.create({ + data: { + organizationId: params.organizationId, + balanceId: balance.id, + productKey: params.productKey, + skuKey: balance.skuKey, + eventType: 'consume', + quantity: 1, + source: 'manual_credit', + sourceResourceId: params.sourceResourceId, + idempotencyKey, + }, + }); + return true; + }); + if (consumed) { + return { status: 'consumed' }; + } + } catch (error) { + if (isUniqueConstraintError(error)) return { status: 'consumed' }; + throw error; + } + } + + return { status: 'exhausted' }; + } + + async refundForProduct(params: { + organizationId: string; + productKey: BillingProductKey; + sourceResourceId: string; + reason: string; + tx?: Prisma.TransactionClient; + }): Promise { + const client = params.tx ?? db; + const consumed = await client.billingCreditEvent.findFirst({ + where: { + organizationId: params.organizationId, + productKey: params.productKey, + eventType: 'consume', + sourceResourceId: params.sourceResourceId, + }, + orderBy: { createdAt: 'desc' }, + }); + if (!consumed) return false; + + const idempotencyKey = [ + 'credit-refund', + params.organizationId, + params.productKey, + params.sourceResourceId, + ].join(':'); + + try { + const writeRefund = async (tx: Prisma.TransactionClient) => { + await tx.billingCreditEvent.create({ + data: { + organizationId: params.organizationId, + balanceId: consumed.balanceId, + productKey: params.productKey, + skuKey: consumed.skuKey, + eventType: 'refund', + quantity: consumed.quantity, + source: 'refund', + note: params.reason, + sourceResourceId: params.sourceResourceId, + linkedEventId: consumed.id, + idempotencyKey, + }, + }); + await tx.billingCreditBalance.update({ + where: { id: consumed.balanceId }, + data: { + balance: { increment: consumed.quantity }, + totalRefunded: { increment: consumed.quantity }, + totalConsumed: { decrement: consumed.quantity }, + lastSource: 'refund', + }, + }); + }; + if (params.tx) { + await writeRefund(params.tx); + } else { + await db.$transaction(writeRefund); + } + } catch (error) { + if (!isUniqueConstraintError(error)) throw error; + } + + return true; + } + + private async findOrCreateBalance(params: { + organizationId: string; + productKey: BillingProductKey; + skuKey: BillingSkuKey | null; + }) { + const existing = await db.billingCreditBalance.findFirst({ + where: { + organizationId: params.organizationId, + productKey: params.productKey, + skuKey: params.skuKey, + }, + }); + if (existing) return existing; + + try { + return await db.billingCreditBalance.create({ + data: { + organizationId: params.organizationId, + productKey: params.productKey, + skuKey: params.skuKey, + lastSource: 'manual', + }, + }); + } catch (error) { + if (!isUniqueConstraintError(error)) throw error; + return db.billingCreditBalance.findFirstOrThrow({ + where: { + organizationId: params.organizationId, + productKey: params.productKey, + skuKey: params.skuKey, + }, + }); + } + } + + private async getBalance(id: string): Promise { + const balance = await db.billingCreditBalance.findUniqueOrThrow({ + where: { id }, + }); + return { + id: balance.id, + productKey: assertProductKey(balance.productKey), + skuKey: balance.skuKey, + balance: balance.balance, + totalGranted: balance.totalGranted, + totalConsumed: balance.totalConsumed, + totalRefunded: balance.totalRefunded, + lastSource: balance.lastSource, + updatedAt: balance.updatedAt.toISOString(), + }; + } +} diff --git a/apps/api/src/billing/billing-credits.types.ts b/apps/api/src/billing/billing-credits.types.ts new file mode 100644 index 0000000000..865cd9f320 --- /dev/null +++ b/apps/api/src/billing/billing-credits.types.ts @@ -0,0 +1,76 @@ +import { BadRequestException } from '@nestjs/common'; +import { + getBillingSkuProductKey, + type BillingProductKey, + type BillingSkuKey, +} from '@trycompai/billing'; + +export type BillingCreditEventType = + | 'grant' + | 'consume' + | 'refund' + | 'adjustment' + | 'migration'; + +export function assertCreditEventType(value: string): BillingCreditEventType { + if ( + value === 'grant' || + value === 'consume' || + value === 'refund' || + value === 'adjustment' || + value === 'migration' + ) { + return value; + } + throw new BadRequestException('Unsupported billing credit event type.'); +} + +export interface BillingCreditBalanceSummary { + id: string; + productKey: BillingProductKey; + skuKey: string | null; + balance: number; + totalGranted: number; + totalConsumed: number; + totalRefunded: number; + lastSource: string; + updatedAt: string; +} + +export interface BillingCreditEventSummary { + id: string; + productKey: BillingProductKey; + skuKey: string | null; + eventType: BillingCreditEventType; + quantity: number; + source: string; + note: string | null; + adminUserId: string | null; + sourceResourceId: string | null; + createdAt: string; +} + +export function validateCreditInput(params: { + productKey: BillingProductKey; + skuKey?: BillingSkuKey | null; + quantity: number; + note: string; +}) { + if (!Number.isInteger(params.quantity) || params.quantity < 1) { + throw new BadRequestException('Credit amount must be a positive integer.'); + } + if (!params.note.trim()) { + throw new BadRequestException('A note is required for credit grants.'); + } + if (params.skuKey) { + const productKey = getBillingSkuProductKey(params.skuKey); + if (productKey !== params.productKey) { + throw new BadRequestException('SKU does not belong to product.'); + } + } +} + +export function assertProductKey(value: string): BillingProductKey { + if (value === 'pentest' || value === 'background_check') return value; + throw new BadRequestException('Unsupported billing product.'); +} diff --git a/apps/api/src/billing/billing-customer.spec.ts b/apps/api/src/billing/billing-customer.spec.ts new file mode 100644 index 0000000000..abb8b78f31 --- /dev/null +++ b/apps/api/src/billing/billing-customer.spec.ts @@ -0,0 +1,64 @@ +import { db } from '@db'; +import { findOrCreateBillingCustomer } from './billing-customer'; + +jest.mock('@db', () => ({ + Prisma: { + PrismaClientKnownRequestError: class PrismaClientKnownRequestError extends Error { + code: string; + + constructor(code: string) { + super(code); + this.code = code; + } + }, + }, + db: { + organization: { findUniqueOrThrow: jest.fn() }, + organizationBilling: { create: jest.fn(), findUnique: jest.fn() }, + }, +})); + +const mockedDb = db as unknown as { + organization: { findUniqueOrThrow: jest.Mock }; + organizationBilling: { create: jest.Mock; findUnique: jest.Mock }; +}; + +describe('findOrCreateBillingCustomer', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockedDb.organizationBilling.findUnique.mockResolvedValue(null); + mockedDb.organization.findUniqueOrThrow.mockResolvedValue({ + id: 'org_1', + name: 'Acme', + }); + mockedDb.organizationBilling.create.mockResolvedValue({}); + }); + + it('keeps customer creation idempotent when caller email varies', async () => { + const customersCreate = jest.fn().mockResolvedValue({ id: 'cus_1' }); + const customersUpdate = jest.fn().mockResolvedValue({ id: 'cus_1' }); + + await expect( + findOrCreateBillingCustomer({ + stripeService: { + getClient: () => ({ + customers: { create: customersCreate, update: customersUpdate }, + }), + } as never, + organizationId: 'org_1', + customerEmail: 'billing@example.com', + }), + ).resolves.toBe('cus_1'); + + expect(customersCreate).toHaveBeenCalledWith( + { + name: 'Acme', + metadata: { organizationId: 'org_1' }, + }, + { idempotencyKey: 'organization-billing-customer:org_1' }, + ); + expect(customersUpdate).toHaveBeenCalledWith('cus_1', { + email: 'billing@example.com', + }); + }); +}); diff --git a/apps/api/src/billing/billing-customer.ts b/apps/api/src/billing/billing-customer.ts new file mode 100644 index 0000000000..370dd6fc58 --- /dev/null +++ b/apps/api/src/billing/billing-customer.ts @@ -0,0 +1,63 @@ +import { Prisma, db } from '@db'; +import type { StripeService } from '../stripe/stripe.service'; + +export async function findOrCreateBillingCustomer(params: { + stripeService: StripeService; + organizationId: string; + customerEmail?: string; +}): Promise { + const existing = await db.organizationBilling.findUnique({ + where: { organizationId: params.organizationId }, + select: { stripeCustomerId: true }, + }); + if (existing) { + return existing.stripeCustomerId; + } + + const organization = await db.organization.findUniqueOrThrow({ + where: { id: params.organizationId }, + select: { id: true, name: true }, + }); + const stripe = params.stripeService.getClient(); + const customer = await stripe.customers.create( + { + name: organization.name, + metadata: { organizationId: organization.id }, + }, + { + idempotencyKey: ['organization-billing-customer', organization.id].join( + ':', + ), + }, + ); + + try { + await db.organizationBilling.create({ + data: { + organizationId: organization.id, + stripeCustomerId: customer.id, + }, + }); + if (params.customerEmail) { + await stripe.customers.update(customer.id, { + email: params.customerEmail, + }); + } + } catch (error) { + if (!isUniqueConstraintError(error)) throw error; + const raced = await db.organizationBilling.findUniqueOrThrow({ + where: { organizationId: organization.id }, + select: { stripeCustomerId: true }, + }); + return raced.stripeCustomerId; + } + + return customer.id; +} + +function isUniqueConstraintError(error: unknown): boolean { + return ( + error instanceof Prisma.PrismaClientKnownRequestError && + error.code === 'P2002' + ); +} diff --git a/apps/api/src/billing/billing-entitlements.service.spec.ts b/apps/api/src/billing/billing-entitlements.service.spec.ts new file mode 100644 index 0000000000..3cdce30579 --- /dev/null +++ b/apps/api/src/billing/billing-entitlements.service.spec.ts @@ -0,0 +1,360 @@ +import { db } from '@db'; +import { BillingEntitlementsService } from './billing-entitlements.service'; + +jest.mock('@db', () => ({ + db: { + organizationBillingSubscription: { + findUnique: jest.fn(), + findMany: jest.fn(), + }, + billingUsageEvent: { + findFirst: jest.fn(), + findUnique: jest.fn(), + }, + billingAuditEvent: { + create: jest.fn(), + }, + $transaction: jest.fn(), + }, +})); + +type MockTx = { + organizationBillingSubscription: { + create: jest.Mock; + findUnique: jest.Mock; + updateMany: jest.Mock; + }; + billingUsageEvent: { + create: jest.Mock; + findFirst: jest.Mock; + findUnique: jest.Mock; + }; +}; + +const mockedDb = db as unknown as { + organizationBillingSubscription: { + findUnique: jest.Mock; + findMany: jest.Mock; + }; + billingUsageEvent: { findFirst: jest.Mock; findUnique: jest.Mock }; + billingAuditEvent: { create: jest.Mock }; + $transaction: jest.Mock; +}; + +describe('BillingEntitlementsService', () => { + let tx: MockTx; + let service: BillingEntitlementsService; + + beforeEach(() => { + jest.clearAllMocks(); + tx = { + organizationBillingSubscription: { + create: jest.fn(), + findUnique: jest.fn(), + updateMany: jest.fn().mockResolvedValue({ count: 1 }), + }, + billingUsageEvent: { + create: jest.fn(), + findFirst: jest.fn(), + findUnique: jest.fn(), + }, + }; + mockedDb.billingUsageEvent.findFirst.mockResolvedValue(null); + mockedDb.billingUsageEvent.findUnique.mockResolvedValue(null); + mockedDb.$transaction.mockImplementation( + (callback: (tx: MockTx) => Promise) => callback(tx), + ); + mockedDb.billingAuditEvent.create.mockResolvedValue({}); + service = new BillingEntitlementsService(); + }); + + it('applies same-period subscription updates that shorten currentPeriodEnd', async () => { + tx.organizationBillingSubscription.findUnique.mockResolvedValue({ + id: 'obs_1', + stripeSubscriptionItemId: 'si_1', + currentPeriodStart: new Date('2026-04-01T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-05-01T00:00:00.000Z'), + }); + + await service.syncSubscriptionItem({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_5', + stripeSubscriptionId: 'sub_1', + stripeSubscriptionItemId: 'si_1', + stripePriceId: 'price_1', + stripeStatus: 'canceled', + currentPeriodStart: new Date('2026-04-01T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-04-15T00:00:00.000Z'), + includedQuantity: 5, + cancelAtPeriodEnd: false, + canceledAt: new Date('2026-04-10T00:00:00.000Z'), + stripeEventId: 'evt_1', + }); + + expect(tx.organizationBillingSubscription.updateMany).toHaveBeenCalledWith( + expect.objectContaining({ + where: expect.objectContaining({ + id: 'obs_1', + currentPeriodStart: new Date('2026-04-01T00:00:00.000Z'), + }), + data: expect.not.objectContaining({ usedQuantity: 0 }), + }), + ); + expect(tx.billingUsageEvent.create).not.toHaveBeenCalled(); + }); + + it('retries without resetting usage when another sync already advanced the period', async () => { + tx.organizationBillingSubscription.findUnique + .mockResolvedValueOnce({ + id: 'obs_1', + stripeSubscriptionItemId: 'si_1', + currentPeriodStart: new Date('2026-04-01T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-05-01T00:00:00.000Z'), + }) + .mockResolvedValueOnce({ + id: 'obs_1', + stripeSubscriptionItemId: 'si_1', + currentPeriodStart: new Date('2026-05-01T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-06-01T00:00:00.000Z'), + }); + tx.organizationBillingSubscription.updateMany + .mockResolvedValueOnce({ count: 0 }) + .mockResolvedValueOnce({ count: 1 }); + + await service.syncSubscriptionItem({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_5', + stripeSubscriptionId: 'sub_1', + stripeSubscriptionItemId: 'si_1', + stripePriceId: 'price_1', + stripeStatus: 'active', + currentPeriodStart: new Date('2026-05-01T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-06-01T00:00:00.000Z'), + includedQuantity: 5, + cancelAtPeriodEnd: false, + canceledAt: null, + stripeEventId: 'evt_2', + }); + + expect( + tx.organizationBillingSubscription.updateMany, + ).toHaveBeenNthCalledWith( + 1, + expect.objectContaining({ + where: { + id: 'obs_1', + currentPeriodStart: new Date('2026-04-01T00:00:00.000Z'), + }, + data: expect.objectContaining({ usedQuantity: 0 }), + }), + ); + expect( + tx.organizationBillingSubscription.updateMany, + ).toHaveBeenNthCalledWith( + 2, + expect.objectContaining({ + where: { + id: 'obs_1', + currentPeriodStart: new Date('2026-05-01T00:00:00.000Z'), + }, + data: expect.not.objectContaining({ usedQuantity: 0 }), + }), + ); + expect(tx.billingUsageEvent.create).not.toHaveBeenCalled(); + expect(mockedDb.billingAuditEvent.create).toHaveBeenCalledWith( + expect.objectContaining({ + data: expect.objectContaining({ + eventType: 'subscription_synced', + skuKey: 'pentest_monthly_5', + }), + }), + ); + }); + + it('consumes the active subscription for a product family', async () => { + mockedDb.organizationBillingSubscription.findMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'pentest_monthly_4', + stripeStatus: 'active', + usedQuantity: 1, + includedQuantity: 4, + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }, + ]); + mockedDb.organizationBillingSubscription.findUnique.mockResolvedValue({ + id: 'obs_1', + skuKey: 'pentest_monthly_4', + stripeStatus: 'active', + usedQuantity: 1, + includedQuantity: 4, + stripeSubscriptionItemId: 'si_1', + currentPeriodStart: new Date('2026-04-30T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }); + tx.organizationBillingSubscription.updateMany.mockResolvedValue({ + count: 1, + }); + + await expect( + service.tryConsumeIncludedUsageForProduct({ + organizationId: 'org_1', + productKey: 'pentest', + sourceResourceId: 'run_1', + }), + ).resolves.toEqual({ status: 'consumed', subscriptionId: 'obs_1' }); + + expect( + mockedDb.organizationBillingSubscription.findUnique, + ).toHaveBeenCalledWith({ + where: { + organizationId_skuKey: { + organizationId: 'org_1', + skuKey: 'pentest_monthly_4', + }, + }, + }); + }); + + it('treats product consumption retries as consumed before credit fallback', async () => { + const credits = { tryConsumeForProduct: jest.fn() }; + service = new BillingEntitlementsService(credits as never); + mockedDb.organizationBillingSubscription.findMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'pentest_monthly_1', + stripeStatus: 'active', + usedQuantity: 1, + includedQuantity: 1, + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }, + ]); + mockedDb.billingUsageEvent.findFirst.mockResolvedValue({ + skuKey: 'pentest_monthly_1', + }); + + await expect( + service.tryConsumeIncludedUsageForProduct({ + organizationId: 'org_1', + productKey: 'pentest', + sourceResourceId: 'run_1', + }), + ).resolves.toEqual({ status: 'consumed', subscriptionId: 'obs_1' }); + + expect(credits.tryConsumeForProduct).not.toHaveBeenCalled(); + }); + + it('falls back to credits when included usage is exhausted during consume', async () => { + const credits = { + tryConsumeForProduct: jest.fn().mockResolvedValue({ status: 'consumed' }), + }; + service = new BillingEntitlementsService(credits as never); + mockedDb.organizationBillingSubscription.findMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'pentest_monthly_1', + stripeStatus: 'active', + usedQuantity: 0, + includedQuantity: 1, + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }, + ]); + mockedDb.organizationBillingSubscription.findUnique.mockResolvedValue({ + id: 'obs_1', + skuKey: 'pentest_monthly_1', + stripeStatus: 'active', + usedQuantity: 0, + includedQuantity: 1, + stripeSubscriptionItemId: 'si_1', + currentPeriodStart: new Date('2026-04-30T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }); + tx.organizationBillingSubscription.updateMany.mockResolvedValue({ + count: 0, + }); + + await expect( + service.tryConsumeIncludedUsageForProduct({ + organizationId: 'org_1', + productKey: 'pentest', + sourceResourceId: 'run_1', + }), + ).resolves.toEqual({ status: 'consumed', subscriptionId: 'manual_credit' }); + + expect(credits.tryConsumeForProduct).toHaveBeenCalledWith({ + organizationId: 'org_1', + productKey: 'pentest', + sourceResourceId: 'run_1', + }); + }); + + it('returns exhausted instead of throwing when allowance is concurrently exhausted', async () => { + mockedDb.organizationBillingSubscription.findUnique.mockResolvedValue({ + id: 'obs_1', + skuKey: 'background_checks_monthly_3', + stripeStatus: 'active', + usedQuantity: 0, + includedQuantity: 1, + stripeSubscriptionItemId: 'si_1', + currentPeriodStart: new Date('2026-04-30T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }); + tx.organizationBillingSubscription.updateMany.mockResolvedValue({ + count: 0, + }); + + await expect( + service.tryConsumeIncludedUsage({ + organizationId: 'org_1', + skuKey: 'background_checks_monthly_3', + sourceResourceId: 'mem_1', + }), + ).resolves.toEqual({ status: 'exhausted', subscriptionId: 'obs_1' }); + }); + + it('uses credit refund fallback even with a transaction client', async () => { + const credits = { refundForProduct: jest.fn().mockResolvedValue(true) }; + service = new BillingEntitlementsService(credits as never); + tx.billingUsageEvent.findFirst.mockResolvedValue(null); + + await service.refundIncludedUsageForProduct({ + organizationId: 'org_1', + productKey: 'pentest', + sourceResourceId: 'run_1', + reason: 'canceled', + tx: tx as never, + }); + + expect(credits.refundForProduct).toHaveBeenCalledWith({ + organizationId: 'org_1', + productKey: 'pentest', + sourceResourceId: 'run_1', + reason: 'canceled', + tx, + }); + }); + + it('uses a stable included-usage refund key across reasons', async () => { + tx.billingUsageEvent.findUnique.mockResolvedValue({ + stripeSubscriptionItemId: 'si_1', + periodStart: new Date('2026-04-30T00:00:00.000Z'), + periodEnd: new Date('2026-05-30T00:00:00.000Z'), + }); + + await service.refundIncludedUsage({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_1', + sourceResourceId: 'run_1', + reason: 'first reason', + tx: tx as never, + }); + + expect(tx.billingUsageEvent.create).toHaveBeenCalledWith( + expect.objectContaining({ + data: expect.objectContaining({ + idempotencyKey: 'refund:org_1:pentest_monthly_1:run_1', + }), + }), + ); + }); +}); diff --git a/apps/api/src/billing/billing-entitlements.service.ts b/apps/api/src/billing/billing-entitlements.service.ts new file mode 100644 index 0000000000..8fc87be63f --- /dev/null +++ b/apps/api/src/billing/billing-entitlements.service.ts @@ -0,0 +1,434 @@ +import { Injectable, Optional } from '@nestjs/common'; +import { Prisma, db } from '@db'; +import { + getBillingSkuKeysForProduct, + type BillingProductKey, + type BillingSkuKey, +} from '@trycompai/billing'; +import { BillingCreditsService } from './billing-credits.service'; +import { refundIncludedUsageEvent } from './billing-included-usage-refunds'; +import { + type BillingConsumeResult, + isAccessStatus, + isUniqueConstraintError, + sameTime, + type SyncSubscriptionItemParams, + type WriteBillingAuditEventParams, +} from './billing-entitlements.types'; + +@Injectable() +export class BillingEntitlementsService { + constructor(@Optional() private readonly credits?: BillingCreditsService) {} + + async tryConsumeIncludedUsageForProduct(params: { + organizationId: string; + productKey: BillingProductKey; + sourceResourceId: string; + }): Promise { + const skuKeys = getBillingSkuKeysForProduct(params.productKey); + const subscriptions = await db.organizationBillingSubscription.findMany({ + where: { + organizationId: params.organizationId, + skuKey: { in: skuKeys }, + }, + orderBy: [{ createdAt: 'desc' }], + }); + + const activeSubscription = subscriptions.find( + (subscription) => + isAccessStatus(subscription.stripeStatus) && + (!subscription.currentPeriodEnd || + subscription.currentPeriodEnd.getTime() > Date.now()), + ); + if (!activeSubscription) { + return this.tryConsumeCreditFallback({ + ...params, + fallbackStatus: 'not_configured', + }); + } + + if ( + activeSubscription.usedQuantity >= activeSubscription.includedQuantity + ) { + const existingUsage = await this.findExistingIncludedUsageForProduct({ + organizationId: params.organizationId, + skuKeys, + sourceResourceId: params.sourceResourceId, + subscriptions, + }); + if (existingUsage) return existingUsage; + + const creditResult = await this.tryConsumeCreditFallback({ + ...params, + fallbackStatus: 'exhausted', + }); + return creditResult.status === 'consumed' + ? creditResult + : { status: 'exhausted', subscriptionId: activeSubscription.id }; + } + + const usageResult = await this.tryConsumeIncludedUsage({ + organizationId: params.organizationId, + skuKey: activeSubscription.skuKey as BillingSkuKey, + sourceResourceId: params.sourceResourceId, + }); + if (usageResult.status === 'consumed') return usageResult; + + const creditResult = await this.tryConsumeCreditFallback({ + ...params, + fallbackStatus: + usageResult.status === 'exhausted' ? 'exhausted' : 'not_configured', + }); + return creditResult.status === 'consumed' ? creditResult : usageResult; + } + + async tryConsumeIncludedUsage(params: { + organizationId: string; + skuKey: BillingSkuKey; + sourceResourceId: string; + }): Promise { + const subscription = await db.organizationBillingSubscription.findUnique({ + where: { + organizationId_skuKey: { + organizationId: params.organizationId, + skuKey: params.skuKey, + }, + }, + }); + + if (!subscription || !isAccessStatus(subscription.stripeStatus)) { + return { status: 'not_configured' }; + } + + const idempotencyKey = [ + 'consume', + params.organizationId, + params.skuKey, + params.sourceResourceId, + ].join(':'); + const existingUsage = await db.billingUsageEvent.findUnique({ + where: { idempotencyKey }, + select: { id: true }, + }); + if (existingUsage) { + return { status: 'consumed', subscriptionId: subscription.id }; + } + + if ( + subscription.currentPeriodEnd && + subscription.currentPeriodEnd.getTime() <= Date.now() + ) { + return { status: 'not_configured' }; + } + + if (subscription.usedQuantity >= subscription.includedQuantity) { + return { status: 'exhausted', subscriptionId: subscription.id }; + } + + try { + await db.$transaction(async (tx) => { + await tx.billingUsageEvent.create({ + data: { + organizationId: params.organizationId, + skuKey: params.skuKey, + eventType: 'consume', + quantity: 1, + sourceResourceId: params.sourceResourceId, + idempotencyKey, + stripeSubscriptionItemId: subscription.stripeSubscriptionItemId, + periodStart: subscription.currentPeriodStart, + periodEnd: subscription.currentPeriodEnd, + }, + }); + + const updated = await tx.organizationBillingSubscription.updateMany({ + where: { + id: subscription.id, + usedQuantity: { lt: subscription.includedQuantity }, + }, + data: { usedQuantity: { increment: 1 } }, + }); + + if (updated.count === 0) { + throw new BillingAllowanceExhaustedError(); + } + }); + } catch (error) { + if (error instanceof BillingAllowanceExhaustedError) { + return { status: 'exhausted', subscriptionId: subscription.id }; + } + if (isUniqueConstraintError(error)) { + return { status: 'consumed', subscriptionId: subscription.id }; + } + throw error; + } + + return { status: 'consumed', subscriptionId: subscription.id }; + } + + async syncSubscriptionItem( + params: SyncSubscriptionItemParams, + ): Promise { + for (let attempt = 0; attempt < 2; attempt += 1) { + let syncResult: { didSync: boolean; retry: boolean }; + try { + syncResult = await db.$transaction(async (tx) => { + const existing = await tx.organizationBillingSubscription.findUnique({ + where: { + organizationId_skuKey: { + organizationId: params.organizationId, + skuKey: params.skuKey, + }, + }, + select: { + id: true, + currentPeriodStart: true, + }, + }); + if ( + existing?.currentPeriodStart && + params.currentPeriodStart && + existing.currentPeriodStart.getTime() > + params.currentPeriodStart.getTime() + ) { + return { didSync: false, retry: false }; + } + + const resetUsage = + !existing || + !sameTime(existing.currentPeriodStart, params.currentPeriodStart); + if (!existing) { + await tx.organizationBillingSubscription.create({ + data: { + organizationId: params.organizationId, + skuKey: params.skuKey, + stripeSubscriptionId: params.stripeSubscriptionId, + stripeSubscriptionItemId: params.stripeSubscriptionItemId, + stripePriceId: params.stripePriceId, + stripeStatus: params.stripeStatus, + currentPeriodStart: params.currentPeriodStart, + currentPeriodEnd: params.currentPeriodEnd, + includedQuantity: params.includedQuantity, + cancelAtPeriodEnd: params.cancelAtPeriodEnd, + canceledAt: params.canceledAt, + }, + }); + } else { + const updated = await tx.organizationBillingSubscription.updateMany( + { + where: { + id: existing.id, + currentPeriodStart: existing.currentPeriodStart, + }, + data: { + stripeSubscriptionId: params.stripeSubscriptionId, + stripeSubscriptionItemId: params.stripeSubscriptionItemId, + stripePriceId: params.stripePriceId, + stripeStatus: params.stripeStatus, + currentPeriodStart: params.currentPeriodStart, + currentPeriodEnd: params.currentPeriodEnd, + includedQuantity: params.includedQuantity, + ...(resetUsage ? { usedQuantity: 0 } : {}), + cancelAtPeriodEnd: params.cancelAtPeriodEnd, + canceledAt: params.canceledAt, + }, + }, + ); + if (updated.count === 0) { + return { didSync: false, retry: true }; + } + } + + if (resetUsage) { + const idempotencyKey = [ + 'grant', + params.organizationId, + params.skuKey, + params.stripeSubscriptionItemId, + params.currentPeriodStart?.toISOString() ?? 'none', + params.currentPeriodEnd?.toISOString() ?? 'none', + params.stripeEventId ?? 'manual', + ].join(':'); + await tx.billingUsageEvent.create({ + data: { + organizationId: params.organizationId, + skuKey: params.skuKey, + eventType: 'grant', + quantity: params.includedQuantity, + idempotencyKey, + stripeEventId: params.stripeEventId, + stripeSubscriptionItemId: params.stripeSubscriptionItemId, + periodStart: params.currentPeriodStart, + periodEnd: params.currentPeriodEnd, + }, + }); + } + return { didSync: true, retry: false }; + }); + } catch (error) { + if (isUniqueConstraintError(error)) { + return; + } + throw error; + } + if (!syncResult.didSync && syncResult.retry) continue; + if (!syncResult.didSync) return; + + await this.writeAuditEvent({ + organizationId: params.organizationId, + eventType: 'subscription_synced', + skuKey: params.skuKey, + stripeEventId: params.stripeEventId, + metadata: { + stripeSubscriptionId: params.stripeSubscriptionId, + stripeSubscriptionItemId: params.stripeSubscriptionItemId, + stripeStatus: params.stripeStatus, + }, + }); + return; + } + } + + async recordOneTimeUsage(params: { + organizationId: string; + skuKey: BillingSkuKey; + sourceResourceId: string; + stripeInvoiceId?: string; + }): Promise { + const idempotencyKey = [ + 'one-time', + params.organizationId, + params.skuKey, + params.sourceResourceId, + ].join(':'); + + try { + await db.billingUsageEvent.create({ + data: { + organizationId: params.organizationId, + skuKey: params.skuKey, + eventType: 'one_time', + quantity: 1, + sourceResourceId: params.sourceResourceId, + idempotencyKey, + stripeInvoiceId: params.stripeInvoiceId, + }, + }); + } catch (error) { + if (!isUniqueConstraintError(error)) throw error; + } + } + + async refundIncludedUsage(params: { + organizationId: string; + skuKey: BillingSkuKey; + sourceResourceId: string; + reason: string; + tx?: Prisma.TransactionClient; + }): Promise { + await refundIncludedUsageEvent(params); + } + + async refundIncludedUsageForProduct(params: { + organizationId: string; + productKey: BillingProductKey; + sourceResourceId: string; + reason: string; + tx?: Prisma.TransactionClient; + }): Promise { + const skuKeys = getBillingSkuKeysForProduct(params.productKey); + const client = params.tx ?? db; + const consumed = await client.billingUsageEvent.findFirst({ + where: { + organizationId: params.organizationId, + skuKey: { in: skuKeys }, + eventType: 'consume', + sourceResourceId: params.sourceResourceId, + }, + orderBy: { createdAt: 'desc' }, + select: { skuKey: true }, + }); + if (!consumed) { + if (!this.credits) return; + await this.credits.refundForProduct({ + organizationId: params.organizationId, + productKey: params.productKey, + sourceResourceId: params.sourceResourceId, + reason: params.reason, + tx: params.tx, + }); + return; + } + await refundIncludedUsageEvent({ + organizationId: params.organizationId, + skuKey: consumed.skuKey as BillingSkuKey, + sourceResourceId: params.sourceResourceId, + reason: params.reason, + tx: params.tx, + }); + } + + private async tryConsumeCreditFallback(params: { + organizationId: string; + productKey: BillingProductKey; + sourceResourceId: string; + fallbackStatus: 'not_configured' | 'exhausted'; + }): Promise { + if (!this.credits) { + return params.fallbackStatus === 'exhausted' + ? { status: 'exhausted', subscriptionId: 'manual_credit' } + : { status: 'not_configured' }; + } + const creditResult = await this.credits.tryConsumeForProduct({ + organizationId: params.organizationId, + productKey: params.productKey, + sourceResourceId: params.sourceResourceId, + }); + return creditResult.status === 'consumed' + ? { status: 'consumed', subscriptionId: 'manual_credit' } + : params.fallbackStatus === 'exhausted' + ? { status: 'exhausted', subscriptionId: 'manual_credit' } + : { status: 'not_configured' }; + } + + private async findExistingIncludedUsageForProduct(params: { + organizationId: string; + skuKeys: string[]; + sourceResourceId: string; + subscriptions: { id: string; skuKey: string }[]; + }): Promise { + const existingUsage = await db.billingUsageEvent.findFirst({ + where: { + organizationId: params.organizationId, + skuKey: { in: params.skuKeys }, + eventType: 'consume', + sourceResourceId: params.sourceResourceId, + }, + select: { skuKey: true }, + orderBy: { createdAt: 'desc' }, + }); + if (!existingUsage) return null; + + const subscription = params.subscriptions.find( + (item) => item.skuKey === existingUsage.skuKey, + ); + return { + status: 'consumed', + subscriptionId: subscription?.id ?? 'included_usage', + }; + } + + async writeAuditEvent(params: WriteBillingAuditEventParams): Promise { + await db.billingAuditEvent.create({ + data: { + organizationId: params.organizationId, + eventType: params.eventType, + skuKey: params.skuKey, + stripeEventId: params.stripeEventId, + metadata: params.metadata, + }, + }); + } +} + +class BillingAllowanceExhaustedError extends Error {} diff --git a/apps/api/src/billing/billing-entitlements.types.ts b/apps/api/src/billing/billing-entitlements.types.ts new file mode 100644 index 0000000000..6c95d092f0 --- /dev/null +++ b/apps/api/src/billing/billing-entitlements.types.ts @@ -0,0 +1,47 @@ +import { Prisma } from '@db'; +import type { BillingSkuKey } from '@trycompai/billing'; + +export type BillingConsumeResult = + | { status: 'consumed'; subscriptionId: string } + | { status: 'exhausted'; subscriptionId: string } + | { status: 'not_configured' }; + +export type SyncSubscriptionItemParams = { + organizationId: string; + skuKey: BillingSkuKey; + stripeSubscriptionId: string; + stripeSubscriptionItemId: string; + stripePriceId: string; + stripeStatus: string; + currentPeriodStart: Date | null; + currentPeriodEnd: Date | null; + includedQuantity: number; + cancelAtPeriodEnd: boolean; + canceledAt: Date | null; + stripeEventId?: string; +}; + +export type WriteBillingAuditEventParams = { + organizationId: string; + eventType: string; + skuKey?: string; + stripeEventId?: string; + metadata?: Prisma.InputJsonValue; +}; + +export function isAccessStatus(status: string): boolean { + return status === 'active' || status === 'trialing'; +} + +export function sameTime(left: Date | null, right: Date | null): boolean { + if (!left && !right) return true; + if (!left || !right) return false; + return left.getTime() === right.getTime(); +} + +export function isUniqueConstraintError(error: unknown): boolean { + return ( + error instanceof Prisma.PrismaClientKnownRequestError && + error.code === 'P2002' + ); +} diff --git a/apps/api/src/billing/billing-included-usage-refunds.ts b/apps/api/src/billing/billing-included-usage-refunds.ts new file mode 100644 index 0000000000..3de6a02746 --- /dev/null +++ b/apps/api/src/billing/billing-included-usage-refunds.ts @@ -0,0 +1,69 @@ +import { Prisma, db } from '@db'; +import type { BillingSkuKey } from '@trycompai/billing'; +import { isUniqueConstraintError } from './billing-entitlements.types'; + +export async function refundIncludedUsageEvent(params: { + organizationId: string; + skuKey: BillingSkuKey; + sourceResourceId: string; + reason: string; + tx?: Prisma.TransactionClient; +}): Promise { + const consumeKey = [ + 'consume', + params.organizationId, + params.skuKey, + params.sourceResourceId, + ].join(':'); + const refundKey = [ + 'refund', + params.organizationId, + params.skuKey, + params.sourceResourceId, + ].join(':'); + + try { + const writeRefund = async (tx: Prisma.TransactionClient) => { + const consumed = await tx.billingUsageEvent.findUnique({ + where: { idempotencyKey: consumeKey }, + select: { + stripeSubscriptionItemId: true, + periodStart: true, + periodEnd: true, + }, + }); + if (!consumed) return; + + await tx.billingUsageEvent.create({ + data: { + organizationId: params.organizationId, + skuKey: params.skuKey, + eventType: 'refund', + quantity: 1, + sourceResourceId: params.sourceResourceId, + idempotencyKey: refundKey, + stripeSubscriptionItemId: consumed.stripeSubscriptionItemId, + periodStart: consumed.periodStart, + periodEnd: consumed.periodEnd, + }, + }); + + await tx.organizationBillingSubscription.updateMany({ + where: { + organizationId: params.organizationId, + skuKey: params.skuKey, + usedQuantity: { gt: 0 }, + }, + data: { usedQuantity: { decrement: 1 } }, + }); + }; + + if (params.tx) { + await writeRefund(params.tx); + } else { + await db.$transaction(writeRefund); + } + } catch (error) { + if (!isUniqueConstraintError(error)) throw error; + } +} diff --git a/apps/api/src/billing/billing-invoices.spec.ts b/apps/api/src/billing/billing-invoices.spec.ts new file mode 100644 index 0000000000..7dee6a84e8 --- /dev/null +++ b/apps/api/src/billing/billing-invoices.spec.ts @@ -0,0 +1,105 @@ +import type Stripe from 'stripe'; +import type { StripeService } from '../stripe/stripe.service'; +import { listBillingInvoices } from './billing-invoices'; + +function mockStripeService(params: { invoices: Stripe.Invoice[] }): StripeService { + const invoicesList = jest.fn().mockResolvedValue({ data: params.invoices }); + return { + isConfigured: () => true, + getClient: () => ({ + invoices: { + list: invoicesList, + }, + }), + } as unknown as StripeService; +} + +function createInvoice(params: { + id: string; + amountPaid: number; + priceId: string; + billingReason?: Stripe.Invoice.BillingReason | null; +}): Stripe.Invoice { + const line = { + id: `il_${params.id}`, + object: 'line_item', + parent: null, + pricing: { + type: 'price_details', + unit_amount_decimal: String(params.amountPaid), + price_details: { + price: params.priceId, + product: 'prod_test', + }, + }, + subscription: null, + } as unknown as Stripe.InvoiceLineItem; + + return { + id: params.id, + number: `${params.id}-0001`, + created: 1777564800, + due_date: null, + amount_paid: params.amountPaid, + amount_due: params.amountPaid, + currency: 'usd', + status: 'paid', + billing_reason: params.billingReason ?? 'manual', + hosted_invoice_url: `https://invoice.stripe.test/${params.id}`, + invoice_pdf: `https://invoice.stripe.test/${params.id}.pdf`, + lines: { + object: 'list', + data: [line], + has_more: false, + url: `/v1/invoices/${params.id}/lines`, + }, + } as unknown as Stripe.Invoice; +} + +describe('listBillingInvoices', () => { + it('labels subscription catalog invoices as subscriptions', async () => { + const invoices = await listBillingInvoices({ + stripeService: mockStripeService({ + invoices: [ + createInvoice({ + id: 'in_subscription', + amountPaid: 39900, + priceId: 'price_1TRya6CkFWhKYvHI1sJ2M2no', + }), + ], + }), + stripeCustomerId: 'cus_1', + }); + + expect(invoices).toEqual([ + expect.objectContaining({ + id: 'in_subscription', + amountPaid: 39900, + type: 'Subscription', + }), + ]); + }); + + it('keeps one-time catalog invoices labelled as one-time', async () => { + const invoices = await listBillingInvoices({ + stripeService: mockStripeService({ + invoices: [ + createInvoice({ + id: 'in_one_time', + amountPaid: 4900, + priceId: 'price_1TRWckCkFWhKYvHIA1GLv1sO', + }), + ], + }), + stripeCustomerId: 'cus_1', + }); + + expect(invoices).toEqual([ + expect.objectContaining({ + id: 'in_one_time', + amountPaid: 4900, + type: 'One Time', + }), + ]); + }); +}); diff --git a/apps/api/src/billing/billing-invoices.ts b/apps/api/src/billing/billing-invoices.ts new file mode 100644 index 0000000000..e855a770b4 --- /dev/null +++ b/apps/api/src/billing/billing-invoices.ts @@ -0,0 +1,95 @@ +import type Stripe from 'stripe'; +import { getBillingSkuByStripePriceId } from '@trycompai/billing'; +import type { StripeService } from '../stripe/stripe.service'; + +export interface BillingInvoice { + id: string; + number: string; + createdAt: string; + dueDate: string | null; + amountPaid: number; + amountDue: number; + currency: string; + status: string; + type: 'Subscription' | 'One Time'; + hostedInvoiceUrl: string | null; + invoicePdfUrl: string | null; +} + +export async function listBillingInvoices(params: { + stripeService: StripeService; + stripeCustomerId: string | null; +}): Promise { + if (!params.stripeCustomerId || !params.stripeService.isConfigured()) { + return []; + } + + const stripe = params.stripeService.getClient(); + const invoices = await stripe.invoices.list({ + customer: params.stripeCustomerId, + limit: 20, + }); + + return invoices.data.map(mapInvoice); +} + +function mapInvoice(invoice: Stripe.Invoice): BillingInvoice { + return { + id: invoice.id, + number: invoice.number ?? invoice.id, + createdAt: new Date(invoice.created * 1000).toISOString(), + dueDate: invoice.due_date + ? new Date(invoice.due_date * 1000).toISOString() + : null, + amountPaid: invoice.amount_paid, + amountDue: invoice.amount_due, + currency: invoice.currency, + status: invoice.status ?? 'unknown', + type: hasSubscription(invoice) ? 'Subscription' : 'One Time', + hostedInvoiceUrl: invoice.hosted_invoice_url ?? null, + invoicePdfUrl: invoice.invoice_pdf ?? null, + }; +} + +function hasSubscription(invoice: Stripe.Invoice): boolean { + if (invoice.billing_reason?.startsWith('subscription')) { + return true; + } + + return invoice.lines.data.some(isSubscriptionLine); +} + +function isSubscriptionLine(line: Stripe.InvoiceLineItem): boolean { + if (line.parent?.subscription_item_details?.subscription) { + return true; + } + + if (line.subscription) { + return true; + } + + const priceId = getLinePriceId(line); + if (!priceId) { + return false; + } + + const sku = getBillingSkuByStripePriceId({ stripePriceId: priceId }); + return sku?.cadence === 'month'; +} + +function getLinePriceId(line: Stripe.InvoiceLineItem): string | null { + const price = line.pricing?.price_details?.price; + if (typeof price === 'string') { + return price; + } + + if (isRecord(price) && typeof price.id === 'string') { + return price.id; + } + + return null; +} + +function isRecord(value: unknown): value is Record { + return typeof value === 'object' && value !== null; +} diff --git a/apps/api/src/billing/billing-preferences.spec.ts b/apps/api/src/billing/billing-preferences.spec.ts new file mode 100644 index 0000000000..72408d6457 --- /dev/null +++ b/apps/api/src/billing/billing-preferences.spec.ts @@ -0,0 +1,262 @@ +import { db } from '@db'; +import type Stripe from 'stripe'; +import type { StripeService } from '../stripe/stripe.service'; +import { updateBillingPreferences } from './billing-preferences'; + +jest.mock('@db', () => ({ + db: { + organizationBilling: { + findUnique: jest.fn(), + }, + }, +})); + +const mockedDb = db as unknown as { + organizationBilling: { findUnique: jest.Mock }; +}; + +function mockStripeService(client: unknown): StripeService { + return { + isConfigured: () => true, + getClient: () => client, + } as unknown as StripeService; +} + +function createCustomer(): Stripe.Customer { + return { + id: 'cus_1', + object: 'customer', + address: { + line1: '1 Test Street', + line2: null, + city: 'London', + state: null, + postal_code: 'SW1A 1AA', + country: 'GB', + }, + business_name: 'Test Company', + created: 1777564800, + default_source: null, + description: null, + email: 'accounts@example.com', + invoice_settings: { + custom_fields: [{ name: 'PO / Reference', value: 'PO-123' }], + default_payment_method: null, + footer: null, + rendering_options: null, + }, + livemode: false, + metadata: {}, + name: 'Test Company', + shipping: null, + } as unknown as Stripe.Customer; +} + +describe('updateBillingPreferences', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockedDb.organizationBilling.findUnique.mockResolvedValue({ + stripeCustomerId: 'cus_1', + }); + }); + + it('updates the Stripe customer fields used for B2B invoices', async () => { + const customersUpdate = jest.fn().mockResolvedValue(createCustomer()); + const taxIdsList = jest.fn().mockResolvedValue({ data: [] }); + const taxIdsCreate = jest.fn().mockResolvedValue({ + id: 'txi_1', + type: 'gb_vat', + value: 'GB123456789', + verification: { status: 'verified' }, + }); + + const result = await updateBillingPreferences({ + stripeService: mockStripeService({ + customers: { update: customersUpdate }, + taxIds: { list: taxIdsList, create: taxIdsCreate, del: jest.fn() }, + }), + organizationId: 'org_1', + preferences: { + companyName: 'Test Company', + billingEmail: 'accounts@example.com', + purchaseOrder: 'PO-123', + address: { + line1: '1 Test Street', + line2: null, + city: 'London', + state: null, + postalCode: 'SW1A 1AA', + country: 'gb', + }, + taxId: { type: 'gb_vat', value: 'GB123456789' }, + }, + }); + + expect(customersUpdate).toHaveBeenCalledWith( + 'cus_1', + expect.objectContaining({ + email: 'accounts@example.com', + name: 'Test Company', + business_name: 'Test Company', + address: expect.objectContaining({ country: 'GB' }), + invoice_settings: { + custom_fields: [{ name: 'PO / Reference', value: 'PO-123' }], + }, + }), + ); + expect(taxIdsCreate).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'gb_vat', + value: 'GB123456789', + owner: { type: 'customer', customer: 'cus_1' }, + }), + expect.objectContaining({ + idempotencyKey: expect.stringContaining('cus_1'), + }), + ); + expect(result.preferences).toEqual( + expect.objectContaining({ + billingEmail: 'accounts@example.com', + purchaseOrder: 'PO-123', + taxId: expect.objectContaining({ type: 'gb_vat' }), + }), + ); + }); + + it('sends empty strings to Stripe to clear blank address fields', async () => { + const customersUpdate = jest.fn().mockResolvedValue({ + ...createCustomer(), + address: null, + }); + const taxIdsList = jest.fn().mockResolvedValue({ data: [] }); + + await updateBillingPreferences({ + stripeService: mockStripeService({ + customers: { update: customersUpdate }, + taxIds: { list: taxIdsList, create: jest.fn(), del: jest.fn() }, + }), + organizationId: 'org_1', + preferences: { + companyName: 'Test Company', + billingEmail: 'accounts@example.com', + purchaseOrder: null, + address: { + line1: '', + line2: null, + city: '', + state: null, + postalCode: '', + country: '', + }, + taxId: null, + }, + }); + + expect(customersUpdate).toHaveBeenCalledWith( + 'cus_1', + expect.objectContaining({ + address: { + line1: '', + line2: '', + city: '', + state: '', + postal_code: '', + country: '', + }, + }), + ); + }); + + it('creates the replacement tax ID before deleting stale tax IDs', async () => { + const customersUpdate = jest.fn().mockResolvedValue(createCustomer()); + const taxIdsDelete = jest.fn().mockResolvedValue({}); + const taxIdsCreate = jest.fn().mockResolvedValue({ + id: 'txi_new', + type: 'gb_vat', + value: 'GB987654321', + verification: { status: 'verified' }, + }); + const taxIdsList = jest.fn().mockResolvedValue({ + data: [ + { id: 'txi_old_1', type: 'gb_vat', value: 'GB111111111' }, + { id: 'txi_old_2', type: 'us_ein', value: '12-3456789' }, + ], + has_more: false, + }); + + await updateBillingPreferences({ + stripeService: mockStripeService({ + customers: { update: customersUpdate }, + taxIds: { + list: taxIdsList, + create: taxIdsCreate, + del: taxIdsDelete, + }, + }), + organizationId: 'org_1', + preferences: { + companyName: 'Test Company', + billingEmail: 'accounts@example.com', + purchaseOrder: null, + address: { + line1: '1 Test Street', + line2: null, + city: 'London', + state: null, + postalCode: 'SW1A 1AA', + country: 'GB', + }, + taxId: { type: 'gb_vat', value: 'GB987654321' }, + }, + }); + + expect(taxIdsDelete).toHaveBeenCalledWith('txi_old_1'); + expect(taxIdsDelete).toHaveBeenCalledWith('txi_old_2'); + expect(taxIdsCreate).toHaveBeenCalled(); + expect(taxIdsCreate.mock.invocationCallOrder[0]).toBeLessThan( + taxIdsDelete.mock.invocationCallOrder[0], + ); + }); + + it('keeps existing tax IDs when Stripe rejects the replacement tax ID', async () => { + const customersUpdate = jest.fn().mockResolvedValue(createCustomer()); + const taxIdsDelete = jest.fn().mockResolvedValue({}); + const taxIdsCreate = jest + .fn() + .mockRejectedValue(new Error('invalid tax id')); + const taxIdsList = jest.fn().mockResolvedValue({ + data: [{ id: 'txi_old_1', type: 'gb_vat', value: 'GB111111111' }], + has_more: false, + }); + + await expect( + updateBillingPreferences({ + stripeService: mockStripeService({ + customers: { update: customersUpdate }, + taxIds: { + list: taxIdsList, + create: taxIdsCreate, + del: taxIdsDelete, + }, + }), + organizationId: 'org_1', + preferences: { + companyName: 'Test Company', + billingEmail: 'accounts@example.com', + purchaseOrder: null, + address: { + line1: '1 Test Street', + line2: null, + city: 'London', + state: null, + postalCode: 'SW1A 1AA', + country: 'GB', + }, + taxId: { type: 'gb_vat', value: 'GB987654321' }, + }, + }), + ).rejects.toThrow('invalid tax id'); + + expect(taxIdsDelete).not.toHaveBeenCalled(); + }); +}); diff --git a/apps/api/src/billing/billing-preferences.ts b/apps/api/src/billing/billing-preferences.ts new file mode 100644 index 0000000000..5037c357d9 --- /dev/null +++ b/apps/api/src/billing/billing-preferences.ts @@ -0,0 +1,300 @@ +import type Stripe from 'stripe'; +import { BadRequestException } from '@nestjs/common'; +import type { StripeService } from '../stripe/stripe.service'; +import { findOrCreateBillingCustomer } from './billing-customer'; + +export const billingTaxIdTypes = [ + 'gb_vat', + 'eu_vat', + 'us_ein', + 'au_abn', + 'ca_bn', + 'nz_gst', + 'sg_gst', + 'sg_uen', +] as const satisfies readonly Stripe.TaxId.Type[]; + +export type BillingTaxIdType = (typeof billingTaxIdTypes)[number]; + +export interface BillingPreferences { + companyName: string | null; + billingEmail: string | null; + purchaseOrder: string | null; + address: { + line1: string | null; + line2: string | null; + city: string | null; + state: string | null; + postalCode: string | null; + country: string | null; + }; + taxId: { + id: string; + type: string; + value: string; + verificationStatus: string | null; + } | null; +} + +export interface BillingPreferencesInput { + companyName: string; + billingEmail: string; + purchaseOrder: string | null; + address: BillingPreferences['address']; + taxId: { + type: string | null; + value: string | null; + } | null; +} + +const purchaseOrderFieldName = 'PO / Reference'; + +export async function getBillingPreferences(params: { + stripeService: StripeService; + stripeCustomerId: string | null; + fallbackCompanyName: string | null; +}): Promise { + if (!params.stripeCustomerId || !params.stripeService.isConfigured()) { + return createEmptyPreferences({ companyName: params.fallbackCompanyName }); + } + + const stripe = params.stripeService.getClient(); + const [customer, taxIds] = await Promise.all([ + stripe.customers.retrieve(params.stripeCustomerId), + listCustomerTaxIds({ stripe, stripeCustomerId: params.stripeCustomerId }), + ]); + + if (isDeletedCustomer(customer)) { + return createEmptyPreferences({ companyName: params.fallbackCompanyName }); + } + + return mapCustomerPreferences({ customer, taxId: taxIds[0] ?? null }); +} + +export async function updateBillingPreferences(params: { + stripeService: StripeService; + organizationId: string; + preferences: BillingPreferencesInput; +}): Promise<{ stripeCustomerId: string; preferences: BillingPreferences }> { + validatePreferences(params.preferences); + + const stripeCustomerId = await findOrCreateBillingCustomer({ + stripeService: params.stripeService, + organizationId: params.organizationId, + customerEmail: params.preferences.billingEmail, + }); + const stripe = params.stripeService.getClient(); + const existingTaxIds = await listCustomerTaxIds({ stripe, stripeCustomerId }); + const customer = await stripe.customers.update(stripeCustomerId, { + email: params.preferences.billingEmail, + name: params.preferences.companyName, + business_name: params.preferences.companyName, + address: toStripeAddress(params.preferences.address), + invoice_settings: { + custom_fields: params.preferences.purchaseOrder + ? [ + { + name: purchaseOrderFieldName, + value: params.preferences.purchaseOrder, + }, + ] + : '', + }, + metadata: { organizationId: params.organizationId }, + }); + + const taxId = await syncPrimaryTaxId({ + stripe, + stripeCustomerId, + existingTaxIds, + taxId: params.preferences.taxId, + }); + + return { + stripeCustomerId, + preferences: mapCustomerPreferences({ customer, taxId }), + }; +} + +function validatePreferences(preferences: BillingPreferencesInput): void { + if (!preferences.companyName.trim()) { + throw new BadRequestException('Company name is required.'); + } + if (!preferences.billingEmail.trim()) { + throw new BadRequestException('Billing email is required.'); + } + + const type = preferences.taxId?.type?.trim() ?? ''; + const value = preferences.taxId?.value?.trim() ?? ''; + if (type && !isSupportedTaxIdType(type)) { + throw new BadRequestException('Unsupported tax ID type.'); + } + if ((type && !value) || (!type && value)) { + throw new BadRequestException( + 'Tax ID type and value must be set together.', + ); + } +} + +async function syncPrimaryTaxId(params: { + stripe: Stripe; + stripeCustomerId: string; + existingTaxIds: Stripe.TaxId[]; + taxId: BillingPreferencesInput['taxId']; +}): Promise { + const type = params.taxId?.type?.trim() ?? ''; + const value = params.taxId?.value?.trim() ?? ''; + if (!type || !value) { + for (const existingTaxId of params.existingTaxIds) { + await params.stripe.taxIds.del(existingTaxId.id); + } + return null; + } + + const matchingTaxId = params.existingTaxIds.find( + (existingTaxId) => + existingTaxId.type === type && existingTaxId.value === value, + ); + if (matchingTaxId) { + for (const existingTaxId of params.existingTaxIds) { + if (existingTaxId.id !== matchingTaxId.id) { + await params.stripe.taxIds.del(existingTaxId.id); + } + } + return matchingTaxId; + } + + if (!isSupportedTaxIdType(type)) { + throw new BadRequestException('Unsupported tax ID type.'); + } + + const createdTaxId = await params.stripe.taxIds.create( + { + type, + value, + owner: { type: 'customer', customer: params.stripeCustomerId }, + }, + { + idempotencyKey: [ + 'billing-tax-id', + params.stripeCustomerId, + type, + value.replace(/[^a-zA-Z0-9]/g, '').toLowerCase(), + ].join(':'), + }, + ); + + for (const existingTaxId of params.existingTaxIds) { + await params.stripe.taxIds.del(existingTaxId.id); + } + + return createdTaxId; +} + +async function listCustomerTaxIds(params: { + stripe: Stripe; + stripeCustomerId: string; +}): Promise { + const taxIds: Stripe.TaxId[] = []; + let startingAfter: string | undefined; + + do { + const page = await params.stripe.taxIds.list({ + owner: { type: 'customer', customer: params.stripeCustomerId }, + limit: 100, + ...(startingAfter ? { starting_after: startingAfter } : {}), + }); + taxIds.push(...page.data); + startingAfter = page.has_more ? page.data.at(-1)?.id : undefined; + } while (startingAfter); + + return taxIds; +} + +function mapCustomerPreferences(params: { + customer: Stripe.Customer; + taxId: Stripe.TaxId | null; +}): BillingPreferences { + return { + companyName: params.customer.name ?? params.customer.business_name ?? null, + billingEmail: params.customer.email ?? null, + purchaseOrder: findInvoiceCustomFieldValue( + params.customer, + purchaseOrderFieldName, + ), + address: { + line1: params.customer.address?.line1 ?? null, + line2: params.customer.address?.line2 ?? null, + city: params.customer.address?.city ?? null, + state: params.customer.address?.state ?? null, + postalCode: params.customer.address?.postal_code ?? null, + country: params.customer.address?.country ?? null, + }, + taxId: params.taxId + ? { + id: params.taxId.id, + type: params.taxId.type, + value: params.taxId.value, + verificationStatus: params.taxId.verification?.status ?? null, + } + : null, + }; +} + +function toStripeAddress( + address: BillingPreferences['address'], +): Stripe.AddressParam { + return { + line1: emptyToString(address.line1), + line2: emptyToString(address.line2), + city: emptyToString(address.city), + state: emptyToString(address.state), + postal_code: emptyToString(address.postalCode), + country: emptyToString(address.country).toUpperCase(), + }; +} + +function createEmptyPreferences(params: { + companyName: string | null; +}): BillingPreferences { + return { + companyName: params.companyName, + billingEmail: null, + purchaseOrder: null, + address: { + line1: null, + line2: null, + city: null, + state: null, + postalCode: null, + country: null, + }, + taxId: null, + }; +} + +function findInvoiceCustomFieldValue( + customer: Stripe.Customer, + name: string, +): string | null { + return ( + customer.invoice_settings.custom_fields?.find( + (field) => field.name === name, + )?.value ?? null + ); +} + +function emptyToString(value: string | null): string { + const trimmed = value?.trim(); + return trimmed ? trimmed : ''; +} + +function isDeletedCustomer( + customer: Stripe.Customer | Stripe.DeletedCustomer, +): customer is Stripe.DeletedCustomer { + return 'deleted' in customer && customer.deleted === true; +} + +function isSupportedTaxIdType(value: string): value is BillingTaxIdType { + return billingTaxIdTypes.some((type) => type === value); +} diff --git a/apps/api/src/billing/billing-redirect-urls.spec.ts b/apps/api/src/billing/billing-redirect-urls.spec.ts new file mode 100644 index 0000000000..0475fb44c6 --- /dev/null +++ b/apps/api/src/billing/billing-redirect-urls.spec.ts @@ -0,0 +1,14 @@ +import { BadRequestException } from '@nestjs/common'; +import { validateBillingRedirectUrl } from './billing-redirect-urls'; + +describe('validateBillingRedirectUrl', () => { + it('allows http only for local development hosts', () => { + expect(() => + validateBillingRedirectUrl('http://localhost:3000/org_1/billing'), + ).not.toThrow(); + + expect(() => + validateBillingRedirectUrl('http://app.trycomp.ai/org_1/billing'), + ).toThrow(BadRequestException); + }); +}); diff --git a/apps/api/src/billing/billing-redirect-urls.ts b/apps/api/src/billing/billing-redirect-urls.ts new file mode 100644 index 0000000000..7865048657 --- /dev/null +++ b/apps/api/src/billing/billing-redirect-urls.ts @@ -0,0 +1,30 @@ +import { BadRequestException } from '@nestjs/common'; + +const allowedHosts = new Set([ + 'localhost', + '127.0.0.1', + 'app.trycomp.ai', + 'app.staging.trycomp.ai', +]); +const localDevelopmentHosts = new Set(['localhost', '127.0.0.1']); + +export function validateBillingRedirectUrl(value: string): void { + let url: URL; + try { + url = new URL(value); + } catch { + throw new BadRequestException('Billing redirect URL is invalid.'); + } + + if (!['http:', 'https:'].includes(url.protocol)) { + throw new BadRequestException('Billing redirect URL is invalid.'); + } + + if (!allowedHosts.has(url.hostname)) { + throw new BadRequestException('Billing redirect URL is not allowed.'); + } + + if (url.protocol === 'http:' && !localDevelopmentHosts.has(url.hostname)) { + throw new BadRequestException('Billing redirect URL must use HTTPS.'); + } +} diff --git a/apps/api/src/billing/billing-setup-sessions.ts b/apps/api/src/billing/billing-setup-sessions.ts new file mode 100644 index 0000000000..ddd2ff4fd3 --- /dev/null +++ b/apps/api/src/billing/billing-setup-sessions.ts @@ -0,0 +1,112 @@ +import { BadRequestException } from '@nestjs/common'; +import { db } from '@db'; +import { StripeService } from '../stripe/stripe.service'; +import { findOrCreateBillingCustomer } from './billing-customer'; +import { validateBillingRedirectUrl } from './billing-redirect-urls'; +import { assertStripeBillingConfigured } from './billing-stripe-config'; +import { extractStripeId } from './billing-stripe-ids'; + +export async function createBillingSetupSession(params: { + organizationId: string; + successUrl: string; + cancelUrl: string; + customerEmail?: string; + stripeService: StripeService; +}): Promise<{ url: string }> { + validateBillingRedirectUrl(params.successUrl); + validateBillingRedirectUrl(params.cancelUrl); + assertStripeBillingConfigured(params.stripeService); + + const stripe = params.stripeService.getClient(); + const customerId = await findOrCreateBillingCustomer({ + stripeService: params.stripeService, + organizationId: params.organizationId, + customerEmail: params.customerEmail, + }); + + const session = await stripe.checkout.sessions.create({ + mode: 'setup', + customer: customerId, + currency: 'usd', + success_url: params.successUrl, + cancel_url: params.cancelUrl, + metadata: { + organizationId: params.organizationId, + source: 'comp-billing-setup', + }, + }); + + if (!session.url) { + throw new BadRequestException('Failed to create Stripe Checkout session.'); + } + + return { url: session.url }; +} + +export async function handleBillingSetupSuccess(params: { + organizationId: string; + sessionId: string; + stripeService: StripeService; +}): Promise<{ success: true }> { + assertStripeBillingConfigured(params.stripeService); + + const stripe = params.stripeService.getClient(); + const session = await stripe.checkout.sessions.retrieve(params.sessionId, { + expand: ['setup_intent'], + }); + + if (session.status !== 'complete') { + throw new BadRequestException('Checkout session is not complete.'); + } + + if (session.metadata?.organizationId !== params.organizationId) { + throw new BadRequestException( + 'Checkout session does not belong to this organization.', + ); + } + + const stripeCustomerId = extractStripeId(session.customer); + if (!stripeCustomerId) { + throw new BadRequestException('Checkout session is missing a customer.'); + } + const billing = await db.organizationBilling.findUnique({ + where: { organizationId: params.organizationId }, + select: { stripeCustomerId: true }, + }); + if (billing && billing.stripeCustomerId !== stripeCustomerId) { + throw new BadRequestException( + 'Checkout session customer does not match this organization.', + ); + } + + const setupIntent = session.setup_intent; + if (!setupIntent || typeof setupIntent === 'string') { + throw new BadRequestException('Checkout session is missing a setup intent.'); + } + + const paymentMethodId = extractStripeId(setupIntent.payment_method); + if (!paymentMethodId) { + throw new BadRequestException('Setup intent is missing a payment method.'); + } + + await stripe.customers.update(stripeCustomerId, { + invoice_settings: { default_payment_method: paymentMethodId }, + }); + + await db.organizationBilling.upsert({ + where: { organizationId: params.organizationId }, + create: { + organizationId: params.organizationId, + stripeCustomerId, + stripePaymentMethodId: paymentMethodId, + paymentMethodUpdatedAt: new Date(), + }, + update: { + stripeCustomerId, + stripePaymentMethodId: paymentMethodId, + paymentMethodUpdatedAt: new Date(), + }, + }); + + return { success: true }; +} diff --git a/apps/api/src/billing/billing-stripe-config.ts b/apps/api/src/billing/billing-stripe-config.ts new file mode 100644 index 0000000000..fa393e1ab9 --- /dev/null +++ b/apps/api/src/billing/billing-stripe-config.ts @@ -0,0 +1,13 @@ +import { HttpException, HttpStatus } from '@nestjs/common'; +import type { StripeService } from '../stripe/stripe.service'; + +export function assertStripeBillingConfigured( + stripeService: StripeService, +): void { + if (stripeService.isConfigured()) return; + + throw new HttpException( + 'Stripe billing is not configured.', + HttpStatus.PAYMENT_REQUIRED, + ); +} diff --git a/apps/api/src/billing/billing-stripe-ids.ts b/apps/api/src/billing/billing-stripe-ids.ts new file mode 100644 index 0000000000..d51132dad5 --- /dev/null +++ b/apps/api/src/billing/billing-stripe-ids.ts @@ -0,0 +1,7 @@ +export function extractStripeId( + value: string | { id?: string } | null, +): string | null { + if (!value) return null; + if (typeof value === 'string') return value; + return value.id ?? null; +} diff --git a/apps/api/src/billing/billing-subscription-plans.spec.ts b/apps/api/src/billing/billing-subscription-plans.spec.ts new file mode 100644 index 0000000000..842f4ec960 --- /dev/null +++ b/apps/api/src/billing/billing-subscription-plans.spec.ts @@ -0,0 +1,75 @@ +import { db } from '@db'; +import { changeSubscriptionPlan } from './billing-subscription-plans'; + +jest.mock('@db', () => ({ + db: { + organizationBillingSubscription: { update: jest.fn() }, + }, +})); + +const mockedDb = db as unknown as { + organizationBillingSubscription: { update: jest.Mock }; +}; + +describe('changeSubscriptionPlan', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockedDb.organizationBillingSubscription.update.mockResolvedValue({}); + }); + + it('includes the source plan and local state version in the Stripe idempotency key', async () => { + const subscriptionsUpdate = jest.fn().mockResolvedValue({ + id: 'sub_1', + status: 'active', + cancel_at_period_end: false, + canceled_at: null, + items: { + data: [ + { + id: 'si_1', + current_period_start: 1775001600, + current_period_end: 1777593600, + }, + ], + }, + }); + + await changeSubscriptionPlan({ + organizationId: 'org_1', + subscription: { + id: 'obs_1', + skuKey: 'pentest_monthly_3', + stripeStatus: 'active', + stripeSubscriptionId: 'sub_1', + stripeSubscriptionItemId: 'si_1', + currentPeriodStart: new Date('2026-04-01T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-05-01T00:00:00.000Z'), + updatedAt: new Date('2026-04-30T10:00:00.000Z'), + }, + skuKey: 'pentest_monthly_5_current', + stripePriceId: 'price_next', + includedQuantity: 5, + stripeService: { + getClient: () => ({ + subscriptions: { update: subscriptionsUpdate }, + }), + } as never, + entitlements: { writeAuditEvent: jest.fn() } as never, + }); + + expect(subscriptionsUpdate).toHaveBeenCalledWith( + 'sub_1', + expect.anything(), + { + idempotencyKey: [ + 'subscription-plan-change-v2', + 'org_1', + 'si_1', + 'pentest_monthly_3', + 'pentest_monthly_5_current', + new Date('2026-04-30T10:00:00.000Z').getTime(), + ].join(':'), + }, + ); + }); +}); diff --git a/apps/api/src/billing/billing-subscription-plans.ts b/apps/api/src/billing/billing-subscription-plans.ts new file mode 100644 index 0000000000..14004d7141 --- /dev/null +++ b/apps/api/src/billing/billing-subscription-plans.ts @@ -0,0 +1,197 @@ +import { HttpException, HttpStatus } from '@nestjs/common'; +import { db } from '@db'; +import { + billingCatalogs, + getBillingSku, + type BillingProductKey, + type BillingSku, + type BillingSkuKey, + getBillingSkuProductKey, +} from '@trycompai/billing'; +import { BillingEntitlementsService } from './billing-entitlements.service'; +import { StripeService } from '../stripe/stripe.service'; + +export async function findActiveProductSubscription(params: { + organizationId: string; + productKey: BillingProductKey; +}) { + const subscriptions = await findProductSubscriptions(params); + return ( + subscriptions.find((subscription) => { + return ( + subscription.stripeStatus === 'active' || + subscription.stripeStatus === 'trialing' + ); + }) ?? null + ); +} + +export async function findProductSubscriptions(params: { + organizationId: string; + productKey: BillingProductKey; +}) { + const subscriptions = await db.organizationBillingSubscription.findMany({ + where: { organizationId: params.organizationId }, + orderBy: { createdAt: 'desc' }, + }); + return subscriptions.filter( + (subscription) => + getBillingSkuProductKey(subscription.skuKey) === params.productKey, + ); +} + +export async function changeSubscriptionPlan(params: { + organizationId: string; + subscription: { + id: string; + skuKey: string; + stripeStatus: string; + stripeSubscriptionId: string; + stripeSubscriptionItemId: string; + currentPeriodStart: Date | null; + currentPeriodEnd: Date | null; + updatedAt: Date; + }; + skuKey: BillingSkuKey; + stripePriceId: string; + includedQuantity: number; + stripeService: StripeService; + entitlements: BillingEntitlementsService; +}): Promise<{ changed: true }> { + const stripe = params.stripeService.getClient(); + const isUpgrade = isPlanUpgrade({ + currentSkuKey: params.subscription.skuKey, + nextSkuKey: params.skuKey, + }); + const shouldEndTrial = + isUpgrade && params.subscription.stripeStatus === 'trialing'; + const updateParams = { + items: [ + isUpgrade + ? { + id: params.subscription.stripeSubscriptionItemId, + price: params.stripePriceId, + quantity: 1, + } + : { + id: params.subscription.stripeSubscriptionItemId, + price: params.stripePriceId, + }, + ], + metadata: { + organizationId: params.organizationId, + skuKey: params.skuKey, + source: 'comp-billing-subscription', + }, + ...(isUpgrade + ? { + proration_behavior: 'always_invoice' as const, + payment_behavior: 'error_if_incomplete' as const, + } + : {}), + ...(shouldEndTrial ? { trial_end: 'now' as const } : {}), + }; + let updatedSubscription: Awaited< + ReturnType + >; + try { + updatedSubscription = await stripe.subscriptions.update( + params.subscription.stripeSubscriptionId, + updateParams, + { + idempotencyKey: [ + 'subscription-plan-change-v2', + params.organizationId, + params.subscription.stripeSubscriptionItemId, + params.subscription.skuKey, + params.skuKey, + params.subscription.updatedAt.getTime(), + ].join(':'), + }, + ); + } catch (error) { + if (isUpgrade && isPaymentRequiredStripeError(error)) { + throw new HttpException( + 'We could not charge the prorated upgrade amount. Please update your payment method and try again.', + HttpStatus.PAYMENT_REQUIRED, + ); + } + throw error; + } + + const updatedItem = + updatedSubscription.items.data.find( + (item) => item.id === params.subscription.stripeSubscriptionItemId, + ) ?? updatedSubscription.items.data[0]; + + const currentPeriodStart = + dateFromSeconds(readNumber(updatedItem, 'current_period_start')) ?? + params.subscription.currentPeriodStart; + const currentPeriodEnd = + dateFromSeconds(readNumber(updatedItem, 'current_period_end')) ?? + params.subscription.currentPeriodEnd; + + await db.organizationBillingSubscription.update({ + where: { id: params.subscription.id }, + data: { + skuKey: params.skuKey, + stripeSubscriptionId: updatedSubscription.id, + stripeSubscriptionItemId: + updatedItem?.id ?? params.subscription.stripeSubscriptionItemId, + stripePriceId: params.stripePriceId, + stripeStatus: updatedSubscription.status, + currentPeriodStart, + currentPeriodEnd, + includedQuantity: params.includedQuantity, + cancelAtPeriodEnd: updatedSubscription.cancel_at_period_end, + canceledAt: dateFromSeconds(updatedSubscription.canceled_at), + }, + }); + + await params.entitlements.writeAuditEvent({ + organizationId: params.organizationId, + eventType: 'subscription_plan_changed', + skuKey: params.skuKey, + metadata: { + stripeSubscriptionId: updatedSubscription.id, + stripeSubscriptionItemId: + updatedItem?.id ?? params.subscription.stripeSubscriptionItemId, + previousSkuKey: params.subscription.skuKey, + }, + }); + + return { changed: true }; +} + +function dateFromSeconds(value: number | null): Date | null { + return value === null ? null : new Date(value * 1000); +} + +function readNumber(value: unknown, key: string): number | null { + if (typeof value !== 'object' || value === null) return null; + const raw = (value as Record)[key]; + return typeof raw === 'number' ? raw : null; +} + +function isPlanUpgrade(params: { + currentSkuKey: string; + nextSkuKey: BillingSkuKey; +}): boolean { + const currentSku = findBillingSku(params.currentSkuKey); + const nextSku = getBillingSku({ skuKey: params.nextSkuKey }); + return currentSku !== null && nextSku.unitAmount > currentSku.unitAmount; +} + +function findBillingSku(skuKey: string): BillingSku | null { + for (const catalog of Object.values(billingCatalogs)) { + const sku = Object.values(catalog.skus).find((item) => item.key === skuKey); + if (sku) return sku; + } + return null; +} + +function isPaymentRequiredStripeError(error: unknown): boolean { + if (typeof error !== 'object' || error === null) return false; + const record = error as Record; + return record.statusCode === HttpStatus.PAYMENT_REQUIRED; +} diff --git a/apps/api/src/billing/billing-usage.spec.ts b/apps/api/src/billing/billing-usage.spec.ts new file mode 100644 index 0000000000..03ba99e1cc --- /dev/null +++ b/apps/api/src/billing/billing-usage.spec.ts @@ -0,0 +1,256 @@ +import { db } from '@db'; +import { listBillingUsageRows } from './billing-usage'; + +jest.mock('@db', () => ({ + db: { + backgroundCheckRequest: { findMany: jest.fn() }, + securityPenetrationTestRun: { findMany: jest.fn() }, + billingUsageEvent: { findMany: jest.fn() }, + }, +})); + +const mockedDb = db as jest.Mocked; +const backgroundCheckFindMany = mockedDb.backgroundCheckRequest + .findMany as unknown as jest.Mock; +const pentestRunFindMany = mockedDb.securityPenetrationTestRun + .findMany as unknown as jest.Mock; +const billingUsageEventFindMany = mockedDb.billingUsageEvent + .findMany as unknown as jest.Mock; + +describe('listBillingUsageRows', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('combines run history with subscription allowance details', async () => { + backgroundCheckFindMany.mockResolvedValue([ + { + id: 'bcr_1', + memberId: 'mem_1', + employeeName: 'Ada Lovelace', + employeeEmail: 'ada@example.com', + status: 'completed', + stripePaymentStatus: 'succeeded', + createdAt: new Date('2026-04-30T10:00:00.000Z'), + updatedAt: new Date('2026-04-30T10:05:00.000Z'), + }, + ]); + pentestRunFindMany.mockResolvedValue([ + { + id: 'ptr_1', + providerRunId: 'run_1', + billingUsageSourceId: 'pending:run_1', + createdAt: new Date('2026-04-30T11:00:00.000Z'), + updatedAt: new Date('2026-04-30T11:05:00.000Z'), + }, + ]); + billingUsageEventFindMany.mockResolvedValue([ + { + skuKey: 'background_checks_monthly_25', + eventType: 'consume', + sourceResourceId: 'mem_1', + stripeInvoiceId: null, + }, + { + skuKey: 'pentest_monthly_5', + eventType: 'consume', + sourceResourceId: 'pending:run_1', + stripeInvoiceId: null, + }, + ]); + + const rows = await listBillingUsageRows({ + organizationId: 'org_1', + subscriptions: [ + { + skuKey: 'background_checks_monthly_25', + includedQuantity: 25, + usedQuantity: 2, + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }, + { + skuKey: 'pentest_monthly_5', + includedQuantity: 5, + usedQuantity: 1, + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }, + ], + }); + + expect(rows).toEqual([ + expect.objectContaining({ + service: 'Penetration Test', + details: 'run_1', + billingType: 'Subscription allowance', + subscriptionRemaining: 4, + }), + expect.objectContaining({ + service: 'Background Check', + details: 'Ada Lovelace (ada@example.com)', + billingType: 'Subscription allowance', + subscriptionRemaining: 23, + }), + ]); + }); + + it('labels legacy pentest rows independently of current subscription state', async () => { + backgroundCheckFindMany.mockResolvedValue([]); + pentestRunFindMany.mockResolvedValue([ + { + id: 'ptr_legacy', + providerRunId: 'run_legacy', + billingUsageSourceId: null, + createdAt: new Date('2026-04-30T11:00:00.000Z'), + updatedAt: new Date('2026-04-30T11:05:00.000Z'), + }, + ]); + billingUsageEventFindMany.mockResolvedValue([]); + + const rows = await listBillingUsageRows({ + organizationId: 'org_1', + subscriptions: [ + { + skuKey: 'pentest_monthly_5', + includedQuantity: 5, + usedQuantity: 1, + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }, + ], + }); + + expect(rows[0]).toEqual( + expect.objectContaining({ + service: 'Penetration Test', + billingType: 'Trial credit', + }), + ); + }); + + it('uses the usage event sku when multiple subscriptions share a product', async () => { + backgroundCheckFindMany.mockResolvedValue([ + { + id: 'bcr_2', + memberId: 'mem_2', + employeeName: 'Grace Hopper', + employeeEmail: 'grace@example.com', + status: 'completed', + stripePaymentStatus: 'succeeded', + createdAt: new Date('2026-04-30T12:00:00.000Z'), + updatedAt: new Date('2026-04-30T12:05:00.000Z'), + }, + ]); + pentestRunFindMany.mockResolvedValue([]); + billingUsageEventFindMany.mockResolvedValue([ + { + skuKey: 'background_checks_monthly_10', + eventType: 'consume', + sourceResourceId: 'mem_2', + stripeInvoiceId: null, + }, + ]); + + const rows = await listBillingUsageRows({ + organizationId: 'org_1', + subscriptions: [ + { + skuKey: 'background_checks_monthly_3', + includedQuantity: 3, + usedQuantity: 3, + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }, + { + skuKey: 'background_checks_monthly_10', + includedQuantity: 10, + usedQuantity: 4, + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }, + ], + }); + + expect(rows[0]).toEqual( + expect.objectContaining({ + skuKey: 'background_checks_monthly_10', + subscriptionRemaining: 6, + subscriptionIncluded: 10, + }), + ); + }); + + it('uses the newest usage event for each source resource', async () => { + backgroundCheckFindMany.mockResolvedValue([ + { + id: 'bcr_4', + memberId: 'mem_4', + employeeName: 'Mary Jackson', + employeeEmail: 'mary@example.com', + status: 'completed', + stripePaymentStatus: 'succeeded', + createdAt: new Date('2026-04-30T15:00:00.000Z'), + updatedAt: new Date('2026-04-30T15:05:00.000Z'), + }, + ]); + pentestRunFindMany.mockResolvedValue([]); + billingUsageEventFindMany.mockResolvedValue([ + { + skuKey: 'background_checks_monthly_10', + eventType: 'one_time', + sourceResourceId: 'mem_4', + stripeInvoiceId: 'in_latest', + }, + { + skuKey: 'background_checks_monthly_3', + eventType: 'consume', + sourceResourceId: 'mem_4', + stripeInvoiceId: null, + }, + ]); + + const rows = await listBillingUsageRows({ + organizationId: 'org_1', + subscriptions: [], + }); + + expect(rows[0]).toEqual( + expect.objectContaining({ + skuKey: 'background_checks_monthly_10', + billingType: 'One-time invoice', + }), + ); + }); + + it('fetches usage events for the displayed source resources without a global cap', async () => { + backgroundCheckFindMany.mockResolvedValue([ + { + id: 'bcr_3', + memberId: 'mem_3', + employeeName: 'Katherine Johnson', + employeeEmail: 'katherine@example.com', + status: 'completed', + stripePaymentStatus: 'succeeded', + createdAt: new Date('2026-04-30T13:00:00.000Z'), + updatedAt: new Date('2026-04-30T13:05:00.000Z'), + }, + ]); + pentestRunFindMany.mockResolvedValue([ + { + id: 'ptr_3', + providerRunId: 'run_3', + billingUsageSourceId: 'pending:run_3', + createdAt: new Date('2026-04-30T14:00:00.000Z'), + updatedAt: new Date('2026-04-30T14:05:00.000Z'), + }, + ]); + billingUsageEventFindMany.mockResolvedValue([]); + + await listBillingUsageRows({ + organizationId: 'org_1', + subscriptions: [], + }); + + const usageQuery = billingUsageEventFindMany.mock.calls[0][0]; + expect(usageQuery.where.sourceResourceId).toEqual({ + in: ['mem_3', 'pending:run_3'], + }); + expect(usageQuery).not.toHaveProperty('take'); + }); +}); diff --git a/apps/api/src/billing/billing-usage.ts b/apps/api/src/billing/billing-usage.ts new file mode 100644 index 0000000000..f383b4c3cf --- /dev/null +++ b/apps/api/src/billing/billing-usage.ts @@ -0,0 +1,175 @@ +import { db } from '@db'; +import { getBillingSkuProductKey } from '@trycompai/billing'; +import type { BillingUsageRow } from './billing.types'; + +type SubscriptionSummary = { + skuKey: string; + includedQuantity: number; + usedQuantity: number; + currentPeriodEnd: Date | null; +}; + +const backgroundCheckSku = 'background_checks_monthly_3'; +const pentestSku = 'pentest_monthly_1'; + +export async function listBillingUsageRows(params: { + organizationId: string; + subscriptions: SubscriptionSummary[]; +}): Promise { + const [backgroundChecks, pentestRuns] = await Promise.all([ + db.backgroundCheckRequest.findMany({ + where: { organizationId: params.organizationId }, + orderBy: { createdAt: 'desc' }, + take: 50, + select: { + id: true, + memberId: true, + employeeName: true, + employeeEmail: true, + status: true, + stripePaymentStatus: true, + createdAt: true, + updatedAt: true, + }, + }), + db.securityPenetrationTestRun.findMany({ + where: { organizationId: params.organizationId }, + orderBy: { createdAt: 'desc' }, + take: 50, + select: { + id: true, + providerRunId: true, + billingUsageSourceId: true, + createdAt: true, + updatedAt: true, + }, + }), + ]); + const sourceResourceIds = [ + ...backgroundChecks.map((request) => request.memberId), + ...pentestRuns.map((run) => run.billingUsageSourceId), + ].filter((value): value is string => typeof value === 'string'); + const usageEvents = + sourceResourceIds.length > 0 + ? await db.billingUsageEvent.findMany({ + where: { + organizationId: params.organizationId, + eventType: { in: ['consume', 'one_time'] }, + sourceResourceId: { in: sourceResourceIds }, + }, + orderBy: { createdAt: 'desc' }, + select: { + skuKey: true, + eventType: true, + sourceResourceId: true, + stripeInvoiceId: true, + }, + }) + : []; + + const usageBySource = new Map(); + for (const event of usageEvents) { + if (!event.sourceResourceId || usageBySource.has(event.sourceResourceId)) { + continue; + } + usageBySource.set(event.sourceResourceId, event); + } + + const rows = [ + ...backgroundChecks.map((request) => { + const usage = usageBySource.get(request.memberId); + return toBillingUsageRow({ + id: request.id, + service: 'Background Check', + skuKey: usage?.skuKey ?? backgroundCheckSku, + details: `${request.employeeName} (${request.employeeEmail})`, + status: formatStatus(request.status), + billingType: formatBillingType( + usage?.eventType, + usage?.stripeInvoiceId, + ), + createdAt: request.createdAt, + updatedAt: request.updatedAt, + subscriptions: params.subscriptions, + }); + }), + ...pentestRuns.map((run) => { + const usage = run.billingUsageSourceId + ? usageBySource.get(run.billingUsageSourceId) + : undefined; + return toBillingUsageRow({ + id: run.id, + service: 'Penetration Test', + skuKey: usage?.skuKey ?? pentestSku, + details: run.providerRunId, + status: 'Created', + billingType: usage + ? formatBillingType(usage.eventType, usage.stripeInvoiceId) + : 'Trial credit', + createdAt: run.createdAt, + updatedAt: run.updatedAt, + subscriptions: params.subscriptions, + }); + }), + ]; + + return rows.sort((first, second) => + second.createdAt.localeCompare(first.createdAt), + ); +} + +function toBillingUsageRow(params: { + id: string; + service: BillingUsageRow['service']; + skuKey: string; + details: string; + status: string; + billingType: string; + createdAt: Date; + updatedAt: Date; + subscriptions: SubscriptionSummary[]; +}): BillingUsageRow { + const productKey = getBillingSkuProductKey(params.skuKey); + const subscription = + params.subscriptions.find((item) => item.skuKey === params.skuKey) ?? + params.subscriptions.find((item) => + productKey + ? getBillingSkuProductKey(item.skuKey) === productKey + : item.skuKey === params.skuKey, + ); + const remaining = subscription + ? Math.max(subscription.includedQuantity - subscription.usedQuantity, 0) + : null; + + return { + id: params.id, + service: params.service, + skuKey: params.skuKey, + details: params.details, + status: params.status, + billingType: params.billingType, + createdAt: params.createdAt.toISOString(), + updatedAt: params.updatedAt.toISOString(), + subscriptionRemaining: remaining, + subscriptionIncluded: subscription?.includedQuantity ?? null, + subscriptionPeriodEnd: + subscription?.currentPeriodEnd?.toISOString() ?? null, + }; +} + +function formatBillingType( + eventType?: string, + stripeInvoiceId?: string | null, +): string { + if (eventType === 'consume') return 'Subscription allowance'; + if (eventType === 'one_time') + return stripeInvoiceId ? 'One-time invoice' : 'One-time'; + return 'Legacy / manual'; +} + +function formatStatus(status: string): string { + return status + .split('_') + .map((part) => part.charAt(0).toUpperCase() + part.slice(1)) + .join(' '); +} diff --git a/apps/api/src/billing/billing-webhook.service.spec.ts b/apps/api/src/billing/billing-webhook.service.spec.ts new file mode 100644 index 0000000000..a434431ebe --- /dev/null +++ b/apps/api/src/billing/billing-webhook.service.spec.ts @@ -0,0 +1,69 @@ +import { BillingWebhookService } from './billing-webhook.service'; +import { + claimStripeWebhookEvent, + markStripeWebhookFailed, +} from './stripe-webhook-records'; + +jest.mock('@db', () => ({ + Prisma: {}, + db: {}, +})); + +jest.mock('./stripe-webhook-records', () => ({ + claimStripeWebhookEvent: jest.fn(), + markStripeWebhookFailed: jest.fn(), + markStripeWebhookProcessed: jest.fn(), +})); + +const mockClaimStripeWebhookEvent = claimStripeWebhookEvent as jest.Mock; +const mockMarkStripeWebhookFailed = markStripeWebhookFailed as jest.Mock; + +describe('BillingWebhookService', () => { + const originalSecret = process.env.STRIPE_WEBHOOK_SECRET; + + beforeEach(() => { + jest.clearAllMocks(); + process.env.STRIPE_WEBHOOK_SECRET = 'whsec_test'; + mockClaimStripeWebhookEvent.mockResolvedValue({ status: 'claimed' }); + }); + + afterAll(() => { + if (typeof originalSecret === 'string') { + process.env.STRIPE_WEBHOOK_SECRET = originalSecret; + return; + } + delete process.env.STRIPE_WEBHOOK_SECRET; + }); + + it('rethrows the processing error when marking the webhook failed also fails', async () => { + const processingError = new Error('processing failed'); + mockMarkStripeWebhookFailed.mockRejectedValue(new Error('db failed')); + const service = new BillingWebhookService( + { + getClient: () => ({ + webhooks: { + constructEvent: () => ({ + id: 'evt_1', + type: 'invoice.payment_failed', + data: { object: { customer: 'cus_missing' } }, + }), + }, + }), + } as never, + {} as never, + ); + jest + .spyOn( + service as unknown as { processEvent: () => Promise }, + 'processEvent', + ) + .mockRejectedValue(processingError); + + await expect( + service.handleWebhook({ + rawBody: Buffer.from('{}'), + signature: 'sig', + }), + ).rejects.toBe(processingError); + }); +}); diff --git a/apps/api/src/billing/billing-webhook.service.ts b/apps/api/src/billing/billing-webhook.service.ts new file mode 100644 index 0000000000..f56e7c26c6 --- /dev/null +++ b/apps/api/src/billing/billing-webhook.service.ts @@ -0,0 +1,282 @@ +import { BadRequestException, Injectable } from '@nestjs/common'; +import { Prisma, db } from '@db'; +import { getBillingSkuByStripePriceId } from '@trycompai/billing'; +import Stripe from 'stripe'; +import { StripeService } from '../stripe/stripe.service'; +import { BillingEntitlementsService } from './billing-entitlements.service'; +import { + claimStripeWebhookEvent, + markStripeWebhookFailed, + markStripeWebhookProcessed, +} from './stripe-webhook-records'; + +@Injectable() +export class BillingWebhookService { + constructor( + private readonly stripeService: StripeService, + private readonly entitlements: BillingEntitlementsService, + ) {} + + async handleWebhook(params: { + rawBody: Buffer | undefined; + signature: string | undefined; + }): Promise<{ ok: true; duplicate?: true }> { + if (!params.rawBody) throw new BadRequestException('Raw body unavailable.'); + if (!params.signature) { + throw new BadRequestException('Stripe signature header is missing.'); + } + + const secret = + process.env.STRIPE_WEBHOOK_SECRET ?? + process.env.STRIPE_PENTEST_WEBHOOK_SECRET; + if (!secret) + throw new BadRequestException('Stripe webhook secret is not configured.'); + + const stripe = this.stripeService.getClient(); + let event: Stripe.Event; + try { + event = stripe.webhooks.constructEvent( + params.rawBody, + params.signature, + secret, + ); + } catch { + throw new BadRequestException('Invalid Stripe webhook signature.'); + } + const claim = await claimStripeWebhookEvent({ + stripeEventId: event.id, + eventType: event.type, + payload: event.data.object as unknown as Prisma.InputJsonValue, + }); + if (claim.status === 'duplicate') return { ok: true, duplicate: true }; + + try { + await this.processEvent(event); + await markStripeWebhookProcessed(event.id); + return { ok: true }; + } catch (error) { + try { + await markStripeWebhookFailed({ stripeEventId: event.id, error }); + } catch { + // Preserve the processing error so Stripe retries for the real failure. + } + throw error; + } + } + + private async processEvent(event: Stripe.Event): Promise { + switch (event.type) { + case 'checkout.session.completed': + await this.handleCheckoutSessionCompleted(event); + return; + case 'customer.subscription.updated': + case 'customer.subscription.deleted': + await this.syncSubscriptionFromEvent(event); + return; + case 'invoice.paid': + await this.handleInvoicePaid(event); + return; + case 'invoice.payment_failed': + case 'invoice.payment_action_required': + await this.handleInvoiceRecoveryEvent(event); + return; + default: + return; + } + } + + private async handleCheckoutSessionCompleted( + event: Stripe.Event, + ): Promise { + const session = event.data.object as Stripe.Checkout.Session; + if (session.mode !== 'subscription') return; + + const organizationId = session.metadata?.organizationId; + const customerId = extractStripeId(session.customer); + if (!organizationId || !customerId) return; + + await db.organizationBilling.upsert({ + where: { organizationId }, + create: { organizationId, stripeCustomerId: customerId }, + update: { stripeCustomerId: customerId }, + }); + + const subscriptionId = extractStripeId(session.subscription); + if (!subscriptionId) return; + const subscription = await this.retrieveSubscription(subscriptionId); + await this.syncCustomerPaymentMethod({ + organizationId, + customerId, + subscription, + stripeEventId: event.id, + }); + await this.syncSubscriptionItems({ + subscription, + organizationId, + stripeEventId: event.id, + }); + } + + private async syncSubscriptionFromEvent(event: Stripe.Event): Promise { + const subscription = event.data.object as Stripe.Subscription; + const organizationId = + await this.resolveSubscriptionOrganization(subscription); + if (!organizationId) return; + await this.syncSubscriptionItems({ + subscription, + organizationId, + stripeEventId: event.id, + }); + } + + private async handleInvoicePaid(event: Stripe.Event): Promise { + const invoice = event.data.object as Stripe.Invoice; + const subscriptionId = extractStripeId(readField(invoice, 'subscription')); + if (!subscriptionId) return; + const subscription = await this.retrieveSubscription(subscriptionId); + const organizationId = + await this.resolveSubscriptionOrganization(subscription); + if (!organizationId) return; + await this.syncSubscriptionItems({ + subscription, + organizationId, + stripeEventId: event.id, + }); + } + + private async handleInvoiceRecoveryEvent(event: Stripe.Event): Promise { + const invoice = event.data.object as Stripe.Invoice; + const customerId = extractStripeId(invoice.customer); + if (!customerId) return; + const billing = await db.organizationBilling.findFirst({ + where: { stripeCustomerId: customerId }, + select: { organizationId: true }, + }); + if (!billing) return; + await this.entitlements.writeAuditEvent({ + organizationId: billing.organizationId, + eventType: event.type, + stripeEventId: event.id, + metadata: { invoiceId: invoice.id }, + }); + } + + private async retrieveSubscription( + subscriptionId: string, + ): Promise { + return this.stripeService + .getClient() + .subscriptions.retrieve(subscriptionId, { + expand: ['items.data.price'], + }); + } + + private async syncCustomerPaymentMethod(params: { + organizationId: string; + customerId: string; + subscription: Stripe.Subscription; + stripeEventId: string; + }): Promise { + const subscriptionPaymentMethodId = extractStripeId( + params.subscription.default_payment_method, + ); + const customer = await this.stripeService + .getClient() + .customers.retrieve(params.customerId); + if (customer.deleted) return; + + const customerPaymentMethodId = extractStripeId( + customer.invoice_settings.default_payment_method, + ); + const paymentMethodId = + subscriptionPaymentMethodId ?? customerPaymentMethodId; + if (!paymentMethodId) return; + + await this.stripeService.getClient().customers.update(params.customerId, { + invoice_settings: { default_payment_method: paymentMethodId }, + }); + await db.organizationBilling.update({ + where: { organizationId: params.organizationId }, + data: { + stripePaymentMethodId: paymentMethodId, + paymentMethodUpdatedAt: new Date(), + }, + }); + await this.entitlements.writeAuditEvent({ + organizationId: params.organizationId, + eventType: 'payment_method_updated', + stripeEventId: params.stripeEventId, + metadata: { source: 'checkout.session.completed' }, + }); + } + + private async syncSubscriptionItems(params: { + subscription: Stripe.Subscription; + organizationId: string; + stripeEventId: string; + }): Promise { + for (const item of params.subscription.items.data) { + const priceId = item.price.id; + const sku = getBillingSkuByStripePriceId({ stripePriceId: priceId }); + if (!sku?.includedUsage) continue; + await this.entitlements.syncSubscriptionItem({ + organizationId: params.organizationId, + skuKey: sku.key, + stripeSubscriptionId: params.subscription.id, + stripeSubscriptionItemId: item.id, + stripePriceId: priceId, + stripeStatus: params.subscription.status, + currentPeriodStart: dateFromSeconds( + readNumber(item, 'current_period_start'), + ), + currentPeriodEnd: dateFromSeconds( + readNumber(item, 'current_period_end'), + ), + includedQuantity: sku.includedUsage.quantity, + cancelAtPeriodEnd: params.subscription.cancel_at_period_end, + canceledAt: dateFromSeconds(params.subscription.canceled_at), + stripeEventId: params.stripeEventId, + }); + } + } + + private async resolveSubscriptionOrganization( + subscription: Stripe.Subscription, + ): Promise { + if (subscription.metadata?.organizationId) { + return subscription.metadata.organizationId; + } + const customerId = extractStripeId(subscription.customer); + if (!customerId) return null; + const billing = await db.organizationBilling.findFirst({ + where: { stripeCustomerId: customerId }, + select: { organizationId: true }, + }); + return billing?.organizationId ?? null; + } +} + +function extractStripeId(value: unknown): string | null { + if (typeof value === 'string') return value; + if (!isRecord(value)) return null; + return typeof value.id === 'string' ? value.id : null; +} + +function dateFromSeconds(value: number | null): Date | null { + return value === null ? null : new Date(value * 1000); +} + +function readNumber(value: unknown, key: string): number | null { + if (!isRecord(value)) return null; + const raw = value[key]; + return typeof raw === 'number' ? raw : null; +} + +function readField(value: unknown, key: string): unknown { + if (!isRecord(value)) return null; + return value[key]; +} + +function isRecord(value: unknown): value is Record { + return typeof value === 'object' && value !== null; +} diff --git a/apps/api/src/billing/billing.controller.ts b/apps/api/src/billing/billing.controller.ts new file mode 100644 index 0000000000..a2e4e9ddb6 --- /dev/null +++ b/apps/api/src/billing/billing.controller.ts @@ -0,0 +1,153 @@ +import { + Body, + Controller, + Get, + Headers, + HttpCode, + Post, + Put, + Req, + UseGuards, + type RawBodyRequest, +} from '@nestjs/common'; +import { ApiOperation, ApiSecurity, ApiTags } from '@nestjs/swagger'; +import type { Request } from 'express'; +import { AuthContext, OrganizationId } from '../auth/auth-context.decorator'; +import { HybridAuthGuard } from '../auth/hybrid-auth.guard'; +import { PermissionGuard } from '../auth/permission.guard'; +import { Public } from '../auth/public.decorator'; +import { RequirePermission } from '../auth/require-permission.decorator'; +import type { AuthContext as AuthContextType } from '../auth/types'; +import { BillingService } from './billing.service'; +import { BillingWebhookService } from './billing-webhook.service'; +import { + BillingPortalDto, + BillingPreferencesDto, + BillingSetupSessionDto, + BillingSetupSuccessDto, + BillingSubscriptionCheckoutDto, +} from './dto/billing.dto'; + +@ApiTags('Billing') +@Controller({ path: 'billing', version: '1' }) +@UseGuards(HybridAuthGuard, PermissionGuard) +@ApiSecurity('apikey') +export class BillingController { + constructor( + private readonly billingService: BillingService, + private readonly webhookService: BillingWebhookService, + ) {} + + @Get('status') + @RequirePermission('organization', 'read') + @ApiOperation({ summary: 'Get organization billing status' }) + async getStatus(@OrganizationId() organizationId: string) { + return this.billingService.getStatus(organizationId); + } + + @Put('preferences') + @RequirePermission('organization', 'update') + @ApiOperation({ summary: 'Update organization billing preferences' }) + async updatePreferences( + @OrganizationId() organizationId: string, + @Body() body: BillingPreferencesDto, + ) { + return this.billingService.updatePreferences({ + organizationId, + preferences: { + companyName: body.companyName, + billingEmail: body.billingEmail, + purchaseOrder: body.purchaseOrder ?? null, + address: { + line1: body.addressLine1 ?? null, + line2: body.addressLine2 ?? null, + city: body.addressCity ?? null, + state: body.addressState ?? null, + postalCode: body.addressPostalCode ?? null, + country: body.addressCountry ?? null, + }, + taxId: { + type: body.taxIdType ?? null, + value: body.taxIdValue ?? null, + }, + }, + }); + } + + @Post('setup-session') + @RequirePermission('organization', 'update') + @HttpCode(200) + @ApiOperation({ summary: 'Create a Stripe setup session' }) + async setupSession( + @OrganizationId() organizationId: string, + @AuthContext() authContext: AuthContextType, + @Body() body: BillingSetupSessionDto, + ) { + return this.billingService.createSetupSession({ + organizationId, + successUrl: body.successUrl, + cancelUrl: body.cancelUrl, + customerEmail: authContext.userEmail, + }); + } + + @Post('setup-success') + @RequirePermission('organization', 'update') + @HttpCode(200) + @ApiOperation({ summary: 'Persist a successful Stripe setup session' }) + async setupSuccess( + @OrganizationId() organizationId: string, + @Body() body: BillingSetupSuccessDto, + ) { + return this.billingService.handleSetupSuccess({ + organizationId, + sessionId: body.sessionId, + }); + } + + @Post('portal') + @RequirePermission('organization', 'update') + @HttpCode(200) + @ApiOperation({ summary: 'Create a Stripe billing portal session' }) + async portal( + @OrganizationId() organizationId: string, + @Body() body: BillingPortalDto, + ) { + return this.billingService.createBillingPortalSession({ + organizationId, + returnUrl: body.returnUrl, + }); + } + + @Post('subscription-session') + @RequirePermission('organization', 'update') + @HttpCode(200) + @ApiOperation({ summary: 'Create a Stripe subscription Checkout session' }) + async subscriptionSession( + @OrganizationId() organizationId: string, + @AuthContext() authContext: AuthContextType, + @Body() body: BillingSubscriptionCheckoutDto, + ) { + return this.billingService.createSubscriptionCheckoutSession({ + organizationId, + skuKey: body.skuKey, + successUrl: body.successUrl, + cancelUrl: body.cancelUrl, + customerEmail: authContext.userEmail, + }); + } + + @Post('webhook') + @Public() + @HttpCode(200) + @ApiOperation({ summary: 'Receive Stripe billing webhook events' }) + async webhook( + @Headers('stripe-signature') signature: string | undefined, + @Req() req: RawBodyRequest, + ) { + return this.webhookService.handleWebhook({ + rawBody: req.rawBody, + signature, + }); + } +} diff --git a/apps/api/src/billing/billing.module.ts b/apps/api/src/billing/billing.module.ts new file mode 100644 index 0000000000..9e22d1e96f --- /dev/null +++ b/apps/api/src/billing/billing.module.ts @@ -0,0 +1,21 @@ +import { Module } from '@nestjs/common'; +import { AuthModule } from '../auth/auth.module'; +import { StripeModule } from '../stripe/stripe.module'; +import { BillingController } from './billing.controller'; +import { BillingCreditsService } from './billing-credits.service'; +import { BillingEntitlementsService } from './billing-entitlements.service'; +import { BillingService } from './billing.service'; +import { BillingWebhookService } from './billing-webhook.service'; + +@Module({ + imports: [AuthModule, StripeModule], + controllers: [BillingController], + providers: [ + BillingService, + BillingCreditsService, + BillingEntitlementsService, + BillingWebhookService, + ], + exports: [BillingService, BillingCreditsService, BillingEntitlementsService], +}) +export class BillingModule {} diff --git a/apps/api/src/billing/billing.service.spec.ts b/apps/api/src/billing/billing.service.spec.ts new file mode 100644 index 0000000000..05084a2b17 --- /dev/null +++ b/apps/api/src/billing/billing.service.spec.ts @@ -0,0 +1,557 @@ +import { BadRequestException, HttpException, HttpStatus } from '@nestjs/common'; +import { db } from '@db'; +import { BillingService } from './billing.service'; +import type { StripeService } from '../stripe/stripe.service'; + +jest.mock('@db', () => ({ + db: { + organization: { + findUniqueOrThrow: jest.fn(), + }, + organizationBilling: { + create: jest.fn(), + findUnique: jest.fn(), + }, + organizationBillingSubscription: { + findMany: jest.fn(), + update: jest.fn(), + }, + backgroundCheckRequest: { + count: jest.fn(), + findMany: jest.fn(), + }, + securityPenetrationTestRun: { + count: jest.fn(), + findMany: jest.fn(), + }, + billingUsageEvent: { + findMany: jest.fn(), + }, + }, +})); + +const mockedDb = db as jest.Mocked; +const organizationFindUniqueOrThrow = mockedDb.organization + .findUniqueOrThrow as unknown as jest.Mock; +const organizationBillingFindUnique = mockedDb.organizationBilling + .findUnique as unknown as jest.Mock; +const organizationBillingCreate = mockedDb.organizationBilling + .create as unknown as jest.Mock; +const organizationBillingSubscriptionFindMany = mockedDb + .organizationBillingSubscription.findMany as unknown as jest.Mock; +const organizationBillingSubscriptionUpdate = mockedDb + .organizationBillingSubscription.update as unknown as jest.Mock; +const backgroundCheckRequestCount = mockedDb.backgroundCheckRequest + .count as unknown as jest.Mock; +const backgroundCheckRequestFindMany = mockedDb.backgroundCheckRequest + .findMany as unknown as jest.Mock; +const securityPenetrationTestRunCount = mockedDb.securityPenetrationTestRun + .count as unknown as jest.Mock; +const securityPenetrationTestRunFindMany = mockedDb.securityPenetrationTestRun + .findMany as unknown as jest.Mock; +const billingUsageEventFindMany = mockedDb.billingUsageEvent + .findMany as unknown as jest.Mock; + +function mockStripeService(client: unknown): StripeService { + return { + isConfigured: () => client !== null, + getClient: () => client, + } as unknown as StripeService; +} + +describe('BillingService', () => { + beforeEach(() => { + jest.clearAllMocks(); + organizationFindUniqueOrThrow.mockResolvedValue({ + id: 'org_1', + name: 'Test Company', + }); + organizationBillingFindUnique.mockResolvedValue(null); + organizationBillingCreate.mockResolvedValue({ + id: 'obil_1', + organizationId: 'org_1', + stripeCustomerId: 'cus_1', + stripePaymentMethodId: null, + paymentMethodUpdatedAt: null, + createdAt: new Date('2026-04-30T00:00:00.000Z'), + updatedAt: new Date('2026-04-30T00:00:00.000Z'), + }); + organizationBillingSubscriptionFindMany.mockResolvedValue([]); + organizationBillingSubscriptionUpdate.mockResolvedValue({}); + backgroundCheckRequestCount.mockResolvedValue(0); + backgroundCheckRequestFindMany.mockResolvedValue([]); + securityPenetrationTestRunCount.mockResolvedValue(0); + securityPenetrationTestRunFindMany.mockResolvedValue([]); + billingUsageEventFindMany.mockResolvedValue([]); + }); + + it('creates a Stripe subscription checkout session from the billing catalog', async () => { + const customersCreate = jest.fn().mockResolvedValue({ id: 'cus_1' }); + const customersUpdate = jest.fn().mockResolvedValue({ id: 'cus_1' }); + const sessionsCreate = jest.fn().mockResolvedValue({ + url: 'https://checkout.stripe.test/session', + }); + const service = new BillingService( + mockStripeService({ + customers: { create: customersCreate, update: customersUpdate }, + checkout: { sessions: { create: sessionsCreate } }, + }), + { syncSubscriptionItem: jest.fn() } as never, + ); + + await expect( + service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_1', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + customerEmail: 'admin@example.com', + }), + ).resolves.toEqual({ url: 'https://checkout.stripe.test/session' }); + + expect(customersCreate).toHaveBeenCalledWith( + expect.objectContaining({ + metadata: { organizationId: 'org_1' }, + }), + { idempotencyKey: 'organization-billing-customer:org_1' }, + ); + expect(customersUpdate).toHaveBeenCalledWith('cus_1', { + email: 'admin@example.com', + }); + expect(sessionsCreate).toHaveBeenCalledWith( + expect.objectContaining({ + mode: 'subscription', + customer: 'cus_1', + line_items: [{ price: 'price_1TS3ziCkFWhKYvHI0H5TWxNI', quantity: 1 }], + payment_method_collection: 'always', + metadata: expect.objectContaining({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_1', + }), + subscription_data: expect.objectContaining({ + trial_period_days: 14, + }), + }), + ); + }); + + it('does not apply a trial when the product has historical subscription rows', async () => { + const customersCreate = jest.fn().mockResolvedValue({ id: 'cus_1' }); + const sessionsCreate = jest.fn().mockResolvedValue({ + url: 'https://checkout.stripe.test/session', + }); + organizationBillingSubscriptionFindMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'pentest_monthly_3', + stripeSubscriptionId: 'sub_1', + stripeSubscriptionItemId: 'si_1', + stripeStatus: 'canceled', + createdAt: new Date('2026-04-01T00:00:00.000Z'), + }, + ]); + const service = new BillingService( + mockStripeService({ + customers: { create: customersCreate }, + checkout: { sessions: { create: sessionsCreate } }, + }), + { syncSubscriptionItem: jest.fn() } as never, + ); + + await service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_1', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }); + + expect(sessionsCreate).toHaveBeenCalledWith( + expect.not.objectContaining({ + payment_method_collection: 'always', + }), + ); + expect(sessionsCreate).toHaveBeenCalledWith( + expect.objectContaining({ + subscription_data: expect.not.objectContaining({ + trial_period_days: expect.any(Number), + }), + }), + ); + }); + + it('never applies a trial to higher tiers', async () => { + const customersCreate = jest.fn().mockResolvedValue({ id: 'cus_1' }); + const sessionsCreate = jest.fn().mockResolvedValue({ + url: 'https://checkout.stripe.test/session', + }); + const service = new BillingService( + mockStripeService({ + customers: { create: customersCreate }, + checkout: { sessions: { create: sessionsCreate } }, + }), + { syncSubscriptionItem: jest.fn() } as never, + ); + + await service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_3', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }); + + expect(sessionsCreate).toHaveBeenCalledWith( + expect.objectContaining({ + subscription_data: expect.not.objectContaining({ + trial_period_days: expect.any(Number), + }), + }), + ); + }); + + it('marks trial eligibility false after any product subscription history', async () => { + organizationBillingSubscriptionFindMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'background_checks_monthly_10', + stripeStatus: 'canceled', + includedQuantity: 10, + usedQuantity: 0, + currentPeriodStart: null, + currentPeriodEnd: null, + cancelAtPeriodEnd: false, + }, + ]); + const service = new BillingService( + mockStripeService({ + invoices: { list: jest.fn().mockResolvedValue({ data: [] }) }, + customers: { retrieve: jest.fn().mockResolvedValue({}) }, + paymentMethods: { retrieve: jest.fn() }, + }), + { syncSubscriptionItem: jest.fn() } as never, + ); + + await expect(service.getStatus('org_1')).resolves.toMatchObject({ + trialEligibility: { + pentest: true, + background_check: false, + }, + }); + }); + + it('charges immediately when upgrading an existing product subscription', async () => { + const subscriptionsUpdate = jest.fn().mockResolvedValue({ + id: 'sub_1', + status: 'active', + cancel_at_period_end: false, + canceled_at: null, + items: { + data: [ + { + id: 'si_1', + current_period_start: 1775001600, + current_period_end: 1777593600, + }, + ], + }, + }); + const writeAuditEvent = jest.fn().mockResolvedValue(undefined); + organizationBillingSubscriptionFindMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'pentest_monthly_3', + stripeSubscriptionId: 'sub_1', + stripeSubscriptionItemId: 'si_1', + stripeStatus: 'active', + currentPeriodStart: new Date('2026-04-01T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-05-01T00:00:00.000Z'), + updatedAt: new Date('2026-04-30T09:00:00.000Z'), + createdAt: new Date('2026-04-01T00:00:00.000Z'), + }, + ]); + const service = new BillingService( + mockStripeService({ + subscriptions: { update: subscriptionsUpdate }, + }), + { syncSubscriptionItem: jest.fn(), writeAuditEvent } as never, + ); + + await expect( + service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_5_current', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }), + ).resolves.toEqual({ changed: true }); + + expect(subscriptionsUpdate).toHaveBeenCalledWith( + 'sub_1', + expect.objectContaining({ + items: [ + { + id: 'si_1', + price: 'price_1TS3zjCkFWhKYvHISBHjtZXB', + quantity: 1, + }, + ], + proration_behavior: 'always_invoice', + payment_behavior: 'error_if_incomplete', + }), + expect.anything(), + ); + expect(organizationBillingSubscriptionUpdate).toHaveBeenCalledWith( + expect.objectContaining({ + where: { id: 'obs_1' }, + data: expect.objectContaining({ + skuKey: 'pentest_monthly_5_current', + stripeSubscriptionItemId: 'si_1', + includedQuantity: 5, + }), + }), + ); + expect(writeAuditEvent).toHaveBeenCalledWith( + expect.objectContaining({ + organizationId: 'org_1', + eventType: 'subscription_plan_changed', + skuKey: 'pentest_monthly_5_current', + }), + ); + }); + + it('ends the trial immediately when upgrading from a trial plan', async () => { + const subscriptionsUpdate = jest.fn().mockResolvedValue({ + id: 'sub_1', + status: 'active', + cancel_at_period_end: false, + canceled_at: null, + items: { + data: [ + { + id: 'si_1', + current_period_start: 1775001600, + current_period_end: 1777593600, + }, + ], + }, + }); + organizationBillingSubscriptionFindMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'background_checks_monthly_3', + stripeSubscriptionId: 'sub_1', + stripeSubscriptionItemId: 'si_1', + stripeStatus: 'trialing', + currentPeriodStart: new Date('2026-04-01T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-05-01T00:00:00.000Z'), + updatedAt: new Date('2026-04-30T09:00:00.000Z'), + createdAt: new Date('2026-04-01T00:00:00.000Z'), + }, + ]); + const service = new BillingService( + mockStripeService({ + subscriptions: { update: subscriptionsUpdate }, + }), + { syncSubscriptionItem: jest.fn(), writeAuditEvent: jest.fn() } as never, + ); + + await service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'background_checks_monthly_20', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }); + + expect(subscriptionsUpdate).toHaveBeenCalledWith( + 'sub_1', + expect.objectContaining({ + items: [ + { + id: 'si_1', + price: 'price_1TS3zjCkFWhKYvHIU5jMCCWs', + quantity: 1, + }, + ], + proration_behavior: 'always_invoice', + payment_behavior: 'error_if_incomplete', + trial_end: 'now', + }), + expect.objectContaining({ + idempotencyKey: [ + 'subscription-plan-change-v2', + 'org_1', + 'si_1', + 'background_checks_monthly_3', + 'background_checks_monthly_20', + new Date('2026-04-30T09:00:00.000Z').getTime(), + ].join(':'), + }), + ); + }); + + it('does not grant upgraded credits when immediate upgrade payment fails', async () => { + const subscriptionsUpdate = jest.fn().mockRejectedValue( + Object.assign(new Error('Card declined'), { + statusCode: HttpStatus.PAYMENT_REQUIRED, + }), + ); + const writeAuditEvent = jest.fn().mockResolvedValue(undefined); + organizationBillingSubscriptionFindMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'pentest_monthly_3', + stripeSubscriptionId: 'sub_1', + stripeSubscriptionItemId: 'si_1', + stripeStatus: 'active', + currentPeriodStart: new Date('2026-04-01T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-05-01T00:00:00.000Z'), + updatedAt: new Date('2026-04-30T09:00:00.000Z'), + createdAt: new Date('2026-04-01T00:00:00.000Z'), + }, + ]); + const service = new BillingService( + mockStripeService({ + subscriptions: { update: subscriptionsUpdate }, + }), + { syncSubscriptionItem: jest.fn(), writeAuditEvent } as never, + ); + + try { + await service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_5_current', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }); + throw new Error('Expected upgrade payment failure'); + } catch (error) { + expect(error).toBeInstanceOf(HttpException); + if (error instanceof HttpException) { + expect(error.getStatus()).toBe(HttpStatus.PAYMENT_REQUIRED); + } + } + expect(organizationBillingSubscriptionUpdate).not.toHaveBeenCalled(); + expect(writeAuditEvent).not.toHaveBeenCalled(); + }); + + it('keeps same-plan subscription changes rejected before calling Stripe', async () => { + const subscriptionsUpdate = jest.fn(); + organizationBillingSubscriptionFindMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'pentest_monthly_3', + stripeSubscriptionId: 'sub_1', + stripeSubscriptionItemId: 'si_1', + stripeStatus: 'active', + createdAt: new Date('2026-04-01T00:00:00.000Z'), + }, + ]); + const service = new BillingService( + mockStripeService({ + subscriptions: { update: subscriptionsUpdate }, + }), + { syncSubscriptionItem: jest.fn(), writeAuditEvent: jest.fn() } as never, + ); + + await expect( + service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_3', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }), + ).rejects.toBeInstanceOf(BadRequestException); + expect(subscriptionsUpdate).not.toHaveBeenCalled(); + }); + + it('does not immediately invoice lower-priced plan switches', async () => { + const subscriptionsUpdate = jest.fn().mockResolvedValue({ + id: 'sub_1', + status: 'active', + cancel_at_period_end: false, + canceled_at: null, + items: { + data: [ + { + id: 'si_1', + current_period_start: 1775001600, + current_period_end: 1777593600, + }, + ], + }, + }); + organizationBillingSubscriptionFindMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'pentest_monthly_5_current', + stripeSubscriptionId: 'sub_1', + stripeSubscriptionItemId: 'si_1', + stripeStatus: 'active', + currentPeriodStart: new Date('2026-04-01T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-05-01T00:00:00.000Z'), + updatedAt: new Date('2026-04-30T09:00:00.000Z'), + createdAt: new Date('2026-04-01T00:00:00.000Z'), + }, + ]); + const service = new BillingService( + mockStripeService({ + subscriptions: { update: subscriptionsUpdate }, + }), + { syncSubscriptionItem: jest.fn(), writeAuditEvent: jest.fn() } as never, + ); + + await service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_3', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }); + + expect(subscriptionsUpdate).toHaveBeenCalledWith( + 'sub_1', + { + items: [{ id: 'si_1', price: 'price_1TS3ziCkFWhKYvHI1nbXC7UU' }], + metadata: { + organizationId: 'org_1', + skuKey: 'pentest_monthly_3', + source: 'comp-billing-subscription', + }, + }, + expect.anything(), + ); + }); + + it('does not create subscription checkout for one-time SKUs', async () => { + const service = new BillingService( + mockStripeService({ + checkout: { sessions: { create: jest.fn() } }, + }), + { syncSubscriptionItem: jest.fn() } as never, + ); + + await expect( + service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'background_check_one_time', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }), + ).rejects.toBeInstanceOf(BadRequestException); + }); + + it('returns a controlled error when Stripe is not configured', async () => { + const service = new BillingService(mockStripeService(null), { + syncSubscriptionItem: jest.fn(), + } as never); + + await expect( + service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_1', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }), + ).rejects.toMatchObject({ + status: HttpStatus.PAYMENT_REQUIRED, + }); + }); +}); diff --git a/apps/api/src/billing/billing.service.ts b/apps/api/src/billing/billing.service.ts new file mode 100644 index 0000000000..43fa2d84cb --- /dev/null +++ b/apps/api/src/billing/billing.service.ts @@ -0,0 +1,274 @@ +import { + BadRequestException, + Injectable, + NotFoundException, +} from '@nestjs/common'; +import { db } from '@db'; +import { + type BillingProductKey, + getBillingSku, + getBillingSkuProductKey, + resolveBillingCatalogEnvironment, + isSubscriptionBillingSkuKey, +} from '@trycompai/billing'; +import { StripeService } from '../stripe/stripe.service'; +import { findOrCreateBillingCustomer } from './billing-customer'; +import { BillingEntitlementsService } from './billing-entitlements.service'; +import { listBillingInvoices } from './billing-invoices'; +import { + type BillingPreferencesInput, + getBillingPreferences, + updateBillingPreferences, +} from './billing-preferences'; +import { validateBillingRedirectUrl } from './billing-redirect-urls'; +import { + createBillingSetupSession, + handleBillingSetupSuccess, +} from './billing-setup-sessions'; +import { assertStripeBillingConfigured } from './billing-stripe-config'; +import { + changeSubscriptionPlan, + findProductSubscriptions, +} from './billing-subscription-plans'; +import type { BillingStatus } from './billing.types'; +import { listBillingUsageRows } from './billing-usage'; + +@Injectable() +export class BillingService { + constructor( + private readonly stripeService: StripeService, + private readonly entitlements: BillingEntitlementsService, + ) {} + + async getStatus(organizationId: string): Promise { + const [ + organization, + billing, + subscriptions, + backgroundChecks, + penetrationTests, + ] = await Promise.all([ + db.organization.findUniqueOrThrow({ + where: { id: organizationId }, + select: { name: true }, + }), + db.organizationBilling.findUnique({ + where: { organizationId }, + select: { + stripeCustomerId: true, + stripePaymentMethodId: true, + paymentMethodUpdatedAt: true, + }, + }), + db.organizationBillingSubscription.findMany({ + where: { organizationId }, + orderBy: { skuKey: 'asc' }, + }), + db.backgroundCheckRequest.count({ where: { organizationId } }), + db.securityPenetrationTestRun.count({ where: { organizationId } }), + ]); + const [invoices, preferences, usageRows] = await Promise.all([ + listBillingInvoices({ + stripeService: this.stripeService, + stripeCustomerId: billing?.stripeCustomerId ?? null, + }), + getBillingPreferences({ + stripeService: this.stripeService, + stripeCustomerId: billing?.stripeCustomerId ?? null, + fallbackCompanyName: organization.name, + }), + listBillingUsageRows({ organizationId, subscriptions }), + ]); + + return { + hasBilling: !!billing, + hasPaymentMethod: !!billing?.stripePaymentMethodId, + setupAt: billing?.paymentMethodUpdatedAt ?? null, + usage: { backgroundChecks, penetrationTests }, + preferences, + trialEligibility: getTrialEligibility(subscriptions), + usageRows, + subscriptions: subscriptions.map((subscription) => ({ + skuKey: subscription.skuKey, + status: subscription.stripeStatus, + includedQuantity: subscription.includedQuantity, + usedQuantity: subscription.usedQuantity, + currentPeriodStart: + subscription.currentPeriodStart?.toISOString() ?? null, + currentPeriodEnd: subscription.currentPeriodEnd?.toISOString() ?? null, + cancelAtPeriodEnd: subscription.cancelAtPeriodEnd, + })), + invoices, + }; + } + + async updatePreferences(params: { + organizationId: string; + preferences: BillingPreferencesInput; + }): Promise { + assertStripeBillingConfigured(this.stripeService); + + const result = await updateBillingPreferences({ + stripeService: this.stripeService, + organizationId: params.organizationId, + preferences: params.preferences, + }); + + await db.billingAuditEvent.create({ + data: { + organizationId: params.organizationId, + eventType: 'billing_preferences_updated', + metadata: { + stripeCustomerId: result.stripeCustomerId, + billingEmail: result.preferences.billingEmail, + }, + }, + }); + + return this.getStatus(params.organizationId); + } + + async createSetupSession(params: { + organizationId: string; + successUrl: string; + cancelUrl: string; + customerEmail?: string; + }): Promise<{ url: string }> { + return createBillingSetupSession({ + ...params, + stripeService: this.stripeService, + }); + } + + async handleSetupSuccess(params: { + organizationId: string; + sessionId: string; + }): Promise<{ success: true }> { + return handleBillingSetupSuccess({ + ...params, + stripeService: this.stripeService, + }); + } + + async createBillingPortalSession(params: { + organizationId: string; + returnUrl: string; + }): Promise<{ url: string }> { + validateBillingRedirectUrl(params.returnUrl); + assertStripeBillingConfigured(this.stripeService); + + const billing = await db.organizationBilling.findUnique({ + where: { organizationId: params.organizationId }, + select: { stripeCustomerId: true }, + }); + if (!billing) { + throw new NotFoundException( + 'No billing record found for this organization.', + ); + } + + const session = await this.stripeService + .getClient() + .billingPortal.sessions.create({ + customer: billing.stripeCustomerId, + return_url: params.returnUrl, + }); + return { url: session.url }; + } + + async createSubscriptionCheckoutSession(params: { + organizationId: string; + skuKey: string; + successUrl: string; + cancelUrl: string; + customerEmail?: string; + }): Promise<{ url: string } | { changed: true }> { + validateBillingRedirectUrl(params.successUrl); + validateBillingRedirectUrl(params.cancelUrl); + if (!isSubscriptionBillingSkuKey(params.skuKey)) { + throw new BadRequestException('Unknown subscription SKU.'); + } + assertStripeBillingConfigured(this.stripeService); + + const environment = resolveBillingCatalogEnvironment({ + stripeSecretKey: process.env.STRIPE_SECRET_KEY, + nodeEnv: process.env.NODE_ENV, + }); + const sku = getBillingSku({ environment, skuKey: params.skuKey }); + const productSubscriptions = await findProductSubscriptions({ + organizationId: params.organizationId, + productKey: sku.productKey, + }); + const existingSubscription = + productSubscriptions.find( + (subscription) => + subscription.stripeStatus === 'active' || + subscription.stripeStatus === 'trialing', + ) ?? null; + if (existingSubscription) { + if (existingSubscription.skuKey === sku.key) { + throw new BadRequestException('This plan is already active.'); + } + return changeSubscriptionPlan({ + organizationId: params.organizationId, + subscription: existingSubscription, + skuKey: sku.key, + stripePriceId: sku.stripePriceId, + includedQuantity: sku.includedUsage?.quantity ?? 0, + stripeService: this.stripeService, + entitlements: this.entitlements, + }); + } + + const customerId = await findOrCreateBillingCustomer({ + stripeService: this.stripeService, + organizationId: params.organizationId, + customerEmail: params.customerEmail, + }); + const stripe = this.stripeService.getClient(); + const applyTrial = + productSubscriptions.length === 0 && typeof sku.trialDays === 'number'; + const session = await stripe.checkout.sessions.create({ + mode: 'subscription', + customer: customerId, + line_items: [{ price: sku.stripePriceId, quantity: 1 }], + success_url: params.successUrl, + cancel_url: params.cancelUrl, + ...(applyTrial ? { payment_method_collection: 'always' } : {}), + metadata: { + organizationId: params.organizationId, + skuKey: sku.key, + source: 'comp-billing-subscription', + }, + subscription_data: { + ...(applyTrial ? { trial_period_days: sku.trialDays } : {}), + metadata: { + organizationId: params.organizationId, + skuKey: sku.key, + source: 'comp-billing-subscription', + }, + }, + }); + + if (!session.url) { + throw new BadRequestException( + 'Failed to create Stripe Checkout session.', + ); + } + return { url: session.url }; + } +} + +function getTrialEligibility( + subscriptions: Array<{ skuKey: string }>, +): Record { + const productHistory = new Set(); + for (const subscription of subscriptions) { + const productKey = getBillingSkuProductKey(subscription.skuKey); + if (productKey) productHistory.add(productKey); + } + return { + pentest: !productHistory.has('pentest'), + background_check: !productHistory.has('background_check'), + }; +} diff --git a/apps/api/src/billing/billing.types.ts b/apps/api/src/billing/billing.types.ts new file mode 100644 index 0000000000..f9db394ae9 --- /dev/null +++ b/apps/api/src/billing/billing.types.ts @@ -0,0 +1,42 @@ +import type { BillingInvoice } from './billing-invoices'; +import type { BillingPreferences } from './billing-preferences'; + +export interface BillingStatus { + hasBilling: boolean; + hasPaymentMethod: boolean; + setupAt: Date | null; + usage: { + backgroundChecks: number; + penetrationTests: number; + }; + preferences: BillingPreferences; + trialEligibility: { + pentest: boolean; + background_check: boolean; + }; + usageRows: BillingUsageRow[]; + subscriptions: Array<{ + skuKey: string; + status: string; + includedQuantity: number; + usedQuantity: number; + currentPeriodStart: string | null; + currentPeriodEnd: string | null; + cancelAtPeriodEnd: boolean; + }>; + invoices: BillingInvoice[]; +} + +export interface BillingUsageRow { + id: string; + service: 'Penetration Test' | 'Background Check'; + skuKey: string; + details: string; + status: string; + billingType: string; + createdAt: string; + updatedAt: string; + subscriptionRemaining: number | null; + subscriptionIncluded: number | null; + subscriptionPeriodEnd: string | null; +} diff --git a/apps/api/src/billing/dto/billing.dto.ts b/apps/api/src/billing/dto/billing.dto.ts new file mode 100644 index 0000000000..e0ba0a6b0a --- /dev/null +++ b/apps/api/src/billing/dto/billing.dto.ts @@ -0,0 +1,101 @@ +import { subscriptionBillingSkuKeys } from '@trycompai/billing'; +import { + IsEmail, + IsIn, + IsNotEmpty, + IsOptional, + IsString, + IsUrl, + Length, +} from 'class-validator'; +import { billingTaxIdTypes } from '../billing-preferences'; + +export class BillingSetupSessionDto { + @IsString() + @IsUrl({ require_tld: false }, { message: 'successUrl must be a valid URL' }) + successUrl: string; + + @IsString() + @IsUrl({ require_tld: false }, { message: 'cancelUrl must be a valid URL' }) + cancelUrl: string; +} + +export class BillingSetupSuccessDto { + @IsString() + @IsNotEmpty() + sessionId: string; +} + +export class BillingPortalDto { + @IsString() + @IsUrl({ require_tld: false }, { message: 'returnUrl must be a valid URL' }) + returnUrl: string; +} + +export class BillingSubscriptionCheckoutDto { + @IsString() + @IsIn(subscriptionBillingSkuKeys) + skuKey: string; + + @IsString() + @IsUrl({ require_tld: false }, { message: 'successUrl must be a valid URL' }) + successUrl: string; + + @IsString() + @IsUrl({ require_tld: false }, { message: 'cancelUrl must be a valid URL' }) + cancelUrl: string; +} + +export class BillingPreferencesDto { + @IsString() + @Length(1, 150) + companyName: string; + + @IsEmail() + billingEmail: string; + + @IsOptional() + @IsString() + @Length(0, 140) + purchaseOrder: string | null; + + @IsOptional() + @IsString() + @Length(0, 200) + addressLine1: string | null; + + @IsOptional() + @IsString() + @Length(0, 200) + addressLine2: string | null; + + @IsOptional() + @IsString() + @Length(0, 100) + addressCity: string | null; + + @IsOptional() + @IsString() + @Length(0, 100) + addressState: string | null; + + @IsOptional() + @IsString() + @Length(0, 32) + addressPostalCode: string | null; + + @IsOptional() + @IsString() + @Length(0, 2) + addressCountry: string | null; + + @IsOptional() + @IsString() + @IsIn([...billingTaxIdTypes, '']) + taxIdType: string | null; + + @IsOptional() + @IsString() + @Length(0, 64) + taxIdValue: string | null; +} diff --git a/apps/api/src/billing/stripe-webhook-records.spec.ts b/apps/api/src/billing/stripe-webhook-records.spec.ts new file mode 100644 index 0000000000..a0dc29f321 --- /dev/null +++ b/apps/api/src/billing/stripe-webhook-records.spec.ts @@ -0,0 +1,124 @@ +import { Prisma, db } from '@db'; +import { claimStripeWebhookEvent } from './stripe-webhook-records'; + +jest.mock('@db', () => { + class PrismaClientKnownRequestError extends Error { + code: string; + + constructor(message: string, params: { code: string }) { + super(message); + this.code = params.code; + } + } + + return { + Prisma: { + PrismaClientKnownRequestError, + }, + db: { + stripeWebhookEvent: { + create: jest.fn(), + updateMany: jest.fn(), + update: jest.fn(), + }, + }, + }; +}); + +const stripeWebhookEventCreate = db.stripeWebhookEvent + .create as unknown as jest.Mock; +const stripeWebhookEventUpdateMany = db.stripeWebhookEvent + .updateMany as unknown as jest.Mock; + +describe('stripe webhook records', () => { + beforeEach(() => { + jest.clearAllMocks(); + jest + .useFakeTimers() + .setSystemTime(new Date('2026-04-30T12:00:00.000Z').getTime()); + }); + + afterEach(() => { + jest.useRealTimers(); + }); + + it('claims a new Stripe webhook event', async () => { + stripeWebhookEventCreate.mockResolvedValue({}); + + await expect( + claimStripeWebhookEvent({ + stripeEventId: 'evt_1', + eventType: 'invoice.paid', + payload: { id: 'in_1' }, + }), + ).resolves.toEqual({ status: 'claimed' }); + + expect(stripeWebhookEventCreate).toHaveBeenCalledWith({ + data: expect.objectContaining({ + stripeEventId: 'evt_1', + status: 'processing', + }), + }); + }); + + it('atomically reclaims failed webhook events', async () => { + stripeWebhookEventCreate.mockRejectedValue( + new Prisma.PrismaClientKnownRequestError('Unique', { + code: 'P2002', + clientVersion: 'test', + }), + ); + stripeWebhookEventUpdateMany.mockResolvedValue({ count: 1 }); + + await expect( + claimStripeWebhookEvent({ + stripeEventId: 'evt_1', + eventType: 'invoice.paid', + payload: { id: 'in_1' }, + }), + ).resolves.toEqual({ status: 'claimed' }); + + expect(stripeWebhookEventUpdateMany).toHaveBeenCalledWith( + expect.objectContaining({ + where: expect.objectContaining({ + stripeEventId: 'evt_1', + OR: expect.arrayContaining([ + expect.objectContaining({ status: 'failed' }), + ]), + }), + data: expect.objectContaining({ status: 'processing', error: null }), + }), + ); + }); + + it('reclaims only one stale processing webhook retry', async () => { + stripeWebhookEventCreate.mockRejectedValue( + new Prisma.PrismaClientKnownRequestError('Unique', { + code: 'P2002', + clientVersion: 'test', + }), + ); + stripeWebhookEventUpdateMany.mockResolvedValue({ count: 0 }); + + await expect( + claimStripeWebhookEvent({ + stripeEventId: 'evt_1', + eventType: 'invoice.paid', + payload: { id: 'in_1' }, + }), + ).resolves.toEqual({ status: 'duplicate' }); + + expect(stripeWebhookEventUpdateMany).toHaveBeenCalledWith( + expect.objectContaining({ + where: expect.objectContaining({ + OR: expect.arrayContaining([ + expect.objectContaining({ + status: 'processing', + processedAt: { lt: new Date('2026-04-30T11:45:00.000Z') }, + }), + ]), + }), + }), + ); + }); +}); diff --git a/apps/api/src/billing/stripe-webhook-records.ts b/apps/api/src/billing/stripe-webhook-records.ts new file mode 100644 index 0000000000..76d28deb1b --- /dev/null +++ b/apps/api/src/billing/stripe-webhook-records.ts @@ -0,0 +1,80 @@ +import { Prisma, db } from '@db'; + +const processingReclaimAfterMs = 15 * 60 * 1000; + +export type StripeWebhookClaim = + | { status: 'claimed' } + | { status: 'duplicate' }; + +export async function claimStripeWebhookEvent(params: { + stripeEventId: string; + eventType: string; + payload: Prisma.InputJsonValue; +}): Promise { + try { + await db.stripeWebhookEvent.create({ + data: { + stripeEventId: params.stripeEventId, + eventType: params.eventType, + payload: params.payload, + status: 'processing', + }, + }); + return { status: 'claimed' }; + } catch (error) { + if (!isUniqueConstraintError(error)) throw error; + const reclaimBefore = new Date(Date.now() - processingReclaimAfterMs); + const reclaimed = await db.stripeWebhookEvent.updateMany({ + where: { + stripeEventId: params.stripeEventId, + OR: [ + { status: 'failed' }, + { status: 'processing', processedAt: { lt: reclaimBefore } }, + ], + }, + data: { + eventType: params.eventType, + payload: params.payload, + status: 'processing', + error: null, + processedAt: new Date(), + }, + }); + + if (reclaimed.count === 1) { + return { status: 'claimed' }; + } + + return { status: 'duplicate' }; + } +} + +export async function markStripeWebhookProcessed( + stripeEventId: string, +): Promise { + await db.stripeWebhookEvent.update({ + where: { stripeEventId }, + data: { status: 'processed', error: null, processedAt: new Date() }, + }); +} + +export async function markStripeWebhookFailed(params: { + stripeEventId: string; + error: unknown; +}): Promise { + await db.stripeWebhookEvent.update({ + where: { stripeEventId: params.stripeEventId }, + data: { + status: 'failed', + error: + params.error instanceof Error ? params.error.message : 'Unknown error', + }, + }); +} + +function isUniqueConstraintError(error: unknown): boolean { + return ( + error instanceof Prisma.PrismaClientKnownRequestError && + error.code === 'P2002' + ); +} diff --git a/apps/api/src/main.ts b/apps/api/src/main.ts index babfc7e119..d46f1dd99c 100644 --- a/apps/api/src/main.ts +++ b/apps/api/src/main.ts @@ -93,6 +93,8 @@ async function bootstrap(): Promise { '/security-penetration-tests/webhook', '/v1/background-checks/webhook', '/background-checks/webhook', + '/v1/billing/webhook', + '/billing/webhook', ]; const needsRawBody = (req: express.Request): boolean => RAW_BODY_PATHS.some((p) => req.path.endsWith(p)); diff --git a/apps/api/src/security-penetration-tests/pentest-credits.service.ts b/apps/api/src/security-penetration-tests/pentest-credits.service.ts index d2afc75bad..4e89994752 100644 --- a/apps/api/src/security-penetration-tests/pentest-credits.service.ts +++ b/apps/api/src/security-penetration-tests/pentest-credits.service.ts @@ -3,8 +3,10 @@ import { HttpStatus, Injectable, Logger, + Optional, } from '@nestjs/common'; import { AuditLogEntityType, db, Prisma } from '@db'; +import { BillingCreditsService } from '../billing/billing-credits.service'; /** * Source of credits — free-form so v2 can add new sources without a schema @@ -45,6 +47,9 @@ export type PentestAuditAction = @Injectable() export class PentestCreditsService { private readonly logger = new Logger(PentestCreditsService.name); + constructor( + @Optional() private readonly billingCredits?: BillingCreditsService, + ) {} /** * Default trial grant for new orgs. Static today; in v2 this can become a @@ -54,12 +59,30 @@ export class PentestCreditsService { private readonly initialTrialAmount = 1; async getStatus(organizationId: string): Promise { + if (this.billingCredits) { + const balances = await this.billingCredits.listBalances(organizationId); + const balance = balances.find((item) => item.productKey === 'pentest'); + return balance + ? { + balance: balance.balance, + totalGranted: balance.totalGranted, + totalConsumed: balance.totalConsumed, + lastGrantSource: balance.lastSource, + } + : { + balance: 0, + totalGranted: 0, + totalConsumed: 0, + lastGrantSource: 'none', + }; + } + const row = await db.pentestCredits.findUnique({ where: { organizationId }, }); if (!row) { - // Org has no row yet. Return a zero balance — the client UI treats - // this as "no trial granted, paid plans coming soon." + // Org has no row yet. Return a zero balance for historical/admin + // callers; customer-facing scan creation now uses subscriptions. return { balance: 0, totalGranted: 0, @@ -99,6 +122,16 @@ export class PentestCreditsService { `grant amount must be positive (got ${amount} for org=${organizationId})`, ); } + if (this.billingCredits) { + await this.billingCredits.grant({ + organizationId, + productKey: 'pentest', + quantity: amount, + source, + note: `Pentest credits granted from ${source}`, + }); + return; + } await db.pentestCredits.upsert({ where: { organizationId }, create: { @@ -140,6 +173,25 @@ export class PentestCreditsService { runId?: string, tx?: Pick, ): Promise { + if (this.billingCredits && !tx) { + const result = await this.billingCredits.tryConsumeForProduct({ + organizationId, + productKey: 'pentest', + sourceResourceId: runId ?? 'pending', + }); + if (result.status !== 'consumed') { + throw new HttpException( + { + error: + 'No pentest runs remaining. Choose a penetration test plan to continue.', + code: 'pentest_credits_exhausted', + }, + HttpStatus.PAYMENT_REQUIRED, + ); + } + return this.getStatus(organizationId); + } + const client = tx ?? db; const updated = await client.pentestCredits.updateMany({ where: { organizationId, balance: { gt: 0 } }, @@ -157,7 +209,7 @@ export class PentestCreditsService { throw new HttpException( { error: - 'No pentest runs remaining. Paid plans coming soon — contact support if you need access today.', + 'No pentest runs remaining. Choose a penetration test plan to continue.', code: 'pentest_credits_exhausted', }, HttpStatus.PAYMENT_REQUIRED, @@ -177,7 +229,12 @@ export class PentestCreditsService { totalConsumed: row.totalConsumed, lastGrantSource: row.lastGrantSource, } - : { balance: 0, totalGranted: 0, totalConsumed: 0, lastGrantSource: 'none' }; + : { + balance: 0, + totalGranted: 0, + totalConsumed: 0, + lastGrantSource: 'none', + }; this.logger.log( `[Credits] debit org=${organizationId} run=${runId ?? 'pending'} balance=${status.balance}`, ); @@ -202,6 +259,16 @@ export class PentestCreditsService { reason: string, tx?: Pick, ): Promise { + if (this.billingCredits && !tx) { + await this.billingCredits.refundForProduct({ + organizationId, + productKey: 'pentest', + sourceResourceId: runId ?? 'pending', + reason, + }); + return; + } + // Use the optional tx client when provided so the caller can wrap // claim+refund in a single transaction (webhook idempotency: if the // refund DB write fails, the claim rolls back and a redelivered diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.billing.spec.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.billing.spec.ts new file mode 100644 index 0000000000..830150233d --- /dev/null +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.billing.spec.ts @@ -0,0 +1,244 @@ +import { db } from '@db'; +import { HttpException, HttpStatus } from '@nestjs/common'; +import type { BillingEntitlementsService } from '../billing/billing-entitlements.service'; +import type { PentestCreditsService } from './pentest-credits.service'; +import { SecurityPenetrationTestsService } from './security-penetration-tests.service'; + +const mockMacedPentestsCreate = jest.fn(); + +jest.mock( + '@maced/api-client', + () => ({ + createMacedClient: () => ({ + pentests: { + create: mockMacedPentestsCreate, + }, + }), + MacedApiError: class MacedApiError extends Error {}, + MacedWebhookSignatureError: class MacedWebhookSignatureError extends Error { + code = 'invalid_signature'; + }, + MacedClient: { + webhooks: { + constructEvent: jest.fn(), + }, + }, + }), + { virtual: true }, +); + +jest.mock('@db', () => ({ + db: { + securityPenetrationTestRun: { + upsert: jest.fn(), + updateMany: jest.fn(), + findUnique: jest.fn(), + }, + secret: { + upsert: jest.fn(), + }, + $transaction: jest.fn(), + }, +})); + +type MockDb = { + securityPenetrationTestRun: { + upsert: jest.Mock; + updateMany: jest.Mock; + findUnique: jest.Mock; + }; + secret: { + upsert: jest.Mock; + }; + $transaction: jest.Mock; +}; + +describe('SecurityPenetrationTestsService billing usage', () => { + const originalMacedApiKey = process.env.MACED_API_KEY; + const mockedDb = db as unknown as MockDb; + const credits: jest.Mocked< + Pick< + PentestCreditsService, + 'getStatus' | 'debitOrThrow' | 'refund' | 'writePentestAuditEntry' + > + > = { + getStatus: jest.fn(), + debitOrThrow: jest.fn(), + refund: jest.fn(), + writePentestAuditEntry: jest.fn(), + }; + const billingEntitlements: jest.Mocked< + Pick< + BillingEntitlementsService, + 'tryConsumeIncludedUsageForProduct' | 'refundIncludedUsageForProduct' + > + > = { + tryConsumeIncludedUsageForProduct: jest.fn(), + refundIncludedUsageForProduct: jest.fn(), + }; + let service: SecurityPenetrationTestsService; + + beforeEach(() => { + jest.clearAllMocks(); + process.env.MACED_API_KEY = 'mc_dev_test_maced_api_key'; + mockMacedPentestsCreate.mockResolvedValue({ + id: 'run_subscription', + status: 'provisioning', + }); + credits.debitOrThrow.mockResolvedValue({ + balance: 4, + totalGranted: 5, + totalConsumed: 1, + lastGrantSource: 'trial', + }); + credits.refund.mockResolvedValue(); + credits.writePentestAuditEntry.mockResolvedValue(); + billingEntitlements.tryConsumeIncludedUsageForProduct.mockResolvedValue({ + status: 'consumed', + subscriptionId: 'obs_1', + }); + billingEntitlements.refundIncludedUsageForProduct.mockResolvedValue(); + mockedDb.securityPenetrationTestRun.upsert.mockResolvedValue({}); + mockedDb.securityPenetrationTestRun.updateMany.mockResolvedValue({ + count: 1, + }); + mockedDb.$transaction.mockImplementation( + (callback: (tx: MockDb) => Promise) => callback(mockedDb), + ); + service = new SecurityPenetrationTestsService( + credits as unknown as PentestCreditsService, + billingEntitlements as unknown as BillingEntitlementsService, + ); + }); + + afterAll(() => { + process.env.MACED_API_KEY = originalMacedApiKey; + }); + + it('persists the subscription usage source on subscription-backed runs', async () => { + await service.createReport('org_123', { + targetUrl: 'https://app.example.com', + repoUrl: 'https://github.com/org/repo', + }); + + expect(credits.debitOrThrow).not.toHaveBeenCalled(); + expect(mockedDb.securityPenetrationTestRun.upsert).toHaveBeenCalledWith( + expect.objectContaining({ + create: expect.objectContaining({ + billingUsageSourceId: expect.stringMatching(/^pending:/), + }), + }), + ); + }); + + it('requires a subscription or free trial instead of debiting wallet credits', async () => { + billingEntitlements.tryConsumeIncludedUsageForProduct.mockResolvedValue({ + status: 'not_configured', + }); + + await expect( + service.createReport('org_123', { + targetUrl: 'https://app.example.com', + }), + ).rejects.toMatchObject({ + status: 402, + response: expect.objectContaining({ + code: 'pentest_subscription_required', + }), + }); + + expect(credits.debitOrThrow).not.toHaveBeenCalled(); + expect(credits.writePentestAuditEntry).toHaveBeenCalledWith( + expect.objectContaining({ + organizationId: 'org_123', + action: 'pentest_create_blocked', + metadata: expect.objectContaining({ + reason: 'pentest_subscription_required', + }), + }), + ); + expect(mockMacedPentestsCreate).not.toHaveBeenCalled(); + }); + + it('preserves exhausted subscription reasons from string payment errors', async () => { + billingEntitlements.tryConsumeIncludedUsageForProduct.mockRejectedValue( + new HttpException( + 'pentest_subscription_exhausted', + HttpStatus.PAYMENT_REQUIRED, + ), + ); + + await expect( + service.createReport('org_123', { + targetUrl: 'https://app.example.com', + }), + ).rejects.toMatchObject({ + status: 402, + }); + + expect(credits.writePentestAuditEntry).toHaveBeenCalledWith( + expect.objectContaining({ + organizationId: 'org_123', + action: 'pentest_create_blocked', + metadata: expect.objectContaining({ + reason: 'pentest_subscription_exhausted', + }), + }), + ); + expect(mockMacedPentestsCreate).not.toHaveBeenCalled(); + }); + + it('refunds subscription usage on terminal failure for subscription-backed runs', async () => { + mockedDb.securityPenetrationTestRun.findUnique.mockResolvedValue({ + organizationId: 'org_123', + billingUsageSourceId: 'pending:run_subscription', + }); + const refundInvoker = service as unknown as { + refundOnTerminalFailure: ( + providerRunId: string, + eventType: 'pentest.failed', + ) => Promise; + }; + + await refundInvoker.refundOnTerminalFailure( + 'run_subscription', + 'pentest.failed', + ); + + expect( + billingEntitlements.refundIncludedUsageForProduct, + ).toHaveBeenCalledWith({ + organizationId: 'org_123', + productKey: 'pentest', + sourceResourceId: 'pending:run_subscription', + reason: 'pentest.failed', + tx: mockedDb, + }); + expect(credits.refund).not.toHaveBeenCalled(); + }); + + it('keeps legacy credit refunds for terminal failures without subscription usage', async () => { + mockedDb.securityPenetrationTestRun.findUnique.mockResolvedValue({ + organizationId: 'org_123', + billingUsageSourceId: null, + }); + const refundInvoker = service as unknown as { + refundOnTerminalFailure: ( + providerRunId: string, + eventType: 'pentest.failed', + ) => Promise; + }; + + await refundInvoker.refundOnTerminalFailure('run_legacy', 'pentest.failed'); + + expect(credits.refund).toHaveBeenCalledWith( + 'org_123', + 'run_legacy', + 'pentest.failed', + mockedDb, + ); + expect( + billingEntitlements.refundIncludedUsageForProduct, + ).not.toHaveBeenCalled(); + }); +}); diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.module.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.module.ts index 1560d1c85e..1f648b8f34 100644 --- a/apps/api/src/security-penetration-tests/security-penetration-tests.module.ts +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.module.ts @@ -1,12 +1,13 @@ import { Module } from '@nestjs/common'; import { AuthModule } from '../auth/auth.module'; +import { BillingModule } from '../billing/billing.module'; import { PentestCreditsController } from './pentest-credits.controller'; import { PentestCreditsService } from './pentest-credits.service'; import { SecurityPenetrationTestsController } from './security-penetration-tests.controller'; import { SecurityPenetrationTestsService } from './security-penetration-tests.service'; @Module({ - imports: [AuthModule], + imports: [AuthModule, BillingModule], controllers: [SecurityPenetrationTestsController, PentestCreditsController], providers: [SecurityPenetrationTestsService, PentestCreditsService], exports: [PentestCreditsService], diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.service.spec.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.service.spec.ts index 2afda1b79c..5c26b06c51 100644 --- a/apps/api/src/security-penetration-tests/security-penetration-tests.service.spec.ts +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.service.spec.ts @@ -1,6 +1,7 @@ import { HttpException, HttpStatus } from '@nestjs/common'; import { db } from '@db'; import { createHash } from 'node:crypto'; +import type { BillingEntitlementsService } from '../billing/billing-entitlements.service'; import type { CredentialVaultService } from '../integration-platform/services/credential-vault.service'; import type { CreatePenetrationTestDto } from './dto/create-penetration-test.dto'; import type { PentestCreditsService } from './pentest-credits.service'; @@ -22,6 +23,16 @@ const mockPentestCreditsService: jest.Mocked< refund: jest.fn(), }; +const mockBillingEntitlementsService: jest.Mocked< + Pick< + BillingEntitlementsService, + 'tryConsumeIncludedUsageForProduct' | 'refundIncludedUsageForProduct' + > +> = { + tryConsumeIncludedUsageForProduct: jest.fn(), + refundIncludedUsageForProduct: jest.fn(), +}; + jest.mock('@db', () => ({ db: { securityPenetrationTestRun: { @@ -102,8 +113,13 @@ describe('SecurityPenetrationTestsService', () => { lastGrantSource: 'trial', }); mockPentestCreditsService.refund.mockResolvedValue(); + mockBillingEntitlementsService.tryConsumeIncludedUsageForProduct.mockResolvedValue({ + status: 'not_configured', + }); + mockBillingEntitlementsService.refundIncludedUsageForProduct.mockResolvedValue(); service = new SecurityPenetrationTestsService( mockPentestCreditsService as unknown as PentestCreditsService, + mockBillingEntitlementsService as unknown as BillingEntitlementsService, ); fetchMock.mockReset(); global.fetch = fetchMock as unknown as typeof fetch; diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts index 8217a2e420..2cb341475b 100644 --- a/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts @@ -21,8 +21,10 @@ import { type PentestProgress as MacedPentestProgress, type PentestWithProgress, } from '@maced/api-client'; +import { randomUUID } from 'crypto'; import type { CreatePenetrationTestDto } from './dto/create-penetration-test.dto'; +import { BillingEntitlementsService } from '../billing/billing-entitlements.service'; import { PentestCreditsService } from './pentest-credits.service'; /** @@ -112,7 +114,10 @@ export class SecurityPenetrationTestsService { private readonly logger = new Logger(SecurityPenetrationTestsService.name); private readonly macedClient: MacedClient; - constructor(private readonly credits: PentestCreditsService) { + constructor( + private readonly credits: PentestCreditsService, + private readonly billingEntitlements: BillingEntitlementsService, + ) { const apiKey = process.env.MACED_API_KEY; if (!apiKey) { // Throw at construction so the app fails loudly on boot, not on first request. @@ -168,8 +173,7 @@ export class SecurityPenetrationTestsService { // …). Include the constructor name + message in the log so we can tell // what actually broke without a debugger. const errName = error?.constructor?.name ?? typeof error; - const errMessage = - error instanceof Error ? error.message : String(error); + const errMessage = error instanceof Error ? error.message : String(error); this.logger.error( `Transport failure calling Maced (${context}): ${errName} — ${errMessage}`, ); @@ -222,32 +226,60 @@ export class SecurityPenetrationTestsService { payload: CreatePenetrationTestDto, ): Promise { const resolvedWebhookUrl = this.resolveWebhookUrl(payload.webhookUrl); + const billingUsageSourceId = `pending:${randomUUID()}`; + let consumedSubscriptionAllowance = false; - // Debit FIRST so concurrent fast-clicks block at the cheap DB - // conditional update before any of them reach Maced. Without this, - // double-clicks would race past the balance check, all hit Maced - // (creating multiple paid runs), and only one would win the debit — - // we'd burn money on orphaned provider-side runs. Atomic - // `updateMany WHERE balance > 0` guarantees only one decrement - // succeeds. + // Reserve subscription allowance before calling Maced so fast double-clicks + // cannot create more paid provider runs than the organization has available. try { - await this.credits.debitOrThrow(organizationId); + const subscriptionUsage = + await this.billingEntitlements.tryConsumeIncludedUsageForProduct({ + organizationId, + productKey: 'pentest', + sourceResourceId: billingUsageSourceId, + }); + if (subscriptionUsage.status === 'exhausted') { + throw new HttpException( + { + error: + 'No pentest runs remaining in your subscription. Upgrade or wait for your monthly allowance to reset.', + code: 'pentest_subscription_exhausted', + }, + HttpStatus.PAYMENT_REQUIRED, + ); + } + if (subscriptionUsage.status === 'not_configured') { + throw new HttpException( + { + error: 'Start a penetration test plan or free trial to run scans.', + code: 'pentest_subscription_required', + }, + HttpStatus.PAYMENT_REQUIRED, + ); + } else { + consumedSubscriptionAllowance = true; + } } catch (error) { if ( error instanceof HttpException && error.getStatus() === HttpStatus.PAYMENT_REQUIRED ) { // Record the blocked attempt so support / compliance can answer - // "did the user try to scan after their trial was used?". Best- + // "did the user try to scan without an allowance?". Best- // effort — never let an audit-log failure hide the 402 from the // user. + const response = error.getResponse(); + const reason = getPaymentRequiredCode(response); await this.credits.writePentestAuditEntry({ organizationId, action: 'pentest_create_blocked', runId: null, - description: 'Pentest create blocked: no credits remaining', + description: + reason === 'pentest_subscription_exhausted' + ? 'Pentest create blocked: subscription exhausted' + : 'Pentest create blocked: subscription required', metadata: { - reason: 'pentest_credits_exhausted', + reason, targetUrl: payload.targetUrl, }, }); @@ -289,17 +321,37 @@ export class SecurityPenetrationTestsService { } catch (error) { // Provider call failed after we debited. Refund so the user isn't // charged for a run that never started. - await this.refundQuietly(organizationId, 'pending', 'maced_create_failed'); + if (consumedSubscriptionAllowance) { + await this.refundBillingUsageQuietly({ + organizationId, + sourceResourceId: billingUsageSourceId, + reason: 'maced_create_failed', + }); + } else { + await this.refundQuietly( + organizationId, + 'pending', + 'maced_create_failed', + ); + } throw error; } const providerRunId = createdReport.id; if (!providerRunId) { - await this.refundQuietly( - organizationId, - 'pending', - 'maced_missing_run_id', - ); + if (consumedSubscriptionAllowance) { + await this.refundBillingUsageQuietly({ + organizationId, + sourceResourceId: billingUsageSourceId, + reason: 'maced_missing_run_id', + }); + } else { + await this.refundQuietly( + organizationId, + 'pending', + 'maced_missing_run_id', + ); + } throw new HttpException( { error: 'Create response missing report identifier' }, HttpStatus.BAD_GATEWAY, @@ -309,6 +361,7 @@ export class SecurityPenetrationTestsService { const ownershipPersisted = await this.persistRunOwnershipWithRetry( organizationId, providerRunId, + consumedSubscriptionAllowance ? billingUsageSourceId : null, ); if (!ownershipPersisted) { // We debited and Maced created the run, but our DB rejected the @@ -316,11 +369,19 @@ export class SecurityPenetrationTestsService { // shouldn't pay for it. The Maced run is orphaned (no // ownership) but Maced has the `compOrganizationId` metadata if // support ever needs to clean it up. - await this.refundQuietly( - organizationId, - providerRunId, - 'ownership_persist_failed', - ); + if (consumedSubscriptionAllowance) { + await this.refundBillingUsageQuietly({ + organizationId, + sourceResourceId: billingUsageSourceId, + reason: 'ownership_persist_failed', + }); + } else { + await this.refundQuietly( + organizationId, + providerRunId, + 'ownership_persist_failed', + ); + } throw new HttpException( { error: @@ -374,10 +435,7 @@ export class SecurityPenetrationTestsService { ); } - async getReportIssues( - organizationId: string, - id: string, - ): Promise { + async getReportIssues(organizationId: string, id: string): Promise { await this.assertRunOwnership(organizationId, id); return this.callMaced( () => this.macedClient.pentests.issues(id), @@ -448,7 +506,9 @@ export class SecurityPenetrationTestsService { eventId?: string; }> { if (!params.rawBody) { - throw new BadRequestException('Missing raw body for webhook verification'); + throw new BadRequestException( + 'Missing raw body for webhook verification', + ); } const secret = process.env.MACED_WEBHOOK_SIGNING_SECRET; @@ -524,15 +584,13 @@ export class SecurityPenetrationTestsService { * those are rare race-condition artifacts and don't represent * customer-visible state. */ - private async auditPentestCompleted( - data: { - pentestId: string; - targetUrl: string; - issueCount: number; - durationMs: number; - agentCount: number; - }, - ): Promise { + private async auditPentestCompleted(data: { + pentestId: string; + targetUrl: string; + issueCount: number; + durationMs: number; + agentCount: number; + }): Promise { // Atomic claim — only the first webhook delivery for this run gets // count: 1 back. Subsequent redeliveries see `completed_audit_at` // already set and bail out before writing a duplicate audit row. @@ -617,20 +675,25 @@ export class SecurityPenetrationTestsService { const run = await tx.securityPenetrationTestRun.findUnique({ where: { providerRunId }, - select: { organizationId: true }, + select: { organizationId: true, billingUsageSourceId: true }, }); if (!run) { // Vanishingly rare race; abort the transaction so the claim // rolls back. Webhook redelivery will retry. - throw new Error( - `Run row vanished after claim for ${providerRunId}`, - ); + throw new Error(`Run row vanished after claim for ${providerRunId}`); + } + + if (run.billingUsageSourceId) { + await this.billingEntitlements.refundIncludedUsageForProduct({ + organizationId: run.organizationId, + productKey: 'pentest', + sourceResourceId: run.billingUsageSourceId, + reason: eventType, + tx, + }); + return; } - // Pass the tx client through so the wallet write happens in - // the same transaction as the claim. If this throws, the claim - // is undone and the error propagates up to handleWebhook → Maced - // sees 5xx and redelivers, allowing retry. await this.credits.refund( run.organizationId, providerRunId, @@ -866,9 +929,31 @@ export class SecurityPenetrationTestsService { } } + private async refundBillingUsageQuietly(params: { + organizationId: string; + sourceResourceId: string; + reason: string; + }): Promise { + try { + await this.billingEntitlements.refundIncludedUsageForProduct({ + organizationId: params.organizationId, + productKey: 'pentest', + sourceResourceId: params.sourceResourceId, + reason: params.reason, + }); + } catch (error) { + this.logger.error( + `Billing usage refund failed for org=${params.organizationId} source=${params.sourceResourceId} reason=${params.reason}: ${ + error instanceof Error ? error.message : String(error) + }`, + ); + } + } + private async persistRunOwnership( organizationId: string, reportId: string, + billingUsageSourceId: string | null, ): Promise { // Defensive: if a row already exists for this providerRunId, do NOT // overwrite its organizationId. Maced generates unique providerRunIds @@ -885,6 +970,7 @@ export class SecurityPenetrationTestsService { create: { organizationId, providerRunId: reportId, + billingUsageSourceId, }, update: {}, }); @@ -893,10 +979,15 @@ export class SecurityPenetrationTestsService { private async persistRunOwnershipWithRetry( organizationId: string, reportId: string, + billingUsageSourceId: string | null, ): Promise { for (let attempt = 1; attempt <= 3; attempt += 1) { try { - await this.persistRunOwnership(organizationId, reportId); + await this.persistRunOwnership( + organizationId, + reportId, + billingUsageSourceId, + ); return true; } catch (error) { this.logger.error( @@ -1002,5 +1093,34 @@ export class SecurityPenetrationTestsService { return hosts; } +} +function getPaymentRequiredCode(response: unknown): string { + if (typeof response === 'string') { + return normalizePaymentRequiredCode(response); + } + if (typeof response !== 'object' || response === null) { + return 'pentest_subscription_required'; + } + const code = (response as Record).code; + if (typeof code === 'string') return normalizePaymentRequiredCode(code); + + const message = (response as Record).message; + return typeof message === 'string' + ? normalizePaymentRequiredCode(message) + : 'pentest_subscription_required'; +} + +function normalizePaymentRequiredCode(value: string): string { + if ( + value === 'pentest_subscription_exhausted' || + value.toLowerCase().includes('exhaust') || + value.toLowerCase().includes('remaining') + ) { + return 'pentest_subscription_exhausted'; + } + if (value === 'pentest_subscription_required') { + return 'pentest_subscription_required'; + } + return 'pentest_subscription_required'; } diff --git a/apps/app/package.json b/apps/app/package.json index de6e57de53..118705ed56 100644 --- a/apps/app/package.json +++ b/apps/app/package.json @@ -67,6 +67,7 @@ "@trigger.dev/react-hooks": "4.4.3", "@trigger.dev/sdk": "4.4.3", "@trycompai/auth": "workspace:*", + "@trycompai/billing": "workspace:*", "@trycompai/company": "workspace:*", "@trycompai/db": "workspace:*", "@trycompai/design-system": "^1.0.32", @@ -190,7 +191,7 @@ "build": "next build", "build:docker": "prisma generate --schema=prisma/schema && node ../../packages/db/scripts/fix-generated-extensions.js src/generated/prisma && next build", "db:generate": "bun run db:getschema && prisma generate --schema=prisma/schema && node ../../packages/db/scripts/fix-generated-extensions.js src/generated/prisma", - "db:getschema": "find ../../packages/db/prisma/schema -name '*.prisma' ! -name 'schema.prisma' -exec cp {} prisma/schema/ \\;", + "db:getschema": "find prisma/schema -name '*.prisma' ! -name 'schema.prisma' -delete && find ../../packages/db/prisma/schema -name '*.prisma' ! -name 'schema.prisma' -exec cp {} prisma/schema/ \\;", "db:migrate": "cd ../../packages/db && bunx prisma migrate dev && cd ../../apps/app", "deploy:trigger-prod": "npx trigger.dev@4.4.3 deploy", "dev": "bunx concurrently --kill-others --names \"next,trigger\" --prefix-colors \"yellow,blue\" \"NODE_OPTIONS='--no-deprecation' next dev --turbo -p 3000\" \"NODE_OPTIONS='--no-deprecation' trigger dev\"", diff --git a/apps/app/src/app/(app)/[orgId]/admin/organizations/[adminOrgId]/components/AdminBillingForms.tsx b/apps/app/src/app/(app)/[orgId]/admin/organizations/[adminOrgId]/components/AdminBillingForms.tsx new file mode 100644 index 0000000000..83a0e0e33b --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/admin/organizations/[adminOrgId]/components/AdminBillingForms.tsx @@ -0,0 +1,228 @@ +'use client'; + +import { zodResolver } from '@hookform/resolvers/zod'; +import { + Button, + Input, + Select, + SelectContent, + SelectItem, + SelectTrigger, + Textarea, +} from '@trycompai/design-system'; +import { Controller, useForm } from 'react-hook-form'; +import { z } from 'zod'; +import type { AdminBillingStatus } from './AdminBillingTypes'; + +const creditSchema = z.object({ + productKey: z.enum(['pentest', 'background_check']), + quantity: z.number().int().min(1).max(1000), + note: z.string().min(3).max(500), + confirm: z.string().optional(), +}); + +export type CreditFormValues = z.infer; + +export function CreditGrantForm({ + onSubmit, + loading, +}: { + onSubmit: (values: CreditFormValues) => Promise; + loading: boolean; +}) { + const form = useForm({ + resolver: zodResolver(creditSchema), + defaultValues: { + productKey: 'pentest', + quantity: 1, + note: '', + confirm: '', + }, + }); + + const handleSubmit = form.handleSubmit(async (values) => { + await onSubmit(values); + form.reset({ + productKey: values.productKey, + quantity: 1, + note: '', + confirm: '', + }); + }); + + return ( +
+ ( + + )} + /> + + + +
+ +
+ + ); +} + +const preferenceSchema = z.object({ + companyName: z.string().min(1).max(150), + billingEmail: z.string().email(), + purchaseOrder: z.string().max(140).optional(), + addressLine1: z.string().max(200).optional(), + addressLine2: z.string().max(200).optional(), + addressCity: z.string().max(100).optional(), + addressState: z.string().max(100).optional(), + addressPostalCode: z.string().max(32).optional(), + addressCountry: z.string().max(2).optional(), + taxIdType: z.string().optional(), + taxIdValue: z.string().max(64).optional(), + confirmBillingEmailChange: z.boolean().optional(), + note: z.string().min(3).max(500), +}); + +export type PreferenceFormValues = z.infer; + +export function BillingPreferencesAdminForm({ + status, + onSubmit, + loading, +}: { + status: AdminBillingStatus; + onSubmit: (values: PreferenceFormValues) => Promise; + loading: boolean; +}) { + const form = useForm({ + resolver: zodResolver(preferenceSchema), + defaultValues: { + companyName: status.preferences.companyName ?? '', + billingEmail: status.preferences.billingEmail ?? '', + purchaseOrder: status.preferences.purchaseOrder ?? '', + addressLine1: status.preferences.address.line1 ?? '', + addressLine2: status.preferences.address.line2 ?? '', + addressCity: status.preferences.address.city ?? '', + addressState: status.preferences.address.state ?? '', + addressPostalCode: status.preferences.address.postalCode ?? '', + addressCountry: status.preferences.address.country ?? '', + taxIdType: status.preferences.taxId?.type ?? '', + taxIdValue: status.preferences.taxId?.value ?? '', + confirmBillingEmailChange: false, + note: '', + }, + }); + + return ( +
+ + + + + + + + + + + + +
+