From f316c505c4403034bd184d273a54a2f147977e47 Mon Sep 17 00:00:00 2001 From: Lewis Carhart Date: Thu, 30 Apr 2026 17:39:48 +0100 Subject: [PATCH 01/20] feat(billing): implement background check billing customer and invoice management - Added functions to find or create a billing customer in Stripe and list billing invoices. - Introduced validation for redirect URLs in billing processes. - Updated BackgroundCheckBillingService to utilize new billing customer and invoice functionalities. - Enhanced BackgroundCheckPaymentService to handle invoice creation and payment processing. - Created BillingInvoicesTable component for displaying invoices in the UI. - Updated tests to cover new billing features and ensure proper functionality. --- .../background-check-billing-customer.ts | 53 ++++++ .../background-check-billing-invoices.ts | 59 ++++++ .../background-check-billing-urls.ts | 24 +++ .../background-check-billing.controller.ts | 18 +- .../background-check-billing.service.ts | 104 ++++------- .../background-check-payment.service.spec.ts | 81 +++++++-- .../background-check-payment.service.ts | 139 +++++++++++--- .../background-checks.service.spec.ts | 62 ++++++- .../settings/billing/BillingInvoicesTable.tsx | 170 ++++++++++++++++++ .../billing/BillingSettingsClient.test.tsx | 71 +++++++- .../billing/BillingSettingsClient.tsx | 65 +++---- .../(app)/[orgId]/settings/billing/page.tsx | 7 +- .../(app)/[orgId]/settings/billing/types.ts | 15 ++ 13 files changed, 705 insertions(+), 163 deletions(-) create mode 100644 apps/api/src/background-checks/background-check-billing-customer.ts create mode 100644 apps/api/src/background-checks/background-check-billing-invoices.ts create mode 100644 apps/api/src/background-checks/background-check-billing-urls.ts create mode 100644 apps/app/src/app/(app)/[orgId]/settings/billing/BillingInvoicesTable.tsx diff --git a/apps/api/src/background-checks/background-check-billing-customer.ts b/apps/api/src/background-checks/background-check-billing-customer.ts new file mode 100644 index 0000000000..a6038f2778 --- /dev/null +++ b/apps/api/src/background-checks/background-check-billing-customer.ts @@ -0,0 +1,53 @@ +import { NotFoundException } from '@nestjs/common'; +import { db } from '@db'; +import { StripeService } from '../stripe/stripe.service'; + +export async function findOrCreateBackgroundCheckBillingCustomer({ + stripeService, + organizationId, + customerEmail, +}: { + stripeService: StripeService; + organizationId: string; + customerEmail?: string; +}): Promise { + const existingBilling = await db.organizationBilling.findUnique({ + where: { organizationId }, + select: { stripeCustomerId: true }, + }); + const stripe = stripeService.getClient(); + + if (existingBilling) { + if (customerEmail) { + await stripe.customers.update(existingBilling.stripeCustomerId, { + email: customerEmail, + }); + } + + return existingBilling.stripeCustomerId; + } + + const organization = await db.organization.findUnique({ + where: { id: organizationId }, + select: { name: true }, + }); + + if (!organization) { + throw new NotFoundException('Organization not found.'); + } + + const customer = await stripe.customers.create({ + name: organization.name, + ...(customerEmail ? { email: customerEmail } : {}), + metadata: { organizationId }, + }); + + await db.organizationBilling.create({ + data: { + organizationId, + stripeCustomerId: customer.id, + }, + }); + + return customer.id; +} diff --git a/apps/api/src/background-checks/background-check-billing-invoices.ts b/apps/api/src/background-checks/background-check-billing-invoices.ts new file mode 100644 index 0000000000..376a6a560e --- /dev/null +++ b/apps/api/src/background-checks/background-check-billing-invoices.ts @@ -0,0 +1,59 @@ +import Stripe from 'stripe'; +import { StripeService } from '../stripe/stripe.service'; + +export interface BackgroundCheckBillingInvoice { + id: string; + number: string; + createdAt: string; + dueDate: string | null; + amountPaid: number; + amountDue: number; + currency: string; + status: string; + type: 'Subscription' | 'One Time'; + hostedInvoiceUrl: string | null; + invoicePdfUrl: string | null; +} + +export async function listBackgroundCheckBillingInvoices({ + stripeService, + stripeCustomerId, +}: { + stripeService: StripeService; + stripeCustomerId: string | null; +}): Promise { + if (!stripeCustomerId || !stripeService.isConfigured()) { + return []; + } + + const stripe = stripeService.getClient(); + const invoices = await stripe.invoices.list({ + customer: stripeCustomerId, + limit: 10, + }); + + return invoices.data.map(toBillingInvoice); +} + +function toBillingInvoice( + invoice: Stripe.Invoice, +): BackgroundCheckBillingInvoice { + return { + id: invoice.id, + number: invoice.number ?? invoice.id, + createdAt: new Date(invoice.created * 1000).toISOString(), + dueDate: invoice.due_date + ? new Date(invoice.due_date * 1000).toISOString() + : null, + amountPaid: invoice.amount_paid, + amountDue: invoice.amount_due, + currency: invoice.currency, + status: invoice.status ?? 'unknown', + type: + invoice.parent?.type === 'subscription_details' + ? 'Subscription' + : 'One Time', + hostedInvoiceUrl: invoice.hosted_invoice_url ?? null, + invoicePdfUrl: invoice.invoice_pdf ?? null, + }; +} diff --git a/apps/api/src/background-checks/background-check-billing-urls.ts b/apps/api/src/background-checks/background-check-billing-urls.ts new file mode 100644 index 0000000000..9fd7aae9dd --- /dev/null +++ b/apps/api/src/background-checks/background-check-billing-urls.ts @@ -0,0 +1,24 @@ +import { BadRequestException } from '@nestjs/common'; + +export function validateBackgroundCheckBillingRedirectUrl(url: string): void { + const appUrl = + process.env.NEXT_PUBLIC_APP_URL || + process.env.APP_URL || + process.env.BETTER_AUTH_URL; + if (!appUrl) { + throw new BadRequestException('App URL is not configured on the server.'); + } + + let parsed: URL; + try { + parsed = new URL(url); + } catch { + throw new BadRequestException('Invalid redirect URL.'); + } + + if (parsed.origin !== new URL(appUrl).origin) { + throw new BadRequestException( + 'Redirect URL must belong to the application origin.', + ); + } +} diff --git a/apps/api/src/background-checks/background-check-billing.controller.ts b/apps/api/src/background-checks/background-check-billing.controller.ts index 708d0dd2c1..e872ccb06a 100644 --- a/apps/api/src/background-checks/background-check-billing.controller.ts +++ b/apps/api/src/background-checks/background-check-billing.controller.ts @@ -1,9 +1,17 @@ -import { Body, Controller, Get, HttpCode, Post, UseGuards } from '@nestjs/common'; +import { + Body, + Controller, + Get, + HttpCode, + Post, + UseGuards, +} from '@nestjs/common'; import { ApiOperation, ApiSecurity, ApiTags } from '@nestjs/swagger'; -import { OrganizationId } from '../auth/auth-context.decorator'; +import { AuthContext, OrganizationId } from '../auth/auth-context.decorator'; import { HybridAuthGuard } from '../auth/hybrid-auth.guard'; import { PermissionGuard } from '../auth/permission.guard'; import { RequirePermission } from '../auth/require-permission.decorator'; +import type { AuthContext as AuthContextType } from '../auth/types'; import { BackgroundCheckBillingService } from './background-check-billing.service'; import { BackgroundCheckBillingPortalDto, @@ -28,15 +36,19 @@ export class BackgroundCheckBillingController { @Post('setup-session') @RequirePermission('organization', 'update') @HttpCode(200) - @ApiOperation({ summary: 'Create a Stripe setup session for background checks' }) + @ApiOperation({ + summary: 'Create a Stripe setup session for background checks', + }) async setupSession( @OrganizationId() organizationId: string, + @AuthContext() authContext: AuthContextType, @Body() body: BackgroundCheckSetupSessionDto, ) { return this.billingService.createSetupSession({ organizationId, successUrl: body.successUrl, cancelUrl: body.cancelUrl, + customerEmail: authContext.userEmail, }); } diff --git a/apps/api/src/background-checks/background-check-billing.service.ts b/apps/api/src/background-checks/background-check-billing.service.ts index 21a1b8cd63..da138b69ef 100644 --- a/apps/api/src/background-checks/background-check-billing.service.ts +++ b/apps/api/src/background-checks/background-check-billing.service.ts @@ -1,6 +1,16 @@ -import { BadRequestException, Injectable, NotFoundException } from '@nestjs/common'; +import { + BadRequestException, + Injectable, + NotFoundException, +} from '@nestjs/common'; import { db } from '@db'; import { StripeService } from '../stripe/stripe.service'; +import { findOrCreateBackgroundCheckBillingCustomer } from './background-check-billing-customer'; +import { + type BackgroundCheckBillingInvoice, + listBackgroundCheckBillingInvoices, +} from './background-check-billing-invoices'; +import { validateBackgroundCheckBillingRedirectUrl } from './background-check-billing-urls'; @Injectable() export class BackgroundCheckBillingService { @@ -14,6 +24,7 @@ export class BackgroundCheckBillingService { backgroundChecks: number; penetrationTests: number; }; + invoices: BackgroundCheckBillingInvoice[]; }> { const [billing, backgroundChecks, penetrationTests] = await Promise.all([ db.organizationBilling.findUnique({ @@ -27,6 +38,10 @@ export class BackgroundCheckBillingService { db.backgroundCheckRequest.count({ where: { organizationId } }), db.securityPenetrationTestRun.count({ where: { organizationId } }), ]); + const invoices = await listBackgroundCheckBillingInvoices({ + stripeService: this.stripeService, + stripeCustomerId: billing?.stripeCustomerId ?? null, + }); return { hasBilling: !!billing, @@ -36,6 +51,7 @@ export class BackgroundCheckBillingService { backgroundChecks, penetrationTests, }, + invoices, }; } @@ -43,16 +59,22 @@ export class BackgroundCheckBillingService { organizationId, successUrl, cancelUrl, + customerEmail, }: { organizationId: string; successUrl: string; cancelUrl: string; + customerEmail?: string; }): Promise<{ url: string }> { - this.validateRedirectUrl(successUrl); - this.validateRedirectUrl(cancelUrl); + validateBackgroundCheckBillingRedirectUrl(successUrl); + validateBackgroundCheckBillingRedirectUrl(cancelUrl); const stripe = this.stripeService.getClient(); - const customerId = await this.findOrCreateCustomer(organizationId); + const customerId = await findOrCreateBackgroundCheckBillingCustomer({ + stripeService: this.stripeService, + organizationId, + customerEmail, + }); const price = await this.getBackgroundCheckPrice(); const session = await stripe.checkout.sessions.create({ @@ -68,7 +90,9 @@ export class BackgroundCheckBillingService { }); if (!session.url) { - throw new BadRequestException('Failed to create Stripe Checkout session.'); + throw new BadRequestException( + 'Failed to create Stripe Checkout session.', + ); } return { url: session.url }; @@ -118,7 +142,9 @@ export class BackgroundCheckBillingService { const paymentMethodId = this.extractStripeId(setupIntent.payment_method); if (!paymentMethodId) { - throw new BadRequestException('Setup intent is missing a payment method.'); + throw new BadRequestException( + 'Setup intent is missing a payment method.', + ); } await stripe.customers.update(stripeCustomerId, { @@ -152,7 +178,7 @@ export class BackgroundCheckBillingService { organizationId: string; returnUrl: string; }): Promise<{ url: string }> { - this.validateRedirectUrl(returnUrl); + validateBackgroundCheckBillingRedirectUrl(returnUrl); const stripe = this.stripeService.getClient(); const billing = await db.organizationBilling.findUnique({ @@ -161,7 +187,9 @@ export class BackgroundCheckBillingService { }); if (!billing) { - throw new NotFoundException('No billing record found for this organization.'); + throw new NotFoundException( + 'No billing record found for this organization.', + ); } const portalSession = await stripe.billingPortal.sessions.create({ @@ -172,41 +200,6 @@ export class BackgroundCheckBillingService { return { url: portalSession.url }; } - async findOrCreateCustomer(organizationId: string): Promise { - const existingBilling = await db.organizationBilling.findUnique({ - where: { organizationId }, - select: { stripeCustomerId: true }, - }); - - if (existingBilling) { - return existingBilling.stripeCustomerId; - } - - const organization = await db.organization.findUnique({ - where: { id: organizationId }, - select: { name: true }, - }); - - if (!organization) { - throw new NotFoundException('Organization not found.'); - } - - const stripe = this.stripeService.getClient(); - const customer = await stripe.customers.create({ - name: organization.name, - metadata: { organizationId }, - }); - - await db.organizationBilling.create({ - data: { - organizationId, - stripeCustomerId: customer.id, - }, - }); - - return customer.id; - } - async getBackgroundCheckPrice(): Promise<{ id: string; unitAmount: number; @@ -234,28 +227,9 @@ export class BackgroundCheckBillingService { }; } - private validateRedirectUrl(url: string): void { - const appUrl = - process.env.NEXT_PUBLIC_APP_URL || - process.env.APP_URL || - process.env.BETTER_AUTH_URL; - if (!appUrl) { - throw new BadRequestException('App URL is not configured on the server.'); - } - - let parsed: URL; - try { - parsed = new URL(url); - } catch { - throw new BadRequestException('Invalid redirect URL.'); - } - - if (parsed.origin !== new URL(appUrl).origin) { - throw new BadRequestException('Redirect URL must belong to the application origin.'); - } - } - - private extractStripeId(value: string | { id?: string } | null): string | null { + private extractStripeId( + value: string | { id?: string } | null, + ): string | null { if (!value) return null; if (typeof value === 'string') return value; return value.id ?? null; diff --git a/apps/api/src/background-checks/background-check-payment.service.spec.ts b/apps/api/src/background-checks/background-check-payment.service.spec.ts index b794294310..0359de5768 100644 --- a/apps/api/src/background-checks/background-check-payment.service.spec.ts +++ b/apps/api/src/background-checks/background-check-payment.service.spec.ts @@ -29,7 +29,9 @@ describe('BackgroundCheckPaymentService', () => { ).mockResolvedValueOnce(null); const service = new BackgroundCheckPaymentService( { getClient: jest.fn() } as unknown as StripeService, - { getBackgroundCheckPrice: jest.fn() } as unknown as BackgroundCheckBillingService, + { + getBackgroundCheckPrice: jest.fn(), + } as unknown as BackgroundCheckBillingService, ); await expect( @@ -41,20 +43,40 @@ describe('BackgroundCheckPaymentService', () => { ); }); - it('charges Stripe with payment-method scoped idempotency key', async () => { + it('creates and pays a Stripe invoice with payment-method scoped idempotency keys', async () => { mockAsync>>( mockedDb.organizationBilling.findUnique, ).mockResolvedValueOnce({ stripeCustomerId: 'cus_1', stripeBackgroundCheckPaymentMethodId: 'pm_1', } as Awaited>); - const paymentIntentsCreate = jest.fn().mockResolvedValue({ - id: 'pi_1', - status: 'succeeded', + const invoiceItemsCreate = jest.fn().mockResolvedValue({ id: 'ii_1' }); + const invoicesCreate = jest.fn().mockResolvedValue({ id: 'in_1' }); + const finalizeInvoice = jest.fn().mockResolvedValue({ id: 'in_1' }); + const invoicesPay = jest.fn().mockResolvedValue({ + id: 'in_1', + status: 'paid', + payments: { + data: [ + { + payment: { + type: 'payment_intent', + payment_intent: 'pi_1', + }, + }, + ], + }, }); const service = new BackgroundCheckPaymentService( { - getClient: () => ({ paymentIntents: { create: paymentIntentsCreate } }), + getClient: () => ({ + invoiceItems: { create: invoiceItemsCreate }, + invoices: { + create: invoicesCreate, + finalizeInvoice, + pay: invoicesPay, + }, + }), } as unknown as StripeService, { getBackgroundCheckPrice: jest.fn().mockResolvedValue({ @@ -65,16 +87,55 @@ describe('BackgroundCheckPaymentService', () => { } as unknown as BackgroundCheckBillingService, ); - await service.charge({ organizationId: 'org_1', memberId: 'mem_1' }); + await expect( + service.charge({ organizationId: 'org_1', memberId: 'mem_1' }), + ).resolves.toMatchObject({ + paymentIntentId: 'pi_1', + invoiceId: 'in_1', + status: 'succeeded', + }); - expect(paymentIntentsCreate).toHaveBeenCalledWith( + expect(invoicesCreate).toHaveBeenCalledWith( expect.objectContaining({ - amount: 1250, customer: 'cus_1', + collection_method: 'charge_automatically', description: 'Comp AI - Background Check x1', + default_payment_method: 'pm_1', + statement_descriptor: 'COMP AI BG CHECK', + }), + { idempotencyKey: 'background-check:org_1:mem_1:price_bg:pm_1:invoice' }, + ); + expect(invoiceItemsCreate).toHaveBeenCalledWith( + expect.objectContaining({ + customer: 'cus_1', + invoice: 'in_1', + pricing: { + price: 'price_bg', + }, + quantity: 1, + }), + { + idempotencyKey: 'background-check:org_1:mem_1:price_bg:pm_1:line-item', + }, + ); + expect(finalizeInvoice).toHaveBeenCalledWith( + 'in_1', + { auto_advance: false }, + { + idempotencyKey: + 'background-check:org_1:mem_1:price_bg:pm_1:finalize-invoice', + }, + ); + expect(invoicesPay).toHaveBeenCalledWith( + 'in_1', + expect.objectContaining({ payment_method: 'pm_1', + off_session: true, }), - { idempotencyKey: 'background-check:org_1:mem_1:price_bg:pm_1' }, + { + idempotencyKey: + 'background-check:org_1:mem_1:price_bg:pm_1:pay-invoice', + }, ); }); }); diff --git a/apps/api/src/background-checks/background-check-payment.service.ts b/apps/api/src/background-checks/background-check-payment.service.ts index 17e0dc4501..b49fa6bcb1 100644 --- a/apps/api/src/background-checks/background-check-payment.service.ts +++ b/apps/api/src/background-checks/background-check-payment.service.ts @@ -1,12 +1,14 @@ import { HttpException, HttpStatus, Injectable, Logger } from '@nestjs/common'; import { db } from '@db'; +import Stripe from 'stripe'; import { StripeService } from '../stripe/stripe.service'; import { BackgroundCheckBillingService } from './background-check-billing.service'; @Injectable() export class BackgroundCheckPaymentService { - private static readonly receiptDescription = - 'Comp AI - Background Check x1'; + private static readonly receiptDescription = 'Comp AI - Background Check x1'; + + private static readonly statementDescriptor = 'COMP AI BG CHECK'; private readonly logger = new Logger(BackgroundCheckPaymentService.name); @@ -15,11 +17,9 @@ export class BackgroundCheckPaymentService { private readonly billingService: BackgroundCheckBillingService, ) {} - async charge(params: { - organizationId: string; - memberId: string; - }): Promise<{ + async charge(params: { organizationId: string; memberId: string }): Promise<{ paymentIntentId: string; + invoiceId: string; status: string; amount: number; currency: string; @@ -41,37 +41,91 @@ export class BackgroundCheckPaymentService { const price = await this.billingService.getBackgroundCheckPrice(); const stripe = this.stripeService.getClient(); - const paymentIntent = await stripe.paymentIntents.create( + const metadata = { + source: 'comp-background-check', + compOrganizationId: params.organizationId, + compMemberId: params.memberId, + }; + const idempotencyKeyParts = [ + 'background-check', + params.organizationId, + params.memberId, + price.id, + billing.stripeBackgroundCheckPaymentMethodId, + ]; + + const invoice = await stripe.invoices.create( { customer: billing.stripeCustomerId, - amount: price.unitAmount, + collection_method: 'charge_automatically', currency: price.currency, + default_payment_method: billing.stripeBackgroundCheckPaymentMethodId, description: BackgroundCheckPaymentService.receiptDescription, - payment_method: billing.stripeBackgroundCheckPaymentMethodId, - off_session: true, - confirm: true, - automatic_payment_methods: { - enabled: true, - allow_redirects: 'never', - }, - metadata: { - source: 'comp-background-check', - compOrganizationId: params.organizationId, - compMemberId: params.memberId, + statement_descriptor: BackgroundCheckPaymentService.statementDescriptor, + auto_advance: false, + metadata, + }, + { + idempotencyKey: [...idempotencyKeyParts, 'invoice'].join(':'), + }, + ); + + await stripe.invoiceItems.create( + { + customer: billing.stripeCustomerId, + invoice: invoice.id, + pricing: { + price: price.id, }, + quantity: 1, + description: BackgroundCheckPaymentService.receiptDescription, + metadata, + }, + { + idempotencyKey: [...idempotencyKeyParts, 'line-item'].join(':'), }, + ); + + await stripe.invoices.finalizeInvoice( + invoice.id, + { auto_advance: false }, { - idempotencyKey: [ - 'background-check', - params.organizationId, - params.memberId, - price.id, - billing.stripeBackgroundCheckPaymentMethodId, - ].join(':'), + idempotencyKey: [...idempotencyKeyParts, 'finalize-invoice'].join(':'), }, ); - if (paymentIntent.status !== 'succeeded') { + let paidInvoice: Stripe.Invoice; + try { + paidInvoice = await stripe.invoices.pay( + invoice.id, + { + payment_method: billing.stripeBackgroundCheckPaymentMethodId, + off_session: true, + expand: ['payments'], + }, + { + idempotencyKey: [...idempotencyKeyParts, 'pay-invoice'].join(':'), + }, + ); + } catch (error) { + await this.voidInvoice({ stripe, invoiceId: invoice.id }); + throw new HttpException( + 'Background check payment failed. Update billing and try again.', + HttpStatus.PAYMENT_REQUIRED, + { cause: error }, + ); + } + + if (paidInvoice.status !== 'paid') { + await this.voidInvoice({ stripe, invoiceId: invoice.id }); + throw new HttpException( + 'Background check payment failed. Update billing and try again.', + HttpStatus.PAYMENT_REQUIRED, + ); + } + + const paymentIntentId = this.extractPaymentIntentId(paidInvoice); + if (!paymentIntentId) { throw new HttpException( 'Background check payment failed. Update billing and try again.', HttpStatus.PAYMENT_REQUIRED, @@ -79,8 +133,9 @@ export class BackgroundCheckPaymentService { } return { - paymentIntentId: paymentIntent.id, - status: paymentIntent.status, + paymentIntentId, + invoiceId: paidInvoice.id, + status: 'succeeded', amount: price.unitAmount, currency: price.currency, }; @@ -118,4 +173,30 @@ export class BackgroundCheckPaymentService { return null; } } + + private extractPaymentIntentId(invoice: Stripe.Invoice): string | null { + const payment = invoice.payments?.data.find( + (invoicePayment) => invoicePayment.payment.type === 'payment_intent', + ); + const paymentIntent = payment?.payment.payment_intent; + if (!paymentIntent) return null; + return typeof paymentIntent === 'string' ? paymentIntent : paymentIntent.id; + } + + private async voidInvoice({ + stripe, + invoiceId, + }: { + stripe: Stripe; + invoiceId: string; + }): Promise { + try { + await stripe.invoices.voidInvoice(invoiceId); + } catch (error) { + this.logger.error('Failed to void unpaid background check invoice.', { + invoiceId, + error: error instanceof Error ? error.message : 'Unknown error', + }); + } + } } diff --git a/apps/api/src/background-checks/background-checks.service.spec.ts b/apps/api/src/background-checks/background-checks.service.spec.ts index 17dfcaa0e4..0fc3c88d7c 100644 --- a/apps/api/src/background-checks/background-checks.service.spec.ts +++ b/apps/api/src/background-checks/background-checks.service.spec.ts @@ -61,7 +61,8 @@ function mockAsync(fn: unknown): jest.MockedFunction<() => Promise> { function invocationOrder(fn: unknown, index = 0): number { return ( - (fn as { mock: { invocationCallOrder: number[] } }).mock.invocationCallOrder[index] ?? 0 + (fn as { mock: { invocationCallOrder: number[] } }).mock + .invocationCallOrder[index] ?? 0 ); } @@ -132,7 +133,9 @@ describe('background checks', () => { compOrganizationId: 'org_1', compMemberId: 'mem_1', }); - expect(body.callbackUrl).toBe('https://api.trycomp.ai/v1/background-checks/webhook'); + expect(body.callbackUrl).toBe( + 'https://api.trycomp.ai/v1/background-checks/webhook', + ); expect(body.requesterNotes).toBeUndefined(); }); @@ -192,7 +195,9 @@ describe('background checks', () => { mockAsync>>( mockedDb.backgroundCheckRequest.findUnique, ).mockResolvedValueOnce( - existing as Awaited>, + existing as Awaited< + ReturnType + >, ); const identityClient = { createBackgroundCheck: jest.fn() }; const paymentService = { charge: jest.fn(), refund: jest.fn() }; @@ -243,7 +248,9 @@ describe('background checks', () => { } as Awaited>); const identityClient = { - createBackgroundCheck: jest.fn().mockRejectedValue(new Error('identity down')), + createBackgroundCheck: jest + .fn() + .mockRejectedValue(new Error('identity down')), }; const paymentService = { charge: jest.fn().mockResolvedValue({ @@ -369,9 +376,9 @@ describe('background checks', () => { }), ); // Record is created before Identity API is called - expect(invocationOrder(mockedDb.backgroundCheckRequest.create)).toBeLessThan( - invocationOrder(identityClient.createBackgroundCheck), - ); + expect( + invocationOrder(mockedDb.backgroundCheckRequest.create), + ).toBeLessThan(invocationOrder(identityClient.createBackgroundCheck)); expect(identityClient.createBackgroundCheck).toHaveBeenCalledWith( expect.not.objectContaining({ requesterNotes: expect.any(String), @@ -476,8 +483,15 @@ describe('background checks', () => { successUrl: 'http://localhost:3000/org_1/people/mem_1?background_check_billing=success', cancelUrl: 'http://localhost:3000/org_1/people/mem_1', + customerEmail: 'billing@trycomp.ai', }), ).resolves.toEqual({ url: 'https://checkout.stripe.com/c/session_1' }); + + expect(stripe.customers.create).toHaveBeenCalledWith({ + name: 'Acme', + email: 'billing@trycomp.ai', + metadata: { organizationId: 'org_1' }, + }); }); it('includes background check and penetration test usage in billing status', async () => { @@ -488,13 +502,33 @@ describe('background checks', () => { stripeBackgroundCheckPaymentMethodId: 'pm_1', backgroundCheckPaymentMethodSetupAt: new Date('2026-04-29T12:00:00.000Z'), } as Awaited>); - mockAsync(mockedDb.backgroundCheckRequest.count).mockResolvedValueOnce(4); + mockAsync( + mockedDb.backgroundCheckRequest.count, + ).mockResolvedValueOnce(4); mockAsync( mockedDb.securityPenetrationTestRun.count, ).mockResolvedValueOnce(2); + const invoicesList = jest.fn().mockResolvedValue({ + data: [ + { + id: 'in_1', + number: 'INV-001', + created: 1777464000, + due_date: null, + amount_paid: 4900, + amount_due: 4900, + currency: 'usd', + status: 'paid', + parent: null, + hosted_invoice_url: 'https://invoice.stripe.com/i/in_1', + invoice_pdf: 'https://invoice.stripe.com/i/in_1.pdf', + }, + ], + }); const service = new BackgroundCheckBillingService({ - getClient: jest.fn(), + getClient: () => ({ invoices: { list: invoicesList } }), + isConfigured: () => true, } as unknown as StripeService); await expect(service.getStatus('org_1')).resolves.toMatchObject({ @@ -504,7 +538,17 @@ describe('background checks', () => { backgroundChecks: 4, penetrationTests: 2, }, + invoices: [ + { + id: 'in_1', + number: 'INV-001', + amountPaid: 4900, + status: 'paid', + type: 'One Time', + }, + ], }); + expect(invoicesList).toHaveBeenCalledWith({ customer: 'cus_1', limit: 10 }); expect(mockedDb.backgroundCheckRequest.count).toHaveBeenCalledWith({ where: { organizationId: 'org_1' }, }); diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingInvoicesTable.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingInvoicesTable.tsx new file mode 100644 index 0000000000..3d980f01b5 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingInvoicesTable.tsx @@ -0,0 +1,170 @@ +'use client'; + +import { Badge, Button, Input, Stack, Text } from '@trycompai/design-system'; +import { Download, Launch, Search } from '@trycompai/design-system/icons'; +import type React from 'react'; +import { useMemo, useState } from 'react'; +import type { BillingInvoice } from './types'; + +interface BillingInvoicesTableProps { + invoices: BillingInvoice[]; +} + +export function BillingInvoicesTable({ invoices }: BillingInvoicesTableProps) { + const [query, setQuery] = useState(''); + const filteredInvoices = useMemo(() => { + const normalizedQuery = query.trim().toLowerCase(); + if (!normalizedQuery) return invoices; + + return invoices.filter((invoice) => { + const searchable = [ + invoice.number, + invoice.status, + invoice.type, + formatAmount(invoice.amountPaid, invoice.currency), + formatDate(invoice.createdAt), + ] + .join(' ') + .toLowerCase(); + + return searchable.includes(normalizedQuery); + }); + }, [invoices, query]); + + return ( +
+
+ + + + Invoices + + + View and download invoices for paid services. + + +
+
+ + + + setQuery(event.target.value)} + /> +
+
+
+
+
+ + + + Invoice + Date + Due Date + Amount + Status + Type + + Actions + + + + + {filteredInvoices.map((invoice) => ( + + ))} + {filteredInvoices.length === 0 && ( + + + + )} + +
+ + {invoices.length === 0 ? 'No invoices yet.' : 'No invoices match your search.'} + +
+
+
+ + {filteredInvoices.length} of {invoices.length} invoice + {invoices.length === 1 ? '' : 's'} + +
+
+ ); +} + +function InvoiceRow({ invoice }: { invoice: BillingInvoice }) { + return ( + + {invoice.number} + {formatDate(invoice.createdAt)} + + {invoice.dueDate ? formatDate(invoice.dueDate) : 'On receipt'} + + {formatAmount(invoice.amountPaid, invoice.currency)} + + + {formatStatus(invoice.status)} + + + {invoice.type} + +
+ {invoice.hostedInvoiceUrl && ( + + )} + {invoice.invoicePdfUrl && ( + + )} +
+ + + ); +} + +function TableHead({ children }: { children: React.ReactNode }) { + return {children}; +} + +function formatAmount(amount: number, currency: string) { + return new Intl.NumberFormat('en-US', { + style: 'currency', + currency: currency.toUpperCase(), + }).format(amount / 100); +} + +function formatDate(date: string) { + return new Intl.DateTimeFormat('en-US', { + month: 'short', + day: 'numeric', + year: 'numeric', + }).format(new Date(date)); +} + +function formatStatus(status: string) { + return status.charAt(0).toUpperCase() + status.slice(1); +} diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.test.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.test.tsx index b4aa714e9d..8b6e00617c 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.test.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.test.tsx @@ -18,9 +18,7 @@ const permissionMock = vi.hoisted(() => ({ vi.mock('@/hooks/use-permissions', () => ({ usePermissions: () => ({ hasPermission: (resource: string, action: string) => - resource === 'organization' && - action === 'update' && - permissionMock.canUpdateOrganization, + resource === 'organization' && action === 'update' && permissionMock.canUpdateOrganization, }), })); @@ -45,10 +43,28 @@ function renderBillingSettings({ hasPaymentMethod = true, backgroundChecks = 4, penetrationTests = 2, + invoices = [ + { + id: 'in_1', + number: 'INV-001', + createdAt: '2026-04-30T09:35:07.000Z', + dueDate: null, + amountPaid: 4900, + amountDue: 4900, + currency: 'usd', + status: 'paid', + type: 'One Time' as const, + hostedInvoiceUrl: 'https://invoice.stripe.com/i/in_1', + invoicePdfUrl: 'https://invoice.stripe.com/i/in_1.pdf', + }, + ], }: { hasPaymentMethod?: boolean; backgroundChecks?: number; penetrationTests?: number; + invoices?: NonNullable< + Parameters[0]['initialBillingStatus']['invoices'] + >; } = {}) { return render( new Map() }}> @@ -58,6 +74,7 @@ function renderBillingSettings({ hasPaymentMethod, setupAt: null, usage: { backgroundChecks, penetrationTests }, + invoices, }} /> , @@ -88,7 +105,11 @@ describe('BillingSettingsClient', () => { expect(screen.getByText('2')).toBeInTheDocument(); expect(screen.getByText('4')).toBeInTheDocument(); expect(screen.getByText(/update billing details, cards, and receipts/i)).toBeInTheDocument(); - await user.click(screen.getByRole('button', { name: /open stripe portal/i })); + expect(screen.getByText('Invoices')).toBeInTheDocument(); + expect(screen.getByText('INV-001')).toBeInTheDocument(); + expect(screen.getByText('$49.00')).toBeInTheDocument(); + expect(screen.getByText('One Time')).toBeInTheDocument(); + await user.click(screen.getByRole('button', { name: /update billing details/i })); await waitFor(() => { expect(apiClient.post).toHaveBeenCalledWith( @@ -125,7 +146,7 @@ describe('BillingSettingsClient', () => { permissionMock.canUpdateOrganization = false; renderBillingSettings({ hasPaymentMethod: true }); - const button = screen.getByRole('button', { name: /open stripe portal/i }); + const button = screen.getByRole('button', { name: /update billing details/i }); expect(button).toBeDisabled(); await user.click(button); @@ -146,5 +167,45 @@ describe('BillingSettingsClient', () => { ); expect(screen.getAllByText('0')).toHaveLength(2); + expect(screen.getByText('No invoices yet.')).toBeInTheDocument(); + }); + + it('filters invoices by search query', async () => { + const user = userEvent.setup(); + renderBillingSettings({ + invoices: [ + { + id: 'in_1', + number: 'INV-001', + createdAt: '2026-04-30T09:35:07.000Z', + dueDate: null, + amountPaid: 4900, + amountDue: 4900, + currency: 'usd', + status: 'paid', + type: 'One Time', + hostedInvoiceUrl: 'https://invoice.stripe.com/i/in_1', + invoicePdfUrl: 'https://invoice.stripe.com/i/in_1.pdf', + }, + { + id: 'in_2', + number: 'INV-002', + createdAt: '2026-04-01T09:35:07.000Z', + dueDate: null, + amountPaid: 0, + amountDue: 0, + currency: 'usd', + status: 'paid', + type: 'Subscription', + hostedInvoiceUrl: null, + invoicePdfUrl: null, + }, + ], + }); + + await user.type(screen.getByLabelText('Search invoices'), 'subscription'); + + expect(screen.queryByText('INV-001')).not.toBeInTheDocument(); + expect(screen.getByText('INV-002')).toBeInTheDocument(); }); }); diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.tsx index 736481961a..eb2034a47b 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.tsx @@ -13,9 +13,10 @@ import { } from '@trycompai/design-system'; import { Launch } from '@trycompai/design-system/icons'; import { usePathname, useRouter, useSearchParams } from 'next/navigation'; -import { useEffect, useRef, useState } from 'react'; +import { useEffect, useMemo, useRef, useState } from 'react'; import { toast } from 'sonner'; import useSWR from 'swr'; +import { BillingInvoicesTable } from './BillingInvoicesTable'; import type { BackgroundCheckBillingStatus } from './types'; interface BillingSettingsClientProps { @@ -40,29 +41,29 @@ export function BillingSettingsClient({ const { hasPermission } = usePermissions(); const canManageBilling = hasPermission('organization', 'update'); - const { data: billingStatus, mutate: mutateBillingStatus } = - useSWR( - ['/v1/background-check-billing/status', organizationId], - async ([endpoint]) => { - const response = await apiClient.get( - endpoint, - organizationId, - ); - if (response.error || !response.data) { - throw new Error('Failed to load billing status'); - } - return response.data; - }, - { - fallbackData: initialBillingStatus, - revalidateOnMount: false, - }, - ); + const { data: billingStatus, mutate: mutateBillingStatus } = useSWR( + ['/v1/background-check-billing/status', organizationId], + async ([endpoint]) => { + const response = await apiClient.get(endpoint, organizationId); + if (response.error || !response.data) { + throw new Error('Failed to load billing status'); + } + return response.data; + }, + { + fallbackData: initialBillingStatus, + revalidateOnMount: false, + }, + ); const hasPaymentMethod = billingStatus?.hasPaymentMethod === true; const usage = billingStatus?.usage ?? initialBillingStatus.usage ?? defaultUsage; + const invoices = useMemo( + () => billingStatus?.invoices ?? initialBillingStatus.invoices ?? [], + [billingStatus?.invoices, initialBillingStatus.invoices], + ); const statusLabel = hasPaymentMethod ? 'Billing set up' : 'Payment method needed'; - const actionLabel = hasPaymentMethod ? 'Open Stripe portal' : 'Add payment method'; + const actionLabel = hasPaymentMethod ? 'Update Billing Details' : 'Add payment method'; useEffect(() => { const sessionId = searchParams.get('session_id'); @@ -89,12 +90,13 @@ export function BillingSettingsClient({ hasPaymentMethod: true, setupAt: new Date().toISOString(), usage, + invoices, }, { revalidate: true }, ); router.replace(pathname, { scroll: false }); })(); - }, [mutateBillingStatus, organizationId, pathname, router, searchParams]); + }, [invoices, mutateBillingStatus, organizationId, pathname, router, searchParams, usage]); const handleOpenBilling = async () => { setIsOpeningBilling(true); @@ -109,11 +111,7 @@ export function BillingSettingsClient({ successUrl: `${returnUrl}?background_check_billing=success&session_id={CHECKOUT_SESSION_ID}`, cancelUrl: returnUrl, }; - const response = await apiClient.post<{ url: string }>( - endpoint, - body, - organizationId, - ); + const response = await apiClient.post<{ url: string }>(endpoint, body, organizationId); if (response.data?.url) { window.location.href = response.data.url; @@ -158,14 +156,8 @@ export function BillingSettingsClient({
- - + +
@@ -175,9 +167,7 @@ export function BillingSettingsClient({
Payment method - - {statusLabel} - + {statusLabel}
{hasPaymentMethod @@ -200,6 +190,7 @@ export function BillingSettingsClient({ + ); } diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/page.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/page.tsx index b413e6151d..bf72e67da4 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/page.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/page.tsx @@ -10,13 +10,10 @@ const emptyBillingStatus: BackgroundCheckBillingStatus = { backgroundChecks: 0, penetrationTests: 0, }, + invoices: [], }; -export default async function BillingPage({ - params, -}: { - params: Promise<{ orgId: string }>; -}) { +export default async function BillingPage({ params }: { params: Promise<{ orgId: string }> }) { const { orgId } = await params; const response = await serverApi.get( '/v1/background-check-billing/status', diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/types.ts b/apps/app/src/app/(app)/[orgId]/settings/billing/types.ts index 335e87c6f7..31c2e8c1a3 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/types.ts +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/types.ts @@ -1,3 +1,17 @@ +export interface BillingInvoice { + id: string; + number: string; + createdAt: string; + dueDate: string | null; + amountPaid: number; + amountDue: number; + currency: string; + status: string; + type: 'Subscription' | 'One Time'; + hostedInvoiceUrl: string | null; + invoicePdfUrl: string | null; +} + export interface BackgroundCheckBillingStatus { hasPaymentMethod: boolean; setupAt: string | null; @@ -5,4 +19,5 @@ export interface BackgroundCheckBillingStatus { backgroundChecks: number; penetrationTests: number; }; + invoices?: BillingInvoice[]; } From 9acb95ebc6a865b547de8aabb5787c377addc095 Mon Sep 17 00:00:00 2001 From: Lewis Carhart Date: Thu, 30 Apr 2026 17:45:56 +0100 Subject: [PATCH 02/20] fix(billing): update layout of BillingInvoicesTable component - Refactored the layout of the BillingInvoicesTable component to improve responsiveness. - Changed Stack component to a div with flex properties for better alignment on larger screens. - Adjusted width properties for the search input to enhance UI consistency. --- .../(app)/[orgId]/settings/billing/BillingInvoicesTable.tsx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingInvoicesTable.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingInvoicesTable.tsx index 3d980f01b5..3aae34864b 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingInvoicesTable.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingInvoicesTable.tsx @@ -34,7 +34,7 @@ export function BillingInvoicesTable({ invoices }: BillingInvoicesTableProps) { return (
- +
Invoices @@ -43,7 +43,7 @@ export function BillingInvoicesTable({ invoices }: BillingInvoicesTableProps) { View and download invoices for paid services. -
+
@@ -56,7 +56,7 @@ export function BillingInvoicesTable({ invoices }: BillingInvoicesTableProps) { />
- +
From 11efc56a554ed471f48a7249c4236aabd3a60d2d Mon Sep 17 00:00:00 2001 From: Lewis Carhart Date: Thu, 30 Apr 2026 17:55:36 +0100 Subject: [PATCH 03/20] feat(billing): add tests for background check billing customer and URL validation - Introduced unit tests for the findOrCreateBackgroundCheckBillingCustomer function to ensure proper handling of concurrent requests and Stripe customer updates. - Added tests for validateBackgroundCheckBillingRedirectUrl to validate app URL configurations and handle malformed URLs. - Enhanced BackgroundCheckPaymentService tests to cover invoice voiding scenarios when invoice item creation or finalization fails. --- .../background-check-billing-customer.spec.ts | 96 +++++++++++++++++++ .../background-check-billing-customer.ts | 90 +++++++++++++---- .../background-check-billing-urls.spec.ts | 22 +++++ .../background-check-billing-urls.ts | 11 ++- .../background-check-payment.service.spec.ts | 95 ++++++++++++++++++ .../background-check-payment.service.ts | 48 +++++----- .../background-checks.service.spec.ts | 12 ++- 7 files changed, 330 insertions(+), 44 deletions(-) create mode 100644 apps/api/src/background-checks/background-check-billing-customer.spec.ts create mode 100644 apps/api/src/background-checks/background-check-billing-urls.spec.ts diff --git a/apps/api/src/background-checks/background-check-billing-customer.spec.ts b/apps/api/src/background-checks/background-check-billing-customer.spec.ts new file mode 100644 index 0000000000..d67cb79922 --- /dev/null +++ b/apps/api/src/background-checks/background-check-billing-customer.spec.ts @@ -0,0 +1,96 @@ +import { db, Prisma } from '@db'; +import type { StripeService } from '../stripe/stripe.service'; +import { findOrCreateBackgroundCheckBillingCustomer } from './background-check-billing-customer'; + +jest.mock('@db', () => { + class PrismaClientKnownRequestError extends Error { + code: string; + + constructor(message: string, options: { code: string }) { + super(message); + this.code = options.code; + } + } + + return { + Prisma: { + PrismaClientKnownRequestError, + }, + db: { + organizationBilling: { + findUnique: jest.fn(), + create: jest.fn(), + }, + organization: { + findUnique: jest.fn(), + }, + }, + }; +}); + +const mockedDb = db as jest.Mocked; + +describe('findOrCreateBackgroundCheckBillingCustomer', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('recovers when a concurrent request creates the billing row first', async () => { + type OrganizationBilling = typeof db.organizationBilling; + type Organization = typeof db.organization; + const dbMocks = mockedDb as unknown as { + organizationBilling: { + findUnique: jest.MockedFunction; + create: jest.MockedFunction; + }; + organization: { + findUnique: jest.MockedFunction; + }; + }; + + dbMocks.organizationBilling.findUnique + .mockResolvedValueOnce(null) + .mockResolvedValueOnce({ + stripeCustomerId: 'cus_winner', + } as Awaited>); + dbMocks.organization.findUnique.mockResolvedValueOnce({ + name: 'Acme', + } as Awaited>); + dbMocks.organizationBilling.create.mockRejectedValueOnce( + new Prisma.PrismaClientKnownRequestError('Unique constraint failed', { + code: 'P2002', + clientVersion: 'test', + }), + ); + + const customersCreate = jest.fn().mockResolvedValue({ id: 'cus_loser' }); + const customersUpdate = jest.fn().mockResolvedValue({ id: 'cus_winner' }); + const stripeService = { + getClient: () => ({ + customers: { + create: customersCreate, + update: customersUpdate, + }, + }), + } as unknown as StripeService; + + await expect( + findOrCreateBackgroundCheckBillingCustomer({ + stripeService, + organizationId: 'org_1', + customerEmail: 'billing@trycomp.ai', + }), + ).resolves.toBe('cus_winner'); + + expect(customersCreate).toHaveBeenCalledWith( + { + name: 'Acme', + metadata: { organizationId: 'org_1' }, + }, + { idempotencyKey: 'background-check-customer:org_1' }, + ); + expect(customersUpdate).toHaveBeenCalledWith('cus_winner', { + email: 'billing@trycomp.ai', + }); + }); +}); diff --git a/apps/api/src/background-checks/background-check-billing-customer.ts b/apps/api/src/background-checks/background-check-billing-customer.ts index a6038f2778..e8604d435e 100644 --- a/apps/api/src/background-checks/background-check-billing-customer.ts +++ b/apps/api/src/background-checks/background-check-billing-customer.ts @@ -1,5 +1,5 @@ import { NotFoundException } from '@nestjs/common'; -import { db } from '@db'; +import { db, Prisma } from '@db'; import { StripeService } from '../stripe/stripe.service'; export async function findOrCreateBackgroundCheckBillingCustomer({ @@ -18,12 +18,11 @@ export async function findOrCreateBackgroundCheckBillingCustomer({ const stripe = stripeService.getClient(); if (existingBilling) { - if (customerEmail) { - await stripe.customers.update(existingBilling.stripeCustomerId, { - email: customerEmail, - }); - } - + await updateStripeCustomerEmail({ + stripeService, + stripeCustomerId: existingBilling.stripeCustomerId, + customerEmail, + }); return existingBilling.stripeCustomerId; } @@ -36,18 +35,75 @@ export async function findOrCreateBackgroundCheckBillingCustomer({ throw new NotFoundException('Organization not found.'); } - const customer = await stripe.customers.create({ - name: organization.name, - ...(customerEmail ? { email: customerEmail } : {}), - metadata: { organizationId }, - }); - - await db.organizationBilling.create({ - data: { - organizationId, - stripeCustomerId: customer.id, + const customer = await stripe.customers.create( + { + name: organization.name, + metadata: { organizationId }, }, + { + idempotencyKey: `background-check-customer:${organizationId}`, + }, + ); + + try { + await db.organizationBilling.create({ + data: { + organizationId, + stripeCustomerId: customer.id, + }, + }); + } catch (error) { + if (!isUniqueConstraintError(error)) { + throw error; + } + + const billing = await db.organizationBilling.findUnique({ + where: { organizationId }, + select: { stripeCustomerId: true }, + }); + + if (!billing) { + throw error; + } + + await updateStripeCustomerEmail({ + stripeService, + stripeCustomerId: billing.stripeCustomerId, + customerEmail, + }); + + return billing.stripeCustomerId; + } + + await updateStripeCustomerEmail({ + stripeService, + stripeCustomerId: customer.id, + customerEmail, }); return customer.id; } + +async function updateStripeCustomerEmail({ + stripeService, + stripeCustomerId, + customerEmail, +}: { + stripeService: StripeService; + stripeCustomerId: string; + customerEmail?: string; +}): Promise { + if (!customerEmail) return; + + const stripe = stripeService.getClient(); + await stripe.customers.update(stripeCustomerId, { + email: customerEmail, + }); +} + +function isUniqueConstraintError(error: unknown): boolean { + return ( + error instanceof Prisma.PrismaClientKnownRequestError && + error.code === 'P2002' + ); +} diff --git a/apps/api/src/background-checks/background-check-billing-urls.spec.ts b/apps/api/src/background-checks/background-check-billing-urls.spec.ts new file mode 100644 index 0000000000..ab4c843afe --- /dev/null +++ b/apps/api/src/background-checks/background-check-billing-urls.spec.ts @@ -0,0 +1,22 @@ +import { BadRequestException } from '@nestjs/common'; +import { validateBackgroundCheckBillingRedirectUrl } from './background-check-billing-urls'; + +describe('validateBackgroundCheckBillingRedirectUrl', () => { + const originalEnv = { ...process.env }; + + afterEach(() => { + process.env = { ...originalEnv }; + }); + + it('throws a controlled error when the configured app URL is malformed', () => { + process.env.NEXT_PUBLIC_APP_URL = 'not a url'; + process.env.APP_URL = ''; + process.env.BETTER_AUTH_URL = ''; + + expect(() => + validateBackgroundCheckBillingRedirectUrl( + 'https://app.trycomp.ai/return', + ), + ).toThrow(BadRequestException); + }); +}); diff --git a/apps/api/src/background-checks/background-check-billing-urls.ts b/apps/api/src/background-checks/background-check-billing-urls.ts index 9fd7aae9dd..01b2541bc3 100644 --- a/apps/api/src/background-checks/background-check-billing-urls.ts +++ b/apps/api/src/background-checks/background-check-billing-urls.ts @@ -9,6 +9,15 @@ export function validateBackgroundCheckBillingRedirectUrl(url: string): void { throw new BadRequestException('App URL is not configured on the server.'); } + let appOrigin: string; + try { + appOrigin = new URL(appUrl).origin; + } catch { + throw new BadRequestException( + 'App URL is not configured correctly on the server.', + ); + } + let parsed: URL; try { parsed = new URL(url); @@ -16,7 +25,7 @@ export function validateBackgroundCheckBillingRedirectUrl(url: string): void { throw new BadRequestException('Invalid redirect URL.'); } - if (parsed.origin !== new URL(appUrl).origin) { + if (parsed.origin !== appOrigin) { throw new BadRequestException( 'Redirect URL must belong to the application origin.', ); diff --git a/apps/api/src/background-checks/background-check-payment.service.spec.ts b/apps/api/src/background-checks/background-check-payment.service.spec.ts index 0359de5768..e17936cec9 100644 --- a/apps/api/src/background-checks/background-check-payment.service.spec.ts +++ b/apps/api/src/background-checks/background-check-payment.service.spec.ts @@ -138,4 +138,99 @@ describe('BackgroundCheckPaymentService', () => { }, ); }); + + it('voids the invoice when adding the invoice item fails', async () => { + mockAsync>>( + mockedDb.organizationBilling.findUnique, + ).mockResolvedValueOnce({ + stripeCustomerId: 'cus_1', + stripeBackgroundCheckPaymentMethodId: 'pm_1', + } as Awaited>); + const invoiceItemsCreate = jest + .fn() + .mockRejectedValue(new Error('line item failed')); + const invoicesCreate = jest.fn().mockResolvedValue({ id: 'in_1' }); + const finalizeInvoice = jest.fn(); + const invoicesPay = jest.fn(); + const voidInvoice = jest.fn().mockResolvedValue({ id: 'in_1' }); + const service = new BackgroundCheckPaymentService( + { + getClient: () => ({ + invoiceItems: { create: invoiceItemsCreate }, + invoices: { + create: invoicesCreate, + finalizeInvoice, + pay: invoicesPay, + voidInvoice, + }, + }), + } as unknown as StripeService, + { + getBackgroundCheckPrice: jest.fn().mockResolvedValue({ + id: 'price_bg', + unitAmount: 1250, + currency: 'usd', + }), + } as unknown as BackgroundCheckBillingService, + ); + + await expect( + service.charge({ organizationId: 'org_1', memberId: 'mem_1' }), + ).rejects.toThrow( + expect.objectContaining({ + status: HttpStatus.PAYMENT_REQUIRED, + }), + ); + + expect(voidInvoice).toHaveBeenCalledWith('in_1'); + expect(finalizeInvoice).not.toHaveBeenCalled(); + expect(invoicesPay).not.toHaveBeenCalled(); + }); + + it('voids the invoice when finalizing fails', async () => { + mockAsync>>( + mockedDb.organizationBilling.findUnique, + ).mockResolvedValueOnce({ + stripeCustomerId: 'cus_1', + stripeBackgroundCheckPaymentMethodId: 'pm_1', + } as Awaited>); + const invoiceItemsCreate = jest.fn().mockResolvedValue({ id: 'ii_1' }); + const invoicesCreate = jest.fn().mockResolvedValue({ id: 'in_1' }); + const finalizeInvoice = jest + .fn() + .mockRejectedValue(new Error('finalize failed')); + const invoicesPay = jest.fn(); + const voidInvoice = jest.fn().mockResolvedValue({ id: 'in_1' }); + const service = new BackgroundCheckPaymentService( + { + getClient: () => ({ + invoiceItems: { create: invoiceItemsCreate }, + invoices: { + create: invoicesCreate, + finalizeInvoice, + pay: invoicesPay, + voidInvoice, + }, + }), + } as unknown as StripeService, + { + getBackgroundCheckPrice: jest.fn().mockResolvedValue({ + id: 'price_bg', + unitAmount: 1250, + currency: 'usd', + }), + } as unknown as BackgroundCheckBillingService, + ); + + await expect( + service.charge({ organizationId: 'org_1', memberId: 'mem_1' }), + ).rejects.toThrow( + expect.objectContaining({ + status: HttpStatus.PAYMENT_REQUIRED, + }), + ); + + expect(voidInvoice).toHaveBeenCalledWith('in_1'); + expect(invoicesPay).not.toHaveBeenCalled(); + }); }); diff --git a/apps/api/src/background-checks/background-check-payment.service.ts b/apps/api/src/background-checks/background-check-payment.service.ts index b49fa6bcb1..e4b7e5450f 100644 --- a/apps/api/src/background-checks/background-check-payment.service.ts +++ b/apps/api/src/background-checks/background-check-payment.service.ts @@ -70,32 +70,34 @@ export class BackgroundCheckPaymentService { }, ); - await stripe.invoiceItems.create( - { - customer: billing.stripeCustomerId, - invoice: invoice.id, - pricing: { - price: price.id, + let paidInvoice: Stripe.Invoice; + try { + await stripe.invoiceItems.create( + { + customer: billing.stripeCustomerId, + invoice: invoice.id, + pricing: { + price: price.id, + }, + quantity: 1, + description: BackgroundCheckPaymentService.receiptDescription, + metadata, }, - quantity: 1, - description: BackgroundCheckPaymentService.receiptDescription, - metadata, - }, - { - idempotencyKey: [...idempotencyKeyParts, 'line-item'].join(':'), - }, - ); + { + idempotencyKey: [...idempotencyKeyParts, 'line-item'].join(':'), + }, + ); - await stripe.invoices.finalizeInvoice( - invoice.id, - { auto_advance: false }, - { - idempotencyKey: [...idempotencyKeyParts, 'finalize-invoice'].join(':'), - }, - ); + await stripe.invoices.finalizeInvoice( + invoice.id, + { auto_advance: false }, + { + idempotencyKey: [...idempotencyKeyParts, 'finalize-invoice'].join( + ':', + ), + }, + ); - let paidInvoice: Stripe.Invoice; - try { paidInvoice = await stripe.invoices.pay( invoice.id, { diff --git a/apps/api/src/background-checks/background-checks.service.spec.ts b/apps/api/src/background-checks/background-checks.service.spec.ts index 0fc3c88d7c..67771a382d 100644 --- a/apps/api/src/background-checks/background-checks.service.spec.ts +++ b/apps/api/src/background-checks/background-checks.service.spec.ts @@ -463,6 +463,7 @@ describe('background checks', () => { }, customers: { create: jest.fn().mockResolvedValue({ id: 'cus_1' }), + update: jest.fn().mockResolvedValue({ id: 'cus_1' }), }, prices: { retrieve: jest.fn().mockResolvedValue({ @@ -487,10 +488,15 @@ describe('background checks', () => { }), ).resolves.toEqual({ url: 'https://checkout.stripe.com/c/session_1' }); - expect(stripe.customers.create).toHaveBeenCalledWith({ - name: 'Acme', + expect(stripe.customers.create).toHaveBeenCalledWith( + { + name: 'Acme', + metadata: { organizationId: 'org_1' }, + }, + { idempotencyKey: 'background-check-customer:org_1' }, + ); + expect(stripe.customers.update).toHaveBeenCalledWith('cus_1', { email: 'billing@trycomp.ai', - metadata: { organizationId: 'org_1' }, }); }); From 435c235cf10304b92c2a39a091c34e1a12dba2c2 Mon Sep 17 00:00:00 2001 From: Lewis Carhart Date: Thu, 30 Apr 2026 20:49:14 +0100 Subject: [PATCH 04/20] feat(billing): integrate billing module and enhance background check services - Added the @trycompai/billing package to the workspace and integrated it into the API. - Updated BackgroundCheckPaymentService to utilize the new BillingEntitlementsService for managing billing entitlements and usage. - Refactored background check billing logic to improve invoice handling and payment processing. - Introduced new billing-related endpoints and services, including billing customer management and usage tracking. - Enhanced tests for billing functionalities to ensure robust coverage of new features. --- .agents/skills/billing/SKILL.md | 59 + apps/api/package.json | 4 +- apps/api/src/app.module.ts | 2 + .../background-check-billing.service.ts | 248 +- .../background-check-payment.service.spec.ts | 110 +- .../background-check-payment.service.ts | 102 +- .../background-checks.module.ts | 3 +- .../background-checks.service.spec.ts | 136 +- apps/api/src/billing/billing-customer.ts | 59 + .../billing/billing-entitlements.service.ts | 298 ++ .../src/billing/billing-entitlements.types.ts | 47 + apps/api/src/billing/billing-invoices.spec.ts | 105 + apps/api/src/billing/billing-invoices.ts | 95 + .../src/billing/billing-preferences.spec.ts | 120 + apps/api/src/billing/billing-preferences.ts | 261 ++ apps/api/src/billing/billing-redirect-urls.ts | 25 + apps/api/src/billing/billing-usage.spec.ts | 87 + apps/api/src/billing/billing-usage.ts | 150 + .../src/billing/billing-webhook.service.ts | 270 ++ apps/api/src/billing/billing.controller.ts | 153 + apps/api/src/billing/billing.module.ts | 18 + apps/api/src/billing/billing.service.spec.ts | 109 + apps/api/src/billing/billing.service.ts | 299 ++ apps/api/src/billing/billing.types.ts | 38 + apps/api/src/billing/dto/billing.dto.ts | 92 + .../api/src/billing/stripe-webhook-records.ts | 67 + apps/api/src/main.ts | 2 + .../security-penetration-tests.module.ts | 3 +- ...security-penetration-tests.service.spec.ts | 16 + .../security-penetration-tests.service.ts | 135 +- apps/app/package.json | 3 +- .../billing/BillingAddOnPlansClient.tsx | 124 + .../settings/billing/BillingAddOns.test.tsx | 163 + .../billing/BillingAddOnsOverview.tsx | 74 + .../billing/BillingPaymentMethodCard.tsx | 75 + .../billing/BillingPreferencesForm.tsx | 294 ++ .../billing/BillingSettingsClient.test.tsx | 116 +- .../billing/BillingSettingsClient.tsx | 201 +- .../billing/BillingSettingsDetails.test.tsx | 197 ++ .../billing/BillingSubscriptionPlans.tsx | 97 + .../settings/billing/BillingUsageTable.tsx | 143 + .../settings/billing/add-ons/[addOn]/page.tsx | 40 + .../[orgId]/settings/billing/billingAddOns.ts | 37 + .../billing/billingPreferencesFormSchema.ts | 128 + .../settings/billing/emptyBillingStatus.ts | 14 + .../(app)/[orgId]/settings/billing/page.tsx | 15 +- .../(app)/[orgId]/settings/billing/types.ts | 45 + apps/framework-editor/package.json | 2 +- apps/framework-editor/prisma/schema.prisma | 2611 ----------------- apps/portal/package.json | 2 +- bun.lock | 12 + package.json | 3 +- packages/billing/package.json | 22 + packages/billing/src/catalog.test.ts | 52 + packages/billing/src/index.ts | 142 + packages/billing/tsconfig.json | 9 + .../migration.sql | 125 + .../migration.sql | 1 + .../migration.sql | 2 + .../db/prisma/schema/background-check.prisma | 4 +- .../prisma/schema/organization-billing.prisma | 96 +- packages/db/prisma/schema/organization.prisma | 7 +- .../db/prisma/schema/pentest-credits.prisma | 2 +- packages/docs/openapi.json | 210 ++ scripts/check-generated-prisma-schemas.js | 61 + 65 files changed, 5062 insertions(+), 3180 deletions(-) create mode 100644 .agents/skills/billing/SKILL.md create mode 100644 apps/api/src/billing/billing-customer.ts create mode 100644 apps/api/src/billing/billing-entitlements.service.ts create mode 100644 apps/api/src/billing/billing-entitlements.types.ts create mode 100644 apps/api/src/billing/billing-invoices.spec.ts create mode 100644 apps/api/src/billing/billing-invoices.ts create mode 100644 apps/api/src/billing/billing-preferences.spec.ts create mode 100644 apps/api/src/billing/billing-preferences.ts create mode 100644 apps/api/src/billing/billing-redirect-urls.ts create mode 100644 apps/api/src/billing/billing-usage.spec.ts create mode 100644 apps/api/src/billing/billing-usage.ts create mode 100644 apps/api/src/billing/billing-webhook.service.ts create mode 100644 apps/api/src/billing/billing.controller.ts create mode 100644 apps/api/src/billing/billing.module.ts create mode 100644 apps/api/src/billing/billing.service.spec.ts create mode 100644 apps/api/src/billing/billing.service.ts create mode 100644 apps/api/src/billing/billing.types.ts create mode 100644 apps/api/src/billing/dto/billing.dto.ts create mode 100644 apps/api/src/billing/stripe-webhook-records.ts create mode 100644 apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnPlansClient.tsx create mode 100644 apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOns.test.tsx create mode 100644 apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnsOverview.tsx create mode 100644 apps/app/src/app/(app)/[orgId]/settings/billing/BillingPaymentMethodCard.tsx create mode 100644 apps/app/src/app/(app)/[orgId]/settings/billing/BillingPreferencesForm.tsx create mode 100644 apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsDetails.test.tsx create mode 100644 apps/app/src/app/(app)/[orgId]/settings/billing/BillingSubscriptionPlans.tsx create mode 100644 apps/app/src/app/(app)/[orgId]/settings/billing/BillingUsageTable.tsx create mode 100644 apps/app/src/app/(app)/[orgId]/settings/billing/add-ons/[addOn]/page.tsx create mode 100644 apps/app/src/app/(app)/[orgId]/settings/billing/billingAddOns.ts create mode 100644 apps/app/src/app/(app)/[orgId]/settings/billing/billingPreferencesFormSchema.ts create mode 100644 apps/app/src/app/(app)/[orgId]/settings/billing/emptyBillingStatus.ts delete mode 100644 apps/framework-editor/prisma/schema.prisma create mode 100644 packages/billing/package.json create mode 100644 packages/billing/src/catalog.test.ts create mode 100644 packages/billing/src/index.ts create mode 100644 packages/billing/tsconfig.json create mode 100644 packages/db/prisma/migrations/20260430180000_subscription_billing_foundation/migration.sql create mode 100644 packages/db/prisma/migrations/20260430181000_drop_stale_pentest_subscription/migration.sql create mode 100644 packages/db/prisma/migrations/20260430182000_name_billing_subscription_unique_index/migration.sql create mode 100644 scripts/check-generated-prisma-schemas.js diff --git a/.agents/skills/billing/SKILL.md b/.agents/skills/billing/SKILL.md new file mode 100644 index 0000000000..d92c1cb48a --- /dev/null +++ b/.agents/skills/billing/SKILL.md @@ -0,0 +1,59 @@ +--- +name: billing +description: Use when changing Comp AI billing, Stripe products/prices, subscription checkout, org payment methods, entitlements, usage ledgers, invoices, or billing webhooks. +metadata: + short-description: Comp AI billing architecture +--- + +# Billing + +Comp AI billing is SKU-first and subscription-ready. Stripe is the payment provider; Comp AI owns catalog definitions, entitlement state, usage gating, and audit history. + +## Core Rules + +- Use `@trycompai/billing` for SKU keys, amounts, Stripe product IDs, Stripe price IDs, cadence, and included usage. +- Do not add product-specific nullable fields to `OrganizationBilling`. +- Keep org-level billing generic: `stripeCustomerId`, `stripePaymentMethodId`, and `paymentMethodUpdatedAt`. +- Store per-product state in generic per-SKU tables keyed by `skuKey`. +- Gate paid actions from local entitlement/usage state, not directly from Stripe object reads. +- Treat Stripe webhooks as eventually consistent, retryable, and possibly out of order. +- Make webhook and entitlement handling idempotent with Stripe event IDs, invoice IDs, subscription item IDs, and period start/end. +- Archive accidental live/test Stripe objects instead of reusing them blindly. + +## Current SKU Shape + +- `background_check_one_time`: one-time `$49` background check. +- `background_checks_monthly_25`: `$249/mo`, includes 25 background checks per month. +- `pentest_monthly_5`: `$399/mo`, includes 5 penetration-test scans per month. + +Live catalog entries should only be added after deliberate live Stripe object creation. + +## Implementation Pattern + +1. Add or update the SKU in `packages/billing/src/catalog.ts`. +2. Store Stripe IDs in the catalog by environment rather than adding new env vars. +3. Create subscriptions through Stripe Checkout `mode: subscription`. +4. Use the shared Stripe customer default payment method unless the product explicitly needs overrides. +5. For allowance products, sync local subscription state from Stripe subscription items and period timestamps. +6. Record usage in the billing usage ledger with a stable idempotency key. +7. Record support-relevant mutations in billing audit events. + +## Webhooks + +Handle these events before launching subscription access: + +- `checkout.session.completed` +- `invoice.paid` +- `invoice.payment_failed` +- `invoice.payment_action_required` +- `customer.subscription.updated` +- `customer.subscription.deleted` + +Provision or renew allowance only for the matching subscription item and SKU. Do not use generic invoice-level periods for multi-item subscriptions. + +## Validation + +- Run `bun run db:generate` after Prisma schema changes. +- Run `bun run check:prisma-schemas` to catch stale copied schema fragments. +- Run catalog tests after SKU changes. +- Add webhook tests for duplicate events, out-of-order delivery, payment failure, action required, cancellation, and renewal. diff --git a/apps/api/package.json b/apps/api/package.json index b36942bebc..a43b296328 100644 --- a/apps/api/package.json +++ b/apps/api/package.json @@ -78,6 +78,7 @@ "@trigger.dev/build": "4.4.3", "@trigger.dev/sdk": "4.4.3", "@trycompai/auth": "workspace:*", + "@trycompai/billing": "workspace:*", "@trycompai/company": "workspace:*", "@trycompai/db": "workspace:*", "@trycompai/email": "workspace:*", @@ -170,6 +171,7 @@ "^@db$": "/../prisma/index", "^@/(.*)$": "/$1", "^@trycompai/auth$": "/../../../packages/auth/src/index.ts", + "^@trycompai/billing$": "/../../../packages/billing/src/index.ts", "^@trycompai/company$": "/../../../packages/company/src/index.ts", "^@trycompai/db$": "@prisma/client", "^@trycompai/email$": "/../../../packages/email/index.ts", @@ -183,7 +185,7 @@ "build": "nest build", "build:docker": "bunx prisma generate --schema=prisma/schema && nest build", "db:generate": "bun run db:getschema && bunx prisma generate --schema=prisma/schema", - "db:getschema": "find ../../packages/db/prisma/schema -name '*.prisma' ! -name 'schema.prisma' -exec cp {} prisma/schema/ \\;", + "db:getschema": "find prisma/schema -name '*.prisma' ! -name 'schema.prisma' -delete && find ../../packages/db/prisma/schema -name '*.prisma' ! -name 'schema.prisma' -exec cp {} prisma/schema/ \\;", "db:migrate": "cd ../../packages/db && bunx prisma migrate dev && cd ../../apps/api", "deploy:trigger-prod": "npx trigger.dev@4.4.3 deploy", "dev": "bunx concurrently --kill-others --names \"nest,trigger\" --prefix-colors \"green,blue\" \"nest start --watch\" \"trigger dev\"", diff --git a/apps/api/src/app.module.ts b/apps/api/src/app.module.ts index ce20e1782d..318ba6afde 100644 --- a/apps/api/src/app.module.ts +++ b/apps/api/src/app.module.ts @@ -53,6 +53,7 @@ import { AdminOrganizationsModule } from './admin-organizations/admin-organizati import { AdminFeatureFlagsModule } from './admin-feature-flags/admin-feature-flags.module'; import { TimelinesModule } from './timelines/timelines.module'; import { BackgroundChecksModule } from './background-checks/background-checks.module'; +import { BillingModule } from './billing/billing.module'; @Module({ imports: [ @@ -114,6 +115,7 @@ import { BackgroundChecksModule } from './background-checks/background-checks.mo SecretsModule, SecurityPenetrationTestsModule, StripeModule, + BillingModule, BackgroundChecksModule, AdminOrganizationsModule, AdminFeatureFlagsModule, diff --git a/apps/api/src/background-checks/background-check-billing.service.ts b/apps/api/src/background-checks/background-check-billing.service.ts index da138b69ef..8622e58f0c 100644 --- a/apps/api/src/background-checks/background-check-billing.service.ts +++ b/apps/api/src/background-checks/background-check-billing.service.ts @@ -1,203 +1,35 @@ -import { - BadRequestException, - Injectable, - NotFoundException, -} from '@nestjs/common'; -import { db } from '@db'; -import { StripeService } from '../stripe/stripe.service'; -import { findOrCreateBackgroundCheckBillingCustomer } from './background-check-billing-customer'; -import { - type BackgroundCheckBillingInvoice, - listBackgroundCheckBillingInvoices, -} from './background-check-billing-invoices'; -import { validateBackgroundCheckBillingRedirectUrl } from './background-check-billing-urls'; +import { Injectable } from '@nestjs/common'; +import { BillingService } from '../billing/billing.service'; @Injectable() export class BackgroundCheckBillingService { - constructor(private readonly stripeService: StripeService) {} + constructor(private readonly billingService: BillingService) {} - async getStatus(organizationId: string): Promise<{ - hasBilling: boolean; - hasPaymentMethod: boolean; - setupAt: Date | null; - usage: { - backgroundChecks: number; - penetrationTests: number; - }; - invoices: BackgroundCheckBillingInvoice[]; - }> { - const [billing, backgroundChecks, penetrationTests] = await Promise.all([ - db.organizationBilling.findUnique({ - where: { organizationId }, - select: { - stripeCustomerId: true, - stripeBackgroundCheckPaymentMethodId: true, - backgroundCheckPaymentMethodSetupAt: true, - }, - }), - db.backgroundCheckRequest.count({ where: { organizationId } }), - db.securityPenetrationTestRun.count({ where: { organizationId } }), - ]); - const invoices = await listBackgroundCheckBillingInvoices({ - stripeService: this.stripeService, - stripeCustomerId: billing?.stripeCustomerId ?? null, - }); - - return { - hasBilling: !!billing, - hasPaymentMethod: !!billing?.stripeBackgroundCheckPaymentMethodId, - setupAt: billing?.backgroundCheckPaymentMethodSetupAt ?? null, - usage: { - backgroundChecks, - penetrationTests, - }, - invoices, - }; + async getStatus(organizationId: string) { + return this.billingService.getStatus(organizationId); } - async createSetupSession({ - organizationId, - successUrl, - cancelUrl, - customerEmail, - }: { + async createSetupSession(params: { organizationId: string; successUrl: string; cancelUrl: string; customerEmail?: string; }): Promise<{ url: string }> { - validateBackgroundCheckBillingRedirectUrl(successUrl); - validateBackgroundCheckBillingRedirectUrl(cancelUrl); - - const stripe = this.stripeService.getClient(); - const customerId = await findOrCreateBackgroundCheckBillingCustomer({ - stripeService: this.stripeService, - organizationId, - customerEmail, - }); - const price = await this.getBackgroundCheckPrice(); - - const session = await stripe.checkout.sessions.create({ - mode: 'setup', - customer: customerId, - currency: price.currency, - success_url: successUrl, - cancel_url: cancelUrl, - metadata: { - organizationId, - source: 'comp-background-check', - }, - }); - - if (!session.url) { - throw new BadRequestException( - 'Failed to create Stripe Checkout session.', - ); - } - - return { url: session.url }; + return this.billingService.createSetupSession(params); } - async handleSetupSuccess({ - organizationId, - sessionId, - }: { + async handleSetupSuccess(params: { organizationId: string; sessionId: string; }): Promise<{ success: true }> { - const stripe = this.stripeService.getClient(); - const session = await stripe.checkout.sessions.retrieve(sessionId, { - expand: ['setup_intent'], - }); - - if (session.status !== 'complete') { - throw new BadRequestException('Checkout session is not complete.'); - } - - if ( - session.metadata?.organizationId && - session.metadata.organizationId !== organizationId - ) { - throw new BadRequestException( - 'Checkout session does not belong to this organization.', - ); - } - - const stripeCustomerId = this.extractStripeId(session.customer); - if (!stripeCustomerId) { - throw new BadRequestException('Checkout session is missing a customer.'); - } - - await this.assertCustomerBelongsToOrganization({ - organizationId, - stripeCustomerId, - }); - - const setupIntent = session.setup_intent; - if (!setupIntent || typeof setupIntent === 'string') { - throw new BadRequestException( - 'Checkout session is missing a setup intent.', - ); - } - - const paymentMethodId = this.extractStripeId(setupIntent.payment_method); - if (!paymentMethodId) { - throw new BadRequestException( - 'Setup intent is missing a payment method.', - ); - } - - await stripe.customers.update(stripeCustomerId, { - invoice_settings: { - default_payment_method: paymentMethodId, - }, - }); - - await db.organizationBilling.upsert({ - where: { organizationId }, - create: { - organizationId, - stripeCustomerId, - stripeBackgroundCheckPaymentMethodId: paymentMethodId, - backgroundCheckPaymentMethodSetupAt: new Date(), - }, - update: { - stripeCustomerId, - stripeBackgroundCheckPaymentMethodId: paymentMethodId, - backgroundCheckPaymentMethodSetupAt: new Date(), - }, - }); - - return { success: true }; + return this.billingService.handleSetupSuccess(params); } - async createBillingPortalSession({ - organizationId, - returnUrl, - }: { + async createBillingPortalSession(params: { organizationId: string; returnUrl: string; }): Promise<{ url: string }> { - validateBackgroundCheckBillingRedirectUrl(returnUrl); - - const stripe = this.stripeService.getClient(); - const billing = await db.organizationBilling.findUnique({ - where: { organizationId }, - select: { stripeCustomerId: true }, - }); - - if (!billing) { - throw new NotFoundException( - 'No billing record found for this organization.', - ); - } - - const portalSession = await stripe.billingPortal.sessions.create({ - customer: billing.stripeCustomerId, - return_url: returnUrl, - }); - - return { url: portalSession.url }; + return this.billingService.createBillingPortalSession(params); } async getBackgroundCheckPrice(): Promise<{ @@ -205,61 +37,11 @@ export class BackgroundCheckBillingService { unitAmount: number; currency: string; }> { - const priceId = process.env.STRIPE_BACKGROUND_CHECK_PRICE_ID; - if (!priceId) { - throw new BadRequestException( - 'Background check pricing is not configured. Contact support.', - ); - } - - const stripe = this.stripeService.getClient(); - const price = await stripe.prices.retrieve(priceId); - if (price.unit_amount === null || price.unit_amount === undefined) { - throw new BadRequestException( - 'Background check pricing is not configured. Contact support.', - ); - } - + const sku = this.billingService.getOneTimeBackgroundCheckSku(); return { - id: price.id, - unitAmount: price.unit_amount, - currency: price.currency, + id: sku.stripePriceId, + unitAmount: sku.unitAmount, + currency: sku.currency, }; } - - private extractStripeId( - value: string | { id?: string } | null, - ): string | null { - if (!value) return null; - if (typeof value === 'string') return value; - return value.id ?? null; - } - - private async assertCustomerBelongsToOrganization({ - organizationId, - stripeCustomerId, - }: { - organizationId: string; - stripeCustomerId: string; - }): Promise { - const billing = await db.organizationBilling.findUnique({ - where: { organizationId }, - select: { stripeCustomerId: true }, - }); - - if (billing?.stripeCustomerId === stripeCustomerId) { - return; - } - - const stripe = this.stripeService.getClient(); - const customer = await stripe.customers.retrieve(stripeCustomerId); - if ( - customer.deleted || - customer.metadata?.organizationId !== organizationId - ) { - throw new BadRequestException( - 'Checkout session does not belong to this organization.', - ); - } - } } diff --git a/apps/api/src/background-checks/background-check-payment.service.spec.ts b/apps/api/src/background-checks/background-check-payment.service.spec.ts index e17936cec9..bc5041bb16 100644 --- a/apps/api/src/background-checks/background-check-payment.service.spec.ts +++ b/apps/api/src/background-checks/background-check-payment.service.spec.ts @@ -1,5 +1,6 @@ -import { HttpException, HttpStatus } from '@nestjs/common'; +import { HttpStatus } from '@nestjs/common'; import { db } from '@db'; +import { BillingEntitlementsService } from '../billing/billing-entitlements.service'; import { StripeService } from '../stripe/stripe.service'; import { BackgroundCheckBillingService } from './background-check-billing.service'; import { BackgroundCheckPaymentService } from './background-check-payment.service'; @@ -18,6 +19,33 @@ function mockAsync(fn: unknown): jest.MockedFunction<() => Promise> { return fn as jest.MockedFunction<() => Promise>; } +function mockEntitlements( + overrides: Partial = {}, +): BillingEntitlementsService { + return { + tryConsumeIncludedUsage: jest + .fn() + .mockResolvedValue({ status: 'not_configured' }), + recordOneTimeUsage: jest.fn().mockResolvedValue(undefined), + refundIncludedUsage: jest.fn().mockResolvedValue(undefined), + syncSubscriptionItem: jest.fn().mockResolvedValue(undefined), + writeAuditEvent: jest.fn().mockResolvedValue(undefined), + ...overrides, + } as unknown as BillingEntitlementsService; +} + +function mockBillingRow() { + return { + id: 'obil_1', + organizationId: 'org_1', + stripeCustomerId: 'cus_1', + stripePaymentMethodId: 'pm_1', + paymentMethodUpdatedAt: null, + createdAt: new Date('2026-04-30T00:00:00.000Z'), + updatedAt: new Date('2026-04-30T00:00:00.000Z'), + }; +} + describe('BackgroundCheckPaymentService', () => { beforeEach(() => { jest.clearAllMocks(); @@ -32,6 +60,7 @@ describe('BackgroundCheckPaymentService', () => { { getBackgroundCheckPrice: jest.fn(), } as unknown as BackgroundCheckBillingService, + mockEntitlements(), ); await expect( @@ -46,10 +75,7 @@ describe('BackgroundCheckPaymentService', () => { it('creates and pays a Stripe invoice with payment-method scoped idempotency keys', async () => { mockAsync>>( mockedDb.organizationBilling.findUnique, - ).mockResolvedValueOnce({ - stripeCustomerId: 'cus_1', - stripeBackgroundCheckPaymentMethodId: 'pm_1', - } as Awaited>); + ).mockResolvedValueOnce(mockBillingRow()); const invoiceItemsCreate = jest.fn().mockResolvedValue({ id: 'ii_1' }); const invoicesCreate = jest.fn().mockResolvedValue({ id: 'in_1' }); const finalizeInvoice = jest.fn().mockResolvedValue({ id: 'in_1' }); @@ -85,6 +111,7 @@ describe('BackgroundCheckPaymentService', () => { currency: 'usd', }), } as unknown as BackgroundCheckBillingService, + mockEntitlements(), ); await expect( @@ -139,20 +166,18 @@ describe('BackgroundCheckPaymentService', () => { ); }); - it('voids the invoice when adding the invoice item fails', async () => { + it('deletes the draft invoice when adding the invoice item fails', async () => { mockAsync>>( mockedDb.organizationBilling.findUnique, - ).mockResolvedValueOnce({ - stripeCustomerId: 'cus_1', - stripeBackgroundCheckPaymentMethodId: 'pm_1', - } as Awaited>); + ).mockResolvedValueOnce(mockBillingRow()); const invoiceItemsCreate = jest .fn() .mockRejectedValue(new Error('line item failed')); const invoicesCreate = jest.fn().mockResolvedValue({ id: 'in_1' }); const finalizeInvoice = jest.fn(); const invoicesPay = jest.fn(); - const voidInvoice = jest.fn().mockResolvedValue({ id: 'in_1' }); + const deleteInvoice = jest.fn().mockResolvedValue({ id: 'in_1' }); + const voidInvoice = jest.fn(); const service = new BackgroundCheckPaymentService( { getClient: () => ({ @@ -161,6 +186,7 @@ describe('BackgroundCheckPaymentService', () => { create: invoicesCreate, finalizeInvoice, pay: invoicesPay, + del: deleteInvoice, voidInvoice, }, }), @@ -172,6 +198,7 @@ describe('BackgroundCheckPaymentService', () => { currency: 'usd', }), } as unknown as BackgroundCheckBillingService, + mockEntitlements(), ); await expect( @@ -182,24 +209,69 @@ describe('BackgroundCheckPaymentService', () => { }), ); - expect(voidInvoice).toHaveBeenCalledWith('in_1'); + expect(deleteInvoice).toHaveBeenCalledWith('in_1'); + expect(voidInvoice).not.toHaveBeenCalled(); expect(finalizeInvoice).not.toHaveBeenCalled(); expect(invoicesPay).not.toHaveBeenCalled(); }); - it('voids the invoice when finalizing fails', async () => { + it('deletes the draft invoice when finalizing fails', async () => { mockAsync>>( mockedDb.organizationBilling.findUnique, - ).mockResolvedValueOnce({ - stripeCustomerId: 'cus_1', - stripeBackgroundCheckPaymentMethodId: 'pm_1', - } as Awaited>); + ).mockResolvedValueOnce(mockBillingRow()); const invoiceItemsCreate = jest.fn().mockResolvedValue({ id: 'ii_1' }); const invoicesCreate = jest.fn().mockResolvedValue({ id: 'in_1' }); const finalizeInvoice = jest .fn() .mockRejectedValue(new Error('finalize failed')); const invoicesPay = jest.fn(); + const deleteInvoice = jest.fn().mockResolvedValue({ id: 'in_1' }); + const voidInvoice = jest.fn(); + const service = new BackgroundCheckPaymentService( + { + getClient: () => ({ + invoiceItems: { create: invoiceItemsCreate }, + invoices: { + create: invoicesCreate, + finalizeInvoice, + pay: invoicesPay, + del: deleteInvoice, + voidInvoice, + }, + }), + } as unknown as StripeService, + { + getBackgroundCheckPrice: jest.fn().mockResolvedValue({ + id: 'price_bg', + unitAmount: 1250, + currency: 'usd', + }), + } as unknown as BackgroundCheckBillingService, + mockEntitlements(), + ); + + await expect( + service.charge({ organizationId: 'org_1', memberId: 'mem_1' }), + ).rejects.toThrow( + expect.objectContaining({ + status: HttpStatus.PAYMENT_REQUIRED, + }), + ); + + expect(deleteInvoice).toHaveBeenCalledWith('in_1'); + expect(voidInvoice).not.toHaveBeenCalled(); + expect(invoicesPay).not.toHaveBeenCalled(); + }); + + it('voids the finalized invoice when paying fails', async () => { + mockAsync>>( + mockedDb.organizationBilling.findUnique, + ).mockResolvedValueOnce(mockBillingRow()); + const invoiceItemsCreate = jest.fn().mockResolvedValue({ id: 'ii_1' }); + const invoicesCreate = jest.fn().mockResolvedValue({ id: 'in_1' }); + const finalizeInvoice = jest.fn().mockResolvedValue({ id: 'in_1' }); + const invoicesPay = jest.fn().mockRejectedValue(new Error('pay failed')); + const deleteInvoice = jest.fn(); const voidInvoice = jest.fn().mockResolvedValue({ id: 'in_1' }); const service = new BackgroundCheckPaymentService( { @@ -209,6 +281,7 @@ describe('BackgroundCheckPaymentService', () => { create: invoicesCreate, finalizeInvoice, pay: invoicesPay, + del: deleteInvoice, voidInvoice, }, }), @@ -220,6 +293,7 @@ describe('BackgroundCheckPaymentService', () => { currency: 'usd', }), } as unknown as BackgroundCheckBillingService, + mockEntitlements(), ); await expect( @@ -230,7 +304,7 @@ describe('BackgroundCheckPaymentService', () => { }), ); + expect(deleteInvoice).not.toHaveBeenCalled(); expect(voidInvoice).toHaveBeenCalledWith('in_1'); - expect(invoicesPay).not.toHaveBeenCalled(); }); }); diff --git a/apps/api/src/background-checks/background-check-payment.service.ts b/apps/api/src/background-checks/background-check-payment.service.ts index e4b7e5450f..a976b1e575 100644 --- a/apps/api/src/background-checks/background-check-payment.service.ts +++ b/apps/api/src/background-checks/background-check-payment.service.ts @@ -1,6 +1,7 @@ import { HttpException, HttpStatus, Injectable, Logger } from '@nestjs/common'; import { db } from '@db'; import Stripe from 'stripe'; +import { BillingEntitlementsService } from '../billing/billing-entitlements.service'; import { StripeService } from '../stripe/stripe.service'; import { BackgroundCheckBillingService } from './background-check-billing.service'; @@ -15,26 +16,42 @@ export class BackgroundCheckPaymentService { constructor( private readonly stripeService: StripeService, private readonly billingService: BackgroundCheckBillingService, + private readonly entitlements: BillingEntitlementsService, ) {} async charge(params: { organizationId: string; memberId: string }): Promise<{ - paymentIntentId: string; - invoiceId: string; + paymentIntentId: string | null; + invoiceId: string | null; status: string; amount: number; currency: string; }> { + const includedUsage = await this.entitlements.tryConsumeIncludedUsage({ + organizationId: params.organizationId, + skuKey: 'background_checks_monthly_25', + sourceResourceId: params.memberId, + }); + if (includedUsage.status === 'consumed') { + return { + paymentIntentId: null, + invoiceId: null, + status: 'subscription_included', + amount: 0, + currency: 'usd', + }; + } + const billing = await db.organizationBilling.findUnique({ where: { organizationId: params.organizationId }, select: { stripeCustomerId: true, - stripeBackgroundCheckPaymentMethodId: true, + stripePaymentMethodId: true, }, }); - if (!billing?.stripeBackgroundCheckPaymentMethodId) { + if (!billing?.stripePaymentMethodId) { throw new HttpException( - 'No background check payment method on file. Update billing first.', + 'No payment method on file. Update billing first.', HttpStatus.PAYMENT_REQUIRED, ); } @@ -51,7 +68,7 @@ export class BackgroundCheckPaymentService { params.organizationId, params.memberId, price.id, - billing.stripeBackgroundCheckPaymentMethodId, + billing.stripePaymentMethodId, ]; const invoice = await stripe.invoices.create( @@ -59,7 +76,7 @@ export class BackgroundCheckPaymentService { customer: billing.stripeCustomerId, collection_method: 'charge_automatically', currency: price.currency, - default_payment_method: billing.stripeBackgroundCheckPaymentMethodId, + default_payment_method: billing.stripePaymentMethodId, description: BackgroundCheckPaymentService.receiptDescription, statement_descriptor: BackgroundCheckPaymentService.statementDescriptor, auto_advance: false, @@ -71,6 +88,7 @@ export class BackgroundCheckPaymentService { ); let paidInvoice: Stripe.Invoice; + let invoiceFinalized = false; try { await stripe.invoiceItems.create( { @@ -97,11 +115,12 @@ export class BackgroundCheckPaymentService { ), }, ); + invoiceFinalized = true; paidInvoice = await stripe.invoices.pay( invoice.id, { - payment_method: billing.stripeBackgroundCheckPaymentMethodId, + payment_method: billing.stripePaymentMethodId, off_session: true, expand: ['payments'], }, @@ -110,7 +129,11 @@ export class BackgroundCheckPaymentService { }, ); } catch (error) { - await this.voidInvoice({ stripe, invoiceId: invoice.id }); + await this.cleanupUnpaidInvoice({ + stripe, + invoiceId: invoice.id, + finalized: invoiceFinalized, + }); throw new HttpException( 'Background check payment failed. Update billing and try again.', HttpStatus.PAYMENT_REQUIRED, @@ -119,7 +142,11 @@ export class BackgroundCheckPaymentService { } if (paidInvoice.status !== 'paid') { - await this.voidInvoice({ stripe, invoiceId: invoice.id }); + await this.cleanupUnpaidInvoice({ + stripe, + invoiceId: invoice.id, + finalized: true, + }); throw new HttpException( 'Background check payment failed. Update billing and try again.', HttpStatus.PAYMENT_REQUIRED, @@ -134,6 +161,13 @@ export class BackgroundCheckPaymentService { ); } + await this.entitlements.recordOneTimeUsage({ + organizationId: params.organizationId, + skuKey: 'background_check_one_time', + sourceResourceId: params.memberId, + stripeInvoiceId: paidInvoice.id, + }); + return { paymentIntentId, invoiceId: paidInvoice.id, @@ -146,8 +180,18 @@ export class BackgroundCheckPaymentService { async refund(params: { organizationId: string; memberId: string; - paymentIntentId: string; + paymentIntentId: string | null; }): Promise { + if (!params.paymentIntentId) { + await this.entitlements.refundIncludedUsage({ + organizationId: params.organizationId, + skuKey: 'background_checks_monthly_25', + sourceResourceId: params.memberId, + reason: 'background_check_failed', + }); + return null; + } + try { const stripe = this.stripeService.getClient(); const refund = await stripe.refunds.create( @@ -185,7 +229,41 @@ export class BackgroundCheckPaymentService { return typeof paymentIntent === 'string' ? paymentIntent : paymentIntent.id; } - private async voidInvoice({ + private async cleanupUnpaidInvoice({ + stripe, + invoiceId, + finalized, + }: { + stripe: Stripe; + invoiceId: string; + finalized: boolean; + }): Promise { + if (!finalized) { + await this.deleteDraftInvoice({ stripe, invoiceId }); + return; + } + + await this.voidFinalizedInvoice({ stripe, invoiceId }); + } + + private async deleteDraftInvoice({ + stripe, + invoiceId, + }: { + stripe: Stripe; + invoiceId: string; + }): Promise { + try { + await stripe.invoices.del(invoiceId); + } catch (error) { + this.logger.error('Failed to delete draft background check invoice.', { + invoiceId, + error: error instanceof Error ? error.message : 'Unknown error', + }); + } + } + + private async voidFinalizedInvoice({ stripe, invoiceId, }: { diff --git a/apps/api/src/background-checks/background-checks.module.ts b/apps/api/src/background-checks/background-checks.module.ts index b1e16339b6..b643bd468f 100644 --- a/apps/api/src/background-checks/background-checks.module.ts +++ b/apps/api/src/background-checks/background-checks.module.ts @@ -1,6 +1,7 @@ import { Module } from '@nestjs/common'; import { AttachmentsModule } from '../attachments/attachments.module'; import { AuthModule } from '../auth/auth.module'; +import { BillingModule } from '../billing/billing.module'; import { BackgroundCheckBillingController } from './background-check-billing.controller'; import { BackgroundCheckBillingService } from './background-check-billing.service'; import { BackgroundCheckCustomService } from './background-check-custom.service'; @@ -13,7 +14,7 @@ import { import { BackgroundChecksService } from './background-checks.service'; @Module({ - imports: [AuthModule, AttachmentsModule], + imports: [AuthModule, AttachmentsModule, BillingModule], controllers: [ BackgroundChecksController, PeopleBackgroundChecksController, diff --git a/apps/api/src/background-checks/background-checks.service.spec.ts b/apps/api/src/background-checks/background-checks.service.spec.ts index 67771a382d..39aff5847d 100644 --- a/apps/api/src/background-checks/background-checks.service.spec.ts +++ b/apps/api/src/background-checks/background-checks.service.spec.ts @@ -1,9 +1,9 @@ import { BackgroundCheckIdentityClient } from './background-check-identity.client'; +import { BillingService } from '../billing/billing.service'; import { BackgroundCheckBillingService } from './background-check-billing.service'; import { BackgroundCheckPaymentService } from './background-check-payment.service'; import { BackgroundChecksService } from './background-checks.service'; import { db } from '@db'; -import type { StripeService } from '../stripe/stripe.service'; jest.mock('@db', () => { class PrismaClientKnownRequestError extends Error { @@ -432,51 +432,12 @@ describe('background checks', () => { }); it('uses BETTER_AUTH_URL as the local app URL fallback for setup redirects', async () => { - process.env = { - ...process.env, - NEXT_PUBLIC_APP_URL: '', - APP_URL: '', - BETTER_AUTH_URL: 'http://localhost:3000', - }; - mockAsync>>( - mockedDb.organizationBilling.findUnique, - ).mockResolvedValueOnce(null); - mockAsync>>( - mockedDb.organization.findUnique, - ).mockResolvedValueOnce({ - name: 'Acme', - } as Awaited>); - mockAsync>>( - mockedDb.organizationBilling.create, - ).mockResolvedValueOnce({ - organizationId: 'org_1', - stripeCustomerId: 'cus_1', - } as Awaited>); - - const stripe = { - checkout: { - sessions: { - create: jest.fn().mockResolvedValue({ - url: 'https://checkout.stripe.com/c/session_1', - }), - }, - }, - customers: { - create: jest.fn().mockResolvedValue({ id: 'cus_1' }), - update: jest.fn().mockResolvedValue({ id: 'cus_1' }), - }, - prices: { - retrieve: jest.fn().mockResolvedValue({ - id: 'price_bg', - unit_amount: 4900, - currency: 'usd', - }), - }, - }; - const stripeService = { - getClient: () => stripe, - } as unknown as StripeService; - const service = new BackgroundCheckBillingService(stripeService); + const billingService = { + createSetupSession: jest.fn().mockResolvedValue({ + url: 'https://checkout.stripe.com/c/session_1', + }), + } as unknown as BillingService; + const service = new BackgroundCheckBillingService(billingService); await expect( service.createSetupSession({ @@ -488,54 +449,41 @@ describe('background checks', () => { }), ).resolves.toEqual({ url: 'https://checkout.stripe.com/c/session_1' }); - expect(stripe.customers.create).toHaveBeenCalledWith( - { - name: 'Acme', - metadata: { organizationId: 'org_1' }, - }, - { idempotencyKey: 'background-check-customer:org_1' }, - ); - expect(stripe.customers.update).toHaveBeenCalledWith('cus_1', { - email: 'billing@trycomp.ai', + expect(billingService.createSetupSession).toHaveBeenCalledWith({ + organizationId: 'org_1', + successUrl: + 'http://localhost:3000/org_1/people/mem_1?background_check_billing=success', + cancelUrl: 'http://localhost:3000/org_1/people/mem_1', + customerEmail: 'billing@trycomp.ai', }); }); it('includes background check and penetration test usage in billing status', async () => { - mockAsync>>( - mockedDb.organizationBilling.findUnique, - ).mockResolvedValueOnce({ - stripeCustomerId: 'cus_1', - stripeBackgroundCheckPaymentMethodId: 'pm_1', - backgroundCheckPaymentMethodSetupAt: new Date('2026-04-29T12:00:00.000Z'), - } as Awaited>); - mockAsync( - mockedDb.backgroundCheckRequest.count, - ).mockResolvedValueOnce(4); - mockAsync( - mockedDb.securityPenetrationTestRun.count, - ).mockResolvedValueOnce(2); - - const invoicesList = jest.fn().mockResolvedValue({ - data: [ - { - id: 'in_1', - number: 'INV-001', - created: 1777464000, - due_date: null, - amount_paid: 4900, - amount_due: 4900, - currency: 'usd', - status: 'paid', - parent: null, - hosted_invoice_url: 'https://invoice.stripe.com/i/in_1', - invoice_pdf: 'https://invoice.stripe.com/i/in_1.pdf', - }, - ], - }); - const service = new BackgroundCheckBillingService({ - getClient: () => ({ invoices: { list: invoicesList } }), - isConfigured: () => true, - } as unknown as StripeService); + const billingService = { + getStatus: jest.fn().mockResolvedValue({ + hasBilling: true, + hasPaymentMethod: true, + setupAt: new Date('2026-04-29T12:00:00.000Z'), + usage: { backgroundChecks: 4, penetrationTests: 2 }, + subscriptions: [], + invoices: [ + { + id: 'in_1', + number: 'INV-001', + createdAt: '2026-04-30T00:00:00.000Z', + dueDate: null, + amountPaid: 4900, + amountDue: 4900, + currency: 'usd', + status: 'paid', + type: 'One Time', + hostedInvoiceUrl: 'https://invoice.stripe.com/i/in_1', + invoicePdfUrl: 'https://invoice.stripe.com/i/in_1.pdf', + }, + ], + }), + } as unknown as BillingService; + const service = new BackgroundCheckBillingService(billingService); await expect(service.getStatus('org_1')).resolves.toMatchObject({ hasBilling: true, @@ -554,12 +502,6 @@ describe('background checks', () => { }, ], }); - expect(invoicesList).toHaveBeenCalledWith({ customer: 'cus_1', limit: 10 }); - expect(mockedDb.backgroundCheckRequest.count).toHaveBeenCalledWith({ - where: { organizationId: 'org_1' }, - }); - expect(mockedDb.securityPenetrationTestRun.count).toHaveBeenCalledWith({ - where: { organizationId: 'org_1' }, - }); + expect(billingService.getStatus).toHaveBeenCalledWith('org_1'); }); }); diff --git a/apps/api/src/billing/billing-customer.ts b/apps/api/src/billing/billing-customer.ts new file mode 100644 index 0000000000..27e1276ec7 --- /dev/null +++ b/apps/api/src/billing/billing-customer.ts @@ -0,0 +1,59 @@ +import { Prisma, db } from '@db'; +import type { StripeService } from '../stripe/stripe.service'; + +export async function findOrCreateBillingCustomer(params: { + stripeService: StripeService; + organizationId: string; + customerEmail?: string; +}): Promise { + const existing = await db.organizationBilling.findUnique({ + where: { organizationId: params.organizationId }, + select: { stripeCustomerId: true }, + }); + if (existing) { + return existing.stripeCustomerId; + } + + const organization = await db.organization.findUniqueOrThrow({ + where: { id: params.organizationId }, + select: { id: true, name: true }, + }); + const stripe = params.stripeService.getClient(); + const customer = await stripe.customers.create( + { + name: organization.name, + email: params.customerEmail, + metadata: { organizationId: organization.id }, + }, + { + idempotencyKey: ['organization-billing-customer', organization.id].join( + ':', + ), + }, + ); + + try { + await db.organizationBilling.create({ + data: { + organizationId: organization.id, + stripeCustomerId: customer.id, + }, + }); + } catch (error) { + if (!isUniqueConstraintError(error)) throw error; + const raced = await db.organizationBilling.findUniqueOrThrow({ + where: { organizationId: organization.id }, + select: { stripeCustomerId: true }, + }); + return raced.stripeCustomerId; + } + + return customer.id; +} + +function isUniqueConstraintError(error: unknown): boolean { + return ( + error instanceof Prisma.PrismaClientKnownRequestError && + error.code === 'P2002' + ); +} diff --git a/apps/api/src/billing/billing-entitlements.service.ts b/apps/api/src/billing/billing-entitlements.service.ts new file mode 100644 index 0000000000..a788d637a0 --- /dev/null +++ b/apps/api/src/billing/billing-entitlements.service.ts @@ -0,0 +1,298 @@ +import { HttpException, HttpStatus, Injectable } from '@nestjs/common'; +import { db } from '@db'; +import type { BillingSkuKey } from '@trycompai/billing'; +import { + type BillingConsumeResult, + isAccessStatus, + isUniqueConstraintError, + sameTime, + type SyncSubscriptionItemParams, + type WriteBillingAuditEventParams, +} from './billing-entitlements.types'; + +@Injectable() +export class BillingEntitlementsService { + async tryConsumeIncludedUsage(params: { + organizationId: string; + skuKey: BillingSkuKey; + sourceResourceId: string; + }): Promise { + const subscription = await db.organizationBillingSubscription.findUnique({ + where: { + organizationId_skuKey: { + organizationId: params.organizationId, + skuKey: params.skuKey, + }, + }, + }); + + if (!subscription || !isAccessStatus(subscription.stripeStatus)) { + return { status: 'not_configured' }; + } + + if ( + subscription.currentPeriodEnd && + subscription.currentPeriodEnd.getTime() <= Date.now() + ) { + return { status: 'not_configured' }; + } + + if (subscription.usedQuantity >= subscription.includedQuantity) { + return { status: 'exhausted', subscriptionId: subscription.id }; + } + + const idempotencyKey = [ + 'consume', + params.organizationId, + params.skuKey, + params.sourceResourceId, + ].join(':'); + + try { + await db.$transaction(async (tx) => { + await tx.billingUsageEvent.create({ + data: { + organizationId: params.organizationId, + skuKey: params.skuKey, + eventType: 'consume', + quantity: 1, + sourceResourceId: params.sourceResourceId, + idempotencyKey, + stripeSubscriptionItemId: subscription.stripeSubscriptionItemId, + periodStart: subscription.currentPeriodStart, + periodEnd: subscription.currentPeriodEnd, + }, + }); + + const updated = await tx.organizationBillingSubscription.updateMany({ + where: { + id: subscription.id, + usedQuantity: { lt: subscription.includedQuantity }, + }, + data: { usedQuantity: { increment: 1 } }, + }); + + if (updated.count === 0) { + throw new HttpException( + 'Subscription allowance has been exhausted.', + HttpStatus.PAYMENT_REQUIRED, + ); + } + }); + } catch (error) { + if (isUniqueConstraintError(error)) { + return { status: 'consumed', subscriptionId: subscription.id }; + } + throw error; + } + + return { status: 'consumed', subscriptionId: subscription.id }; + } + + async syncSubscriptionItem( + params: SyncSubscriptionItemParams, + ): Promise { + const existing = await db.organizationBillingSubscription.findUnique({ + where: { + organizationId_skuKey: { + organizationId: params.organizationId, + skuKey: params.skuKey, + }, + }, + select: { + stripeSubscriptionItemId: true, + currentPeriodStart: true, + currentPeriodEnd: true, + }, + }); + if ( + existing?.currentPeriodEnd && + params.currentPeriodEnd && + existing.currentPeriodEnd.getTime() > params.currentPeriodEnd.getTime() + ) { + return; + } + + const resetUsage = + !sameTime( + existing?.currentPeriodStart ?? null, + params.currentPeriodStart, + ) || + !sameTime(existing?.currentPeriodEnd ?? null, params.currentPeriodEnd); + + await db.$transaction(async (tx) => { + await tx.organizationBillingSubscription.upsert({ + where: { + organizationId_skuKey: { + organizationId: params.organizationId, + skuKey: params.skuKey, + }, + }, + create: { + organizationId: params.organizationId, + skuKey: params.skuKey, + stripeSubscriptionId: params.stripeSubscriptionId, + stripeSubscriptionItemId: params.stripeSubscriptionItemId, + stripePriceId: params.stripePriceId, + stripeStatus: params.stripeStatus, + currentPeriodStart: params.currentPeriodStart, + currentPeriodEnd: params.currentPeriodEnd, + includedQuantity: params.includedQuantity, + cancelAtPeriodEnd: params.cancelAtPeriodEnd, + canceledAt: params.canceledAt, + }, + update: { + stripeSubscriptionId: params.stripeSubscriptionId, + stripeSubscriptionItemId: params.stripeSubscriptionItemId, + stripePriceId: params.stripePriceId, + stripeStatus: params.stripeStatus, + currentPeriodStart: params.currentPeriodStart, + currentPeriodEnd: params.currentPeriodEnd, + includedQuantity: params.includedQuantity, + ...(resetUsage ? { usedQuantity: 0 } : {}), + cancelAtPeriodEnd: params.cancelAtPeriodEnd, + canceledAt: params.canceledAt, + }, + }); + + if (resetUsage) { + const idempotencyKey = [ + 'grant', + params.organizationId, + params.skuKey, + params.stripeSubscriptionItemId, + params.currentPeriodStart?.toISOString() ?? 'none', + params.currentPeriodEnd?.toISOString() ?? 'none', + params.stripeEventId ?? 'manual', + ].join(':'); + await tx.billingUsageEvent.create({ + data: { + organizationId: params.organizationId, + skuKey: params.skuKey, + eventType: 'grant', + quantity: params.includedQuantity, + idempotencyKey, + stripeEventId: params.stripeEventId, + stripeSubscriptionItemId: params.stripeSubscriptionItemId, + periodStart: params.currentPeriodStart, + periodEnd: params.currentPeriodEnd, + }, + }); + } + }); + + await this.writeAuditEvent({ + organizationId: params.organizationId, + eventType: 'subscription_synced', + skuKey: params.skuKey, + stripeEventId: params.stripeEventId, + metadata: { + stripeSubscriptionId: params.stripeSubscriptionId, + stripeSubscriptionItemId: params.stripeSubscriptionItemId, + stripeStatus: params.stripeStatus, + }, + }); + } + + async recordOneTimeUsage(params: { + organizationId: string; + skuKey: BillingSkuKey; + sourceResourceId: string; + stripeInvoiceId?: string; + }): Promise { + const idempotencyKey = [ + 'one-time', + params.organizationId, + params.skuKey, + params.sourceResourceId, + ].join(':'); + + try { + await db.billingUsageEvent.create({ + data: { + organizationId: params.organizationId, + skuKey: params.skuKey, + eventType: 'one_time', + quantity: 1, + sourceResourceId: params.sourceResourceId, + idempotencyKey, + stripeInvoiceId: params.stripeInvoiceId, + }, + }); + } catch (error) { + if (!isUniqueConstraintError(error)) throw error; + } + } + + async refundIncludedUsage(params: { + organizationId: string; + skuKey: BillingSkuKey; + sourceResourceId: string; + reason: string; + }): Promise { + const consumeKey = [ + 'consume', + params.organizationId, + params.skuKey, + params.sourceResourceId, + ].join(':'); + const refundKey = [ + 'refund', + params.organizationId, + params.skuKey, + params.sourceResourceId, + params.reason, + ].join(':'); + + try { + await db.$transaction(async (tx) => { + const consumed = await tx.billingUsageEvent.findUnique({ + where: { idempotencyKey: consumeKey }, + select: { + stripeSubscriptionItemId: true, + periodStart: true, + periodEnd: true, + }, + }); + if (!consumed) return; + + await tx.billingUsageEvent.create({ + data: { + organizationId: params.organizationId, + skuKey: params.skuKey, + eventType: 'refund', + quantity: 1, + sourceResourceId: params.sourceResourceId, + idempotencyKey: refundKey, + stripeSubscriptionItemId: consumed.stripeSubscriptionItemId, + periodStart: consumed.periodStart, + periodEnd: consumed.periodEnd, + }, + }); + + await tx.organizationBillingSubscription.updateMany({ + where: { + organizationId: params.organizationId, + skuKey: params.skuKey, + usedQuantity: { gt: 0 }, + }, + data: { usedQuantity: { decrement: 1 } }, + }); + }); + } catch (error) { + if (!isUniqueConstraintError(error)) throw error; + } + } + + async writeAuditEvent(params: WriteBillingAuditEventParams): Promise { + await db.billingAuditEvent.create({ + data: { + organizationId: params.organizationId, + eventType: params.eventType, + skuKey: params.skuKey, + stripeEventId: params.stripeEventId, + metadata: params.metadata, + }, + }); + } +} diff --git a/apps/api/src/billing/billing-entitlements.types.ts b/apps/api/src/billing/billing-entitlements.types.ts new file mode 100644 index 0000000000..6c95d092f0 --- /dev/null +++ b/apps/api/src/billing/billing-entitlements.types.ts @@ -0,0 +1,47 @@ +import { Prisma } from '@db'; +import type { BillingSkuKey } from '@trycompai/billing'; + +export type BillingConsumeResult = + | { status: 'consumed'; subscriptionId: string } + | { status: 'exhausted'; subscriptionId: string } + | { status: 'not_configured' }; + +export type SyncSubscriptionItemParams = { + organizationId: string; + skuKey: BillingSkuKey; + stripeSubscriptionId: string; + stripeSubscriptionItemId: string; + stripePriceId: string; + stripeStatus: string; + currentPeriodStart: Date | null; + currentPeriodEnd: Date | null; + includedQuantity: number; + cancelAtPeriodEnd: boolean; + canceledAt: Date | null; + stripeEventId?: string; +}; + +export type WriteBillingAuditEventParams = { + organizationId: string; + eventType: string; + skuKey?: string; + stripeEventId?: string; + metadata?: Prisma.InputJsonValue; +}; + +export function isAccessStatus(status: string): boolean { + return status === 'active' || status === 'trialing'; +} + +export function sameTime(left: Date | null, right: Date | null): boolean { + if (!left && !right) return true; + if (!left || !right) return false; + return left.getTime() === right.getTime(); +} + +export function isUniqueConstraintError(error: unknown): boolean { + return ( + error instanceof Prisma.PrismaClientKnownRequestError && + error.code === 'P2002' + ); +} diff --git a/apps/api/src/billing/billing-invoices.spec.ts b/apps/api/src/billing/billing-invoices.spec.ts new file mode 100644 index 0000000000..7dee6a84e8 --- /dev/null +++ b/apps/api/src/billing/billing-invoices.spec.ts @@ -0,0 +1,105 @@ +import type Stripe from 'stripe'; +import type { StripeService } from '../stripe/stripe.service'; +import { listBillingInvoices } from './billing-invoices'; + +function mockStripeService(params: { invoices: Stripe.Invoice[] }): StripeService { + const invoicesList = jest.fn().mockResolvedValue({ data: params.invoices }); + return { + isConfigured: () => true, + getClient: () => ({ + invoices: { + list: invoicesList, + }, + }), + } as unknown as StripeService; +} + +function createInvoice(params: { + id: string; + amountPaid: number; + priceId: string; + billingReason?: Stripe.Invoice.BillingReason | null; +}): Stripe.Invoice { + const line = { + id: `il_${params.id}`, + object: 'line_item', + parent: null, + pricing: { + type: 'price_details', + unit_amount_decimal: String(params.amountPaid), + price_details: { + price: params.priceId, + product: 'prod_test', + }, + }, + subscription: null, + } as unknown as Stripe.InvoiceLineItem; + + return { + id: params.id, + number: `${params.id}-0001`, + created: 1777564800, + due_date: null, + amount_paid: params.amountPaid, + amount_due: params.amountPaid, + currency: 'usd', + status: 'paid', + billing_reason: params.billingReason ?? 'manual', + hosted_invoice_url: `https://invoice.stripe.test/${params.id}`, + invoice_pdf: `https://invoice.stripe.test/${params.id}.pdf`, + lines: { + object: 'list', + data: [line], + has_more: false, + url: `/v1/invoices/${params.id}/lines`, + }, + } as unknown as Stripe.Invoice; +} + +describe('listBillingInvoices', () => { + it('labels subscription catalog invoices as subscriptions', async () => { + const invoices = await listBillingInvoices({ + stripeService: mockStripeService({ + invoices: [ + createInvoice({ + id: 'in_subscription', + amountPaid: 39900, + priceId: 'price_1TRya6CkFWhKYvHI1sJ2M2no', + }), + ], + }), + stripeCustomerId: 'cus_1', + }); + + expect(invoices).toEqual([ + expect.objectContaining({ + id: 'in_subscription', + amountPaid: 39900, + type: 'Subscription', + }), + ]); + }); + + it('keeps one-time catalog invoices labelled as one-time', async () => { + const invoices = await listBillingInvoices({ + stripeService: mockStripeService({ + invoices: [ + createInvoice({ + id: 'in_one_time', + amountPaid: 4900, + priceId: 'price_1TRWckCkFWhKYvHIA1GLv1sO', + }), + ], + }), + stripeCustomerId: 'cus_1', + }); + + expect(invoices).toEqual([ + expect.objectContaining({ + id: 'in_one_time', + amountPaid: 4900, + type: 'One Time', + }), + ]); + }); +}); diff --git a/apps/api/src/billing/billing-invoices.ts b/apps/api/src/billing/billing-invoices.ts new file mode 100644 index 0000000000..e855a770b4 --- /dev/null +++ b/apps/api/src/billing/billing-invoices.ts @@ -0,0 +1,95 @@ +import type Stripe from 'stripe'; +import { getBillingSkuByStripePriceId } from '@trycompai/billing'; +import type { StripeService } from '../stripe/stripe.service'; + +export interface BillingInvoice { + id: string; + number: string; + createdAt: string; + dueDate: string | null; + amountPaid: number; + amountDue: number; + currency: string; + status: string; + type: 'Subscription' | 'One Time'; + hostedInvoiceUrl: string | null; + invoicePdfUrl: string | null; +} + +export async function listBillingInvoices(params: { + stripeService: StripeService; + stripeCustomerId: string | null; +}): Promise { + if (!params.stripeCustomerId || !params.stripeService.isConfigured()) { + return []; + } + + const stripe = params.stripeService.getClient(); + const invoices = await stripe.invoices.list({ + customer: params.stripeCustomerId, + limit: 20, + }); + + return invoices.data.map(mapInvoice); +} + +function mapInvoice(invoice: Stripe.Invoice): BillingInvoice { + return { + id: invoice.id, + number: invoice.number ?? invoice.id, + createdAt: new Date(invoice.created * 1000).toISOString(), + dueDate: invoice.due_date + ? new Date(invoice.due_date * 1000).toISOString() + : null, + amountPaid: invoice.amount_paid, + amountDue: invoice.amount_due, + currency: invoice.currency, + status: invoice.status ?? 'unknown', + type: hasSubscription(invoice) ? 'Subscription' : 'One Time', + hostedInvoiceUrl: invoice.hosted_invoice_url ?? null, + invoicePdfUrl: invoice.invoice_pdf ?? null, + }; +} + +function hasSubscription(invoice: Stripe.Invoice): boolean { + if (invoice.billing_reason?.startsWith('subscription')) { + return true; + } + + return invoice.lines.data.some(isSubscriptionLine); +} + +function isSubscriptionLine(line: Stripe.InvoiceLineItem): boolean { + if (line.parent?.subscription_item_details?.subscription) { + return true; + } + + if (line.subscription) { + return true; + } + + const priceId = getLinePriceId(line); + if (!priceId) { + return false; + } + + const sku = getBillingSkuByStripePriceId({ stripePriceId: priceId }); + return sku?.cadence === 'month'; +} + +function getLinePriceId(line: Stripe.InvoiceLineItem): string | null { + const price = line.pricing?.price_details?.price; + if (typeof price === 'string') { + return price; + } + + if (isRecord(price) && typeof price.id === 'string') { + return price.id; + } + + return null; +} + +function isRecord(value: unknown): value is Record { + return typeof value === 'object' && value !== null; +} diff --git a/apps/api/src/billing/billing-preferences.spec.ts b/apps/api/src/billing/billing-preferences.spec.ts new file mode 100644 index 0000000000..c76c7aadf6 --- /dev/null +++ b/apps/api/src/billing/billing-preferences.spec.ts @@ -0,0 +1,120 @@ +import { db } from '@db'; +import type Stripe from 'stripe'; +import type { StripeService } from '../stripe/stripe.service'; +import { updateBillingPreferences } from './billing-preferences'; + +jest.mock('@db', () => ({ + db: { + organizationBilling: { + findUnique: jest.fn(), + }, + }, +})); + +const organizationBillingFindUnique = db.organizationBilling + .findUnique as unknown as jest.Mock; + +function mockStripeService(client: unknown): StripeService { + return { + isConfigured: () => true, + getClient: () => client, + } as unknown as StripeService; +} + +function createCustomer(): Stripe.Customer { + return { + id: 'cus_1', + object: 'customer', + address: { + line1: '1 Test Street', + line2: null, + city: 'London', + state: null, + postal_code: 'SW1A 1AA', + country: 'GB', + }, + business_name: 'Test Company', + created: 1777564800, + default_source: null, + description: null, + email: 'accounts@example.com', + invoice_settings: { + custom_fields: [{ name: 'PO / Reference', value: 'PO-123' }], + default_payment_method: null, + footer: null, + rendering_options: null, + }, + livemode: false, + metadata: {}, + name: 'Test Company', + shipping: null, + } as unknown as Stripe.Customer; +} + +describe('updateBillingPreferences', () => { + beforeEach(() => { + jest.clearAllMocks(); + organizationBillingFindUnique.mockResolvedValue({ stripeCustomerId: 'cus_1' }); + }); + + it('updates the Stripe customer fields used for B2B invoices', async () => { + const customersUpdate = jest.fn().mockResolvedValue(createCustomer()); + const taxIdsList = jest.fn().mockResolvedValue({ data: [] }); + const taxIdsCreate = jest.fn().mockResolvedValue({ + id: 'txi_1', + type: 'gb_vat', + value: 'GB123456789', + verification: { status: 'verified' }, + }); + + const result = await updateBillingPreferences({ + stripeService: mockStripeService({ + customers: { update: customersUpdate }, + taxIds: { list: taxIdsList, create: taxIdsCreate, del: jest.fn() }, + }), + organizationId: 'org_1', + preferences: { + companyName: 'Test Company', + billingEmail: 'accounts@example.com', + purchaseOrder: 'PO-123', + address: { + line1: '1 Test Street', + line2: null, + city: 'London', + state: null, + postalCode: 'SW1A 1AA', + country: 'gb', + }, + taxId: { type: 'gb_vat', value: 'GB123456789' }, + }, + }); + + expect(customersUpdate).toHaveBeenCalledWith( + 'cus_1', + expect.objectContaining({ + email: 'accounts@example.com', + name: 'Test Company', + business_name: 'Test Company', + address: expect.objectContaining({ country: 'GB' }), + invoice_settings: { + custom_fields: [{ name: 'PO / Reference', value: 'PO-123' }], + }, + }), + ); + expect(taxIdsCreate).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'gb_vat', + value: 'GB123456789', + owner: { type: 'customer', customer: 'cus_1' }, + }), + expect.objectContaining({ idempotencyKey: expect.stringContaining('cus_1') }), + ); + expect(result.preferences).toEqual( + expect.objectContaining({ + billingEmail: 'accounts@example.com', + purchaseOrder: 'PO-123', + taxId: expect.objectContaining({ type: 'gb_vat' }), + }), + ); + }); +}); diff --git a/apps/api/src/billing/billing-preferences.ts b/apps/api/src/billing/billing-preferences.ts new file mode 100644 index 0000000000..ac05f49d79 --- /dev/null +++ b/apps/api/src/billing/billing-preferences.ts @@ -0,0 +1,261 @@ +import type Stripe from 'stripe'; +import { BadRequestException } from '@nestjs/common'; +import type { StripeService } from '../stripe/stripe.service'; +import { findOrCreateBillingCustomer } from './billing-customer'; + +export const billingTaxIdTypes = [ + 'gb_vat', + 'eu_vat', + 'us_ein', + 'au_abn', + 'ca_bn', + 'nz_gst', + 'sg_gst', + 'sg_uen', +] as const satisfies readonly Stripe.TaxId.Type[]; + +export type BillingTaxIdType = (typeof billingTaxIdTypes)[number]; + +export interface BillingPreferences { + companyName: string | null; + billingEmail: string | null; + purchaseOrder: string | null; + address: { + line1: string | null; + line2: string | null; + city: string | null; + state: string | null; + postalCode: string | null; + country: string | null; + }; + taxId: { + id: string; + type: string; + value: string; + verificationStatus: string | null; + } | null; +} + +export interface BillingPreferencesInput { + companyName: string; + billingEmail: string; + purchaseOrder: string | null; + address: BillingPreferences['address']; + taxId: { + type: string | null; + value: string | null; + } | null; +} + +const purchaseOrderFieldName = 'PO / Reference'; + +export async function getBillingPreferences(params: { + stripeService: StripeService; + stripeCustomerId: string | null; + fallbackCompanyName: string | null; +}): Promise { + if (!params.stripeCustomerId || !params.stripeService.isConfigured()) { + return createEmptyPreferences({ companyName: params.fallbackCompanyName }); + } + + const stripe = params.stripeService.getClient(); + const [customer, taxIds] = await Promise.all([ + stripe.customers.retrieve(params.stripeCustomerId), + listCustomerTaxIds({ stripe, stripeCustomerId: params.stripeCustomerId }), + ]); + + if (isDeletedCustomer(customer)) { + return createEmptyPreferences({ companyName: params.fallbackCompanyName }); + } + + return mapCustomerPreferences({ customer, taxId: taxIds[0] ?? null }); +} + +export async function updateBillingPreferences(params: { + stripeService: StripeService; + organizationId: string; + preferences: BillingPreferencesInput; +}): Promise<{ stripeCustomerId: string; preferences: BillingPreferences }> { + validatePreferences(params.preferences); + + const stripeCustomerId = await findOrCreateBillingCustomer({ + stripeService: params.stripeService, + organizationId: params.organizationId, + customerEmail: params.preferences.billingEmail, + }); + const stripe = params.stripeService.getClient(); + const existingTaxIds = await listCustomerTaxIds({ stripe, stripeCustomerId }); + const customer = await stripe.customers.update(stripeCustomerId, { + email: params.preferences.billingEmail, + name: params.preferences.companyName, + business_name: params.preferences.companyName, + address: toStripeAddress(params.preferences.address), + invoice_settings: { + custom_fields: params.preferences.purchaseOrder + ? [{ name: purchaseOrderFieldName, value: params.preferences.purchaseOrder }] + : '', + }, + metadata: { organizationId: params.organizationId }, + }); + + const taxId = await syncPrimaryTaxId({ + stripe, + stripeCustomerId, + existingTaxId: existingTaxIds[0] ?? null, + taxId: params.preferences.taxId, + }); + + return { + stripeCustomerId, + preferences: mapCustomerPreferences({ customer, taxId }), + }; +} + +function validatePreferences(preferences: BillingPreferencesInput): void { + if (!preferences.companyName.trim()) { + throw new BadRequestException('Company name is required.'); + } + if (!preferences.billingEmail.trim()) { + throw new BadRequestException('Billing email is required.'); + } + + const type = preferences.taxId?.type?.trim() ?? ''; + const value = preferences.taxId?.value?.trim() ?? ''; + if (type && !isSupportedTaxIdType(type)) { + throw new BadRequestException('Unsupported tax ID type.'); + } + if ((type && !value) || (!type && value)) { + throw new BadRequestException('Tax ID type and value must be set together.'); + } +} + +async function syncPrimaryTaxId(params: { + stripe: Stripe; + stripeCustomerId: string; + existingTaxId: Stripe.TaxId | null; + taxId: BillingPreferencesInput['taxId']; +}): Promise { + const type = params.taxId?.type?.trim() ?? ''; + const value = params.taxId?.value?.trim() ?? ''; + if (!type || !value) { + if (params.existingTaxId) { + await params.stripe.taxIds.del(params.existingTaxId.id); + } + return null; + } + + if (params.existingTaxId?.type === type && params.existingTaxId.value === value) { + return params.existingTaxId; + } + + if (params.existingTaxId) { + await params.stripe.taxIds.del(params.existingTaxId.id); + } + + if (!isSupportedTaxIdType(type)) { + throw new BadRequestException('Unsupported tax ID type.'); + } + + return params.stripe.taxIds.create( + { + type, + value, + owner: { type: 'customer', customer: params.stripeCustomerId }, + }, + { + idempotencyKey: [ + 'billing-tax-id', + params.stripeCustomerId, + type, + value.replace(/[^a-zA-Z0-9]/g, '').toLowerCase(), + ].join(':'), + }, + ); +} + +async function listCustomerTaxIds(params: { + stripe: Stripe; + stripeCustomerId: string; +}): Promise { + const taxIds = await params.stripe.taxIds.list({ + owner: { type: 'customer', customer: params.stripeCustomerId }, + limit: 5, + }); + return taxIds.data; +} + +function mapCustomerPreferences(params: { + customer: Stripe.Customer; + taxId: Stripe.TaxId | null; +}): BillingPreferences { + return { + companyName: params.customer.name ?? params.customer.business_name ?? null, + billingEmail: params.customer.email ?? null, + purchaseOrder: findInvoiceCustomFieldValue(params.customer, purchaseOrderFieldName), + address: { + line1: params.customer.address?.line1 ?? null, + line2: params.customer.address?.line2 ?? null, + city: params.customer.address?.city ?? null, + state: params.customer.address?.state ?? null, + postalCode: params.customer.address?.postal_code ?? null, + country: params.customer.address?.country ?? null, + }, + taxId: params.taxId + ? { + id: params.taxId.id, + type: params.taxId.type, + value: params.taxId.value, + verificationStatus: params.taxId.verification?.status ?? null, + } + : null, + }; +} + +function toStripeAddress(address: BillingPreferences['address']): Stripe.AddressParam { + return { + line1: emptyToUndefined(address.line1), + line2: emptyToUndefined(address.line2), + city: emptyToUndefined(address.city), + state: emptyToUndefined(address.state), + postal_code: emptyToUndefined(address.postalCode), + country: emptyToUndefined(address.country)?.toUpperCase(), + }; +} + +function createEmptyPreferences(params: { companyName: string | null }): BillingPreferences { + return { + companyName: params.companyName, + billingEmail: null, + purchaseOrder: null, + address: { + line1: null, + line2: null, + city: null, + state: null, + postalCode: null, + country: null, + }, + taxId: null, + }; +} + +function findInvoiceCustomFieldValue(customer: Stripe.Customer, name: string): string | null { + return ( + customer.invoice_settings.custom_fields?.find((field) => field.name === name)?.value ?? null + ); +} + +function emptyToUndefined(value: string | null): string | undefined { + const trimmed = value?.trim(); + return trimmed ? trimmed : undefined; +} + +function isDeletedCustomer( + customer: Stripe.Customer | Stripe.DeletedCustomer, +): customer is Stripe.DeletedCustomer { + return 'deleted' in customer && customer.deleted === true; +} + +function isSupportedTaxIdType(value: string): value is BillingTaxIdType { + return billingTaxIdTypes.some((type) => type === value); +} diff --git a/apps/api/src/billing/billing-redirect-urls.ts b/apps/api/src/billing/billing-redirect-urls.ts new file mode 100644 index 0000000000..eb94d43ded --- /dev/null +++ b/apps/api/src/billing/billing-redirect-urls.ts @@ -0,0 +1,25 @@ +import { BadRequestException } from '@nestjs/common'; + +const allowedHosts = new Set([ + 'localhost', + '127.0.0.1', + 'app.trycomp.ai', + 'app.staging.trycomp.ai', +]); + +export function validateBillingRedirectUrl(value: string): void { + let url: URL; + try { + url = new URL(value); + } catch { + throw new BadRequestException('Billing redirect URL is invalid.'); + } + + if (!['http:', 'https:'].includes(url.protocol)) { + throw new BadRequestException('Billing redirect URL is invalid.'); + } + + if (!allowedHosts.has(url.hostname)) { + throw new BadRequestException('Billing redirect URL is not allowed.'); + } +} diff --git a/apps/api/src/billing/billing-usage.spec.ts b/apps/api/src/billing/billing-usage.spec.ts new file mode 100644 index 0000000000..4340e73585 --- /dev/null +++ b/apps/api/src/billing/billing-usage.spec.ts @@ -0,0 +1,87 @@ +import { db } from '@db'; +import { listBillingUsageRows } from './billing-usage'; + +jest.mock('@db', () => ({ + db: { + backgroundCheckRequest: { findMany: jest.fn() }, + securityPenetrationTestRun: { findMany: jest.fn() }, + billingUsageEvent: { findMany: jest.fn() }, + }, +})); + +const mockedDb = db as jest.Mocked; +const backgroundCheckFindMany = mockedDb.backgroundCheckRequest + .findMany as unknown as jest.Mock; +const pentestRunFindMany = mockedDb.securityPenetrationTestRun + .findMany as unknown as jest.Mock; +const billingUsageEventFindMany = mockedDb.billingUsageEvent + .findMany as unknown as jest.Mock; + +describe('listBillingUsageRows', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('combines run history with subscription allowance details', async () => { + backgroundCheckFindMany.mockResolvedValue([ + { + id: 'bcr_1', + memberId: 'mem_1', + employeeName: 'Ada Lovelace', + employeeEmail: 'ada@example.com', + status: 'completed', + stripePaymentStatus: 'succeeded', + createdAt: new Date('2026-04-30T10:00:00.000Z'), + updatedAt: new Date('2026-04-30T10:05:00.000Z'), + }, + ]); + pentestRunFindMany.mockResolvedValue([ + { + id: 'ptr_1', + providerRunId: 'run_1', + createdAt: new Date('2026-04-30T11:00:00.000Z'), + updatedAt: new Date('2026-04-30T11:05:00.000Z'), + }, + ]); + billingUsageEventFindMany.mockResolvedValue([ + { + skuKey: 'background_checks_monthly_25', + eventType: 'consume', + sourceResourceId: 'mem_1', + stripeInvoiceId: null, + }, + ]); + + const rows = await listBillingUsageRows({ + organizationId: 'org_1', + subscriptions: [ + { + skuKey: 'background_checks_monthly_25', + includedQuantity: 25, + usedQuantity: 2, + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }, + { + skuKey: 'pentest_monthly_5', + includedQuantity: 5, + usedQuantity: 1, + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }, + ], + }); + + expect(rows).toEqual([ + expect.objectContaining({ + service: 'Penetration Test', + details: 'run_1', + subscriptionRemaining: 4, + }), + expect.objectContaining({ + service: 'Background Check', + details: 'Ada Lovelace (ada@example.com)', + billingType: 'Subscription allowance', + subscriptionRemaining: 23, + }), + ]); + }); +}); diff --git a/apps/api/src/billing/billing-usage.ts b/apps/api/src/billing/billing-usage.ts new file mode 100644 index 0000000000..f5108fcbf6 --- /dev/null +++ b/apps/api/src/billing/billing-usage.ts @@ -0,0 +1,150 @@ +import { db } from '@db'; +import type { BillingSkuKey } from '@trycompai/billing'; +import type { BillingUsageRow } from './billing.types'; + +type SubscriptionSummary = { + skuKey: string; + includedQuantity: number; + usedQuantity: number; + currentPeriodEnd: Date | null; +}; + +const backgroundCheckSku = 'background_checks_monthly_25'; +const pentestSku = 'pentest_monthly_5'; + +export async function listBillingUsageRows(params: { + organizationId: string; + subscriptions: SubscriptionSummary[]; +}): Promise { + const [backgroundChecks, pentestRuns, usageEvents] = await Promise.all([ + db.backgroundCheckRequest.findMany({ + where: { organizationId: params.organizationId }, + orderBy: { createdAt: 'desc' }, + take: 50, + select: { + id: true, + memberId: true, + employeeName: true, + employeeEmail: true, + status: true, + stripePaymentStatus: true, + createdAt: true, + updatedAt: true, + }, + }), + db.securityPenetrationTestRun.findMany({ + where: { organizationId: params.organizationId }, + orderBy: { createdAt: 'desc' }, + take: 50, + select: { + id: true, + providerRunId: true, + createdAt: true, + updatedAt: true, + }, + }), + db.billingUsageEvent.findMany({ + where: { + organizationId: params.organizationId, + eventType: { in: ['consume', 'one_time'] }, + sourceResourceId: { not: null }, + }, + orderBy: { createdAt: 'desc' }, + take: 100, + select: { + skuKey: true, + eventType: true, + sourceResourceId: true, + stripeInvoiceId: true, + }, + }), + ]); + + const usageBySource = new Map( + usageEvents + .filter((event) => event.sourceResourceId) + .map((event) => [event.sourceResourceId as string, event]), + ); + + const rows = [ + ...backgroundChecks.map((request) => { + const usage = usageBySource.get(request.memberId); + return toBillingUsageRow({ + id: request.id, + service: 'Background Check', + skuKey: backgroundCheckSku, + details: `${request.employeeName} (${request.employeeEmail})`, + status: formatStatus(request.status), + billingType: formatBillingType(usage?.eventType, usage?.stripeInvoiceId), + createdAt: request.createdAt, + updatedAt: request.updatedAt, + subscriptions: params.subscriptions, + }); + }), + ...pentestRuns.map((run) => + toBillingUsageRow({ + id: run.id, + service: 'Penetration Test', + skuKey: pentestSku, + details: run.providerRunId, + status: 'Created', + billingType: formatPentestBillingType(params.subscriptions), + createdAt: run.createdAt, + updatedAt: run.updatedAt, + subscriptions: params.subscriptions, + }), + ), + ]; + + return rows.sort((first, second) => second.createdAt.localeCompare(first.createdAt)); +} + +function toBillingUsageRow(params: { + id: string; + service: BillingUsageRow['service']; + skuKey: BillingSkuKey; + details: string; + status: string; + billingType: string; + createdAt: Date; + updatedAt: Date; + subscriptions: SubscriptionSummary[]; +}): BillingUsageRow { + const subscription = params.subscriptions.find((item) => item.skuKey === params.skuKey); + const remaining = subscription + ? Math.max(subscription.includedQuantity - subscription.usedQuantity, 0) + : null; + + return { + id: params.id, + service: params.service, + skuKey: params.skuKey, + details: params.details, + status: params.status, + billingType: params.billingType, + createdAt: params.createdAt.toISOString(), + updatedAt: params.updatedAt.toISOString(), + subscriptionRemaining: remaining, + subscriptionIncluded: subscription?.includedQuantity ?? null, + subscriptionPeriodEnd: subscription?.currentPeriodEnd?.toISOString() ?? null, + }; +} + +function formatBillingType(eventType?: string, stripeInvoiceId?: string | null): string { + if (eventType === 'consume') return 'Subscription allowance'; + if (eventType === 'one_time') return stripeInvoiceId ? 'One-time invoice' : 'One-time'; + return 'Legacy / manual'; +} + +function formatPentestBillingType(subscriptions: SubscriptionSummary[]): string { + return subscriptions.some((item) => item.skuKey === pentestSku) + ? 'Subscription allowance' + : 'Trial credit'; +} + +function formatStatus(status: string): string { + return status + .split('_') + .map((part) => part.charAt(0).toUpperCase() + part.slice(1)) + .join(' '); +} diff --git a/apps/api/src/billing/billing-webhook.service.ts b/apps/api/src/billing/billing-webhook.service.ts new file mode 100644 index 0000000000..d409a2de26 --- /dev/null +++ b/apps/api/src/billing/billing-webhook.service.ts @@ -0,0 +1,270 @@ +import { BadRequestException, Injectable } from '@nestjs/common'; +import { Prisma, db } from '@db'; +import { getBillingSkuByStripePriceId } from '@trycompai/billing'; +import Stripe from 'stripe'; +import { StripeService } from '../stripe/stripe.service'; +import { BillingEntitlementsService } from './billing-entitlements.service'; +import { + claimStripeWebhookEvent, + markStripeWebhookFailed, + markStripeWebhookProcessed, +} from './stripe-webhook-records'; + +@Injectable() +export class BillingWebhookService { + constructor( + private readonly stripeService: StripeService, + private readonly entitlements: BillingEntitlementsService, + ) {} + + async handleWebhook(params: { + rawBody: Buffer | undefined; + signature: string | undefined; + }): Promise<{ ok: true; duplicate?: true }> { + if (!params.rawBody) throw new BadRequestException('Raw body unavailable.'); + if (!params.signature) { + throw new BadRequestException('Stripe signature header is missing.'); + } + + const secret = + process.env.STRIPE_WEBHOOK_SECRET ?? + process.env.STRIPE_PENTEST_WEBHOOK_SECRET; + if (!secret) + throw new BadRequestException('Stripe webhook secret is not configured.'); + + const event = this.stripeService + .getClient() + .webhooks.constructEvent(params.rawBody, params.signature, secret); + const claim = await claimStripeWebhookEvent({ + stripeEventId: event.id, + eventType: event.type, + payload: event.data.object as unknown as Prisma.InputJsonValue, + }); + if (claim.status === 'duplicate') return { ok: true, duplicate: true }; + + try { + await this.processEvent(event); + await markStripeWebhookProcessed(event.id); + return { ok: true }; + } catch (error) { + await markStripeWebhookFailed({ stripeEventId: event.id, error }); + throw error; + } + } + + private async processEvent(event: Stripe.Event): Promise { + switch (event.type) { + case 'checkout.session.completed': + await this.handleCheckoutSessionCompleted(event); + return; + case 'customer.subscription.updated': + case 'customer.subscription.deleted': + await this.syncSubscriptionFromEvent(event); + return; + case 'invoice.paid': + await this.handleInvoicePaid(event); + return; + case 'invoice.payment_failed': + case 'invoice.payment_action_required': + await this.handleInvoiceRecoveryEvent(event); + return; + default: + return; + } + } + + private async handleCheckoutSessionCompleted( + event: Stripe.Event, + ): Promise { + const session = event.data.object as Stripe.Checkout.Session; + if (session.mode !== 'subscription') return; + + const organizationId = session.metadata?.organizationId; + const customerId = extractStripeId(session.customer); + if (!organizationId || !customerId) return; + + await db.organizationBilling.upsert({ + where: { organizationId }, + create: { organizationId, stripeCustomerId: customerId }, + update: { stripeCustomerId: customerId }, + }); + + const subscriptionId = extractStripeId(session.subscription); + if (!subscriptionId) return; + const subscription = await this.retrieveSubscription(subscriptionId); + await this.syncCustomerPaymentMethod({ + organizationId, + customerId, + subscription, + stripeEventId: event.id, + }); + await this.syncSubscriptionItems({ + subscription, + organizationId, + stripeEventId: event.id, + }); + } + + private async syncSubscriptionFromEvent(event: Stripe.Event): Promise { + const subscription = event.data.object as Stripe.Subscription; + const organizationId = + await this.resolveSubscriptionOrganization(subscription); + if (!organizationId) return; + await this.syncSubscriptionItems({ + subscription, + organizationId, + stripeEventId: event.id, + }); + } + + private async handleInvoicePaid(event: Stripe.Event): Promise { + const invoice = event.data.object as Stripe.Invoice; + const subscriptionId = extractStripeId(readField(invoice, 'subscription')); + if (!subscriptionId) return; + const subscription = await this.retrieveSubscription(subscriptionId); + const organizationId = + await this.resolveSubscriptionOrganization(subscription); + if (!organizationId) return; + await this.syncSubscriptionItems({ + subscription, + organizationId, + stripeEventId: event.id, + }); + } + + private async handleInvoiceRecoveryEvent(event: Stripe.Event): Promise { + const invoice = event.data.object as Stripe.Invoice; + const customerId = extractStripeId(invoice.customer); + if (!customerId) return; + const billing = await db.organizationBilling.findFirst({ + where: { stripeCustomerId: customerId }, + select: { organizationId: true }, + }); + if (!billing) return; + await this.entitlements.writeAuditEvent({ + organizationId: billing.organizationId, + eventType: event.type, + stripeEventId: event.id, + metadata: { invoiceId: invoice.id }, + }); + } + + private async retrieveSubscription( + subscriptionId: string, + ): Promise { + return this.stripeService + .getClient() + .subscriptions.retrieve(subscriptionId, { + expand: ['items.data.price'], + }); + } + + private async syncCustomerPaymentMethod(params: { + organizationId: string; + customerId: string; + subscription: Stripe.Subscription; + stripeEventId: string; + }): Promise { + const subscriptionPaymentMethodId = extractStripeId( + params.subscription.default_payment_method, + ); + const customer = await this.stripeService + .getClient() + .customers.retrieve(params.customerId); + if (customer.deleted) return; + + const customerPaymentMethodId = extractStripeId( + customer.invoice_settings.default_payment_method, + ); + const paymentMethodId = + subscriptionPaymentMethodId ?? customerPaymentMethodId; + if (!paymentMethodId) return; + + await this.stripeService.getClient().customers.update(params.customerId, { + invoice_settings: { default_payment_method: paymentMethodId }, + }); + await db.organizationBilling.update({ + where: { organizationId: params.organizationId }, + data: { + stripePaymentMethodId: paymentMethodId, + paymentMethodUpdatedAt: new Date(), + }, + }); + await this.entitlements.writeAuditEvent({ + organizationId: params.organizationId, + eventType: 'payment_method_updated', + stripeEventId: params.stripeEventId, + metadata: { source: 'checkout.session.completed' }, + }); + } + + private async syncSubscriptionItems(params: { + subscription: Stripe.Subscription; + organizationId: string; + stripeEventId: string; + }): Promise { + for (const item of params.subscription.items.data) { + const priceId = item.price.id; + const sku = getBillingSkuByStripePriceId({ stripePriceId: priceId }); + if (!sku?.includedUsage) continue; + await this.entitlements.syncSubscriptionItem({ + organizationId: params.organizationId, + skuKey: sku.key, + stripeSubscriptionId: params.subscription.id, + stripeSubscriptionItemId: item.id, + stripePriceId: priceId, + stripeStatus: params.subscription.status, + currentPeriodStart: dateFromSeconds( + readNumber(item, 'current_period_start'), + ), + currentPeriodEnd: dateFromSeconds( + readNumber(item, 'current_period_end'), + ), + includedQuantity: sku.includedUsage.quantity, + cancelAtPeriodEnd: params.subscription.cancel_at_period_end, + canceledAt: dateFromSeconds(params.subscription.canceled_at), + stripeEventId: params.stripeEventId, + }); + } + } + + private async resolveSubscriptionOrganization( + subscription: Stripe.Subscription, + ): Promise { + if (subscription.metadata?.organizationId) { + return subscription.metadata.organizationId; + } + const customerId = extractStripeId(subscription.customer); + if (!customerId) return null; + const billing = await db.organizationBilling.findFirst({ + where: { stripeCustomerId: customerId }, + select: { organizationId: true }, + }); + return billing?.organizationId ?? null; + } +} + +function extractStripeId(value: unknown): string | null { + if (typeof value === 'string') return value; + if (!isRecord(value)) return null; + return typeof value.id === 'string' ? value.id : null; +} + +function dateFromSeconds(value: number | null): Date | null { + return value === null ? null : new Date(value * 1000); +} + +function readNumber(value: unknown, key: string): number | null { + if (!isRecord(value)) return null; + const raw = value[key]; + return typeof raw === 'number' ? raw : null; +} + +function readField(value: unknown, key: string): unknown { + if (!isRecord(value)) return null; + return value[key]; +} + +function isRecord(value: unknown): value is Record { + return typeof value === 'object' && value !== null; +} diff --git a/apps/api/src/billing/billing.controller.ts b/apps/api/src/billing/billing.controller.ts new file mode 100644 index 0000000000..a2e4e9ddb6 --- /dev/null +++ b/apps/api/src/billing/billing.controller.ts @@ -0,0 +1,153 @@ +import { + Body, + Controller, + Get, + Headers, + HttpCode, + Post, + Put, + Req, + UseGuards, + type RawBodyRequest, +} from '@nestjs/common'; +import { ApiOperation, ApiSecurity, ApiTags } from '@nestjs/swagger'; +import type { Request } from 'express'; +import { AuthContext, OrganizationId } from '../auth/auth-context.decorator'; +import { HybridAuthGuard } from '../auth/hybrid-auth.guard'; +import { PermissionGuard } from '../auth/permission.guard'; +import { Public } from '../auth/public.decorator'; +import { RequirePermission } from '../auth/require-permission.decorator'; +import type { AuthContext as AuthContextType } from '../auth/types'; +import { BillingService } from './billing.service'; +import { BillingWebhookService } from './billing-webhook.service'; +import { + BillingPortalDto, + BillingPreferencesDto, + BillingSetupSessionDto, + BillingSetupSuccessDto, + BillingSubscriptionCheckoutDto, +} from './dto/billing.dto'; + +@ApiTags('Billing') +@Controller({ path: 'billing', version: '1' }) +@UseGuards(HybridAuthGuard, PermissionGuard) +@ApiSecurity('apikey') +export class BillingController { + constructor( + private readonly billingService: BillingService, + private readonly webhookService: BillingWebhookService, + ) {} + + @Get('status') + @RequirePermission('organization', 'read') + @ApiOperation({ summary: 'Get organization billing status' }) + async getStatus(@OrganizationId() organizationId: string) { + return this.billingService.getStatus(organizationId); + } + + @Put('preferences') + @RequirePermission('organization', 'update') + @ApiOperation({ summary: 'Update organization billing preferences' }) + async updatePreferences( + @OrganizationId() organizationId: string, + @Body() body: BillingPreferencesDto, + ) { + return this.billingService.updatePreferences({ + organizationId, + preferences: { + companyName: body.companyName, + billingEmail: body.billingEmail, + purchaseOrder: body.purchaseOrder ?? null, + address: { + line1: body.addressLine1 ?? null, + line2: body.addressLine2 ?? null, + city: body.addressCity ?? null, + state: body.addressState ?? null, + postalCode: body.addressPostalCode ?? null, + country: body.addressCountry ?? null, + }, + taxId: { + type: body.taxIdType ?? null, + value: body.taxIdValue ?? null, + }, + }, + }); + } + + @Post('setup-session') + @RequirePermission('organization', 'update') + @HttpCode(200) + @ApiOperation({ summary: 'Create a Stripe setup session' }) + async setupSession( + @OrganizationId() organizationId: string, + @AuthContext() authContext: AuthContextType, + @Body() body: BillingSetupSessionDto, + ) { + return this.billingService.createSetupSession({ + organizationId, + successUrl: body.successUrl, + cancelUrl: body.cancelUrl, + customerEmail: authContext.userEmail, + }); + } + + @Post('setup-success') + @RequirePermission('organization', 'update') + @HttpCode(200) + @ApiOperation({ summary: 'Persist a successful Stripe setup session' }) + async setupSuccess( + @OrganizationId() organizationId: string, + @Body() body: BillingSetupSuccessDto, + ) { + return this.billingService.handleSetupSuccess({ + organizationId, + sessionId: body.sessionId, + }); + } + + @Post('portal') + @RequirePermission('organization', 'update') + @HttpCode(200) + @ApiOperation({ summary: 'Create a Stripe billing portal session' }) + async portal( + @OrganizationId() organizationId: string, + @Body() body: BillingPortalDto, + ) { + return this.billingService.createBillingPortalSession({ + organizationId, + returnUrl: body.returnUrl, + }); + } + + @Post('subscription-session') + @RequirePermission('organization', 'update') + @HttpCode(200) + @ApiOperation({ summary: 'Create a Stripe subscription Checkout session' }) + async subscriptionSession( + @OrganizationId() organizationId: string, + @AuthContext() authContext: AuthContextType, + @Body() body: BillingSubscriptionCheckoutDto, + ) { + return this.billingService.createSubscriptionCheckoutSession({ + organizationId, + skuKey: body.skuKey, + successUrl: body.successUrl, + cancelUrl: body.cancelUrl, + customerEmail: authContext.userEmail, + }); + } + + @Post('webhook') + @Public() + @HttpCode(200) + @ApiOperation({ summary: 'Receive Stripe billing webhook events' }) + async webhook( + @Headers('stripe-signature') signature: string | undefined, + @Req() req: RawBodyRequest, + ) { + return this.webhookService.handleWebhook({ + rawBody: req.rawBody, + signature, + }); + } +} diff --git a/apps/api/src/billing/billing.module.ts b/apps/api/src/billing/billing.module.ts new file mode 100644 index 0000000000..c2c78f3e4a --- /dev/null +++ b/apps/api/src/billing/billing.module.ts @@ -0,0 +1,18 @@ +import { Module } from '@nestjs/common'; +import { AuthModule } from '../auth/auth.module'; +import { BillingController } from './billing.controller'; +import { BillingEntitlementsService } from './billing-entitlements.service'; +import { BillingService } from './billing.service'; +import { BillingWebhookService } from './billing-webhook.service'; + +@Module({ + imports: [AuthModule], + controllers: [BillingController], + providers: [ + BillingService, + BillingEntitlementsService, + BillingWebhookService, + ], + exports: [BillingService, BillingEntitlementsService], +}) +export class BillingModule {} diff --git a/apps/api/src/billing/billing.service.spec.ts b/apps/api/src/billing/billing.service.spec.ts new file mode 100644 index 0000000000..151db5f412 --- /dev/null +++ b/apps/api/src/billing/billing.service.spec.ts @@ -0,0 +1,109 @@ +import { BadRequestException } from '@nestjs/common'; +import { db } from '@db'; +import { BillingService } from './billing.service'; +import type { StripeService } from '../stripe/stripe.service'; + +jest.mock('@db', () => ({ + db: { + organization: { + findUniqueOrThrow: jest.fn(), + }, + organizationBilling: { + create: jest.fn(), + findUnique: jest.fn(), + }, + }, +})); + +const mockedDb = db as jest.Mocked; +const organizationFindUniqueOrThrow = mockedDb.organization + .findUniqueOrThrow as unknown as jest.Mock; +const organizationBillingFindUnique = mockedDb.organizationBilling + .findUnique as unknown as jest.Mock; +const organizationBillingCreate = mockedDb.organizationBilling + .create as unknown as jest.Mock; + +function mockStripeService(client: unknown): StripeService { + return { + getClient: () => client, + } as unknown as StripeService; +} + +describe('BillingService', () => { + beforeEach(() => { + jest.clearAllMocks(); + organizationFindUniqueOrThrow.mockResolvedValue({ + id: 'org_1', + name: 'Test Company', + }); + organizationBillingFindUnique.mockResolvedValue(null); + organizationBillingCreate.mockResolvedValue({ + id: 'obil_1', + organizationId: 'org_1', + stripeCustomerId: 'cus_1', + stripePaymentMethodId: null, + paymentMethodUpdatedAt: null, + createdAt: new Date('2026-04-30T00:00:00.000Z'), + updatedAt: new Date('2026-04-30T00:00:00.000Z'), + }); + }); + + it('creates a Stripe subscription checkout session from the billing catalog', async () => { + const customersCreate = jest.fn().mockResolvedValue({ id: 'cus_1' }); + const sessionsCreate = jest.fn().mockResolvedValue({ + url: 'https://checkout.stripe.test/session', + }); + const service = new BillingService( + mockStripeService({ + customers: { create: customersCreate }, + checkout: { sessions: { create: sessionsCreate } }, + }), + ); + + await expect( + service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_5', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + customerEmail: 'admin@example.com', + }), + ).resolves.toEqual({ url: 'https://checkout.stripe.test/session' }); + + expect(customersCreate).toHaveBeenCalledWith( + expect.objectContaining({ + email: 'admin@example.com', + metadata: { organizationId: 'org_1' }, + }), + { idempotencyKey: 'organization-billing-customer:org_1' }, + ); + expect(sessionsCreate).toHaveBeenCalledWith( + expect.objectContaining({ + mode: 'subscription', + customer: 'cus_1', + line_items: [{ price: 'price_1TRya6CkFWhKYvHI1sJ2M2no', quantity: 1 }], + metadata: expect.objectContaining({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_5', + }), + }), + ); + }); + + it('does not create subscription checkout for one-time SKUs', async () => { + const service = new BillingService( + mockStripeService({ + checkout: { sessions: { create: jest.fn() } }, + }), + ); + + await expect( + service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'background_check_one_time', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }), + ).rejects.toBeInstanceOf(BadRequestException); + }); +}); diff --git a/apps/api/src/billing/billing.service.ts b/apps/api/src/billing/billing.service.ts new file mode 100644 index 0000000000..a00663f237 --- /dev/null +++ b/apps/api/src/billing/billing.service.ts @@ -0,0 +1,299 @@ +import { + BadRequestException, + Injectable, + NotFoundException, +} from '@nestjs/common'; +import { db } from '@db'; +import { + type BillingSku, + type BillingSkuKey, + getBillingSku, + isSubscriptionBillingSkuKey, +} from '@trycompai/billing'; +import { StripeService } from '../stripe/stripe.service'; +import { findOrCreateBillingCustomer } from './billing-customer'; +import { listBillingInvoices } from './billing-invoices'; +import { + type BillingPreferencesInput, + getBillingPreferences, + updateBillingPreferences, +} from './billing-preferences'; +import { validateBillingRedirectUrl } from './billing-redirect-urls'; +import type { BillingStatus } from './billing.types'; +import { listBillingUsageRows } from './billing-usage'; + +@Injectable() +export class BillingService { + constructor(private readonly stripeService: StripeService) {} + + async getStatus(organizationId: string): Promise { + const [organization, billing, subscriptions, backgroundChecks, penetrationTests] = + await Promise.all([ + db.organization.findUniqueOrThrow({ + where: { id: organizationId }, + select: { name: true }, + }), + db.organizationBilling.findUnique({ + where: { organizationId }, + select: { + stripeCustomerId: true, + stripePaymentMethodId: true, + paymentMethodUpdatedAt: true, + }, + }), + db.organizationBillingSubscription.findMany({ + where: { organizationId }, + orderBy: { skuKey: 'asc' }, + }), + db.backgroundCheckRequest.count({ where: { organizationId } }), + db.securityPenetrationTestRun.count({ where: { organizationId } }), + ]); + const [invoices, preferences, usageRows] = await Promise.all([ + listBillingInvoices({ + stripeService: this.stripeService, + stripeCustomerId: billing?.stripeCustomerId ?? null, + }), + getBillingPreferences({ + stripeService: this.stripeService, + stripeCustomerId: billing?.stripeCustomerId ?? null, + fallbackCompanyName: organization.name, + }), + listBillingUsageRows({ organizationId, subscriptions }), + ]); + + return { + hasBilling: !!billing, + hasPaymentMethod: !!billing?.stripePaymentMethodId, + setupAt: billing?.paymentMethodUpdatedAt ?? null, + usage: { backgroundChecks, penetrationTests }, + preferences, + usageRows, + subscriptions: subscriptions.map((subscription) => ({ + skuKey: subscription.skuKey, + status: subscription.stripeStatus, + includedQuantity: subscription.includedQuantity, + usedQuantity: subscription.usedQuantity, + currentPeriodStart: + subscription.currentPeriodStart?.toISOString() ?? null, + currentPeriodEnd: subscription.currentPeriodEnd?.toISOString() ?? null, + cancelAtPeriodEnd: subscription.cancelAtPeriodEnd, + })), + invoices, + }; + } + + async updatePreferences(params: { + organizationId: string; + preferences: BillingPreferencesInput; + }): Promise { + const result = await updateBillingPreferences({ + stripeService: this.stripeService, + organizationId: params.organizationId, + preferences: params.preferences, + }); + + await db.billingAuditEvent.create({ + data: { + organizationId: params.organizationId, + eventType: 'billing_preferences_updated', + metadata: { + stripeCustomerId: result.stripeCustomerId, + billingEmail: result.preferences.billingEmail, + }, + }, + }); + + return this.getStatus(params.organizationId); + } + + async createSetupSession(params: { + organizationId: string; + successUrl: string; + cancelUrl: string; + customerEmail?: string; + }): Promise<{ url: string }> { + validateBillingRedirectUrl(params.successUrl); + validateBillingRedirectUrl(params.cancelUrl); + + const stripe = this.stripeService.getClient(); + const customerId = await findOrCreateBillingCustomer({ + stripeService: this.stripeService, + organizationId: params.organizationId, + customerEmail: params.customerEmail, + }); + + const session = await stripe.checkout.sessions.create({ + mode: 'setup', + customer: customerId, + currency: 'usd', + success_url: params.successUrl, + cancel_url: params.cancelUrl, + metadata: { + organizationId: params.organizationId, + source: 'comp-billing-setup', + }, + }); + + if (!session.url) { + throw new BadRequestException( + 'Failed to create Stripe Checkout session.', + ); + } + + return { url: session.url }; + } + + async handleSetupSuccess(params: { + organizationId: string; + sessionId: string; + }): Promise<{ success: true }> { + const stripe = this.stripeService.getClient(); + const session = await stripe.checkout.sessions.retrieve(params.sessionId, { + expand: ['setup_intent'], + }); + + if (session.status !== 'complete') { + throw new BadRequestException('Checkout session is not complete.'); + } + + if (session.metadata?.organizationId !== params.organizationId) { + throw new BadRequestException( + 'Checkout session does not belong to this organization.', + ); + } + + const stripeCustomerId = extractStripeId(session.customer); + if (!stripeCustomerId) { + throw new BadRequestException('Checkout session is missing a customer.'); + } + const billing = await db.organizationBilling.findUnique({ + where: { organizationId: params.organizationId }, + select: { stripeCustomerId: true }, + }); + if (billing && billing.stripeCustomerId !== stripeCustomerId) { + throw new BadRequestException( + 'Checkout session customer does not match this organization.', + ); + } + + const setupIntent = session.setup_intent; + if (!setupIntent || typeof setupIntent === 'string') { + throw new BadRequestException( + 'Checkout session is missing a setup intent.', + ); + } + + const paymentMethodId = extractStripeId(setupIntent.payment_method); + if (!paymentMethodId) { + throw new BadRequestException( + 'Setup intent is missing a payment method.', + ); + } + + await stripe.customers.update(stripeCustomerId, { + invoice_settings: { default_payment_method: paymentMethodId }, + }); + + await db.organizationBilling.upsert({ + where: { organizationId: params.organizationId }, + create: { + organizationId: params.organizationId, + stripeCustomerId, + stripePaymentMethodId: paymentMethodId, + paymentMethodUpdatedAt: new Date(), + }, + update: { + stripeCustomerId, + stripePaymentMethodId: paymentMethodId, + paymentMethodUpdatedAt: new Date(), + }, + }); + + return { success: true }; + } + + async createBillingPortalSession(params: { + organizationId: string; + returnUrl: string; + }): Promise<{ url: string }> { + validateBillingRedirectUrl(params.returnUrl); + + const billing = await db.organizationBilling.findUnique({ + where: { organizationId: params.organizationId }, + select: { stripeCustomerId: true }, + }); + if (!billing) { + throw new NotFoundException( + 'No billing record found for this organization.', + ); + } + + const session = await this.stripeService + .getClient() + .billingPortal.sessions.create({ + customer: billing.stripeCustomerId, + return_url: params.returnUrl, + }); + return { url: session.url }; + } + + async createSubscriptionCheckoutSession(params: { + organizationId: string; + skuKey: string; + successUrl: string; + cancelUrl: string; + customerEmail?: string; + }): Promise<{ url: string }> { + validateBillingRedirectUrl(params.successUrl); + validateBillingRedirectUrl(params.cancelUrl); + if (!isSubscriptionBillingSkuKey(params.skuKey)) { + throw new BadRequestException('Unknown subscription SKU.'); + } + + const sku = getBillingSku({ skuKey: params.skuKey }); + const customerId = await findOrCreateBillingCustomer({ + stripeService: this.stripeService, + organizationId: params.organizationId, + customerEmail: params.customerEmail, + }); + const stripe = this.stripeService.getClient(); + const session = await stripe.checkout.sessions.create({ + mode: 'subscription', + customer: customerId, + line_items: [{ price: sku.stripePriceId, quantity: 1 }], + success_url: params.successUrl, + cancel_url: params.cancelUrl, + metadata: { + organizationId: params.organizationId, + skuKey: sku.key, + source: 'comp-billing-subscription', + }, + subscription_data: { + metadata: { + organizationId: params.organizationId, + skuKey: sku.key, + source: 'comp-billing-subscription', + }, + }, + }); + + if (!session.url) { + throw new BadRequestException( + 'Failed to create Stripe Checkout session.', + ); + } + return { url: session.url }; + } + + getOneTimeBackgroundCheckSku(): BillingSku { + return getBillingSku({ skuKey: 'background_check_one_time' }); + } +} + +function extractStripeId( + value: string | { id?: string } | null, +): string | null { + if (!value) return null; + if (typeof value === 'string') return value; + return value.id ?? null; +} diff --git a/apps/api/src/billing/billing.types.ts b/apps/api/src/billing/billing.types.ts new file mode 100644 index 0000000000..8e3b2a8575 --- /dev/null +++ b/apps/api/src/billing/billing.types.ts @@ -0,0 +1,38 @@ +import type { BillingInvoice } from './billing-invoices'; +import type { BillingPreferences } from './billing-preferences'; + +export interface BillingStatus { + hasBilling: boolean; + hasPaymentMethod: boolean; + setupAt: Date | null; + usage: { + backgroundChecks: number; + penetrationTests: number; + }; + preferences: BillingPreferences; + usageRows: BillingUsageRow[]; + subscriptions: Array<{ + skuKey: string; + status: string; + includedQuantity: number; + usedQuantity: number; + currentPeriodStart: string | null; + currentPeriodEnd: string | null; + cancelAtPeriodEnd: boolean; + }>; + invoices: BillingInvoice[]; +} + +export interface BillingUsageRow { + id: string; + service: 'Penetration Test' | 'Background Check'; + skuKey: string; + details: string; + status: string; + billingType: string; + createdAt: string; + updatedAt: string; + subscriptionRemaining: number | null; + subscriptionIncluded: number | null; + subscriptionPeriodEnd: string | null; +} diff --git a/apps/api/src/billing/dto/billing.dto.ts b/apps/api/src/billing/dto/billing.dto.ts new file mode 100644 index 0000000000..9218231068 --- /dev/null +++ b/apps/api/src/billing/dto/billing.dto.ts @@ -0,0 +1,92 @@ +import { subscriptionBillingSkuKeys } from '@trycompai/billing'; +import { IsEmail, IsIn, IsOptional, IsString, IsUrl, Length } from 'class-validator'; +import { billingTaxIdTypes } from '../billing-preferences'; + +export class BillingSetupSessionDto { + @IsString() + @IsUrl({ require_tld: false }, { message: 'successUrl must be a valid URL' }) + successUrl: string; + + @IsString() + @IsUrl({ require_tld: false }, { message: 'cancelUrl must be a valid URL' }) + cancelUrl: string; +} + +export class BillingSetupSuccessDto { + @IsString() + sessionId: string; +} + +export class BillingPortalDto { + @IsString() + @IsUrl({ require_tld: false }, { message: 'returnUrl must be a valid URL' }) + returnUrl: string; +} + +export class BillingSubscriptionCheckoutDto { + @IsString() + @IsIn(subscriptionBillingSkuKeys) + skuKey: string; + + @IsString() + @IsUrl({ require_tld: false }, { message: 'successUrl must be a valid URL' }) + successUrl: string; + + @IsString() + @IsUrl({ require_tld: false }, { message: 'cancelUrl must be a valid URL' }) + cancelUrl: string; +} + +export class BillingPreferencesDto { + @IsString() + @Length(1, 150) + companyName: string; + + @IsEmail() + billingEmail: string; + + @IsOptional() + @IsString() + @Length(0, 140) + purchaseOrder: string | null; + + @IsOptional() + @IsString() + @Length(0, 200) + addressLine1: string | null; + + @IsOptional() + @IsString() + @Length(0, 200) + addressLine2: string | null; + + @IsOptional() + @IsString() + @Length(0, 100) + addressCity: string | null; + + @IsOptional() + @IsString() + @Length(0, 100) + addressState: string | null; + + @IsOptional() + @IsString() + @Length(0, 32) + addressPostalCode: string | null; + + @IsOptional() + @IsString() + @Length(0, 2) + addressCountry: string | null; + + @IsOptional() + @IsString() + @IsIn([...billingTaxIdTypes, '']) + taxIdType: string | null; + + @IsOptional() + @IsString() + @Length(0, 64) + taxIdValue: string | null; +} diff --git a/apps/api/src/billing/stripe-webhook-records.ts b/apps/api/src/billing/stripe-webhook-records.ts new file mode 100644 index 0000000000..52638d7f67 --- /dev/null +++ b/apps/api/src/billing/stripe-webhook-records.ts @@ -0,0 +1,67 @@ +import { Prisma, db } from '@db'; + +export type StripeWebhookClaim = + | { status: 'claimed' } + | { status: 'duplicate' }; + +export async function claimStripeWebhookEvent(params: { + stripeEventId: string; + eventType: string; + payload: Prisma.InputJsonValue; +}): Promise { + try { + await db.stripeWebhookEvent.create({ + data: { + stripeEventId: params.stripeEventId, + eventType: params.eventType, + payload: params.payload, + status: 'processing', + }, + }); + return { status: 'claimed' }; + } catch (error) { + if (!isUniqueConstraintError(error)) throw error; + const existing = await db.stripeWebhookEvent.findUnique({ + where: { stripeEventId: params.stripeEventId }, + select: { status: true }, + }); + if (existing?.status === 'failed') { + await db.stripeWebhookEvent.update({ + where: { stripeEventId: params.stripeEventId }, + data: { status: 'processing', error: null }, + }); + return { status: 'claimed' }; + } + return { status: 'duplicate' }; + } +} + +export async function markStripeWebhookProcessed( + stripeEventId: string, +): Promise { + await db.stripeWebhookEvent.update({ + where: { stripeEventId }, + data: { status: 'processed', error: null, processedAt: new Date() }, + }); +} + +export async function markStripeWebhookFailed(params: { + stripeEventId: string; + error: unknown; +}): Promise { + await db.stripeWebhookEvent.update({ + where: { stripeEventId: params.stripeEventId }, + data: { + status: 'failed', + error: + params.error instanceof Error ? params.error.message : 'Unknown error', + }, + }); +} + +function isUniqueConstraintError(error: unknown): boolean { + return ( + error instanceof Prisma.PrismaClientKnownRequestError && + error.code === 'P2002' + ); +} diff --git a/apps/api/src/main.ts b/apps/api/src/main.ts index babfc7e119..d46f1dd99c 100644 --- a/apps/api/src/main.ts +++ b/apps/api/src/main.ts @@ -93,6 +93,8 @@ async function bootstrap(): Promise { '/security-penetration-tests/webhook', '/v1/background-checks/webhook', '/background-checks/webhook', + '/v1/billing/webhook', + '/billing/webhook', ]; const needsRawBody = (req: express.Request): boolean => RAW_BODY_PATHS.some((p) => req.path.endsWith(p)); diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.module.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.module.ts index 1560d1c85e..1f648b8f34 100644 --- a/apps/api/src/security-penetration-tests/security-penetration-tests.module.ts +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.module.ts @@ -1,12 +1,13 @@ import { Module } from '@nestjs/common'; import { AuthModule } from '../auth/auth.module'; +import { BillingModule } from '../billing/billing.module'; import { PentestCreditsController } from './pentest-credits.controller'; import { PentestCreditsService } from './pentest-credits.service'; import { SecurityPenetrationTestsController } from './security-penetration-tests.controller'; import { SecurityPenetrationTestsService } from './security-penetration-tests.service'; @Module({ - imports: [AuthModule], + imports: [AuthModule, BillingModule], controllers: [SecurityPenetrationTestsController, PentestCreditsController], providers: [SecurityPenetrationTestsService, PentestCreditsService], exports: [PentestCreditsService], diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.service.spec.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.service.spec.ts index 2afda1b79c..6bee1e0dc2 100644 --- a/apps/api/src/security-penetration-tests/security-penetration-tests.service.spec.ts +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.service.spec.ts @@ -1,6 +1,7 @@ import { HttpException, HttpStatus } from '@nestjs/common'; import { db } from '@db'; import { createHash } from 'node:crypto'; +import type { BillingEntitlementsService } from '../billing/billing-entitlements.service'; import type { CredentialVaultService } from '../integration-platform/services/credential-vault.service'; import type { CreatePenetrationTestDto } from './dto/create-penetration-test.dto'; import type { PentestCreditsService } from './pentest-credits.service'; @@ -22,6 +23,16 @@ const mockPentestCreditsService: jest.Mocked< refund: jest.fn(), }; +const mockBillingEntitlementsService: jest.Mocked< + Pick< + BillingEntitlementsService, + 'tryConsumeIncludedUsage' | 'refundIncludedUsage' + > +> = { + tryConsumeIncludedUsage: jest.fn(), + refundIncludedUsage: jest.fn(), +}; + jest.mock('@db', () => ({ db: { securityPenetrationTestRun: { @@ -102,8 +113,13 @@ describe('SecurityPenetrationTestsService', () => { lastGrantSource: 'trial', }); mockPentestCreditsService.refund.mockResolvedValue(); + mockBillingEntitlementsService.tryConsumeIncludedUsage.mockResolvedValue({ + status: 'not_configured', + }); + mockBillingEntitlementsService.refundIncludedUsage.mockResolvedValue(); service = new SecurityPenetrationTestsService( mockPentestCreditsService as unknown as PentestCreditsService, + mockBillingEntitlementsService as unknown as BillingEntitlementsService, ); fetchMock.mockReset(); global.fetch = fetchMock as unknown as typeof fetch; diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts index 8217a2e420..4e3aed0f13 100644 --- a/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts @@ -21,8 +21,10 @@ import { type PentestProgress as MacedPentestProgress, type PentestWithProgress, } from '@maced/api-client'; +import { randomUUID } from 'crypto'; import type { CreatePenetrationTestDto } from './dto/create-penetration-test.dto'; +import { BillingEntitlementsService } from '../billing/billing-entitlements.service'; import { PentestCreditsService } from './pentest-credits.service'; /** @@ -112,7 +114,10 @@ export class SecurityPenetrationTestsService { private readonly logger = new Logger(SecurityPenetrationTestsService.name); private readonly macedClient: MacedClient; - constructor(private readonly credits: PentestCreditsService) { + constructor( + private readonly credits: PentestCreditsService, + private readonly billingEntitlements: BillingEntitlementsService, + ) { const apiKey = process.env.MACED_API_KEY; if (!apiKey) { // Throw at construction so the app fails loudly on boot, not on first request. @@ -168,8 +173,7 @@ export class SecurityPenetrationTestsService { // …). Include the constructor name + message in the log so we can tell // what actually broke without a debugger. const errName = error?.constructor?.name ?? typeof error; - const errMessage = - error instanceof Error ? error.message : String(error); + const errMessage = error instanceof Error ? error.message : String(error); this.logger.error( `Transport failure calling Maced (${context}): ${errName} — ${errMessage}`, ); @@ -222,6 +226,8 @@ export class SecurityPenetrationTestsService { payload: CreatePenetrationTestDto, ): Promise { const resolvedWebhookUrl = this.resolveWebhookUrl(payload.webhookUrl); + const billingUsageSourceId = `pending:${randomUUID()}`; + let consumedSubscriptionAllowance = false; // Debit FIRST so concurrent fast-clicks block at the cheap DB // conditional update before any of them reach Maced. Without this, @@ -231,7 +237,27 @@ export class SecurityPenetrationTestsService { // `updateMany WHERE balance > 0` guarantees only one decrement // succeeds. try { - await this.credits.debitOrThrow(organizationId); + const subscriptionUsage = + await this.billingEntitlements.tryConsumeIncludedUsage({ + organizationId, + skuKey: 'pentest_monthly_5', + sourceResourceId: billingUsageSourceId, + }); + if (subscriptionUsage.status === 'exhausted') { + throw new HttpException( + { + error: + 'No pentest runs remaining in your subscription. Upgrade or wait for your monthly allowance to reset.', + code: 'pentest_subscription_exhausted', + }, + HttpStatus.PAYMENT_REQUIRED, + ); + } + if (subscriptionUsage.status === 'not_configured') { + await this.credits.debitOrThrow(organizationId); + } else { + consumedSubscriptionAllowance = true; + } } catch (error) { if ( error instanceof HttpException && @@ -289,17 +315,37 @@ export class SecurityPenetrationTestsService { } catch (error) { // Provider call failed after we debited. Refund so the user isn't // charged for a run that never started. - await this.refundQuietly(organizationId, 'pending', 'maced_create_failed'); + if (consumedSubscriptionAllowance) { + await this.refundBillingUsageQuietly({ + organizationId, + sourceResourceId: billingUsageSourceId, + reason: 'maced_create_failed', + }); + } else { + await this.refundQuietly( + organizationId, + 'pending', + 'maced_create_failed', + ); + } throw error; } const providerRunId = createdReport.id; if (!providerRunId) { - await this.refundQuietly( - organizationId, - 'pending', - 'maced_missing_run_id', - ); + if (consumedSubscriptionAllowance) { + await this.refundBillingUsageQuietly({ + organizationId, + sourceResourceId: billingUsageSourceId, + reason: 'maced_missing_run_id', + }); + } else { + await this.refundQuietly( + organizationId, + 'pending', + 'maced_missing_run_id', + ); + } throw new HttpException( { error: 'Create response missing report identifier' }, HttpStatus.BAD_GATEWAY, @@ -316,11 +362,19 @@ export class SecurityPenetrationTestsService { // shouldn't pay for it. The Maced run is orphaned (no // ownership) but Maced has the `compOrganizationId` metadata if // support ever needs to clean it up. - await this.refundQuietly( - organizationId, - providerRunId, - 'ownership_persist_failed', - ); + if (consumedSubscriptionAllowance) { + await this.refundBillingUsageQuietly({ + organizationId, + sourceResourceId: billingUsageSourceId, + reason: 'ownership_persist_failed', + }); + } else { + await this.refundQuietly( + organizationId, + providerRunId, + 'ownership_persist_failed', + ); + } throw new HttpException( { error: @@ -374,10 +428,7 @@ export class SecurityPenetrationTestsService { ); } - async getReportIssues( - organizationId: string, - id: string, - ): Promise { + async getReportIssues(organizationId: string, id: string): Promise { await this.assertRunOwnership(organizationId, id); return this.callMaced( () => this.macedClient.pentests.issues(id), @@ -448,7 +499,9 @@ export class SecurityPenetrationTestsService { eventId?: string; }> { if (!params.rawBody) { - throw new BadRequestException('Missing raw body for webhook verification'); + throw new BadRequestException( + 'Missing raw body for webhook verification', + ); } const secret = process.env.MACED_WEBHOOK_SIGNING_SECRET; @@ -524,15 +577,13 @@ export class SecurityPenetrationTestsService { * those are rare race-condition artifacts and don't represent * customer-visible state. */ - private async auditPentestCompleted( - data: { - pentestId: string; - targetUrl: string; - issueCount: number; - durationMs: number; - agentCount: number; - }, - ): Promise { + private async auditPentestCompleted(data: { + pentestId: string; + targetUrl: string; + issueCount: number; + durationMs: number; + agentCount: number; + }): Promise { // Atomic claim — only the first webhook delivery for this run gets // count: 1 back. Subsequent redeliveries see `completed_audit_at` // already set and bail out before writing a duplicate audit row. @@ -622,9 +673,7 @@ export class SecurityPenetrationTestsService { if (!run) { // Vanishingly rare race; abort the transaction so the claim // rolls back. Webhook redelivery will retry. - throw new Error( - `Run row vanished after claim for ${providerRunId}`, - ); + throw new Error(`Run row vanished after claim for ${providerRunId}`); } // Pass the tx client through so the wallet write happens in @@ -866,6 +915,27 @@ export class SecurityPenetrationTestsService { } } + private async refundBillingUsageQuietly(params: { + organizationId: string; + sourceResourceId: string; + reason: string; + }): Promise { + try { + await this.billingEntitlements.refundIncludedUsage({ + organizationId: params.organizationId, + skuKey: 'pentest_monthly_5', + sourceResourceId: params.sourceResourceId, + reason: params.reason, + }); + } catch (error) { + this.logger.error( + `Billing usage refund failed for org=${params.organizationId} source=${params.sourceResourceId} reason=${params.reason}: ${ + error instanceof Error ? error.message : String(error) + }`, + ); + } + } + private async persistRunOwnership( organizationId: string, reportId: string, @@ -1002,5 +1072,4 @@ export class SecurityPenetrationTestsService { return hosts; } - } diff --git a/apps/app/package.json b/apps/app/package.json index de6e57de53..118705ed56 100644 --- a/apps/app/package.json +++ b/apps/app/package.json @@ -67,6 +67,7 @@ "@trigger.dev/react-hooks": "4.4.3", "@trigger.dev/sdk": "4.4.3", "@trycompai/auth": "workspace:*", + "@trycompai/billing": "workspace:*", "@trycompai/company": "workspace:*", "@trycompai/db": "workspace:*", "@trycompai/design-system": "^1.0.32", @@ -190,7 +191,7 @@ "build": "next build", "build:docker": "prisma generate --schema=prisma/schema && node ../../packages/db/scripts/fix-generated-extensions.js src/generated/prisma && next build", "db:generate": "bun run db:getschema && prisma generate --schema=prisma/schema && node ../../packages/db/scripts/fix-generated-extensions.js src/generated/prisma", - "db:getschema": "find ../../packages/db/prisma/schema -name '*.prisma' ! -name 'schema.prisma' -exec cp {} prisma/schema/ \\;", + "db:getschema": "find prisma/schema -name '*.prisma' ! -name 'schema.prisma' -delete && find ../../packages/db/prisma/schema -name '*.prisma' ! -name 'schema.prisma' -exec cp {} prisma/schema/ \\;", "db:migrate": "cd ../../packages/db && bunx prisma migrate dev && cd ../../apps/app", "deploy:trigger-prod": "npx trigger.dev@4.4.3 deploy", "dev": "bunx concurrently --kill-others --names \"next,trigger\" --prefix-colors \"yellow,blue\" \"NODE_OPTIONS='--no-deprecation' next dev --turbo -p 3000\" \"NODE_OPTIONS='--no-deprecation' trigger dev\"", diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnPlansClient.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnPlansClient.tsx new file mode 100644 index 0000000000..3e29e6dd0c --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnPlansClient.tsx @@ -0,0 +1,124 @@ +'use client'; + +import { usePermissions } from '@/hooks/use-permissions'; +import { apiClient } from '@/lib/api-client'; +import type { BillingSkuKey } from '@trycompai/billing'; +import { + PageHeader, + PageHeaderDescription, + PageLayout, + Section, + Tabs, + TabsContent, + TabsList, + TabsTrigger, +} from '@trycompai/design-system'; +import { useMemo, useState } from 'react'; +import { toast } from 'sonner'; +import useSWR from 'swr'; +import { BillingSubscriptionPlans } from './BillingSubscriptionPlans'; +import type { BillingAddOn } from './billingAddOns'; +import type { BackgroundCheckBillingStatus } from './types'; + +interface BillingAddOnPlansClientProps { + organizationId: string; + addOn: BillingAddOn; + initialBillingStatus: BackgroundCheckBillingStatus; +} + +export function BillingAddOnPlansClient({ + organizationId, + addOn, + initialBillingStatus, +}: BillingAddOnPlansClientProps) { + const [loadingSubscriptionSku, setLoadingSubscriptionSku] = useState(null); + const { hasPermission } = usePermissions(); + const canManageBilling = hasPermission('organization', 'update'); + + const { data: billingStatus } = useSWR( + ['/v1/billing/status', organizationId], + async ([endpoint]) => { + const response = await apiClient.get(endpoint, organizationId); + if (response.error || !response.data) { + throw new Error('Failed to load billing status'); + } + return response.data; + }, + { + fallbackData: initialBillingStatus, + revalidateOnMount: false, + }, + ); + + const subscriptions = useMemo( + () => billingStatus?.subscriptions ?? initialBillingStatus.subscriptions ?? [], + [billingStatus?.subscriptions, initialBillingStatus.subscriptions], + ); + + const handleOpenSubscription = async (skuKey: BillingSkuKey) => { + setLoadingSubscriptionSku(skuKey); + + const returnUrl = `${window.location.origin}/${organizationId}/settings/billing/add-ons/${addOn.slug}`; + const response = await apiClient.post<{ url: string }>( + '/v1/billing/subscription-session', + { + skuKey, + successUrl: `${returnUrl}?billing_subscription=success&session_id={CHECKOUT_SESSION_ID}`, + cancelUrl: returnUrl, + }, + organizationId, + ); + + if (response.data?.url) { + window.location.href = response.data.url; + return; + } + + toast.error('Failed to open checkout'); + setLoadingSubscriptionSku(null); + }; + + return ( + + + Overview + + } + > + {addOn.description} + + } + > + +
+ +
+
+
+
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOns.test.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOns.test.tsx new file mode 100644 index 0000000000..0c779cfc94 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOns.test.tsx @@ -0,0 +1,163 @@ +import { apiClient } from '@/lib/api-client'; +import { render, screen, waitFor } from '@testing-library/react'; +import userEvent from '@testing-library/user-event'; +import { SWRConfig } from 'swr'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; +import { BillingAddOnPlansClient } from './BillingAddOnPlansClient'; +import { BillingAddOnsOverview } from './BillingAddOnsOverview'; +import { getBillingAddOn } from './billingAddOns'; +import type { BackgroundCheckBillingStatus } from './types'; + +const navigationMock = vi.hoisted(() => ({ + push: vi.fn(), +})); + +const permissionMock = vi.hoisted(() => ({ + canUpdateOrganization: true, +})); + +vi.mock('@/hooks/use-permissions', () => ({ + usePermissions: () => ({ + hasPermission: (resource: string, action: string) => + resource === 'organization' && action === 'update' && permissionMock.canUpdateOrganization, + }), +})); + +vi.mock('@/lib/api-client', () => ({ + apiClient: { + get: vi.fn(), + post: vi.fn(), + }, +})); + +vi.mock('next/navigation', () => ({ + useRouter: () => ({ push: navigationMock.push }), +})); + +vi.mock('sonner', () => ({ + toast: { error: vi.fn() }, +})); + +const emptyBillingStatus: BackgroundCheckBillingStatus = { + hasPaymentMethod: true, + setupAt: null, + usage: { backgroundChecks: 0, penetrationTests: 0 }, + preferences: null, + usageRows: [], + subscriptions: [], + invoices: [], +}; + +function renderAddOnPlans({ + addOnSlug, + subscriptions = [], +}: { + addOnSlug: 'penetration-tests' | 'background-checks'; + subscriptions?: NonNullable; +}) { + const addOn = getBillingAddOn(addOnSlug); + if (!addOn) throw new Error(`Missing test add-on: ${addOnSlug}`); + + return render( + new Map() }}> + + , + ); +} + +describe('billing add-ons', () => { + beforeEach(() => { + vi.clearAllMocks(); + navigationMock.push.mockReset(); + permissionMock.canUpdateOrganization = true; + vi.mocked(apiClient.post).mockResolvedValue({ + data: { url: '#stripe-session' }, + status: 200, + }); + }); + + it('shows product-level add-ons before plan selection', () => { + render(); + + expect(screen.getByText('Penetration Tests')).toBeInTheDocument(); + expect(screen.getByText('Background Checks')).toBeInTheDocument(); + screen.getByRole('button', { name: /view penetration tests plans/i }).click(); + expect(navigationMock.push).toHaveBeenCalledWith( + '/org_1/settings/billing/add-ons/penetration-tests', + ); + screen.getByRole('button', { name: /view background checks plans/i }).click(); + expect(navigationMock.push).toHaveBeenCalledWith( + '/org_1/settings/billing/add-ons/background-checks', + ); + }); + + it('opens subscription checkout for an add-on plan', async () => { + const user = userEvent.setup(); + renderAddOnPlans({ addOnSlug: 'background-checks' }); + + await user.click(screen.getByRole('button', { name: /subscribe to background checks/i })); + + await waitFor(() => { + expect(apiClient.post).toHaveBeenCalledWith( + '/v1/billing/subscription-session', + { + skuKey: 'background_checks_monthly_25', + successUrl: + 'http://localhost:3000/org_1/settings/billing/add-ons/background-checks?billing_subscription=success&session_id={CHECKOUT_SESSION_ID}', + cancelUrl: 'http://localhost:3000/org_1/settings/billing/add-ons/background-checks', + }, + 'org_1', + ); + }); + }); + + it('shows add-on plans on a standalone overview tab', () => { + renderAddOnPlans({ addOnSlug: 'penetration-tests' }); + + expect(screen.getByRole('link', { name: /add-ons/i })).toHaveAttribute( + 'href', + '/org_1/settings/billing', + ); + expect(screen.getByRole('heading', { name: 'Penetration Test' })).toBeInTheDocument(); + expect(screen.getByRole('tab', { name: /^overview$/i })).toBeInTheDocument(); + expect( + screen.getByRole('button', { name: /subscribe to penetration tests/i }), + ).toBeInTheDocument(); + }); + + it('shows active add-on subscriptions as disabled plan actions', async () => { + const user = userEvent.setup(); + renderAddOnPlans({ + addOnSlug: 'penetration-tests', + subscriptions: [ + { + skuKey: 'pentest_monthly_5', + status: 'active', + includedQuantity: 5, + usedQuantity: 1, + currentPeriodStart: '2026-04-30T00:00:00.000Z', + currentPeriodEnd: '2026-05-30T00:00:00.000Z', + cancelAtPeriodEnd: false, + }, + ], + }); + + const activeButton = screen.getByRole('button', { name: /active subscription/i }); + expect(activeButton).toBeDisabled(); + expect(screen.getByText('1 of 5 used this period.')).toBeInTheDocument(); + + await user.click(activeButton); + expect(apiClient.post).not.toHaveBeenCalledWith( + '/v1/billing/subscription-session', + expect.objectContaining({ skuKey: 'pentest_monthly_5' }), + 'org_1', + ); + }); +}); diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnsOverview.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnsOverview.tsx new file mode 100644 index 0000000000..5b51244841 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnsOverview.tsx @@ -0,0 +1,74 @@ +'use client'; + +import { Badge, Button, Stack, Text } from '@trycompai/design-system'; +import { ArrowRight } from '@trycompai/design-system/icons'; +import { useRouter } from 'next/navigation'; +import { billingAddOns } from './billingAddOns'; +import type { BackgroundCheckBillingStatus } from './types'; + +interface BillingAddOnsOverviewProps { + organizationId: string; + subscriptions: NonNullable; +} + +export function BillingAddOnsOverview({ + organizationId, + subscriptions, +}: BillingAddOnsOverviewProps) { + const router = useRouter(); + + return ( +
+ {billingAddOns.map((addOn) => { + const skuKeys: readonly string[] = addOn.skuKeys; + const activeSubscription = subscriptions.find( + (subscription) => + skuKeys.includes(subscription.skuKey) && + (subscription.status === 'active' || subscription.status === 'trialing'), + ); + + return ( +
+ + +
+ {addOn.name} + {activeSubscription && Active} +
+ + {addOn.description} + +
+ + + + {addOn.summary} + + {activeSubscription && ( + + {activeSubscription.usedQuantity} of {activeSubscription.includedQuantity} used + this period. + + )} + + +
+ +
+
+
+ ); + })} +
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingPaymentMethodCard.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingPaymentMethodCard.tsx new file mode 100644 index 0000000000..9a30bb415c --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingPaymentMethodCard.tsx @@ -0,0 +1,75 @@ +'use client'; + +import { + Badge, + Button, + Card, + CardContent, + CardHeader, + CardTitle, + Stack, + Text, +} from '@trycompai/design-system'; +import { Launch } from '@trycompai/design-system/icons'; + +interface BillingPaymentMethodCardProps { + hasPaymentMethod: boolean; + statusLabel: string; + canManageBilling: boolean; + isOpeningBilling: boolean; + onOpenBilling: () => void; +} + +export function BillingPaymentMethodCard({ + hasPaymentMethod, + statusLabel, + canManageBilling, + isOpeningBilling, + onOpenBilling, +}: BillingPaymentMethodCardProps) { + const actionLabel = hasPaymentMethod ? 'Open Billing Portal' : 'Add payment method'; + + return ( + + + Managed securely in the billing portal + + + + } + > + +
+ Payment method + {statusLabel} +
+
+ + + + {hasPaymentMethod + ? 'A payment method is connected. Open the billing portal to update billing details, cards, and receipts.' + : 'Add a payment method to use paid services such as background checks and penetration testing.'} + + {!canManageBilling && ( + + Ask an organization admin to update billing details. + + )} + + +
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingPreferencesForm.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingPreferencesForm.tsx new file mode 100644 index 0000000000..5cad22c38a --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingPreferencesForm.tsx @@ -0,0 +1,294 @@ +'use client'; + +import { apiClient } from '@/lib/api-client'; +import { zodResolver } from '@hookform/resolvers/zod'; +import { + Button, + Card, + Input, + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, + Stack, + Text, +} from '@trycompai/design-system'; +import type React from 'react'; +import { Controller, useForm } from 'react-hook-form'; +import { toast } from 'sonner'; +import type { BackgroundCheckBillingStatus, BillingPreferences } from './types'; +import { + billingCountries, + billingPreferencesSchema, + getCountryLabel, + getTaxIdTypeLabel, + taxIdTypes, + toBillingPreferencesFormValues, + toBillingPreferencesPayload, + type BillingPreferencesFormValues, +} from './billingPreferencesFormSchema'; + +interface BillingPreferencesFormProps { + organizationId: string; + preferences: BillingPreferences | null; + disabled: boolean; + onSaved: (status: BackgroundCheckBillingStatus) => void; +} + +export function BillingPreferencesForm({ + organizationId, + preferences, + disabled, + onSaved, +}: BillingPreferencesFormProps) { + const { + control, + register, + handleSubmit, + formState: { errors, isSubmitting, isDirty }, + reset, + } = useForm({ + resolver: zodResolver(billingPreferencesSchema), + values: toBillingPreferencesFormValues(preferences), + }); + + const handleSave = handleSubmit(async (values) => { + const response = await apiClient.put( + '/v1/billing/preferences', + toBillingPreferencesPayload(values), + organizationId, + ); + + if (response.error || !response.data) { + toast.error(response.error ?? 'Failed to save billing preferences'); + return; + } + + onSaved(response.data); + reset(toBillingPreferencesFormValues(response.data.preferences ?? null)); + toast.success('Billing preferences saved'); + }); + + return ( +
+ + + Future invoice emails are sent to the billing email above. + + + + } + > + +
+ + + + + + +
+ +
+ + + + + + + + + + + + + + + + + ( + + )} + /> + +
+ +
+ + ( + + )} + /> + + + + + + + +
+ +
+
+ + ); +} + +function FormField({ + id, + label, + error, + children, +}: { + id: string; + label: string; + error?: string; + children: React.ReactNode; +}) { + return ( + + + {children} + {error && ( + + {error} + + )} + + ); +} diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.test.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.test.tsx index 8b6e00617c..828eca59d8 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.test.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.test.tsx @@ -26,6 +26,7 @@ vi.mock('@/lib/api-client', () => ({ apiClient: { get: vi.fn(), post: vi.fn(), + put: vi.fn(), }, })); @@ -58,6 +59,22 @@ function renderBillingSettings({ invoicePdfUrl: 'https://invoice.stripe.com/i/in_1.pdf', }, ], + subscriptions = [], + usageRows = [], + preferences = { + companyName: 'Test Company', + billingEmail: 'billing@example.com', + purchaseOrder: null, + address: { + line1: null, + line2: null, + city: null, + state: null, + postalCode: null, + country: null, + }, + taxId: null, + }, }: { hasPaymentMethod?: boolean; backgroundChecks?: number; @@ -65,6 +82,15 @@ function renderBillingSettings({ invoices?: NonNullable< Parameters[0]['initialBillingStatus']['invoices'] >; + subscriptions?: NonNullable< + Parameters[0]['initialBillingStatus']['subscriptions'] + >; + usageRows?: NonNullable< + Parameters[0]['initialBillingStatus']['usageRows'] + >; + preferences?: NonNullable< + Parameters[0]['initialBillingStatus']['preferences'] + >; } = {}) { return render( new Map() }}> @@ -74,6 +100,9 @@ function renderBillingSettings({ hasPaymentMethod, setupAt: null, usage: { backgroundChecks, penetrationTests }, + preferences, + usageRows, + subscriptions, invoices, }} /> @@ -86,34 +115,56 @@ describe('BillingSettingsClient', () => { vi.clearAllMocks(); vi.mocked(apiClient.get).mockReset(); vi.mocked(apiClient.post).mockReset(); + vi.mocked(apiClient.put).mockReset(); vi.mocked(apiClient.post).mockResolvedValue({ data: { url: '#stripe-session' }, status: 200, }); + vi.mocked(apiClient.put).mockResolvedValue({ + data: { + hasPaymentMethod: true, + setupAt: null, + usage: { backgroundChecks: 4, penetrationTests: 2 }, + usageRows: [], + preferences: { + companyName: 'Test Company', + billingEmail: 'accounts@example.com', + purchaseOrder: 'PO-123', + address: { + line1: '1 Test Street', + line2: null, + city: 'London', + state: null, + postalCode: 'SW1A 1AA', + country: 'GB', + }, + taxId: null, + }, + subscriptions: [], + invoices: [], + }, + status: 200, + }); permissionMock.canUpdateOrganization = true; navigationMock.searchParams = new URLSearchParams(); }); - it('opens the Stripe billing portal when a payment method is saved', async () => { + it('opens the billing portal when a payment method is saved', async () => { const user = userEvent.setup(); renderBillingSettings({ hasPaymentMethod: true }); + await user.click(screen.getByRole('tab', { name: /billing details/i })); expect(screen.getByText('Billing set up')).toBeInTheDocument(); - expect(screen.getByText('Historical usage')).toBeInTheDocument(); - expect(screen.getByText('Penetration Tests')).toBeInTheDocument(); - expect(screen.getByText('Background Checks')).toBeInTheDocument(); - expect(screen.getByText('2')).toBeInTheDocument(); - expect(screen.getByText('4')).toBeInTheDocument(); expect(screen.getByText(/update billing details, cards, and receipts/i)).toBeInTheDocument(); expect(screen.getByText('Invoices')).toBeInTheDocument(); expect(screen.getByText('INV-001')).toBeInTheDocument(); expect(screen.getByText('$49.00')).toBeInTheDocument(); expect(screen.getByText('One Time')).toBeInTheDocument(); - await user.click(screen.getByRole('button', { name: /update billing details/i })); + await user.click(screen.getByRole('button', { name: /open billing portal/i })); await waitFor(() => { expect(apiClient.post).toHaveBeenCalledWith( - '/v1/background-check-billing/portal', + '/v1/billing/portal', { returnUrl: 'http://localhost:3000/org_1/settings/billing' }, 'org_1', ); @@ -124,13 +175,14 @@ describe('BillingSettingsClient', () => { const user = userEvent.setup(); renderBillingSettings({ hasPaymentMethod: false }); + await user.click(screen.getByRole('tab', { name: /billing details/i })); expect(screen.getByText('Payment method needed')).toBeInTheDocument(); expect(screen.getByText(/background checks and penetration testing/i)).toBeInTheDocument(); await user.click(screen.getByRole('button', { name: /add payment method/i })); await waitFor(() => { expect(apiClient.post).toHaveBeenCalledWith( - '/v1/background-check-billing/setup-session', + '/v1/billing/setup-session', { successUrl: 'http://localhost:3000/org_1/settings/billing?background_check_billing=success&session_id={CHECKOUT_SESSION_ID}', @@ -146,7 +198,8 @@ describe('BillingSettingsClient', () => { permissionMock.canUpdateOrganization = false; renderBillingSettings({ hasPaymentMethod: true }); - const button = screen.getByRole('button', { name: /update billing details/i }); + await user.click(screen.getByRole('tab', { name: /billing details/i })); + const button = screen.getByRole('button', { name: /open billing portal/i }); expect(button).toBeDisabled(); await user.click(button); @@ -156,7 +209,8 @@ describe('BillingSettingsClient', () => { ).toBeInTheDocument(); }); - it('falls back to zero usage for older billing status payloads', () => { + it('falls back to zero usage for older billing status payloads', async () => { + const user = userEvent.setup(); render( new Map() }}> { , ); + await user.click(screen.getByRole('tab', { name: /^usage$/i })); expect(screen.getAllByText('0')).toHaveLength(2); + await user.click(screen.getByRole('tab', { name: /billing details/i })); expect(screen.getByText('No invoices yet.')).toBeInTheDocument(); }); - it('filters invoices by search query', async () => { - const user = userEvent.setup(); - renderBillingSettings({ - invoices: [ - { - id: 'in_1', - number: 'INV-001', - createdAt: '2026-04-30T09:35:07.000Z', - dueDate: null, - amountPaid: 4900, - amountDue: 4900, - currency: 'usd', - status: 'paid', - type: 'One Time', - hostedInvoiceUrl: 'https://invoice.stripe.com/i/in_1', - invoicePdfUrl: 'https://invoice.stripe.com/i/in_1.pdf', - }, - { - id: 'in_2', - number: 'INV-002', - createdAt: '2026-04-01T09:35:07.000Z', - dueDate: null, - amountPaid: 0, - amountDue: 0, - currency: 'usd', - status: 'paid', - type: 'Subscription', - hostedInvoiceUrl: null, - invoicePdfUrl: null, - }, - ], - }); - - await user.type(screen.getByLabelText('Search invoices'), 'subscription'); - - expect(screen.queryByText('INV-001')).not.toBeInTheDocument(); - expect(screen.getByText('INV-002')).toBeInTheDocument(); - }); }); diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.tsx index eb2034a47b..faca0cf6e7 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.tsx @@ -3,20 +3,24 @@ import { usePermissions } from '@/hooks/use-permissions'; import { apiClient } from '@/lib/api-client'; import { - Badge, - Button, PageHeader, PageLayout, Section, Stack, - Text, + Tabs, + TabsContent, + TabsList, + TabsTrigger, } from '@trycompai/design-system'; -import { Launch } from '@trycompai/design-system/icons'; import { usePathname, useRouter, useSearchParams } from 'next/navigation'; import { useEffect, useMemo, useRef, useState } from 'react'; import { toast } from 'sonner'; import useSWR from 'swr'; +import { BillingAddOnsOverview } from './BillingAddOnsOverview'; import { BillingInvoicesTable } from './BillingInvoicesTable'; +import { BillingPaymentMethodCard } from './BillingPaymentMethodCard'; +import { BillingPreferencesForm } from './BillingPreferencesForm'; +import { BillingUsageTable } from './BillingUsageTable'; import type { BackgroundCheckBillingStatus } from './types'; interface BillingSettingsClientProps { @@ -42,7 +46,7 @@ export function BillingSettingsClient({ const canManageBilling = hasPermission('organization', 'update'); const { data: billingStatus, mutate: mutateBillingStatus } = useSWR( - ['/v1/background-check-billing/status', organizationId], + ['/v1/billing/status', organizationId], async ([endpoint]) => { const response = await apiClient.get(endpoint, organizationId); if (response.error || !response.data) { @@ -62,8 +66,19 @@ export function BillingSettingsClient({ () => billingStatus?.invoices ?? initialBillingStatus.invoices ?? [], [billingStatus?.invoices, initialBillingStatus.invoices], ); + const usageRows = useMemo( + () => billingStatus?.usageRows ?? initialBillingStatus.usageRows ?? [], + [billingStatus?.usageRows, initialBillingStatus.usageRows], + ); + const subscriptions = useMemo( + () => billingStatus?.subscriptions ?? initialBillingStatus.subscriptions ?? [], + [billingStatus?.subscriptions, initialBillingStatus.subscriptions], + ); + const preferences = useMemo( + () => billingStatus?.preferences ?? initialBillingStatus.preferences ?? null, + [billingStatus?.preferences, initialBillingStatus.preferences], + ); const statusLabel = hasPaymentMethod ? 'Billing set up' : 'Payment method needed'; - const actionLabel = hasPaymentMethod ? 'Update Billing Details' : 'Add payment method'; useEffect(() => { const sessionId = searchParams.get('session_id'); @@ -73,7 +88,7 @@ export function BillingSettingsClient({ handledSessionId.current = sessionId; void (async () => { const setupResponse = await apiClient.post<{ success: true }>( - '/v1/background-check-billing/setup-success', + '/v1/billing/setup-success', { sessionId }, organizationId, ); @@ -91,20 +106,32 @@ export function BillingSettingsClient({ setupAt: new Date().toISOString(), usage, invoices, + preferences, + usageRows, + subscriptions, }, { revalidate: true }, ); router.replace(pathname, { scroll: false }); })(); - }, [invoices, mutateBillingStatus, organizationId, pathname, router, searchParams, usage]); + }, [ + invoices, + mutateBillingStatus, + organizationId, + pathname, + router, + searchParams, + subscriptions, + usage, + preferences, + usageRows, + ]); const handleOpenBilling = async () => { setIsOpeningBilling(true); const returnUrl = `${window.location.origin}/${organizationId}/settings/billing`; - const endpoint = hasPaymentMethod - ? '/v1/background-check-billing/portal' - : '/v1/background-check-billing/setup-session'; + const endpoint = hasPaymentMethod ? '/v1/billing/portal' : '/v1/billing/setup-session'; const body = hasPaymentMethod ? { returnUrl } : { @@ -118,94 +145,78 @@ export function BillingSettingsClient({ return; } - toast.error('Failed to open Stripe billing'); + toast.error('Failed to open billing portal'); setIsOpeningBilling(false); }; return ( - } - onClick={handleOpenBilling} - > - {actionLabel} - - } - /> - } - > -
+ + Add-ons + Usage + Billing Details + + } + /> + } > -
-
- - - Historical usage - - Completed paid services for this organization. - - -
- - -
-
-
-
-
-
- -
- Payment method - {statusLabel} -
- - {hasPaymentMethod - ? 'A payment method is connected. Open the Stripe portal to update billing details, cards, and receipts.' - : 'Add a payment method to use paid services such as background checks and penetration testing.'} - -
-
-
- - Managed securely in Stripe - - {!canManageBilling && ( - - Ask an organization admin to update billing details. - - )} -
-
-
-
-
- -
- ); -} - -function UsageMetric({ label, value }: { label: string; value: number }) { - return ( -
- - - {label} - - - {value.toLocaleString()} - - -
+ +
+ +
+
+ + +
+ +
+
+ + + +
+ +
+
+ { + void mutateBillingStatus(status, { revalidate: false }); + }} + /> +
+ +
+
+ + ); } diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsDetails.test.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsDetails.test.tsx new file mode 100644 index 0000000000..dfa47d11f6 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsDetails.test.tsx @@ -0,0 +1,197 @@ +import { apiClient } from '@/lib/api-client'; +import { render, screen, waitFor } from '@testing-library/react'; +import userEvent from '@testing-library/user-event'; +import { SWRConfig } from 'swr'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; +import { BillingSettingsClient } from './BillingSettingsClient'; +import type { BackgroundCheckBillingStatus } from './types'; + +vi.mock('@/hooks/use-permissions', () => ({ + usePermissions: () => ({ + hasPermission: (resource: string, action: string) => + resource === 'organization' && action === 'update', + }), +})); + +vi.mock('@/lib/api-client', () => ({ + apiClient: { + get: vi.fn(), + post: vi.fn(), + put: vi.fn(), + }, +})); + +vi.mock('next/navigation', () => ({ + usePathname: () => '/org_1/settings/billing', + useRouter: () => ({ replace: vi.fn() }), + useSearchParams: () => new URLSearchParams(), +})); + +vi.mock('sonner', () => ({ + toast: { success: vi.fn(), error: vi.fn() }, +})); + +function renderBillingSettings(status: Partial = {}) { + return render( + new Map() }}> + + , + ); +} + +describe('BillingSettingsClient details', () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.mocked(apiClient.put).mockResolvedValue({ + data: { + hasPaymentMethod: true, + setupAt: null, + usage: { backgroundChecks: 4, penetrationTests: 2 }, + usageRows: [], + preferences: { + companyName: 'Test Company', + billingEmail: 'accounts@example.com', + purchaseOrder: 'PO-123', + address: { + line1: '1 Test Street', + line2: null, + city: 'London', + state: null, + postalCode: 'SW1A 1AA', + country: 'GB', + }, + taxId: null, + }, + subscriptions: [], + invoices: [], + }, + status: 200, + }); + }); + + it('filters invoices by search query', async () => { + const user = userEvent.setup(); + renderBillingSettings({ + invoices: [ + { + id: 'in_1', + number: 'INV-001', + createdAt: '2026-04-30T09:35:07.000Z', + dueDate: null, + amountPaid: 4900, + amountDue: 4900, + currency: 'usd', + status: 'paid', + type: 'One Time', + hostedInvoiceUrl: 'https://invoice.stripe.com/i/in_1', + invoicePdfUrl: 'https://invoice.stripe.com/i/in_1.pdf', + }, + { + id: 'in_2', + number: 'INV-002', + createdAt: '2026-04-01T09:35:07.000Z', + dueDate: null, + amountPaid: 0, + amountDue: 0, + currency: 'usd', + status: 'paid', + type: 'Subscription', + hostedInvoiceUrl: null, + invoicePdfUrl: null, + }, + ], + }); + + await user.click(screen.getByRole('tab', { name: /billing details/i })); + await user.type(screen.getByLabelText('Search invoices'), 'subscription'); + + expect(screen.queryByText('INV-001')).not.toBeInTheDocument(); + expect(screen.getByText('INV-002')).toBeInTheDocument(); + }); + + it('shows run history and monthly remaining usage on the Usage tab', async () => { + const user = userEvent.setup(); + renderBillingSettings({ + subscriptions: [ + { + skuKey: 'background_checks_monthly_25', + status: 'active', + includedQuantity: 25, + usedQuantity: 3, + currentPeriodStart: '2026-04-30T00:00:00.000Z', + currentPeriodEnd: '2026-05-30T00:00:00.000Z', + cancelAtPeriodEnd: false, + }, + ], + usageRows: [ + { + id: 'bcr_1', + service: 'Background Check', + skuKey: 'background_checks_monthly_25', + details: 'Ada Lovelace (ada@example.com)', + status: 'Completed', + billingType: 'Subscription allowance', + createdAt: '2026-04-30T10:00:00.000Z', + updatedAt: '2026-04-30T10:05:00.000Z', + subscriptionRemaining: 22, + subscriptionIncluded: 25, + subscriptionPeriodEnd: '2026-05-30T00:00:00.000Z', + }, + ], + }); + + await user.click(screen.getByRole('tab', { name: /^usage$/i })); + + expect(screen.getByText('Ada Lovelace (ada@example.com)')).toBeInTheDocument(); + expect(screen.getByText('Subscription allowance')).toBeInTheDocument(); + expect(screen.getByText('22 of 25')).toBeInTheDocument(); + }); + + it('saves billing preferences', async () => { + const user = userEvent.setup(); + renderBillingSettings(); + + await user.click(screen.getByRole('tab', { name: /billing details/i })); + await user.clear(screen.getByLabelText('Billing email')); + await user.type(screen.getByLabelText('Billing email'), 'accounts@example.com'); + await user.type(screen.getByLabelText('PO / reference'), 'PO-123'); + await user.click(screen.getByRole('button', { name: /save billing preferences/i })); + + await waitFor(() => { + expect(apiClient.put).toHaveBeenCalledWith( + '/v1/billing/preferences', + expect.objectContaining({ + companyName: 'Test Company', + billingEmail: 'accounts@example.com', + purchaseOrder: 'PO-123', + }), + 'org_1', + ); + }); + }); +}); diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSubscriptionPlans.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSubscriptionPlans.tsx new file mode 100644 index 0000000000..74bf2f36a8 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSubscriptionPlans.tsx @@ -0,0 +1,97 @@ +'use client'; + +import { getBillingSku, type BillingSkuKey } from '@trycompai/billing'; +import { Button, Stack, Text } from '@trycompai/design-system'; +import { Launch } from '@trycompai/design-system/icons'; +import type { BackgroundCheckBillingStatus } from './types'; + +interface BillingSubscriptionPlansProps { + skuKeys?: readonly BillingSkuKey[]; + subscriptions: NonNullable; + disabled: boolean; + loadingSkuKey: string | null; + onSubscribe: (skuKey: BillingSkuKey) => void; +} + +const planSkuKeys = [ + 'pentest_monthly_5', + 'background_checks_monthly_25', +] as const satisfies readonly BillingSkuKey[]; + +export function BillingSubscriptionPlans({ + skuKeys = planSkuKeys, + subscriptions, + disabled, + loadingSkuKey, + onSubscribe, +}: BillingSubscriptionPlansProps) { + return ( +
+ {skuKeys.map((skuKey) => { + const sku = getBillingSku({ skuKey }); + const subscription = subscriptions.find((item) => item.skuKey === skuKey); + const active = subscription?.status === 'active' || subscription?.status === 'trialing'; + const included = sku.includedUsage; + + return ( +
+ + + {sku.name} + + {sku.description} + + + + + {formatAmount(sku.unitAmount)} + / month + + {included && ( + + {included.quantity} {formatUsageUnit(included.unit)} included monthly + + )} + + {active && subscription ? ( + + + {subscription.usedQuantity} of {subscription.includedQuantity} used this + period. + + + + ) : ( + + )} + +
+ ); + })} +
+ ); +} + +function formatAmount(amount: number) { + return new Intl.NumberFormat('en-US', { + style: 'currency', + currency: 'USD', + maximumFractionDigits: 0, + }).format(amount / 100); +} + +function formatUsageUnit(unit: string) { + return unit === 'scan' ? 'scans' : 'background checks'; +} diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingUsageTable.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingUsageTable.tsx new file mode 100644 index 0000000000..b6f6c8004d --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingUsageTable.tsx @@ -0,0 +1,143 @@ +'use client'; + +import { Card, CardContent, CardHeader, CardTitle, Stack, Text } from '@trycompai/design-system'; +import type React from 'react'; +import type { BackgroundCheckBillingStatus, BillingUsageRow } from './types'; + +interface BillingUsageTableProps { + subscriptions: NonNullable; + usageRows: BillingUsageRow[]; +} + +export function BillingUsageTable({ subscriptions, usageRows }: BillingUsageTableProps) { + return ( + +
+ item.skuKey === 'pentest_monthly_5')} + /> + item.skuKey === 'background_checks_monthly_25', + )} + /> +
+
+
+ + + Run history + + + Paid service runs for this organization. + + +
+
+
+ + + Service + Details + Billing + Status + Run date + Subscription remaining + + + + {usageRows.map((row) => ( + + ))} + {usageRows.length === 0 && ( + + + + )} + +
+ + No paid service runs yet. + +
+
+
+ + {usageRows.length} service run{usageRows.length === 1 ? '' : 's'} + +
+
+ + ); +} + +function AllowanceCard({ + label, + subscription, +}: { + label: string; + subscription?: NonNullable[number]; +}) { + const remaining = subscription + ? Math.max(subscription.includedQuantity - subscription.usedQuantity, 0) + : null; + + return ( + + + {label} + + + + + {remaining ?? 0} + + + {subscription + ? `${subscription.usedQuantity} of ${subscription.includedQuantity} used this period.` + : 'No active monthly subscription.'} + + {subscription?.currentPeriodEnd && ( + + Renews {formatDate(subscription.currentPeriodEnd)} + + )} + + + + ); +} + +function UsageRow({ row }: { row: BillingUsageRow }) { + return ( + + {row.service} + {row.details} + {row.billingType} + {row.status} + {formatDate(row.createdAt)} + {formatRemaining(row)} + + ); +} + +function TableHead({ children }: { children: React.ReactNode }) { + return {children}; +} + +function formatRemaining(row: BillingUsageRow) { + if (row.subscriptionRemaining === null || row.subscriptionIncluded === null) { + return 'No subscription'; + } + return `${row.subscriptionRemaining} of ${row.subscriptionIncluded}`; +} + +function formatDate(date: string) { + return new Intl.DateTimeFormat('en-US', { + month: 'short', + day: 'numeric', + year: 'numeric', + }).format(new Date(date)); +} diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/add-ons/[addOn]/page.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/add-ons/[addOn]/page.tsx new file mode 100644 index 0000000000..20d48209c5 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/add-ons/[addOn]/page.tsx @@ -0,0 +1,40 @@ +import { serverApi } from '@/lib/api-server'; +import type { Metadata } from 'next'; +import { notFound } from 'next/navigation'; +import { BillingAddOnPlansClient } from '../../BillingAddOnPlansClient'; +import { emptyBillingStatus } from '../../emptyBillingStatus'; +import { getBillingAddOn } from '../../billingAddOns'; +import type { BackgroundCheckBillingStatus } from '../../types'; + +export default async function BillingAddOnPage({ + params, +}: { + params: Promise<{ orgId: string; addOn: string }>; +}) { + const { orgId, addOn: addOnSlug } = await params; + const addOn = getBillingAddOn(addOnSlug); + if (!addOn) notFound(); + + const response = await serverApi.get('/v1/billing/status'); + + return ( + + ); +} + +export async function generateMetadata({ + params, +}: { + params: Promise<{ addOn: string }>; +}): Promise { + const { addOn: addOnSlug } = await params; + const addOn = getBillingAddOn(addOnSlug); + + return { + title: addOn ? `${addOn.detailTitle} Billing` : 'Billing Add-on', + }; +} diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/billingAddOns.ts b/apps/app/src/app/(app)/[orgId]/settings/billing/billingAddOns.ts new file mode 100644 index 0000000000..4095f441af --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/billingAddOns.ts @@ -0,0 +1,37 @@ +import type { BillingSkuKey as CatalogBillingSkuKey } from '@trycompai/billing'; + +export type BillingSkuKey = CatalogBillingSkuKey; + +export type BillingAddOnSlug = 'penetration-tests' | 'background-checks'; + +export interface BillingAddOn { + slug: BillingAddOnSlug; + name: string; + detailTitle: string; + description: string; + summary: string; + skuKeys: readonly BillingSkuKey[]; +} + +export const billingAddOns = [ + { + slug: 'penetration-tests', + name: 'Penetration Tests', + detailTitle: 'Penetration Test', + description: 'Run security scans with monthly included usage and centralized billing.', + summary: 'Plans from $399/mo', + skuKeys: ['pentest_monthly_5'], + }, + { + slug: 'background-checks', + name: 'Background Checks', + detailTitle: 'Background Checks', + description: 'Verify employees with subscription allowances or one-off checks as needed.', + summary: 'Plans from $249/mo', + skuKeys: ['background_checks_monthly_25'], + }, +] as const satisfies readonly BillingAddOn[]; + +export function getBillingAddOn(slug: string) { + return billingAddOns.find((addOn) => addOn.slug === slug) ?? null; +} diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/billingPreferencesFormSchema.ts b/apps/app/src/app/(app)/[orgId]/settings/billing/billingPreferencesFormSchema.ts new file mode 100644 index 0000000000..13245bedfa --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/billingPreferencesFormSchema.ts @@ -0,0 +1,128 @@ +import { z } from 'zod'; +import type { BillingPreferences } from './types'; + +export const taxIdTypes = [ + { value: 'none', label: 'No tax ID' }, + { value: 'gb_vat', label: 'UK VAT' }, + { value: 'eu_vat', label: 'EU VAT' }, + { value: 'us_ein', label: 'US EIN' }, + { value: 'au_abn', label: 'Australia ABN' }, + { value: 'ca_bn', label: 'Canada BN' }, + { value: 'nz_gst', label: 'New Zealand GST' }, + { value: 'sg_gst', label: 'Singapore GST' }, + { value: 'sg_uen', label: 'Singapore UEN' }, +] as const; + +export const billingCountries = [ + { value: 'none', label: 'No country' }, + { value: 'AU', label: 'Australia' }, + { value: 'AT', label: 'Austria' }, + { value: 'BE', label: 'Belgium' }, + { value: 'BR', label: 'Brazil' }, + { value: 'CA', label: 'Canada' }, + { value: 'DK', label: 'Denmark' }, + { value: 'FI', label: 'Finland' }, + { value: 'FR', label: 'France' }, + { value: 'DE', label: 'Germany' }, + { value: 'HK', label: 'Hong Kong' }, + { value: 'IN', label: 'India' }, + { value: 'IE', label: 'Ireland' }, + { value: 'IT', label: 'Italy' }, + { value: 'JP', label: 'Japan' }, + { value: 'NL', label: 'Netherlands' }, + { value: 'NZ', label: 'New Zealand' }, + { value: 'NO', label: 'Norway' }, + { value: 'PT', label: 'Portugal' }, + { value: 'SG', label: 'Singapore' }, + { value: 'ZA', label: 'South Africa' }, + { value: 'ES', label: 'Spain' }, + { value: 'SE', label: 'Sweden' }, + { value: 'CH', label: 'Switzerland' }, + { value: 'AE', label: 'United Arab Emirates' }, + { value: 'GB', label: 'United Kingdom' }, + { value: 'US', label: 'United States' }, +] as const; + +const taxIdTypeValues = taxIdTypes.map((type) => type.value) as [ + string, + ...string[], +]; + +export const billingPreferencesSchema = z + .object({ + companyName: z.string().trim().min(1, 'Company name is required').max(150), + billingEmail: z.string().trim().email('Enter a valid billing email').max(512), + purchaseOrder: z.string().trim().max(140).optional(), + addressLine1: z.string().trim().max(200).optional(), + addressLine2: z.string().trim().max(200).optional(), + addressCity: z.string().trim().max(100).optional(), + addressState: z.string().trim().max(100).optional(), + addressPostalCode: z.string().trim().max(32).optional(), + addressCountry: z + .string() + .trim() + .transform((value) => { + const normalizedValue = value.toUpperCase(); + return normalizedValue === 'NONE' ? '' : normalizedValue; + }) + .refine((value) => value === '' || /^[A-Z]{2}$/.test(value), { + message: 'Use a 2-letter country code', + }), + taxIdType: z.enum(taxIdTypeValues), + taxIdValue: z.string().trim().max(64).optional(), + }) + .refine((value) => value.taxIdType === 'none' || !!value.taxIdValue, { + message: 'Enter the tax ID value', + path: ['taxIdValue'], + }) + .refine((value) => value.taxIdType !== 'none' || !value.taxIdValue, { + message: 'Choose the tax ID type', + path: ['taxIdType'], + }); + +export type BillingPreferencesFormValues = z.infer; + +export function toBillingPreferencesFormValues( + preferences: BillingPreferences | null, +): BillingPreferencesFormValues { + return { + companyName: preferences?.companyName ?? '', + billingEmail: preferences?.billingEmail ?? '', + purchaseOrder: preferences?.purchaseOrder ?? '', + addressLine1: preferences?.address.line1 ?? '', + addressLine2: preferences?.address.line2 ?? '', + addressCity: preferences?.address.city ?? '', + addressState: preferences?.address.state ?? '', + addressPostalCode: preferences?.address.postalCode ?? '', + addressCountry: preferences?.address.country?.toUpperCase() ?? '', + taxIdType: preferences?.taxId?.type ?? 'none', + taxIdValue: preferences?.taxId?.value ?? '', + }; +} + +export function toBillingPreferencesPayload(values: BillingPreferencesFormValues) { + return { + companyName: values.companyName, + billingEmail: values.billingEmail, + purchaseOrder: values.purchaseOrder ?? '', + addressLine1: values.addressLine1 ?? '', + addressLine2: values.addressLine2 ?? '', + addressCity: values.addressCity ?? '', + addressState: values.addressState ?? '', + addressPostalCode: values.addressPostalCode ?? '', + addressCountry: values.addressCountry ?? '', + taxIdType: values.taxIdType === 'none' ? '' : values.taxIdType, + taxIdValue: values.taxIdValue ?? '', + }; +} + +export function getTaxIdTypeLabel(value: string) { + return taxIdTypes.find((type) => type.value === value)?.label ?? value; +} + +export function getCountryLabel(value: string) { + if (!value) return 'No country'; + const normalizedValue = value.toUpperCase(); + const country = billingCountries.find((item) => item.value === normalizedValue); + return country ? `${country.label} (${country.value})` : normalizedValue; +} diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/emptyBillingStatus.ts b/apps/app/src/app/(app)/[orgId]/settings/billing/emptyBillingStatus.ts new file mode 100644 index 0000000000..1947baee7e --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/emptyBillingStatus.ts @@ -0,0 +1,14 @@ +import type { BackgroundCheckBillingStatus } from './types'; + +export const emptyBillingStatus: BackgroundCheckBillingStatus = { + hasPaymentMethod: false, + setupAt: null, + usage: { + backgroundChecks: 0, + penetrationTests: 0, + }, + invoices: [], + subscriptions: [], + usageRows: [], + preferences: null, +}; diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/page.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/page.tsx index bf72e67da4..e1198c1f3d 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/page.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/page.tsx @@ -1,23 +1,12 @@ import { serverApi } from '@/lib/api-server'; import type { Metadata } from 'next'; import { BillingSettingsClient } from './BillingSettingsClient'; +import { emptyBillingStatus } from './emptyBillingStatus'; import type { BackgroundCheckBillingStatus } from './types'; -const emptyBillingStatus: BackgroundCheckBillingStatus = { - hasPaymentMethod: false, - setupAt: null, - usage: { - backgroundChecks: 0, - penetrationTests: 0, - }, - invoices: [], -}; - export default async function BillingPage({ params }: { params: Promise<{ orgId: string }> }) { const { orgId } = await params; - const response = await serverApi.get( - '/v1/background-check-billing/status', - ); + const response = await serverApi.get('/v1/billing/status'); return ( ; invoices?: BillingInvoice[]; } + +export interface BillingPreferences { + companyName: string | null; + billingEmail: string | null; + purchaseOrder: string | null; + address: { + line1: string | null; + line2: string | null; + city: string | null; + state: string | null; + postalCode: string | null; + country: string | null; + }; + taxId: { + id: string; + type: string; + value: string; + verificationStatus: string | null; + } | null; +} + +export interface BillingUsageRow { + id: string; + service: 'Penetration Test' | 'Background Check'; + skuKey: string; + details: string; + status: string; + billingType: string; + createdAt: string; + updatedAt: string; + subscriptionRemaining: number | null; + subscriptionIncluded: number | null; + subscriptionPeriodEnd: string | null; +} diff --git a/apps/framework-editor/package.json b/apps/framework-editor/package.json index 940d6925b4..72457f1fd6 100644 --- a/apps/framework-editor/package.json +++ b/apps/framework-editor/package.json @@ -42,7 +42,7 @@ "private": true, "scripts": { "build": "next build", - "db:generate": "find ../../packages/db/prisma/schema -name '*.prisma' ! -name 'schema.prisma' -exec cp {} prisma/schema/ \\; && prisma generate --schema=prisma/schema && node ../../packages/db/scripts/fix-generated-extensions.js src/generated/prisma", + "db:generate": "find prisma/schema -name '*.prisma' ! -name 'schema.prisma' -delete && find ../../packages/db/prisma/schema -name '*.prisma' ! -name 'schema.prisma' -exec cp {} prisma/schema/ \\; && prisma generate --schema=prisma/schema && node ../../packages/db/scripts/fix-generated-extensions.js src/generated/prisma", "dev": "next dev --port 3004", "lint": "echo 'no lint configured'", "prebuild": "bun run db:generate", diff --git a/apps/framework-editor/prisma/schema.prisma b/apps/framework-editor/prisma/schema.prisma deleted file mode 100644 index c7f57a4879..0000000000 --- a/apps/framework-editor/prisma/schema.prisma +++ /dev/null @@ -1,2611 +0,0 @@ -generator client { - provider = "prisma-client-js" - engineType = "binary" - previewFeatures = ["postgresqlExtensions"] - binaryTargets = ["rhel-openssl-3.0.x", "native", "debian-openssl-3.0.x", "linux-musl-openssl-3.0.x", "linux-musl-arm64-openssl-3.0.x"] -} - -datasource db { - provider = "postgresql" - url = env("DATABASE_URL") - extensions = [pgcrypto] -} - - -// ===== attachments.prisma ===== -model Attachment { - id String @id @default(dbgenerated("generate_prefixed_cuid('att'::text)")) - name String - url String - type AttachmentType - entityId String - entityType AttachmentEntityType - - // Dates - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - // Relationships - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - comment Comment? @relation(fields: [commentId], references: [id]) - commentId String? - - @@index([entityId, entityType]) -} - -enum AttachmentEntityType { - task - vendor - risk - comment - trust_nda - task_item -} - -enum AttachmentType { - image - video - audio - document - other -} - - -// ===== auth.prisma ===== -model User { - id String @id @default(dbgenerated("generate_prefixed_cuid('usr'::text)")) - name String - email String - emailVerified Boolean - image String? - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - lastLogin DateTime? - emailNotificationsUnsubscribed Boolean @default(false) - emailPreferences Json? @default("{\"policyNotifications\":true,\"taskReminders\":true,\"weeklyTaskDigest\":true,\"unassignedItemsNotifications\":true}") - role String? @default("user") - banned Boolean? - banReason String? - banExpires DateTime? - isPlatformAdmin Boolean @default(false) - - accounts Account[] - auditLog AuditLog[] - integrationResults IntegrationResult[] - invitations Invitation[] - members Member[] - sessions Session[] - fleetPolicyResults FleetPolicyResult[] - evidenceSubmissions EvidenceSubmission[] @relation("EvidenceSubmitter") - evidenceReviews EvidenceSubmission[] @relation("EvidenceReviewer") - adminFindings Finding[] @relation("AdminFindingCreator") - - @@unique([email]) -} - -model EmployeeTrainingVideoCompletion { - id String @id @default(dbgenerated("generate_prefixed_cuid('evc'::text)")) - completedAt DateTime? - videoId String - - memberId String - member Member @relation(fields: [memberId], references: [id], onDelete: Cascade) - - @@unique([memberId, videoId]) - @@index([memberId]) -} - -model Session { - id String @id @default(dbgenerated("generate_prefixed_cuid('ses'::text)")) - expiresAt DateTime - token String - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - ipAddress String? - userAgent String? - userId String - activeOrganizationId String? - impersonatedBy String? - user User @relation(fields: [userId], references: [id], onDelete: Cascade) - - @@unique([token]) -} - -model Account { - id String @id @default(dbgenerated("generate_prefixed_cuid('acc'::text)")) - accountId String - providerId String - userId String - user User @relation(fields: [userId], references: [id], onDelete: Cascade) - accessToken String? - refreshToken String? - idToken String? - accessTokenExpiresAt DateTime? - refreshTokenExpiresAt DateTime? - scope String? - password String? - createdAt DateTime - updatedAt DateTime -} - -model Verification { - id String @id @default(dbgenerated("generate_prefixed_cuid('ver'::text)")) - identifier String - value String - expiresAt DateTime - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt -} - -// JWT Plugin - Required by Better Auth JWT plugin -// https://www.better-auth.com/docs/plugins/jwt -model Jwks { - id String @id @default(dbgenerated("generate_prefixed_cuid('jwk'::text)")) - publicKey String - privateKey String - createdAt DateTime @default(now()) - expiresAt DateTime? - - @@map("jwks") -} - -model Member { - id String @id @default(dbgenerated("generate_prefixed_cuid('mem'::text)")) - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - userId String - user User @relation(fields: [userId], references: [id], onDelete: Cascade) - role String // Purposefully a string, since BetterAuth doesn't support enums this way - createdAt DateTime @default(now()) - - department Departments @default(none) - jobTitle String? - isActive Boolean @default(true) - deactivated Boolean @default(false) - externalUserId String? - externalUserSource String? - employeeTrainingVideoCompletion EmployeeTrainingVideoCompletion[] - fleetDmLabelId Int? - - assignedPolicies Policy[] @relation("PolicyAssignee") // Policies where this member is an assignee - approvedPolicies Policy[] @relation("PolicyApprover") // Policies where this member is an approver - approvedSOADocuments SOADocument[] @relation("SOADocumentApprover") // SOA documents where this member is an approver - risks Risk[] - tasks Task[] - vendors Vendor[] - comments Comment[] - auditLogs AuditLog[] - reviewedAccessRequests TrustAccessRequest[] @relation("TrustAccessRequestReviewer") - issuedGrants TrustAccessGrant[] @relation("IssuedGrants") - revokedGrants TrustAccessGrant[] @relation("RevokedGrants") - createdTaskItems TaskItem[] @relation("TaskItemCreator") - updatedTaskItems TaskItem[] @relation("TaskItemUpdater") - assignedTaskItems TaskItem[] @relation("TaskItemAssignee") - createdFindings Finding[] @relation("FindingCreatedBy") - publishedPolicyVersions PolicyVersion[] @relation("PolicyVersionPublisher") - approvedTasks Task[] @relation("TaskApprover") - devices Device[] -} - -model Invitation { - id String @id @default(dbgenerated("generate_prefixed_cuid('inv'::text)")) - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - email String - role String // Purposefully a string, since BetterAuth doesn't support enums this way - status String - expiresAt DateTime - inviterId String - user User @relation(fields: [inviterId], references: [id], onDelete: Cascade) - createdAt DateTime @default(now()) -} - -// This is only for the app to consume, shouldn't be enforced by DB -// Otherwise it won't work with Better Auth, as per https://www.better-auth.com/docs/plugins/organization#access-control -enum Role { - owner - admin - auditor - employee - contractor -} - -// Custom roles for dynamic access control -// This table stores organization-specific custom roles created via better-auth -// See: https://www.better-auth.com/docs/plugins/organization#dynamic-access-control -model OrganizationRole { - id String @id @default(dbgenerated("generate_prefixed_cuid('rol'::text)")) - name String - permissions String @db.Text // Stored as serialized JSON string for better-auth compatibility - obligations String @default("{}") @db.Text // JSON: { compliance?: boolean } - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - @@unique([organizationId, name]) - @@map("organization_role") -} - -enum PolicyStatus { - draft - published - needs_review -} - - -// ===== automation-run.prisma ===== -model EvidenceAutomationRun { - id String @id @default(dbgenerated("generate_prefixed_cuid('ear'::text)")) - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - // Relations - evidenceAutomationId String - evidenceAutomation EvidenceAutomation @relation(fields: [evidenceAutomationId], references: [id], onDelete: Cascade) - - // Run details - status EvidenceAutomationRunStatus @default(pending) - startedAt DateTime? - completedAt DateTime? - - // Results - success Boolean? - error String? - logs Json? - output Json? - - // Evaluation - evaluationStatus EvidenceAutomationEvaluationStatus? - evaluationReason String? - - // Metadata - triggeredBy EvidenceAutomationTrigger @default(scheduled) - runDuration Int? // in milliseconds - version Int? // Version number that was executed (null = draft) - task Task? @relation(fields: [taskId], references: [id]) - taskId String? - - @@index([evidenceAutomationId]) - @@index([status]) - @@index([createdAt]) - @@index([version]) -} - -enum EvidenceAutomationRunStatus { - pending - running - completed - failed - cancelled -} - -enum EvidenceAutomationTrigger { - manual - scheduled - api -} - -enum EvidenceAutomationEvaluationStatus { - pass - fail -} - - -// ===== automation-version.prisma ===== -model EvidenceAutomationVersion { - id String @id @default(dbgenerated("generate_prefixed_cuid('eav'::text)")) - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - // Relations - evidenceAutomationId String - evidenceAutomation EvidenceAutomation @relation(fields: [evidenceAutomationId], references: [id], onDelete: Cascade) - - // Version details - version Int // Sequential version number (1, 2, 3...) - scriptKey String // S3 key for this version's script - publishedBy String? // User ID who published - changelog String? // Optional description of changes - - @@unique([evidenceAutomationId, version]) - @@index([evidenceAutomationId]) - @@index([createdAt]) -} - - -// ===== automation.prisma ===== -model EvidenceAutomation { - id String @id @default(dbgenerated("generate_prefixed_cuid('aut'::text)")) - name String - description String? - createdAt DateTime @default(now()) - isEnabled Boolean @default(false) - - chatHistory String? - evaluationCriteria String? - - taskId String - task Task @relation(fields: [taskId], references: [id], onDelete: Cascade) - - // Relations - runs EvidenceAutomationRun[] - versions EvidenceAutomationVersion[] - - @@index([taskId]) -} - - -// ===== browserbase-context.prisma ===== -/// Stores Browserbase context IDs for browser-based automation -/// One context per organization - shared like a normal browser -model BrowserbaseContext { - id String @id @default(dbgenerated("generate_prefixed_cuid('bbc'::text)")) - - /// Organization that owns this browser context - organizationId String @unique - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - /// Browserbase context ID from their API - contextId String - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - @@index([organizationId]) -} - -/// Browser automation configuration linked to a task -model BrowserAutomation { - id String @id @default(dbgenerated("generate_prefixed_cuid('bau'::text)")) - name String - description String? - - /// Task this automation belongs to - taskId String - task Task @relation(fields: [taskId], references: [id], onDelete: Cascade) - - /// Starting URL for the automation - targetUrl String - - /// Natural language instruction for the AI agent - instruction String - - /// Whether automation is enabled for scheduled runs - isEnabled Boolean @default(false) - - /// Cron expression for scheduled runs (null = manual only) - schedule String? - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - runs BrowserAutomationRun[] - - @@index([taskId]) -} - -/// Records of browser automation executions -model BrowserAutomationRun { - id String @id @default(dbgenerated("generate_prefixed_cuid('bar'::text)")) - - /// Parent automation - automationId String - automation BrowserAutomation @relation(fields: [automationId], references: [id], onDelete: Cascade) - - /// Execution status - status BrowserAutomationRunStatus @default(pending) - - /// Timestamps - startedAt DateTime? - completedAt DateTime? - - /// Duration in milliseconds - durationMs Int? - - /// Screenshot URL in S3 (if successful) - screenshotUrl String? - - /// Evaluation result - whether the automation fulfilled the task requirements - evaluationStatus BrowserAutomationEvaluationStatus? - - /// AI explanation of why it passed or failed - evaluationReason String? - - /// Error message (if failed) - error String? - - createdAt DateTime @default(now()) - - @@index([automationId]) - @@index([status]) - @@index([createdAt]) -} - -enum BrowserAutomationEvaluationStatus { - pass - fail -} - -enum BrowserAutomationRunStatus { - pending - running - completed - failed -} - - -// ===== comment.prisma ===== -model Comment { - id String @id @default(dbgenerated("generate_prefixed_cuid('cmt'::text)")) - content String - entityId String - entityType CommentEntityType - - // Dates - createdAt DateTime @default(now()) - - // Relationships - authorId String - author Member @relation(fields: [authorId], references: [id]) - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - // Relation to Attachments - attachments Attachment[] - - @@index([entityId]) -} - -enum CommentEntityType { - task - vendor - risk - policy -} - - -// ===== context.prisma ===== -model Context { - id String @id @default(dbgenerated("generate_prefixed_cuid('ctx'::text)")) - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - question String - answer String - - tags String[] - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - @@index([organizationId]) - @@index([question]) - @@index([answer]) - @@index([tags]) -} - - -// ===== control-document-type.prisma ===== -model ControlDocumentType { - id String @id @default(dbgenerated("generate_prefixed_cuid('cdt'::text)")) - controlId String - control Control @relation(fields: [controlId], references: [id], onDelete: Cascade) - formType EvidenceFormType - - @@unique([controlId, formType]) - @@index([controlId]) -} - - -// ===== control.prisma ===== -model Control { - // Metadata - id String @id @default(dbgenerated("generate_prefixed_cuid('ctl'::text)")) - name String - description String - - // Review dates - lastReviewDate DateTime? - nextReviewDate DateTime? - - // Relationships - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - organizationId String - requirementsMapped RequirementMap[] - tasks Task[] - policies Policy[] - controlTemplateId String? - controlTemplate FrameworkEditorControlTemplate? @relation(fields: [controlTemplateId], references: [id]) - controlDocumentTypes ControlDocumentType[] - - @@index([organizationId]) -} - - -// ===== device.prisma ===== -model Device { - id String @id @default(dbgenerated("generate_prefixed_cuid('dev'::text)")) - name String - hostname String - platform DevicePlatform - osVersion String - serialNumber String? - hardwareModel String? - - memberId String - member Member @relation(fields: [memberId], references: [id], onDelete: Cascade) - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - isCompliant Boolean @default(false) - diskEncryptionEnabled Boolean @default(false) - antivirusEnabled Boolean @default(false) - passwordPolicySet Boolean @default(false) - screenLockEnabled Boolean @default(false) - checkDetails Json? - - lastCheckIn DateTime? - agentVersion String? - installedAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - @@unique([serialNumber, organizationId]) - @@index([memberId]) - @@index([organizationId]) - @@index([isCompliant]) -} - -enum DevicePlatform { - macos - windows - linux -} - - -// ===== dynamic-integration.prisma ===== -// ===== Dynamic Integration Platform ===== -// Stores integration manifests and declarative check definitions in the database -// Enables adding new integrations without code changes or deployments - -/// Stores a full integration manifest as JSON — replaces hand-written TypeScript manifests -model DynamicIntegration { - id String @id @default(dbgenerated("generate_prefixed_cuid('din'::text)")) - /// Unique slug (e.g., "azure-devops", "office-365") - slug String @unique - /// Display name - name String - /// Short description for catalog - description String - /// Category for grouping - category String - /// Logo URL - logoUrl String - /// URL to documentation - docsUrl String? - - /// API base URL for ctx.fetch - baseUrl String? - /// Default headers (JSON object) - defaultHeaders Json? - - /// Auth strategy config (JSON — matches AuthStrategy type: oauth2/api_key/basic/jwt/custom) - authConfig Json - - /// Capabilities JSON array (default ["checks"]) - capabilities Json @default("[\"checks\"]") - - /// Whether multiple connections per org are allowed - supportsMultipleConnections Boolean @default(false) - - /// Declarative sync definition (JSON — DSL steps that produce employee list) - /// When present and capabilities includes 'sync', enables employee sync - syncDefinition Json? - - /// Whether this dynamic integration is active - isActive Boolean @default(true) - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - checks DynamicCheck[] - - @@index([slug]) - @@index([category]) - @@index([isActive]) -} - -/// Stores a declarative check definition — DSL JSON replaces hand-written run() functions -model DynamicCheck { - id String @id @default(dbgenerated("generate_prefixed_cuid('dck'::text)")) - - /// Parent integration - integrationId String - integration DynamicIntegration @relation(fields: [integrationId], references: [id], onDelete: Cascade) - - /// Unique slug within integration (e.g., "mfa_enabled") - checkSlug String - - /// Human-readable name - name String - /// Description of what this check does - description String - - /// Task template ID for auto-completion (references TASK_TEMPLATES) - taskMapping String? - - /// Default severity for findings - defaultSeverity String @default("medium") - - /// Declarative DSL definition (JSON — the step-by-step instructions) - definition Json - - /// Check-level variables (JSON array of CheckVariable) - variables Json @default("[]") - - /// Whether this check is enabled - isEnabled Boolean @default(true) - - /// Display order - sortOrder Int @default(0) - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - @@unique([integrationId, checkSlug]) - @@index([integrationId]) - @@index([isEnabled]) -} - - -// ===== evidence-submission.prisma ===== -model EvidenceSubmission { - id String @id @default(dbgenerated("generate_prefixed_cuid('evs'::text)")) - organizationId String - formType EvidenceFormType - submittedById String? - submittedAt DateTime @default(now()) - data Json - status String @default("pending") - reviewedById String? - reviewedAt DateTime? - reviewReason String? - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - submittedBy User? @relation("EvidenceSubmitter", fields: [submittedById], references: [id], onDelete: SetNull) - reviewedBy User? @relation("EvidenceReviewer", fields: [reviewedById], references: [id], onDelete: SetNull) - findings Finding[] - - @@index([organizationId, formType, submittedAt]) - @@index([organizationId, formType]) - @@index([submittedById, status]) -} - - -// ===== finding.prisma ===== -enum FindingType { - soc2 - iso27001 -} - -enum FindingStatus { - open - ready_for_review - needs_revision - closed -} - -model FindingTemplate { - id String @id @default(dbgenerated("generate_prefixed_cuid('fnd_t'::text)")) - category String // e.g., "evidence_issue", "further_evidence", "task_specific", "na_incorrect" - title String // Short title - content String // Full message template - order Int @default(0) - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - findings Finding[] -} - -model Finding { - id String @id @default(dbgenerated("generate_prefixed_cuid('fnd'::text)")) - type FindingType @default(soc2) - status FindingStatus @default(open) - content String // Custom message or copied from template - revisionNote String? // Auditor's note when requesting revision - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - // Relationships - taskId String? - task Task? @relation(fields: [taskId], references: [id], onDelete: Cascade) - evidenceSubmissionId String? - evidenceSubmission EvidenceSubmission? @relation(fields: [evidenceSubmissionId], references: [id], onDelete: Cascade) - evidenceFormType EvidenceFormType? - templateId String? - template FindingTemplate? @relation(fields: [templateId], references: [id]) - createdById String? - createdBy Member? @relation("FindingCreatedBy", fields: [createdById], references: [id]) - createdByAdminId String? - createdByAdmin User? @relation("AdminFindingCreator", fields: [createdByAdminId], references: [id]) - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - @@index([taskId]) - @@index([evidenceSubmissionId]) - @@index([evidenceFormType]) - @@index([organizationId, status]) -} - - -// ===== fleet-policy-result.prisma ===== -model FleetPolicyResult { - id String @id @default(dbgenerated("generate_prefixed_cuid('fpr'::text)")) - userId String - organizationId String - fleetPolicyId Int - fleetPolicyName String - fleetPolicyResponse String - attachments String[] @default([]) - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - user User @relation(fields: [userId], references: [id], onDelete: Cascade) - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - @@index([userId]) - @@index([organizationId]) -} - - -// ===== framework-editor.prisma ===== -// --- Data for Framework Editor --- -model FrameworkEditorVideo { - id String @id @default(dbgenerated("generate_prefixed_cuid('frk_vi'::text)")) - title String - description String - youtubeId String - url String - - // Dates - createdAt DateTime @default(now()) - updatedAt DateTime @default(now()) @updatedAt -} - -model FrameworkEditorFramework { - id String @id @default(dbgenerated("generate_prefixed_cuid('frk'::text)")) - name String // e.g., "soc2", "iso27001" - version String - description String - visible Boolean @default(false) - - requirements FrameworkEditorRequirement[] - frameworkInstances FrameworkInstance[] - soaConfigurations SOAFrameworkConfiguration[] // Multiple SOA config versions per framework - soaDocuments SOADocument[] // SOA documents from organizations - - // Dates - createdAt DateTime @default(now()) - updatedAt DateTime @default(now()) @updatedAt -} - -model FrameworkEditorRequirement { - id String @id @default(dbgenerated("generate_prefixed_cuid('frk_rq'::text)")) - frameworkId String - framework FrameworkEditorFramework @relation(fields: [frameworkId], references: [id]) - - name String // Original requirement ID within that framework, e.g., "Privacy" - identifier String @default("") // Unique identifier for the requirement, e.g., "cc1-1" - description String - - controlTemplates FrameworkEditorControlTemplate[] - requirementMaps RequirementMap[] - - // Dates - createdAt DateTime @default(now()) - updatedAt DateTime @default(now()) @updatedAt -} - -model FrameworkEditorPolicyTemplate { - id String @id @default(dbgenerated("generate_prefixed_cuid('frk_pt'::text)")) - name String - description String - frequency Frequency // Using the enum from shared.prisma - department Departments // Using the enum from shared.prisma - content Json - - controlTemplates FrameworkEditorControlTemplate[] - - // Dates - createdAt DateTime @default(now()) - updatedAt DateTime @default(now()) @updatedAt - - // Instances - policies Policy[] -} - -model FrameworkEditorTaskTemplate { - id String @id @default(dbgenerated("generate_prefixed_cuid('frk_tt'::text)")) - name String - description String - frequency Frequency // Using the enum from shared.prisma - department Departments // Using the enum from shared.prisma - automationStatus TaskAutomationStatus @default(AUTOMATED) - - controlTemplates FrameworkEditorControlTemplate[] - - // Dates - createdAt DateTime @default(now()) - updatedAt DateTime @default(now()) @updatedAt - - // Instances - tasks Task[] -} - -model FrameworkEditorControlTemplate { - id String @id @default(dbgenerated("generate_prefixed_cuid('frk_ct'::text)")) - name String - description String - - policyTemplates FrameworkEditorPolicyTemplate[] - requirements FrameworkEditorRequirement[] - taskTemplates FrameworkEditorTaskTemplate[] - documentTypes EvidenceFormType[] - - // Dates - createdAt DateTime @default(now()) - updatedAt DateTime @default(now()) @updatedAt - - // Instances - controls Control[] -} - - -// ===== framework.prisma ===== -model FrameworkInstance { - // Metadata - id String @id @default(dbgenerated("generate_prefixed_cuid('frm'::text)")) - organizationId String - - frameworkId String - framework FrameworkEditorFramework @relation(fields: [frameworkId], references: [id], onDelete: Cascade) - - // Relationships - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - requirementsMapped RequirementMap[] - - @@unique([organizationId, frameworkId]) -} - - -// ===== integration-platform.prisma ===== -// ===== Integration Platform ===== -// New integration platform models for scalable, config-driven integrations - -/// Stores metadata about available integration providers (synced from code manifests) -model IntegrationProvider { - id String @id @default(dbgenerated("generate_prefixed_cuid('prv'::text)")) - /// Unique slug matching manifest ID (e.g., "github", "slack") - slug String @unique - /// Display name - name String - /// Category for grouping - category String - /// Hash of manifest for detecting changes - manifestHash String? - /// Capabilities JSON array - capabilities Json @default("[]") - /// Whether provider is active - isActive Boolean @default(true) - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - connections IntegrationConnection[] - - @@index([slug]) - @@index([category]) -} - -/// Represents an organization's connection to an integration provider -model IntegrationConnection { - id String @id @default(dbgenerated("generate_prefixed_cuid('icn'::text)")) - - /// Reference to the provider - providerId String - provider IntegrationProvider @relation(fields: [providerId], references: [id], onDelete: Cascade) - - /// Organization that owns this connection - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - /// Connection status - status IntegrationConnectionStatus @default(pending) - - /// Auth strategy used (oauth2, api_key, basic, jwt, custom) - authStrategy String - - /// Reference to active credential version - activeCredentialVersionId String? - - /// Last successful sync timestamp - lastSyncAt DateTime? - - /// Next scheduled sync timestamp - nextSyncAt DateTime? - - /// Custom sync cadence (cron expression), null = use default - syncCadence String? - - /// Additional metadata (e.g., connected account info) - metadata Json? - - /// User-configured variables for checks (collected after OAuth) - variables Json? - - /// Error message if status is error - errorMessage String? - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - credentialVersions IntegrationCredentialVersion[] - runs IntegrationRun[] - findings IntegrationPlatformFinding[] - checkRuns IntegrationCheckRun[] - syncLogs IntegrationSyncLog[] - - @@index([organizationId]) - @@index([providerId]) - @@index([providerId, organizationId]) - @@index([status]) -} - -enum IntegrationConnectionStatus { - pending // Awaiting credential setup - active // Connected and operational - error // Connection has errors - paused // Manually paused by user - disconnected // User disconnected -} - -/// Stores encrypted credentials with versioning for audit trail -model IntegrationCredentialVersion { - id String @id @default(dbgenerated("generate_prefixed_cuid('icv'::text)")) - - /// Parent connection - connectionId String - connection IntegrationConnection @relation(fields: [connectionId], references: [id], onDelete: Cascade) - - /// Encrypted credential payload (JSON with encrypted fields) - encryptedPayload Json - - /// Version number (auto-increment per connection) - version Int - - /// Token expiration (for OAuth tokens) - expiresAt DateTime? - - /// When this version was rotated/replaced - rotatedAt DateTime? - - createdAt DateTime @default(now()) - - @@unique([connectionId, version]) - @@index([connectionId]) -} - -/// Records each sync/job execution for audit and debugging -model IntegrationRun { - id String @id @default(dbgenerated("generate_prefixed_cuid('irn'::text)")) - - /// Parent connection - connectionId String - connection IntegrationConnection @relation(fields: [connectionId], references: [id], onDelete: Cascade) - - /// Type of job - jobType IntegrationRunJobType - - /// Execution status - status IntegrationRunStatus @default(pending) - - /// Timestamps - startedAt DateTime? - completedAt DateTime? - - /// Duration in milliseconds - durationMs Int? - - /// Number of findings from this run - findingsCount Int @default(0) - - /// Error details if failed - error Json? - - /// Additional metadata (trigger source, cursor, etc.) - metadata Json? - - createdAt DateTime @default(now()) - - findings IntegrationPlatformFinding[] - - @@index([connectionId]) - @@index([status]) - @@index([createdAt]) -} - -enum IntegrationRunJobType { - full_sync - delta_sync - webhook - manual - test_connection -} - -enum IntegrationRunStatus { - pending - running - success - failed - cancelled -} - -/// Stores findings/results from integration syncs -model IntegrationPlatformFinding { - id String @id @default(dbgenerated("generate_prefixed_cuid('ipf'::text)")) - - /// Parent run (optional - webhooks may not have runs) - runId String? - run IntegrationRun? @relation(fields: [runId], references: [id], onDelete: SetNull) - - /// Parent connection - connectionId String - connection IntegrationConnection @relation(fields: [connectionId], references: [id], onDelete: Cascade) - - /// Resource classification - resourceType String - resourceId String - - /// Finding details - title String - description String? - - /// Severity level - severity IntegrationFindingSeverity @default(info) - - /// Finding status - status IntegrationFindingStatus @default(open) - - /// Remediation guidance - remediation String? - - /// Raw payload from provider - rawPayload Json? - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - @@index([connectionId]) - @@index([runId]) - @@index([resourceType, resourceId]) - @@index([severity]) - @@index([status]) -} - -enum IntegrationFindingSeverity { - info - low - medium - high - critical -} - -enum IntegrationFindingStatus { - open - resolved - ignored -} - -/// Stores OAuth state for CSRF protection during OAuth flow -model IntegrationOAuthState { - id String @id @default(dbgenerated("generate_prefixed_cuid('ios'::text)")) - - /// Random state parameter - state String @unique - - /// Provider slug - providerSlug String - - /// Organization initiating the OAuth - organizationId String - - /// User initiating the OAuth - userId String - - /// PKCE code verifier (if using PKCE) - codeVerifier String? - - /// Redirect URL after OAuth completes - redirectUrl String? - - /// Expiration timestamp - expiresAt DateTime - - createdAt DateTime @default(now()) - - @@index([state]) - @@index([expiresAt]) -} - -/// Stores organization-level OAuth app credentials -/// Allows orgs (especially self-hosters) to use their own OAuth apps -model IntegrationOAuthApp { - id String @id @default(dbgenerated("generate_prefixed_cuid('ioa'::text)")) - - /// Provider slug (e.g., "github", "slack") - providerSlug String - - /// Organization that owns this OAuth app config - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - /// Encrypted client ID - encryptedClientId Json - - /// Encrypted client secret - encryptedClientSecret Json - - /// Optional: custom scopes (overrides manifest defaults) - customScopes String[] - - /// Provider-specific settings (e.g., Rippling app name for authorize URL) - /// Stored as JSON: { "appName": "compai533c" } - customSettings Json? - - /// Whether this config is active - isActive Boolean @default(true) - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - @@unique([providerSlug, organizationId]) - @@index([organizationId]) - @@index([providerSlug]) -} - -/// Records check runs linked to tasks for compliance verification -model IntegrationCheckRun { - id String @id @default(dbgenerated("generate_prefixed_cuid('icr'::text)")) - - /// Parent connection - connectionId String - connection IntegrationConnection @relation(fields: [connectionId], references: [id], onDelete: Cascade) - - /// Task being verified (optional - checks can run without a task) - taskId String? - task Task? @relation(fields: [taskId], references: [id], onDelete: SetNull) - - /// Check ID from the manifest - checkId String - - /// Check name (denormalized for display) - checkName String - - /// Execution status - status IntegrationRunStatus @default(pending) - - /// Timestamps - startedAt DateTime? - completedAt DateTime? - - /// Duration in milliseconds - durationMs Int? - - /// Summary counts - totalChecked Int @default(0) - passedCount Int @default(0) - failedCount Int @default(0) - - /// Error message if failed - errorMessage String? - - /// Full execution logs (JSON array) - logs Json? - - createdAt DateTime @default(now()) - - /// Results from this check run - results IntegrationCheckResult[] - - @@index([connectionId]) - @@index([taskId]) - @@index([checkId]) - @@index([status]) - @@index([createdAt]) -} - -/// Stores individual results (pass/fail) from check runs -model IntegrationCheckResult { - id String @id @default(dbgenerated("generate_prefixed_cuid('icx'::text)")) - - /// Parent check run - checkRunId String - checkRun IntegrationCheckRun @relation(fields: [checkRunId], references: [id], onDelete: Cascade) - - /// Whether this result is a pass or fail - passed Boolean - - /// Resource classification - resourceType String - resourceId String - - /// Result details - title String - description String? - - /// Severity (for failures) - severity IntegrationFindingSeverity? - - /// Remediation guidance (for failures) - remediation String? - - /// Evidence/proof (JSON - API response data) - evidence Json? - - /// When this evidence was collected - collectedAt DateTime @default(now()) - - @@index([checkRunId]) - @@index([passed]) - @@index([resourceType, resourceId]) -} - -/// Stores platform-wide OAuth app credentials -/// Used by platform operators to provide default OAuth apps for all users -model IntegrationPlatformCredential { - id String @id @default(dbgenerated("generate_prefixed_cuid('ipc'::text)")) - - /// Provider slug (e.g., "github", "slack") - unique per platform - providerSlug String @unique - - /// Encrypted client ID - encryptedClientId Json - - /// Encrypted client secret - encryptedClientSecret Json - - /// Masked display hint for client ID (computed at write time) - clientIdHint String? - - /// Masked display hint for client secret (computed at write time) - clientSecretHint String? - - /// Optional: custom scopes (overrides manifest defaults) - customScopes String[] - - /// Provider-specific settings (e.g., Rippling app name for authorize URL) - /// Stored as JSON: { "appName": "compai533c" } - customSettings Json? - - /// Whether this credential is active - isActive Boolean @default(true) - - /// Who created this credential - createdById String? - - /// Who last updated this credential - updatedById String? - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - @@index([providerSlug]) -} - - -// ===== integration-sync-log.prisma ===== -// ===== Integration Sync Log ===== -// Generic audit trail for integration sync operations (employee sync, role discovery, etc.) - -model IntegrationSyncLog { - id String @id @default(dbgenerated("generate_prefixed_cuid('isl'::text)")) - connectionId String - connection IntegrationConnection @relation(fields: [connectionId], references: [id], onDelete: Cascade) - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - /// Provider slug (e.g., "google-workspace", "rippling", "jumpcloud") - provider String - /// Event type (e.g., "employee_sync", "role_discovery", "role_mapping_save") - eventType String - /// Execution status - status IntegrationSyncLogStatus @default(pending) - /// When the operation started executing - startedAt DateTime? - /// When the operation completed (success or failure) - completedAt DateTime? - /// Duration in milliseconds - durationMs Int? - /// Flexible result payload (e.g., { imported, deactivated, reactivated, skipped, errors }) - result Json? - /// Error message if failed - error String? - /// How the sync was triggered: "manual", "scheduled", "api" - triggeredBy String? - /// User who triggered the sync (null for automated/cron) - userId String? - - createdAt DateTime @default(now()) - - @@index([connectionId]) - @@index([organizationId]) - @@index([provider]) - @@index([createdAt]) -} - -enum IntegrationSyncLogStatus { - pending - running - success - failed -} - - -// ===== integration.prisma ===== -model Integration { - id String @id @default(dbgenerated("generate_prefixed_cuid('int'::text)")) - name String - integrationId String - settings Json - userSettings Json - organizationId String - lastRunAt DateTime? - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - results IntegrationResult[] - - @@index([organizationId]) -} - -model IntegrationResult { - id String @id @default(dbgenerated("generate_prefixed_cuid('itr'::text)")) - title String? - description String? - remediation String? - status String? - severity String? - resultDetails Json? - completedAt DateTime? @default(now()) - integrationId String - organizationId String - assignedUserId String? - - assignedUser User? @relation(fields: [assignedUserId], references: [id], onDelete: Cascade) - integration Integration @relation(fields: [integrationId], references: [id], onDelete: Cascade) - - @@index([integrationId]) -} - - -// ===== knowledge-base-document.prisma ===== -model KnowledgeBaseDocument { - id String @id @default(dbgenerated("generate_prefixed_cuid('kbd'::text)")) - name String // Original filename - description String? // Optional user description/notes - s3Key String // S3 storage key (e.g., "org123/knowledge-base-documents/timestamp-file.pdf") - fileType String // MIME type (e.g., "application/pdf") - fileSize Int // File size in bytes - processingStatus KnowledgeBaseDocumentProcessingStatus @default(pending) // Track indexing status - processedAt DateTime? // When indexing completed - triggerRunId String? // Trigger.dev run ID for tracking processing progress - - // Dates - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - // Relationships - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - @@index([organizationId]) - @@index([organizationId, processingStatus]) - @@index([s3Key]) - @@index([triggerRunId]) -} - -enum KnowledgeBaseDocumentProcessingStatus { - pending // Uploaded but not yet processed/indexed - processing // Currently being processed/indexed - completed // Successfully indexed in vector database - failed // Processing failed -} - - -// ===== notification-policy.prisma ===== -model RoleNotificationSetting { - id String @id @default(dbgenerated("generate_prefixed_cuid('rns'::text)")) - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - role String // "owner", "admin", "auditor", "employee", "contractor", or custom role name - - policyNotifications Boolean @default(true) - taskReminders Boolean @default(true) - taskAssignments Boolean @default(true) - taskMentions Boolean @default(true) - weeklyTaskDigest Boolean @default(true) - findingNotifications Boolean @default(true) - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - @@unique([organizationId, role]) - @@map("role_notification_setting") -} - - -// ===== onboarding.prisma ===== -model Onboarding { - organizationId String @id - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - policies Boolean @default(false) - employees Boolean @default(false) - vendors Boolean @default(false) - integrations Boolean @default(false) - risk Boolean @default(false) - team Boolean @default(false) - tasks Boolean @default(false) - callBooked Boolean @default(false) - companyBookingDetails Json? - companyDetails Json? - triggerJobId String? - triggerJobCompleted Boolean @default(false) - - @@index([organizationId]) -} - - -// ===== org-chart.prisma ===== -model OrganizationChart { - id String @id @default(dbgenerated("generate_prefixed_cuid('och'::text)")) - organizationId String @unique - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - name String @default("Organization Chart") - type String @default("interactive") // "interactive" or "uploaded" - nodes Json @default("[]") - edges Json @default("[]") - uploadedImageUrl String? // S3 key when type="uploaded" - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - @@index([organizationId]) -} - - -// ===== organization-billing.prisma ===== -model OrganizationBilling { - id String @id @default(dbgenerated("generate_prefixed_cuid('obil'::text)")) - organizationId String @unique @map("organization_id") - stripeCustomerId String @map("stripe_customer_id") - createdAt DateTime @default(now()) @map("created_at") - updatedAt DateTime @updatedAt @map("updated_at") - - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - pentestSubscription PentestSubscription? - - @@map("organization_billing") -} - - -// ===== organization.prisma ===== -model Organization { - id String @id @default(dbgenerated("generate_prefixed_cuid('org'::text)")) - name String - slug String @unique @default(dbgenerated("generate_prefixed_cuid('slug'::text)")) - logo String? - createdAt DateTime @default(now()) - metadata String? - onboarding Onboarding? - website String? - onboardingCompleted Boolean @default(false) - hasAccess Boolean @default(false) - advancedModeEnabled Boolean @default(false) - evidenceApprovalEnabled Boolean @default(false) - deviceAgentStepEnabled Boolean @default(true) - securityTrainingStepEnabled Boolean @default(true) - whistleblowerReportEnabled Boolean @default(true) - accessRequestFormEnabled Boolean @default(true) - - // FleetDM - fleetDmLabelId Int? - isFleetSetupCompleted Boolean @default(false) - - // Employee sync provider (e.g., 'google-workspace', 'rippling') - // When set, the scheduled sync will only use this provider - employeeSyncProvider String? - - apiKeys ApiKey[] - auditLog AuditLog[] - controls Control[] - frameworkInstances FrameworkInstance[] - integrations Integration[] - invitations Invitation[] - members Member[] - policy Policy[] - risk Risk[] - vendors Vendor[] - tasks Task[] - taskItems TaskItem[] - comments Comment[] - attachments Attachment[] - evidenceSubmissions EvidenceSubmission[] - trust Trust[] - context Context[] - secrets Secret[] - securityPenetrationTestRuns SecurityPenetrationTestRun[] - trustAccessRequests TrustAccessRequest[] - trustNdaAgreements TrustNDAAgreement[] - trustDocuments TrustDocument[] - trustResources TrustResource[] @relation("OrganizationTrustResources") - trustCustomLinks TrustCustomLink[] - knowledgeBaseDocuments KnowledgeBaseDocument[] - questionnaires Questionnaire[] - securityQuestionnaireManualAnswers SecurityQuestionnaireManualAnswer[] - soaDocuments SOADocument[] - primaryColor String? - trustPortalFaqs Json? // Array of { question: string, answer: string, order: number } - - // Integration Platform - integrationConnections IntegrationConnection[] - integrationOAuthApps IntegrationOAuthApp[] - integrationSyncLogs IntegrationSyncLog[] - - // Pentest Subscription - pentestSubscription PentestSubscription? - billing OrganizationBilling? - - // Browser Automation - browserbaseContext BrowserbaseContext? - fleetPolicyResults FleetPolicyResult[] - - // Findings - findings Finding[] - - // Device Agent - devices Device[] - - // Org Chart - organizationChart OrganizationChart? - - // RBAC - organizationRoles OrganizationRole[] - roleNotificationSettings RoleNotificationSetting[] - - @@index([slug]) -} - - -// ===== pentest-subscription.prisma ===== -model PentestSubscription { - id String @id @default(dbgenerated("generate_prefixed_cuid('psub'::text)")) - organizationId String @unique @map("organization_id") - organizationBillingId String @unique @map("organization_billing_id") - stripeSubscriptionId String @map("stripe_subscription_id") - stripePriceId String @map("stripe_price_id") - stripeOveragePriceId String? @map("stripe_overage_price_id") - status String @default("active") // active | cancelled | past_due - includedRunsPerPeriod Int @default(3) @map("included_runs_per_period") - currentPeriodStart DateTime @map("current_period_start") - currentPeriodEnd DateTime @map("current_period_end") - createdAt DateTime @default(now()) @map("created_at") - updatedAt DateTime @updatedAt @map("updated_at") - - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - organizationBilling OrganizationBilling @relation(fields: [organizationBillingId], references: [id]) - - @@index([organizationId]) - @@map("pentest_subscriptions") -} - - -// ===== policy.prisma ===== -enum PolicyDisplayFormat { - EDITOR - PDF -} - -enum PolicyVisibility { - ALL // Visible to everyone in organization - DEPARTMENT // Only visible to specified departments -} - -model Policy { - id String @id @default(dbgenerated("generate_prefixed_cuid('pol'::text)")) - name String - description String? - status PolicyStatus @default(draft) - content Json[] - draftContent Json[] @default([]) - frequency Frequency? - department Departments? - isRequiredToSign Boolean @default(true) - signedBy String[] @default([]) - reviewDate DateTime? - isArchived Boolean @default(false) - displayFormat PolicyDisplayFormat @default(EDITOR) - pdfUrl String? - - // Visibility settings (for department-specific policies) - visibility PolicyVisibility @default(ALL) - visibleToDepartments Departments[] @default([]) - - // Dates - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - lastArchivedAt DateTime? - lastPublishedAt DateTime? - - // Relationships - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - assigneeId String? - assignee Member? @relation("PolicyAssignee", fields: [assigneeId], references: [id], onDelete: SetNull, onUpdate: Cascade) - approverId String? - approver Member? @relation("PolicyApprover", fields: [approverId], references: [id], onDelete: SetNull, onUpdate: Cascade) - policyTemplateId String? - policyTemplate FrameworkEditorPolicyTemplate? @relation(fields: [policyTemplateId], references: [id]) - controls Control[] - currentVersionId String? @unique - currentVersion PolicyVersion? @relation("PolicyCurrentVersion", fields: [currentVersionId], references: [id]) - pendingVersionId String? - versions PolicyVersion[] @relation("PolicyVersions") - - @@index([organizationId]) -} - -model PolicyVersion { - id String @id @default(dbgenerated("generate_prefixed_cuid('pv'::text)")) - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - // Relations - policyId String - policy Policy @relation("PolicyVersions", fields: [policyId], references: [id], onDelete: Cascade) - currentForPolicy Policy? @relation("PolicyCurrentVersion") - - // Version details - version Int - content Json[] - pdfUrl String? - publishedById String? - publishedBy Member? @relation("PolicyVersionPublisher", fields: [publishedById], references: [id], onDelete: SetNull) - changelog String? - - @@unique([policyId, version]) - @@index([policyId]) - @@index([createdAt]) -} - - -// ===== questionnaire.prisma ===== -model Questionnaire { - id String @id @default(dbgenerated("generate_prefixed_cuid('qst'::text)")) - filename String // Original filename - s3Key String // S3 storage key for the uploaded file - fileType String // MIME type (e.g., "application/pdf") - fileSize Int // File size in bytes - status QuestionnaireStatus @default(parsing) // Parsing status - parsedAt DateTime? // When parsing completed - totalQuestions Int @default(0) // Total number of questions parsed - answeredQuestions Int @default(0) // Number of questions with answers - source String @default("internal") // Source of the questionnaire: 'internal' (from app) or 'external' (from trust portal) - - // Dates - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - // Relationships - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - questions QuestionnaireQuestionAnswer[] - manualAnswers SecurityQuestionnaireManualAnswer[] // Manual answers saved from this questionnaire - - @@index([organizationId]) - @@index([organizationId, createdAt]) - @@index([status]) - @@index([source]) -} - -model QuestionnaireQuestionAnswer { - id String @id @default(dbgenerated("generate_prefixed_cuid('qqa'::text)")) - question String // The question text - answer String? // The answer (nullable if not provided in file or not generated yet) - status QuestionnaireAnswerStatus @default(untouched) // Answer status - questionIndex Int // Order/index of the question in the questionnaire - sources Json? // Sources used for generated answers (array of source objects) - generatedAt DateTime? // When answer was generated (if status is generated) - updatedBy String? // User ID who last updated the answer (if manual) - - // Dates - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - // Relationships - questionnaireId String - questionnaire Questionnaire @relation(fields: [questionnaireId], references: [id], onDelete: Cascade) - - @@index([questionnaireId]) - @@index([questionnaireId, questionIndex]) - @@index([status]) -} - -enum QuestionnaireStatus { - parsing // Currently being parsed - completed // Successfully parsed - failed // Parsing failed -} - -enum QuestionnaireAnswerStatus { - untouched // No answer yet (empty or not generated) - generated // AI generated answer - manual // Manually written/edited by user -} - - -// ===== requirement.prisma ===== -model RequirementMap { - id String @id @default(dbgenerated("generate_prefixed_cuid('req'::text)")) - - requirementId String - requirement FrameworkEditorRequirement @relation(fields: [requirementId], references: [id], onDelete: Cascade) - - controlId String - control Control @relation(fields: [controlId], references: [id], onDelete: Cascade) - - frameworkInstanceId String - frameworkInstance FrameworkInstance @relation(fields: [frameworkInstanceId], references: [id], onDelete: Cascade) - - @@unique([controlId, frameworkInstanceId, requirementId]) - @@index([requirementId, frameworkInstanceId]) -} - - -// ===== risk.prisma ===== -model Risk { - // Metadata - id String @id @default(dbgenerated("generate_prefixed_cuid('rsk'::text)")) - title String - description String - category RiskCategory - department Departments? - status RiskStatus @default(open) - likelihood Likelihood @default(very_unlikely) - impact Impact @default(insignificant) - residualLikelihood Likelihood @default(very_unlikely) - residualImpact Impact @default(insignificant) - treatmentStrategyDescription String? - treatmentStrategy RiskTreatmentType @default(accept) - - // Dates - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - // Relationships - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - assigneeId String? - assignee Member? @relation(fields: [assigneeId], references: [id]) - tasks Task[] - - @@index([organizationId]) - @@index([category]) - @@index([status]) -} - -enum RiskTreatmentType { - accept - avoid - mitigate - transfer -} - -enum RiskCategory { - customer - fraud - governance - operations - other - people - regulatory - reporting - resilience - technology - vendor_management -} - -enum RiskStatus { - open - pending - closed - archived -} - - -// ===== secret.prisma ===== -model Secret { - id String @id @default(dbgenerated("generate_prefixed_cuid('sec'::text)")) - organizationId String @map("organization_id") - name String - value String @db.Text // Encrypted value - description String? @db.Text - category String? // e.g., "api", "webhook", "database", etc. - lastUsedAt DateTime? @map("last_used_at") - createdAt DateTime @default(now()) @map("created_at") - updatedAt DateTime @updatedAt @map("updated_at") - - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - @@unique([organizationId, name]) - @@map("secrets") -} - - -// ===== security-penetration-test-run.prisma ===== -model SecurityPenetrationTestRun { - id String @id @default(dbgenerated("generate_prefixed_cuid('ptr'::text)")) - organizationId String @map("organization_id") - providerRunId String @map("provider_run_id") - createdAt DateTime @default(now()) @map("created_at") - updatedAt DateTime @updatedAt @map("updated_at") - - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - @@unique([providerRunId]) - @@index([organizationId]) - @@map("security_penetration_test_runs") -} - - -// ===== security-questionnaire-manual-answer.prisma ===== -model SecurityQuestionnaireManualAnswer { - id String @id @default(dbgenerated("generate_prefixed_cuid('sqma'::text)")) - question String // The question text - answer String // The answer text (required for saved answers) - tags String[] @default([]) // Optional tags for categorization - - // Optional reference to original questionnaire (for tracking) - sourceQuestionnaireId String? - sourceQuestionnaire Questionnaire? @relation(fields: [sourceQuestionnaireId], references: [id], onDelete: SetNull) - - // User who created/updated this answer - createdBy String? // User ID - updatedBy String? // User ID - - // Dates - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - // Relationships - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - @@unique([organizationId, question]) // Prevent duplicate questions per organization - @@index([organizationId]) - @@index([organizationId, question]) - @@index([tags]) - @@index([createdAt]) -} - - -// ===== shared.prisma ===== -model ApiKey { - id String @id @default(dbgenerated("generate_prefixed_cuid('apk'::text)")) - name String - key String @unique - keyPrefix String? - salt String? - createdAt DateTime @default(now()) - expiresAt DateTime? - lastUsedAt DateTime? - isActive Boolean @default(true) - scopes String[] @default([]) - - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - organizationId String - - @@index([organizationId]) - @@index([key]) - @@index([keyPrefix]) -} - -model AuditLog { - id String @id @default(dbgenerated("generate_prefixed_cuid('aud'::text)")) - timestamp DateTime @default(now()) - organizationId String - userId String - memberId String? - data Json - description String? - entityId String? - entityType AuditLogEntityType? - - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - user User @relation(fields: [userId], references: [id], onDelete: Cascade) - member Member? @relation(fields: [memberId], references: [id], onDelete: Cascade) - - @@index([userId]) - @@index([organizationId]) - @@index([memberId]) - @@index([entityType]) -} - -enum AuditLogEntityType { - organization - framework - requirement - control - policy - task - people - risk - vendor - tests - integration - trust - finding -} - -enum EvidenceFormType { - board_meeting @map("board-meeting") - it_leadership_meeting @map("it-leadership-meeting") - risk_committee_meeting @map("risk-committee-meeting") - meeting - access_request @map("access-request") - whistleblower_report @map("whistleblower-report") - penetration_test @map("penetration-test") - rbac_matrix @map("rbac-matrix") - infrastructure_inventory @map("infrastructure-inventory") - employee_performance_evaluation @map("employee-performance-evaluation") - network_diagram @map("network-diagram") - tabletop_exercise @map("tabletop-exercise") -} - -model GlobalVendors { - website String @id @unique - company_name String? - legal_name String? - company_description String? - company_hq_address String? - privacy_policy_url String? - terms_of_service_url String? - service_level_agreement_url String? - security_page_url String? - trust_page_url String? - security_certifications String[] - subprocessors String[] - type_of_company String? - - // Vendor Risk Assessment (shared across all organizations) - riskAssessmentData Json? - riskAssessmentVersion String? - riskAssessmentUpdatedAt DateTime? - - approved Boolean @default(false) - createdAt DateTime @default(now()) - - @@index([website]) -} - -enum Departments { - none - admin - gov - hr - it - itsm - qms -} - -enum Frequency { - monthly - quarterly - yearly -} - -enum Likelihood { - very_unlikely - unlikely - possible - likely - very_likely -} - -enum Impact { - insignificant - minor - moderate - major - severe -} - - -// ===== soa.prisma ===== -// Statement of Applicability (SOA) Auto-complete Configuration and Answers - -model SOAFrameworkConfiguration { - id String @id @default(dbgenerated("generate_prefixed_cuid('soa_cfg'::text)")) - frameworkId String - framework FrameworkEditorFramework @relation(fields: [frameworkId], references: [id], onDelete: Cascade) - - // Configuration versioning - allows multiple configurations per framework - version Int @default(1) // Version number for this configuration (increments when config changes) - isLatest Boolean @default(true) // Whether this is the latest configuration version - - // Column definitions for SOA structure (template used when creating new documents) - columns Json // Array of { name: string, type: string } objects - // Example: [{ name: "Control ID", type: "string" }, { name: "Control Name", type: "string" }, { name: "Applicable", type: "boolean" }, { name: "Justification", type: "text" }] - - // Predefined questions for this framework - // Documents reference a specific configuration version via SOADocument.configurationId - // Old documents keep their old config version, new documents use new config version - questions Json // Array of question objects with unique IDs - // Example: [{ id: "A.5.1.1", text: "Is this control applicable?", columnMapping: "Applicable", controlId: "A.5.1.1" }, ...] - // IMPORTANT: question.id must be unique and stable - this is what SOAAnswer.questionId references - - // Dates - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - // Relationships - documents SOADocument[] - - @@unique([frameworkId, version]) // Prevent duplicate configuration versions - @@index([frameworkId]) - @@index([frameworkId, version]) - @@index([frameworkId, isLatest]) -} - -model SOADocument { - id String @id @default(dbgenerated("generate_prefixed_cuid('soa_doc'::text)")) - - // Framework and organization context - frameworkId String - framework FrameworkEditorFramework @relation(fields: [frameworkId], references: [id], onDelete: Cascade) - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - // Configuration reference - references a specific SOAFrameworkConfiguration version - // Each document version can use a different configuration version - // Old documents keep their old config, new documents use new config - configurationId String - configuration SOAFrameworkConfiguration @relation(fields: [configurationId], references: [id], onDelete: Cascade) - - // Document versioning - version Int @default(1) // Version number for this document (increments yearly) - isLatest Boolean @default(true) // Whether this is the latest version - - // Document status - status SOADocumentStatus @default(draft) // draft, in_progress, completed - - // Document metadata - totalQuestions Int @default(0) // Total number of questions in this document - answeredQuestions Int @default(0) // Number of questions with answers - - // Approval tracking - preparedBy String @default("Comp AI") // Always "Comp AI" - approverId String? // Member ID who will approve this document (set when submitted for approval) - approver Member? @relation("SOADocumentApprover", fields: [approverId], references: [id], onDelete: SetNull, onUpdate: Cascade) - approvedAt DateTime? // When document was approved - - // Dates - completedAt DateTime? // When document was completed - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - // Relationships - answers SOAAnswer[] - - @@unique([frameworkId, organizationId, version]) // Prevent duplicate versions - @@index([frameworkId, organizationId]) - @@index([frameworkId, organizationId, version]) - @@index([frameworkId, organizationId, isLatest]) - @@index([configurationId]) - @@index([status]) -} - -model SOAAnswer { - id String @id @default(dbgenerated("generate_prefixed_cuid('soa_ans'::text)")) - - // Document context (replaces direct framework/organization link) - documentId String - document SOADocument @relation(fields: [documentId], references: [id], onDelete: Cascade) - - // Question reference - references question.id from SOADocument.configuration.questions - // References the specific configuration version that the document uses - // If config changes, old documents still reference their old config version - questionId String // Must match a question.id from SOADocument.configuration.questions - - // Answer data - simple text answer - answer String? // Text answer (nullable if not generated yet) - - // Answer metadata - status SOAAnswerStatus @default(untouched) // untouched, generated, manual - sources Json? // Sources used for generated answers (similar to questionnaire) - generatedAt DateTime? // When answer was generated - - // Answer versioning (within the document) - answerVersion Int @default(1) // Version number for this specific answer - isLatestAnswer Boolean @default(true) // Whether this is the latest version of this answer - - // User tracking - createdBy String? // User ID who created this answer - updatedBy String? // User ID who last updated this answer - - // Dates - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - @@unique([documentId, questionId, answerVersion]) // Prevent duplicate answer versions - @@index([documentId]) - @@index([documentId, questionId]) - @@index([documentId, questionId, isLatestAnswer]) - @@index([status]) -} - -enum SOADocumentStatus { - draft // Document is being created/edited - in_progress // Document is being generated - needs_review // Document is submitted for approval - completed // Document is complete and approved -} - -enum SOAAnswerStatus { - untouched // No answer yet (not generated) - generated // AI generated answer - manual // Manually written/edited by user -} - - -// ===== task-item.prisma ===== -model TaskItem { - id String @id @default(dbgenerated("generate_prefixed_cuid('tski'::text)")) - title String - description String? - status TaskItemStatus @default(todo) - priority TaskItemPriority @default(medium) - - // Polymorphic relation (like Comment and Attachment) - entityId String - entityType TaskItemEntityType - - // Assignment (nullable) - assigneeId String? - assignee Member? @relation("TaskItemAssignee", fields: [assigneeId], references: [id], onDelete: SetNull) - - // Creator & Updater - createdById String - createdBy Member @relation("TaskItemCreator", fields: [createdById], references: [id]) - updatedById String? - updatedBy Member? @relation("TaskItemUpdater", fields: [updatedById], references: [id]) - - // Relationships - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - // Dates - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - @@index([entityId, entityType]) - @@index([organizationId]) - @@index([assigneeId]) - @@index([status]) - @@index([priority]) -} - -enum TaskItemStatus { - todo - in_progress - in_review - done - canceled -} - -enum TaskItemPriority { - urgent - high - medium - low -} - -enum TaskItemEntityType { - vendor - risk -} - - -// ===== task.prisma ===== -model Task { - // Metadata - id String @id @default(dbgenerated("generate_prefixed_cuid('tsk'::text)")) - title String - description String - status TaskStatus @default(todo) - automationStatus TaskAutomationStatus @default(AUTOMATED) - frequency TaskFrequency? - department Departments? @default(none) - order Int @default(0) - - // Dates - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - lastCompletedAt DateTime? - reviewDate DateTime? - - // Relationships - assigneeId String? - assignee Member? @relation(fields: [assigneeId], references: [id]) - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - taskTemplateId String? - taskTemplate FrameworkEditorTaskTemplate? @relation(fields: [taskTemplateId], references: [id]) - controls Control[] - vendors Vendor[] - risks Risk[] - evidenceAutomations EvidenceAutomation[] - browserAutomations BrowserAutomation[] - - evidenceAutomationRuns EvidenceAutomationRun[] - integrationCheckRuns IntegrationCheckRun[] - findings Finding[] - - // Evidence approval - approverId String? - approver Member? @relation("TaskApprover", fields: [approverId], references: [id]) - approvedAt DateTime? - previousStatus TaskStatus? -} - -enum TaskStatus { - todo - in_progress - in_review - done - not_relevant - failed -} - -enum TaskFrequency { - daily - weekly - monthly - quarterly - yearly -} - -enum TaskAutomationStatus { - AUTOMATED - MANUAL -} - - -// ===== trust.prisma ===== -model Trust { - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - friendlyUrl String? @unique - domain String? - domainVerified Boolean @default(false) - isVercelDomain Boolean @default(false) - vercelVerification String? - status TrustStatus @default(published) - contactEmail String? - - /// Domains that bypass NDA signing when requesting trust portal access - allowedDomains String[] @default([]) - - email String? - privacyPolicy String? - soc2 Boolean @default(false) - soc2type1 Boolean @default(false) - soc2type2 Boolean @default(false) - iso27001 Boolean @default(false) - iso42001 Boolean @default(false) - nen7510 Boolean @default(false) - gdpr Boolean @default(false) - hipaa Boolean @default(false) - pci_dss Boolean @default(false) - iso9001 Boolean @default(false) - - soc2_status FrameworkStatus @default(started) - soc2type1_status FrameworkStatus @default(started) - soc2type2_status FrameworkStatus @default(started) - iso27001_status FrameworkStatus @default(started) - iso42001_status FrameworkStatus @default(started) - nen7510_status FrameworkStatus @default(started) - gdpr_status FrameworkStatus @default(started) - hipaa_status FrameworkStatus @default(started) - pci_dss_status FrameworkStatus @default(started) - iso9001_status FrameworkStatus @default(started) - - // Overview section for public trust portal - overviewTitle String? - overviewContent String? // Markdown content with links - showOverview Boolean @default(false) - - // Favicon for trust portal (stored in S3) - favicon String? - - @@id([status, organizationId]) - @@unique([organizationId]) - @@index([organizationId]) - @@index([friendlyUrl]) -} - -enum TrustStatus { - draft - published -} - -enum FrameworkStatus { - started - in_progress - compliant -} - -enum TrustFramework { - iso_27001 - iso_42001 - gdpr - hipaa - soc2_type1 - soc2_type2 - pci_dss - nen_7510 - iso_9001 -} - -model TrustResource { - id String @id @default(dbgenerated("generate_prefixed_cuid('tcr'::text)")) - organizationId String - organization Organization @relation("OrganizationTrustResources", fields: [organizationId], references: [id], onDelete: Cascade) - framework TrustFramework - s3Key String - fileName String - fileSize Int - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - @@unique([organizationId, framework]) - @@index([organizationId]) -} - -model TrustAccessRequest { - id String @id @default(dbgenerated("generate_prefixed_cuid('tar'::text)")) - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - name String - email String - company String? - jobTitle String? - purpose String? - requestedDurationDays Int? - - status TrustAccessRequestStatus @default(under_review) - reviewerMemberId String? - reviewer Member? @relation("TrustAccessRequestReviewer", fields: [reviewerMemberId], references: [id], onDelete: SetNull) - reviewedAt DateTime? - decisionReason String? - - ipAddress String? - userAgent String? - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - grant TrustAccessGrant? @relation("RequestGrant") - ndaAgreements TrustNDAAgreement[] @relation("RequestNDA") - - @@index([organizationId]) - @@index([email]) - @@index([status]) - @@index([organizationId, status]) -} - -model TrustAccessGrant { - id String @id @default(dbgenerated("generate_prefixed_cuid('tag'::text)")) - - accessRequestId String @unique - accessRequest TrustAccessRequest @relation("RequestGrant", fields: [accessRequestId], references: [id], onDelete: Cascade) - - subjectEmail String - - status TrustAccessGrantStatus @default(active) - expiresAt DateTime - - accessToken String? @unique - accessTokenExpiresAt DateTime? - - issuedByMemberId String? - issuedBy Member? @relation("IssuedGrants", fields: [issuedByMemberId], references: [id], onDelete: SetNull) - - revokedAt DateTime? - revokedByMemberId String? - revokedBy Member? @relation("RevokedGrants", fields: [revokedByMemberId], references: [id], onDelete: SetNull) - revokeReason String? - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - ndaAgreement TrustNDAAgreement? @relation("GrantNDA") - - @@index([accessRequestId]) - @@index([subjectEmail]) - @@index([status]) - @@index([expiresAt]) - @@index([status, expiresAt]) - @@index([accessToken]) -} - -enum TrustAccessRequestStatus { - under_review - approved - denied - canceled -} - -enum TrustAccessGrantStatus { - active - expired - revoked -} - -model TrustNDAAgreement { - id String @id @default(dbgenerated("generate_prefixed_cuid('tna'::text)")) - - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - accessRequestId String - accessRequest TrustAccessRequest @relation("RequestNDA", fields: [accessRequestId], references: [id], onDelete: Cascade) - - grantId String? @unique - grant TrustAccessGrant? @relation("GrantNDA", fields: [grantId], references: [id], onDelete: SetNull) - - signerName String? - signerEmail String? - - status TrustNDAStatus @default(pending) - - signToken String @unique - signTokenExpiresAt DateTime - - pdfTemplateKey String? - pdfSignedKey String? - - signedAt DateTime? - - ipAddress String? - userAgent String? - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - @@index([organizationId]) - @@index([accessRequestId]) - @@index([signToken]) - @@index([status]) -} - -enum TrustNDAStatus { - pending - signed - void -} - -model TrustDocument { - id String @id @default(dbgenerated("generate_prefixed_cuid('tdoc'::text)")) - - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - name String - description String? - s3Key String - - isActive Boolean @default(true) - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - @@index([organizationId]) - @@index([organizationId, isActive]) -} - -model TrustCustomLink { - id String @id @default(dbgenerated("generate_prefixed_cuid('tcl'::text)")) - - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - - title String - description String? - url String - order Int @default(0) - isActive Boolean @default(true) - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - @@index([organizationId]) - @@index([organizationId, isActive, order]) -} - - -// ===== vendor.prisma ===== -model Vendor { - id String @id @default(dbgenerated("generate_prefixed_cuid('vnd'::text)")) - name String - description String - category VendorCategory @default(other) - status VendorStatus @default(not_assessed) - inherentProbability Likelihood @default(very_unlikely) - inherentImpact Impact @default(insignificant) - residualProbability Likelihood @default(very_unlikely) - residualImpact Impact @default(insignificant) - website String? - isSubProcessor Boolean @default(false) - - // Trust Portal display settings - logoUrl String? - showOnTrustPortal Boolean @default(false) - trustPortalOrder Int? - complianceBadges Json? // Array of { type: 'soc2' | 'iso27001' | etc, verified: boolean } - - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - - organizationId String - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - assigneeId String? - assignee Member? @relation(fields: [assigneeId], references: [id], onDelete: Cascade) - contacts VendorContact[] - tasks Task[] - - @@index([organizationId]) - @@index([assigneeId]) - @@index([category]) -} - -model VendorContact { - id String @id @default(dbgenerated("generate_prefixed_cuid('vct'::text)")) - vendorId String - name String - email String - phone String - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - Vendor Vendor @relation(fields: [vendorId], references: [id], onDelete: Cascade) - - @@index([vendorId]) -} - -enum VendorCategory { - cloud - infrastructure - software_as_a_service - finance - marketing - sales - hr - other -} - -enum VendorStatus { - not_assessed - in_progress - assessed -} diff --git a/apps/portal/package.json b/apps/portal/package.json index 7bee2618e2..972093953b 100644 --- a/apps/portal/package.json +++ b/apps/portal/package.json @@ -57,7 +57,7 @@ "build": "next build", "build:docker": "prisma generate --schema=prisma/schema && node ../../packages/db/scripts/fix-generated-extensions.js src/generated/prisma && next build", "db:generate": "bun run db:getschema && prisma generate --schema=prisma/schema && node ../../packages/db/scripts/fix-generated-extensions.js src/generated/prisma", - "db:getschema": "find ../../packages/db/prisma/schema -name '*.prisma' ! -name 'schema.prisma' -exec cp {} prisma/schema/ \\;", + "db:getschema": "find prisma/schema -name '*.prisma' ! -name 'schema.prisma' -delete && find ../../packages/db/prisma/schema -name '*.prisma' ! -name 'schema.prisma' -exec cp {} prisma/schema/ \\;", "db:migrate": "cd ../../packages/db && bunx prisma migrate dev && cd ../../apps/portal", "dev": "next dev --turbopack -p 3002", "lint": "eslint . && prettier --check .", diff --git a/bun.lock b/bun.lock index 9acbc4bbd0..a2a44f2404 100644 --- a/bun.lock +++ b/bun.lock @@ -145,6 +145,7 @@ "@trigger.dev/build": "4.4.3", "@trigger.dev/sdk": "4.4.3", "@trycompai/auth": "workspace:*", + "@trycompai/billing": "workspace:*", "@trycompai/company": "workspace:*", "@trycompai/db": "workspace:*", "@trycompai/email": "workspace:*", @@ -282,6 +283,7 @@ "@trigger.dev/react-hooks": "4.4.3", "@trigger.dev/sdk": "4.4.3", "@trycompai/auth": "workspace:*", + "@trycompai/billing": "workspace:*", "@trycompai/company": "workspace:*", "@trycompai/db": "workspace:*", "@trycompai/design-system": "^1.0.32", @@ -512,6 +514,14 @@ "typescript": "^5.7.3", }, }, + "packages/billing": { + "name": "@trycompai/billing", + "version": "1.0.0", + "devDependencies": { + "@trycompai/tsconfig": "workspace:*", + "typescript": "^5.8.3", + }, + }, "packages/company": { "name": "@trycompai/company", "version": "1.0.0", @@ -2695,6 +2705,8 @@ "@trycompai/auth": ["@trycompai/auth@workspace:packages/auth"], + "@trycompai/billing": ["@trycompai/billing@workspace:packages/billing"], + "@trycompai/company": ["@trycompai/company@workspace:packages/company"], "@trycompai/db": ["@trycompai/db@workspace:packages/db"], diff --git a/package.json b/package.json index 6b9ab0e668..9ad2cb8c00 100644 --- a/package.json +++ b/package.json @@ -61,6 +61,7 @@ "scripts": { "build": "turbo build", "check:env": "node scripts/check-unused-env.js", + "check:prisma-schemas": "node scripts/check-generated-prisma-schemas.js", "clean": "git clean -xdf node_modules", "clean:workspaces": "turbo clean", "deploy:trigger-prod": "npx trigger.dev@4.4.3 deploy", @@ -71,7 +72,7 @@ "deps:update": "syncpack update", "deps:upgrade": "syncpack update && bun install", "dev": "turbo dev --parallel", - "db:generate": "turbo run db:generate --filter=@trycompai/db --filter=@trycompai/app --filter=@trycompai/portal --filter=@trycompai/api", + "db:generate": "turbo run db:generate --filter=@trycompai/db --filter=@trycompai/app --filter=@trycompai/portal --filter=@trycompai/api --filter=@trycompai/framework-editor", "docker:clean": "bun run -F @trycompai/db docker:clean", "docker:down": "bun run -F @trycompai/db docker:down", "docker:up": "bun run -F @trycompai/db docker:up", diff --git a/packages/billing/package.json b/packages/billing/package.json new file mode 100644 index 0000000000..2fcdca892e --- /dev/null +++ b/packages/billing/package.json @@ -0,0 +1,22 @@ +{ + "name": "@trycompai/billing", + "version": "1.0.0", + "private": true, + "exports": { + ".": "./src/index.ts" + }, + "main": "src/index.ts", + "types": "src/index.ts", + "scripts": { + "clean": "rm -rf .turbo node_modules", + "format": "prettier --write .", + "lint": "prettier --check .", + "test": "bun test src/*.test.ts", + "typecheck": "tsc --noEmit" + }, + "sideEffects": false, + "devDependencies": { + "@trycompai/tsconfig": "workspace:*", + "typescript": "^5.8.3" + } +} diff --git a/packages/billing/src/catalog.test.ts b/packages/billing/src/catalog.test.ts new file mode 100644 index 0000000000..3755c33b4b --- /dev/null +++ b/packages/billing/src/catalog.test.ts @@ -0,0 +1,52 @@ +import { describe, expect, it } from 'bun:test'; + +import { + getBillingCatalog, + getBillingSku, + getBillingSkuByStripePriceId, + isSubscriptionBillingSkuKey, +} from './index'; + +describe('billing catalog', () => { + it('records the test-mode subscription prices', () => { + const catalog = getBillingCatalog('test'); + + expect(catalog.products.pentest).toBe('prod_UQqGJryNvXajUt'); + expect(catalog.prices.pentest_monthly_5).toBe('price_1TRya6CkFWhKYvHI1sJ2M2no'); + expect(catalog.products.background_check).toBe('prod_UQNNIS1di6uOLB'); + expect(catalog.prices.background_checks_monthly_25).toBe('price_1TRya7CkFWhKYvHIDbCWSITp'); + }); + + it('models included monthly usage for subscription skus', () => { + expect( + getBillingSku({ environment: 'test', skuKey: 'pentest_monthly_5' }).includedUsage, + ).toEqual({ quantity: 5, unit: 'scan', reset: 'monthly' }); + + expect( + getBillingSku({ + environment: 'test', + skuKey: 'background_checks_monthly_25', + }).includedUsage, + ).toEqual({ quantity: 25, unit: 'background_check', reset: 'monthly' }); + }); + + it('finds skus by Stripe price id', () => { + expect( + getBillingSkuByStripePriceId({ + environment: 'test', + stripePriceId: 'price_1TRya6CkFWhKYvHI1sJ2M2no', + })?.key, + ).toBe('pentest_monthly_5'); + expect( + getBillingSkuByStripePriceId({ + environment: 'test', + stripePriceId: 'price_missing', + }), + ).toBeNull(); + }); + + it('recognizes subscription sku keys', () => { + expect(isSubscriptionBillingSkuKey('pentest_monthly_5')).toBe(true); + expect(isSubscriptionBillingSkuKey('background_check_one_time')).toBe(false); + }); +}); diff --git a/packages/billing/src/index.ts b/packages/billing/src/index.ts new file mode 100644 index 0000000000..49524dd13b --- /dev/null +++ b/packages/billing/src/index.ts @@ -0,0 +1,142 @@ +export type BillingCatalogEnvironment = 'test'; + +export type BillingSkuKey = + | 'background_check_one_time' + | 'background_checks_monthly_25' + | 'pentest_monthly_5'; + +export type BillingProductKey = 'background_check' | 'pentest'; + +export type BillingCadence = 'one_time' | 'month'; + +export type BillingSku = { + key: BillingSkuKey; + productKey: BillingProductKey; + name: string; + description: string; + cadence: BillingCadence; + currency: 'usd'; + unitAmount: number; + stripeProductId: string; + stripePriceId: string; + includedUsage?: { + quantity: number; + unit: 'background_check' | 'scan'; + reset: 'monthly'; + }; +}; + +export const subscriptionBillingSkuKeys = [ + 'background_checks_monthly_25', + 'pentest_monthly_5', +] as const satisfies readonly BillingSkuKey[]; + +export type BillingCatalog = { + environment: BillingCatalogEnvironment; + products: Record; + prices: Record; + skus: Record; +}; + +const testSkus = { + background_check_one_time: { + key: 'background_check_one_time', + productKey: 'background_check', + name: 'Employee Background Check', + description: 'One-time employee background check.', + cadence: 'one_time', + currency: 'usd', + unitAmount: 4900, + stripeProductId: 'prod_UQNNIS1di6uOLB', + stripePriceId: 'price_1TRWckCkFWhKYvHIA1GLv1sO', + }, + background_checks_monthly_25: { + key: 'background_checks_monthly_25', + productKey: 'background_check', + name: 'Background Checks Monthly', + description: + 'Monthly Comp AI background check package. Includes 25 checks per month.', + cadence: 'month', + currency: 'usd', + unitAmount: 24900, + stripeProductId: 'prod_UQNNIS1di6uOLB', + stripePriceId: 'price_1TRya7CkFWhKYvHIDbCWSITp', + includedUsage: { + quantity: 25, + unit: 'background_check', + reset: 'monthly', + }, + }, + pentest_monthly_5: { + key: 'pentest_monthly_5', + productKey: 'pentest', + name: 'Penetration Tests Monthly', + description: + 'Monthly Comp AI penetration testing package. Includes 5 scans per month.', + cadence: 'month', + currency: 'usd', + unitAmount: 39900, + stripeProductId: 'prod_UQqGJryNvXajUt', + stripePriceId: 'price_1TRya6CkFWhKYvHI1sJ2M2no', + includedUsage: { + quantity: 5, + unit: 'scan', + reset: 'monthly', + }, + }, +} satisfies Record; + +export const billingCatalogs = { + test: createCatalog({ environment: 'test', skus: testSkus }), +} satisfies Record; + +export function getBillingCatalog( + environment: BillingCatalogEnvironment = 'test', +): BillingCatalog { + return billingCatalogs[environment]; +} + +export function getBillingSku(params: { + environment?: BillingCatalogEnvironment; + skuKey: BillingSkuKey; +}): BillingSku { + const catalog = getBillingCatalog(params.environment); + return catalog.skus[params.skuKey]; +} + +export function getBillingSkuByStripePriceId(params: { + environment?: BillingCatalogEnvironment; + stripePriceId: string; +}): BillingSku | null { + const catalog = getBillingCatalog(params.environment); + return ( + Object.values(catalog.skus).find( + (sku) => sku.stripePriceId === params.stripePriceId, + ) ?? null + ); +} + +export function isSubscriptionBillingSkuKey(value: string): value is BillingSkuKey { + return subscriptionBillingSkuKeys.some((skuKey) => skuKey === value); +} + +function createCatalog(params: { + environment: BillingCatalogEnvironment; + skus: Record; +}): BillingCatalog { + return { + environment: params.environment, + products: { + background_check: params.skus.background_check_one_time.stripeProductId, + pentest: params.skus.pentest_monthly_5.stripeProductId, + }, + prices: { + background_check_one_time: + params.skus.background_check_one_time.stripePriceId, + background_checks_monthly_25: + params.skus.background_checks_monthly_25.stripePriceId, + pentest_monthly_5: params.skus.pentest_monthly_5.stripePriceId, + }, + skus: params.skus, + }; +} diff --git a/packages/billing/tsconfig.json b/packages/billing/tsconfig.json new file mode 100644 index 0000000000..df12fbc94a --- /dev/null +++ b/packages/billing/tsconfig.json @@ -0,0 +1,9 @@ +{ + "extends": "@trycompai/tsconfig/base.json", + "compilerOptions": { + "incremental": true, + "tsBuildInfoFile": "node_modules/.cache/tsbuildinfo.json" + }, + "include": ["src"], + "exclude": ["node_modules"] +} diff --git a/packages/db/prisma/migrations/20260430180000_subscription_billing_foundation/migration.sql b/packages/db/prisma/migrations/20260430180000_subscription_billing_foundation/migration.sql new file mode 100644 index 0000000000..69d9095c5a --- /dev/null +++ b/packages/db/prisma/migrations/20260430180000_subscription_billing_foundation/migration.sql @@ -0,0 +1,125 @@ +ALTER TABLE "organization_billing" + RENAME COLUMN "stripe_background_check_payment_method_id" TO "stripe_payment_method_id"; + +ALTER TABLE "organization_billing" + RENAME COLUMN "background_check_payment_method_setup_at" TO "payment_method_updated_at"; + +CREATE TABLE "organization_billing_subscriptions" ( + "id" TEXT NOT NULL DEFAULT generate_prefixed_cuid('obs'::text), + "organization_id" TEXT NOT NULL, + "sku_key" TEXT NOT NULL, + "stripe_subscription_id" TEXT NOT NULL, + "stripe_subscription_item_id" TEXT NOT NULL, + "stripe_price_id" TEXT NOT NULL, + "stripe_status" TEXT NOT NULL, + "current_period_start" TIMESTAMP(3), + "current_period_end" TIMESTAMP(3), + "included_quantity" INTEGER NOT NULL, + "used_quantity" INTEGER NOT NULL DEFAULT 0, + "cancel_at_period_end" BOOLEAN NOT NULL DEFAULT false, + "canceled_at" TIMESTAMP(3), + "created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "organization_billing_subscriptions_pkey" PRIMARY KEY ("id") +); + +CREATE TABLE "billing_usage_events" ( + "id" TEXT NOT NULL DEFAULT generate_prefixed_cuid('bue'::text), + "organization_id" TEXT NOT NULL, + "sku_key" TEXT NOT NULL, + "event_type" TEXT NOT NULL, + "quantity" INTEGER NOT NULL, + "source_resource_id" TEXT, + "idempotency_key" TEXT NOT NULL, + "stripe_event_id" TEXT, + "stripe_invoice_id" TEXT, + "stripe_subscription_item_id" TEXT, + "period_start" TIMESTAMP(3), + "period_end" TIMESTAMP(3), + "created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "billing_usage_events_pkey" PRIMARY KEY ("id") +); + +CREATE TABLE "stripe_webhook_events" ( + "id" TEXT NOT NULL DEFAULT generate_prefixed_cuid('swe'::text), + "stripe_event_id" TEXT NOT NULL, + "event_type" TEXT NOT NULL, + "payload" JSONB NOT NULL, + "status" TEXT NOT NULL DEFAULT 'processed', + "error" TEXT, + "processed_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "stripe_webhook_events_pkey" PRIMARY KEY ("id") +); + +CREATE TABLE "billing_audit_events" ( + "id" TEXT NOT NULL DEFAULT generate_prefixed_cuid('bae'::text), + "organization_id" TEXT NOT NULL, + "event_type" TEXT NOT NULL, + "sku_key" TEXT, + "stripe_event_id" TEXT, + "metadata" JSONB, + "created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "billing_audit_events_pkey" PRIMARY KEY ("id") +); + +CREATE UNIQUE INDEX "organization_billing_subscriptions_organization_id_sku_key_key" + ON "organization_billing_subscriptions"("organization_id", "sku_key"); + +CREATE UNIQUE INDEX "organization_billing_subscriptions_stripe_subscription_id_stripe_subscription_item_id_key" + ON "organization_billing_subscriptions"("stripe_subscription_id", "stripe_subscription_item_id"); + +CREATE INDEX "organization_billing_subscriptions_organization_id_idx" + ON "organization_billing_subscriptions"("organization_id"); + +CREATE INDEX "organization_billing_subscriptions_stripe_subscription_id_idx" + ON "organization_billing_subscriptions"("stripe_subscription_id"); + +CREATE INDEX "organization_billing_subscriptions_sku_key_idx" + ON "organization_billing_subscriptions"("sku_key"); + +CREATE UNIQUE INDEX "billing_usage_events_idempotency_key_key" + ON "billing_usage_events"("idempotency_key"); + +CREATE INDEX "billing_usage_events_organization_id_sku_key_idx" + ON "billing_usage_events"("organization_id", "sku_key"); + +CREATE INDEX "billing_usage_events_stripe_event_id_idx" + ON "billing_usage_events"("stripe_event_id"); + +CREATE INDEX "billing_usage_events_stripe_invoice_id_idx" + ON "billing_usage_events"("stripe_invoice_id"); + +CREATE UNIQUE INDEX "stripe_webhook_events_stripe_event_id_key" + ON "stripe_webhook_events"("stripe_event_id"); + +CREATE INDEX "stripe_webhook_events_event_type_idx" + ON "stripe_webhook_events"("event_type"); + +CREATE INDEX "stripe_webhook_events_status_idx" + ON "stripe_webhook_events"("status"); + +CREATE INDEX "billing_audit_events_organization_id_idx" + ON "billing_audit_events"("organization_id"); + +CREATE INDEX "billing_audit_events_stripe_event_id_idx" + ON "billing_audit_events"("stripe_event_id"); + +CREATE INDEX "billing_audit_events_sku_key_idx" + ON "billing_audit_events"("sku_key"); + +ALTER TABLE "organization_billing_subscriptions" + ADD CONSTRAINT "organization_billing_subscriptions_organization_id_fkey" + FOREIGN KEY ("organization_id") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +ALTER TABLE "billing_usage_events" + ADD CONSTRAINT "billing_usage_events_organization_id_fkey" + FOREIGN KEY ("organization_id") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +ALTER TABLE "billing_audit_events" + ADD CONSTRAINT "billing_audit_events_organization_id_fkey" + FOREIGN KEY ("organization_id") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/packages/db/prisma/migrations/20260430181000_drop_stale_pentest_subscription/migration.sql b/packages/db/prisma/migrations/20260430181000_drop_stale_pentest_subscription/migration.sql new file mode 100644 index 0000000000..93347bc65a --- /dev/null +++ b/packages/db/prisma/migrations/20260430181000_drop_stale_pentest_subscription/migration.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS "pentest_subscriptions"; diff --git a/packages/db/prisma/migrations/20260430182000_name_billing_subscription_unique_index/migration.sql b/packages/db/prisma/migrations/20260430182000_name_billing_subscription_unique_index/migration.sql new file mode 100644 index 0000000000..15a030da77 --- /dev/null +++ b/packages/db/prisma/migrations/20260430182000_name_billing_subscription_unique_index/migration.sql @@ -0,0 +1,2 @@ +ALTER INDEX IF EXISTS "organization_billing_subscriptions_stripe_subscription_id_strip" + RENAME TO "org_billing_subs_stripe_sub_item_key"; diff --git a/packages/db/prisma/schema/background-check.prisma b/packages/db/prisma/schema/background-check.prisma index c365efed7d..d80fe594bd 100644 --- a/packages/db/prisma/schema/background-check.prisma +++ b/packages/db/prisma/schema/background-check.prisma @@ -25,8 +25,8 @@ model BackgroundCheckRequest { createdAt DateTime @default(now()) updatedAt DateTime @updatedAt - organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) - member Member @relation(fields: [memberId], references: [id], onDelete: Cascade) + organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) + member Member @relation(fields: [memberId], references: [id], onDelete: Cascade) webhookEvents BackgroundCheckWebhookEvent[] @@unique([organizationId, memberId]) diff --git a/packages/db/prisma/schema/organization-billing.prisma b/packages/db/prisma/schema/organization-billing.prisma index a48fb991ad..9b83418e33 100644 --- a/packages/db/prisma/schema/organization-billing.prisma +++ b/packages/db/prisma/schema/organization-billing.prisma @@ -1,13 +1,95 @@ model OrganizationBilling { - id String @id @default(dbgenerated("generate_prefixed_cuid('obil'::text)")) - organizationId String @unique @map("organization_id") - stripeCustomerId String @map("stripe_customer_id") - stripeBackgroundCheckPaymentMethodId String? @map("stripe_background_check_payment_method_id") - backgroundCheckPaymentMethodSetupAt DateTime? @map("background_check_payment_method_setup_at") - createdAt DateTime @default(now()) @map("created_at") - updatedAt DateTime @updatedAt @map("updated_at") + id String @id @default(dbgenerated("generate_prefixed_cuid('obil'::text)")) + organizationId String @unique @map("organization_id") + stripeCustomerId String @map("stripe_customer_id") + stripePaymentMethodId String? @map("stripe_payment_method_id") + paymentMethodUpdatedAt DateTime? @map("payment_method_updated_at") + createdAt DateTime @default(now()) @map("created_at") + updatedAt DateTime @updatedAt @map("updated_at") organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) @@map("organization_billing") } + +model OrganizationBillingSubscription { + id String @id @default(dbgenerated("generate_prefixed_cuid('obs'::text)")) + organizationId String @map("organization_id") + skuKey String @map("sku_key") + stripeSubscriptionId String @map("stripe_subscription_id") + stripeSubscriptionItemId String @map("stripe_subscription_item_id") + stripePriceId String @map("stripe_price_id") + stripeStatus String @map("stripe_status") + currentPeriodStart DateTime? @map("current_period_start") + currentPeriodEnd DateTime? @map("current_period_end") + includedQuantity Int @map("included_quantity") + usedQuantity Int @default(0) @map("used_quantity") + cancelAtPeriodEnd Boolean @default(false) @map("cancel_at_period_end") + canceledAt DateTime? @map("canceled_at") + createdAt DateTime @default(now()) @map("created_at") + updatedAt DateTime @updatedAt @map("updated_at") + + organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) + + @@unique([organizationId, skuKey]) + @@unique([stripeSubscriptionId, stripeSubscriptionItemId], map: "org_billing_subs_stripe_sub_item_key") + @@index([organizationId]) + @@index([stripeSubscriptionId]) + @@index([skuKey]) + @@map("organization_billing_subscriptions") +} + +model BillingUsageEvent { + id String @id @default(dbgenerated("generate_prefixed_cuid('bue'::text)")) + organizationId String @map("organization_id") + skuKey String @map("sku_key") + eventType String @map("event_type") + quantity Int + sourceResourceId String? @map("source_resource_id") + idempotencyKey String @unique @map("idempotency_key") + stripeEventId String? @map("stripe_event_id") + stripeInvoiceId String? @map("stripe_invoice_id") + stripeSubscriptionItemId String? @map("stripe_subscription_item_id") + periodStart DateTime? @map("period_start") + periodEnd DateTime? @map("period_end") + createdAt DateTime @default(now()) @map("created_at") + + organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) + + @@index([organizationId, skuKey]) + @@index([stripeEventId]) + @@index([stripeInvoiceId]) + @@map("billing_usage_events") +} + +model StripeWebhookEvent { + id String @id @default(dbgenerated("generate_prefixed_cuid('swe'::text)")) + stripeEventId String @unique @map("stripe_event_id") + eventType String @map("event_type") + payload Json + status String @default("processed") + error String? + processedAt DateTime @default(now()) @map("processed_at") + createdAt DateTime @default(now()) @map("created_at") + + @@index([eventType]) + @@index([status]) + @@map("stripe_webhook_events") +} + +model BillingAuditEvent { + id String @id @default(dbgenerated("generate_prefixed_cuid('bae'::text)")) + organizationId String @map("organization_id") + eventType String @map("event_type") + skuKey String? @map("sku_key") + stripeEventId String? @map("stripe_event_id") + metadata Json? + createdAt DateTime @default(now()) @map("created_at") + + organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) + + @@index([organizationId]) + @@index([stripeEventId]) + @@index([skuKey]) + @@map("billing_audit_events") +} diff --git a/packages/db/prisma/schema/organization.prisma b/packages/db/prisma/schema/organization.prisma index 9124c2375f..7c18fe32b2 100644 --- a/packages/db/prisma/schema/organization.prisma +++ b/packages/db/prisma/schema/organization.prisma @@ -63,8 +63,11 @@ model Organization { // Pentest credits — wallet of run-credits an org can spend. // Source of credits (trial / future Stripe subscription / top-up) // is metadata on the row; balance is unified. - pentestCredits PentestCredits? - billing OrganizationBilling? + pentestCredits PentestCredits? + billing OrganizationBilling? + billingSubscriptions OrganizationBillingSubscription[] + billingUsageEvents BillingUsageEvent[] + billingAuditEvents BillingAuditEvent[] // Browser Automation browserbaseContext BrowserbaseContext? diff --git a/packages/db/prisma/schema/pentest-credits.prisma b/packages/db/prisma/schema/pentest-credits.prisma index 9ee9fce98b..283c7d7663 100644 --- a/packages/db/prisma/schema/pentest-credits.prisma +++ b/packages/db/prisma/schema/pentest-credits.prisma @@ -34,7 +34,7 @@ model PentestCredits { lastGrantSource String @default("trial") @map("last_grant_source") createdAt DateTime @default(now()) @map("created_at") - updatedAt DateTime @updatedAt @map("updated_at") + updatedAt DateTime @updatedAt @map("updated_at") organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) diff --git a/packages/docs/openapi.json b/packages/docs/openapi.json index da3c15884d..2097acecad 100644 --- a/packages/docs/openapi.json +++ b/packages/docs/openapi.json @@ -21065,6 +21065,196 @@ ] } }, + "/v1/billing/status": { + "get": { + "operationId": "BillingController_getStatus_v1", + "parameters": [], + "responses": { + "200": { + "description": "" + } + }, + "security": [ + { + "apikey": [] + } + ], + "summary": "Get organization billing status", + "tags": [ + "Billing" + ] + } + }, + "/v1/billing/preferences": { + "put": { + "operationId": "BillingController_updatePreferences_v1", + "parameters": [], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BillingPreferencesDto" + } + } + } + }, + "responses": { + "200": { + "description": "" + } + }, + "security": [ + { + "apikey": [] + } + ], + "summary": "Update organization billing preferences", + "tags": [ + "Billing" + ] + } + }, + "/v1/billing/setup-session": { + "post": { + "operationId": "BillingController_setupSession_v1", + "parameters": [], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BillingSetupSessionDto" + } + } + } + }, + "responses": { + "200": { + "description": "" + } + }, + "security": [ + { + "apikey": [] + } + ], + "summary": "Create a Stripe setup session", + "tags": [ + "Billing" + ] + } + }, + "/v1/billing/setup-success": { + "post": { + "operationId": "BillingController_setupSuccess_v1", + "parameters": [], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BillingSetupSuccessDto" + } + } + } + }, + "responses": { + "200": { + "description": "" + } + }, + "security": [ + { + "apikey": [] + } + ], + "summary": "Persist a successful Stripe setup session", + "tags": [ + "Billing" + ] + } + }, + "/v1/billing/portal": { + "post": { + "operationId": "BillingController_portal_v1", + "parameters": [], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BillingPortalDto" + } + } + } + }, + "responses": { + "200": { + "description": "" + } + }, + "security": [ + { + "apikey": [] + } + ], + "summary": "Create a Stripe billing portal session", + "tags": [ + "Billing" + ] + } + }, + "/v1/billing/subscription-session": { + "post": { + "operationId": "BillingController_subscriptionSession_v1", + "parameters": [], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BillingSubscriptionCheckoutDto" + } + } + } + }, + "responses": { + "200": { + "description": "" + } + }, + "security": [ + { + "apikey": [] + } + ], + "summary": "Create a Stripe subscription Checkout session", + "tags": [ + "Billing" + ] + } + }, + "/v1/billing/webhook": { + "post": { + "operationId": "BillingController_webhook_v1", + "parameters": [], + "responses": { + "200": { + "description": "" + } + }, + "security": [ + { + "apikey": [] + } + ], + "summary": "Receive Stripe billing webhook events", + "tags": [ + "Billing" + ] + } + }, "/v1/background-checks/{id}": { "get": { "operationId": "BackgroundChecksController_getById_v1", @@ -26124,6 +26314,26 @@ "targetUrl" ] }, + "BillingPreferencesDto": { + "type": "object", + "properties": {} + }, + "BillingSetupSessionDto": { + "type": "object", + "properties": {} + }, + "BillingSetupSuccessDto": { + "type": "object", + "properties": {} + }, + "BillingPortalDto": { + "type": "object", + "properties": {} + }, + "BillingSubscriptionCheckoutDto": { + "type": "object", + "properties": {} + }, "RequestBackgroundCheckDto": { "type": "object", "properties": {} diff --git a/scripts/check-generated-prisma-schemas.js b/scripts/check-generated-prisma-schemas.js new file mode 100644 index 0000000000..af0f5d6637 --- /dev/null +++ b/scripts/check-generated-prisma-schemas.js @@ -0,0 +1,61 @@ +#!/usr/bin/env node + +const { existsSync, readFileSync, readdirSync } = require('node:fs'); +const path = require('node:path'); + +const repoRoot = process.cwd(); +const canonicalDir = path.join(repoRoot, 'packages/db/prisma/schema'); +const appsDir = path.join(repoRoot, 'apps'); + +const readPrismaFiles = (directory) => + new Set( + readdirSync(directory) + .filter((file) => file.endsWith('.prisma')) + .filter((file) => file !== 'schema.prisma'), + ); + +const canonicalFiles = readPrismaFiles(canonicalDir); +const errors = []; + +for (const appName of readdirSync(appsDir)) { + const appPrismaDir = path.join(appsDir, appName, 'prisma'); + const schemaDir = path.join(appPrismaDir, 'schema'); + + if (!existsSync(schemaDir)) { + continue; + } + + const localFiles = readPrismaFiles(schemaDir); + + for (const file of localFiles) { + if (!canonicalFiles.has(file)) { + errors.push(`${appName}: stale Prisma schema fragment ${file}`); + } + } + + for (const file of canonicalFiles) { + if (!localFiles.has(file)) { + errors.push(`${appName}: missing Prisma schema fragment ${file}`); + } + } + + const legacySchemaFile = path.join(appPrismaDir, 'schema.prisma'); + if (!existsSync(legacySchemaFile)) { + continue; + } + + const legacySchema = readFileSync(legacySchemaFile, 'utf8'); + if (/^model\s+\w+/m.test(legacySchema)) { + errors.push(`${appName}: legacy prisma/schema.prisma contains model definitions`); + } +} + +if (errors.length > 0) { + console.error('Generated Prisma schemas are out of sync:'); + for (const error of errors) { + console.error(`- ${error}`); + } + process.exit(1); +} + +console.log('Generated Prisma schemas are in sync.'); From 7697765ae9d6802d3ae468f23b83dac8618ddc0e Mon Sep 17 00:00:00 2001 From: Lewis Carhart Date: Thu, 30 Apr 2026 23:15:07 +0100 Subject: [PATCH 05/20] fix(billing): harden stripe flows --- .../background-check-billing-customer.spec.ts | 22 +++ .../background-check-billing-customer.ts | 2 +- .../background-check-billing.service.ts | 7 +- .../background-checks.service.spec.ts | 3 + .../dto/background-check-billing.dto.ts | 3 +- .../billing-entitlements.service.spec.ts | 85 +++++++++ .../billing/billing-entitlements.service.ts | 71 +------ .../billing/billing-included-usage-refunds.ts | 70 +++++++ .../src/billing/billing-preferences.spec.ts | 52 ++++- apps/api/src/billing/billing-preferences.ts | 54 ++++-- .../src/billing/billing-redirect-urls.spec.ts | 14 ++ apps/api/src/billing/billing-redirect-urls.ts | 5 + apps/api/src/billing/billing-stripe-config.ts | 10 + apps/api/src/billing/billing-stripe-ids.ts | 7 + apps/api/src/billing/billing-usage.spec.ts | 41 ++++ apps/api/src/billing/billing-usage.ts | 46 +++-- .../src/billing/billing-webhook.service.ts | 14 +- apps/api/src/billing/billing.service.spec.ts | 14 ++ apps/api/src/billing/billing.service.ts | 69 +++---- apps/api/src/billing/dto/billing.dto.ts | 11 +- .../billing/stripe-webhook-records.spec.ts | 124 ++++++++++++ .../api/src/billing/stripe-webhook-records.ts | 29 ++- ...security-penetration-tests.billing.spec.ts | 177 ++++++++++++++++++ .../security-penetration-tests.service.ts | 27 ++- .../billing/BillingPreferencesForm.tsx | 11 +- packages/billing/package.json | 40 ++-- packages/billing/src/index.ts | 24 +-- packages/billing/tsconfig.json | 14 +- .../migration.sql | 2 + .../security-penetration-test-run.prisma | 11 +- scripts/check-generated-prisma-schemas.js | 9 + 31 files changed, 866 insertions(+), 202 deletions(-) create mode 100644 apps/api/src/billing/billing-entitlements.service.spec.ts create mode 100644 apps/api/src/billing/billing-included-usage-refunds.ts create mode 100644 apps/api/src/billing/billing-redirect-urls.spec.ts create mode 100644 apps/api/src/billing/billing-stripe-config.ts create mode 100644 apps/api/src/billing/billing-stripe-ids.ts create mode 100644 apps/api/src/billing/stripe-webhook-records.spec.ts create mode 100644 apps/api/src/security-penetration-tests/security-penetration-tests.billing.spec.ts create mode 100644 packages/db/prisma/migrations/20260430183000_pentest_billing_usage_source/migration.sql diff --git a/apps/api/src/background-checks/background-check-billing-customer.spec.ts b/apps/api/src/background-checks/background-check-billing-customer.spec.ts index d67cb79922..e47abaf77f 100644 --- a/apps/api/src/background-checks/background-check-billing-customer.spec.ts +++ b/apps/api/src/background-checks/background-check-billing-customer.spec.ts @@ -93,4 +93,26 @@ describe('findOrCreateBackgroundCheckBillingCustomer', () => { email: 'billing@trycomp.ai', }); }); + + it('does not create a Stripe client when existing billing needs no update', async () => { + type OrganizationBilling = typeof db.organizationBilling; + const dbMocks = mockedDb as unknown as { + organizationBilling: { + findUnique: jest.MockedFunction; + }; + }; + dbMocks.organizationBilling.findUnique.mockResolvedValueOnce({ + stripeCustomerId: 'cus_existing', + } as Awaited>); + const getClient = jest.fn(); + + await expect( + findOrCreateBackgroundCheckBillingCustomer({ + stripeService: { getClient } as unknown as StripeService, + organizationId: 'org_1', + }), + ).resolves.toBe('cus_existing'); + + expect(getClient).not.toHaveBeenCalled(); + }); }); diff --git a/apps/api/src/background-checks/background-check-billing-customer.ts b/apps/api/src/background-checks/background-check-billing-customer.ts index e8604d435e..c57dc72b6f 100644 --- a/apps/api/src/background-checks/background-check-billing-customer.ts +++ b/apps/api/src/background-checks/background-check-billing-customer.ts @@ -15,7 +15,6 @@ export async function findOrCreateBackgroundCheckBillingCustomer({ where: { organizationId }, select: { stripeCustomerId: true }, }); - const stripe = stripeService.getClient(); if (existingBilling) { await updateStripeCustomerEmail({ @@ -35,6 +34,7 @@ export async function findOrCreateBackgroundCheckBillingCustomer({ throw new NotFoundException('Organization not found.'); } + const stripe = stripeService.getClient(); const customer = await stripe.customers.create( { name: organization.name, diff --git a/apps/api/src/background-checks/background-check-billing.service.ts b/apps/api/src/background-checks/background-check-billing.service.ts index 8622e58f0c..0236fedea2 100644 --- a/apps/api/src/background-checks/background-check-billing.service.ts +++ b/apps/api/src/background-checks/background-check-billing.service.ts @@ -1,5 +1,7 @@ import { Injectable } from '@nestjs/common'; +import { getBillingSku } from '@trycompai/billing'; import { BillingService } from '../billing/billing.service'; +import { validateBackgroundCheckBillingRedirectUrl } from './background-check-billing-urls'; @Injectable() export class BackgroundCheckBillingService { @@ -15,6 +17,8 @@ export class BackgroundCheckBillingService { cancelUrl: string; customerEmail?: string; }): Promise<{ url: string }> { + validateBackgroundCheckBillingRedirectUrl(params.successUrl); + validateBackgroundCheckBillingRedirectUrl(params.cancelUrl); return this.billingService.createSetupSession(params); } @@ -29,6 +33,7 @@ export class BackgroundCheckBillingService { organizationId: string; returnUrl: string; }): Promise<{ url: string }> { + validateBackgroundCheckBillingRedirectUrl(params.returnUrl); return this.billingService.createBillingPortalSession(params); } @@ -37,7 +42,7 @@ export class BackgroundCheckBillingService { unitAmount: number; currency: string; }> { - const sku = this.billingService.getOneTimeBackgroundCheckSku(); + const sku = getBillingSku({ skuKey: 'background_check_one_time' }); return { id: sku.stripePriceId, unitAmount: sku.unitAmount, diff --git a/apps/api/src/background-checks/background-checks.service.spec.ts b/apps/api/src/background-checks/background-checks.service.spec.ts index 39aff5847d..2d23c240c9 100644 --- a/apps/api/src/background-checks/background-checks.service.spec.ts +++ b/apps/api/src/background-checks/background-checks.service.spec.ts @@ -432,6 +432,9 @@ describe('background checks', () => { }); it('uses BETTER_AUTH_URL as the local app URL fallback for setup redirects', async () => { + process.env.NEXT_PUBLIC_APP_URL = ''; + process.env.APP_URL = ''; + process.env.BETTER_AUTH_URL = 'http://localhost:3000'; const billingService = { createSetupSession: jest.fn().mockResolvedValue({ url: 'https://checkout.stripe.com/c/session_1', diff --git a/apps/api/src/background-checks/dto/background-check-billing.dto.ts b/apps/api/src/background-checks/dto/background-check-billing.dto.ts index 7d1e9c249e..b796939c57 100644 --- a/apps/api/src/background-checks/dto/background-check-billing.dto.ts +++ b/apps/api/src/background-checks/dto/background-check-billing.dto.ts @@ -1,4 +1,4 @@ -import { IsString, IsUrl } from 'class-validator'; +import { IsNotEmpty, IsString, IsUrl } from 'class-validator'; export class BackgroundCheckSetupSessionDto { @IsString() @@ -12,6 +12,7 @@ export class BackgroundCheckSetupSessionDto { export class BackgroundCheckSetupSuccessDto { @IsString() + @IsNotEmpty() sessionId: string; } diff --git a/apps/api/src/billing/billing-entitlements.service.spec.ts b/apps/api/src/billing/billing-entitlements.service.spec.ts new file mode 100644 index 0000000000..f884f34747 --- /dev/null +++ b/apps/api/src/billing/billing-entitlements.service.spec.ts @@ -0,0 +1,85 @@ +import { db } from '@db'; +import { BillingEntitlementsService } from './billing-entitlements.service'; + +jest.mock('@db', () => ({ + db: { + organizationBillingSubscription: { + findUnique: jest.fn(), + }, + billingAuditEvent: { + create: jest.fn(), + }, + $transaction: jest.fn(), + }, +})); + +type MockTx = { + organizationBillingSubscription: { + upsert: jest.Mock; + updateMany: jest.Mock; + }; + billingUsageEvent: { + create: jest.Mock; + findUnique: jest.Mock; + }; +}; + +const mockedDb = db as unknown as { + organizationBillingSubscription: { findUnique: jest.Mock }; + billingAuditEvent: { create: jest.Mock }; + $transaction: jest.Mock; +}; + +describe('BillingEntitlementsService', () => { + let tx: MockTx; + let service: BillingEntitlementsService; + + beforeEach(() => { + jest.clearAllMocks(); + tx = { + organizationBillingSubscription: { + upsert: jest.fn(), + updateMany: jest.fn(), + }, + billingUsageEvent: { + create: jest.fn(), + findUnique: jest.fn(), + }, + }; + mockedDb.$transaction.mockImplementation( + (callback: (tx: MockTx) => Promise) => callback(tx), + ); + mockedDb.billingAuditEvent.create.mockResolvedValue({}); + service = new BillingEntitlementsService(); + }); + + it('applies same-period subscription updates that shorten currentPeriodEnd', async () => { + mockedDb.organizationBillingSubscription.findUnique.mockResolvedValue({ + stripeSubscriptionItemId: 'si_1', + currentPeriodStart: new Date('2026-04-01T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-05-01T00:00:00.000Z'), + }); + + await service.syncSubscriptionItem({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_5', + stripeSubscriptionId: 'sub_1', + stripeSubscriptionItemId: 'si_1', + stripePriceId: 'price_1', + stripeStatus: 'canceled', + currentPeriodStart: new Date('2026-04-01T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-04-15T00:00:00.000Z'), + includedQuantity: 5, + cancelAtPeriodEnd: false, + canceledAt: new Date('2026-04-10T00:00:00.000Z'), + stripeEventId: 'evt_1', + }); + + expect(tx.organizationBillingSubscription.upsert).toHaveBeenCalledWith( + expect.objectContaining({ + update: expect.not.objectContaining({ usedQuantity: 0 }), + }), + ); + expect(tx.billingUsageEvent.create).not.toHaveBeenCalled(); + }); +}); diff --git a/apps/api/src/billing/billing-entitlements.service.ts b/apps/api/src/billing/billing-entitlements.service.ts index a788d637a0..2ebff97ba9 100644 --- a/apps/api/src/billing/billing-entitlements.service.ts +++ b/apps/api/src/billing/billing-entitlements.service.ts @@ -1,6 +1,7 @@ import { HttpException, HttpStatus, Injectable } from '@nestjs/common'; -import { db } from '@db'; +import { Prisma, db } from '@db'; import type { BillingSkuKey } from '@trycompai/billing'; +import { refundIncludedUsageEvent } from './billing-included-usage-refunds'; import { type BillingConsumeResult, isAccessStatus, @@ -106,19 +107,17 @@ export class BillingEntitlementsService { }, }); if ( - existing?.currentPeriodEnd && - params.currentPeriodEnd && - existing.currentPeriodEnd.getTime() > params.currentPeriodEnd.getTime() + existing?.currentPeriodStart && + params.currentPeriodStart && + existing.currentPeriodStart.getTime() > + params.currentPeriodStart.getTime() ) { return; } const resetUsage = - !sameTime( - existing?.currentPeriodStart ?? null, - params.currentPeriodStart, - ) || - !sameTime(existing?.currentPeriodEnd ?? null, params.currentPeriodEnd); + !existing || + !sameTime(existing.currentPeriodStart, params.currentPeriodStart); await db.$transaction(async (tx) => { await tx.organizationBillingSubscription.upsert({ @@ -229,59 +228,9 @@ export class BillingEntitlementsService { skuKey: BillingSkuKey; sourceResourceId: string; reason: string; + tx?: Prisma.TransactionClient; }): Promise { - const consumeKey = [ - 'consume', - params.organizationId, - params.skuKey, - params.sourceResourceId, - ].join(':'); - const refundKey = [ - 'refund', - params.organizationId, - params.skuKey, - params.sourceResourceId, - params.reason, - ].join(':'); - - try { - await db.$transaction(async (tx) => { - const consumed = await tx.billingUsageEvent.findUnique({ - where: { idempotencyKey: consumeKey }, - select: { - stripeSubscriptionItemId: true, - periodStart: true, - periodEnd: true, - }, - }); - if (!consumed) return; - - await tx.billingUsageEvent.create({ - data: { - organizationId: params.organizationId, - skuKey: params.skuKey, - eventType: 'refund', - quantity: 1, - sourceResourceId: params.sourceResourceId, - idempotencyKey: refundKey, - stripeSubscriptionItemId: consumed.stripeSubscriptionItemId, - periodStart: consumed.periodStart, - periodEnd: consumed.periodEnd, - }, - }); - - await tx.organizationBillingSubscription.updateMany({ - where: { - organizationId: params.organizationId, - skuKey: params.skuKey, - usedQuantity: { gt: 0 }, - }, - data: { usedQuantity: { decrement: 1 } }, - }); - }); - } catch (error) { - if (!isUniqueConstraintError(error)) throw error; - } + await refundIncludedUsageEvent(params); } async writeAuditEvent(params: WriteBillingAuditEventParams): Promise { diff --git a/apps/api/src/billing/billing-included-usage-refunds.ts b/apps/api/src/billing/billing-included-usage-refunds.ts new file mode 100644 index 0000000000..174aac5003 --- /dev/null +++ b/apps/api/src/billing/billing-included-usage-refunds.ts @@ -0,0 +1,70 @@ +import { Prisma, db } from '@db'; +import type { BillingSkuKey } from '@trycompai/billing'; +import { isUniqueConstraintError } from './billing-entitlements.types'; + +export async function refundIncludedUsageEvent(params: { + organizationId: string; + skuKey: BillingSkuKey; + sourceResourceId: string; + reason: string; + tx?: Prisma.TransactionClient; +}): Promise { + const consumeKey = [ + 'consume', + params.organizationId, + params.skuKey, + params.sourceResourceId, + ].join(':'); + const refundKey = [ + 'refund', + params.organizationId, + params.skuKey, + params.sourceResourceId, + params.reason, + ].join(':'); + + try { + const writeRefund = async (tx: Prisma.TransactionClient) => { + const consumed = await tx.billingUsageEvent.findUnique({ + where: { idempotencyKey: consumeKey }, + select: { + stripeSubscriptionItemId: true, + periodStart: true, + periodEnd: true, + }, + }); + if (!consumed) return; + + await tx.billingUsageEvent.create({ + data: { + organizationId: params.organizationId, + skuKey: params.skuKey, + eventType: 'refund', + quantity: 1, + sourceResourceId: params.sourceResourceId, + idempotencyKey: refundKey, + stripeSubscriptionItemId: consumed.stripeSubscriptionItemId, + periodStart: consumed.periodStart, + periodEnd: consumed.periodEnd, + }, + }); + + await tx.organizationBillingSubscription.updateMany({ + where: { + organizationId: params.organizationId, + skuKey: params.skuKey, + usedQuantity: { gt: 0 }, + }, + data: { usedQuantity: { decrement: 1 } }, + }); + }; + + if (params.tx) { + await writeRefund(params.tx); + } else { + await db.$transaction(writeRefund); + } + } catch (error) { + if (!isUniqueConstraintError(error)) throw error; + } +} diff --git a/apps/api/src/billing/billing-preferences.spec.ts b/apps/api/src/billing/billing-preferences.spec.ts index c76c7aadf6..5af1391554 100644 --- a/apps/api/src/billing/billing-preferences.spec.ts +++ b/apps/api/src/billing/billing-preferences.spec.ts @@ -54,7 +54,9 @@ function createCustomer(): Stripe.Customer { describe('updateBillingPreferences', () => { beforeEach(() => { jest.clearAllMocks(); - organizationBillingFindUnique.mockResolvedValue({ stripeCustomerId: 'cus_1' }); + organizationBillingFindUnique.mockResolvedValue({ + stripeCustomerId: 'cus_1', + }); }); it('updates the Stripe customer fields used for B2B invoices', async () => { @@ -107,7 +109,9 @@ describe('updateBillingPreferences', () => { value: 'GB123456789', owner: { type: 'customer', customer: 'cus_1' }, }), - expect.objectContaining({ idempotencyKey: expect.stringContaining('cus_1') }), + expect.objectContaining({ + idempotencyKey: expect.stringContaining('cus_1'), + }), ); expect(result.preferences).toEqual( expect.objectContaining({ @@ -117,4 +121,48 @@ describe('updateBillingPreferences', () => { }), ); }); + + it('sends empty strings to Stripe to clear blank address fields', async () => { + const customersUpdate = jest.fn().mockResolvedValue({ + ...createCustomer(), + address: null, + }); + const taxIdsList = jest.fn().mockResolvedValue({ data: [] }); + + await updateBillingPreferences({ + stripeService: mockStripeService({ + customers: { update: customersUpdate }, + taxIds: { list: taxIdsList, create: jest.fn(), del: jest.fn() }, + }), + organizationId: 'org_1', + preferences: { + companyName: 'Test Company', + billingEmail: 'accounts@example.com', + purchaseOrder: null, + address: { + line1: '', + line2: null, + city: '', + state: null, + postalCode: '', + country: '', + }, + taxId: null, + }, + }); + + expect(customersUpdate).toHaveBeenCalledWith( + 'cus_1', + expect.objectContaining({ + address: { + line1: '', + line2: '', + city: '', + state: '', + postal_code: '', + country: '', + }, + }), + ); + }); }); diff --git a/apps/api/src/billing/billing-preferences.ts b/apps/api/src/billing/billing-preferences.ts index ac05f49d79..20f62c2511 100644 --- a/apps/api/src/billing/billing-preferences.ts +++ b/apps/api/src/billing/billing-preferences.ts @@ -92,7 +92,12 @@ export async function updateBillingPreferences(params: { address: toStripeAddress(params.preferences.address), invoice_settings: { custom_fields: params.preferences.purchaseOrder - ? [{ name: purchaseOrderFieldName, value: params.preferences.purchaseOrder }] + ? [ + { + name: purchaseOrderFieldName, + value: params.preferences.purchaseOrder, + }, + ] : '', }, metadata: { organizationId: params.organizationId }, @@ -125,7 +130,9 @@ function validatePreferences(preferences: BillingPreferencesInput): void { throw new BadRequestException('Unsupported tax ID type.'); } if ((type && !value) || (!type && value)) { - throw new BadRequestException('Tax ID type and value must be set together.'); + throw new BadRequestException( + 'Tax ID type and value must be set together.', + ); } } @@ -144,7 +151,10 @@ async function syncPrimaryTaxId(params: { return null; } - if (params.existingTaxId?.type === type && params.existingTaxId.value === value) { + if ( + params.existingTaxId?.type === type && + params.existingTaxId.value === value + ) { return params.existingTaxId; } @@ -191,7 +201,10 @@ function mapCustomerPreferences(params: { return { companyName: params.customer.name ?? params.customer.business_name ?? null, billingEmail: params.customer.email ?? null, - purchaseOrder: findInvoiceCustomFieldValue(params.customer, purchaseOrderFieldName), + purchaseOrder: findInvoiceCustomFieldValue( + params.customer, + purchaseOrderFieldName, + ), address: { line1: params.customer.address?.line1 ?? null, line2: params.customer.address?.line2 ?? null, @@ -211,18 +224,22 @@ function mapCustomerPreferences(params: { }; } -function toStripeAddress(address: BillingPreferences['address']): Stripe.AddressParam { +function toStripeAddress( + address: BillingPreferences['address'], +): Stripe.AddressParam { return { - line1: emptyToUndefined(address.line1), - line2: emptyToUndefined(address.line2), - city: emptyToUndefined(address.city), - state: emptyToUndefined(address.state), - postal_code: emptyToUndefined(address.postalCode), - country: emptyToUndefined(address.country)?.toUpperCase(), + line1: emptyToString(address.line1), + line2: emptyToString(address.line2), + city: emptyToString(address.city), + state: emptyToString(address.state), + postal_code: emptyToString(address.postalCode), + country: emptyToString(address.country).toUpperCase(), }; } -function createEmptyPreferences(params: { companyName: string | null }): BillingPreferences { +function createEmptyPreferences(params: { + companyName: string | null; +}): BillingPreferences { return { companyName: params.companyName, billingEmail: null, @@ -239,15 +256,20 @@ function createEmptyPreferences(params: { companyName: string | null }): Billing }; } -function findInvoiceCustomFieldValue(customer: Stripe.Customer, name: string): string | null { +function findInvoiceCustomFieldValue( + customer: Stripe.Customer, + name: string, +): string | null { return ( - customer.invoice_settings.custom_fields?.find((field) => field.name === name)?.value ?? null + customer.invoice_settings.custom_fields?.find( + (field) => field.name === name, + )?.value ?? null ); } -function emptyToUndefined(value: string | null): string | undefined { +function emptyToString(value: string | null): string { const trimmed = value?.trim(); - return trimmed ? trimmed : undefined; + return trimmed ? trimmed : ''; } function isDeletedCustomer( diff --git a/apps/api/src/billing/billing-redirect-urls.spec.ts b/apps/api/src/billing/billing-redirect-urls.spec.ts new file mode 100644 index 0000000000..0475fb44c6 --- /dev/null +++ b/apps/api/src/billing/billing-redirect-urls.spec.ts @@ -0,0 +1,14 @@ +import { BadRequestException } from '@nestjs/common'; +import { validateBillingRedirectUrl } from './billing-redirect-urls'; + +describe('validateBillingRedirectUrl', () => { + it('allows http only for local development hosts', () => { + expect(() => + validateBillingRedirectUrl('http://localhost:3000/org_1/billing'), + ).not.toThrow(); + + expect(() => + validateBillingRedirectUrl('http://app.trycomp.ai/org_1/billing'), + ).toThrow(BadRequestException); + }); +}); diff --git a/apps/api/src/billing/billing-redirect-urls.ts b/apps/api/src/billing/billing-redirect-urls.ts index eb94d43ded..7865048657 100644 --- a/apps/api/src/billing/billing-redirect-urls.ts +++ b/apps/api/src/billing/billing-redirect-urls.ts @@ -6,6 +6,7 @@ const allowedHosts = new Set([ 'app.trycomp.ai', 'app.staging.trycomp.ai', ]); +const localDevelopmentHosts = new Set(['localhost', '127.0.0.1']); export function validateBillingRedirectUrl(value: string): void { let url: URL; @@ -22,4 +23,8 @@ export function validateBillingRedirectUrl(value: string): void { if (!allowedHosts.has(url.hostname)) { throw new BadRequestException('Billing redirect URL is not allowed.'); } + + if (url.protocol === 'http:' && !localDevelopmentHosts.has(url.hostname)) { + throw new BadRequestException('Billing redirect URL must use HTTPS.'); + } } diff --git a/apps/api/src/billing/billing-stripe-config.ts b/apps/api/src/billing/billing-stripe-config.ts new file mode 100644 index 0000000000..74f1c4a920 --- /dev/null +++ b/apps/api/src/billing/billing-stripe-config.ts @@ -0,0 +1,10 @@ +import { BadRequestException } from '@nestjs/common'; +import type { StripeService } from '../stripe/stripe.service'; + +export function assertStripeBillingConfigured( + stripeService: StripeService, +): void { + if (stripeService.isConfigured()) return; + + throw new BadRequestException('Stripe billing is not configured.'); +} diff --git a/apps/api/src/billing/billing-stripe-ids.ts b/apps/api/src/billing/billing-stripe-ids.ts new file mode 100644 index 0000000000..d51132dad5 --- /dev/null +++ b/apps/api/src/billing/billing-stripe-ids.ts @@ -0,0 +1,7 @@ +export function extractStripeId( + value: string | { id?: string } | null, +): string | null { + if (!value) return null; + if (typeof value === 'string') return value; + return value.id ?? null; +} diff --git a/apps/api/src/billing/billing-usage.spec.ts b/apps/api/src/billing/billing-usage.spec.ts index 4340e73585..610fe65989 100644 --- a/apps/api/src/billing/billing-usage.spec.ts +++ b/apps/api/src/billing/billing-usage.spec.ts @@ -39,6 +39,7 @@ describe('listBillingUsageRows', () => { { id: 'ptr_1', providerRunId: 'run_1', + billingUsageSourceId: 'pending:run_1', createdAt: new Date('2026-04-30T11:00:00.000Z'), updatedAt: new Date('2026-04-30T11:05:00.000Z'), }, @@ -50,6 +51,12 @@ describe('listBillingUsageRows', () => { sourceResourceId: 'mem_1', stripeInvoiceId: null, }, + { + skuKey: 'pentest_monthly_5', + eventType: 'consume', + sourceResourceId: 'pending:run_1', + stripeInvoiceId: null, + }, ]); const rows = await listBillingUsageRows({ @@ -74,6 +81,7 @@ describe('listBillingUsageRows', () => { expect.objectContaining({ service: 'Penetration Test', details: 'run_1', + billingType: 'Subscription allowance', subscriptionRemaining: 4, }), expect.objectContaining({ @@ -84,4 +92,37 @@ describe('listBillingUsageRows', () => { }), ]); }); + + it('labels legacy pentest rows independently of current subscription state', async () => { + backgroundCheckFindMany.mockResolvedValue([]); + pentestRunFindMany.mockResolvedValue([ + { + id: 'ptr_legacy', + providerRunId: 'run_legacy', + billingUsageSourceId: null, + createdAt: new Date('2026-04-30T11:00:00.000Z'), + updatedAt: new Date('2026-04-30T11:05:00.000Z'), + }, + ]); + billingUsageEventFindMany.mockResolvedValue([]); + + const rows = await listBillingUsageRows({ + organizationId: 'org_1', + subscriptions: [ + { + skuKey: 'pentest_monthly_5', + includedQuantity: 5, + usedQuantity: 1, + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }, + ], + }); + + expect(rows[0]).toEqual( + expect.objectContaining({ + service: 'Penetration Test', + billingType: 'Trial credit', + }), + ); + }); }); diff --git a/apps/api/src/billing/billing-usage.ts b/apps/api/src/billing/billing-usage.ts index f5108fcbf6..26c64a787b 100644 --- a/apps/api/src/billing/billing-usage.ts +++ b/apps/api/src/billing/billing-usage.ts @@ -39,6 +39,7 @@ export async function listBillingUsageRows(params: { select: { id: true, providerRunId: true, + billingUsageSourceId: true, createdAt: true, updatedAt: true, }, @@ -75,28 +76,38 @@ export async function listBillingUsageRows(params: { skuKey: backgroundCheckSku, details: `${request.employeeName} (${request.employeeEmail})`, status: formatStatus(request.status), - billingType: formatBillingType(usage?.eventType, usage?.stripeInvoiceId), + billingType: formatBillingType( + usage?.eventType, + usage?.stripeInvoiceId, + ), createdAt: request.createdAt, updatedAt: request.updatedAt, subscriptions: params.subscriptions, }); }), - ...pentestRuns.map((run) => - toBillingUsageRow({ + ...pentestRuns.map((run) => { + const usage = run.billingUsageSourceId + ? usageBySource.get(run.billingUsageSourceId) + : undefined; + return toBillingUsageRow({ id: run.id, service: 'Penetration Test', skuKey: pentestSku, details: run.providerRunId, status: 'Created', - billingType: formatPentestBillingType(params.subscriptions), + billingType: usage + ? formatBillingType(usage.eventType, usage.stripeInvoiceId) + : 'Trial credit', createdAt: run.createdAt, updatedAt: run.updatedAt, subscriptions: params.subscriptions, - }), - ), + }); + }), ]; - return rows.sort((first, second) => second.createdAt.localeCompare(first.createdAt)); + return rows.sort((first, second) => + second.createdAt.localeCompare(first.createdAt), + ); } function toBillingUsageRow(params: { @@ -110,7 +121,9 @@ function toBillingUsageRow(params: { updatedAt: Date; subscriptions: SubscriptionSummary[]; }): BillingUsageRow { - const subscription = params.subscriptions.find((item) => item.skuKey === params.skuKey); + const subscription = params.subscriptions.find( + (item) => item.skuKey === params.skuKey, + ); const remaining = subscription ? Math.max(subscription.includedQuantity - subscription.usedQuantity, 0) : null; @@ -126,22 +139,21 @@ function toBillingUsageRow(params: { updatedAt: params.updatedAt.toISOString(), subscriptionRemaining: remaining, subscriptionIncluded: subscription?.includedQuantity ?? null, - subscriptionPeriodEnd: subscription?.currentPeriodEnd?.toISOString() ?? null, + subscriptionPeriodEnd: + subscription?.currentPeriodEnd?.toISOString() ?? null, }; } -function formatBillingType(eventType?: string, stripeInvoiceId?: string | null): string { +function formatBillingType( + eventType?: string, + stripeInvoiceId?: string | null, +): string { if (eventType === 'consume') return 'Subscription allowance'; - if (eventType === 'one_time') return stripeInvoiceId ? 'One-time invoice' : 'One-time'; + if (eventType === 'one_time') + return stripeInvoiceId ? 'One-time invoice' : 'One-time'; return 'Legacy / manual'; } -function formatPentestBillingType(subscriptions: SubscriptionSummary[]): string { - return subscriptions.some((item) => item.skuKey === pentestSku) - ? 'Subscription allowance' - : 'Trial credit'; -} - function formatStatus(status: string): string { return status .split('_') diff --git a/apps/api/src/billing/billing-webhook.service.ts b/apps/api/src/billing/billing-webhook.service.ts index d409a2de26..d15842adc5 100644 --- a/apps/api/src/billing/billing-webhook.service.ts +++ b/apps/api/src/billing/billing-webhook.service.ts @@ -32,9 +32,17 @@ export class BillingWebhookService { if (!secret) throw new BadRequestException('Stripe webhook secret is not configured.'); - const event = this.stripeService - .getClient() - .webhooks.constructEvent(params.rawBody, params.signature, secret); + const stripe = this.stripeService.getClient(); + let event: Stripe.Event; + try { + event = stripe.webhooks.constructEvent( + params.rawBody, + params.signature, + secret, + ); + } catch { + throw new BadRequestException('Invalid Stripe webhook signature.'); + } const claim = await claimStripeWebhookEvent({ stripeEventId: event.id, eventType: event.type, diff --git a/apps/api/src/billing/billing.service.spec.ts b/apps/api/src/billing/billing.service.spec.ts index 151db5f412..f0a4cddca1 100644 --- a/apps/api/src/billing/billing.service.spec.ts +++ b/apps/api/src/billing/billing.service.spec.ts @@ -25,6 +25,7 @@ const organizationBillingCreate = mockedDb.organizationBilling function mockStripeService(client: unknown): StripeService { return { + isConfigured: () => client !== null, getClient: () => client, } as unknown as StripeService; } @@ -106,4 +107,17 @@ describe('BillingService', () => { }), ).rejects.toBeInstanceOf(BadRequestException); }); + + it('returns a controlled error when Stripe is not configured', async () => { + const service = new BillingService(mockStripeService(null)); + + await expect( + service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_5', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }), + ).rejects.toBeInstanceOf(BadRequestException); + }); }); diff --git a/apps/api/src/billing/billing.service.ts b/apps/api/src/billing/billing.service.ts index a00663f237..a1e718a812 100644 --- a/apps/api/src/billing/billing.service.ts +++ b/apps/api/src/billing/billing.service.ts @@ -5,7 +5,6 @@ import { } from '@nestjs/common'; import { db } from '@db'; import { - type BillingSku, type BillingSkuKey, getBillingSku, isSubscriptionBillingSkuKey, @@ -19,6 +18,8 @@ import { updateBillingPreferences, } from './billing-preferences'; import { validateBillingRedirectUrl } from './billing-redirect-urls'; +import { assertStripeBillingConfigured } from './billing-stripe-config'; +import { extractStripeId } from './billing-stripe-ids'; import type { BillingStatus } from './billing.types'; import { listBillingUsageRows } from './billing-usage'; @@ -27,27 +28,32 @@ export class BillingService { constructor(private readonly stripeService: StripeService) {} async getStatus(organizationId: string): Promise { - const [organization, billing, subscriptions, backgroundChecks, penetrationTests] = - await Promise.all([ - db.organization.findUniqueOrThrow({ - where: { id: organizationId }, - select: { name: true }, - }), - db.organizationBilling.findUnique({ - where: { organizationId }, - select: { - stripeCustomerId: true, - stripePaymentMethodId: true, - paymentMethodUpdatedAt: true, - }, - }), - db.organizationBillingSubscription.findMany({ - where: { organizationId }, - orderBy: { skuKey: 'asc' }, - }), - db.backgroundCheckRequest.count({ where: { organizationId } }), - db.securityPenetrationTestRun.count({ where: { organizationId } }), - ]); + const [ + organization, + billing, + subscriptions, + backgroundChecks, + penetrationTests, + ] = await Promise.all([ + db.organization.findUniqueOrThrow({ + where: { id: organizationId }, + select: { name: true }, + }), + db.organizationBilling.findUnique({ + where: { organizationId }, + select: { + stripeCustomerId: true, + stripePaymentMethodId: true, + paymentMethodUpdatedAt: true, + }, + }), + db.organizationBillingSubscription.findMany({ + where: { organizationId }, + orderBy: { skuKey: 'asc' }, + }), + db.backgroundCheckRequest.count({ where: { organizationId } }), + db.securityPenetrationTestRun.count({ where: { organizationId } }), + ]); const [invoices, preferences, usageRows] = await Promise.all([ listBillingInvoices({ stripeService: this.stripeService, @@ -86,6 +92,8 @@ export class BillingService { organizationId: string; preferences: BillingPreferencesInput; }): Promise { + assertStripeBillingConfigured(this.stripeService); + const result = await updateBillingPreferences({ stripeService: this.stripeService, organizationId: params.organizationId, @@ -114,6 +122,7 @@ export class BillingService { }): Promise<{ url: string }> { validateBillingRedirectUrl(params.successUrl); validateBillingRedirectUrl(params.cancelUrl); + assertStripeBillingConfigured(this.stripeService); const stripe = this.stripeService.getClient(); const customerId = await findOrCreateBillingCustomer({ @@ -147,6 +156,8 @@ export class BillingService { organizationId: string; sessionId: string; }): Promise<{ success: true }> { + assertStripeBillingConfigured(this.stripeService); + const stripe = this.stripeService.getClient(); const session = await stripe.checkout.sessions.retrieve(params.sessionId, { expand: ['setup_intent'], @@ -217,6 +228,7 @@ export class BillingService { returnUrl: string; }): Promise<{ url: string }> { validateBillingRedirectUrl(params.returnUrl); + assertStripeBillingConfigured(this.stripeService); const billing = await db.organizationBilling.findUnique({ where: { organizationId: params.organizationId }, @@ -249,6 +261,7 @@ export class BillingService { if (!isSubscriptionBillingSkuKey(params.skuKey)) { throw new BadRequestException('Unknown subscription SKU.'); } + assertStripeBillingConfigured(this.stripeService); const sku = getBillingSku({ skuKey: params.skuKey }); const customerId = await findOrCreateBillingCustomer({ @@ -284,16 +297,4 @@ export class BillingService { } return { url: session.url }; } - - getOneTimeBackgroundCheckSku(): BillingSku { - return getBillingSku({ skuKey: 'background_check_one_time' }); - } -} - -function extractStripeId( - value: string | { id?: string } | null, -): string | null { - if (!value) return null; - if (typeof value === 'string') return value; - return value.id ?? null; } diff --git a/apps/api/src/billing/dto/billing.dto.ts b/apps/api/src/billing/dto/billing.dto.ts index 9218231068..e0ba0a6b0a 100644 --- a/apps/api/src/billing/dto/billing.dto.ts +++ b/apps/api/src/billing/dto/billing.dto.ts @@ -1,5 +1,13 @@ import { subscriptionBillingSkuKeys } from '@trycompai/billing'; -import { IsEmail, IsIn, IsOptional, IsString, IsUrl, Length } from 'class-validator'; +import { + IsEmail, + IsIn, + IsNotEmpty, + IsOptional, + IsString, + IsUrl, + Length, +} from 'class-validator'; import { billingTaxIdTypes } from '../billing-preferences'; export class BillingSetupSessionDto { @@ -14,6 +22,7 @@ export class BillingSetupSessionDto { export class BillingSetupSuccessDto { @IsString() + @IsNotEmpty() sessionId: string; } diff --git a/apps/api/src/billing/stripe-webhook-records.spec.ts b/apps/api/src/billing/stripe-webhook-records.spec.ts new file mode 100644 index 0000000000..a0dc29f321 --- /dev/null +++ b/apps/api/src/billing/stripe-webhook-records.spec.ts @@ -0,0 +1,124 @@ +import { Prisma, db } from '@db'; +import { claimStripeWebhookEvent } from './stripe-webhook-records'; + +jest.mock('@db', () => { + class PrismaClientKnownRequestError extends Error { + code: string; + + constructor(message: string, params: { code: string }) { + super(message); + this.code = params.code; + } + } + + return { + Prisma: { + PrismaClientKnownRequestError, + }, + db: { + stripeWebhookEvent: { + create: jest.fn(), + updateMany: jest.fn(), + update: jest.fn(), + }, + }, + }; +}); + +const stripeWebhookEventCreate = db.stripeWebhookEvent + .create as unknown as jest.Mock; +const stripeWebhookEventUpdateMany = db.stripeWebhookEvent + .updateMany as unknown as jest.Mock; + +describe('stripe webhook records', () => { + beforeEach(() => { + jest.clearAllMocks(); + jest + .useFakeTimers() + .setSystemTime(new Date('2026-04-30T12:00:00.000Z').getTime()); + }); + + afterEach(() => { + jest.useRealTimers(); + }); + + it('claims a new Stripe webhook event', async () => { + stripeWebhookEventCreate.mockResolvedValue({}); + + await expect( + claimStripeWebhookEvent({ + stripeEventId: 'evt_1', + eventType: 'invoice.paid', + payload: { id: 'in_1' }, + }), + ).resolves.toEqual({ status: 'claimed' }); + + expect(stripeWebhookEventCreate).toHaveBeenCalledWith({ + data: expect.objectContaining({ + stripeEventId: 'evt_1', + status: 'processing', + }), + }); + }); + + it('atomically reclaims failed webhook events', async () => { + stripeWebhookEventCreate.mockRejectedValue( + new Prisma.PrismaClientKnownRequestError('Unique', { + code: 'P2002', + clientVersion: 'test', + }), + ); + stripeWebhookEventUpdateMany.mockResolvedValue({ count: 1 }); + + await expect( + claimStripeWebhookEvent({ + stripeEventId: 'evt_1', + eventType: 'invoice.paid', + payload: { id: 'in_1' }, + }), + ).resolves.toEqual({ status: 'claimed' }); + + expect(stripeWebhookEventUpdateMany).toHaveBeenCalledWith( + expect.objectContaining({ + where: expect.objectContaining({ + stripeEventId: 'evt_1', + OR: expect.arrayContaining([ + expect.objectContaining({ status: 'failed' }), + ]), + }), + data: expect.objectContaining({ status: 'processing', error: null }), + }), + ); + }); + + it('reclaims only one stale processing webhook retry', async () => { + stripeWebhookEventCreate.mockRejectedValue( + new Prisma.PrismaClientKnownRequestError('Unique', { + code: 'P2002', + clientVersion: 'test', + }), + ); + stripeWebhookEventUpdateMany.mockResolvedValue({ count: 0 }); + + await expect( + claimStripeWebhookEvent({ + stripeEventId: 'evt_1', + eventType: 'invoice.paid', + payload: { id: 'in_1' }, + }), + ).resolves.toEqual({ status: 'duplicate' }); + + expect(stripeWebhookEventUpdateMany).toHaveBeenCalledWith( + expect.objectContaining({ + where: expect.objectContaining({ + OR: expect.arrayContaining([ + expect.objectContaining({ + status: 'processing', + processedAt: { lt: new Date('2026-04-30T11:45:00.000Z') }, + }), + ]), + }), + }), + ); + }); +}); diff --git a/apps/api/src/billing/stripe-webhook-records.ts b/apps/api/src/billing/stripe-webhook-records.ts index 52638d7f67..76d28deb1b 100644 --- a/apps/api/src/billing/stripe-webhook-records.ts +++ b/apps/api/src/billing/stripe-webhook-records.ts @@ -1,5 +1,7 @@ import { Prisma, db } from '@db'; +const processingReclaimAfterMs = 15 * 60 * 1000; + export type StripeWebhookClaim = | { status: 'claimed' } | { status: 'duplicate' }; @@ -21,17 +23,28 @@ export async function claimStripeWebhookEvent(params: { return { status: 'claimed' }; } catch (error) { if (!isUniqueConstraintError(error)) throw error; - const existing = await db.stripeWebhookEvent.findUnique({ - where: { stripeEventId: params.stripeEventId }, - select: { status: true }, + const reclaimBefore = new Date(Date.now() - processingReclaimAfterMs); + const reclaimed = await db.stripeWebhookEvent.updateMany({ + where: { + stripeEventId: params.stripeEventId, + OR: [ + { status: 'failed' }, + { status: 'processing', processedAt: { lt: reclaimBefore } }, + ], + }, + data: { + eventType: params.eventType, + payload: params.payload, + status: 'processing', + error: null, + processedAt: new Date(), + }, }); - if (existing?.status === 'failed') { - await db.stripeWebhookEvent.update({ - where: { stripeEventId: params.stripeEventId }, - data: { status: 'processing', error: null }, - }); + + if (reclaimed.count === 1) { return { status: 'claimed' }; } + return { status: 'duplicate' }; } } diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.billing.spec.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.billing.spec.ts new file mode 100644 index 0000000000..f3c9712c68 --- /dev/null +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.billing.spec.ts @@ -0,0 +1,177 @@ +import { db } from '@db'; +import type { BillingEntitlementsService } from '../billing/billing-entitlements.service'; +import type { PentestCreditsService } from './pentest-credits.service'; +import { SecurityPenetrationTestsService } from './security-penetration-tests.service'; + +const mockMacedPentestsCreate = jest.fn(); + +jest.mock( + '@maced/api-client', + () => ({ + createMacedClient: () => ({ + pentests: { + create: mockMacedPentestsCreate, + }, + }), + MacedApiError: class MacedApiError extends Error {}, + MacedWebhookSignatureError: class MacedWebhookSignatureError extends Error { + code = 'invalid_signature'; + }, + MacedClient: { + webhooks: { + constructEvent: jest.fn(), + }, + }, + }), + { virtual: true }, +); + +jest.mock('@db', () => ({ + db: { + securityPenetrationTestRun: { + upsert: jest.fn(), + updateMany: jest.fn(), + findUnique: jest.fn(), + }, + secret: { + upsert: jest.fn(), + }, + $transaction: jest.fn(), + }, +})); + +type MockDb = { + securityPenetrationTestRun: { + upsert: jest.Mock; + updateMany: jest.Mock; + findUnique: jest.Mock; + }; + secret: { + upsert: jest.Mock; + }; + $transaction: jest.Mock; +}; + +describe('SecurityPenetrationTestsService billing usage', () => { + const originalMacedApiKey = process.env.MACED_API_KEY; + const mockedDb = db as unknown as MockDb; + const credits: jest.Mocked< + Pick + > = { + getStatus: jest.fn(), + debitOrThrow: jest.fn(), + refund: jest.fn(), + }; + const billingEntitlements: jest.Mocked< + Pick< + BillingEntitlementsService, + 'tryConsumeIncludedUsage' | 'refundIncludedUsage' + > + > = { + tryConsumeIncludedUsage: jest.fn(), + refundIncludedUsage: jest.fn(), + }; + let service: SecurityPenetrationTestsService; + + beforeEach(() => { + jest.clearAllMocks(); + process.env.MACED_API_KEY = 'mc_dev_test_maced_api_key'; + mockMacedPentestsCreate.mockResolvedValue({ + id: 'run_subscription', + status: 'provisioning', + }); + credits.debitOrThrow.mockResolvedValue({ + balance: 4, + totalGranted: 5, + totalConsumed: 1, + lastGrantSource: 'trial', + }); + credits.refund.mockResolvedValue(); + billingEntitlements.tryConsumeIncludedUsage.mockResolvedValue({ + status: 'consumed', + subscriptionId: 'obs_1', + }); + billingEntitlements.refundIncludedUsage.mockResolvedValue(); + mockedDb.securityPenetrationTestRun.upsert.mockResolvedValue({}); + mockedDb.securityPenetrationTestRun.updateMany.mockResolvedValue({ + count: 1, + }); + mockedDb.$transaction.mockImplementation( + (callback: (tx: MockDb) => Promise) => callback(mockedDb), + ); + service = new SecurityPenetrationTestsService( + credits as unknown as PentestCreditsService, + billingEntitlements as unknown as BillingEntitlementsService, + ); + }); + + afterAll(() => { + process.env.MACED_API_KEY = originalMacedApiKey; + }); + + it('persists the subscription usage source on subscription-backed runs', async () => { + await service.createReport('org_123', { + targetUrl: 'https://app.example.com', + repoUrl: 'https://github.com/org/repo', + }); + + expect(credits.debitOrThrow).not.toHaveBeenCalled(); + expect(mockedDb.securityPenetrationTestRun.upsert).toHaveBeenCalledWith( + expect.objectContaining({ + create: expect.objectContaining({ + billingUsageSourceId: expect.stringMatching(/^pending:/), + }), + }), + ); + }); + + it('refunds subscription usage on terminal failure for subscription-backed runs', async () => { + mockedDb.securityPenetrationTestRun.findUnique.mockResolvedValue({ + organizationId: 'org_123', + billingUsageSourceId: 'pending:run_subscription', + }); + const refundInvoker = service as unknown as { + refundOnTerminalFailure: ( + providerRunId: string, + eventType: 'pentest.failed', + ) => Promise; + }; + + await refundInvoker.refundOnTerminalFailure( + 'run_subscription', + 'pentest.failed', + ); + + expect(billingEntitlements.refundIncludedUsage).toHaveBeenCalledWith({ + organizationId: 'org_123', + skuKey: 'pentest_monthly_5', + sourceResourceId: 'pending:run_subscription', + reason: 'pentest.failed', + tx: mockedDb, + }); + expect(credits.refund).not.toHaveBeenCalled(); + }); + + it('keeps legacy credit refunds for terminal failures without subscription usage', async () => { + mockedDb.securityPenetrationTestRun.findUnique.mockResolvedValue({ + organizationId: 'org_123', + billingUsageSourceId: null, + }); + const refundInvoker = service as unknown as { + refundOnTerminalFailure: ( + providerRunId: string, + eventType: 'pentest.failed', + ) => Promise; + }; + + await refundInvoker.refundOnTerminalFailure('run_legacy', 'pentest.failed'); + + expect(credits.refund).toHaveBeenCalledWith( + 'org_123', + 'run_legacy', + 'pentest.failed', + mockedDb, + ); + expect(billingEntitlements.refundIncludedUsage).not.toHaveBeenCalled(); + }); +}); diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts index 4e3aed0f13..843ebd75d4 100644 --- a/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts @@ -355,6 +355,7 @@ export class SecurityPenetrationTestsService { const ownershipPersisted = await this.persistRunOwnershipWithRetry( organizationId, providerRunId, + consumedSubscriptionAllowance ? billingUsageSourceId : null, ); if (!ownershipPersisted) { // We debited and Maced created the run, but our DB rejected the @@ -668,7 +669,7 @@ export class SecurityPenetrationTestsService { const run = await tx.securityPenetrationTestRun.findUnique({ where: { providerRunId }, - select: { organizationId: true }, + select: { organizationId: true, billingUsageSourceId: true }, }); if (!run) { // Vanishingly rare race; abort the transaction so the claim @@ -676,10 +677,17 @@ export class SecurityPenetrationTestsService { throw new Error(`Run row vanished after claim for ${providerRunId}`); } - // Pass the tx client through so the wallet write happens in - // the same transaction as the claim. If this throws, the claim - // is undone and the error propagates up to handleWebhook → Maced - // sees 5xx and redelivers, allowing retry. + if (run.billingUsageSourceId) { + await this.billingEntitlements.refundIncludedUsage({ + organizationId: run.organizationId, + skuKey: 'pentest_monthly_5', + sourceResourceId: run.billingUsageSourceId, + reason: eventType, + tx, + }); + return; + } + await this.credits.refund( run.organizationId, providerRunId, @@ -939,6 +947,7 @@ export class SecurityPenetrationTestsService { private async persistRunOwnership( organizationId: string, reportId: string, + billingUsageSourceId: string | null, ): Promise { // Defensive: if a row already exists for this providerRunId, do NOT // overwrite its organizationId. Maced generates unique providerRunIds @@ -955,6 +964,7 @@ export class SecurityPenetrationTestsService { create: { organizationId, providerRunId: reportId, + billingUsageSourceId, }, update: {}, }); @@ -963,10 +973,15 @@ export class SecurityPenetrationTestsService { private async persistRunOwnershipWithRetry( organizationId: string, reportId: string, + billingUsageSourceId: string | null, ): Promise { for (let attempt = 1; attempt <= 3; attempt += 1) { try { - await this.persistRunOwnership(organizationId, reportId); + await this.persistRunOwnership( + organizationId, + reportId, + billingUsageSourceId, + ); return true; } catch (error) { this.logger.error( diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingPreferencesForm.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingPreferencesForm.tsx index 5cad22c38a..947a94dc53 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingPreferencesForm.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingPreferencesForm.tsx @@ -17,7 +17,6 @@ import { import type React from 'react'; import { Controller, useForm } from 'react-hook-form'; import { toast } from 'sonner'; -import type { BackgroundCheckBillingStatus, BillingPreferences } from './types'; import { billingCountries, billingPreferencesSchema, @@ -28,6 +27,7 @@ import { toBillingPreferencesPayload, type BillingPreferencesFormValues, } from './billingPreferencesFormSchema'; +import type { BackgroundCheckBillingStatus, BillingPreferences } from './types'; interface BillingPreferencesFormProps { organizationId: string; @@ -215,7 +215,11 @@ export function BillingPreferencesForm({
- + field.onChange(value ?? '')} disabled={disabled} > - + {getTaxIdTypeLabel(field.value)} @@ -260,7 +264,6 @@ export function BillingPreferencesForm({ />
-
diff --git a/packages/billing/package.json b/packages/billing/package.json index 2fcdca892e..bd2d1ad328 100644 --- a/packages/billing/package.json +++ b/packages/billing/package.json @@ -1,22 +1,22 @@ { - "name": "@trycompai/billing", - "version": "1.0.0", - "private": true, - "exports": { - ".": "./src/index.ts" - }, - "main": "src/index.ts", - "types": "src/index.ts", - "scripts": { - "clean": "rm -rf .turbo node_modules", - "format": "prettier --write .", - "lint": "prettier --check .", - "test": "bun test src/*.test.ts", - "typecheck": "tsc --noEmit" - }, - "sideEffects": false, - "devDependencies": { - "@trycompai/tsconfig": "workspace:*", - "typescript": "^5.8.3" - } + "name": "@trycompai/billing", + "version": "1.0.0", + "private": true, + "exports": { + ".": "./src/index.ts" + }, + "main": "src/index.ts", + "types": "src/index.ts", + "scripts": { + "clean": "rm -rf .turbo node_modules", + "format": "prettier --write .", + "lint": "prettier --check .", + "test": "bun test src/*.test.ts", + "typecheck": "tsc --noEmit" + }, + "sideEffects": false, + "devDependencies": { + "@trycompai/tsconfig": "workspace:*", + "typescript": "^5.8.3" + } } diff --git a/packages/billing/src/index.ts b/packages/billing/src/index.ts index 49524dd13b..5de001d3ff 100644 --- a/packages/billing/src/index.ts +++ b/packages/billing/src/index.ts @@ -31,6 +31,8 @@ export const subscriptionBillingSkuKeys = [ 'pentest_monthly_5', ] as const satisfies readonly BillingSkuKey[]; +export type SubscriptionBillingSkuKey = (typeof subscriptionBillingSkuKeys)[number]; + export type BillingCatalog = { environment: BillingCatalogEnvironment; products: Record; @@ -54,8 +56,7 @@ const testSkus = { key: 'background_checks_monthly_25', productKey: 'background_check', name: 'Background Checks Monthly', - description: - 'Monthly Comp AI background check package. Includes 25 checks per month.', + description: 'Monthly Comp AI background check package. Includes 25 checks per month.', cadence: 'month', currency: 'usd', unitAmount: 24900, @@ -71,8 +72,7 @@ const testSkus = { key: 'pentest_monthly_5', productKey: 'pentest', name: 'Penetration Tests Monthly', - description: - 'Monthly Comp AI penetration testing package. Includes 5 scans per month.', + description: 'Monthly Comp AI penetration testing package. Includes 5 scans per month.', cadence: 'month', currency: 'usd', unitAmount: 39900, @@ -90,9 +90,7 @@ export const billingCatalogs = { test: createCatalog({ environment: 'test', skus: testSkus }), } satisfies Record; -export function getBillingCatalog( - environment: BillingCatalogEnvironment = 'test', -): BillingCatalog { +export function getBillingCatalog(environment: BillingCatalogEnvironment = 'test'): BillingCatalog { return billingCatalogs[environment]; } @@ -110,13 +108,11 @@ export function getBillingSkuByStripePriceId(params: { }): BillingSku | null { const catalog = getBillingCatalog(params.environment); return ( - Object.values(catalog.skus).find( - (sku) => sku.stripePriceId === params.stripePriceId, - ) ?? null + Object.values(catalog.skus).find((sku) => sku.stripePriceId === params.stripePriceId) ?? null ); } -export function isSubscriptionBillingSkuKey(value: string): value is BillingSkuKey { +export function isSubscriptionBillingSkuKey(value: string): value is SubscriptionBillingSkuKey { return subscriptionBillingSkuKeys.some((skuKey) => skuKey === value); } @@ -131,10 +127,8 @@ function createCatalog(params: { pentest: params.skus.pentest_monthly_5.stripeProductId, }, prices: { - background_check_one_time: - params.skus.background_check_one_time.stripePriceId, - background_checks_monthly_25: - params.skus.background_checks_monthly_25.stripePriceId, + background_check_one_time: params.skus.background_check_one_time.stripePriceId, + background_checks_monthly_25: params.skus.background_checks_monthly_25.stripePriceId, pentest_monthly_5: params.skus.pentest_monthly_5.stripePriceId, }, skus: params.skus, diff --git a/packages/billing/tsconfig.json b/packages/billing/tsconfig.json index df12fbc94a..20fef847d2 100644 --- a/packages/billing/tsconfig.json +++ b/packages/billing/tsconfig.json @@ -1,9 +1,9 @@ { - "extends": "@trycompai/tsconfig/base.json", - "compilerOptions": { - "incremental": true, - "tsBuildInfoFile": "node_modules/.cache/tsbuildinfo.json" - }, - "include": ["src"], - "exclude": ["node_modules"] + "extends": "@trycompai/tsconfig/base.json", + "compilerOptions": { + "incremental": true, + "tsBuildInfoFile": "node_modules/.cache/tsbuildinfo.json" + }, + "include": ["src"], + "exclude": ["node_modules"] } diff --git a/packages/db/prisma/migrations/20260430183000_pentest_billing_usage_source/migration.sql b/packages/db/prisma/migrations/20260430183000_pentest_billing_usage_source/migration.sql new file mode 100644 index 0000000000..0b467bc6df --- /dev/null +++ b/packages/db/prisma/migrations/20260430183000_pentest_billing_usage_source/migration.sql @@ -0,0 +1,2 @@ +ALTER TABLE "security_penetration_test_runs" +ADD COLUMN "billing_usage_source_id" TEXT; diff --git a/packages/db/prisma/schema/security-penetration-test-run.prisma b/packages/db/prisma/schema/security-penetration-test-run.prisma index a4cfa455b6..99e08f3854 100644 --- a/packages/db/prisma/schema/security-penetration-test-run.prisma +++ b/packages/db/prisma/schema/security-penetration-test-run.prisma @@ -1,9 +1,10 @@ model SecurityPenetrationTestRun { - id String @id @default(dbgenerated("generate_prefixed_cuid('ptr'::text)")) - organizationId String @map("organization_id") - providerRunId String @map("provider_run_id") - createdAt DateTime @default(now()) @map("created_at") - updatedAt DateTime @updatedAt @map("updated_at") + id String @id @default(dbgenerated("generate_prefixed_cuid('ptr'::text)")) + organizationId String @map("organization_id") + providerRunId String @map("provider_run_id") + billingUsageSourceId String? @map("billing_usage_source_id") + createdAt DateTime @default(now()) @map("created_at") + updatedAt DateTime @updatedAt @map("updated_at") /// Set the first time we refund this run's credit (e.g. on /// `pentest.failed` / `pentest.cancelled` webhooks). Used to make the diff --git a/scripts/check-generated-prisma-schemas.js b/scripts/check-generated-prisma-schemas.js index af0f5d6637..12052d3145 100644 --- a/scripts/check-generated-prisma-schemas.js +++ b/scripts/check-generated-prisma-schemas.js @@ -36,6 +36,15 @@ for (const appName of readdirSync(appsDir)) { for (const file of canonicalFiles) { if (!localFiles.has(file)) { errors.push(`${appName}: missing Prisma schema fragment ${file}`); + continue; + } + + const canonicalPath = path.join(canonicalDir, file); + const localPath = path.join(schemaDir, file); + const canonicalContents = readFileSync(canonicalPath, 'utf8'); + const localContents = readFileSync(localPath, 'utf8'); + if (localContents !== canonicalContents) { + errors.push(`${appName}: Prisma schema fragment ${file} is out of sync`); } } From f7e5e9fdb56f92a8d9e00f6580108d515c04d017 Mon Sep 17 00:00:00 2001 From: Lewis Carhart Date: Fri, 1 May 2026 00:54:29 +0100 Subject: [PATCH 06/20] feat(billing): enhance billing services and subscription management - Updated BackgroundCheckBillingService to utilize resolveBillingCatalogEnvironment for improved SKU resolution. - Refactored BackgroundCheckPaymentService to streamline payment processing and error handling. - Introduced new billing setup session management with createBillingSetupSession and handleBillingSetupSuccess functions. - Enhanced BillingService to support subscription plan changes and trial eligibility checks. - Added new billing-related utility functions for managing subscriptions and usage tracking. - Updated tests to cover new billing functionalities and ensure robust integration with Stripe. --- apps/api/package.json | 1 + .../background-check-billing.service.ts | 13 +- .../background-check-payment.service.spec.ts | 300 ++++-------------- .../background-check-payment.service.ts | 225 ++----------- .../billing-entitlements.service.spec.ts | 44 ++- .../billing/billing-entitlements.service.ts | 70 +++- .../api/src/billing/billing-setup-sessions.ts | 112 +++++++ .../src/billing/billing-subscription-plans.ts | 133 ++++++++ apps/api/src/billing/billing-usage.ts | 13 +- apps/api/src/billing/billing.service.spec.ts | 228 ++++++++++++- apps/api/src/billing/billing.service.ts | 168 +++++----- apps/api/src/billing/billing.types.ts | 4 + .../pentest-credits.service.ts | 20 +- ...security-penetration-tests.billing.spec.ts | 56 +++- ...security-penetration-tests.service.spec.ts | 10 +- .../security-penetration-tests.service.ts | 45 ++- .../components/BackgroundCheckDetailsForm.tsx | 28 +- .../components/BackgroundCheckWizardParts.tsx | 13 +- .../EmployeeBackgroundCheck.test.tsx | 169 +++------- .../components/EmployeeBackgroundCheck.tsx | 39 ++- .../components/backgroundCheckTypes.ts | 9 + .../penetration-test-page-client.test.tsx | 30 +- .../_components/CreateRunPanel.tsx | 65 ++-- .../_components/DetailPane.tsx | 6 +- .../_components/EmptyState.tsx | 42 +-- .../penetration-tests/_components/RunList.tsx | 70 ++-- .../_components/SplitView.tsx | 118 ++++--- .../hooks/use-penetration-tests.test.tsx | 22 +- .../billing/BillingAddOnPlansClient.tsx | 22 +- .../settings/billing/BillingAddOns.test.tsx | 35 +- .../billing/BillingAddOnsOverview.tsx | 98 ++++-- .../billing/BillingSettingsClient.tsx | 2 +- .../billing/BillingSubscriptionPlans.tsx | 141 ++++++-- .../settings/billing/BillingUsageTable.tsx | 7 +- .../[orgId]/settings/billing/billingAddOns.ts | 34 +- .../(app)/[orgId]/settings/billing/types.ts | 4 + .../components/PostPaymentOnboarding.tsx | 204 ++++++------ .../actions/create-organization-minimal.ts | 7 - .../setup/actions/create-organization.ts | 7 - .../components/OrganizationSetupForm.tsx | 7 +- packages/billing/src/catalog.test.ts | 95 ++++-- packages/billing/src/index.ts | 178 +++++++---- packages/billing/src/sku-definitions.ts | 152 +++++++++ .../migration.sql | 13 + 44 files changed, 1840 insertions(+), 1219 deletions(-) create mode 100644 apps/api/src/billing/billing-setup-sessions.ts create mode 100644 apps/api/src/billing/billing-subscription-plans.ts create mode 100644 packages/billing/src/sku-definitions.ts create mode 100644 packages/db/prisma/migrations/20260501001000_remove_pentest_trial_credits/migration.sql diff --git a/apps/api/package.json b/apps/api/package.json index a43b296328..8e9a3cb9ea 100644 --- a/apps/api/package.json +++ b/apps/api/package.json @@ -170,6 +170,7 @@ "moduleNameMapper": { "^@db$": "/../prisma/index", "^@/(.*)$": "/$1", + "^\\./sku-definitions\\.js$": "/../../../packages/billing/src/sku-definitions.ts", "^@trycompai/auth$": "/../../../packages/auth/src/index.ts", "^@trycompai/billing$": "/../../../packages/billing/src/index.ts", "^@trycompai/company$": "/../../../packages/company/src/index.ts", diff --git a/apps/api/src/background-checks/background-check-billing.service.ts b/apps/api/src/background-checks/background-check-billing.service.ts index 0236fedea2..8b9b7b3975 100644 --- a/apps/api/src/background-checks/background-check-billing.service.ts +++ b/apps/api/src/background-checks/background-check-billing.service.ts @@ -1,5 +1,8 @@ import { Injectable } from '@nestjs/common'; -import { getBillingSku } from '@trycompai/billing'; +import { + getBillingSku, + resolveBillingCatalogEnvironment, +} from '@trycompai/billing'; import { BillingService } from '../billing/billing.service'; import { validateBackgroundCheckBillingRedirectUrl } from './background-check-billing-urls'; @@ -42,7 +45,13 @@ export class BackgroundCheckBillingService { unitAmount: number; currency: string; }> { - const sku = getBillingSku({ skuKey: 'background_check_one_time' }); + const sku = getBillingSku({ + environment: resolveBillingCatalogEnvironment({ + stripeSecretKey: process.env.STRIPE_SECRET_KEY, + nodeEnv: process.env.NODE_ENV, + }), + skuKey: 'background_check_one_time', + }); return { id: sku.stripePriceId, unitAmount: sku.unitAmount, diff --git a/apps/api/src/background-checks/background-check-payment.service.spec.ts b/apps/api/src/background-checks/background-check-payment.service.spec.ts index bc5041bb16..9c4bb48dc2 100644 --- a/apps/api/src/background-checks/background-check-payment.service.spec.ts +++ b/apps/api/src/background-checks/background-check-payment.service.spec.ts @@ -1,203 +1,61 @@ +jest.mock('@db', () => ({ db: {} })); + import { HttpStatus } from '@nestjs/common'; -import { db } from '@db'; import { BillingEntitlementsService } from '../billing/billing-entitlements.service'; import { StripeService } from '../stripe/stripe.service'; import { BackgroundCheckBillingService } from './background-check-billing.service'; import { BackgroundCheckPaymentService } from './background-check-payment.service'; -jest.mock('@db', () => ({ - db: { - organizationBilling: { - findUnique: jest.fn(), - }, - }, -})); - -const mockedDb = db as jest.Mocked; - -function mockAsync(fn: unknown): jest.MockedFunction<() => Promise> { - return fn as jest.MockedFunction<() => Promise>; -} - function mockEntitlements( overrides: Partial = {}, ): BillingEntitlementsService { return { - tryConsumeIncludedUsage: jest + tryConsumeIncludedUsageForProduct: jest .fn() .mockResolvedValue({ status: 'not_configured' }), - recordOneTimeUsage: jest.fn().mockResolvedValue(undefined), - refundIncludedUsage: jest.fn().mockResolvedValue(undefined), - syncSubscriptionItem: jest.fn().mockResolvedValue(undefined), - writeAuditEvent: jest.fn().mockResolvedValue(undefined), + refundIncludedUsageForProduct: jest.fn().mockResolvedValue(undefined), ...overrides, } as unknown as BillingEntitlementsService; } -function mockBillingRow() { - return { - id: 'obil_1', - organizationId: 'org_1', - stripeCustomerId: 'cus_1', - stripePaymentMethodId: 'pm_1', - paymentMethodUpdatedAt: null, - createdAt: new Date('2026-04-30T00:00:00.000Z'), - updatedAt: new Date('2026-04-30T00:00:00.000Z'), - }; -} - describe('BackgroundCheckPaymentService', () => { beforeEach(() => { jest.clearAllMocks(); }); - it('throws payment required when no background check payment method exists', async () => { - mockAsync>>( - mockedDb.organizationBilling.findUnique, - ).mockResolvedValueOnce(null); + it('consumes background check subscription allowance', async () => { + const entitlements = mockEntitlements({ + tryConsumeIncludedUsageForProduct: jest + .fn() + .mockResolvedValue({ status: 'consumed', subscriptionId: 'obs_1' }), + }); const service = new BackgroundCheckPaymentService( { getClient: jest.fn() } as unknown as StripeService, - { - getBackgroundCheckPrice: jest.fn(), - } as unknown as BackgroundCheckBillingService, - mockEntitlements(), + {} as BackgroundCheckBillingService, + entitlements, ); await expect( service.charge({ organizationId: 'org_1', memberId: 'mem_1' }), - ).rejects.toThrow( - expect.objectContaining({ - status: HttpStatus.PAYMENT_REQUIRED, - }), - ); - }); - - it('creates and pays a Stripe invoice with payment-method scoped idempotency keys', async () => { - mockAsync>>( - mockedDb.organizationBilling.findUnique, - ).mockResolvedValueOnce(mockBillingRow()); - const invoiceItemsCreate = jest.fn().mockResolvedValue({ id: 'ii_1' }); - const invoicesCreate = jest.fn().mockResolvedValue({ id: 'in_1' }); - const finalizeInvoice = jest.fn().mockResolvedValue({ id: 'in_1' }); - const invoicesPay = jest.fn().mockResolvedValue({ - id: 'in_1', - status: 'paid', - payments: { - data: [ - { - payment: { - type: 'payment_intent', - payment_intent: 'pi_1', - }, - }, - ], - }, + ).resolves.toEqual({ + paymentIntentId: null, + invoiceId: null, + status: 'subscription_included', + amount: 0, + currency: 'usd', }); - const service = new BackgroundCheckPaymentService( - { - getClient: () => ({ - invoiceItems: { create: invoiceItemsCreate }, - invoices: { - create: invoicesCreate, - finalizeInvoice, - pay: invoicesPay, - }, - }), - } as unknown as StripeService, - { - getBackgroundCheckPrice: jest.fn().mockResolvedValue({ - id: 'price_bg', - unitAmount: 1250, - currency: 'usd', - }), - } as unknown as BackgroundCheckBillingService, - mockEntitlements(), - ); - await expect( - service.charge({ organizationId: 'org_1', memberId: 'mem_1' }), - ).resolves.toMatchObject({ - paymentIntentId: 'pi_1', - invoiceId: 'in_1', - status: 'succeeded', + expect(entitlements.tryConsumeIncludedUsageForProduct).toHaveBeenCalledWith({ + organizationId: 'org_1', + productKey: 'background_check', + sourceResourceId: 'mem_1', }); - - expect(invoicesCreate).toHaveBeenCalledWith( - expect.objectContaining({ - customer: 'cus_1', - collection_method: 'charge_automatically', - description: 'Comp AI - Background Check x1', - default_payment_method: 'pm_1', - statement_descriptor: 'COMP AI BG CHECK', - }), - { idempotencyKey: 'background-check:org_1:mem_1:price_bg:pm_1:invoice' }, - ); - expect(invoiceItemsCreate).toHaveBeenCalledWith( - expect.objectContaining({ - customer: 'cus_1', - invoice: 'in_1', - pricing: { - price: 'price_bg', - }, - quantity: 1, - }), - { - idempotencyKey: 'background-check:org_1:mem_1:price_bg:pm_1:line-item', - }, - ); - expect(finalizeInvoice).toHaveBeenCalledWith( - 'in_1', - { auto_advance: false }, - { - idempotencyKey: - 'background-check:org_1:mem_1:price_bg:pm_1:finalize-invoice', - }, - ); - expect(invoicesPay).toHaveBeenCalledWith( - 'in_1', - expect.objectContaining({ - payment_method: 'pm_1', - off_session: true, - }), - { - idempotencyKey: - 'background-check:org_1:mem_1:price_bg:pm_1:pay-invoice', - }, - ); }); - it('deletes the draft invoice when adding the invoice item fails', async () => { - mockAsync>>( - mockedDb.organizationBilling.findUnique, - ).mockResolvedValueOnce(mockBillingRow()); - const invoiceItemsCreate = jest - .fn() - .mockRejectedValue(new Error('line item failed')); - const invoicesCreate = jest.fn().mockResolvedValue({ id: 'in_1' }); - const finalizeInvoice = jest.fn(); - const invoicesPay = jest.fn(); - const deleteInvoice = jest.fn().mockResolvedValue({ id: 'in_1' }); - const voidInvoice = jest.fn(); + it('blocks when no background check subscription is configured', async () => { const service = new BackgroundCheckPaymentService( - { - getClient: () => ({ - invoiceItems: { create: invoiceItemsCreate }, - invoices: { - create: invoicesCreate, - finalizeInvoice, - pay: invoicesPay, - del: deleteInvoice, - voidInvoice, - }, - }), - } as unknown as StripeService, - { - getBackgroundCheckPrice: jest.fn().mockResolvedValue({ - id: 'price_bg', - unitAmount: 1250, - currency: 'usd', - }), - } as unknown as BackgroundCheckBillingService, + { getClient: jest.fn() } as unknown as StripeService, + {} as BackgroundCheckBillingService, mockEntitlements(), ); @@ -206,48 +64,22 @@ describe('BackgroundCheckPaymentService', () => { ).rejects.toThrow( expect.objectContaining({ status: HttpStatus.PAYMENT_REQUIRED, + response: expect.objectContaining({ + code: 'background_check_subscription_required', + }), }), ); - - expect(deleteInvoice).toHaveBeenCalledWith('in_1'); - expect(voidInvoice).not.toHaveBeenCalled(); - expect(finalizeInvoice).not.toHaveBeenCalled(); - expect(invoicesPay).not.toHaveBeenCalled(); }); - it('deletes the draft invoice when finalizing fails', async () => { - mockAsync>>( - mockedDb.organizationBilling.findUnique, - ).mockResolvedValueOnce(mockBillingRow()); - const invoiceItemsCreate = jest.fn().mockResolvedValue({ id: 'ii_1' }); - const invoicesCreate = jest.fn().mockResolvedValue({ id: 'in_1' }); - const finalizeInvoice = jest - .fn() - .mockRejectedValue(new Error('finalize failed')); - const invoicesPay = jest.fn(); - const deleteInvoice = jest.fn().mockResolvedValue({ id: 'in_1' }); - const voidInvoice = jest.fn(); + it('blocks when background check subscription allowance is exhausted', async () => { const service = new BackgroundCheckPaymentService( - { - getClient: () => ({ - invoiceItems: { create: invoiceItemsCreate }, - invoices: { - create: invoicesCreate, - finalizeInvoice, - pay: invoicesPay, - del: deleteInvoice, - voidInvoice, - }, - }), - } as unknown as StripeService, - { - getBackgroundCheckPrice: jest.fn().mockResolvedValue({ - id: 'price_bg', - unitAmount: 1250, - currency: 'usd', - }), - } as unknown as BackgroundCheckBillingService, - mockEntitlements(), + { getClient: jest.fn() } as unknown as StripeService, + {} as BackgroundCheckBillingService, + mockEntitlements({ + tryConsumeIncludedUsageForProduct: jest + .fn() + .mockResolvedValue({ status: 'exhausted', subscriptionId: 'obs_1' }), + }), ); await expect( @@ -255,56 +87,34 @@ describe('BackgroundCheckPaymentService', () => { ).rejects.toThrow( expect.objectContaining({ status: HttpStatus.PAYMENT_REQUIRED, + response: expect.objectContaining({ + code: 'background_check_subscription_exhausted', + }), }), ); - - expect(deleteInvoice).toHaveBeenCalledWith('in_1'); - expect(voidInvoice).not.toHaveBeenCalled(); - expect(invoicesPay).not.toHaveBeenCalled(); }); - it('voids the finalized invoice when paying fails', async () => { - mockAsync>>( - mockedDb.organizationBilling.findUnique, - ).mockResolvedValueOnce(mockBillingRow()); - const invoiceItemsCreate = jest.fn().mockResolvedValue({ id: 'ii_1' }); - const invoicesCreate = jest.fn().mockResolvedValue({ id: 'in_1' }); - const finalizeInvoice = jest.fn().mockResolvedValue({ id: 'in_1' }); - const invoicesPay = jest.fn().mockRejectedValue(new Error('pay failed')); - const deleteInvoice = jest.fn(); - const voidInvoice = jest.fn().mockResolvedValue({ id: 'in_1' }); + it('refunds the consumed background check allowance by product family', async () => { + const entitlements = mockEntitlements(); const service = new BackgroundCheckPaymentService( - { - getClient: () => ({ - invoiceItems: { create: invoiceItemsCreate }, - invoices: { - create: invoicesCreate, - finalizeInvoice, - pay: invoicesPay, - del: deleteInvoice, - voidInvoice, - }, - }), - } as unknown as StripeService, - { - getBackgroundCheckPrice: jest.fn().mockResolvedValue({ - id: 'price_bg', - unitAmount: 1250, - currency: 'usd', - }), - } as unknown as BackgroundCheckBillingService, - mockEntitlements(), + { getClient: jest.fn() } as unknown as StripeService, + {} as BackgroundCheckBillingService, + entitlements, ); await expect( - service.charge({ organizationId: 'org_1', memberId: 'mem_1' }), - ).rejects.toThrow( - expect.objectContaining({ - status: HttpStatus.PAYMENT_REQUIRED, + service.refund({ + organizationId: 'org_1', + memberId: 'mem_1', + paymentIntentId: null, }), - ); + ).resolves.toBeNull(); - expect(deleteInvoice).not.toHaveBeenCalled(); - expect(voidInvoice).toHaveBeenCalledWith('in_1'); + expect(entitlements.refundIncludedUsageForProduct).toHaveBeenCalledWith({ + organizationId: 'org_1', + productKey: 'background_check', + sourceResourceId: 'mem_1', + reason: 'background_check_failed', + }); }); }); diff --git a/apps/api/src/background-checks/background-check-payment.service.ts b/apps/api/src/background-checks/background-check-payment.service.ts index a976b1e575..adac71125c 100644 --- a/apps/api/src/background-checks/background-check-payment.service.ts +++ b/apps/api/src/background-checks/background-check-payment.service.ts @@ -1,16 +1,10 @@ import { HttpException, HttpStatus, Injectable, Logger } from '@nestjs/common'; -import { db } from '@db'; -import Stripe from 'stripe'; import { BillingEntitlementsService } from '../billing/billing-entitlements.service'; import { StripeService } from '../stripe/stripe.service'; import { BackgroundCheckBillingService } from './background-check-billing.service'; @Injectable() export class BackgroundCheckPaymentService { - private static readonly receiptDescription = 'Comp AI - Background Check x1'; - - private static readonly statementDescriptor = 'COMP AI BG CHECK'; - private readonly logger = new Logger(BackgroundCheckPaymentService.name); constructor( @@ -26,11 +20,13 @@ export class BackgroundCheckPaymentService { amount: number; currency: string; }> { - const includedUsage = await this.entitlements.tryConsumeIncludedUsage({ - organizationId: params.organizationId, - skuKey: 'background_checks_monthly_25', - sourceResourceId: params.memberId, - }); + const includedUsage = + await this.entitlements.tryConsumeIncludedUsageForProduct({ + organizationId: params.organizationId, + productKey: 'background_check', + sourceResourceId: params.memberId, + }); + if (includedUsage.status === 'consumed') { return { paymentIntentId: null, @@ -41,140 +37,19 @@ export class BackgroundCheckPaymentService { }; } - const billing = await db.organizationBilling.findUnique({ - where: { organizationId: params.organizationId }, - select: { - stripeCustomerId: true, - stripePaymentMethodId: true, - }, - }); - - if (!billing?.stripePaymentMethodId) { - throw new HttpException( - 'No payment method on file. Update billing first.', - HttpStatus.PAYMENT_REQUIRED, - ); - } - - const price = await this.billingService.getBackgroundCheckPrice(); - const stripe = this.stripeService.getClient(); - const metadata = { - source: 'comp-background-check', - compOrganizationId: params.organizationId, - compMemberId: params.memberId, - }; - const idempotencyKeyParts = [ - 'background-check', - params.organizationId, - params.memberId, - price.id, - billing.stripePaymentMethodId, - ]; - - const invoice = await stripe.invoices.create( + throw new HttpException( { - customer: billing.stripeCustomerId, - collection_method: 'charge_automatically', - currency: price.currency, - default_payment_method: billing.stripePaymentMethodId, - description: BackgroundCheckPaymentService.receiptDescription, - statement_descriptor: BackgroundCheckPaymentService.statementDescriptor, - auto_advance: false, - metadata, - }, - { - idempotencyKey: [...idempotencyKeyParts, 'invoice'].join(':'), + error: + includedUsage.status === 'exhausted' + ? 'No background checks remaining in your subscription. Upgrade or wait for your monthly allowance to reset.' + : 'Choose a background check plan before requesting a background check.', + code: + includedUsage.status === 'exhausted' + ? 'background_check_subscription_exhausted' + : 'background_check_subscription_required', }, + HttpStatus.PAYMENT_REQUIRED, ); - - let paidInvoice: Stripe.Invoice; - let invoiceFinalized = false; - try { - await stripe.invoiceItems.create( - { - customer: billing.stripeCustomerId, - invoice: invoice.id, - pricing: { - price: price.id, - }, - quantity: 1, - description: BackgroundCheckPaymentService.receiptDescription, - metadata, - }, - { - idempotencyKey: [...idempotencyKeyParts, 'line-item'].join(':'), - }, - ); - - await stripe.invoices.finalizeInvoice( - invoice.id, - { auto_advance: false }, - { - idempotencyKey: [...idempotencyKeyParts, 'finalize-invoice'].join( - ':', - ), - }, - ); - invoiceFinalized = true; - - paidInvoice = await stripe.invoices.pay( - invoice.id, - { - payment_method: billing.stripePaymentMethodId, - off_session: true, - expand: ['payments'], - }, - { - idempotencyKey: [...idempotencyKeyParts, 'pay-invoice'].join(':'), - }, - ); - } catch (error) { - await this.cleanupUnpaidInvoice({ - stripe, - invoiceId: invoice.id, - finalized: invoiceFinalized, - }); - throw new HttpException( - 'Background check payment failed. Update billing and try again.', - HttpStatus.PAYMENT_REQUIRED, - { cause: error }, - ); - } - - if (paidInvoice.status !== 'paid') { - await this.cleanupUnpaidInvoice({ - stripe, - invoiceId: invoice.id, - finalized: true, - }); - throw new HttpException( - 'Background check payment failed. Update billing and try again.', - HttpStatus.PAYMENT_REQUIRED, - ); - } - - const paymentIntentId = this.extractPaymentIntentId(paidInvoice); - if (!paymentIntentId) { - throw new HttpException( - 'Background check payment failed. Update billing and try again.', - HttpStatus.PAYMENT_REQUIRED, - ); - } - - await this.entitlements.recordOneTimeUsage({ - organizationId: params.organizationId, - skuKey: 'background_check_one_time', - sourceResourceId: params.memberId, - stripeInvoiceId: paidInvoice.id, - }); - - return { - paymentIntentId, - invoiceId: paidInvoice.id, - status: 'succeeded', - amount: price.unitAmount, - currency: price.currency, - }; } async refund(params: { @@ -183,9 +58,9 @@ export class BackgroundCheckPaymentService { paymentIntentId: string | null; }): Promise { if (!params.paymentIntentId) { - await this.entitlements.refundIncludedUsage({ + await this.entitlements.refundIncludedUsageForProduct({ organizationId: params.organizationId, - skuKey: 'background_checks_monthly_25', + productKey: 'background_check', sourceResourceId: params.memberId, reason: 'background_check_failed', }); @@ -208,7 +83,7 @@ export class BackgroundCheckPaymentService { return refund.id; } catch (error) { this.logger.error( - 'Failed to refund background check payment — manual refund required.', + 'Failed to refund background check payment - manual refund required.', { organizationId: params.organizationId, memberId: params.memberId, @@ -220,63 +95,7 @@ export class BackgroundCheckPaymentService { } } - private extractPaymentIntentId(invoice: Stripe.Invoice): string | null { - const payment = invoice.payments?.data.find( - (invoicePayment) => invoicePayment.payment.type === 'payment_intent', - ); - const paymentIntent = payment?.payment.payment_intent; - if (!paymentIntent) return null; - return typeof paymentIntent === 'string' ? paymentIntent : paymentIntent.id; - } - - private async cleanupUnpaidInvoice({ - stripe, - invoiceId, - finalized, - }: { - stripe: Stripe; - invoiceId: string; - finalized: boolean; - }): Promise { - if (!finalized) { - await this.deleteDraftInvoice({ stripe, invoiceId }); - return; - } - - await this.voidFinalizedInvoice({ stripe, invoiceId }); - } - - private async deleteDraftInvoice({ - stripe, - invoiceId, - }: { - stripe: Stripe; - invoiceId: string; - }): Promise { - try { - await stripe.invoices.del(invoiceId); - } catch (error) { - this.logger.error('Failed to delete draft background check invoice.', { - invoiceId, - error: error instanceof Error ? error.message : 'Unknown error', - }); - } - } - - private async voidFinalizedInvoice({ - stripe, - invoiceId, - }: { - stripe: Stripe; - invoiceId: string; - }): Promise { - try { - await stripe.invoices.voidInvoice(invoiceId); - } catch (error) { - this.logger.error('Failed to void unpaid background check invoice.', { - invoiceId, - error: error instanceof Error ? error.message : 'Unknown error', - }); - } + async getBackgroundCheckPrice() { + return this.billingService.getBackgroundCheckPrice(); } } diff --git a/apps/api/src/billing/billing-entitlements.service.spec.ts b/apps/api/src/billing/billing-entitlements.service.spec.ts index f884f34747..cfdecaf08a 100644 --- a/apps/api/src/billing/billing-entitlements.service.spec.ts +++ b/apps/api/src/billing/billing-entitlements.service.spec.ts @@ -5,6 +5,7 @@ jest.mock('@db', () => ({ db: { organizationBillingSubscription: { findUnique: jest.fn(), + findMany: jest.fn(), }, billingAuditEvent: { create: jest.fn(), @@ -25,7 +26,7 @@ type MockTx = { }; const mockedDb = db as unknown as { - organizationBillingSubscription: { findUnique: jest.Mock }; + organizationBillingSubscription: { findUnique: jest.Mock; findMany: jest.Mock }; billingAuditEvent: { create: jest.Mock }; $transaction: jest.Mock; }; @@ -82,4 +83,45 @@ describe('BillingEntitlementsService', () => { ); expect(tx.billingUsageEvent.create).not.toHaveBeenCalled(); }); + + it('consumes the active subscription for a product family', async () => { + mockedDb.organizationBillingSubscription.findMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'pentest_monthly_4', + stripeStatus: 'active', + usedQuantity: 1, + includedQuantity: 4, + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }, + ]); + mockedDb.organizationBillingSubscription.findUnique.mockResolvedValue({ + id: 'obs_1', + skuKey: 'pentest_monthly_4', + stripeStatus: 'active', + usedQuantity: 1, + includedQuantity: 4, + stripeSubscriptionItemId: 'si_1', + currentPeriodStart: new Date('2026-04-30T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }); + tx.organizationBillingSubscription.updateMany.mockResolvedValue({ count: 1 }); + + await expect( + service.tryConsumeIncludedUsageForProduct({ + organizationId: 'org_1', + productKey: 'pentest', + sourceResourceId: 'run_1', + }), + ).resolves.toEqual({ status: 'consumed', subscriptionId: 'obs_1' }); + + expect(mockedDb.organizationBillingSubscription.findUnique).toHaveBeenCalledWith({ + where: { + organizationId_skuKey: { + organizationId: 'org_1', + skuKey: 'pentest_monthly_4', + }, + }, + }); + }); }); diff --git a/apps/api/src/billing/billing-entitlements.service.ts b/apps/api/src/billing/billing-entitlements.service.ts index 2ebff97ba9..dd4612c1a1 100644 --- a/apps/api/src/billing/billing-entitlements.service.ts +++ b/apps/api/src/billing/billing-entitlements.service.ts @@ -1,6 +1,10 @@ import { HttpException, HttpStatus, Injectable } from '@nestjs/common'; import { Prisma, db } from '@db'; -import type { BillingSkuKey } from '@trycompai/billing'; +import { + getBillingSkuKeysForProduct, + type BillingProductKey, + type BillingSkuKey, +} from '@trycompai/billing'; import { refundIncludedUsageEvent } from './billing-included-usage-refunds'; import { type BillingConsumeResult, @@ -13,6 +17,41 @@ import { @Injectable() export class BillingEntitlementsService { + async tryConsumeIncludedUsageForProduct(params: { + organizationId: string; + productKey: BillingProductKey; + sourceResourceId: string; + }): Promise { + const skuKeys = getBillingSkuKeysForProduct(params.productKey); + const subscriptions = await db.organizationBillingSubscription.findMany({ + where: { + organizationId: params.organizationId, + skuKey: { in: skuKeys }, + }, + orderBy: [{ createdAt: 'desc' }], + }); + + const activeSubscription = subscriptions.find( + (subscription) => + isAccessStatus(subscription.stripeStatus) && + (!subscription.currentPeriodEnd || + subscription.currentPeriodEnd.getTime() > Date.now()), + ); + if (!activeSubscription) { + return { status: 'not_configured' }; + } + + if (activeSubscription.usedQuantity >= activeSubscription.includedQuantity) { + return { status: 'exhausted', subscriptionId: activeSubscription.id }; + } + + return this.tryConsumeIncludedUsage({ + organizationId: params.organizationId, + skuKey: activeSubscription.skuKey as BillingSkuKey, + sourceResourceId: params.sourceResourceId, + }); + } + async tryConsumeIncludedUsage(params: { organizationId: string; skuKey: BillingSkuKey; @@ -233,6 +272,35 @@ export class BillingEntitlementsService { await refundIncludedUsageEvent(params); } + async refundIncludedUsageForProduct(params: { + organizationId: string; + productKey: BillingProductKey; + sourceResourceId: string; + reason: string; + tx?: Prisma.TransactionClient; + }): Promise { + const skuKeys = getBillingSkuKeysForProduct(params.productKey); + const client = params.tx ?? db; + const consumed = await client.billingUsageEvent.findFirst({ + where: { + organizationId: params.organizationId, + skuKey: { in: skuKeys }, + eventType: 'consume', + sourceResourceId: params.sourceResourceId, + }, + orderBy: { createdAt: 'desc' }, + select: { skuKey: true }, + }); + if (!consumed) return; + await refundIncludedUsageEvent({ + organizationId: params.organizationId, + skuKey: consumed.skuKey as BillingSkuKey, + sourceResourceId: params.sourceResourceId, + reason: params.reason, + tx: params.tx, + }); + } + async writeAuditEvent(params: WriteBillingAuditEventParams): Promise { await db.billingAuditEvent.create({ data: { diff --git a/apps/api/src/billing/billing-setup-sessions.ts b/apps/api/src/billing/billing-setup-sessions.ts new file mode 100644 index 0000000000..ddd2ff4fd3 --- /dev/null +++ b/apps/api/src/billing/billing-setup-sessions.ts @@ -0,0 +1,112 @@ +import { BadRequestException } from '@nestjs/common'; +import { db } from '@db'; +import { StripeService } from '../stripe/stripe.service'; +import { findOrCreateBillingCustomer } from './billing-customer'; +import { validateBillingRedirectUrl } from './billing-redirect-urls'; +import { assertStripeBillingConfigured } from './billing-stripe-config'; +import { extractStripeId } from './billing-stripe-ids'; + +export async function createBillingSetupSession(params: { + organizationId: string; + successUrl: string; + cancelUrl: string; + customerEmail?: string; + stripeService: StripeService; +}): Promise<{ url: string }> { + validateBillingRedirectUrl(params.successUrl); + validateBillingRedirectUrl(params.cancelUrl); + assertStripeBillingConfigured(params.stripeService); + + const stripe = params.stripeService.getClient(); + const customerId = await findOrCreateBillingCustomer({ + stripeService: params.stripeService, + organizationId: params.organizationId, + customerEmail: params.customerEmail, + }); + + const session = await stripe.checkout.sessions.create({ + mode: 'setup', + customer: customerId, + currency: 'usd', + success_url: params.successUrl, + cancel_url: params.cancelUrl, + metadata: { + organizationId: params.organizationId, + source: 'comp-billing-setup', + }, + }); + + if (!session.url) { + throw new BadRequestException('Failed to create Stripe Checkout session.'); + } + + return { url: session.url }; +} + +export async function handleBillingSetupSuccess(params: { + organizationId: string; + sessionId: string; + stripeService: StripeService; +}): Promise<{ success: true }> { + assertStripeBillingConfigured(params.stripeService); + + const stripe = params.stripeService.getClient(); + const session = await stripe.checkout.sessions.retrieve(params.sessionId, { + expand: ['setup_intent'], + }); + + if (session.status !== 'complete') { + throw new BadRequestException('Checkout session is not complete.'); + } + + if (session.metadata?.organizationId !== params.organizationId) { + throw new BadRequestException( + 'Checkout session does not belong to this organization.', + ); + } + + const stripeCustomerId = extractStripeId(session.customer); + if (!stripeCustomerId) { + throw new BadRequestException('Checkout session is missing a customer.'); + } + const billing = await db.organizationBilling.findUnique({ + where: { organizationId: params.organizationId }, + select: { stripeCustomerId: true }, + }); + if (billing && billing.stripeCustomerId !== stripeCustomerId) { + throw new BadRequestException( + 'Checkout session customer does not match this organization.', + ); + } + + const setupIntent = session.setup_intent; + if (!setupIntent || typeof setupIntent === 'string') { + throw new BadRequestException('Checkout session is missing a setup intent.'); + } + + const paymentMethodId = extractStripeId(setupIntent.payment_method); + if (!paymentMethodId) { + throw new BadRequestException('Setup intent is missing a payment method.'); + } + + await stripe.customers.update(stripeCustomerId, { + invoice_settings: { default_payment_method: paymentMethodId }, + }); + + await db.organizationBilling.upsert({ + where: { organizationId: params.organizationId }, + create: { + organizationId: params.organizationId, + stripeCustomerId, + stripePaymentMethodId: paymentMethodId, + paymentMethodUpdatedAt: new Date(), + }, + update: { + stripeCustomerId, + stripePaymentMethodId: paymentMethodId, + paymentMethodUpdatedAt: new Date(), + }, + }); + + return { success: true }; +} diff --git a/apps/api/src/billing/billing-subscription-plans.ts b/apps/api/src/billing/billing-subscription-plans.ts new file mode 100644 index 0000000000..61b128b6d9 --- /dev/null +++ b/apps/api/src/billing/billing-subscription-plans.ts @@ -0,0 +1,133 @@ +import { db } from '@db'; +import { + type BillingProductKey, + type BillingSkuKey, + getBillingSkuProductKey, +} from '@trycompai/billing'; +import { BillingEntitlementsService } from './billing-entitlements.service'; +import { StripeService } from '../stripe/stripe.service'; + +export async function findActiveProductSubscription(params: { + organizationId: string; + productKey: BillingProductKey; +}) { + const subscriptions = await findProductSubscriptions(params); + return ( + subscriptions.find((subscription) => { + return ( + subscription.stripeStatus === 'active' || + subscription.stripeStatus === 'trialing' + ); + }) ?? null + ); +} + +export async function findProductSubscriptions(params: { + organizationId: string; + productKey: BillingProductKey; +}) { + const subscriptions = await db.organizationBillingSubscription.findMany({ + where: { organizationId: params.organizationId }, + orderBy: { createdAt: 'desc' }, + }); + return subscriptions.filter( + (subscription) => + getBillingSkuProductKey(subscription.skuKey) === params.productKey, + ); +} + +export async function changeSubscriptionPlan(params: { + organizationId: string; + subscription: { + id: string; + skuKey: string; + stripeSubscriptionId: string; + stripeSubscriptionItemId: string; + currentPeriodStart: Date | null; + currentPeriodEnd: Date | null; + }; + skuKey: BillingSkuKey; + stripePriceId: string; + includedQuantity: number; + stripeService: StripeService; + entitlements: BillingEntitlementsService; +}): Promise<{ changed: true }> { + const stripe = params.stripeService.getClient(); + const updatedSubscription = await stripe.subscriptions.update( + params.subscription.stripeSubscriptionId, + { + items: [ + { + id: params.subscription.stripeSubscriptionItemId, + price: params.stripePriceId, + }, + ], + metadata: { + organizationId: params.organizationId, + skuKey: params.skuKey, + source: 'comp-billing-subscription', + }, + }, + { + idempotencyKey: [ + 'subscription-plan-change', + params.organizationId, + params.subscription.stripeSubscriptionItemId, + params.skuKey, + ].join(':'), + }, + ); + + const updatedItem = + updatedSubscription.items.data.find( + (item) => item.id === params.subscription.stripeSubscriptionItemId, + ) ?? updatedSubscription.items.data[0]; + + const currentPeriodStart = + dateFromSeconds(readNumber(updatedItem, 'current_period_start')) ?? + params.subscription.currentPeriodStart; + const currentPeriodEnd = + dateFromSeconds(readNumber(updatedItem, 'current_period_end')) ?? + params.subscription.currentPeriodEnd; + + await db.organizationBillingSubscription.update({ + where: { id: params.subscription.id }, + data: { + skuKey: params.skuKey, + stripeSubscriptionId: updatedSubscription.id, + stripeSubscriptionItemId: + updatedItem?.id ?? params.subscription.stripeSubscriptionItemId, + stripePriceId: params.stripePriceId, + stripeStatus: updatedSubscription.status, + currentPeriodStart, + currentPeriodEnd, + includedQuantity: params.includedQuantity, + cancelAtPeriodEnd: updatedSubscription.cancel_at_period_end, + canceledAt: dateFromSeconds(updatedSubscription.canceled_at), + }, + }); + + await params.entitlements.writeAuditEvent({ + organizationId: params.organizationId, + eventType: 'subscription_plan_changed', + skuKey: params.skuKey, + metadata: { + stripeSubscriptionId: updatedSubscription.id, + stripeSubscriptionItemId: + updatedItem?.id ?? params.subscription.stripeSubscriptionItemId, + previousSkuKey: params.subscription.skuKey, + }, + }); + + return { changed: true }; +} + +function dateFromSeconds(value: number | null): Date | null { + return value === null ? null : new Date(value * 1000); +} + +function readNumber(value: unknown, key: string): number | null { + if (typeof value !== 'object' || value === null) return null; + const raw = (value as Record)[key]; + return typeof raw === 'number' ? raw : null; +} diff --git a/apps/api/src/billing/billing-usage.ts b/apps/api/src/billing/billing-usage.ts index 26c64a787b..eeb41e8816 100644 --- a/apps/api/src/billing/billing-usage.ts +++ b/apps/api/src/billing/billing-usage.ts @@ -1,5 +1,5 @@ import { db } from '@db'; -import type { BillingSkuKey } from '@trycompai/billing'; +import { getBillingSkuProductKey, type BillingSkuKey } from '@trycompai/billing'; import type { BillingUsageRow } from './billing.types'; type SubscriptionSummary = { @@ -9,8 +9,8 @@ type SubscriptionSummary = { currentPeriodEnd: Date | null; }; -const backgroundCheckSku = 'background_checks_monthly_25'; -const pentestSku = 'pentest_monthly_5'; +const backgroundCheckSku = 'background_checks_monthly_3'; +const pentestSku = 'pentest_monthly_1'; export async function listBillingUsageRows(params: { organizationId: string; @@ -121,8 +121,11 @@ function toBillingUsageRow(params: { updatedAt: Date; subscriptions: SubscriptionSummary[]; }): BillingUsageRow { - const subscription = params.subscriptions.find( - (item) => item.skuKey === params.skuKey, + const productKey = getBillingSkuProductKey(params.skuKey); + const subscription = params.subscriptions.find((item) => + productKey + ? getBillingSkuProductKey(item.skuKey) === productKey + : item.skuKey === params.skuKey, ); const remaining = subscription ? Math.max(subscription.includedQuantity - subscription.usedQuantity, 0) diff --git a/apps/api/src/billing/billing.service.spec.ts b/apps/api/src/billing/billing.service.spec.ts index f0a4cddca1..7b0c6d3ccf 100644 --- a/apps/api/src/billing/billing.service.spec.ts +++ b/apps/api/src/billing/billing.service.spec.ts @@ -12,6 +12,21 @@ jest.mock('@db', () => ({ create: jest.fn(), findUnique: jest.fn(), }, + organizationBillingSubscription: { + findMany: jest.fn(), + update: jest.fn(), + }, + backgroundCheckRequest: { + count: jest.fn(), + findMany: jest.fn(), + }, + securityPenetrationTestRun: { + count: jest.fn(), + findMany: jest.fn(), + }, + billingUsageEvent: { + findMany: jest.fn(), + }, }, })); @@ -22,6 +37,20 @@ const organizationBillingFindUnique = mockedDb.organizationBilling .findUnique as unknown as jest.Mock; const organizationBillingCreate = mockedDb.organizationBilling .create as unknown as jest.Mock; +const organizationBillingSubscriptionFindMany = mockedDb + .organizationBillingSubscription.findMany as unknown as jest.Mock; +const organizationBillingSubscriptionUpdate = mockedDb + .organizationBillingSubscription.update as unknown as jest.Mock; +const backgroundCheckRequestCount = mockedDb.backgroundCheckRequest + .count as unknown as jest.Mock; +const backgroundCheckRequestFindMany = mockedDb.backgroundCheckRequest + .findMany as unknown as jest.Mock; +const securityPenetrationTestRunCount = mockedDb.securityPenetrationTestRun + .count as unknown as jest.Mock; +const securityPenetrationTestRunFindMany = mockedDb.securityPenetrationTestRun + .findMany as unknown as jest.Mock; +const billingUsageEventFindMany = mockedDb.billingUsageEvent + .findMany as unknown as jest.Mock; function mockStripeService(client: unknown): StripeService { return { @@ -47,6 +76,13 @@ describe('BillingService', () => { createdAt: new Date('2026-04-30T00:00:00.000Z'), updatedAt: new Date('2026-04-30T00:00:00.000Z'), }); + organizationBillingSubscriptionFindMany.mockResolvedValue([]); + organizationBillingSubscriptionUpdate.mockResolvedValue({}); + backgroundCheckRequestCount.mockResolvedValue(0); + backgroundCheckRequestFindMany.mockResolvedValue([]); + securityPenetrationTestRunCount.mockResolvedValue(0); + securityPenetrationTestRunFindMany.mockResolvedValue([]); + billingUsageEventFindMany.mockResolvedValue([]); }); it('creates a Stripe subscription checkout session from the billing catalog', async () => { @@ -59,12 +95,13 @@ describe('BillingService', () => { customers: { create: customersCreate }, checkout: { sessions: { create: sessionsCreate } }, }), + { syncSubscriptionItem: jest.fn() } as never, ); await expect( service.createSubscriptionCheckoutSession({ organizationId: 'org_1', - skuKey: 'pentest_monthly_5', + skuKey: 'pentest_monthly_1', successUrl: 'http://localhost:3000/org_1/settings/billing/success', cancelUrl: 'http://localhost:3000/org_1/settings/billing', customerEmail: 'admin@example.com', @@ -82,13 +119,191 @@ describe('BillingService', () => { expect.objectContaining({ mode: 'subscription', customer: 'cus_1', - line_items: [{ price: 'price_1TRya6CkFWhKYvHI1sJ2M2no', quantity: 1 }], + line_items: [{ price: 'price_1TS3ziCkFWhKYvHI0H5TWxNI', quantity: 1 }], + payment_method_collection: 'always', metadata: expect.objectContaining({ organizationId: 'org_1', - skuKey: 'pentest_monthly_5', + skuKey: 'pentest_monthly_1', + }), + subscription_data: expect.objectContaining({ + trial_period_days: 14, + }), + }), + ); + }); + + it('does not apply a trial when the product has historical subscription rows', async () => { + const customersCreate = jest.fn().mockResolvedValue({ id: 'cus_1' }); + const sessionsCreate = jest.fn().mockResolvedValue({ + url: 'https://checkout.stripe.test/session', + }); + organizationBillingSubscriptionFindMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'pentest_monthly_3', + stripeSubscriptionId: 'sub_1', + stripeSubscriptionItemId: 'si_1', + stripeStatus: 'canceled', + createdAt: new Date('2026-04-01T00:00:00.000Z'), + }, + ]); + const service = new BillingService( + mockStripeService({ + customers: { create: customersCreate }, + checkout: { sessions: { create: sessionsCreate } }, + }), + { syncSubscriptionItem: jest.fn() } as never, + ); + + await service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_1', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }); + + expect(sessionsCreate).toHaveBeenCalledWith( + expect.not.objectContaining({ + payment_method_collection: 'always', + }), + ); + expect(sessionsCreate).toHaveBeenCalledWith( + expect.objectContaining({ + subscription_data: expect.not.objectContaining({ + trial_period_days: expect.any(Number), + }), + }), + ); + }); + + it('never applies a trial to higher tiers', async () => { + const customersCreate = jest.fn().mockResolvedValue({ id: 'cus_1' }); + const sessionsCreate = jest.fn().mockResolvedValue({ + url: 'https://checkout.stripe.test/session', + }); + const service = new BillingService( + mockStripeService({ + customers: { create: customersCreate }, + checkout: { sessions: { create: sessionsCreate } }, + }), + { syncSubscriptionItem: jest.fn() } as never, + ); + + await service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_3', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }); + + expect(sessionsCreate).toHaveBeenCalledWith( + expect.objectContaining({ + subscription_data: expect.not.objectContaining({ + trial_period_days: expect.any(Number), + }), + }), + ); + }); + + it('marks trial eligibility false after any product subscription history', async () => { + organizationBillingSubscriptionFindMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'background_checks_monthly_10', + stripeStatus: 'canceled', + includedQuantity: 10, + usedQuantity: 0, + currentPeriodStart: null, + currentPeriodEnd: null, + cancelAtPeriodEnd: false, + }, + ]); + const service = new BillingService( + mockStripeService({ + invoices: { list: jest.fn().mockResolvedValue({ data: [] }) }, + customers: { retrieve: jest.fn().mockResolvedValue({}) }, + paymentMethods: { retrieve: jest.fn() }, + }), + { syncSubscriptionItem: jest.fn() } as never, + ); + + await expect(service.getStatus('org_1')).resolves.toMatchObject({ + trialEligibility: { + pentest: true, + background_check: false, + }, + }); + }); + + it('changes an existing product subscription instead of creating another checkout', async () => { + const subscriptionsUpdate = jest.fn().mockResolvedValue({ + id: 'sub_1', + status: 'active', + cancel_at_period_end: false, + canceled_at: null, + items: { + data: [ + { + id: 'si_1', + current_period_start: 1775001600, + current_period_end: 1777593600, + }, + ], + }, + }); + const writeAuditEvent = jest.fn().mockResolvedValue(undefined); + organizationBillingSubscriptionFindMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'pentest_monthly_1', + stripeSubscriptionId: 'sub_1', + stripeSubscriptionItemId: 'si_1', + stripeStatus: 'active', + currentPeriodStart: new Date('2026-04-01T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-05-01T00:00:00.000Z'), + createdAt: new Date('2026-04-01T00:00:00.000Z'), + }, + ]); + const service = new BillingService( + mockStripeService({ + subscriptions: { update: subscriptionsUpdate }, + }), + { syncSubscriptionItem: jest.fn(), writeAuditEvent } as never, + ); + + await expect( + service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_3', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }), + ).resolves.toEqual({ changed: true }); + + expect(subscriptionsUpdate).toHaveBeenCalledWith( + 'sub_1', + expect.objectContaining({ + items: [{ id: 'si_1', price: 'price_1TS3ziCkFWhKYvHI1nbXC7UU' }], + }), + expect.anything(), + ); + expect(organizationBillingSubscriptionUpdate).toHaveBeenCalledWith( + expect.objectContaining({ + where: { id: 'obs_1' }, + data: expect.objectContaining({ + skuKey: 'pentest_monthly_3', + stripeSubscriptionItemId: 'si_1', + includedQuantity: 3, }), }), ); + expect(writeAuditEvent).toHaveBeenCalledWith( + expect.objectContaining({ + organizationId: 'org_1', + eventType: 'subscription_plan_changed', + skuKey: 'pentest_monthly_3', + }), + ); }); it('does not create subscription checkout for one-time SKUs', async () => { @@ -96,6 +311,7 @@ describe('BillingService', () => { mockStripeService({ checkout: { sessions: { create: jest.fn() } }, }), + { syncSubscriptionItem: jest.fn() } as never, ); await expect( @@ -109,12 +325,14 @@ describe('BillingService', () => { }); it('returns a controlled error when Stripe is not configured', async () => { - const service = new BillingService(mockStripeService(null)); + const service = new BillingService(mockStripeService(null), { + syncSubscriptionItem: jest.fn(), + } as never); await expect( service.createSubscriptionCheckoutSession({ organizationId: 'org_1', - skuKey: 'pentest_monthly_5', + skuKey: 'pentest_monthly_1', successUrl: 'http://localhost:3000/org_1/settings/billing/success', cancelUrl: 'http://localhost:3000/org_1/settings/billing', }), diff --git a/apps/api/src/billing/billing.service.ts b/apps/api/src/billing/billing.service.ts index a1e718a812..43fa2d84cb 100644 --- a/apps/api/src/billing/billing.service.ts +++ b/apps/api/src/billing/billing.service.ts @@ -5,12 +5,15 @@ import { } from '@nestjs/common'; import { db } from '@db'; import { - type BillingSkuKey, + type BillingProductKey, getBillingSku, + getBillingSkuProductKey, + resolveBillingCatalogEnvironment, isSubscriptionBillingSkuKey, } from '@trycompai/billing'; import { StripeService } from '../stripe/stripe.service'; import { findOrCreateBillingCustomer } from './billing-customer'; +import { BillingEntitlementsService } from './billing-entitlements.service'; import { listBillingInvoices } from './billing-invoices'; import { type BillingPreferencesInput, @@ -18,14 +21,24 @@ import { updateBillingPreferences, } from './billing-preferences'; import { validateBillingRedirectUrl } from './billing-redirect-urls'; +import { + createBillingSetupSession, + handleBillingSetupSuccess, +} from './billing-setup-sessions'; import { assertStripeBillingConfigured } from './billing-stripe-config'; -import { extractStripeId } from './billing-stripe-ids'; +import { + changeSubscriptionPlan, + findProductSubscriptions, +} from './billing-subscription-plans'; import type { BillingStatus } from './billing.types'; import { listBillingUsageRows } from './billing-usage'; @Injectable() export class BillingService { - constructor(private readonly stripeService: StripeService) {} + constructor( + private readonly stripeService: StripeService, + private readonly entitlements: BillingEntitlementsService, + ) {} async getStatus(organizationId: string): Promise { const [ @@ -73,6 +86,7 @@ export class BillingService { setupAt: billing?.paymentMethodUpdatedAt ?? null, usage: { backgroundChecks, penetrationTests }, preferences, + trialEligibility: getTrialEligibility(subscriptions), usageRows, subscriptions: subscriptions.map((subscription) => ({ skuKey: subscription.skuKey, @@ -120,107 +134,20 @@ export class BillingService { cancelUrl: string; customerEmail?: string; }): Promise<{ url: string }> { - validateBillingRedirectUrl(params.successUrl); - validateBillingRedirectUrl(params.cancelUrl); - assertStripeBillingConfigured(this.stripeService); - - const stripe = this.stripeService.getClient(); - const customerId = await findOrCreateBillingCustomer({ + return createBillingSetupSession({ + ...params, stripeService: this.stripeService, - organizationId: params.organizationId, - customerEmail: params.customerEmail, - }); - - const session = await stripe.checkout.sessions.create({ - mode: 'setup', - customer: customerId, - currency: 'usd', - success_url: params.successUrl, - cancel_url: params.cancelUrl, - metadata: { - organizationId: params.organizationId, - source: 'comp-billing-setup', - }, }); - - if (!session.url) { - throw new BadRequestException( - 'Failed to create Stripe Checkout session.', - ); - } - - return { url: session.url }; } async handleSetupSuccess(params: { organizationId: string; sessionId: string; }): Promise<{ success: true }> { - assertStripeBillingConfigured(this.stripeService); - - const stripe = this.stripeService.getClient(); - const session = await stripe.checkout.sessions.retrieve(params.sessionId, { - expand: ['setup_intent'], - }); - - if (session.status !== 'complete') { - throw new BadRequestException('Checkout session is not complete.'); - } - - if (session.metadata?.organizationId !== params.organizationId) { - throw new BadRequestException( - 'Checkout session does not belong to this organization.', - ); - } - - const stripeCustomerId = extractStripeId(session.customer); - if (!stripeCustomerId) { - throw new BadRequestException('Checkout session is missing a customer.'); - } - const billing = await db.organizationBilling.findUnique({ - where: { organizationId: params.organizationId }, - select: { stripeCustomerId: true }, - }); - if (billing && billing.stripeCustomerId !== stripeCustomerId) { - throw new BadRequestException( - 'Checkout session customer does not match this organization.', - ); - } - - const setupIntent = session.setup_intent; - if (!setupIntent || typeof setupIntent === 'string') { - throw new BadRequestException( - 'Checkout session is missing a setup intent.', - ); - } - - const paymentMethodId = extractStripeId(setupIntent.payment_method); - if (!paymentMethodId) { - throw new BadRequestException( - 'Setup intent is missing a payment method.', - ); - } - - await stripe.customers.update(stripeCustomerId, { - invoice_settings: { default_payment_method: paymentMethodId }, - }); - - await db.organizationBilling.upsert({ - where: { organizationId: params.organizationId }, - create: { - organizationId: params.organizationId, - stripeCustomerId, - stripePaymentMethodId: paymentMethodId, - paymentMethodUpdatedAt: new Date(), - }, - update: { - stripeCustomerId, - stripePaymentMethodId: paymentMethodId, - paymentMethodUpdatedAt: new Date(), - }, + return handleBillingSetupSuccess({ + ...params, + stripeService: this.stripeService, }); - - return { success: true }; } async createBillingPortalSession(params: { @@ -255,7 +182,7 @@ export class BillingService { successUrl: string; cancelUrl: string; customerEmail?: string; - }): Promise<{ url: string }> { + }): Promise<{ url: string } | { changed: true }> { validateBillingRedirectUrl(params.successUrl); validateBillingRedirectUrl(params.cancelUrl); if (!isSubscriptionBillingSkuKey(params.skuKey)) { @@ -263,25 +190,58 @@ export class BillingService { } assertStripeBillingConfigured(this.stripeService); - const sku = getBillingSku({ skuKey: params.skuKey }); + const environment = resolveBillingCatalogEnvironment({ + stripeSecretKey: process.env.STRIPE_SECRET_KEY, + nodeEnv: process.env.NODE_ENV, + }); + const sku = getBillingSku({ environment, skuKey: params.skuKey }); + const productSubscriptions = await findProductSubscriptions({ + organizationId: params.organizationId, + productKey: sku.productKey, + }); + const existingSubscription = + productSubscriptions.find( + (subscription) => + subscription.stripeStatus === 'active' || + subscription.stripeStatus === 'trialing', + ) ?? null; + if (existingSubscription) { + if (existingSubscription.skuKey === sku.key) { + throw new BadRequestException('This plan is already active.'); + } + return changeSubscriptionPlan({ + organizationId: params.organizationId, + subscription: existingSubscription, + skuKey: sku.key, + stripePriceId: sku.stripePriceId, + includedQuantity: sku.includedUsage?.quantity ?? 0, + stripeService: this.stripeService, + entitlements: this.entitlements, + }); + } + const customerId = await findOrCreateBillingCustomer({ stripeService: this.stripeService, organizationId: params.organizationId, customerEmail: params.customerEmail, }); const stripe = this.stripeService.getClient(); + const applyTrial = + productSubscriptions.length === 0 && typeof sku.trialDays === 'number'; const session = await stripe.checkout.sessions.create({ mode: 'subscription', customer: customerId, line_items: [{ price: sku.stripePriceId, quantity: 1 }], success_url: params.successUrl, cancel_url: params.cancelUrl, + ...(applyTrial ? { payment_method_collection: 'always' } : {}), metadata: { organizationId: params.organizationId, skuKey: sku.key, source: 'comp-billing-subscription', }, subscription_data: { + ...(applyTrial ? { trial_period_days: sku.trialDays } : {}), metadata: { organizationId: params.organizationId, skuKey: sku.key, @@ -298,3 +258,17 @@ export class BillingService { return { url: session.url }; } } + +function getTrialEligibility( + subscriptions: Array<{ skuKey: string }>, +): Record { + const productHistory = new Set(); + for (const subscription of subscriptions) { + const productKey = getBillingSkuProductKey(subscription.skuKey); + if (productKey) productHistory.add(productKey); + } + return { + pentest: !productHistory.has('pentest'), + background_check: !productHistory.has('background_check'), + }; +} diff --git a/apps/api/src/billing/billing.types.ts b/apps/api/src/billing/billing.types.ts index 8e3b2a8575..f9db394ae9 100644 --- a/apps/api/src/billing/billing.types.ts +++ b/apps/api/src/billing/billing.types.ts @@ -10,6 +10,10 @@ export interface BillingStatus { penetrationTests: number; }; preferences: BillingPreferences; + trialEligibility: { + pentest: boolean; + background_check: boolean; + }; usageRows: BillingUsageRow[]; subscriptions: Array<{ skuKey: string; diff --git a/apps/api/src/security-penetration-tests/pentest-credits.service.ts b/apps/api/src/security-penetration-tests/pentest-credits.service.ts index d2afc75bad..a2952d82f5 100644 --- a/apps/api/src/security-penetration-tests/pentest-credits.service.ts +++ b/apps/api/src/security-penetration-tests/pentest-credits.service.ts @@ -1,9 +1,4 @@ -import { - HttpException, - HttpStatus, - Injectable, - Logger, -} from '@nestjs/common'; +import { HttpException, HttpStatus, Injectable, Logger } from '@nestjs/common'; import { AuditLogEntityType, db, Prisma } from '@db'; /** @@ -58,8 +53,8 @@ export class PentestCreditsService { where: { organizationId }, }); if (!row) { - // Org has no row yet. Return a zero balance — the client UI treats - // this as "no trial granted, paid plans coming soon." + // Org has no row yet. Return a zero balance for historical/admin + // callers; customer-facing scan creation now uses subscriptions. return { balance: 0, totalGranted: 0, @@ -157,7 +152,7 @@ export class PentestCreditsService { throw new HttpException( { error: - 'No pentest runs remaining. Paid plans coming soon — contact support if you need access today.', + 'No pentest runs remaining. Choose a penetration test plan to continue.', code: 'pentest_credits_exhausted', }, HttpStatus.PAYMENT_REQUIRED, @@ -177,7 +172,12 @@ export class PentestCreditsService { totalConsumed: row.totalConsumed, lastGrantSource: row.lastGrantSource, } - : { balance: 0, totalGranted: 0, totalConsumed: 0, lastGrantSource: 'none' }; + : { + balance: 0, + totalGranted: 0, + totalConsumed: 0, + lastGrantSource: 'none', + }; this.logger.log( `[Credits] debit org=${organizationId} run=${runId ?? 'pending'} balance=${status.balance}`, ); diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.billing.spec.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.billing.spec.ts index f3c9712c68..740ddbfd1d 100644 --- a/apps/api/src/security-penetration-tests/security-penetration-tests.billing.spec.ts +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.billing.spec.ts @@ -56,20 +56,24 @@ describe('SecurityPenetrationTestsService billing usage', () => { const originalMacedApiKey = process.env.MACED_API_KEY; const mockedDb = db as unknown as MockDb; const credits: jest.Mocked< - Pick + Pick< + PentestCreditsService, + 'getStatus' | 'debitOrThrow' | 'refund' | 'writePentestAuditEntry' + > > = { getStatus: jest.fn(), debitOrThrow: jest.fn(), refund: jest.fn(), + writePentestAuditEntry: jest.fn(), }; const billingEntitlements: jest.Mocked< Pick< BillingEntitlementsService, - 'tryConsumeIncludedUsage' | 'refundIncludedUsage' + 'tryConsumeIncludedUsageForProduct' | 'refundIncludedUsageForProduct' > > = { - tryConsumeIncludedUsage: jest.fn(), - refundIncludedUsage: jest.fn(), + tryConsumeIncludedUsageForProduct: jest.fn(), + refundIncludedUsageForProduct: jest.fn(), }; let service: SecurityPenetrationTestsService; @@ -87,11 +91,12 @@ describe('SecurityPenetrationTestsService billing usage', () => { lastGrantSource: 'trial', }); credits.refund.mockResolvedValue(); - billingEntitlements.tryConsumeIncludedUsage.mockResolvedValue({ + credits.writePentestAuditEntry.mockResolvedValue(); + billingEntitlements.tryConsumeIncludedUsageForProduct.mockResolvedValue({ status: 'consumed', subscriptionId: 'obs_1', }); - billingEntitlements.refundIncludedUsage.mockResolvedValue(); + billingEntitlements.refundIncludedUsageForProduct.mockResolvedValue(); mockedDb.securityPenetrationTestRun.upsert.mockResolvedValue({}); mockedDb.securityPenetrationTestRun.updateMany.mockResolvedValue({ count: 1, @@ -125,6 +130,35 @@ describe('SecurityPenetrationTestsService billing usage', () => { ); }); + it('requires a subscription or free trial instead of debiting wallet credits', async () => { + billingEntitlements.tryConsumeIncludedUsageForProduct.mockResolvedValue({ + status: 'not_configured', + }); + + await expect( + service.createReport('org_123', { + targetUrl: 'https://app.example.com', + }), + ).rejects.toMatchObject({ + status: 402, + response: expect.objectContaining({ + code: 'pentest_subscription_required', + }), + }); + + expect(credits.debitOrThrow).not.toHaveBeenCalled(); + expect(credits.writePentestAuditEntry).toHaveBeenCalledWith( + expect.objectContaining({ + organizationId: 'org_123', + action: 'pentest_create_blocked', + metadata: expect.objectContaining({ + reason: 'pentest_subscription_required', + }), + }), + ); + expect(mockMacedPentestsCreate).not.toHaveBeenCalled(); + }); + it('refunds subscription usage on terminal failure for subscription-backed runs', async () => { mockedDb.securityPenetrationTestRun.findUnique.mockResolvedValue({ organizationId: 'org_123', @@ -142,9 +176,11 @@ describe('SecurityPenetrationTestsService billing usage', () => { 'pentest.failed', ); - expect(billingEntitlements.refundIncludedUsage).toHaveBeenCalledWith({ + expect( + billingEntitlements.refundIncludedUsageForProduct, + ).toHaveBeenCalledWith({ organizationId: 'org_123', - skuKey: 'pentest_monthly_5', + productKey: 'pentest', sourceResourceId: 'pending:run_subscription', reason: 'pentest.failed', tx: mockedDb, @@ -172,6 +208,8 @@ describe('SecurityPenetrationTestsService billing usage', () => { 'pentest.failed', mockedDb, ); - expect(billingEntitlements.refundIncludedUsage).not.toHaveBeenCalled(); + expect( + billingEntitlements.refundIncludedUsageForProduct, + ).not.toHaveBeenCalled(); }); }); diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.service.spec.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.service.spec.ts index 6bee1e0dc2..5c26b06c51 100644 --- a/apps/api/src/security-penetration-tests/security-penetration-tests.service.spec.ts +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.service.spec.ts @@ -26,11 +26,11 @@ const mockPentestCreditsService: jest.Mocked< const mockBillingEntitlementsService: jest.Mocked< Pick< BillingEntitlementsService, - 'tryConsumeIncludedUsage' | 'refundIncludedUsage' + 'tryConsumeIncludedUsageForProduct' | 'refundIncludedUsageForProduct' > > = { - tryConsumeIncludedUsage: jest.fn(), - refundIncludedUsage: jest.fn(), + tryConsumeIncludedUsageForProduct: jest.fn(), + refundIncludedUsageForProduct: jest.fn(), }; jest.mock('@db', () => ({ @@ -113,10 +113,10 @@ describe('SecurityPenetrationTestsService', () => { lastGrantSource: 'trial', }); mockPentestCreditsService.refund.mockResolvedValue(); - mockBillingEntitlementsService.tryConsumeIncludedUsage.mockResolvedValue({ + mockBillingEntitlementsService.tryConsumeIncludedUsageForProduct.mockResolvedValue({ status: 'not_configured', }); - mockBillingEntitlementsService.refundIncludedUsage.mockResolvedValue(); + mockBillingEntitlementsService.refundIncludedUsageForProduct.mockResolvedValue(); service = new SecurityPenetrationTestsService( mockPentestCreditsService as unknown as PentestCreditsService, mockBillingEntitlementsService as unknown as BillingEntitlementsService, diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts index 843ebd75d4..b4caf5643e 100644 --- a/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts @@ -229,18 +229,13 @@ export class SecurityPenetrationTestsService { const billingUsageSourceId = `pending:${randomUUID()}`; let consumedSubscriptionAllowance = false; - // Debit FIRST so concurrent fast-clicks block at the cheap DB - // conditional update before any of them reach Maced. Without this, - // double-clicks would race past the balance check, all hit Maced - // (creating multiple paid runs), and only one would win the debit — - // we'd burn money on orphaned provider-side runs. Atomic - // `updateMany WHERE balance > 0` guarantees only one decrement - // succeeds. + // Reserve subscription allowance before calling Maced so fast double-clicks + // cannot create more paid provider runs than the organization has available. try { const subscriptionUsage = - await this.billingEntitlements.tryConsumeIncludedUsage({ + await this.billingEntitlements.tryConsumeIncludedUsageForProduct({ organizationId, - skuKey: 'pentest_monthly_5', + productKey: 'pentest', sourceResourceId: billingUsageSourceId, }); if (subscriptionUsage.status === 'exhausted') { @@ -254,7 +249,13 @@ export class SecurityPenetrationTestsService { ); } if (subscriptionUsage.status === 'not_configured') { - await this.credits.debitOrThrow(organizationId); + throw new HttpException( + { + error: 'Start a penetration test plan or free trial to run scans.', + code: 'pentest_subscription_required', + }, + HttpStatus.PAYMENT_REQUIRED, + ); } else { consumedSubscriptionAllowance = true; } @@ -264,16 +265,18 @@ export class SecurityPenetrationTestsService { error.getStatus() === HttpStatus.PAYMENT_REQUIRED ) { // Record the blocked attempt so support / compliance can answer - // "did the user try to scan after their trial was used?". Best- + // "did the user try to scan without an allowance?". Best- // effort — never let an audit-log failure hide the 402 from the // user. + const response = error.getResponse(); + const reason = getPaymentRequiredCode(response); await this.credits.writePentestAuditEntry({ organizationId, action: 'pentest_create_blocked', runId: null, - description: 'Pentest create blocked: no credits remaining', + description: 'Pentest create blocked: subscription required', metadata: { - reason: 'pentest_credits_exhausted', + reason, targetUrl: payload.targetUrl, }, }); @@ -678,9 +681,9 @@ export class SecurityPenetrationTestsService { } if (run.billingUsageSourceId) { - await this.billingEntitlements.refundIncludedUsage({ + await this.billingEntitlements.refundIncludedUsageForProduct({ organizationId: run.organizationId, - skuKey: 'pentest_monthly_5', + productKey: 'pentest', sourceResourceId: run.billingUsageSourceId, reason: eventType, tx, @@ -929,9 +932,9 @@ export class SecurityPenetrationTestsService { reason: string; }): Promise { try { - await this.billingEntitlements.refundIncludedUsage({ + await this.billingEntitlements.refundIncludedUsageForProduct({ organizationId: params.organizationId, - skuKey: 'pentest_monthly_5', + productKey: 'pentest', sourceResourceId: params.sourceResourceId, reason: params.reason, }); @@ -1088,3 +1091,11 @@ export class SecurityPenetrationTestsService { return hosts; } } + +function getPaymentRequiredCode(response: unknown): string { + if (typeof response !== 'object' || response === null) { + return 'pentest_subscription_required'; + } + const code = (response as Record).code; + return typeof code === 'string' ? code : 'pentest_subscription_required'; +} diff --git a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/BackgroundCheckDetailsForm.tsx b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/BackgroundCheckDetailsForm.tsx index b1d6c8fefe..7260c25249 100644 --- a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/BackgroundCheckDetailsForm.tsx +++ b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/BackgroundCheckDetailsForm.tsx @@ -19,7 +19,8 @@ export function BackgroundCheckDetailsForm({ isOpeningBilling, isRequesting, billingSetupComplete, - hasPaymentMethod, + backgroundChecksRemaining, + billingHref, canGoBack, onBack, onSubmit, @@ -29,7 +30,8 @@ export function BackgroundCheckDetailsForm({ isOpeningBilling: boolean; isRequesting: boolean; billingSetupComplete: boolean; - hasPaymentMethod: boolean; + backgroundChecksRemaining: number | null; + billingHref: string; canGoBack: boolean; onBack: () => void; onSubmit: (values: BackgroundCheckFormValues) => Promise; @@ -100,13 +102,29 @@ export function BackgroundCheckDetailsForm({ - {hasPaymentMethod && ( + {backgroundChecksRemaining !== null && backgroundChecksRemaining > 0 && ( - Your saved card will be charged $49 for this background check. + {backgroundChecksRemaining} background check + {backgroundChecksRemaining === 1 ? '' : 's'} remaining this period. + + )} + {(backgroundChecksRemaining === null || backgroundChecksRemaining === 0) && ( + + No background checks remaining.{' '} + + Choose a plan + + . )} diff --git a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/BackgroundCheckWizardParts.tsx b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/BackgroundCheckWizardParts.tsx index 7024ba3096..09b02c66df 100644 --- a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/BackgroundCheckWizardParts.tsx +++ b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/BackgroundCheckWizardParts.tsx @@ -39,12 +39,9 @@ export function OverviewStep({
- - Launch price - + Plans from - $99{' '} - $49 per check + $79 / month @@ -63,14 +60,14 @@ export function OverviewStep({ iconRight={} onClick={onOpenBilling} > - Set up billing + View plans )} {!hasPaymentMethod && ( - You can also manage payment methods from{' '} + Manage monthly background check credits from{' '} - Billing settings + Billing plans . diff --git a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.test.tsx b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.test.tsx index d94ea6788c..447ba9a81c 100644 --- a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.test.tsx +++ b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.test.tsx @@ -8,6 +8,7 @@ import { EmployeeBackgroundCheck } from './EmployeeBackgroundCheck'; const navigationMock = vi.hoisted(() => ({ pathname: '/org_1/people/mem_1', + push: vi.fn(), replace: vi.fn(), searchParams: new URLSearchParams(), })); @@ -29,7 +30,7 @@ vi.mock('@/lib/api-client', () => ({ vi.mock('next/navigation', () => ({ usePathname: () => navigationMock.pathname, - useRouter: () => ({ replace: navigationMock.replace }), + useRouter: () => ({ push: navigationMock.push, replace: navigationMock.replace }), useSearchParams: () => navigationMock.searchParams, })); @@ -60,6 +61,16 @@ const emptyBackgroundCheckDetails = { reportSyncedAt: null, }; +const activeBackgroundCheckSubscription = { + skuKey: 'background_checks_monthly_3', + status: 'active', + includedQuantity: 3, + usedQuantity: 1, + currentPeriodStart: '2026-04-30T00:00:00.000Z', + currentPeriodEnd: '2026-05-30T00:00:00.000Z', + cancelAtPeriodEnd: false, +}; + function renderSection(props?: Partial[0]>) { return render( new Map() }}> @@ -67,7 +78,11 @@ function renderSection(props?: Partial , @@ -77,6 +92,7 @@ function renderSection(props?: Partial { beforeEach(() => { vi.clearAllMocks(); + navigationMock.push.mockReset(); vi.mocked(apiClient.get).mockReset(); vi.mocked(apiClient.post).mockReset(); window.sessionStorage.clear(); @@ -112,13 +128,12 @@ describe('EmployeeBackgroundCheck', () => { expect( screen.getByText('Streamline employee background checks with Comp AI.'), ).toBeInTheDocument(); - expect(screen.getByText('$99')).toBeInTheDocument(); - expect(screen.getByText('$49 per check')).toBeInTheDocument(); - expect(screen.queryByText('$49', { selector: '[data-slot=\"badge\"]' })).not.toBeInTheDocument(); - expect(screen.getByRole('button', { name: /set up billing/i })).toBeInTheDocument(); - expect(screen.getByRole('link', { name: /billing settings/i })).toHaveAttribute( + expect(screen.getByText('$79 / month')).toBeInTheDocument(); + expect(screen.queryByText(/charged \$49/i)).not.toBeInTheDocument(); + expect(screen.getByRole('button', { name: /view plans/i })).toBeInTheDocument(); + expect(screen.getByRole('link', { name: /billing plans/i })).toHaveAttribute( 'href', - '/org_1/settings/billing', + '/org_1/settings/billing/add-ons/background-checks', ); }); @@ -128,7 +143,7 @@ describe('EmployeeBackgroundCheck', () => { expect(screen.getByText('Employee Background Check')).toBeInTheDocument(); expect(screen.getByLabelText('Personal email')).toBeInTheDocument(); expect( - screen.getByText('Your saved card will be charged $49 for this background check.'), + screen.getByText('2 background checks remaining this period.'), ).toBeInTheDocument(); expect(screen.queryByRole('button', { name: /back/i })).not.toBeInTheDocument(); }); @@ -170,152 +185,44 @@ describe('EmployeeBackgroundCheck', () => { expect(screen.getByText(/spam or junk folders/i)).toBeInTheDocument(); }); - it('starts billing setup from the overview when no payment method exists', async () => { + it('opens plan selection from the overview when no subscription exists', async () => { const user = userEvent.setup(); - vi.mocked(apiClient.post).mockResolvedValueOnce({ - data: {}, - status: 200, - }); renderSection({ initialBillingStatus: { hasPaymentMethod: false, setupAt: null }, }); - await user.click(screen.getByRole('button', { name: /set up billing/i })); + await user.click(screen.getByRole('button', { name: /view plans/i })); - await waitFor(() => { - expect(apiClient.post).toHaveBeenCalledWith( - '/v1/background-check-billing/setup-session', - expect.objectContaining({ - successUrl: expect.stringContaining('background_check_billing=success'), - cancelUrl: 'http://localhost:3000/org_1/settings/billing', - }), - 'org_1', - ); - }); - expect(apiClient.post).toHaveBeenCalledWith( - '/v1/background-check-billing/setup-session', - expect.objectContaining({ - successUrl: expect.stringContaining('/org_1/settings/billing?'), - }), - 'org_1', + expect(navigationMock.push).toHaveBeenCalledWith( + '/org_1/settings/billing/add-ons/background-checks', ); expect( window.sessionStorage.getItem('background-check:org_1:mem_1:pending-request'), ).toBeNull(); }); - it('restores the pending check after Stripe setup before completing it', async () => { + it('stores the pending check and routes to plans when allowance disappears', async () => { const user = userEvent.setup(); - navigationMock.pathname = '/org_1/people/mem_1'; - navigationMock.searchParams = new URLSearchParams( - 'background_check_billing=success&background_check_step=details&session_id=cs_1', - ); - window.sessionStorage.setItem( - 'background-check:org_1:mem_1:pending-request', - JSON.stringify({ - organizationId: 'org_1', - memberId: 'mem_1', - requesterNotes: 'Recruiting requested an expedited check.', - }), - ); - vi.mocked(apiClient.get).mockImplementation(async (endpoint) => { - if (endpoint === '/v1/background-check-billing/status') { - return { - data: { hasPaymentMethod: true, setupAt: '2026-04-29T12:00:00.000Z' }, - status: 200, - }; - } - return { data: null, status: 200 }; - }); - vi.mocked(apiClient.post) - .mockResolvedValueOnce({ - data: { success: true }, - status: 200, - }) - .mockResolvedValueOnce({ - data: { - id: 'bcr_1', - employeeName: 'Ada Lovelace', - employeeEmail: 'ada@example.com', - requesterNotes: 'Recruiting requested an expedited check.', - candidateUrl: 'https://identity.trycomp.ai/cand_1', - status: 'invited', - lastSyncedAt: null, - ...emptyBackgroundCheckDetails, - }, - status: 200, - }); - - renderSection({ - initialBillingStatus: { hasPaymentMethod: false, setupAt: null }, + vi.mocked(apiClient.post).mockResolvedValueOnce({ + error: 'No credits', + status: 402, }); + renderSection(); - await waitFor(() => { - expect(apiClient.post).toHaveBeenCalledWith( - '/v1/background-check-billing/setup-success', - { sessionId: 'cs_1' }, - 'org_1', - ); - }); - expect(apiClient.post).not.toHaveBeenCalledWith( - '/v1/people/mem_1/background-check', - expect.anything(), - 'org_1', - ); - expect(await screen.findByText('Payment method saved')).toBeInTheDocument(); - expect(screen.getByLabelText('Personal email')).toHaveValue(''); - expect( - screen.getByDisplayValue('Recruiting requested an expedited check.'), - ).toBeInTheDocument(); - expect(window.sessionStorage.getItem('background-check:org_1:mem_1:pending-request')).toContain( + await user.type(screen.getByLabelText('Personal email'), 'ada@example.com'); + await user.type( + screen.getByLabelText('Additional information'), 'Recruiting requested an expedited check.', ); - - await user.type(screen.getByLabelText('Personal email'), 'ada@example.com'); await user.click(screen.getByRole('button', { name: /complete/i })); await waitFor(() => { - expect(apiClient.post).toHaveBeenCalledWith( - '/v1/people/mem_1/background-check', - expect.objectContaining({ - employeeEmail: 'ada@example.com', - requesterNotes: 'Recruiting requested an expedited check.', - }), - 'org_1', + expect(navigationMock.push).toHaveBeenCalledWith( + '/org_1/settings/billing/add-ons/background-checks', ); }); expect( window.sessionStorage.getItem('background-check:org_1:mem_1:pending-request'), ).toBeNull(); }); - - it('shows an update payment dialog when payment fails', async () => { - const user = userEvent.setup(); - vi.mocked(apiClient.post) - .mockResolvedValueOnce({ - error: 'Invalid API Key provided: PLACEHOLDER', - status: 402, - }) - .mockResolvedValueOnce({ data: {}, status: 200 }); - renderSection(); - - await user.type(screen.getByLabelText('Personal email'), 'ada@example.com'); - await user.click(screen.getByRole('button', { name: /complete/i })); - - expect( - await screen.findByRole('heading', { name: /update payment method/i }), - ).toBeInTheDocument(); - expect(screen.queryByText(/PLACEHOLDER/)).not.toBeInTheDocument(); - expect( - screen.getByText('Payment failed. Update payment method and try again.'), - ).toBeInTheDocument(); - - await user.click(screen.getByRole('button', { name: /update payment method/i })); - - expect(apiClient.post).toHaveBeenCalledWith( - '/v1/background-check-billing/portal', - expect.objectContaining({ returnUrl: expect.stringContaining('/org_1/people/mem_1') }), - 'org_1', - ); - }); }); diff --git a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.tsx b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.tsx index 5caff177e8..134e2da751 100644 --- a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.tsx +++ b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.tsx @@ -3,6 +3,7 @@ import { usePermissions } from '@/hooks/use-permissions'; import { apiClient } from '@/lib/api-client'; import type { Member, User } from '@db'; +import { getBillingSkuProductKey } from '@trycompai/billing'; import { zodResolver } from '@hookform/resolvers/zod'; import { usePathname, useRouter, useSearchParams } from 'next/navigation'; import { useCallback, useEffect, useRef, useState } from 'react'; @@ -93,7 +94,20 @@ export function EmployeeBackgroundCheck({ const canRequest = hasPermission('member', 'update'); const canManageBilling = hasPermission('organization', 'update'); const hasPaymentMethod = billingStatus?.hasPaymentMethod === true; - const visibleWizardStep = hasPaymentMethod ? 'details' : wizardStep; + const backgroundCheckSubscription = (billingStatus?.subscriptions ?? []).find( + (subscription) => + getBillingSkuProductKey(subscription.skuKey) === 'background_check' && + (subscription.status === 'active' || subscription.status === 'trialing'), + ); + const backgroundChecksRemaining = backgroundCheckSubscription + ? Math.max( + backgroundCheckSubscription.includedQuantity - backgroundCheckSubscription.usedQuantity, + 0, + ) + : null; + const hasBackgroundCheckAllowance = + backgroundChecksRemaining !== null && backgroundChecksRemaining > 0; + const visibleWizardStep = hasBackgroundCheckAllowance ? 'details' : wizardStep; const writePendingRequest = useCallback( (values: BackgroundCheckFormValues) => { @@ -123,7 +137,8 @@ export function EmployeeBackgroundCheck({ if (response.error || !response.data) { if (response.status === 402) { - setPaymentIssue('Payment failed. Update payment method and try again.'); + toast.error('Choose or upgrade a background check plan to continue.'); + router.push(`/${organizationId}/settings/billing/add-ons/background-checks`); return false; } toast.error('Failed to request background check'); @@ -138,7 +153,7 @@ export function EmployeeBackgroundCheck({ clearPendingRequest(); return true; }, - [clearPendingRequest, employee.id, mutateBackgroundCheck, organizationId], + [clearPendingRequest, employee.id, mutateBackgroundCheck, organizationId, router], ); useEffect(() => { @@ -219,8 +234,9 @@ export function EmployeeBackgroundCheck({ }; const handleComplete = async (values: BackgroundCheckFormValues) => { - if (!hasPaymentMethod) { - await handleOpenBilling(values); + if (!hasBackgroundCheckAllowance) { + writePendingRequest(values); + router.push(`/${organizationId}/settings/billing/add-ons/background-checks`); return; } @@ -242,13 +258,15 @@ export function EmployeeBackgroundCheck({ <> {visibleWizardStep === 'overview' && ( setWizardStep('details')} - onOpenBilling={() => void handleOpenBilling()} + onOpenBilling={() => + router.push(`/${organizationId}/settings/billing/add-ons/background-checks`) + } /> )} {visibleWizardStep === 'details' && ( @@ -258,8 +276,9 @@ export function EmployeeBackgroundCheck({ isOpeningBilling={isOpeningBilling} isRequesting={isRequesting} billingSetupComplete={billingSetupComplete} - hasPaymentMethod={hasPaymentMethod} - canGoBack={!hasPaymentMethod} + backgroundChecksRemaining={backgroundChecksRemaining} + canGoBack={!hasBackgroundCheckAllowance} + billingHref={`/${organizationId}/settings/billing/add-ons/background-checks`} onBack={() => setWizardStep('overview')} onSubmit={handleComplete} /> diff --git a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/backgroundCheckTypes.ts b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/backgroundCheckTypes.ts index 2e130e415a..44347d001f 100644 --- a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/backgroundCheckTypes.ts +++ b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/backgroundCheckTypes.ts @@ -34,6 +34,15 @@ export interface CustomBackgroundCheckAttachment { export interface BackgroundCheckBillingStatus { hasPaymentMethod: boolean; setupAt: string | null; + subscriptions?: Array<{ + skuKey: string; + status: string; + includedQuantity: number; + usedQuantity: number; + currentPeriodStart: string | null; + currentPeriodEnd: string | null; + cancelAtPeriodEnd: boolean; + }>; } export function isCompletedBackgroundCheck(status: BackgroundCheckStatus): boolean { diff --git a/apps/app/src/app/(app)/[orgId]/security/penetration-tests/[reportId]/penetration-test-page-client.test.tsx b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/[reportId]/penetration-test-page-client.test.tsx index 7cb7d18537..da4c554598 100644 --- a/apps/app/src/app/(app)/[orgId]/security/penetration-tests/[reportId]/penetration-test-page-client.test.tsx +++ b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/[reportId]/penetration-test-page-client.test.tsx @@ -1,14 +1,24 @@ import { render, screen } from '@testing-library/react'; -import { beforeEach, describe, expect, it, vi } from 'vitest'; import type { ReactNode } from 'react'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; import type { PentestRun } from '@/lib/security/penetration-tests-client'; import { PenetrationTestPageClient } from './penetration-test-page-client'; const usePenetrationTestMock = vi.fn(); const usePenetrationTestProgressMock = vi.fn(); +const usePenetrationTestsMock = vi.fn(); const pushMock = vi.fn(); +vi.mock('@/lib/api-client', () => ({ + api: { + get: vi.fn().mockResolvedValue({ + status: 200, + data: { subscriptions: [] }, + }), + }, +})); + vi.mock('next/link', () => ({ default: ({ href, children, ...props }: { href: string; children: ReactNode }) => ( @@ -20,6 +30,15 @@ vi.mock('next/link', () => ({ vi.mock('../hooks/use-penetration-tests', () => ({ usePenetrationTest: (...args: never[]) => usePenetrationTestMock(...args), usePenetrationTestProgress: (...args: never[]) => usePenetrationTestProgressMock(...args), + usePenetrationTests: (...args: never[]) => usePenetrationTestsMock(...args), + usePenetrationTestIssues: () => ({ issues: [], isLoading: false, error: undefined }), + usePenetrationTestEvents: () => ({ events: [], isLoading: false }), + useCreatePenetrationTest: () => ({ + createReport: vi.fn(), + isCreating: false, + error: null, + resetError: vi.fn(), + }), })); vi.mock('next/navigation', () => ({ @@ -33,10 +52,19 @@ vi.mock('next/navigation', () => ({ const reportMock = usePenetrationTestMock as ReturnType; const progressMock = usePenetrationTestProgressMock as ReturnType; +const reportsMock = usePenetrationTestsMock as ReturnType; describe('PenetrationTestPageClient', () => { beforeEach(() => { vi.clearAllMocks(); + reportsMock.mockReturnValue({ + reports: [], + isLoading: false, + error: undefined, + mutate: vi.fn(), + activeReports: [], + completedReports: [], + }); }); it('shows a loading indicator before the report is available', () => { diff --git a/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/CreateRunPanel.tsx b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/CreateRunPanel.tsx index 5e78e7e73d..1b7b733ef4 100644 --- a/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/CreateRunPanel.tsx +++ b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/CreateRunPanel.tsx @@ -1,20 +1,20 @@ 'use client'; +import type { PentestCreateRequest } from '@/lib/security/penetration-tests-client'; import { Button } from '@trycompai/design-system'; import { ArrowRight, Link } from '@trycompai/design-system/icons'; import { useRouter } from 'next/navigation'; import { useState } from 'react'; import { toast } from 'sonner'; -import type { PentestCreateRequest } from '@/lib/security/penetration-tests-client'; interface CreateRunPanelProps { orgId: string; onSubmit: (payload: PentestCreateRequest) => Promise<{ id: string }>; isSubmitting?: boolean; - /** Spendable credit balance — disables submit when 0. */ + /** Subscription allowance balance — disables submit when 0. */ balance?: number; - /** True when the trial has already been used (paid plans coming soon). */ - trialUsed?: boolean; + planRequired?: boolean; + quotaLabel?: 'Plan'; } /** @@ -27,7 +27,8 @@ export function CreateRunPanel({ onSubmit, isSubmitting, balance, - trialUsed, + planRequired, + quotaLabel = 'Plan', }: CreateRunPanelProps) { const router = useRouter(); const [targetUrl, setTargetUrl] = useState(''); @@ -42,10 +43,11 @@ export function CreateRunPanel({ e.preventDefault(); if (!canCreate) { toast.error( - trialUsed - ? "You've used your trial run. Paid plans coming soon." - : 'No pentest runs remaining.', + planRequired + ? 'Start a plan or free trial to run penetration tests.' + : 'No pentest runs remaining. Choose a plan to continue.', ); + router.push(`/${orgId}/settings/billing/add-ons/penetration-tests`); return; } const normalized = normalizeUrl(targetUrl); @@ -58,9 +60,7 @@ export function CreateRunPanel({ targetUrl: normalized, ...(repoUrl.trim() ? { repoUrl: repoUrl.trim() } : {}), }); - router.push( - `/${orgId}/security/penetration-tests/${encodeURIComponent(result.id)}`, - ); + router.push(`/${orgId}/security/penetration-tests/${encodeURIComponent(result.id)}`); } catch { // onSubmit handles its own toast. } @@ -80,15 +80,15 @@ export function CreateRunPanel({ Start a penetration test

- Scans typically take 1–3 hours. Findings stream in as they're - discovered — you don't need to keep this page open. + Scans typically take 1–3 hours. Findings stream in as they're discovered — you don't + need to keep this page open.

{!canCreate && (
- {trialUsed - ? "You've used your trial run. Paid plans are coming soon — contact support if you need access today." - : 'No pentest runs remaining.'} + {planRequired + ? 'Start a plan or free trial to run penetration tests.' + : 'No pentest runs remaining. Choose a plan to continue.'}
)} @@ -100,9 +100,7 @@ export function CreateRunPanel({ Target URL
- - https:// - + https://

- Must be reachable from the scanner — localhost and private IPs - are rejected. + Must be reachable from the scanner — localhost and private IPs are rejected.

@@ -136,8 +133,7 @@ export function CreateRunPanel({ />

- Public repositories only. We use source context to write better - remediation steps. + Public repositories only. We use source context to write better remediation steps.

@@ -158,26 +154,17 @@ export function CreateRunPanel({ ))}

- Findings stream into this page as they're discovered — you can - close this tab and come back. + Findings stream into this page as they're discovered — you can close this tab and come + back.

- -
@@ -190,9 +177,7 @@ export function CreateRunPanel({ function normalizeUrl(value: string): string | null { const trimmed = value.trim(); if (!trimmed) return null; - const withProtocol = /^https?:\/\//i.test(trimmed) - ? trimmed - : `https://${trimmed}`; + const withProtocol = /^https?:\/\//i.test(trimmed) ? trimmed : `https://${trimmed}`; try { const url = new URL(withProtocol); if (url.protocol !== 'http:' && url.protocol !== 'https:') return null; diff --git a/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/DetailPane.tsx b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/DetailPane.tsx index aa182d2d93..d7913b1c35 100644 --- a/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/DetailPane.tsx +++ b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/DetailPane.tsx @@ -1,6 +1,6 @@ 'use client'; -import { AlertCircle, Loader2 } from 'lucide-react'; +import { ErrorFilled, InProgress } from '@trycompai/design-system/icons'; import type { PentestAgentEvent, PentestIssue, @@ -48,7 +48,7 @@ export function DetailPane({ if (isLoading && !run) { return (
- +
); } @@ -57,7 +57,7 @@ export function DetailPane({ return (
- +

Unable to load scan

{error?.message ?? 'No scan found for this organization.'}

diff --git a/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/EmptyState.tsx b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/EmptyState.tsx index ab3e054742..b270f32858 100644 --- a/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/EmptyState.tsx +++ b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/EmptyState.tsx @@ -1,31 +1,25 @@ 'use client'; import { Button } from '@trycompai/design-system'; -import { - Link, - Settings, - Rocket, -} from '@trycompai/design-system/icons'; +import { Link, Rocket, Settings } from '@trycompai/design-system/icons'; interface EmptyStateProps { onCreateClick: () => void; - /** Spendable balance — when 0 the CTA is disabled. */ + /** Subscription allowance balance — when 0 the CTA routes to plans. */ balance?: number; - /** True if the trial has already been used (paid plans coming soon copy). */ - trialUsed?: boolean; + planRequired?: boolean; + quotaLabel?: 'Plan'; } const STEPS = [ { title: 'Connect target', - description: - 'Point the scanner at a URL you own. HTTPS required.', + description: 'Point the scanner at a URL you own. HTTPS required.', Icon: Link, }, { title: 'Configure scope', - description: - 'Optionally attach a public repository for deeper, code-aware coverage.', + description: 'Optionally attach a public repository for deeper, code-aware coverage.', Icon: Settings, }, { @@ -39,19 +33,18 @@ const STEPS = [ export function EmptyState({ onCreateClick, balance, - trialUsed, + planRequired, + quotaLabel = 'Plan', }: EmptyStateProps) { const canCreate = balance === undefined ? true : balance > 0; - const tagline = trialUsed - ? "You've used your trial run. Paid plans are coming soon — contact support if you need access today." + const tagline = planRequired + ? 'Start a plan or 14-day free trial to run your first scan.' : 'Automated black-box pen testing. Start a scan to see findings here.'; return (
-

- Penetration Tests -

+

Penetration Tests

New @@ -59,8 +52,8 @@ export function EmptyState({

{tagline}

-
@@ -73,19 +66,14 @@ export function EmptyState({ {STEPS.map((step, i) => { const { Icon } = step; return ( -
  • +
  • {i + 1}
    {step.title}
    -
    - {step.description} -
    +
    {step.description}
  • ); diff --git a/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/RunList.tsx b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/RunList.tsx index d88c5eaab3..f184bfea5f 100644 --- a/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/RunList.tsx +++ b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/RunList.tsx @@ -1,10 +1,10 @@ 'use client'; -import { cn } from '@trycompai/design-system/cn'; +import type { PentestRun } from '@/lib/security/penetration-tests-client'; import { Progress } from '@trycompai/design-system'; +import { cn } from '@trycompai/design-system/cn'; import { Add } from '@trycompai/design-system/icons'; import { useRouter } from 'next/navigation'; -import type { PentestRun } from '@/lib/security/penetration-tests-client'; import { formatReportDate } from '../lib'; import { StatusPill } from './StatusPill'; import { isRunInProgress } from './severity'; @@ -14,10 +14,10 @@ interface RunListProps { runs: PentestRun[]; selectedRunId: string | null; onCreateClick: () => void; - /** Spendable credit balance — drives the "X runs left" badge. */ + /** Subscription allowance balance — drives the "X runs left" badge. */ balance?: number; - /** True when the user has used their initial trial (balance 0, totalGranted > 0). */ - trialUsed?: boolean; + planRequired?: boolean; + quotaLabel?: 'Plan'; } export function RunList({ @@ -26,14 +26,13 @@ export function RunList({ selectedRunId, onCreateClick, balance, - trialUsed, + planRequired, + quotaLabel = 'Plan', }: RunListProps) { const router = useRouter(); const canCreate = balance === undefined ? true : balance > 0; const newButtonTitle = !canCreate - ? trialUsed - ? "You've used your trial run. Paid plans coming soon." - : 'No pentest runs remaining.' + ? 'Choose a plan or start a free trial to continue scanning.' : 'Start a new scan'; return (
    {balance !== undefined && ( - + )} ); @@ -94,7 +90,9 @@ export function RunList({ interface QuotaFooterProps { balance: number; - trialUsed: boolean; + planRequired: boolean; + quotaLabel: 'Plan'; + orgId: string; } /** @@ -103,18 +101,16 @@ interface QuotaFooterProps { * actions. Falls back to a "Contact support" mailto when the user is at * zero so they have a clear next step. */ -function QuotaFooter({ balance, trialUsed }: QuotaFooterProps) { +function QuotaFooter({ balance, planRequired, quotaLabel, orgId }: QuotaFooterProps) { if (balance > 0) { return (
    - - {balance} - {' '} - scan{balance === 1 ? '' : 's'} remaining + {balance} scan + {balance === 1 ? '' : 's'} remaining - Trial + {quotaLabel}
    ); @@ -123,15 +119,15 @@ function QuotaFooter({ balance, trialUsed }: QuotaFooterProps) { return (
    - {trialUsed ? "You've used your trial scan" : 'No scans available'} + {planRequired ? 'Plan required' : 'No scans available'}
    @@ -164,16 +160,12 @@ function RunRow({ orgId, run, selected }: RunRowProps) { tabIndex={0} aria-current={selected ? 'true' : undefined} onClick={() => - router.push( - `/${orgId}/security/penetration-tests/${encodeURIComponent(run.id)}`, - ) + router.push(`/${orgId}/security/penetration-tests/${encodeURIComponent(run.id)}`) } onKeyDown={(e) => { if (e.key === 'Enter' || e.key === ' ') { e.preventDefault(); - router.push( - `/${orgId}/security/penetration-tests/${encodeURIComponent(run.id)}`, - ); + router.push(`/${orgId}/security/penetration-tests/${encodeURIComponent(run.id)}`); } }} className={cn( @@ -184,9 +176,7 @@ function RunRow({ orgId, run, selected }: RunRowProps) { >
    - - {shortId} - + {shortId}
    {target}
    diff --git a/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/SplitView.tsx b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/SplitView.tsx index 740b1c8302..79147c1014 100644 --- a/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/SplitView.tsx +++ b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/SplitView.tsx @@ -1,23 +1,24 @@ 'use client'; import { api } from '@/lib/api-client'; -import { cn } from '@trycompai/design-system/cn'; -import { ArrowLeft } from '@trycompai/design-system/icons'; import type { PentestCreateRequest, PentestIssue, PentestRun, } from '@/lib/security/penetration-tests-client'; +import { getBillingSkuProductKey } from '@trycompai/billing'; +import { cn } from '@trycompai/design-system/cn'; +import { ArrowLeft } from '@trycompai/design-system/icons'; import { useRouter } from 'next/navigation'; import { useState } from 'react'; import { toast } from 'sonner'; +import useSWR from 'swr'; import { useCreatePenetrationTest, usePenetrationTest, usePenetrationTestEvents, usePenetrationTestIssues, usePenetrationTests, - usePentestCredits, } from '../hooks/use-penetration-tests'; import { CreateRunPanel } from './CreateRunPanel'; import { DetailPane } from './DetailPane'; @@ -26,6 +27,15 @@ import { OverviewPane } from './OverviewPane'; import { RunList } from './RunList'; import './pentest-tokens.css'; +interface BillingStatus { + subscriptions?: Array<{ + skuKey: string; + status: string; + includedQuantity: number; + usedQuantity: number; + }>; +} + interface SplitViewProps { orgId: string; selectedRunId: string | null; @@ -38,15 +48,9 @@ interface SplitViewProps { * - `/pentests/:id` → Detail (selectedRunId set) * - `/pentests/new` → Create panel (mode create, list dimmed) */ -export function SplitView({ - orgId, - selectedRunId, - mode = 'default', -}: SplitViewProps) { +export function SplitView({ orgId, selectedRunId, mode = 'default' }: SplitViewProps) { const router = useRouter(); - const [selectedFinding, setSelectedFinding] = useState( - null, - ); + const [selectedFinding, setSelectedFinding] = useState(null); const { reports, isLoading: listLoading } = usePenetrationTests(orgId); const { @@ -54,36 +58,38 @@ export function SplitView({ isLoading: runLoading, error: runError, } = usePenetrationTest(orgId, selectedRunId ?? ''); - const { issues } = usePenetrationTestIssues( - orgId, - selectedRunId ?? '', - selectedRun?.status, + const { issues } = usePenetrationTestIssues(orgId, selectedRunId ?? '', selectedRun?.status); + const { events } = usePenetrationTestEvents(orgId, selectedRunId ?? '', selectedRun?.status); + const { createReport, isCreating } = useCreatePenetrationTest(orgId); + const { data: billingStatus } = useSWR( + orgId ? (['/v1/billing/status', orgId] as const) : null, + async ([endpoint, organizationId]: readonly [string, string]) => { + const response = await api.get(endpoint, organizationId); + if (response.status < 200 || response.status >= 300) { + throw new Error(response.error ?? 'Failed to load billing status'); + } + return response.data ?? {}; + }, ); - const { events } = usePenetrationTestEvents( - orgId, - selectedRunId ?? '', - selectedRun?.status, + const pentestSubscription = (billingStatus?.subscriptions ?? []).find( + (subscription) => + getBillingSkuProductKey(subscription.skuKey) === 'pentest' && + (subscription.status === 'active' || subscription.status === 'trialing'), ); - const { createReport, isCreating } = useCreatePenetrationTest(orgId); - const { credits } = usePentestCredits(orgId); - // Keep `balance` undefined while credits are loading. Coalescing to 0 - // would prematurely disable "+ New scan" before we know the user's - // real balance — child props treat `undefined` as "loading, allow - // optimistic UI" and a real `0` as "confirmed empty, block create." - const balance = credits?.balance; - const trialUsed = - credits !== undefined && credits.balance === 0 && credits.totalGranted > 0; + const subscriptionBalance = pentestSubscription + ? Math.max(pentestSubscription.includedQuantity - pentestSubscription.usedQuantity, 0) + : null; + // Keep `balance` undefined while billing is loading so the page does not + // flash a blocked state before subscription allowance is known. + const balance = subscriptionBalance ?? (billingStatus === undefined ? undefined : 0); + const planRequired = subscriptionBalance === null && billingStatus !== undefined; + const quotaLabel = 'Plan'; const showEmptyState = - !listLoading && - reports.length === 0 && - selectedRunId === null && - mode !== 'create'; + !listLoading && reports.length === 0 && selectedRunId === null && mode !== 'create'; const isCreateMode = mode === 'create'; - const handleCreateSubmit = async ( - payload: PentestCreateRequest, - ): Promise<{ id: string }> => { + const handleCreateSubmit = async (payload: PentestCreateRequest): Promise<{ id: string }> => { try { const result = await createReport(payload); return { id: result.id }; @@ -128,10 +134,10 @@ export function SplitView({ filename: `penetration-test-${runId}.pdf`, }); - const goToCreate = () => - router.push(`/${orgId}/security/penetration-tests/new`); - const goToList = () => - router.push(`/${orgId}/security/penetration-tests`); + const goToCreate = () => router.push(`/${orgId}/security/penetration-tests/new`); + const goToPentestPlans = () => + router.push(`/${orgId}/settings/billing/add-ons/penetration-tests`); + const goToList = () => router.push(`/${orgId}/security/penetration-tests`); // Below `xl` (1280px) we show ONE pane at a time, picked from the URL — // list on `/pentests`, main on `/pentests/:id` and `/pentests/new`. @@ -152,9 +158,10 @@ export function SplitView({ // (4rem); the outer shell padding is undone by `-m-*`.
    ); @@ -174,17 +181,13 @@ export function SplitView({ orgId={orgId} runs={reports as PentestRun[]} selectedRunId={selectedRunId} - onCreateClick={goToCreate} + onCreateClick={balance === 0 ? goToPentestPlans : goToCreate} balance={balance} - trialUsed={trialUsed} + planRequired={planRequired} + quotaLabel={quotaLabel} />
    -
    +
    {/* Back-to-list bar shown below xl. The sidebar is hidden on phones / tablets / narrow laptops once a run is selected (or in create mode), so we surface a persistent path back to the @@ -208,13 +211,14 @@ export function SplitView({ onSubmit={handleCreateSubmit} isSubmitting={isCreating} balance={balance} - trialUsed={trialUsed} + planRequired={planRequired} + quotaLabel={quotaLabel} /> ) : selectedRunId === null ? ( 0} onDownloadMarkdown={handleDownloadMarkdownById} onDownloadPdf={handleDownloadPdfById} @@ -251,9 +255,7 @@ async function downloadArtifact({ // whether `filename` is set — both Markdown and PDF callers pass a // filename, so the previous `filename ? pdf : md` check requested // application/pdf for both formats. - const accept = filename?.toLowerCase().endsWith('.pdf') - ? 'application/pdf' - : 'text/markdown'; + const accept = filename?.toLowerCase().endsWith('.pdf') ? 'application/pdf' : 'text/markdown'; try { const response = await api.raw(path, { method: 'GET', @@ -262,9 +264,7 @@ async function downloadArtifact({ }); if (!response.ok) { const body = await response.text().catch(() => ''); - throw new Error( - safeErrorMessage(body) ?? `Request failed with status ${response.status}`, - ); + throw new Error(safeErrorMessage(body) ?? `Request failed with status ${response.status}`); } const blob = await response.blob(); const objectUrl = URL.createObjectURL(blob); @@ -276,9 +276,7 @@ async function downloadArtifact({ document.body.removeChild(link); window.setTimeout(() => URL.revokeObjectURL(objectUrl), 60_000); } catch (err) { - toast.error( - err instanceof Error ? err.message : 'Unable to download report', - ); + toast.error(err instanceof Error ? err.message : 'Unable to download report'); } } diff --git a/apps/app/src/app/(app)/[orgId]/security/penetration-tests/hooks/use-penetration-tests.test.tsx b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/hooks/use-penetration-tests.test.tsx index 022a64fe66..aa94e781a8 100644 --- a/apps/app/src/app/(app)/[orgId]/security/penetration-tests/hooks/use-penetration-tests.test.tsx +++ b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/hooks/use-penetration-tests.test.tsx @@ -154,9 +154,12 @@ describe('use-penetration-tests hooks', () => { }); it('skips progress polling when report is completed', () => { - const { result } = renderHook(() => usePenetrationTestProgress('org_123', 'run_completed', 'completed'), { - wrapper, - }); + const { result } = renderHook( + () => usePenetrationTestProgress('org_123', 'run_completed', 'completed'), + { + wrapper, + }, + ); expect(result.current.progress).toBeUndefined(); expect(fetchMock).not.toHaveBeenCalled(); @@ -228,15 +231,10 @@ describe('use-penetration-tests hooks', () => { expect(requestBody.repoUrl).toBeUndefined(); }); - it('billing action failure surfaces the error after run creation', async () => { - // First call: create pentest (success) + it('creates a run through the subscription-gated create endpoint', async () => { fetchMock.mockResolvedValueOnce( createJsonResponse({ id: 'run_billed', status: 'provisioning' }), ); - // Second call: billing charge (failure via API) - fetchMock.mockResolvedValueOnce( - createJsonResponse({ error: 'No active pentest subscription.' }, 402), - ); const { result } = renderHook(() => useCreatePenetrationTest('org_123'), { wrapper }); @@ -245,11 +243,11 @@ describe('use-penetration-tests hooks', () => { result.current.createReport({ targetUrl: 'https://app.example.com', }), - ).rejects.toThrow('No active pentest subscription.'); + ).resolves.toMatchObject({ id: 'run_billed' }); }); - expect(fetchMock).toHaveBeenCalledTimes(2); - expect(result.current.error).toBe('No active pentest subscription.'); + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(result.current.error).toBeNull(); }); it('surfaces json provider error objects from create response', async () => { diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnPlansClient.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnPlansClient.tsx index 3e29e6dd0c..5fbf1a6245 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnPlansClient.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnPlansClient.tsx @@ -35,7 +35,7 @@ export function BillingAddOnPlansClient({ const { hasPermission } = usePermissions(); const canManageBilling = hasPermission('organization', 'update'); - const { data: billingStatus } = useSWR( + const { data: billingStatus, mutate: mutateBillingStatus } = useSWR( ['/v1/billing/status', organizationId], async ([endpoint]) => { const response = await apiClient.get(endpoint, organizationId); @@ -54,12 +54,17 @@ export function BillingAddOnPlansClient({ () => billingStatus?.subscriptions ?? initialBillingStatus.subscriptions ?? [], [billingStatus?.subscriptions, initialBillingStatus.subscriptions], ); + const trialEligibility = billingStatus?.trialEligibility ?? + initialBillingStatus.trialEligibility ?? { + pentest: false, + background_check: false, + }; const handleOpenSubscription = async (skuKey: BillingSkuKey) => { setLoadingSubscriptionSku(skuKey); const returnUrl = `${window.location.origin}/${organizationId}/settings/billing/add-ons/${addOn.slug}`; - const response = await apiClient.post<{ url: string }>( + const response = await apiClient.post<{ url?: string; changed?: true }>( '/v1/billing/subscription-session', { skuKey, @@ -74,6 +79,13 @@ export function BillingAddOnPlansClient({ return; } + if (response.data?.changed) { + toast.success('Plan updated'); + await mutateBillingStatus(); + setLoadingSubscriptionSku(null); + return; + } + toast.error('Failed to open checkout'); setLoadingSubscriptionSku(null); }; @@ -105,13 +117,11 @@ export function BillingAddOnPlansClient({ } > -
    +
    ; + trialEligibility?: BackgroundCheckBillingStatus['trialEligibility']; }) { const addOn = getBillingAddOn(addOnSlug); if (!addOn) throw new Error(`Missing test add-on: ${addOnSlug}`); @@ -66,6 +69,7 @@ function renderAddOnPlans({ initialBillingStatus={{ ...emptyBillingStatus, subscriptions, + trialEligibility, }} /> , @@ -102,13 +106,13 @@ describe('billing add-ons', () => { const user = userEvent.setup(); renderAddOnPlans({ addOnSlug: 'background-checks' }); - await user.click(screen.getByRole('button', { name: /subscribe to background checks/i })); + await user.click(screen.getByRole('button', { name: /start free trial/i })); await waitFor(() => { expect(apiClient.post).toHaveBeenCalledWith( '/v1/billing/subscription-session', { - skuKey: 'background_checks_monthly_25', + skuKey: 'background_checks_monthly_3', successUrl: 'http://localhost:3000/org_1/settings/billing/add-ons/background-checks?billing_subscription=success&session_id={CHECKOUT_SESSION_ID}', cancelUrl: 'http://localhost:3000/org_1/settings/billing/add-ons/background-checks', @@ -125,11 +129,20 @@ describe('billing add-ons', () => { 'href', '/org_1/settings/billing', ); - expect(screen.getByRole('heading', { name: 'Penetration Test' })).toBeInTheDocument(); + expect(screen.getByRole('heading', { name: 'Penetration Tests' })).toBeInTheDocument(); expect(screen.getByRole('tab', { name: /^overview$/i })).toBeInTheDocument(); - expect( - screen.getByRole('button', { name: /subscribe to penetration tests/i }), - ).toBeInTheDocument(); + expect(screen.getByRole('button', { name: /start free trial/i })).toBeInTheDocument(); + expect(screen.getByText('14-day free trial')).toBeInTheDocument(); + }); + + it('hides trial copy once a product has subscription history', () => { + renderAddOnPlans({ + addOnSlug: 'penetration-tests', + trialEligibility: { pentest: false, background_check: true }, + }); + + expect(screen.queryByText('14-day free trial')).not.toBeInTheDocument(); + expect(screen.getByRole('button', { name: /start monthly scans/i })).toBeInTheDocument(); }); it('shows active add-on subscriptions as disabled plan actions', async () => { @@ -138,9 +151,9 @@ describe('billing add-ons', () => { addOnSlug: 'penetration-tests', subscriptions: [ { - skuKey: 'pentest_monthly_5', + skuKey: 'pentest_monthly_1', status: 'active', - includedQuantity: 5, + includedQuantity: 1, usedQuantity: 1, currentPeriodStart: '2026-04-30T00:00:00.000Z', currentPeriodEnd: '2026-05-30T00:00:00.000Z', @@ -149,14 +162,14 @@ describe('billing add-ons', () => { ], }); - const activeButton = screen.getByRole('button', { name: /active subscription/i }); + const activeButton = screen.getByRole('button', { name: /current plan/i }); expect(activeButton).toBeDisabled(); - expect(screen.getByText('1 of 5 used this period.')).toBeInTheDocument(); + expect(screen.getByText(/0 of 1.*remaining this period/i)).toBeInTheDocument(); await user.click(activeButton); expect(apiClient.post).not.toHaveBeenCalledWith( '/v1/billing/subscription-session', - expect.objectContaining({ skuKey: 'pentest_monthly_5' }), + expect.objectContaining({ skuKey: 'pentest_monthly_1' }), 'org_1', ); }); diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnsOverview.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnsOverview.tsx index 5b51244841..4db2655ca8 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnsOverview.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnsOverview.tsx @@ -2,6 +2,7 @@ import { Badge, Button, Stack, Text } from '@trycompai/design-system'; import { ArrowRight } from '@trycompai/design-system/icons'; +import { getBillingSku, getBillingSkuProductKey } from '@trycompai/billing'; import { useRouter } from 'next/navigation'; import { billingAddOns } from './billingAddOns'; import type { BackgroundCheckBillingStatus } from './types'; @@ -18,40 +19,86 @@ export function BillingAddOnsOverview({ const router = useRouter(); return ( -
    +
    {billingAddOns.map((addOn) => { - const skuKeys: readonly string[] = addOn.skuKeys; + const productKey = getBillingSku({ skuKey: addOn.skuKeys[0] }).productKey; const activeSubscription = subscriptions.find( (subscription) => - skuKeys.includes(subscription.skuKey) && + getBillingSkuProductKey(subscription.skuKey) === productKey && (subscription.status === 'active' || subscription.status === 'trialing'), ); + const remaining = activeSubscription + ? Math.max(activeSubscription.includedQuantity - activeSubscription.usedQuantity, 0) + : null; return ( -
    - - -
    - {addOn.name} +
    +
    + + +
    +
    + {addOn.name} + + {addOn.summary} + +
    {activeSubscription && Active}
    {addOn.description} -
    - - - {addOn.summary} + {addOn.proof} - {activeSubscription && ( - - {activeSubscription.usedQuantity} of {activeSubscription.includedQuantity} used - this period. - - )} +
    + {addOn.skuKeys.map((skuKey) => { + const sku = getBillingSku({ skuKey }); + return ( +
    +
    + + {formatAmount(sku.unitAmount)} + + + /mo + +
    + + {sku.includedUsage?.quantity}{' '} + {formatUnit({ + unit: sku.includedUsage?.unit, + quantity: sku.includedUsage?.quantity, + })} + +
    + ); + })} +
    + +
    + + {remaining === null + ? 'No active subscription yet.' + : `${remaining} of ${activeSubscription?.includedQuantity ?? 0} ${formatUnit({ + unit: + activeSubscription && + getBillingSkuProductKey(activeSubscription.skuKey) === 'pentest' + ? 'scan' + : 'background_check', + quantity: activeSubscription?.includedQuantity, + })} remaining this period.`} + +
    +
    @@ -72,3 +119,16 @@ export function BillingAddOnsOverview({
    ); } + +function formatAmount(amount: number) { + return new Intl.NumberFormat('en-US', { + style: 'currency', + currency: 'USD', + maximumFractionDigits: 0, + }).format(amount / 100); +} + +function formatUnit(params: { unit?: string; quantity?: number }) { + if (params.unit === 'scan') return params.quantity === 1 ? 'scan' : 'scans'; + return params.quantity === 1 ? 'check' : 'checks'; +} diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.tsx index faca0cf6e7..64eec73387 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.tsx @@ -168,7 +168,7 @@ export function BillingSettingsClient({
    ; + trialEligibility: { + pentest: boolean; + background_check: boolean; + }; disabled: boolean; loadingSkuKey: string | null; onSubscribe: (skuKey: BillingSkuKey) => void; } const planSkuKeys = [ - 'pentest_monthly_5', - 'background_checks_monthly_25', + 'pentest_monthly_1', + 'pentest_monthly_3', + 'pentest_monthly_5_current', + 'background_checks_monthly_3', + 'background_checks_monthly_10', + 'background_checks_monthly_20', ] as const satisfies readonly BillingSkuKey[]; export function BillingSubscriptionPlans({ skuKeys = planSkuKeys, subscriptions, + trialEligibility, disabled, loadingSkuKey, onSubscribe, }: BillingSubscriptionPlansProps) { return ( -
    +
    {skuKeys.map((skuKey) => { const sku = getBillingSku({ skuKey }); - const subscription = subscriptions.find((item) => item.skuKey === skuKey); + const subscription = subscriptions.find( + (item) => getBillingSkuProductKey(item.skuKey) === sku.productKey, + ); const active = subscription?.status === 'active' || subscription?.status === 'trialing'; + const current = active && subscription?.skuKey === skuKey; const included = sku.includedUsage; + const remaining = subscription + ? Math.max(subscription.includedQuantity - subscription.usedQuantity, 0) + : null; + const unit = included ? formatUsageUnit(included.unit, included.quantity) : 'credits'; + const cta = getPlanCta({ + active, + productKey: sku.productKey, + quantity: included?.quantity ?? 0, + }); + const trialEligible = + !active && typeof sku.trialDays === 'number' && trialEligibility[sku.productKey]; + const buttonLabel = trialEligible ? 'Start free trial' : cta; return ( -
    - - - {sku.name} +
    +
    + +
    + {sku.name} + {current && Current} + {trialEligible && 14-day free trial} +
    {sku.description} @@ -45,38 +78,48 @@ export function BillingSubscriptionPlans({ {formatAmount(sku.unitAmount)} - / month + / mo {included && ( - {included.quantity} {formatUsageUnit(included.unit)} included monthly + {included.quantity} {unit} every month )} - {active && subscription ? ( - +
    + + {trialEligible + ? 'Try it free for 14 days. Add a card now, pay only if you keep it.' + : getPlanPromise(sku.productKey, included?.quantity ?? 0)} + +
    + {current && subscription ? ( +
    - {subscription.usedQuantity} of {subscription.includedQuantity} used this - period. + {remaining} of {subscription.includedQuantity} {unit} + remaining this period. - - +
    ) : ( - +
    + +
    )} -
    +
    ); })} @@ -92,6 +135,38 @@ function formatAmount(amount: number) { }).format(amount / 100); } -function formatUsageUnit(unit: string) { - return unit === 'scan' ? 'scans' : 'background checks'; +function formatUsageUnit(unit: string, quantity: number) { + if (unit === 'scan') return quantity === 1 ? 'scan' : 'scans'; + return quantity === 1 ? 'background check' : 'background checks'; +} + +function getPlanPromise(productKey: string, quantity: number) { + if (productKey === 'pentest') { + if (quantity === 1) return 'Validate your highest-risk app every month.'; + if (quantity === 3) return 'Cover launch windows and retest fixes without waiting.'; + return 'Keep critical surfaces continuously audit-ready.'; + } + if (quantity === 3) return 'Cover your next hires without per-check approvals.'; + if (quantity === 10) return 'Keep recruiting moving with predictable checks.'; + return 'Scale hiring without surprise background-check spend.'; +} + +function getPlanCta({ + active, + productKey, + quantity, +}: { + active: boolean; + productKey: string; + quantity: number; +}) { + const action = active ? 'Switch to' : 'Start'; + if (productKey === 'pentest') { + if (quantity === 1) return `${action} monthly scans`; + if (quantity === 3) return `${action} release coverage`; + return `${action} continuous coverage`; + } + if (quantity === 3) return `${action} hiring checks`; + if (quantity === 10) return `${action} recruiting coverage`; + return `${action} hiring at scale`; } diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingUsageTable.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingUsageTable.tsx index b6f6c8004d..8868e0d2c1 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingUsageTable.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingUsageTable.tsx @@ -1,6 +1,7 @@ 'use client'; import { Card, CardContent, CardHeader, CardTitle, Stack, Text } from '@trycompai/design-system'; +import { getBillingSkuProductKey } from '@trycompai/billing'; import type React from 'react'; import type { BackgroundCheckBillingStatus, BillingUsageRow } from './types'; @@ -15,12 +16,14 @@ export function BillingUsageTable({ subscriptions, usageRows }: BillingUsageTabl
    item.skuKey === 'pentest_monthly_5')} + subscription={subscriptions.find( + (item) => getBillingSkuProductKey(item.skuKey) === 'pentest', + )} /> item.skuKey === 'background_checks_monthly_25', + (item) => getBillingSkuProductKey(item.skuKey) === 'background_check', )} />
    diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/billingAddOns.ts b/apps/app/src/app/(app)/[orgId]/settings/billing/billingAddOns.ts index 4095f441af..7841ca1f8d 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/billingAddOns.ts +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/billingAddOns.ts @@ -10,6 +10,10 @@ export interface BillingAddOn { detailTitle: string; description: string; summary: string; + proof: string; + planTitle: string; + planDescription: string; + ctaLabel: string; skuKeys: readonly BillingSkuKey[]; } @@ -17,18 +21,34 @@ export const billingAddOns = [ { slug: 'penetration-tests', name: 'Penetration Tests', - detailTitle: 'Penetration Test', - description: 'Run security scans with monthly included usage and centralized billing.', - summary: 'Plans from $399/mo', - skuKeys: ['pentest_monthly_5'], + detailTitle: 'Penetration Tests', + description: + 'Catch exploitable issues before customers, auditors, or attackers do. Monthly scans ship with evidence your team can act on.', + summary: 'Security proof from $299/mo', + proof: 'Turn every release into an audit-ready security check.', + planTitle: 'Choose your scan cadence', + planDescription: + 'Start with one premium monthly scan, or cover releases, retests, and critical customer-facing surfaces with higher allowances.', + ctaLabel: 'Compare scan plans', + skuKeys: ['pentest_monthly_1', 'pentest_monthly_3', 'pentest_monthly_5_current'], }, { slug: 'background-checks', name: 'Background Checks', detailTitle: 'Background Checks', - description: 'Verify employees with subscription allowances or one-off checks as needed.', - summary: 'Plans from $249/mo', - skuKeys: ['background_checks_monthly_25'], + description: + 'Keep hiring moving with compliant employee checks, predictable monthly credits, and no surprise per-check invoice dance.', + summary: 'Hiring checks from $79/mo', + proof: 'Verify new hires faster and keep every result tied to the employee record.', + planTitle: 'Pick your hiring volume', + planDescription: + 'Buy the monthly allowance that matches your hiring pace while keeping premium screening spend predictable.', + ctaLabel: 'Compare check plans', + skuKeys: [ + 'background_checks_monthly_3', + 'background_checks_monthly_10', + 'background_checks_monthly_20', + ], }, ] as const satisfies readonly BillingAddOn[]; diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/types.ts b/apps/app/src/app/(app)/[orgId]/settings/billing/types.ts index f9dada1dec..71f1b1885d 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/types.ts +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/types.ts @@ -20,6 +20,10 @@ export interface BackgroundCheckBillingStatus { penetrationTests: number; }; preferences?: BillingPreferences | null; + trialEligibility?: { + pentest: boolean; + background_check: boolean; + }; usageRows?: BillingUsageRow[]; subscriptions?: Array<{ skuKey: string; diff --git a/apps/app/src/app/(app)/onboarding/components/PostPaymentOnboarding.tsx b/apps/app/src/app/(app)/onboarding/components/PostPaymentOnboarding.tsx index 4a9fda707f..de7f1fa76f 100644 --- a/apps/app/src/app/(app)/onboarding/components/PostPaymentOnboarding.tsx +++ b/apps/app/src/app/(app)/onboarding/components/PostPaymentOnboarding.tsx @@ -3,13 +3,12 @@ import { OnboardingStepInput } from '@/app/(app)/setup/components/OnboardingStepInput'; import { AnimatedWrapper } from '@/components/animated-wrapper'; import { LogoSpinner } from '@/components/logo-spinner'; +import type { Organization } from '@db'; import { Button } from '@trycompai/ui/button'; import { Form, FormControl, FormField, FormItem, FormMessage } from '@trycompai/ui/form'; -import type { Organization } from '@db'; import { AnimatePresence, motion } from 'framer-motion'; import { AlertCircle, Loader2 } from 'lucide-react'; import { useCallback, useEffect, useMemo, useState } from 'react'; -import Balancer from 'react-wrap-balancer'; import { usePostPaymentOnboarding } from '../hooks/usePostPaymentOnboarding'; import { CancelOnboardingButton } from './CancelOnboardingButton'; @@ -62,7 +61,7 @@ export function PostPaymentOnboarding({ const [hasInvalidUrl, setHasInvalidUrl] = useState(false); // Track if user has attempted to submit with invalid URLs const [showUrlError, setShowUrlError] = useState(false); - + const handleTouchedInvalidUrlChange = useCallback((hasInvalid: boolean) => { setHasInvalidUrl(hasInvalid); // Clear error if URLs are now valid @@ -104,7 +103,8 @@ export function PostPaymentOnboarding({ } // For software step, check if there's a value in software OR customVendors if (step.key === 'software') { - const hasSoftwareValue = Boolean(currentStepValue) && String(currentStepValue).trim().length > 0; + const hasSoftwareValue = + Boolean(currentStepValue) && String(currentStepValue).trim().length > 0; const hasCustomVendors = Array.isArray(customVendorsValue) && customVendorsValue.length > 0; return hasSoftwareValue || hasCustomVendors; } @@ -160,12 +160,12 @@ export function PostPaymentOnboarding({

    - {step?.question || ''} + {step?.question || ''}

    - Our AI will personalize the platform based on your answers. + Our AI will personalize the platform based on your answers.

    @@ -235,128 +235,126 @@ export function PostPaymentOnboarding({ > -

    - Please fix the invalid URL format -

    +

    Please fix the invalid URL format

    )} -
    - - - {stepIndex > 0 && ( - - - - )} - - - {isSkippable && ( + + + )} + + + {isSkippable && ( + + + + )} + + {(isLocal || canSkipOnboarding) && ( )} - - {(isLocal || canSkipOnboarding) && ( - - - )} - - {isLastStep ? ( - - ) : ( - + ) : ( + - )} - -
    + + Continue + + + )} + +
    diff --git a/apps/app/src/app/(app)/setup/actions/create-organization-minimal.ts b/apps/app/src/app/(app)/setup/actions/create-organization-minimal.ts index b759e7592d..c84c6af096 100644 --- a/apps/app/src/app/(app)/setup/actions/create-organization-minimal.ts +++ b/apps/app/src/app/(app)/setup/actions/create-organization-minimal.ts @@ -134,13 +134,6 @@ export const createOrganizationMinimal = authActionClientWithoutOrg role: 'owner', }, }, - pentestCredits: { - create: { - balance: 1, - totalGranted: 1, - lastGrantSource: 'trial', - }, - }, // Save framework context: display names for AI prompts + raw IDs for recovery context: { createMany: { diff --git a/apps/app/src/app/(app)/setup/actions/create-organization.ts b/apps/app/src/app/(app)/setup/actions/create-organization.ts index ffd66e8d4e..e0adedae9a 100644 --- a/apps/app/src/app/(app)/setup/actions/create-organization.ts +++ b/apps/app/src/app/(app)/setup/actions/create-organization.ts @@ -62,13 +62,6 @@ export const createOrganization = authActionClientWithoutOrg role: 'owner', }, }, - pentestCredits: { - create: { - balance: 1, - totalGranted: 1, - lastGrantSource: 'trial', - }, - }, context: { create: steps .filter((step) => step.key !== 'organizationName' && step.key !== 'website') diff --git a/apps/app/src/app/(app)/setup/components/OrganizationSetupForm.tsx b/apps/app/src/app/(app)/setup/components/OrganizationSetupForm.tsx index bdeafadc42..775921ff1b 100644 --- a/apps/app/src/app/(app)/setup/components/OrganizationSetupForm.tsx +++ b/apps/app/src/app/(app)/setup/components/OrganizationSetupForm.tsx @@ -5,7 +5,6 @@ import { LogoSpinner } from '@/components/logo-spinner'; import { Form, FormControl, FormField, FormItem, FormMessage } from '@trycompai/ui/form'; import { useRouter } from 'next/navigation'; import { useEffect, useState } from 'react'; -import Balancer from 'react-wrap-balancer'; import { useOnboardingForm } from '../hooks/useOnboardingForm'; import { OnboardingFormActions } from './OnboardingFormActions'; import { OnboardingStepInput } from './OnboardingStepInput'; @@ -101,13 +100,11 @@ export function OrganizationSetupForm({ {/* Title */}
    -

    - {step.question} -

    +

    {step.question}

    - Our AI will personalize the platform based on your answers. + Our AI will personalize the platform based on your answers.

    diff --git a/packages/billing/src/catalog.test.ts b/packages/billing/src/catalog.test.ts index 3755c33b4b..75fc039d6e 100644 --- a/packages/billing/src/catalog.test.ts +++ b/packages/billing/src/catalog.test.ts @@ -4,49 +4,98 @@ import { getBillingCatalog, getBillingSku, getBillingSkuByStripePriceId, + getSubscriptionBillingSkuKeysForProduct, isSubscriptionBillingSkuKey, + resolveBillingCatalogEnvironment, } from './index'; describe('billing catalog', () => { - it('records the test-mode subscription prices', () => { - const catalog = getBillingCatalog('test'); - - expect(catalog.products.pentest).toBe('prod_UQqGJryNvXajUt'); - expect(catalog.prices.pentest_monthly_5).toBe('price_1TRya6CkFWhKYvHI1sJ2M2no'); - expect(catalog.products.background_check).toBe('prod_UQNNIS1di6uOLB'); - expect(catalog.prices.background_checks_monthly_25).toBe('price_1TRya7CkFWhKYvHIDbCWSITp'); + it('records test and live subscription products', () => { + expect(getBillingCatalog('test').products).toEqual({ + pentest: 'prod_UQvQW6Pbh1HrEw', + background_check: 'prod_UQvQYmEniNCyvI', + }); + expect(getBillingCatalog('live').products).toEqual({ + pentest: 'prod_UQvQVLDFiaWxNf', + background_check: 'prod_UQvQRREj9JaIh1', + }); }); - it('models included monthly usage for subscription skus', () => { + it('models included monthly usage for all purchasable subscription SKUs', () => { + expect( + getBillingSku({ environment: 'test', skuKey: 'pentest_monthly_1' }).includedUsage, + ).toEqual({ quantity: 1, unit: 'scan', reset: 'monthly' }); expect( - getBillingSku({ environment: 'test', skuKey: 'pentest_monthly_5' }).includedUsage, + getBillingSku({ environment: 'test', skuKey: 'pentest_monthly_3' }).includedUsage, + ).toEqual({ quantity: 3, unit: 'scan', reset: 'monthly' }); + expect( + getBillingSku({ environment: 'test', skuKey: 'pentest_monthly_5_current' }).includedUsage, ).toEqual({ quantity: 5, unit: 'scan', reset: 'monthly' }); + expect( + getBillingSku({ environment: 'test', skuKey: 'background_checks_monthly_3' }).includedUsage, + ).toEqual({ quantity: 3, unit: 'background_check', reset: 'monthly' }); + expect( + getBillingSku({ environment: 'test', skuKey: 'background_checks_monthly_10' }).includedUsage, + ).toEqual({ quantity: 10, unit: 'background_check', reset: 'monthly' }); + expect( + getBillingSku({ environment: 'test', skuKey: 'background_checks_monthly_20' }).includedUsage, + ).toEqual({ quantity: 20, unit: 'background_check', reset: 'monthly' }); + }); + it('offers 14-day trials only on the lowest subscription tiers', () => { + expect(getBillingSku({ environment: 'test', skuKey: 'pentest_monthly_1' }).trialDays).toBe(14); expect( - getBillingSku({ - environment: 'test', - skuKey: 'background_checks_monthly_25', - }).includedUsage, - ).toEqual({ quantity: 25, unit: 'background_check', reset: 'monthly' }); + getBillingSku({ environment: 'test', skuKey: 'background_checks_monthly_3' }).trialDays, + ).toBe(14); + expect( + getBillingSku({ environment: 'test', skuKey: 'pentest_monthly_3' }).trialDays, + ).toBeUndefined(); + expect( + getBillingSku({ environment: 'test', skuKey: 'background_checks_monthly_10' }).trialDays, + ).toBeUndefined(); }); - it('finds skus by Stripe price id', () => { + it('finds SKUs by Stripe price id across environments', () => { expect( getBillingSkuByStripePriceId({ environment: 'test', - stripePriceId: 'price_1TRya6CkFWhKYvHI1sJ2M2no', + stripePriceId: 'price_1TS3ziCkFWhKYvHI0H5TWxNI', })?.key, - ).toBe('pentest_monthly_5'); + ).toBe('pentest_monthly_1'); expect( getBillingSkuByStripePriceId({ - environment: 'test', - stripePriceId: 'price_missing', - }), - ).toBeNull(); + environment: 'live', + stripePriceId: 'price_1TS3zVCxqPDT5y0WCsB6ywMP', + })?.key, + ).toBe('background_checks_monthly_20'); + expect(getBillingSkuByStripePriceId({ stripePriceId: 'price_missing' })).toBeNull(); }); - it('recognizes subscription sku keys', () => { - expect(isSubscriptionBillingSkuKey('pentest_monthly_5')).toBe(true); + it('recognizes only current subscription SKU keys as purchasable', () => { + expect(isSubscriptionBillingSkuKey('pentest_monthly_1')).toBe(true); + expect(isSubscriptionBillingSkuKey('pentest_monthly_3')).toBe(true); + expect(isSubscriptionBillingSkuKey('background_checks_monthly_3')).toBe(true); + expect(isSubscriptionBillingSkuKey('pentest_monthly_5')).toBe(false); + expect(isSubscriptionBillingSkuKey('pentest_monthly_10')).toBe(false); expect(isSubscriptionBillingSkuKey('background_check_one_time')).toBe(false); }); + + it('groups purchasable subscriptions by product', () => { + expect(getSubscriptionBillingSkuKeysForProduct('pentest')).toEqual([ + 'pentest_monthly_1', + 'pentest_monthly_3', + 'pentest_monthly_5_current', + ]); + expect(getSubscriptionBillingSkuKeysForProduct('background_check')).toEqual([ + 'background_checks_monthly_3', + 'background_checks_monthly_10', + 'background_checks_monthly_20', + ]); + }); + + it('resolves catalog environment from Stripe secret key prefix', () => { + expect(resolveBillingCatalogEnvironment({ stripeSecretKey: 'sk_test_123' })).toBe('test'); + expect(resolveBillingCatalogEnvironment({ stripeSecretKey: 'sk_live_123' })).toBe('live'); + expect(resolveBillingCatalogEnvironment({ nodeEnv: 'development' })).toBe('test'); + }); }); diff --git a/packages/billing/src/index.ts b/packages/billing/src/index.ts index 5de001d3ff..09b31916a2 100644 --- a/packages/billing/src/index.ts +++ b/packages/billing/src/index.ts @@ -1,12 +1,21 @@ -export type BillingCatalogEnvironment = 'test'; +import { createSkus } from './sku-definitions'; + +export type BillingCatalogEnvironment = 'test' | 'live'; export type BillingSkuKey = | 'background_check_one_time' + | 'background_checks_monthly_3' + | 'background_checks_monthly_10' + | 'background_checks_monthly_20' | 'background_checks_monthly_25' - | 'pentest_monthly_5'; + | 'pentest_monthly_1' + | 'pentest_monthly_3' + | 'pentest_monthly_4' + | 'pentest_monthly_5' + | 'pentest_monthly_5_current' + | 'pentest_monthly_10'; export type BillingProductKey = 'background_check' | 'pentest'; - export type BillingCadence = 'one_time' | 'month'; export type BillingSku = { @@ -24,11 +33,17 @@ export type BillingSku = { unit: 'background_check' | 'scan'; reset: 'monthly'; }; + trialDays?: number; + deprecated?: boolean; }; export const subscriptionBillingSkuKeys = [ - 'background_checks_monthly_25', - 'pentest_monthly_5', + 'background_checks_monthly_3', + 'background_checks_monthly_10', + 'background_checks_monthly_20', + 'pentest_monthly_1', + 'pentest_monthly_3', + 'pentest_monthly_5_current', ] as const satisfies readonly BillingSkuKey[]; export type SubscriptionBillingSkuKey = (typeof subscriptionBillingSkuKeys)[number]; @@ -40,56 +55,77 @@ export type BillingCatalog = { skus: Record; }; -const testSkus = { - background_check_one_time: { - key: 'background_check_one_time', - productKey: 'background_check', - name: 'Employee Background Check', - description: 'One-time employee background check.', - cadence: 'one_time', - currency: 'usd', - unitAmount: 4900, - stripeProductId: 'prod_UQNNIS1di6uOLB', - stripePriceId: 'price_1TRWckCkFWhKYvHIA1GLv1sO', - }, - background_checks_monthly_25: { - key: 'background_checks_monthly_25', - productKey: 'background_check', - name: 'Background Checks Monthly', - description: 'Monthly Comp AI background check package. Includes 25 checks per month.', - cadence: 'month', - currency: 'usd', - unitAmount: 24900, - stripeProductId: 'prod_UQNNIS1di6uOLB', - stripePriceId: 'price_1TRya7CkFWhKYvHIDbCWSITp', - includedUsage: { - quantity: 25, - unit: 'background_check', - reset: 'monthly', - }, +const legacyTestProducts = { + backgroundCheck: 'prod_UQNNIS1di6uOLB', + pentest: 'prod_UQqGJryNvXajUt', +}; + +const legacyLiveProducts = { + backgroundCheck: 'prod_UQNH0N7tplDVIy', + pentest: 'prod_UQqFwiuxcFTZ0G', +}; + +const testProducts = { + backgroundCheck: 'prod_UQvQYmEniNCyvI', + pentest: 'prod_UQvQW6Pbh1HrEw', +}; + +const liveProducts = { + backgroundCheck: 'prod_UQvQRREj9JaIh1', + pentest: 'prod_UQvQVLDFiaWxNf', +}; + +const testSkus = createSkus({ + products: testProducts, + legacyProducts: legacyTestProducts, + prices: { + background_check_one_time: 'price_1TRWckCkFWhKYvHIA1GLv1sO', + background_checks_monthly_3: 'price_1TS3zjCkFWhKYvHIkramHVwe', + background_checks_monthly_10: 'price_1TS3zjCkFWhKYvHIRylPAQeI', + background_checks_monthly_20: 'price_1TS3zjCkFWhKYvHIU5jMCCWs', + background_checks_monthly_25: 'price_1TRya7CkFWhKYvHIDbCWSITp', + pentest_monthly_1: 'price_1TS3ziCkFWhKYvHI0H5TWxNI', + pentest_monthly_3: 'price_1TS3ziCkFWhKYvHI1nbXC7UU', + pentest_monthly_4: 'price_1TS3ZsCkFWhKYvHIgAogX3Po', + pentest_monthly_5_current: 'price_1TS3zjCkFWhKYvHISBHjtZXB', + pentest_monthly_10: 'price_1TS3ZsCkFWhKYvHI8gLvPL1t', + pentest_monthly_5: 'price_1TRya6CkFWhKYvHI1sJ2M2no', }, - pentest_monthly_5: { - key: 'pentest_monthly_5', - productKey: 'pentest', - name: 'Penetration Tests Monthly', - description: 'Monthly Comp AI penetration testing package. Includes 5 scans per month.', - cadence: 'month', - currency: 'usd', - unitAmount: 39900, - stripeProductId: 'prod_UQqGJryNvXajUt', - stripePriceId: 'price_1TRya6CkFWhKYvHI1sJ2M2no', - includedUsage: { - quantity: 5, - unit: 'scan', - reset: 'monthly', - }, +}); + +const liveSkus = createSkus({ + products: liveProducts, + legacyProducts: legacyLiveProducts, + prices: { + background_check_one_time: 'price_1TRWWzCxqPDT5y0W2cjTNfIq', + background_checks_monthly_3: 'price_1TS3zTCxqPDT5y0WaDAtQ6EW', + background_checks_monthly_10: 'price_1TS3zUCxqPDT5y0Whecvdmjl', + background_checks_monthly_20: 'price_1TS3zVCxqPDT5y0WCsB6ywMP', + background_checks_monthly_25: 'price_legacy_background_checks_monthly_25', + pentest_monthly_1: 'price_1TS3zKCxqPDT5y0WsZnBU8NT', + pentest_monthly_3: 'price_1TS3zMCxqPDT5y0WC2OyJNAv', + pentest_monthly_4: 'price_1TS3ZcCxqPDT5y0WRHbHXuFd', + pentest_monthly_5_current: 'price_1TS3zNCxqPDT5y0WYC5mLjwA', + pentest_monthly_10: 'price_1TS3ZdCxqPDT5y0WKsHTKiTY', + pentest_monthly_5: 'price_legacy_pentest_monthly_5', }, -} satisfies Record; +}); export const billingCatalogs = { test: createCatalog({ environment: 'test', skus: testSkus }), + live: createCatalog({ environment: 'live', skus: liveSkus }), } satisfies Record; +export function resolveBillingCatalogEnvironment(params?: { + stripeSecretKey?: string | null; + nodeEnv?: string | null; +}): BillingCatalogEnvironment { + const stripeSecretKey = params?.stripeSecretKey ?? process.env.STRIPE_SECRET_KEY; + if (stripeSecretKey?.startsWith('sk_live_')) return 'live'; + if (stripeSecretKey?.startsWith('sk_test_')) return 'test'; + return (params?.nodeEnv ?? process.env.NODE_ENV) === 'production' ? 'live' : 'test'; +} + export function getBillingCatalog(environment: BillingCatalogEnvironment = 'test'): BillingCatalog { return billingCatalogs[environment]; } @@ -106,9 +142,37 @@ export function getBillingSkuByStripePriceId(params: { environment?: BillingCatalogEnvironment; stripePriceId: string; }): BillingSku | null { - const catalog = getBillingCatalog(params.environment); - return ( - Object.values(catalog.skus).find((sku) => sku.stripePriceId === params.stripePriceId) ?? null + const catalogs = params.environment + ? [getBillingCatalog(params.environment)] + : Object.values(billingCatalogs); + for (const catalog of catalogs) { + const sku = Object.values(catalog.skus).find( + (item) => item.stripePriceId === params.stripePriceId, + ); + if (sku) return sku; + } + return null; +} + +export function getBillingSkuProductKey(skuKey: string): BillingProductKey | null { + for (const catalog of Object.values(billingCatalogs)) { + const sku = catalog.skus[skuKey as BillingSkuKey]; + if (sku) return sku.productKey; + } + return null; +} + +export function getBillingSkuKeysForProduct(productKey: BillingProductKey): BillingSkuKey[] { + return Object.values(billingCatalogs.test.skus) + .filter((sku) => sku.productKey === productKey) + .map((sku) => sku.key); +} + +export function getSubscriptionBillingSkuKeysForProduct( + productKey: BillingProductKey, +): SubscriptionBillingSkuKey[] { + return subscriptionBillingSkuKeys.filter( + (skuKey) => billingCatalogs.test.skus[skuKey].productKey === productKey, ); } @@ -123,14 +187,12 @@ function createCatalog(params: { return { environment: params.environment, products: { - background_check: params.skus.background_check_one_time.stripeProductId, - pentest: params.skus.pentest_monthly_5.stripeProductId, - }, - prices: { - background_check_one_time: params.skus.background_check_one_time.stripePriceId, - background_checks_monthly_25: params.skus.background_checks_monthly_25.stripePriceId, - pentest_monthly_5: params.skus.pentest_monthly_5.stripePriceId, + background_check: params.skus.background_checks_monthly_3.stripeProductId, + pentest: params.skus.pentest_monthly_1.stripeProductId, }, + prices: Object.fromEntries( + Object.values(params.skus).map((sku) => [sku.key, sku.stripePriceId]), + ) as Record, skus: params.skus, }; } diff --git a/packages/billing/src/sku-definitions.ts b/packages/billing/src/sku-definitions.ts new file mode 100644 index 0000000000..1f5fef569e --- /dev/null +++ b/packages/billing/src/sku-definitions.ts @@ -0,0 +1,152 @@ +import type { BillingSku, BillingSkuKey } from './index'; + +export function createSkus(params: { + products: { backgroundCheck: string; pentest: string }; + legacyProducts: { backgroundCheck: string; pentest: string }; + prices: Record; +}): Record { + return { + background_check_one_time: { + key: 'background_check_one_time', + productKey: 'background_check', + name: 'Employee Background Check', + description: 'One-time employee background check.', + cadence: 'one_time', + currency: 'usd', + unitAmount: 4900, + stripeProductId: params.legacyProducts.backgroundCheck, + stripePriceId: params.prices.background_check_one_time, + deprecated: true, + }, + background_checks_monthly_3: { + key: 'background_checks_monthly_3', + productKey: 'background_check', + name: 'Hiring Starter', + description: 'Premium screening coverage for your next trusted hires.', + cadence: 'month', + currency: 'usd', + unitAmount: 7900, + stripeProductId: params.products.backgroundCheck, + stripePriceId: params.prices.background_checks_monthly_3, + includedUsage: { quantity: 3, unit: 'background_check', reset: 'monthly' }, + trialDays: 14, + }, + background_checks_monthly_10: { + key: 'background_checks_monthly_10', + productKey: 'background_check', + name: 'Hiring Momentum', + description: 'Predictable coverage for steady recruiting pipelines.', + cadence: 'month', + currency: 'usd', + unitAmount: 19900, + stripeProductId: params.products.backgroundCheck, + stripePriceId: params.prices.background_checks_monthly_10, + includedUsage: { quantity: 10, unit: 'background_check', reset: 'monthly' }, + }, + background_checks_monthly_20: { + key: 'background_checks_monthly_20', + productKey: 'background_check', + name: 'Hiring Scale', + description: 'High-volume screening with finance-friendly spend control.', + cadence: 'month', + currency: 'usd', + unitAmount: 39900, + stripeProductId: params.products.backgroundCheck, + stripePriceId: params.prices.background_checks_monthly_20, + includedUsage: { quantity: 20, unit: 'background_check', reset: 'monthly' }, + }, + background_checks_monthly_25: { + key: 'background_checks_monthly_25', + productKey: 'background_check', + name: 'Background Checks Monthly', + description: 'Legacy monthly package. Includes 25 checks per month.', + cadence: 'month', + currency: 'usd', + unitAmount: 24900, + stripeProductId: params.legacyProducts.backgroundCheck, + stripePriceId: params.prices.background_checks_monthly_25, + includedUsage: { quantity: 25, unit: 'background_check', reset: 'monthly' }, + deprecated: true, + }, + pentest_monthly_1: { + key: 'pentest_monthly_1', + productKey: 'pentest', + name: 'Launch Guard', + description: "One premium monthly scan for your app's most important surface.", + cadence: 'month', + currency: 'usd', + unitAmount: 29900, + stripeProductId: params.products.pentest, + stripePriceId: params.prices.pentest_monthly_1, + includedUsage: { quantity: 1, unit: 'scan', reset: 'monthly' }, + trialDays: 14, + }, + pentest_monthly_3: { + key: 'pentest_monthly_3', + productKey: 'pentest', + name: 'Release Shield', + description: 'Three scans for launch windows, retests, and customer-facing apps.', + cadence: 'month', + currency: 'usd', + unitAmount: 49900, + stripeProductId: params.products.pentest, + stripePriceId: params.prices.pentest_monthly_3, + includedUsage: { quantity: 3, unit: 'scan', reset: 'monthly' }, + }, + pentest_monthly_5_current: { + key: 'pentest_monthly_5_current', + productKey: 'pentest', + name: 'Continuous Assurance', + description: 'Five scans for teams shipping across multiple critical surfaces.', + cadence: 'month', + currency: 'usd', + unitAmount: 89900, + stripeProductId: params.products.pentest, + stripePriceId: params.prices.pentest_monthly_5_current, + includedUsage: { quantity: 5, unit: 'scan', reset: 'monthly' }, + }, + pentest_monthly_4: legacyPentestSku({ + key: 'pentest_monthly_4', + priceId: params.prices.pentest_monthly_4, + productId: params.products.pentest, + quantity: 4, + unitAmount: 49900, + }), + pentest_monthly_10: legacyPentestSku({ + key: 'pentest_monthly_10', + priceId: params.prices.pentest_monthly_10, + productId: params.products.pentest, + quantity: 10, + unitAmount: 79900, + }), + pentest_monthly_5: legacyPentestSku({ + key: 'pentest_monthly_5', + priceId: params.prices.pentest_monthly_5, + productId: params.legacyProducts.pentest, + quantity: 5, + unitAmount: 39900, + }), + }; +} + +function legacyPentestSku(params: { + key: 'pentest_monthly_4' | 'pentest_monthly_5' | 'pentest_monthly_10'; + priceId: string; + productId: string; + quantity: number; + unitAmount: number; +}): BillingSku { + return { + key: params.key, + productKey: 'pentest', + name: 'Penetration Tests Legacy', + description: `Legacy package. Includes ${params.quantity} scans per month.`, + cadence: 'month', + currency: 'usd', + unitAmount: params.unitAmount, + stripeProductId: params.productId, + stripePriceId: params.priceId, + includedUsage: { quantity: params.quantity, unit: 'scan', reset: 'monthly' }, + deprecated: true, + }; +} diff --git a/packages/db/prisma/migrations/20260501001000_remove_pentest_trial_credits/migration.sql b/packages/db/prisma/migrations/20260501001000_remove_pentest_trial_credits/migration.sql new file mode 100644 index 0000000000..133c48684e --- /dev/null +++ b/packages/db/prisma/migrations/20260501001000_remove_pentest_trial_credits/migration.sql @@ -0,0 +1,13 @@ +-- Remove the old customer-facing one-free-pentest-run balance now that +-- add-ons use Stripe subscription trials. Keep lifetime grant history intact; +-- this only removes one spendable wallet credit from each org that still has +-- one available. +UPDATE "pentest_credits" +SET + "balance" = GREATEST("balance" - 1, 0), + "last_grant_source" = CASE + WHEN "last_grant_source" = 'trial' THEN 'migration_remove_trial' + ELSE "last_grant_source" + END, + "updated_at" = CURRENT_TIMESTAMP +WHERE "balance" > 0; From d855b7bf6196e73975de66eb6db6ede6f6278b9d Mon Sep 17 00:00:00 2001 From: Lewis Carhart Date: Fri, 1 May 2026 00:58:41 +0100 Subject: [PATCH 07/20] chore(billing): update TypeScript configuration and import paths - Added "allowImportingTsExtensions" option to tsconfig.json for improved module resolution. - Updated import statement in index.ts to include the .ts extension for sku-definitions. --- packages/billing/src/index.ts | 2 +- packages/billing/tsconfig.json | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/billing/src/index.ts b/packages/billing/src/index.ts index 09b31916a2..330b8b6562 100644 --- a/packages/billing/src/index.ts +++ b/packages/billing/src/index.ts @@ -1,4 +1,4 @@ -import { createSkus } from './sku-definitions'; +import { createSkus } from './sku-definitions.ts'; export type BillingCatalogEnvironment = 'test' | 'live'; diff --git a/packages/billing/tsconfig.json b/packages/billing/tsconfig.json index 20fef847d2..4c11875686 100644 --- a/packages/billing/tsconfig.json +++ b/packages/billing/tsconfig.json @@ -2,6 +2,7 @@ "extends": "@trycompai/tsconfig/base.json", "compilerOptions": { "incremental": true, + "allowImportingTsExtensions": true, "tsBuildInfoFile": "node_modules/.cache/tsbuildinfo.json" }, "include": ["src"], From 50bcc7a167a6c3342cb2d3307cabc44711ed563a Mon Sep 17 00:00:00 2001 From: Lewis Carhart Date: Fri, 1 May 2026 01:00:26 +0100 Subject: [PATCH 08/20] chore(billing): update import paths for SKU definitions - Changed import statement in index.ts to use .js extension for sku-definitions. - Added sku-definitions.js to export createSkus from sku-definitions.ts. --- packages/billing/src/index.ts | 2 +- packages/billing/src/sku-definitions.js | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 packages/billing/src/sku-definitions.js diff --git a/packages/billing/src/index.ts b/packages/billing/src/index.ts index 330b8b6562..70cf4158d4 100644 --- a/packages/billing/src/index.ts +++ b/packages/billing/src/index.ts @@ -1,4 +1,4 @@ -import { createSkus } from './sku-definitions.ts'; +import { createSkus } from './sku-definitions.js'; export type BillingCatalogEnvironment = 'test' | 'live'; diff --git a/packages/billing/src/sku-definitions.js b/packages/billing/src/sku-definitions.js new file mode 100644 index 0000000000..c9534bf616 --- /dev/null +++ b/packages/billing/src/sku-definitions.js @@ -0,0 +1 @@ +export { createSkus } from './sku-definitions.ts'; From 519bebb5810cc5b301715687c400c16d1a0febfa Mon Sep 17 00:00:00 2001 From: Lewis Carhart Date: Fri, 1 May 2026 01:12:06 +0100 Subject: [PATCH 09/20] feat(billing): add billing add-ons functionality and trial eligibility - Implemented proxy handling for billing add-ons in the proxy.ts file. - Enhanced BillingAddOnsOverview component to display trial eligibility badges and messages. - Updated BillingSettingsClient to pass trial eligibility data. - Created a new catch-all page for billing add-ons to handle dynamic routing. - Added layout component for billing add-ons section. - Updated tests to cover new billing add-ons features and trial eligibility scenarios. --- .../settings/billing/BillingAddOns.test.tsx | 28 +++++++++- .../billing/BillingAddOnsOverview.tsx | 14 +++-- .../billing/BillingSettingsClient.tsx | 6 +++ .../billing/BillingSubscriptionPlans.tsx | 26 +++++----- .../settings/billing/[...addOn]/page.tsx | 51 +++++++++++++++++++ .../settings/billing/add-ons/layout.tsx | 3 ++ .../settings/billing/emptyBillingStatus.ts | 4 ++ .../(app)/[orgId]/settings/billing/page.tsx | 31 +++++++++-- apps/app/src/proxy.ts | 13 +++++ 9 files changed, 155 insertions(+), 21 deletions(-) create mode 100644 apps/app/src/app/(app)/[orgId]/settings/billing/[...addOn]/page.tsx create mode 100644 apps/app/src/app/(app)/[orgId]/settings/billing/add-ons/layout.tsx diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOns.test.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOns.test.tsx index 12c28ecc6a..b7b0329414 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOns.test.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOns.test.tsx @@ -88,10 +88,18 @@ describe('billing add-ons', () => { }); it('shows product-level add-ons before plan selection', () => { - render(); + render( + , + ); expect(screen.getByText('Penetration Tests')).toBeInTheDocument(); expect(screen.getByText('Background Checks')).toBeInTheDocument(); + expect(screen.getAllByText('14-day free trial')).toHaveLength(2); + expect(screen.getAllByText(/no charge today/i)).toHaveLength(2); screen.getByRole('button', { name: /view penetration tests plans/i }).click(); expect(navigationMock.push).toHaveBeenCalledWith( '/org_1/settings/billing/add-ons/penetration-tests', @@ -102,6 +110,24 @@ describe('billing add-ons', () => { ); }); + it('hides overview trial copy for products with subscription history', () => { + render( + , + ); + + expect(screen.getAllByText('14-day free trial')).toHaveLength(1); + expect( + screen.queryByText(/start with a 14-day free trial on the first tier/i), + ).toBeInTheDocument(); + expect( + screen.getByText('Turn every release into an audit-ready security check.'), + ).toBeInTheDocument(); + }); + it('opens subscription checkout for an add-on plan', async () => { const user = userEvent.setup(); renderAddOnPlans({ addOnSlug: 'background-checks' }); diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnsOverview.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnsOverview.tsx index 4db2655ca8..3224bcd01b 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnsOverview.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOnsOverview.tsx @@ -1,8 +1,8 @@ 'use client'; +import { getBillingSku, getBillingSkuProductKey } from '@trycompai/billing'; import { Badge, Button, Stack, Text } from '@trycompai/design-system'; import { ArrowRight } from '@trycompai/design-system/icons'; -import { getBillingSku, getBillingSkuProductKey } from '@trycompai/billing'; import { useRouter } from 'next/navigation'; import { billingAddOns } from './billingAddOns'; import type { BackgroundCheckBillingStatus } from './types'; @@ -10,11 +10,13 @@ import type { BackgroundCheckBillingStatus } from './types'; interface BillingAddOnsOverviewProps { organizationId: string; subscriptions: NonNullable; + trialEligibility?: BackgroundCheckBillingStatus['trialEligibility']; } export function BillingAddOnsOverview({ organizationId, subscriptions, + trialEligibility, }: BillingAddOnsOverviewProps) { const router = useRouter(); @@ -27,6 +29,7 @@ export function BillingAddOnsOverview({ getBillingSkuProductKey(subscription.skuKey) === productKey && (subscription.status === 'active' || subscription.status === 'trialing'), ); + const trialEligible = !activeSubscription && trialEligibility?.[productKey] === true; const remaining = activeSubscription ? Math.max(activeSubscription.includedQuantity - activeSubscription.usedQuantity, 0) : null; @@ -46,13 +49,18 @@ export function BillingAddOnsOverview({ {addOn.summary}
    - {activeSubscription && Active} +
    + {trialEligible && 14-day free trial} + {activeSubscription && Active} +
    {addOn.description} - {addOn.proof} + {trialEligible + ? 'Start with a 14-day free trial on the first tier. Card required, no charge today.' + : addOn.proof} diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.tsx index 64eec73387..b5ab644a09 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSettingsClient.tsx @@ -74,6 +74,11 @@ export function BillingSettingsClient({ () => billingStatus?.subscriptions ?? initialBillingStatus.subscriptions ?? [], [billingStatus?.subscriptions, initialBillingStatus.subscriptions], ); + const trialEligibility = billingStatus?.trialEligibility ?? + initialBillingStatus.trialEligibility ?? { + pentest: false, + background_check: false, + }; const preferences = useMemo( () => billingStatus?.preferences ?? initialBillingStatus.preferences ?? null, [billingStatus?.preferences, initialBillingStatus.preferences], @@ -173,6 +178,7 @@ export function BillingSettingsClient({
    diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSubscriptionPlans.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSubscriptionPlans.tsx index d43325bd40..d5af90a8ed 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSubscriptionPlans.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSubscriptionPlans.tsx @@ -71,11 +71,13 @@ export function BillingSubscriptionPlans({ {current && Current} {trialEligible && 14-day free trial}
    - - {sku.description} - +
    + + {sku.description} + +
    - +
    {formatAmount(sku.unitAmount)} / mo @@ -85,8 +87,8 @@ export function BillingSubscriptionPlans({ {included.quantity} {unit} every month )} - -
    +
    +
    {trialEligible ? 'Try it free for 14 days. Add a card now, pay only if you keep it.' @@ -142,13 +144,13 @@ function formatUsageUnit(unit: string, quantity: number) { function getPlanPromise(productKey: string, quantity: number) { if (productKey === 'pentest') { - if (quantity === 1) return 'Validate your highest-risk app every month.'; - if (quantity === 3) return 'Cover launch windows and retest fixes without waiting.'; - return 'Keep critical surfaces continuously audit-ready.'; + if (quantity === 1) return 'Validate critical surfaces every month.'; + if (quantity === 3) return 'Launch and retest fixes with confidence.'; + return 'Keep key surfaces audit-ready.'; } - if (quantity === 3) return 'Cover your next hires without per-check approvals.'; - if (quantity === 10) return 'Keep recruiting moving with predictable checks.'; - return 'Scale hiring without surprise background-check spend.'; + if (quantity === 3) return 'Cover new hires without approval delays.'; + if (quantity === 10) return 'Keep recruiting moving predictably.'; + return 'Scale screening without surprise spend.'; } function getPlanCta({ diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/[...addOn]/page.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/[...addOn]/page.tsx new file mode 100644 index 0000000000..f331c451e4 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/[...addOn]/page.tsx @@ -0,0 +1,51 @@ +import { serverApi } from '@/lib/api-server'; +import type { Metadata } from 'next'; +import { notFound } from 'next/navigation'; +import { BillingAddOnPlansClient } from '../BillingAddOnPlansClient'; +import { getBillingAddOn } from '../billingAddOns'; +import { emptyBillingStatus } from '../emptyBillingStatus'; +import type { BackgroundCheckBillingStatus } from '../types'; + +function getAddOnSlug(addOnPath: string[]) { + if (addOnPath.length !== 2 || addOnPath[0] !== 'add-ons') { + return null; + } + return addOnPath[1]; +} + +export default async function BillingAddOnCatchAllPage({ + params, +}: { + params: Promise<{ orgId: string; addOn: string[] }>; +}) { + const { orgId, addOn: addOnPath } = await params; + const addOnSlug = getAddOnSlug(addOnPath); + if (!addOnSlug) notFound(); + + const addOn = getBillingAddOn(addOnSlug); + if (!addOn) notFound(); + + const response = await serverApi.get('/v1/billing/status'); + + return ( + + ); +} + +export async function generateMetadata({ + params, +}: { + params: Promise<{ addOn: string[] }>; +}): Promise { + const { addOn: addOnPath } = await params; + const addOnSlug = getAddOnSlug(addOnPath); + const addOn = addOnSlug ? getBillingAddOn(addOnSlug) : null; + + return { + title: addOn ? `${addOn.detailTitle} Billing` : 'Billing Add-on', + }; +} diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/add-ons/layout.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/add-ons/layout.tsx new file mode 100644 index 0000000000..aa80481a4b --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/add-ons/layout.tsx @@ -0,0 +1,3 @@ +export default function BillingAddOnsLayout({ children }: { children: React.ReactNode }) { + return children; +} diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/emptyBillingStatus.ts b/apps/app/src/app/(app)/[orgId]/settings/billing/emptyBillingStatus.ts index 1947baee7e..f27dc73dd2 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/emptyBillingStatus.ts +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/emptyBillingStatus.ts @@ -9,6 +9,10 @@ export const emptyBillingStatus: BackgroundCheckBillingStatus = { }, invoices: [], subscriptions: [], + trialEligibility: { + pentest: true, + background_check: true, + }, usageRows: [], preferences: null, }; diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/page.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/page.tsx index e1198c1f3d..01ced89825 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/page.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/page.tsx @@ -1,18 +1,39 @@ import { serverApi } from '@/lib/api-server'; import type { Metadata } from 'next'; +import { notFound } from 'next/navigation'; +import { BillingAddOnPlansClient } from './BillingAddOnPlansClient'; import { BillingSettingsClient } from './BillingSettingsClient'; +import { getBillingAddOn } from './billingAddOns'; import { emptyBillingStatus } from './emptyBillingStatus'; import type { BackgroundCheckBillingStatus } from './types'; -export default async function BillingPage({ params }: { params: Promise<{ orgId: string }> }) { +export default async function BillingPage({ + params, + searchParams, +}: { + params: Promise<{ orgId: string }>; + searchParams?: Promise<{ addOn?: string }>; +}) { const { orgId } = await params; + const addOnSlug = (await searchParams)?.addOn; const response = await serverApi.get('/v1/billing/status'); + const initialBillingStatus = response.data ?? emptyBillingStatus; + + if (addOnSlug) { + const addOn = getBillingAddOn(addOnSlug); + if (!addOn) notFound(); + + return ( + + ); + } return ( - + ); } diff --git a/apps/app/src/proxy.ts b/apps/app/src/proxy.ts index 403d104c76..68d328add9 100644 --- a/apps/app/src/proxy.ts +++ b/apps/app/src/proxy.ts @@ -77,6 +77,19 @@ export async function proxy(request: NextRequest) { return NextResponse.redirect(url); } + const billingAddOnMatch = nextUrl.pathname.match( + /^\/([^/]+)\/settings\/billing\/add-ons\/([^/]+)$/, + ); + if (billingAddOnMatch) { + const [, orgId, addOn] = billingAddOnMatch; + const url = new URL(`/${orgId}/settings/billing`, request.url); + nextUrl.searchParams.forEach((value, key) => { + url.searchParams.set(key, value); + }); + url.searchParams.set('addOn', addOn); + return NextResponse.rewrite(url); + } + // Org existence and membership checks happen in the app layouts/pages so // users get the proper redirect instead of a raw 403 response from middleware. From c8cfb04952f3151342b33cdf6591b81862d5140e Mon Sep 17 00:00:00 2001 From: Lewis Carhart Date: Fri, 1 May 2026 12:33:31 +0100 Subject: [PATCH 10/20] fix(billing): handle subscription edge cases --- apps/api/src/billing/billing-usage.spec.ts | 50 +++++++++++ apps/api/src/billing/billing-usage.ts | 20 +++-- ...security-penetration-tests.billing.spec.ts | 29 +++++++ .../security-penetration-tests.service.ts | 29 ++++++- .../EmployeeBackgroundCheck.test.tsx | 18 ++-- .../components/EmployeeBackgroundCheck.tsx | 87 ++++++++----------- .../components/backgroundCheckForm.ts | 4 + .../useEmployeeBackgroundCheckData.ts | 60 +++++++++++++ .../_components/CreateRunPanel.test.tsx | 38 ++++++++ .../_components/CreateRunPanel.tsx | 1 + .../settings/billing/BillingAddOns.test.tsx | 31 +++++++ .../billing/BillingSubscriptionPlans.tsx | 6 +- .../settings/billing/emptyBillingStatus.ts | 4 +- packages/billing/tsconfig.json | 1 + 14 files changed, 304 insertions(+), 74 deletions(-) create mode 100644 apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/useEmployeeBackgroundCheckData.ts create mode 100644 apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/CreateRunPanel.test.tsx diff --git a/apps/api/src/billing/billing-usage.spec.ts b/apps/api/src/billing/billing-usage.spec.ts index 610fe65989..4d04ab3973 100644 --- a/apps/api/src/billing/billing-usage.spec.ts +++ b/apps/api/src/billing/billing-usage.spec.ts @@ -125,4 +125,54 @@ describe('listBillingUsageRows', () => { }), ); }); + + it('uses the usage event sku when multiple subscriptions share a product', async () => { + backgroundCheckFindMany.mockResolvedValue([ + { + id: 'bcr_2', + memberId: 'mem_2', + employeeName: 'Grace Hopper', + employeeEmail: 'grace@example.com', + status: 'completed', + stripePaymentStatus: 'succeeded', + createdAt: new Date('2026-04-30T12:00:00.000Z'), + updatedAt: new Date('2026-04-30T12:05:00.000Z'), + }, + ]); + pentestRunFindMany.mockResolvedValue([]); + billingUsageEventFindMany.mockResolvedValue([ + { + skuKey: 'background_checks_monthly_10', + eventType: 'consume', + sourceResourceId: 'mem_2', + stripeInvoiceId: null, + }, + ]); + + const rows = await listBillingUsageRows({ + organizationId: 'org_1', + subscriptions: [ + { + skuKey: 'background_checks_monthly_3', + includedQuantity: 3, + usedQuantity: 3, + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }, + { + skuKey: 'background_checks_monthly_10', + includedQuantity: 10, + usedQuantity: 4, + currentPeriodEnd: new Date('2026-05-30T00:00:00.000Z'), + }, + ], + }); + + expect(rows[0]).toEqual( + expect.objectContaining({ + skuKey: 'background_checks_monthly_10', + subscriptionRemaining: 6, + subscriptionIncluded: 10, + }), + ); + }); }); diff --git a/apps/api/src/billing/billing-usage.ts b/apps/api/src/billing/billing-usage.ts index eeb41e8816..675e50c7ae 100644 --- a/apps/api/src/billing/billing-usage.ts +++ b/apps/api/src/billing/billing-usage.ts @@ -1,5 +1,5 @@ import { db } from '@db'; -import { getBillingSkuProductKey, type BillingSkuKey } from '@trycompai/billing'; +import { getBillingSkuProductKey } from '@trycompai/billing'; import type { BillingUsageRow } from './billing.types'; type SubscriptionSummary = { @@ -73,7 +73,7 @@ export async function listBillingUsageRows(params: { return toBillingUsageRow({ id: request.id, service: 'Background Check', - skuKey: backgroundCheckSku, + skuKey: usage?.skuKey ?? backgroundCheckSku, details: `${request.employeeName} (${request.employeeEmail})`, status: formatStatus(request.status), billingType: formatBillingType( @@ -92,7 +92,7 @@ export async function listBillingUsageRows(params: { return toBillingUsageRow({ id: run.id, service: 'Penetration Test', - skuKey: pentestSku, + skuKey: usage?.skuKey ?? pentestSku, details: run.providerRunId, status: 'Created', billingType: usage @@ -113,7 +113,7 @@ export async function listBillingUsageRows(params: { function toBillingUsageRow(params: { id: string; service: BillingUsageRow['service']; - skuKey: BillingSkuKey; + skuKey: string; details: string; status: string; billingType: string; @@ -122,11 +122,13 @@ function toBillingUsageRow(params: { subscriptions: SubscriptionSummary[]; }): BillingUsageRow { const productKey = getBillingSkuProductKey(params.skuKey); - const subscription = params.subscriptions.find((item) => - productKey - ? getBillingSkuProductKey(item.skuKey) === productKey - : item.skuKey === params.skuKey, - ); + const subscription = + params.subscriptions.find((item) => item.skuKey === params.skuKey) ?? + params.subscriptions.find((item) => + productKey + ? getBillingSkuProductKey(item.skuKey) === productKey + : item.skuKey === params.skuKey, + ); const remaining = subscription ? Math.max(subscription.includedQuantity - subscription.usedQuantity, 0) : null; diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.billing.spec.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.billing.spec.ts index 740ddbfd1d..830150233d 100644 --- a/apps/api/src/security-penetration-tests/security-penetration-tests.billing.spec.ts +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.billing.spec.ts @@ -1,4 +1,5 @@ import { db } from '@db'; +import { HttpException, HttpStatus } from '@nestjs/common'; import type { BillingEntitlementsService } from '../billing/billing-entitlements.service'; import type { PentestCreditsService } from './pentest-credits.service'; import { SecurityPenetrationTestsService } from './security-penetration-tests.service'; @@ -159,6 +160,34 @@ describe('SecurityPenetrationTestsService billing usage', () => { expect(mockMacedPentestsCreate).not.toHaveBeenCalled(); }); + it('preserves exhausted subscription reasons from string payment errors', async () => { + billingEntitlements.tryConsumeIncludedUsageForProduct.mockRejectedValue( + new HttpException( + 'pentest_subscription_exhausted', + HttpStatus.PAYMENT_REQUIRED, + ), + ); + + await expect( + service.createReport('org_123', { + targetUrl: 'https://app.example.com', + }), + ).rejects.toMatchObject({ + status: 402, + }); + + expect(credits.writePentestAuditEntry).toHaveBeenCalledWith( + expect.objectContaining({ + organizationId: 'org_123', + action: 'pentest_create_blocked', + metadata: expect.objectContaining({ + reason: 'pentest_subscription_exhausted', + }), + }), + ); + expect(mockMacedPentestsCreate).not.toHaveBeenCalled(); + }); + it('refunds subscription usage on terminal failure for subscription-backed runs', async () => { mockedDb.securityPenetrationTestRun.findUnique.mockResolvedValue({ organizationId: 'org_123', diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts index b4caf5643e..2cb341475b 100644 --- a/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts @@ -274,7 +274,10 @@ export class SecurityPenetrationTestsService { organizationId, action: 'pentest_create_blocked', runId: null, - description: 'Pentest create blocked: subscription required', + description: + reason === 'pentest_subscription_exhausted' + ? 'Pentest create blocked: subscription exhausted' + : 'Pentest create blocked: subscription required', metadata: { reason, targetUrl: payload.targetUrl, @@ -1093,9 +1096,31 @@ export class SecurityPenetrationTestsService { } function getPaymentRequiredCode(response: unknown): string { + if (typeof response === 'string') { + return normalizePaymentRequiredCode(response); + } if (typeof response !== 'object' || response === null) { return 'pentest_subscription_required'; } const code = (response as Record).code; - return typeof code === 'string' ? code : 'pentest_subscription_required'; + if (typeof code === 'string') return normalizePaymentRequiredCode(code); + + const message = (response as Record).message; + return typeof message === 'string' + ? normalizePaymentRequiredCode(message) + : 'pentest_subscription_required'; +} + +function normalizePaymentRequiredCode(value: string): string { + if ( + value === 'pentest_subscription_exhausted' || + value.toLowerCase().includes('exhaust') || + value.toLowerCase().includes('remaining') + ) { + return 'pentest_subscription_exhausted'; + } + if (value === 'pentest_subscription_required') { + return 'pentest_subscription_required'; + } + return 'pentest_subscription_required'; } diff --git a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.test.tsx b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.test.tsx index 447ba9a81c..5884cba919 100644 --- a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.test.tsx +++ b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.test.tsx @@ -142,9 +142,7 @@ describe('EmployeeBackgroundCheck', () => { expect(screen.getByText('Employee Background Check')).toBeInTheDocument(); expect(screen.getByLabelText('Personal email')).toBeInTheDocument(); - expect( - screen.getByText('2 background checks remaining this period.'), - ).toBeInTheDocument(); + expect(screen.getByText('2 background checks remaining this period.')).toBeInTheDocument(); expect(screen.queryByRole('button', { name: /back/i })).not.toBeInTheDocument(); }); @@ -201,7 +199,7 @@ describe('EmployeeBackgroundCheck', () => { ).toBeNull(); }); - it('stores the pending check and routes to plans when allowance disappears', async () => { + it('stores the pending check details and routes to plans when allowance disappears', async () => { const user = userEvent.setup(); vi.mocked(apiClient.post).mockResolvedValueOnce({ error: 'No credits', @@ -221,8 +219,14 @@ describe('EmployeeBackgroundCheck', () => { '/org_1/settings/billing/add-ons/background-checks', ); }); - expect( - window.sessionStorage.getItem('background-check:org_1:mem_1:pending-request'), - ).toBeNull(); + expect(window.sessionStorage.getItem('background-check:org_1:mem_1:pending-request')).toBe( + JSON.stringify({ + organizationId: 'org_1', + memberId: 'mem_1', + employeeName: 'Ada Lovelace', + employeeEmail: 'ada@example.com', + requesterNotes: 'Recruiting requested an expedited check.', + }), + ); }); }); diff --git a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.tsx b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.tsx index 134e2da751..5ee9ada405 100644 --- a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.tsx +++ b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.tsx @@ -3,13 +3,11 @@ import { usePermissions } from '@/hooks/use-permissions'; import { apiClient } from '@/lib/api-client'; import type { Member, User } from '@db'; -import { getBillingSkuProductKey } from '@trycompai/billing'; import { zodResolver } from '@hookform/resolvers/zod'; import { usePathname, useRouter, useSearchParams } from 'next/navigation'; import { useCallback, useEffect, useRef, useState } from 'react'; import { useForm } from 'react-hook-form'; import { toast } from 'sonner'; -import useSWR from 'swr'; import { BackgroundCheckDetailsForm } from './BackgroundCheckDetailsForm'; import { type BackgroundCheckFormValues, @@ -23,6 +21,11 @@ import type { BackgroundCheckBillingStatus, BackgroundCheckRecord } from './back import { OverviewStep } from './BackgroundCheckWizardParts'; import { CustomBackgroundCheckUpload } from './CustomBackgroundCheckUpload'; import { PaymentMethodUpdateDialog } from './PaymentMethodUpdateDialog'; +import { + getBackgroundChecksRemaining, + useBackgroundCheckBillingStatus, + useBackgroundCheckRecord, +} from './useEmployeeBackgroundCheckData'; interface EmployeeBackgroundCheckProps { employee: Member & { user: User }; @@ -52,27 +55,15 @@ export function EmployeeBackgroundCheck({ const [requestConfirmation, setRequestConfirmation] = useState(null); const { hasPermission } = usePermissions(); - const { data: backgroundCheck, mutate: mutateBackgroundCheck } = useSWR( - [`/v1/people/${employee.id}/background-check`, organizationId], - async ([endpoint]) => { - const response = await apiClient.get(endpoint, organizationId); - if (response.error) throw new Error('Failed to load background check'); - return response.data ?? null; - }, - { fallbackData: initialBackgroundCheck }, - ); - - const { data: billingStatus, mutate: mutateBillingStatus } = useSWR( - ['/v1/background-check-billing/status', organizationId], - async ([endpoint]) => { - const response = await apiClient.get(endpoint, organizationId); - if (response.error || !response.data) { - throw new Error('Failed to load billing status'); - } - return response.data; - }, - { fallbackData: initialBillingStatus }, - ); + const { data: backgroundCheck, mutate: mutateBackgroundCheck } = useBackgroundCheckRecord({ + employeeId: employee.id, + initialBackgroundCheck, + organizationId, + }); + const { data: billingStatus, mutate: mutateBillingStatus } = useBackgroundCheckBillingStatus({ + initialBillingStatus, + organizationId, + }); const form = useForm({ resolver: zodResolver(backgroundCheckSchema), @@ -94,17 +85,7 @@ export function EmployeeBackgroundCheck({ const canRequest = hasPermission('member', 'update'); const canManageBilling = hasPermission('organization', 'update'); const hasPaymentMethod = billingStatus?.hasPaymentMethod === true; - const backgroundCheckSubscription = (billingStatus?.subscriptions ?? []).find( - (subscription) => - getBillingSkuProductKey(subscription.skuKey) === 'background_check' && - (subscription.status === 'active' || subscription.status === 'trialing'), - ); - const backgroundChecksRemaining = backgroundCheckSubscription - ? Math.max( - backgroundCheckSubscription.includedQuantity - backgroundCheckSubscription.usedQuantity, - 0, - ) - : null; + const backgroundChecksRemaining = getBackgroundChecksRemaining({ billingStatus }); const hasBackgroundCheckAllowance = backgroundChecksRemaining !== null && backgroundChecksRemaining > 0; const visibleWizardStep = hasBackgroundCheckAllowance ? 'details' : wizardStep; @@ -137,6 +118,7 @@ export function EmployeeBackgroundCheck({ if (response.error || !response.data) { if (response.status === 402) { + writePendingRequest(values); toast.error('Choose or upgrade a background check plan to continue.'); router.push(`/${organizationId}/settings/billing/add-ons/background-checks`); return false; @@ -153,7 +135,14 @@ export function EmployeeBackgroundCheck({ clearPendingRequest(); return true; }, - [clearPendingRequest, employee.id, mutateBackgroundCheck, organizationId, router], + [ + clearPendingRequest, + employee.id, + mutateBackgroundCheck, + organizationId, + router, + writePendingRequest, + ], ); useEffect(() => { @@ -183,35 +172,31 @@ export function EmployeeBackgroundCheck({ { revalidate: true }, ); - const pendingRequest = readPendingBackgroundCheckRequest({ organizationId, memberId: employee.id }); + const pendingRequest = readPendingBackgroundCheckRequest({ + organizationId, + memberId: employee.id, + }); if (!pendingRequest) { router.replace(pathname, { scroll: false }); return; } form.reset({ - employeeName: form.getValues('employeeName') || employee.user.name || '', - employeeEmail: form.getValues('employeeEmail') || '', + employeeName: pendingRequest.employeeName, + employeeEmail: pendingRequest.employeeEmail, requesterNotes: pendingRequest.requesterNotes ?? '', }); router.replace(pathname, { scroll: false }); })(); - }, [ - form, - employee.id, - employee.user.name, - mutateBillingStatus, - organizationId, - pathname, - router, - searchParams, - ]); + }, [form, employee.id, mutateBillingStatus, organizationId, pathname, router, searchParams]); const handleOpenBilling = async (values?: BackgroundCheckFormValues) => { setIsOpeningBilling(true); if (values) writePendingRequest(values); - const returnPath = hasPaymentMethod ? `/${organizationId}/people/${employee.id}` : `/${organizationId}/settings/billing`; + const returnPath = hasPaymentMethod + ? `/${organizationId}/people/${employee.id}` + : `/${organizationId}/settings/billing`; const returnUrl = `${window.location.origin}${returnPath}`; const endpoint = hasPaymentMethod ? '/v1/background-check-billing/portal' @@ -298,9 +283,7 @@ export function EmployeeBackgroundCheck({ employeeName={employee.user.name ?? employee.user.email} organizationId={organizationId} onUploaded={async (uploadedBackgroundCheck) => { - await mutateBackgroundCheck(uploadedBackgroundCheck, { - revalidate: false, - }); + await mutateBackgroundCheck(uploadedBackgroundCheck, { revalidate: false }); }} /> diff --git a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/backgroundCheckForm.ts b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/backgroundCheckForm.ts index 6d2b543473..f4e8c3b225 100644 --- a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/backgroundCheckForm.ts +++ b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/backgroundCheckForm.ts @@ -11,6 +11,8 @@ export type BackgroundCheckFormValues = z.infer; const pendingBackgroundCheckSchema = z.object({ memberId: z.string(), organizationId: z.string(), + employeeName: z.string(), + employeeEmail: z.string(), requesterNotes: z.string().optional(), }); @@ -66,6 +68,8 @@ export function writePendingBackgroundCheckRequest({ const pendingRequest: PendingBackgroundCheckRequest = { organizationId, memberId, + employeeName: values.employeeName, + employeeEmail: values.employeeEmail, requesterNotes: values.requesterNotes, }; window.sessionStorage.setItem( diff --git a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/useEmployeeBackgroundCheckData.ts b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/useEmployeeBackgroundCheckData.ts new file mode 100644 index 0000000000..35a5287c12 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/useEmployeeBackgroundCheckData.ts @@ -0,0 +1,60 @@ +'use client'; + +import { apiClient } from '@/lib/api-client'; +import { getBillingSkuProductKey } from '@trycompai/billing'; +import useSWR from 'swr'; +import type { BackgroundCheckBillingStatus, BackgroundCheckRecord } from './backgroundCheckTypes'; + +export function useBackgroundCheckRecord({ + employeeId, + initialBackgroundCheck, + organizationId, +}: { + employeeId: string; + initialBackgroundCheck: BackgroundCheckRecord | null; + organizationId: string; +}) { + return useSWR( + [`/v1/people/${employeeId}/background-check`, organizationId], + async ([endpoint]) => { + const response = await apiClient.get(endpoint, organizationId); + if (response.error) throw new Error('Failed to load background check'); + return response.data ?? null; + }, + { fallbackData: initialBackgroundCheck }, + ); +} + +export function useBackgroundCheckBillingStatus({ + initialBillingStatus, + organizationId, +}: { + initialBillingStatus: BackgroundCheckBillingStatus; + organizationId: string; +}) { + return useSWR( + ['/v1/background-check-billing/status', organizationId], + async ([endpoint]) => { + const response = await apiClient.get(endpoint, organizationId); + if (response.error || !response.data) { + throw new Error('Failed to load billing status'); + } + return response.data; + }, + { fallbackData: initialBillingStatus }, + ); +} + +export function getBackgroundChecksRemaining({ + billingStatus, +}: { + billingStatus: BackgroundCheckBillingStatus | undefined; +}): number | null { + const subscription = (billingStatus?.subscriptions ?? []).find( + (item) => + getBillingSkuProductKey(item.skuKey) === 'background_check' && + (item.status === 'active' || item.status === 'trialing'), + ); + if (!subscription) return null; + return Math.max(subscription.includedQuantity - subscription.usedQuantity, 0); +} diff --git a/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/CreateRunPanel.test.tsx b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/CreateRunPanel.test.tsx new file mode 100644 index 0000000000..63168dfebe --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/CreateRunPanel.test.tsx @@ -0,0 +1,38 @@ +import { render, screen, waitFor } from '@testing-library/react'; +import userEvent from '@testing-library/user-event'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; +import { CreateRunPanel } from './CreateRunPanel'; + +const navigationMock = vi.hoisted(() => ({ + push: vi.fn(), +})); + +vi.mock('next/navigation', () => ({ + useRouter: () => ({ push: navigationMock.push }), +})); + +vi.mock('sonner', () => ({ + toast: { error: vi.fn() }, +})); + +describe('CreateRunPanel', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('routes users without allowance to billing even when required fields are empty', async () => { + const user = userEvent.setup(); + const onSubmit = vi.fn(async () => ({ id: 'run_1' })); + + render(); + + await user.click(screen.getByRole('button', { name: /choose plan/i })); + + await waitFor(() => { + expect(navigationMock.push).toHaveBeenCalledWith( + '/org_1/settings/billing/add-ons/penetration-tests', + ); + }); + expect(onSubmit).not.toHaveBeenCalled(); + }); +}); diff --git a/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/CreateRunPanel.tsx b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/CreateRunPanel.tsx index 1b7b733ef4..4f42e8d7d9 100644 --- a/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/CreateRunPanel.tsx +++ b/apps/app/src/app/(app)/[orgId]/security/penetration-tests/_components/CreateRunPanel.tsx @@ -70,6 +70,7 @@ export function CreateRunPanel({
    void handleSubmit(e)} className="rounded-[var(--radius)] border border-border bg-card p-8 shadow-[0_24px_48px_-12px_rgba(0,0,0,0.12)]" > diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOns.test.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOns.test.tsx index b7b0329414..26f4669496 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOns.test.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingAddOns.test.tsx @@ -199,4 +199,35 @@ describe('billing add-ons', () => { 'org_1', ); }); + + it('ignores inactive same-product subscriptions when selecting the current plan', () => { + renderAddOnPlans({ + addOnSlug: 'penetration-tests', + subscriptions: [ + { + skuKey: 'pentest_monthly_1', + status: 'canceled', + includedQuantity: 1, + usedQuantity: 1, + currentPeriodStart: '2026-03-30T00:00:00.000Z', + currentPeriodEnd: '2026-04-30T00:00:00.000Z', + cancelAtPeriodEnd: false, + }, + { + skuKey: 'pentest_monthly_3', + status: 'active', + includedQuantity: 3, + usedQuantity: 1, + currentPeriodStart: '2026-04-30T00:00:00.000Z', + currentPeriodEnd: '2026-05-30T00:00:00.000Z', + cancelAtPeriodEnd: false, + }, + ], + trialEligibility: { pentest: false, background_check: true }, + }); + + expect(screen.getByRole('button', { name: /current plan/i })).toBeDisabled(); + expect(screen.getByText(/2 of 3.*remaining this period/i)).toBeInTheDocument(); + expect(screen.getByRole('button', { name: /switch to monthly scans/i })).toBeInTheDocument(); + }); }); diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSubscriptionPlans.tsx b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSubscriptionPlans.tsx index d5af90a8ed..e6f9511824 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSubscriptionPlans.tsx +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/BillingSubscriptionPlans.tsx @@ -39,9 +39,11 @@ export function BillingSubscriptionPlans({ {skuKeys.map((skuKey) => { const sku = getBillingSku({ skuKey }); const subscription = subscriptions.find( - (item) => getBillingSkuProductKey(item.skuKey) === sku.productKey, + (item) => + getBillingSkuProductKey(item.skuKey) === sku.productKey && + (item.status === 'active' || item.status === 'trialing'), ); - const active = subscription?.status === 'active' || subscription?.status === 'trialing'; + const active = subscription !== undefined; const current = active && subscription?.skuKey === skuKey; const included = sku.includedUsage; const remaining = subscription diff --git a/apps/app/src/app/(app)/[orgId]/settings/billing/emptyBillingStatus.ts b/apps/app/src/app/(app)/[orgId]/settings/billing/emptyBillingStatus.ts index f27dc73dd2..635758166c 100644 --- a/apps/app/src/app/(app)/[orgId]/settings/billing/emptyBillingStatus.ts +++ b/apps/app/src/app/(app)/[orgId]/settings/billing/emptyBillingStatus.ts @@ -10,8 +10,8 @@ export const emptyBillingStatus: BackgroundCheckBillingStatus = { invoices: [], subscriptions: [], trialEligibility: { - pentest: true, - background_check: true, + pentest: false, + background_check: false, }, usageRows: [], preferences: null, diff --git a/packages/billing/tsconfig.json b/packages/billing/tsconfig.json index 4c11875686..b03f5fbbab 100644 --- a/packages/billing/tsconfig.json +++ b/packages/billing/tsconfig.json @@ -3,6 +3,7 @@ "compilerOptions": { "incremental": true, "allowImportingTsExtensions": true, + "noEmit": true, "tsBuildInfoFile": "node_modules/.cache/tsbuildinfo.json" }, "include": ["src"], From f655f37d86ee08c5beec82eb9cb4c01415236ec3 Mon Sep 17 00:00:00 2001 From: Lewis Carhart Date: Fri, 1 May 2026 12:41:54 +0100 Subject: [PATCH 11/20] fix(billing): preserve legacy background check drafts --- .../EmployeeBackgroundCheck.test.tsx | 33 +++++++++++++++++++ .../components/EmployeeBackgroundCheck.tsx | 5 +-- .../components/backgroundCheckForm.ts | 4 +-- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.test.tsx b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.test.tsx index 5884cba919..1032e7b43d 100644 --- a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.test.tsx +++ b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.test.tsx @@ -229,4 +229,37 @@ describe('EmployeeBackgroundCheck', () => { }), ); }); + + it('keeps legacy pending check drafts that do not have employee details', async () => { + navigationMock.searchParams = new URLSearchParams({ + background_check_billing: 'success', + session_id: 'cs_test_legacy', + }); + window.sessionStorage.setItem( + 'background-check:org_1:mem_1:pending-request', + JSON.stringify({ + organizationId: 'org_1', + memberId: 'mem_1', + requesterNotes: 'Legacy note before billing.', + }), + ); + + renderSection(); + + await waitFor(() => { + expect(apiClient.post).toHaveBeenCalledWith( + '/v1/background-check-billing/setup-success', + { sessionId: 'cs_test_legacy' }, + 'org_1', + ); + }); + expect(screen.getByLabelText('Employee name')).toHaveValue('Ada Lovelace'); + expect(screen.getByLabelText('Personal email')).toHaveValue(''); + expect(screen.getByLabelText('Additional information')).toHaveValue( + 'Legacy note before billing.', + ); + expect( + window.sessionStorage.getItem('background-check:org_1:mem_1:pending-request'), + ).not.toBeNull(); + }); }); diff --git a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.tsx b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.tsx index 5ee9ada405..6341617f8b 100644 --- a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.tsx +++ b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/EmployeeBackgroundCheck.tsx @@ -182,8 +182,9 @@ export function EmployeeBackgroundCheck({ } form.reset({ - employeeName: pendingRequest.employeeName, - employeeEmail: pendingRequest.employeeEmail, + employeeName: + pendingRequest.employeeName ?? form.getValues('employeeName') ?? employee.user.name ?? '', + employeeEmail: pendingRequest.employeeEmail ?? form.getValues('employeeEmail') ?? '', requesterNotes: pendingRequest.requesterNotes ?? '', }); router.replace(pathname, { scroll: false }); diff --git a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/backgroundCheckForm.ts b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/backgroundCheckForm.ts index f4e8c3b225..fceed22d21 100644 --- a/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/backgroundCheckForm.ts +++ b/apps/app/src/app/(app)/[orgId]/people/[employeeId]/components/backgroundCheckForm.ts @@ -11,8 +11,8 @@ export type BackgroundCheckFormValues = z.infer; const pendingBackgroundCheckSchema = z.object({ memberId: z.string(), organizationId: z.string(), - employeeName: z.string(), - employeeEmail: z.string(), + employeeName: z.string().optional(), + employeeEmail: z.string().optional(), requesterNotes: z.string().optional(), }); From be9027b8666106edd8821ea4d5f29a7c23360c86 Mon Sep 17 00:00:00 2001 From: Lewis Carhart Date: Fri, 1 May 2026 16:30:21 +0100 Subject: [PATCH 12/20] feat(billing): implement admin billing actions and controller - Added AdminBillingActionsService to handle subscription management, including cancellation, resumption, and credit granting. - Created AdminBillingController to expose billing endpoints for managing organization billing preferences and subscriptions. - Introduced DTOs for billing actions to validate incoming requests. - Implemented billing audit logging for actions performed by admins. - Enhanced billing data fetching and context management for organizations. --- .../admin-audit-log.interceptor.ts | 4 + .../admin-billing-actions.service.spec.ts | 66 ++++ .../admin-billing-actions.service.ts | 192 ++++++++++++ .../admin-billing.controller.ts | 208 +++++++++++++ .../admin-organizations/admin-billing.data.ts | 115 +++++++ .../admin-billing.helpers.ts | 93 ++++++ .../admin-billing.service.ts | 260 ++++++++++++++++ .../admin-billing.types.ts | 68 +++++ .../admin-organizations.module.ts | 8 + .../dto/admin-billing.dto.ts | 154 ++++++++++ .../billing/billing-credits.service.spec.ts | 132 ++++++++ .../src/billing/billing-credits.service.ts | 284 ++++++++++++++++++ apps/api/src/billing/billing-credits.types.ts | 63 ++++ .../billing/billing-entitlements.service.ts | 54 +++- .../src/billing/billing-subscription-plans.ts | 106 +++++-- apps/api/src/billing/billing.module.ts | 7 +- apps/api/src/billing/billing.service.spec.ts | 215 ++++++++++++- .../pentest-credits.service.ts | 59 ++++ .../components/AdminBillingForms.tsx | 228 ++++++++++++++ .../components/AdminBillingTab.test.tsx | 124 ++++++++ .../components/AdminBillingTab.tsx | 234 +++++++++++++++ .../components/AdminBillingTables.tsx | 173 +++++++++++ .../components/AdminBillingTypes.ts | 75 +++++ .../components/AdminOrgDangerDialogs.tsx | 114 +++++++ .../[adminOrgId]/components/AdminOrgTabs.tsx | 175 +++-------- .../billing/BillingAddOnPlansClient.tsx | 2 +- .../settings/billing/BillingAddOns.test.tsx | 90 +++++- .../billing/BillingAddOnsOverview.tsx | 32 +- .../billing/BillingPlanChangeDialog.tsx | 128 ++++++++ .../billing/BillingSubscriptionPlans.tsx | 44 ++- .../migration.sql | 117 ++++++++ .../20260501120159_billing/migration.sql | 2 + .../prisma/schema/organization-billing.prisma | 45 +++ packages/db/prisma/schema/organization.prisma | 2 + 34 files changed, 3476 insertions(+), 197 deletions(-) create mode 100644 apps/api/src/admin-organizations/admin-billing-actions.service.spec.ts create mode 100644 apps/api/src/admin-organizations/admin-billing-actions.service.ts create mode 100644 apps/api/src/admin-organizations/admin-billing.controller.ts create mode 100644 apps/api/src/admin-organizations/admin-billing.data.ts create mode 100644 apps/api/src/admin-organizations/admin-billing.helpers.ts create mode 100644 apps/api/src/admin-organizations/admin-billing.service.ts create mode 100644 apps/api/src/admin-organizations/admin-billing.types.ts create mode 100644 apps/api/src/admin-organizations/dto/admin-billing.dto.ts create mode 100644 apps/api/src/billing/billing-credits.service.spec.ts create mode 100644 apps/api/src/billing/billing-credits.service.ts create mode 100644 apps/api/src/billing/billing-credits.types.ts create mode 100644 apps/app/src/app/(app)/[orgId]/admin/organizations/[adminOrgId]/components/AdminBillingForms.tsx create mode 100644 apps/app/src/app/(app)/[orgId]/admin/organizations/[adminOrgId]/components/AdminBillingTab.test.tsx create mode 100644 apps/app/src/app/(app)/[orgId]/admin/organizations/[adminOrgId]/components/AdminBillingTab.tsx create mode 100644 apps/app/src/app/(app)/[orgId]/admin/organizations/[adminOrgId]/components/AdminBillingTables.tsx create mode 100644 apps/app/src/app/(app)/[orgId]/admin/organizations/[adminOrgId]/components/AdminBillingTypes.ts create mode 100644 apps/app/src/app/(app)/[orgId]/admin/organizations/[adminOrgId]/components/AdminOrgDangerDialogs.tsx create mode 100644 apps/app/src/app/(app)/[orgId]/settings/billing/BillingPlanChangeDialog.tsx create mode 100644 packages/db/prisma/migrations/20260501090000_generic_billing_credits/migration.sql create mode 100644 packages/db/prisma/migrations/20260501120159_billing/migration.sql diff --git a/apps/api/src/admin-organizations/admin-audit-log.interceptor.ts b/apps/api/src/admin-organizations/admin-audit-log.interceptor.ts index 121f1caf76..a35f6f6977 100644 --- a/apps/api/src/admin-organizations/admin-audit-log.interceptor.ts +++ b/apps/api/src/admin-organizations/admin-audit-log.interceptor.ts @@ -24,6 +24,10 @@ const SEGMENT_TO_RESOURCE: Record< entity: AuditLogEntityType.pentest, singular: 'pentest credits', }, + billing: { + entity: AuditLogEntityType.organization, + singular: 'billing', + }, }; const SPECIAL_ACTION_DESCRIPTIONS: Record = { diff --git a/apps/api/src/admin-organizations/admin-billing-actions.service.spec.ts b/apps/api/src/admin-organizations/admin-billing-actions.service.spec.ts new file mode 100644 index 0000000000..43a925e4d3 --- /dev/null +++ b/apps/api/src/admin-organizations/admin-billing-actions.service.spec.ts @@ -0,0 +1,66 @@ +import { NotFoundException } from '@nestjs/common'; +import { db } from '@db'; +import { AdminBillingActionsService } from './admin-billing-actions.service'; + +jest.mock('@db', () => ({ + db: { + organization: { findUnique: jest.fn() }, + organizationBilling: { findUnique: jest.fn() }, + organizationBillingSubscription: { + findMany: jest.fn(), + findFirst: jest.fn(), + updateMany: jest.fn(), + }, + billingAuditEvent: { create: jest.fn() }, + }, +})); + +const mockedDb = db as unknown as { + organization: { findUnique: jest.Mock }; + organizationBilling: { findUnique: jest.Mock }; + organizationBillingSubscription: { findMany: jest.Mock }; +}; + +describe('AdminBillingActionsService', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockedDb.organization.findUnique.mockResolvedValue({ + id: 'org_1', + name: 'Customer', + }); + mockedDb.organizationBilling.findUnique.mockResolvedValue({ + organizationId: 'org_1', + stripeCustomerId: 'cus_org_1', + }); + mockedDb.organizationBillingSubscription.findMany.mockResolvedValue([]); + }); + + it('rejects invoice recovery links for invoices owned by another customer', async () => { + const service = new AdminBillingActionsService( + { + getClient: () => ({ + invoices: { + retrieve: jest.fn().mockResolvedValue({ + id: 'in_other', + customer: 'cus_other', + status: 'open', + }), + }, + }), + isConfigured: () => true, + } as never, + {} as never, + {} as never, + {} as never, + ); + + await expect( + service.getInvoiceRetryLink({ + organizationId: 'org_1', + adminUserId: 'usr_admin', + invoiceId: 'in_other', + note: 'customer asked', + }), + ).rejects.toBeInstanceOf(NotFoundException); + }); +}); diff --git a/apps/api/src/admin-organizations/admin-billing-actions.service.ts b/apps/api/src/admin-organizations/admin-billing-actions.service.ts new file mode 100644 index 0000000000..f7d5e1c04c --- /dev/null +++ b/apps/api/src/admin-organizations/admin-billing-actions.service.ts @@ -0,0 +1,192 @@ +import { + BadRequestException, + Injectable, + NotFoundException, +} from '@nestjs/common'; +import { db } from '@db'; +import { BillingCreditsService } from '../billing/billing-credits.service'; +import { BillingService } from '../billing/billing.service'; +import { validateBillingRedirectUrl } from '../billing/billing-redirect-urls'; +import { assertStripeBillingConfigured } from '../billing/billing-stripe-config'; +import { StripeService } from '../stripe/stripe.service'; +import { getInvoiceCustomerId } from './admin-billing.helpers'; +import { getOrgBillingContext, writeBillingAudit } from './admin-billing.data'; +import { AdminBillingService } from './admin-billing.service'; +import type { AdminBillingStatus } from './admin-billing.types'; + +@Injectable() +export class AdminBillingActionsService { + constructor( + private readonly stripeService: StripeService, + private readonly billingService: BillingService, + private readonly credits: BillingCreditsService, + private readonly adminBilling: AdminBillingService, + ) {} + + async cancelSubscription(params: { + organizationId: string; + adminUserId: string; + subscriptionId: string; + mode: 'period_end' | 'immediate'; + note: string; + confirm?: string; + }): Promise { + assertStripeBillingConfigured(this.stripeService); + const subscription = await this.getScopedSubscription(params); + if (params.mode === 'immediate' && params.confirm !== 'cancel now') { + throw new BadRequestException('Type "cancel now" to confirm.'); + } + const stripe = this.stripeService.getClient(); + const updated = + params.mode === 'immediate' + ? await stripe.subscriptions.cancel(subscription.stripeSubscriptionId, { + cancellation_details: { comment: params.note }, + }) + : await stripe.subscriptions.update(subscription.stripeSubscriptionId, { + cancel_at_period_end: true, + cancellation_details: { comment: params.note }, + }); + await db.organizationBillingSubscription.updateMany({ + where: { id: subscription.id, organizationId: params.organizationId }, + data: { + stripeStatus: updated.status, + cancelAtPeriodEnd: updated.cancel_at_period_end, + canceledAt: updated.canceled_at + ? new Date(updated.canceled_at * 1000) + : null, + }, + }); + await writeBillingAudit({ + organizationId: params.organizationId, + eventType: 'admin_subscription_canceled', + skuKey: subscription.skuKey, + metadata: { + adminUserId: params.adminUserId, + subscriptionId: subscription.id, + mode: params.mode, + note: params.note, + }, + }); + return this.adminBilling.getStatus(params.organizationId); + } + + async resumeSubscription(params: { + organizationId: string; + adminUserId: string; + subscriptionId: string; + note: string; + }): Promise { + assertStripeBillingConfigured(this.stripeService); + const subscription = await this.getScopedSubscription(params); + const updated = await this.stripeService + .getClient() + .subscriptions.update(subscription.stripeSubscriptionId, { + cancel_at_period_end: false, + }); + await db.organizationBillingSubscription.updateMany({ + where: { id: subscription.id, organizationId: params.organizationId }, + data: { + stripeStatus: updated.status, + cancelAtPeriodEnd: updated.cancel_at_period_end, + canceledAt: null, + }, + }); + await writeBillingAudit({ + organizationId: params.organizationId, + eventType: 'admin_subscription_resumed', + skuKey: subscription.skuKey, + metadata: { + adminUserId: params.adminUserId, + subscriptionId: subscription.id, + note: params.note, + }, + }); + return this.adminBilling.getStatus(params.organizationId); + } + + async createPaymentLink(params: { + organizationId: string; + successUrl: string; + cancelUrl: string; + }) { + validateBillingRedirectUrl(params.successUrl); + validateBillingRedirectUrl(params.cancelUrl); + return this.billingService.createSetupSession(params); + } + + async grantCredits(params: { + organizationId: string; + adminUserId: string; + productKey: 'pentest' | 'background_check'; + quantity: number; + note: string; + confirm?: string; + }): Promise { + if (params.quantity >= 25 && params.confirm !== 'grant credits') { + throw new BadRequestException('Type "grant credits" to confirm.'); + } + await this.credits.grant({ + organizationId: params.organizationId, + productKey: params.productKey, + quantity: params.quantity, + source: 'manual', + note: params.note, + adminUserId: params.adminUserId, + }); + await writeBillingAudit({ + organizationId: params.organizationId, + eventType: 'admin_credits_granted', + metadata: { + adminUserId: params.adminUserId, + productKey: params.productKey, + quantity: params.quantity, + note: params.note, + }, + }); + return this.adminBilling.getStatus(params.organizationId); + } + + async getInvoiceRetryLink(params: { + organizationId: string; + adminUserId: string; + invoiceId: string; + note: string; + }) { + const { billing } = await getOrgBillingContext(params.organizationId); + if (!billing) throw new NotFoundException('Billing customer not found.'); + const invoice = await this.stripeService + .getClient() + .invoices.retrieve(params.invoiceId); + if (getInvoiceCustomerId(invoice) !== billing.stripeCustomerId) { + throw new NotFoundException('Invoice not found.'); + } + await writeBillingAudit({ + organizationId: params.organizationId, + eventType: 'admin_invoice_retry_link_created', + metadata: { + adminUserId: params.adminUserId, + invoiceId: params.invoiceId, + note: params.note, + }, + }); + return { + hostedInvoiceUrl: invoice.hosted_invoice_url ?? null, + invoicePdfUrl: invoice.invoice_pdf ?? null, + status: invoice.status ?? 'unknown', + }; + } + + private async getScopedSubscription(params: { + organizationId: string; + subscriptionId: string; + }) { + const subscription = await db.organizationBillingSubscription.findFirst({ + where: { + id: params.subscriptionId, + organizationId: params.organizationId, + }, + }); + if (!subscription) throw new NotFoundException('Subscription not found.'); + return subscription; + } +} diff --git a/apps/api/src/admin-organizations/admin-billing.controller.ts b/apps/api/src/admin-organizations/admin-billing.controller.ts new file mode 100644 index 0000000000..8af9805632 --- /dev/null +++ b/apps/api/src/admin-organizations/admin-billing.controller.ts @@ -0,0 +1,208 @@ +import { + Body, + Controller, + Get, + Param, + Post, + Put, + Req, + UseGuards, + UseInterceptors, + UsePipes, + ValidationPipe, +} from '@nestjs/common'; +import { ApiExcludeController, ApiOperation, ApiTags } from '@nestjs/swagger'; +import { Throttle } from '@nestjs/throttler'; +import { + subscriptionBillingSkuKeys, + type BillingSkuKey, +} from '@trycompai/billing'; +import { PlatformAdminGuard } from '../auth/platform-admin.guard'; +import { AdminAuditLogInterceptor } from './admin-audit-log.interceptor'; +import { AdminBillingActionsService } from './admin-billing-actions.service'; +import { AdminBillingService } from './admin-billing.service'; +import { + AdminBillingCancelSubscriptionDto, + AdminBillingGrantCreditsDto, + AdminBillingInvoiceActionDto, + AdminBillingNoteDto, + AdminBillingPaymentLinkDto, + AdminBillingPreferencesDto, + AdminBillingSubscriptionPreviewDto, + AdminBillingSubscriptionDto, +} from './dto/admin-billing.dto'; + +interface AdminRequest { + userId: string; +} + +@ApiExcludeController() +@ApiTags('Admin - Billing') +@Controller({ path: 'admin/organizations/:orgId/billing', version: '1' }) +@UseGuards(PlatformAdminGuard) +@UseInterceptors(AdminAuditLogInterceptor) +@UsePipes( + new ValidationPipe({ + whitelist: true, + forbidNonWhitelisted: true, + transform: true, + }), +) +@Throttle({ default: { ttl: 60_000, limit: 30 } }) +export class AdminBillingController { + constructor( + private readonly billing: AdminBillingService, + private readonly actions: AdminBillingActionsService, + ) {} + + @Get() + @ApiOperation({ summary: 'Get admin billing status for an organization' }) + async getStatus(@Param('orgId') orgId: string) { + return this.billing.getStatus(orgId); + } + + @Put('preferences') + @ApiOperation({ summary: 'Update billing preferences for an organization' }) + async updatePreferences( + @Param('orgId') orgId: string, + @Req() req: AdminRequest, + @Body() body: AdminBillingPreferencesDto, + ) { + return this.billing.updatePreferences({ + organizationId: orgId, + adminUserId: req.userId, + note: body.note, + confirmBillingEmailChange: body.confirmBillingEmailChange, + preferences: { + companyName: body.companyName, + billingEmail: body.billingEmail, + purchaseOrder: body.purchaseOrder ?? null, + address: { + line1: body.addressLine1 ?? null, + line2: body.addressLine2 ?? null, + city: body.addressCity ?? null, + state: body.addressState ?? null, + postalCode: body.addressPostalCode ?? null, + country: body.addressCountry ?? null, + }, + taxId: { + type: body.taxIdType ?? null, + value: body.taxIdValue ?? null, + }, + }, + }); + } + + @Post('subscriptions/preview') + @ApiOperation({ summary: 'Preview a subscription plan change' }) + async previewSubscription( + @Param('orgId') orgId: string, + @Body() body: AdminBillingSubscriptionPreviewDto, + ) { + return this.billing.previewSubscription({ + organizationId: orgId, + skuKey: assertSubscriptionSku(body.skuKey), + }); + } + + @Post('subscriptions') + @ApiOperation({ summary: 'Create or change a subscription plan' }) + async setSubscription( + @Param('orgId') orgId: string, + @Req() req: AdminRequest, + @Body() body: AdminBillingSubscriptionDto, + ) { + return this.billing.setSubscription({ + organizationId: orgId, + adminUserId: req.userId, + skuKey: assertSubscriptionSku(body.skuKey), + returnUrl: body.returnUrl, + note: body.note, + confirmDowngrade: body.confirmDowngrade, + }); + } + + @Post('subscriptions/:subscriptionId/cancel') + @Throttle({ default: { ttl: 60_000, limit: 5 } }) + async cancelSubscription( + @Param('orgId') orgId: string, + @Param('subscriptionId') subscriptionId: string, + @Req() req: AdminRequest, + @Body() body: AdminBillingCancelSubscriptionDto, + ) { + return this.actions.cancelSubscription({ + organizationId: orgId, + adminUserId: req.userId, + subscriptionId, + mode: body.mode, + note: body.note, + confirm: body.confirm, + }); + } + + @Post('subscriptions/:subscriptionId/resume') + async resumeSubscription( + @Param('orgId') orgId: string, + @Param('subscriptionId') subscriptionId: string, + @Req() req: AdminRequest, + @Body() body: AdminBillingNoteDto, + ) { + return this.actions.resumeSubscription({ + organizationId: orgId, + adminUserId: req.userId, + subscriptionId, + note: body.note, + }); + } + + @Post('payment-link') + async createPaymentLink( + @Param('orgId') orgId: string, + @Body() body: AdminBillingPaymentLinkDto, + ) { + return this.actions.createPaymentLink({ + organizationId: orgId, + successUrl: body.successUrl, + cancelUrl: body.cancelUrl, + }); + } + + @Post('credits') + @Throttle({ default: { ttl: 60_000, limit: 5 } }) + async grantCredits( + @Param('orgId') orgId: string, + @Req() req: AdminRequest, + @Body() body: AdminBillingGrantCreditsDto, + ) { + return this.actions.grantCredits({ + organizationId: orgId, + adminUserId: req.userId, + productKey: body.productKey, + quantity: body.quantity, + note: body.note, + confirm: body.confirm, + }); + } + + @Post('invoices/:invoiceId/retry-link') + async getInvoiceRetryLink( + @Param('orgId') orgId: string, + @Param('invoiceId') invoiceId: string, + @Req() req: AdminRequest, + @Body() body: AdminBillingInvoiceActionDto, + ) { + return this.actions.getInvoiceRetryLink({ + organizationId: orgId, + adminUserId: req.userId, + invoiceId, + note: body.note, + }); + } +} + +function assertSubscriptionSku(value: string): BillingSkuKey { + if (subscriptionBillingSkuKeys.some((skuKey) => skuKey === value)) { + return value as BillingSkuKey; + } + throw new Error('Invalid subscription SKU.'); +} diff --git a/apps/api/src/admin-organizations/admin-billing.data.ts b/apps/api/src/admin-organizations/admin-billing.data.ts new file mode 100644 index 0000000000..56d562e19a --- /dev/null +++ b/apps/api/src/admin-organizations/admin-billing.data.ts @@ -0,0 +1,115 @@ +import { NotFoundException } from '@nestjs/common'; +import { db, Prisma } from '@db'; +import type { BillingSkuKey } from '@trycompai/billing'; +import type Stripe from 'stripe'; +import { BillingEntitlementsService } from '../billing/billing-entitlements.service'; +import { StripeService } from '../stripe/stripe.service'; +import { dateFromSeconds, readNumber } from './admin-billing.helpers'; + +export async function getOrgBillingContext(organizationId: string) { + const [organization, billing, subscriptions] = await Promise.all([ + db.organization.findUnique({ + where: { id: organizationId }, + select: { id: true, name: true }, + }), + db.organizationBilling.findUnique({ where: { organizationId } }), + db.organizationBillingSubscription.findMany({ + where: { organizationId }, + orderBy: [{ createdAt: 'desc' }], + }), + ]); + if (!organization) { + throw new NotFoundException(`Organization ${organizationId} not found`); + } + return { organization, billing, subscriptions }; +} + +export async function writeBillingAudit(params: { + organizationId: string; + eventType: string; + skuKey?: string | null; + stripeEventId?: string | null; + metadata?: Prisma.InputJsonValue; +}) { + await db.billingAuditEvent.create({ + data: { + organizationId: params.organizationId, + eventType: params.eventType, + skuKey: params.skuKey, + stripeEventId: params.stripeEventId, + metadata: params.metadata, + }, + }); +} + +export async function createAdminSubscription(params: { + organizationId: string; + stripeCustomerId: string; + skuKey: BillingSkuKey; + stripePriceId: string; + includedQuantity: number; + stripeService: StripeService; + entitlements: BillingEntitlementsService; +}) { + const subscription = await params.stripeService + .getClient() + .subscriptions.create( + { + customer: params.stripeCustomerId, + items: [{ price: params.stripePriceId, quantity: 1 }], + metadata: { + organizationId: params.organizationId, + skuKey: params.skuKey, + source: 'admin-billing', + }, + expand: ['items.data.price'], + }, + { + idempotencyKey: [ + 'admin-subscription-create', + params.organizationId, + params.skuKey, + ].join(':'), + }, + ); + const item = subscription.items.data[0]; + if (!item) throw new NotFoundException('Stripe subscription item not found.'); + await syncStripeSubscriptionItem({ + organizationId: params.organizationId, + skuKey: params.skuKey, + stripePriceId: params.stripePriceId, + includedQuantity: params.includedQuantity, + subscription, + item, + entitlements: params.entitlements, + }); +} + +export async function syncStripeSubscriptionItem(params: { + organizationId: string; + skuKey: BillingSkuKey; + stripePriceId: string; + includedQuantity: number; + subscription: Stripe.Subscription; + item: Stripe.SubscriptionItem; + entitlements: BillingEntitlementsService; +}) { + await params.entitlements.syncSubscriptionItem({ + organizationId: params.organizationId, + skuKey: params.skuKey, + stripeSubscriptionId: params.subscription.id, + stripeSubscriptionItemId: params.item.id, + stripePriceId: params.stripePriceId, + stripeStatus: params.subscription.status, + currentPeriodStart: dateFromSeconds( + readNumber(params.item, 'current_period_start'), + ), + currentPeriodEnd: dateFromSeconds( + readNumber(params.item, 'current_period_end'), + ), + includedQuantity: params.includedQuantity, + cancelAtPeriodEnd: params.subscription.cancel_at_period_end, + canceledAt: dateFromSeconds(params.subscription.canceled_at), + stripeEventId: undefined, + }); +} diff --git a/apps/api/src/admin-organizations/admin-billing.helpers.ts b/apps/api/src/admin-organizations/admin-billing.helpers.ts new file mode 100644 index 0000000000..7e7e7a333a --- /dev/null +++ b/apps/api/src/admin-organizations/admin-billing.helpers.ts @@ -0,0 +1,93 @@ +import { + getBillingCatalog, + getBillingSku, + getBillingSkuProductKey, + type BillingProductKey, + type BillingSkuKey, +} from '@trycompai/billing'; +import type Stripe from 'stripe'; +import type { + AdminBillingPlan, + AdminBillingSubscription, +} from './admin-billing.types'; + +export function listAdminBillingPlans(): AdminBillingPlan[] { + return Object.values(getBillingCatalog().skus) + .filter((sku) => sku.cadence === 'month' && !sku.deprecated) + .map((sku) => ({ + skuKey: sku.key, + productKey: sku.productKey, + name: sku.name, + unitAmount: sku.unitAmount, + currency: sku.currency, + includedQuantity: sku.includedUsage?.quantity ?? 0, + })); +} + +export function mapAdminSubscription(subscription: { + id: string; + skuKey: string; + stripeSubscriptionId: string; + stripeSubscriptionItemId: string; + stripeStatus: string; + includedQuantity: number; + usedQuantity: number; + currentPeriodStart: Date | null; + currentPeriodEnd: Date | null; + cancelAtPeriodEnd: boolean; + canceledAt: Date | null; +}): AdminBillingSubscription { + return { + id: subscription.id, + skuKey: subscription.skuKey, + productKey: getBillingSkuProductKey(subscription.skuKey), + stripeSubscriptionId: subscription.stripeSubscriptionId, + stripeSubscriptionItemId: subscription.stripeSubscriptionItemId, + stripeStatus: subscription.stripeStatus, + includedQuantity: subscription.includedQuantity, + usedQuantity: subscription.usedQuantity, + remainingQuantity: Math.max( + subscription.includedQuantity - subscription.usedQuantity, + 0, + ), + currentPeriodStart: subscription.currentPeriodStart?.toISOString() ?? null, + currentPeriodEnd: subscription.currentPeriodEnd?.toISOString() ?? null, + cancelAtPeriodEnd: subscription.cancelAtPeriodEnd, + canceledAt: subscription.canceledAt?.toISOString() ?? null, + }; +} + +export function getProductFromSku(skuKey: BillingSkuKey): BillingProductKey { + return getBillingSku({ skuKey }).productKey; +} + +export function isDowngrade(params: { + currentIncludedQuantity: number; + nextSkuKey: BillingSkuKey; +}) { + const nextSku = getBillingSku({ skuKey: params.nextSkuKey }); + return ( + (nextSku.includedUsage?.quantity ?? 0) < params.currentIncludedQuantity + ); +} + +export function dateFromSeconds(value: number | null | undefined): Date | null { + return typeof value === 'number' ? new Date(value * 1000) : null; +} + +export function readNumber(value: unknown, key: string): number | null { + if (typeof value !== 'object' || value === null) return null; + const raw = (value as Record)[key]; + return typeof raw === 'number' ? raw : null; +} + +export function extractStripeId(value: unknown): string | null { + if (typeof value === 'string') return value; + if (typeof value !== 'object' || value === null) return null; + const raw = (value as Record).id; + return typeof raw === 'string' ? raw : null; +} + +export function getInvoiceCustomerId(invoice: Stripe.Invoice): string | null { + return extractStripeId(invoice.customer); +} diff --git a/apps/api/src/admin-organizations/admin-billing.service.ts b/apps/api/src/admin-organizations/admin-billing.service.ts new file mode 100644 index 0000000000..f8b5ac8a3b --- /dev/null +++ b/apps/api/src/admin-organizations/admin-billing.service.ts @@ -0,0 +1,260 @@ +import { + BadRequestException, + Injectable, + NotFoundException, +} from '@nestjs/common'; +import { db } from '@db'; +import { + getBillingSku, + resolveBillingCatalogEnvironment, + type BillingSkuKey, +} from '@trycompai/billing'; +import { BillingCreditsService } from '../billing/billing-credits.service'; +import { BillingEntitlementsService } from '../billing/billing-entitlements.service'; +import { listBillingInvoices } from '../billing/billing-invoices'; +import { + getBillingPreferences, + type BillingPreferencesInput, + updateBillingPreferences, +} from '../billing/billing-preferences'; +import { validateBillingRedirectUrl } from '../billing/billing-redirect-urls'; +import { assertStripeBillingConfigured } from '../billing/billing-stripe-config'; +import { changeSubscriptionPlan } from '../billing/billing-subscription-plans'; +import { BillingService } from '../billing/billing.service'; +import { StripeService } from '../stripe/stripe.service'; +import { + getProductFromSku, + isDowngrade, + listAdminBillingPlans, + mapAdminSubscription, +} from './admin-billing.helpers'; +import { + createAdminSubscription, + getOrgBillingContext, + writeBillingAudit, +} from './admin-billing.data'; +import type { + AdminBillingPreview, + AdminBillingStatus, +} from './admin-billing.types'; + +@Injectable() +export class AdminBillingService { + constructor( + private readonly stripeService: StripeService, + private readonly billingService: BillingService, + private readonly credits: BillingCreditsService, + private readonly entitlements: BillingEntitlementsService, + ) {} + + async getStatus(organizationId: string): Promise { + const { organization, billing, subscriptions } = + await getOrgBillingContext(organizationId); + const [ + preferences, + invoices, + usageRows, + creditBalances, + creditEvents, + auditEvents, + ] = await Promise.all([ + getBillingPreferences({ + stripeService: this.stripeService, + stripeCustomerId: billing?.stripeCustomerId ?? null, + fallbackCompanyName: organization.name, + }), + listBillingInvoices({ + stripeService: this.stripeService, + stripeCustomerId: billing?.stripeCustomerId ?? null, + }), + this.billingService + .getStatus(organizationId) + .then((status) => status.usageRows), + this.credits.listBalances(organizationId), + this.credits.listEvents({ organizationId, take: 50 }), + db.billingAuditEvent.findMany({ + where: { organizationId }, + orderBy: { createdAt: 'desc' }, + take: 50, + }), + ]); + + return { + organization, + stripeCustomerId: billing?.stripeCustomerId ?? null, + hasPaymentMethod: !!billing?.stripePaymentMethodId, + paymentMethodUpdatedAt: + billing?.paymentMethodUpdatedAt?.toISOString() ?? null, + preferences, + availablePlans: listAdminBillingPlans(), + subscriptions: subscriptions.map(mapAdminSubscription), + creditBalances, + creditEvents, + usageRows, + invoices, + failedInvoices: invoices.filter((invoice) => + ['open', 'past_due', 'uncollectible'].includes(invoice.status), + ), + auditEvents: auditEvents.map((event) => ({ + id: event.id, + eventType: event.eventType, + skuKey: event.skuKey, + stripeEventId: event.stripeEventId, + metadata: event.metadata, + createdAt: event.createdAt.toISOString(), + })), + }; + } + + async updatePreferences(params: { + organizationId: string; + adminUserId: string; + preferences: BillingPreferencesInput; + note: string; + confirmBillingEmailChange?: boolean; + }): Promise { + assertStripeBillingConfigured(this.stripeService); + const status = await this.getStatus(params.organizationId); + const currentEmail = status.preferences.billingEmail; + if ( + currentEmail && + currentEmail !== params.preferences.billingEmail && + !params.confirmBillingEmailChange + ) { + throw new BadRequestException('Confirm billing email change.'); + } + + const result = await updateBillingPreferences({ + stripeService: this.stripeService, + organizationId: params.organizationId, + preferences: params.preferences, + }); + await writeBillingAudit({ + organizationId: params.organizationId, + eventType: 'admin_billing_preferences_updated', + metadata: { + adminUserId: params.adminUserId, + stripeCustomerId: result.stripeCustomerId, + billingEmail: result.preferences.billingEmail, + note: params.note, + }, + }); + return this.getStatus(params.organizationId); + } + + async previewSubscription(params: { + organizationId: string; + skuKey: BillingSkuKey; + }): Promise { + const { billing, subscriptions } = await getOrgBillingContext( + params.organizationId, + ); + if (!billing) throw new NotFoundException('Billing customer not found.'); + assertStripeBillingConfigured(this.stripeService); + const sku = getBillingSku({ skuKey: params.skuKey }); + const current = subscriptions.find( + (item) => + item.stripeStatus !== 'canceled' && + getProductFromSku(item.skuKey as BillingSkuKey) === sku.productKey, + ); + const prorationDate = Math.floor(Date.now() / 1000); + const invoice = await this.stripeService + .getClient() + .invoices.createPreview({ + customer: billing.stripeCustomerId, + ...(current ? { subscription: current.stripeSubscriptionId } : {}), + subscription_details: { + proration_date: current ? prorationDate : undefined, + items: [ + current + ? { + id: current.stripeSubscriptionItemId, + price: sku.stripePriceId, + quantity: 1, + } + : { price: sku.stripePriceId, quantity: 1 }, + ], + }, + }); + return { + amountDue: invoice.amount_due, + currency: invoice.currency, + subscriptionId: current?.id ?? null, + prorationDate, + }; + } + + async setSubscription(params: { + organizationId: string; + adminUserId: string; + skuKey: BillingSkuKey; + returnUrl: string; + note: string; + confirmDowngrade?: boolean; + }): Promise { + validateBillingRedirectUrl(params.returnUrl); + assertStripeBillingConfigured(this.stripeService); + const { billing, subscriptions } = await getOrgBillingContext( + params.organizationId, + ); + const sku = getBillingSku({ + environment: resolveBillingCatalogEnvironment(), + skuKey: params.skuKey, + }); + const current = subscriptions.find( + (item) => + item.stripeStatus !== 'canceled' && + getProductFromSku(item.skuKey as BillingSkuKey) === sku.productKey, + ); + if ( + current && + isDowngrade({ + currentIncludedQuantity: current.includedQuantity, + nextSkuKey: params.skuKey, + }) && + !params.confirmDowngrade + ) { + throw new BadRequestException('Confirm plan downgrade.'); + } + if (!billing?.stripePaymentMethodId) { + const result = + await this.billingService.createSubscriptionCheckoutSession({ + organizationId: params.organizationId, + skuKey: params.skuKey, + successUrl: params.returnUrl, + cancelUrl: params.returnUrl, + }); + return 'changed' in result + ? this.getStatus(params.organizationId) + : result; + } + if (current) { + await changeSubscriptionPlan({ + organizationId: params.organizationId, + subscription: current, + skuKey: sku.key, + stripePriceId: sku.stripePriceId, + includedQuantity: sku.includedUsage?.quantity ?? 0, + stripeService: this.stripeService, + entitlements: this.entitlements, + }); + } else { + await createAdminSubscription({ + organizationId: params.organizationId, + stripeCustomerId: billing.stripeCustomerId, + skuKey: sku.key, + stripePriceId: sku.stripePriceId, + includedQuantity: sku.includedUsage?.quantity ?? 0, + stripeService: this.stripeService, + entitlements: this.entitlements, + }); + } + await writeBillingAudit({ + organizationId: params.organizationId, + eventType: 'admin_subscription_set', + skuKey: sku.key, + metadata: { adminUserId: params.adminUserId, note: params.note }, + }); + return this.getStatus(params.organizationId); + } +} diff --git a/apps/api/src/admin-organizations/admin-billing.types.ts b/apps/api/src/admin-organizations/admin-billing.types.ts new file mode 100644 index 0000000000..b7cf8a97d6 --- /dev/null +++ b/apps/api/src/admin-organizations/admin-billing.types.ts @@ -0,0 +1,68 @@ +import type { BillingProductKey } from '@trycompai/billing'; +import type { + BillingCreditBalanceSummary, + BillingCreditEventSummary, +} from '../billing/billing-credits.types'; +import type { BillingInvoice } from '../billing/billing-invoices'; +import type { BillingPreferences } from '../billing/billing-preferences'; +import type { BillingUsageRow } from '../billing/billing.types'; + +export interface AdminBillingPlan { + skuKey: string; + productKey: BillingProductKey; + name: string; + unitAmount: number; + currency: string; + includedQuantity: number; +} + +export interface AdminBillingSubscription { + id: string; + skuKey: string; + productKey: BillingProductKey | null; + stripeSubscriptionId: string; + stripeSubscriptionItemId: string; + stripeStatus: string; + includedQuantity: number; + usedQuantity: number; + remainingQuantity: number; + currentPeriodStart: string | null; + currentPeriodEnd: string | null; + cancelAtPeriodEnd: boolean; + canceledAt: string | null; +} + +export interface AdminBillingAuditEvent { + id: string; + eventType: string; + skuKey: string | null; + stripeEventId: string | null; + metadata: unknown; + createdAt: string; +} + +export interface AdminBillingStatus { + organization: { + id: string; + name: string; + }; + stripeCustomerId: string | null; + hasPaymentMethod: boolean; + paymentMethodUpdatedAt: string | null; + preferences: BillingPreferences; + availablePlans: AdminBillingPlan[]; + subscriptions: AdminBillingSubscription[]; + creditBalances: BillingCreditBalanceSummary[]; + creditEvents: BillingCreditEventSummary[]; + usageRows: BillingUsageRow[]; + invoices: BillingInvoice[]; + failedInvoices: BillingInvoice[]; + auditEvents: AdminBillingAuditEvent[]; +} + +export interface AdminBillingPreview { + amountDue: number; + currency: string; + subscriptionId: string | null; + prorationDate: number; +} diff --git a/apps/api/src/admin-organizations/admin-organizations.module.ts b/apps/api/src/admin-organizations/admin-organizations.module.ts index 938dcc6668..4afc166c4c 100644 --- a/apps/api/src/admin-organizations/admin-organizations.module.ts +++ b/apps/api/src/admin-organizations/admin-organizations.module.ts @@ -7,7 +7,11 @@ import { EvidenceFormsModule } from '../evidence-forms/evidence-forms.module'; import { PoliciesModule } from '../policies/policies.module'; import { CommentsModule } from '../comments/comments.module'; import { AttachmentsModule } from '../attachments/attachments.module'; +import { BillingModule } from '../billing/billing.module'; import { SecurityPenetrationTestsModule } from '../security-penetration-tests/security-penetration-tests.module'; +import { AdminBillingActionsService } from './admin-billing-actions.service'; +import { AdminBillingController } from './admin-billing.controller'; +import { AdminBillingService } from './admin-billing.service'; import { AdminOrganizationsController } from './admin-organizations.controller'; import { AdminOrganizationsService } from './admin-organizations.service'; import { PurgeOrganizationService } from './purge-organization.service'; @@ -31,6 +35,7 @@ import { AdminPentestCreditsController } from './admin-pentest-credits.controlle PoliciesModule, CommentsModule, AttachmentsModule, + BillingModule, SecurityPenetrationTestsModule, ], controllers: [ @@ -42,9 +47,12 @@ import { AdminPentestCreditsController } from './admin-pentest-credits.controlle AdminContextController, AdminEvidenceController, AdminPentestCreditsController, + AdminBillingController, ], providers: [ AdminOrganizationsService, + AdminBillingService, + AdminBillingActionsService, PurgeOrganizationService, PurgeOrganizationSnapshotService, PurgeOrganizationExternalService, diff --git a/apps/api/src/admin-organizations/dto/admin-billing.dto.ts b/apps/api/src/admin-organizations/dto/admin-billing.dto.ts new file mode 100644 index 0000000000..32c798ee29 --- /dev/null +++ b/apps/api/src/admin-organizations/dto/admin-billing.dto.ts @@ -0,0 +1,154 @@ +import { subscriptionBillingSkuKeys } from '@trycompai/billing'; +import { + IsBoolean, + IsEmail, + IsIn, + IsInt, + IsNotEmpty, + IsOptional, + IsString, + IsUrl, + Length, + Max, + Min, +} from 'class-validator'; +import { billingTaxIdTypes } from '../../billing/billing-preferences'; + +export class AdminBillingPreferencesDto { + @IsString() + @Length(1, 150) + companyName: string; + + @IsEmail() + billingEmail: string; + + @IsOptional() + @IsString() + @Length(0, 140) + purchaseOrder: string | null; + + @IsOptional() + @IsString() + @Length(0, 200) + addressLine1: string | null; + + @IsOptional() + @IsString() + @Length(0, 200) + addressLine2: string | null; + + @IsOptional() + @IsString() + @Length(0, 100) + addressCity: string | null; + + @IsOptional() + @IsString() + @Length(0, 100) + addressState: string | null; + + @IsOptional() + @IsString() + @Length(0, 32) + addressPostalCode: string | null; + + @IsOptional() + @IsString() + @Length(0, 2) + addressCountry: string | null; + + @IsOptional() + @IsString() + @IsIn([...billingTaxIdTypes, '']) + taxIdType: string | null; + + @IsOptional() + @IsString() + @Length(0, 64) + taxIdValue: string | null; + + @IsOptional() + @IsBoolean() + confirmBillingEmailChange?: boolean; + + @IsString() + @Length(3, 500) + note: string; +} + +export class AdminBillingSubscriptionDto { + @IsString() + @IsIn(subscriptionBillingSkuKeys) + skuKey: string; + + @IsString() + @IsUrl({ require_tld: false }) + returnUrl: string; + + @IsString() + @Length(3, 500) + note: string; + + @IsOptional() + @IsBoolean() + confirmDowngrade?: boolean; +} + +export class AdminBillingSubscriptionPreviewDto { + @IsString() + @IsIn(subscriptionBillingSkuKeys) + skuKey: string; +} + +export class AdminBillingCancelSubscriptionDto { + @IsIn(['period_end', 'immediate']) + mode: 'period_end' | 'immediate'; + + @IsString() + @Length(3, 500) + note: string; + + @IsOptional() + @IsString() + confirm?: string; +} + +export class AdminBillingNoteDto { + @IsString() + @Length(3, 500) + note: string; +} + +export class AdminBillingPaymentLinkDto { + @IsString() + @IsUrl({ require_tld: false }) + successUrl: string; + + @IsString() + @IsUrl({ require_tld: false }) + cancelUrl: string; +} + +export class AdminBillingGrantCreditsDto { + @IsIn(['pentest', 'background_check']) + productKey: 'pentest' | 'background_check'; + + @IsInt() + @Min(1) + @Max(1000) + quantity: number; + + @IsString() + @Length(3, 500) + note: string; + + @IsOptional() + @IsString() + confirm?: string; +} + +export class AdminBillingInvoiceActionDto { + @IsString() + @IsNotEmpty() + note: string; +} diff --git a/apps/api/src/billing/billing-credits.service.spec.ts b/apps/api/src/billing/billing-credits.service.spec.ts new file mode 100644 index 0000000000..f14d604c92 --- /dev/null +++ b/apps/api/src/billing/billing-credits.service.spec.ts @@ -0,0 +1,132 @@ +import { db } from '@db'; +import { BillingCreditsService } from './billing-credits.service'; + +jest.mock('@db', () => ({ + db: { + billingCreditBalance: { + findMany: jest.fn(), + findFirst: jest.fn(), + findFirstOrThrow: jest.fn(), + findUniqueOrThrow: jest.fn(), + create: jest.fn(), + }, + billingCreditEvent: { + findMany: jest.fn(), + findFirst: jest.fn(), + }, + $transaction: jest.fn(), + }, +})); + +type MockTx = { + billingCreditEvent: { create: jest.Mock }; + billingCreditBalance: { update: jest.Mock; updateMany: jest.Mock }; +}; + +const mockedDb = db as unknown as { + billingCreditBalance: { + findMany: jest.Mock; + findFirst: jest.Mock; + findFirstOrThrow: jest.Mock; + findUniqueOrThrow: jest.Mock; + create: jest.Mock; + }; + billingCreditEvent: { findMany: jest.Mock; findFirst: jest.Mock }; + $transaction: jest.Mock; +}; + +describe('BillingCreditsService', () => { + let service: BillingCreditsService; + let tx: MockTx; + + beforeEach(() => { + jest.clearAllMocks(); + tx = { + billingCreditEvent: { create: jest.fn() }, + billingCreditBalance: { + update: jest.fn(), + updateMany: jest.fn().mockResolvedValue({ count: 1 }), + }, + }; + mockedDb.$transaction.mockImplementation( + (callback: (tx: MockTx) => Promise) => callback(tx), + ); + service = new BillingCreditsService(); + }); + + it('grants credits to an org-scoped product balance', async () => { + mockedDb.billingCreditBalance.findFirst.mockResolvedValue({ + id: 'bcb_1', + organizationId: 'org_1', + productKey: 'pentest', + skuKey: null, + }); + mockedDb.billingCreditBalance.findUniqueOrThrow.mockResolvedValue({ + id: 'bcb_1', + productKey: 'pentest', + skuKey: null, + balance: 3, + totalGranted: 3, + totalConsumed: 0, + totalRefunded: 0, + lastSource: 'manual', + updatedAt: new Date('2026-05-01T00:00:00.000Z'), + }); + + await expect( + service.grant({ + organizationId: 'org_1', + productKey: 'pentest', + quantity: 3, + source: 'manual', + note: 'CS goodwill', + adminUserId: 'usr_1', + }), + ).resolves.toMatchObject({ balance: 3, productKey: 'pentest' }); + + expect(tx.billingCreditEvent.create).toHaveBeenCalledWith( + expect.objectContaining({ + data: expect.objectContaining({ + organizationId: 'org_1', + productKey: 'pentest', + quantity: 3, + adminUserId: 'usr_1', + }), + }), + ); + expect(tx.billingCreditBalance.update).toHaveBeenCalledWith( + expect.objectContaining({ + where: { id: 'bcb_1' }, + data: expect.objectContaining({ balance: { increment: 3 } }), + }), + ); + }); + + it('consumes one available manual credit atomically', async () => { + mockedDb.billingCreditBalance.findMany.mockResolvedValue([ + { + id: 'bcb_1', + organizationId: 'org_1', + productKey: 'background_check', + skuKey: null, + balance: 1, + }, + ]); + + await expect( + service.tryConsumeForProduct({ + organizationId: 'org_1', + productKey: 'background_check', + sourceResourceId: 'mem_1', + }), + ).resolves.toEqual({ status: 'consumed' }); + + expect(tx.billingCreditBalance.updateMany).toHaveBeenCalledWith({ + where: { id: 'bcb_1', balance: { gt: 0 } }, + data: { + balance: { decrement: 1 }, + totalConsumed: { increment: 1 }, + }, + }); + }); +}); diff --git a/apps/api/src/billing/billing-credits.service.ts b/apps/api/src/billing/billing-credits.service.ts new file mode 100644 index 0000000000..a224a2207c --- /dev/null +++ b/apps/api/src/billing/billing-credits.service.ts @@ -0,0 +1,284 @@ +import { BadRequestException, Injectable } from '@nestjs/common'; +import { db } from '@db'; +import type { BillingProductKey, BillingSkuKey } from '@trycompai/billing'; +import { isUniqueConstraintError } from './billing-entitlements.types'; +import { + assertProductKey, + type BillingCreditBalanceSummary, + type BillingCreditEventSummary, + validateCreditInput, +} from './billing-credits.types'; + +@Injectable() +export class BillingCreditsService { + async listBalances( + organizationId: string, + ): Promise { + const balances = await db.billingCreditBalance.findMany({ + where: { organizationId }, + orderBy: [{ productKey: 'asc' }, { createdAt: 'asc' }], + }); + return balances.map((balance) => ({ + id: balance.id, + productKey: assertProductKey(balance.productKey), + skuKey: balance.skuKey, + balance: balance.balance, + totalGranted: balance.totalGranted, + totalConsumed: balance.totalConsumed, + totalRefunded: balance.totalRefunded, + lastSource: balance.lastSource, + updatedAt: balance.updatedAt.toISOString(), + })); + } + + async listEvents(params: { + organizationId: string; + take?: number; + }): Promise { + const events = await db.billingCreditEvent.findMany({ + where: { organizationId: params.organizationId }, + orderBy: { createdAt: 'desc' }, + take: params.take ?? 50, + }); + return events.map((event) => ({ + id: event.id, + productKey: assertProductKey(event.productKey), + skuKey: event.skuKey, + eventType: event.eventType, + quantity: event.quantity, + source: event.source, + note: event.note, + adminUserId: event.adminUserId, + sourceResourceId: event.sourceResourceId, + createdAt: event.createdAt.toISOString(), + })); + } + + async grant(params: { + organizationId: string; + productKey: BillingProductKey; + skuKey?: BillingSkuKey | null; + quantity: number; + source: string; + note: string; + adminUserId?: string | null; + idempotencyKey?: string; + }): Promise { + validateCreditInput(params); + const balance = await this.findOrCreateBalance({ + organizationId: params.organizationId, + productKey: params.productKey, + skuKey: params.skuKey ?? null, + }); + const idempotencyKey = + params.idempotencyKey ?? + [ + 'grant', + params.organizationId, + params.productKey, + params.skuKey ?? 'product', + params.quantity, + params.note.trim().toLowerCase(), + ].join(':'); + + try { + await db.$transaction(async (tx) => { + await tx.billingCreditEvent.create({ + data: { + organizationId: params.organizationId, + balanceId: balance.id, + productKey: params.productKey, + skuKey: params.skuKey ?? null, + eventType: 'grant', + quantity: params.quantity, + source: params.source, + note: params.note.trim(), + adminUserId: params.adminUserId ?? null, + idempotencyKey, + }, + }); + await tx.billingCreditBalance.update({ + where: { id: balance.id }, + data: { + balance: { increment: params.quantity }, + totalGranted: { increment: params.quantity }, + lastSource: params.source, + }, + }); + }); + } catch (error) { + if (!isUniqueConstraintError(error)) throw error; + } + + return this.getBalance(balance.id); + } + + async tryConsumeForProduct(params: { + organizationId: string; + productKey: BillingProductKey; + sourceResourceId: string; + }): Promise<{ status: 'consumed' | 'not_configured' | 'exhausted' }> { + const existing = await db.billingCreditBalance.findMany({ + where: { + organizationId: params.organizationId, + productKey: params.productKey, + }, + orderBy: { createdAt: 'asc' }, + }); + if (existing.length === 0) return { status: 'not_configured' }; + + const balance = existing.find((item) => item.balance > 0); + if (!balance) return { status: 'exhausted' }; + + const idempotencyKey = [ + 'credit-consume', + params.organizationId, + params.productKey, + params.sourceResourceId, + ].join(':'); + + try { + await db.$transaction(async (tx) => { + await tx.billingCreditEvent.create({ + data: { + organizationId: params.organizationId, + balanceId: balance.id, + productKey: params.productKey, + skuKey: balance.skuKey, + eventType: 'consume', + quantity: 1, + source: 'manual_credit', + sourceResourceId: params.sourceResourceId, + idempotencyKey, + }, + }); + const updated = await tx.billingCreditBalance.updateMany({ + where: { id: balance.id, balance: { gt: 0 } }, + data: { + balance: { decrement: 1 }, + totalConsumed: { increment: 1 }, + }, + }); + if (updated.count === 0) { + throw new BadRequestException('Credit balance is exhausted.'); + } + }); + } catch (error) { + if (isUniqueConstraintError(error)) return { status: 'consumed' }; + throw error; + } + + return { status: 'consumed' }; + } + + async refundForProduct(params: { + organizationId: string; + productKey: BillingProductKey; + sourceResourceId: string; + reason: string; + }): Promise { + const consumed = await db.billingCreditEvent.findFirst({ + where: { + organizationId: params.organizationId, + productKey: params.productKey, + eventType: 'consume', + sourceResourceId: params.sourceResourceId, + }, + orderBy: { createdAt: 'desc' }, + }); + if (!consumed) return false; + + const idempotencyKey = [ + 'credit-refund', + params.organizationId, + params.productKey, + params.sourceResourceId, + params.reason, + ].join(':'); + + try { + await db.$transaction(async (tx) => { + await tx.billingCreditEvent.create({ + data: { + organizationId: params.organizationId, + balanceId: consumed.balanceId, + productKey: params.productKey, + skuKey: consumed.skuKey, + eventType: 'refund', + quantity: consumed.quantity, + source: 'refund', + note: params.reason, + sourceResourceId: params.sourceResourceId, + linkedEventId: consumed.id, + idempotencyKey, + }, + }); + await tx.billingCreditBalance.update({ + where: { id: consumed.balanceId }, + data: { + balance: { increment: consumed.quantity }, + totalRefunded: { increment: consumed.quantity }, + totalConsumed: { decrement: consumed.quantity }, + lastSource: 'refund', + }, + }); + }); + } catch (error) { + if (!isUniqueConstraintError(error)) throw error; + } + + return true; + } + + private async findOrCreateBalance(params: { + organizationId: string; + productKey: BillingProductKey; + skuKey: BillingSkuKey | null; + }) { + const existing = await db.billingCreditBalance.findFirst({ + where: { + organizationId: params.organizationId, + productKey: params.productKey, + skuKey: params.skuKey, + }, + }); + if (existing) return existing; + + try { + return await db.billingCreditBalance.create({ + data: { + organizationId: params.organizationId, + productKey: params.productKey, + skuKey: params.skuKey, + lastSource: 'manual', + }, + }); + } catch (error) { + if (!isUniqueConstraintError(error)) throw error; + return db.billingCreditBalance.findFirstOrThrow({ + where: { + organizationId: params.organizationId, + productKey: params.productKey, + skuKey: params.skuKey, + }, + }); + } + } + + private async getBalance(id: string): Promise { + const balance = await db.billingCreditBalance.findUniqueOrThrow({ + where: { id }, + }); + return { + id: balance.id, + productKey: assertProductKey(balance.productKey), + skuKey: balance.skuKey, + balance: balance.balance, + totalGranted: balance.totalGranted, + totalConsumed: balance.totalConsumed, + totalRefunded: balance.totalRefunded, + lastSource: balance.lastSource, + updatedAt: balance.updatedAt.toISOString(), + }; + } +} diff --git a/apps/api/src/billing/billing-credits.types.ts b/apps/api/src/billing/billing-credits.types.ts new file mode 100644 index 0000000000..6df403775b --- /dev/null +++ b/apps/api/src/billing/billing-credits.types.ts @@ -0,0 +1,63 @@ +import { BadRequestException } from '@nestjs/common'; +import { + getBillingSkuProductKey, + type BillingProductKey, + type BillingSkuKey, +} from '@trycompai/billing'; + +export type BillingCreditEventType = + | 'grant' + | 'consume' + | 'refund' + | 'adjustment' + | 'migration'; + +export interface BillingCreditBalanceSummary { + id: string; + productKey: BillingProductKey; + skuKey: string | null; + balance: number; + totalGranted: number; + totalConsumed: number; + totalRefunded: number; + lastSource: string; + updatedAt: string; +} + +export interface BillingCreditEventSummary { + id: string; + productKey: BillingProductKey; + skuKey: string | null; + eventType: string; + quantity: number; + source: string; + note: string | null; + adminUserId: string | null; + sourceResourceId: string | null; + createdAt: string; +} + +export function validateCreditInput(params: { + productKey: BillingProductKey; + skuKey?: BillingSkuKey | null; + quantity: number; + note: string; +}) { + if (!Number.isInteger(params.quantity) || params.quantity < 1) { + throw new BadRequestException('Credit amount must be a positive integer.'); + } + if (!params.note.trim()) { + throw new BadRequestException('A note is required for credit grants.'); + } + if (params.skuKey) { + const productKey = getBillingSkuProductKey(params.skuKey); + if (productKey !== params.productKey) { + throw new BadRequestException('SKU does not belong to product.'); + } + } +} + +export function assertProductKey(value: string): BillingProductKey { + if (value === 'pentest' || value === 'background_check') return value; + throw new BadRequestException('Unsupported billing product.'); +} diff --git a/apps/api/src/billing/billing-entitlements.service.ts b/apps/api/src/billing/billing-entitlements.service.ts index dd4612c1a1..95aa86445c 100644 --- a/apps/api/src/billing/billing-entitlements.service.ts +++ b/apps/api/src/billing/billing-entitlements.service.ts @@ -5,6 +5,7 @@ import { type BillingProductKey, type BillingSkuKey, } from '@trycompai/billing'; +import { BillingCreditsService } from './billing-credits.service'; import { refundIncludedUsageEvent } from './billing-included-usage-refunds'; import { type BillingConsumeResult, @@ -17,6 +18,8 @@ import { @Injectable() export class BillingEntitlementsService { + constructor(private readonly credits?: BillingCreditsService) {} + async tryConsumeIncludedUsageForProduct(params: { organizationId: string; productKey: BillingProductKey; @@ -38,11 +41,22 @@ export class BillingEntitlementsService { subscription.currentPeriodEnd.getTime() > Date.now()), ); if (!activeSubscription) { - return { status: 'not_configured' }; + return this.tryConsumeCreditFallback({ + ...params, + fallbackStatus: 'not_configured', + }); } - if (activeSubscription.usedQuantity >= activeSubscription.includedQuantity) { - return { status: 'exhausted', subscriptionId: activeSubscription.id }; + if ( + activeSubscription.usedQuantity >= activeSubscription.includedQuantity + ) { + const creditResult = await this.tryConsumeCreditFallback({ + ...params, + fallbackStatus: 'exhausted', + }); + return creditResult.status === 'consumed' + ? creditResult + : { status: 'exhausted', subscriptionId: activeSubscription.id }; } return this.tryConsumeIncludedUsage({ @@ -291,7 +305,16 @@ export class BillingEntitlementsService { orderBy: { createdAt: 'desc' }, select: { skuKey: true }, }); - if (!consumed) return; + if (!consumed) { + if (params.tx || !this.credits) return; + await this.credits.refundForProduct({ + organizationId: params.organizationId, + productKey: params.productKey, + sourceResourceId: params.sourceResourceId, + reason: params.reason, + }); + return; + } await refundIncludedUsageEvent({ organizationId: params.organizationId, skuKey: consumed.skuKey as BillingSkuKey, @@ -301,6 +324,29 @@ export class BillingEntitlementsService { }); } + private async tryConsumeCreditFallback(params: { + organizationId: string; + productKey: BillingProductKey; + sourceResourceId: string; + fallbackStatus: 'not_configured' | 'exhausted'; + }): Promise { + if (!this.credits) { + return params.fallbackStatus === 'exhausted' + ? { status: 'exhausted', subscriptionId: 'manual_credit' } + : { status: 'not_configured' }; + } + const creditResult = await this.credits.tryConsumeForProduct({ + organizationId: params.organizationId, + productKey: params.productKey, + sourceResourceId: params.sourceResourceId, + }); + return creditResult.status === 'consumed' + ? { status: 'consumed', subscriptionId: 'manual_credit' } + : params.fallbackStatus === 'exhausted' + ? { status: 'exhausted', subscriptionId: 'manual_credit' } + : { status: 'not_configured' }; + } + async writeAuditEvent(params: WriteBillingAuditEventParams): Promise { await db.billingAuditEvent.create({ data: { diff --git a/apps/api/src/billing/billing-subscription-plans.ts b/apps/api/src/billing/billing-subscription-plans.ts index 61b128b6d9..2cf87884eb 100644 --- a/apps/api/src/billing/billing-subscription-plans.ts +++ b/apps/api/src/billing/billing-subscription-plans.ts @@ -1,6 +1,10 @@ +import { HttpException, HttpStatus } from '@nestjs/common'; import { db } from '@db'; import { + billingCatalogs, + getBillingSku, type BillingProductKey, + type BillingSku, type BillingSkuKey, getBillingSkuProductKey, } from '@trycompai/billing'; @@ -41,6 +45,7 @@ export async function changeSubscriptionPlan(params: { subscription: { id: string; skuKey: string; + stripeStatus: string; stripeSubscriptionId: string; stripeSubscriptionItemId: string; currentPeriodStart: Date | null; @@ -53,30 +58,62 @@ export async function changeSubscriptionPlan(params: { entitlements: BillingEntitlementsService; }): Promise<{ changed: true }> { const stripe = params.stripeService.getClient(); - const updatedSubscription = await stripe.subscriptions.update( - params.subscription.stripeSubscriptionId, - { - items: [ - { - id: params.subscription.stripeSubscriptionItemId, - price: params.stripePriceId, - }, - ], - metadata: { - organizationId: params.organizationId, - skuKey: params.skuKey, - source: 'comp-billing-subscription', - }, - }, - { - idempotencyKey: [ - 'subscription-plan-change', - params.organizationId, - params.subscription.stripeSubscriptionItemId, - params.skuKey, - ].join(':'), + const isUpgrade = isPlanUpgrade({ + currentSkuKey: params.subscription.skuKey, + nextSkuKey: params.skuKey, + }); + const shouldEndTrial = isUpgrade && params.subscription.stripeStatus === 'trialing'; + const updateParams = { + items: [ + isUpgrade + ? { + id: params.subscription.stripeSubscriptionItemId, + price: params.stripePriceId, + quantity: 1, + } + : { + id: params.subscription.stripeSubscriptionItemId, + price: params.stripePriceId, + }, + ], + metadata: { + organizationId: params.organizationId, + skuKey: params.skuKey, + source: 'comp-billing-subscription', }, - ); + ...(isUpgrade + ? { + proration_behavior: 'always_invoice' as const, + payment_behavior: 'error_if_incomplete' as const, + } + : {}), + ...(shouldEndTrial ? { trial_end: 'now' as const } : {}), + }; + let updatedSubscription: Awaited< + ReturnType + >; + try { + updatedSubscription = await stripe.subscriptions.update( + params.subscription.stripeSubscriptionId, + updateParams, + { + idempotencyKey: [ + 'subscription-plan-change-v2', + params.organizationId, + params.subscription.stripeSubscriptionItemId, + params.skuKey, + ].join(':'), + }, + ); + } catch (error) { + if (isUpgrade && isPaymentRequiredStripeError(error)) { + throw new HttpException( + 'We could not charge the prorated upgrade amount. Please update your payment method and try again.', + HttpStatus.PAYMENT_REQUIRED, + ); + } + throw error; + } const updatedItem = updatedSubscription.items.data.find( @@ -131,3 +168,26 @@ function readNumber(value: unknown, key: string): number | null { const raw = (value as Record)[key]; return typeof raw === 'number' ? raw : null; } + +function isPlanUpgrade(params: { + currentSkuKey: string; + nextSkuKey: BillingSkuKey; +}): boolean { + const currentSku = findBillingSku(params.currentSkuKey); + const nextSku = getBillingSku({ skuKey: params.nextSkuKey }); + return currentSku !== null && nextSku.unitAmount > currentSku.unitAmount; +} + +function findBillingSku(skuKey: string): BillingSku | null { + for (const catalog of Object.values(billingCatalogs)) { + const sku = Object.values(catalog.skus).find((item) => item.key === skuKey); + if (sku) return sku; + } + return null; +} + +function isPaymentRequiredStripeError(error: unknown): boolean { + if (typeof error !== 'object' || error === null) return false; + const record = error as Record; + return record.statusCode === HttpStatus.PAYMENT_REQUIRED; +} diff --git a/apps/api/src/billing/billing.module.ts b/apps/api/src/billing/billing.module.ts index c2c78f3e4a..9e22d1e96f 100644 --- a/apps/api/src/billing/billing.module.ts +++ b/apps/api/src/billing/billing.module.ts @@ -1,18 +1,21 @@ import { Module } from '@nestjs/common'; import { AuthModule } from '../auth/auth.module'; +import { StripeModule } from '../stripe/stripe.module'; import { BillingController } from './billing.controller'; +import { BillingCreditsService } from './billing-credits.service'; import { BillingEntitlementsService } from './billing-entitlements.service'; import { BillingService } from './billing.service'; import { BillingWebhookService } from './billing-webhook.service'; @Module({ - imports: [AuthModule], + imports: [AuthModule, StripeModule], controllers: [BillingController], providers: [ BillingService, + BillingCreditsService, BillingEntitlementsService, BillingWebhookService, ], - exports: [BillingService, BillingEntitlementsService], + exports: [BillingService, BillingCreditsService, BillingEntitlementsService], }) export class BillingModule {} diff --git a/apps/api/src/billing/billing.service.spec.ts b/apps/api/src/billing/billing.service.spec.ts index 7b0c6d3ccf..649a205a0e 100644 --- a/apps/api/src/billing/billing.service.spec.ts +++ b/apps/api/src/billing/billing.service.spec.ts @@ -1,4 +1,4 @@ -import { BadRequestException } from '@nestjs/common'; +import { BadRequestException, HttpException, HttpStatus } from '@nestjs/common'; import { db } from '@db'; import { BillingService } from './billing.service'; import type { StripeService } from '../stripe/stripe.service'; @@ -235,7 +235,7 @@ describe('BillingService', () => { }); }); - it('changes an existing product subscription instead of creating another checkout', async () => { + it('charges immediately when upgrading an existing product subscription', async () => { const subscriptionsUpdate = jest.fn().mockResolvedValue({ id: 'sub_1', status: 'active', @@ -255,7 +255,7 @@ describe('BillingService', () => { organizationBillingSubscriptionFindMany.mockResolvedValue([ { id: 'obs_1', - skuKey: 'pentest_monthly_1', + skuKey: 'pentest_monthly_3', stripeSubscriptionId: 'sub_1', stripeSubscriptionItemId: 'si_1', stripeStatus: 'active', @@ -274,7 +274,7 @@ describe('BillingService', () => { await expect( service.createSubscriptionCheckoutSession({ organizationId: 'org_1', - skuKey: 'pentest_monthly_3', + skuKey: 'pentest_monthly_5_current', successUrl: 'http://localhost:3000/org_1/settings/billing/success', cancelUrl: 'http://localhost:3000/org_1/settings/billing', }), @@ -283,7 +283,15 @@ describe('BillingService', () => { expect(subscriptionsUpdate).toHaveBeenCalledWith( 'sub_1', expect.objectContaining({ - items: [{ id: 'si_1', price: 'price_1TS3ziCkFWhKYvHI1nbXC7UU' }], + items: [ + { + id: 'si_1', + price: 'price_1TS3zjCkFWhKYvHISBHjtZXB', + quantity: 1, + }, + ], + proration_behavior: 'always_invoice', + payment_behavior: 'error_if_incomplete', }), expect.anything(), ); @@ -291,9 +299,9 @@ describe('BillingService', () => { expect.objectContaining({ where: { id: 'obs_1' }, data: expect.objectContaining({ - skuKey: 'pentest_monthly_3', + skuKey: 'pentest_monthly_5_current', stripeSubscriptionItemId: 'si_1', - includedQuantity: 3, + includedQuantity: 5, }), }), ); @@ -301,8 +309,201 @@ describe('BillingService', () => { expect.objectContaining({ organizationId: 'org_1', eventType: 'subscription_plan_changed', + skuKey: 'pentest_monthly_5_current', + }), + ); + }); + + it('ends the trial immediately when upgrading from a trial plan', async () => { + const subscriptionsUpdate = jest.fn().mockResolvedValue({ + id: 'sub_1', + status: 'active', + cancel_at_period_end: false, + canceled_at: null, + items: { + data: [ + { + id: 'si_1', + current_period_start: 1775001600, + current_period_end: 1777593600, + }, + ], + }, + }); + organizationBillingSubscriptionFindMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'background_checks_monthly_3', + stripeSubscriptionId: 'sub_1', + stripeSubscriptionItemId: 'si_1', + stripeStatus: 'trialing', + currentPeriodStart: new Date('2026-04-01T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-05-01T00:00:00.000Z'), + createdAt: new Date('2026-04-01T00:00:00.000Z'), + }, + ]); + const service = new BillingService( + mockStripeService({ + subscriptions: { update: subscriptionsUpdate }, + }), + { syncSubscriptionItem: jest.fn(), writeAuditEvent: jest.fn() } as never, + ); + + await service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'background_checks_monthly_20', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }); + + expect(subscriptionsUpdate).toHaveBeenCalledWith( + 'sub_1', + expect.objectContaining({ + items: [ + { + id: 'si_1', + price: 'price_1TS3zjCkFWhKYvHIU5jMCCWs', + quantity: 1, + }, + ], + proration_behavior: 'always_invoice', + payment_behavior: 'error_if_incomplete', + trial_end: 'now', + }), + expect.objectContaining({ + idempotencyKey: + 'subscription-plan-change-v2:org_1:si_1:background_checks_monthly_20', + }), + ); + }); + + it('does not grant upgraded credits when immediate upgrade payment fails', async () => { + const subscriptionsUpdate = jest.fn().mockRejectedValue( + Object.assign(new Error('Card declined'), { + statusCode: HttpStatus.PAYMENT_REQUIRED, + }), + ); + const writeAuditEvent = jest.fn().mockResolvedValue(undefined); + organizationBillingSubscriptionFindMany.mockResolvedValue([ + { + id: 'obs_1', skuKey: 'pentest_monthly_3', + stripeSubscriptionId: 'sub_1', + stripeSubscriptionItemId: 'si_1', + stripeStatus: 'active', + currentPeriodStart: new Date('2026-04-01T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-05-01T00:00:00.000Z'), + createdAt: new Date('2026-04-01T00:00:00.000Z'), + }, + ]); + const service = new BillingService( + mockStripeService({ + subscriptions: { update: subscriptionsUpdate }, }), + { syncSubscriptionItem: jest.fn(), writeAuditEvent } as never, + ); + + try { + await service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_5_current', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }); + throw new Error('Expected upgrade payment failure'); + } catch (error) { + expect(error).toBeInstanceOf(HttpException); + if (error instanceof HttpException) { + expect(error.getStatus()).toBe(HttpStatus.PAYMENT_REQUIRED); + } + } + expect(organizationBillingSubscriptionUpdate).not.toHaveBeenCalled(); + expect(writeAuditEvent).not.toHaveBeenCalled(); + }); + + it('keeps same-plan subscription changes rejected before calling Stripe', async () => { + const subscriptionsUpdate = jest.fn(); + organizationBillingSubscriptionFindMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'pentest_monthly_3', + stripeSubscriptionId: 'sub_1', + stripeSubscriptionItemId: 'si_1', + stripeStatus: 'active', + createdAt: new Date('2026-04-01T00:00:00.000Z'), + }, + ]); + const service = new BillingService( + mockStripeService({ + subscriptions: { update: subscriptionsUpdate }, + }), + { syncSubscriptionItem: jest.fn(), writeAuditEvent: jest.fn() } as never, + ); + + await expect( + service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_3', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }), + ).rejects.toBeInstanceOf(BadRequestException); + expect(subscriptionsUpdate).not.toHaveBeenCalled(); + }); + + it('does not immediately invoice lower-priced plan switches', async () => { + const subscriptionsUpdate = jest.fn().mockResolvedValue({ + id: 'sub_1', + status: 'active', + cancel_at_period_end: false, + canceled_at: null, + items: { + data: [ + { + id: 'si_1', + current_period_start: 1775001600, + current_period_end: 1777593600, + }, + ], + }, + }); + organizationBillingSubscriptionFindMany.mockResolvedValue([ + { + id: 'obs_1', + skuKey: 'pentest_monthly_5_current', + stripeSubscriptionId: 'sub_1', + stripeSubscriptionItemId: 'si_1', + stripeStatus: 'active', + currentPeriodStart: new Date('2026-04-01T00:00:00.000Z'), + currentPeriodEnd: new Date('2026-05-01T00:00:00.000Z'), + createdAt: new Date('2026-04-01T00:00:00.000Z'), + }, + ]); + const service = new BillingService( + mockStripeService({ + subscriptions: { update: subscriptionsUpdate }, + }), + { syncSubscriptionItem: jest.fn(), writeAuditEvent: jest.fn() } as never, + ); + + await service.createSubscriptionCheckoutSession({ + organizationId: 'org_1', + skuKey: 'pentest_monthly_3', + successUrl: 'http://localhost:3000/org_1/settings/billing/success', + cancelUrl: 'http://localhost:3000/org_1/settings/billing', + }); + + expect(subscriptionsUpdate).toHaveBeenCalledWith( + 'sub_1', + { + items: [{ id: 'si_1', price: 'price_1TS3ziCkFWhKYvHI1nbXC7UU' }], + metadata: { + organizationId: 'org_1', + skuKey: 'pentest_monthly_3', + source: 'comp-billing-subscription', + }, + }, + expect.anything(), ); }); diff --git a/apps/api/src/security-penetration-tests/pentest-credits.service.ts b/apps/api/src/security-penetration-tests/pentest-credits.service.ts index a2952d82f5..0f310c08e3 100644 --- a/apps/api/src/security-penetration-tests/pentest-credits.service.ts +++ b/apps/api/src/security-penetration-tests/pentest-credits.service.ts @@ -1,5 +1,6 @@ import { HttpException, HttpStatus, Injectable, Logger } from '@nestjs/common'; import { AuditLogEntityType, db, Prisma } from '@db'; +import { BillingCreditsService } from '../billing/billing-credits.service'; /** * Source of credits — free-form so v2 can add new sources without a schema @@ -40,6 +41,7 @@ export type PentestAuditAction = @Injectable() export class PentestCreditsService { private readonly logger = new Logger(PentestCreditsService.name); + constructor(private readonly billingCredits?: BillingCreditsService) {} /** * Default trial grant for new orgs. Static today; in v2 this can become a @@ -49,6 +51,24 @@ export class PentestCreditsService { private readonly initialTrialAmount = 1; async getStatus(organizationId: string): Promise { + if (this.billingCredits) { + const balances = await this.billingCredits.listBalances(organizationId); + const balance = balances.find((item) => item.productKey === 'pentest'); + return balance + ? { + balance: balance.balance, + totalGranted: balance.totalGranted, + totalConsumed: balance.totalConsumed, + lastGrantSource: balance.lastSource, + } + : { + balance: 0, + totalGranted: 0, + totalConsumed: 0, + lastGrantSource: 'none', + }; + } + const row = await db.pentestCredits.findUnique({ where: { organizationId }, }); @@ -94,6 +114,16 @@ export class PentestCreditsService { `grant amount must be positive (got ${amount} for org=${organizationId})`, ); } + if (this.billingCredits) { + await this.billingCredits.grant({ + organizationId, + productKey: 'pentest', + quantity: amount, + source, + note: `Pentest credits granted from ${source}`, + }); + return; + } await db.pentestCredits.upsert({ where: { organizationId }, create: { @@ -135,6 +165,25 @@ export class PentestCreditsService { runId?: string, tx?: Pick, ): Promise { + if (this.billingCredits && !tx) { + const result = await this.billingCredits.tryConsumeForProduct({ + organizationId, + productKey: 'pentest', + sourceResourceId: runId ?? 'pending', + }); + if (result.status !== 'consumed') { + throw new HttpException( + { + error: + 'No pentest runs remaining. Choose a penetration test plan to continue.', + code: 'pentest_credits_exhausted', + }, + HttpStatus.PAYMENT_REQUIRED, + ); + } + return this.getStatus(organizationId); + } + const client = tx ?? db; const updated = await client.pentestCredits.updateMany({ where: { organizationId, balance: { gt: 0 } }, @@ -202,6 +251,16 @@ export class PentestCreditsService { reason: string, tx?: Pick, ): Promise { + if (this.billingCredits && !tx) { + await this.billingCredits.refundForProduct({ + organizationId, + productKey: 'pentest', + sourceResourceId: runId ?? 'pending', + reason, + }); + return; + } + // Use the optional tx client when provided so the caller can wrap // claim+refund in a single transaction (webhook idempotency: if the // refund DB write fails, the claim rolls back and a redelivered diff --git a/apps/app/src/app/(app)/[orgId]/admin/organizations/[adminOrgId]/components/AdminBillingForms.tsx b/apps/app/src/app/(app)/[orgId]/admin/organizations/[adminOrgId]/components/AdminBillingForms.tsx new file mode 100644 index 0000000000..83a0e0e33b --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/admin/organizations/[adminOrgId]/components/AdminBillingForms.tsx @@ -0,0 +1,228 @@ +'use client'; + +import { zodResolver } from '@hookform/resolvers/zod'; +import { + Button, + Input, + Select, + SelectContent, + SelectItem, + SelectTrigger, + Textarea, +} from '@trycompai/design-system'; +import { Controller, useForm } from 'react-hook-form'; +import { z } from 'zod'; +import type { AdminBillingStatus } from './AdminBillingTypes'; + +const creditSchema = z.object({ + productKey: z.enum(['pentest', 'background_check']), + quantity: z.number().int().min(1).max(1000), + note: z.string().min(3).max(500), + confirm: z.string().optional(), +}); + +export type CreditFormValues = z.infer; + +export function CreditGrantForm({ + onSubmit, + loading, +}: { + onSubmit: (values: CreditFormValues) => Promise; + loading: boolean; +}) { + const form = useForm({ + resolver: zodResolver(creditSchema), + defaultValues: { + productKey: 'pentest', + quantity: 1, + note: '', + confirm: '', + }, + }); + + const handleSubmit = form.handleSubmit(async (values) => { + await onSubmit(values); + form.reset({ + productKey: values.productKey, + quantity: 1, + note: '', + confirm: '', + }); + }); + + return ( + + ( + + )} + /> + + + +
    + +
    + + ); +} + +const preferenceSchema = z.object({ + companyName: z.string().min(1).max(150), + billingEmail: z.string().email(), + purchaseOrder: z.string().max(140).optional(), + addressLine1: z.string().max(200).optional(), + addressLine2: z.string().max(200).optional(), + addressCity: z.string().max(100).optional(), + addressState: z.string().max(100).optional(), + addressPostalCode: z.string().max(32).optional(), + addressCountry: z.string().max(2).optional(), + taxIdType: z.string().optional(), + taxIdValue: z.string().max(64).optional(), + confirmBillingEmailChange: z.boolean().optional(), + note: z.string().min(3).max(500), +}); + +export type PreferenceFormValues = z.infer; + +export function BillingPreferencesAdminForm({ + status, + onSubmit, + loading, +}: { + status: AdminBillingStatus; + onSubmit: (values: PreferenceFormValues) => Promise; + loading: boolean; +}) { + const form = useForm({ + resolver: zodResolver(preferenceSchema), + defaultValues: { + companyName: status.preferences.companyName ?? '', + billingEmail: status.preferences.billingEmail ?? '', + purchaseOrder: status.preferences.purchaseOrder ?? '', + addressLine1: status.preferences.address.line1 ?? '', + addressLine2: status.preferences.address.line2 ?? '', + addressCity: status.preferences.address.city ?? '', + addressState: status.preferences.address.state ?? '', + addressPostalCode: status.preferences.address.postalCode ?? '', + addressCountry: status.preferences.address.country ?? '', + taxIdType: status.preferences.taxId?.type ?? '', + taxIdValue: status.preferences.taxId?.value ?? '', + confirmBillingEmailChange: false, + note: '', + }, + }); + + return ( +
    + + + + + + + + + + + + +
    +