diff --git a/packages/plugins/policy/src/expression-transformer.ts b/packages/plugins/policy/src/expression-transformer.ts index 27d77432a..8c1af4303 100644 --- a/packages/plugins/policy/src/expression-transformer.ts +++ b/packages/plugins/policy/src/expression-transformer.ts @@ -40,7 +40,9 @@ import { ReferenceNode, SelectionNode, SelectQueryNode, + sql, TableNode, + UnaryOperationNode, ValueListNode, ValueNode, WhereNode, @@ -253,13 +255,20 @@ export class ExpressionTransformer { if (ValueListNode.is(right)) { return BinaryOperationNode.create(left, OperatorNode.create('in'), right); } else { - // array contains const leftFieldDef = this.getFieldDefFromFieldRef(normalizedLeft, context); const comparand = leftFieldDef && QueryUtils.isEnum(this.schema, leftFieldDef.type) - ? // cast lhs otherwise dialect like pg can reject due to type mismatch - this.dialect.castText(new ExpressionWrapper(left)).toOperationNode() + ? this.dialect.castText(new ExpressionWrapper(left)).toOperationNode() : left; + + // if RHS is a subquery selecting an array column, use + // a cross-db EXISTS approach instead of `= ANY(subquery)` + const rightFieldDef = this.getFieldDefFromFieldRef(normalizedRight, context); + if (rightFieldDef?.array && SelectQueryNode.is(right)) { + return this.buildArrayInExists(comparand, right as SelectQueryNode); + } + + // array contains return BinaryOperationNode.create( comparand, OperatorNode.create('='), @@ -702,7 +711,11 @@ export class ExpressionTransformer { } else { // transform the first segment into a relation access, then continue with the rest of the members const firstMemberFieldDef = QueryUtils.requireField(this.schema, context.thisType, expr.members[0]!); - receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, restContext); + receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, { + ...restContext, + modelOrType: context.thisType, + alias: context.thisAlias, + }); members = expr.members.slice(1); // startType should be the type of the relation access startType = firstMemberFieldDef.type; @@ -756,7 +769,7 @@ export class ExpressionTransformer { currType = fieldDef.type; } - let currNode: SelectQueryNode | ColumnNode | ReferenceNode | undefined = undefined; + let currNode: SelectQueryNode | ColumnNode | ReferenceNode | FunctionNode | undefined = undefined; for (let i = members.length - 1; i >= 0; i--) { const member = members[i]!; @@ -788,7 +801,9 @@ export class ExpressionTransformer { invariant(i === members.length - 1, 'plain field access must be the last segment'); invariant(!currNode, 'plain field access must be the last segment'); - currNode = ColumnNode.create(member); + currNode = fieldDef.array && this.schema.provider.type === 'postgresql' + ? FunctionNode.create('unnest', [ColumnNode.create(member)]) + : ColumnNode.create(member); } } @@ -1015,6 +1030,21 @@ export class ExpressionTransformer { ExpressionUtils.isThis(expr.receiver) ) { return QueryUtils.getField(this.schema, model, expr.members[0]!); + } else if ( + ExpressionUtils.isMember(expr) && + ExpressionUtils.isThis(expr.receiver) && + expr.members.length > 1 + ) { + // `this.relation.field` chain — walk from the @@allow model + const firstDef = QueryUtils.getField(this.schema, model, expr.members[0]!); + if (!firstDef?.relation) return undefined; + let currModel = firstDef.type; + for (let i = 1; i < expr.members.length - 1; i++) { + const hopDef = QueryUtils.getField(this.schema, currModel, expr.members[i]!); + if (!hopDef?.relation) return undefined; + currModel = hopDef.type; + } + return QueryUtils.getField(this.schema, currModel, expr.members[expr.members.length - 1]!); } else if (ExpressionUtils.isMember(expr) && ExpressionUtils.isField(expr.receiver)) { // relation chain access (e.g. `owner.id`, `user.profile.uuid_field`): walk the // relation hops and return the terminal field's FieldDef so native-type info @@ -1032,4 +1062,73 @@ export class ExpressionTransformer { return undefined; } } + /** + * Build a cross-database EXISTS subquery for `scalar IN relation.arrayField`. + * Preserves the original subquery FROM to handle joined relations (e.g. m2m). + */ + private buildArrayInExists( + scalar: OperationNode, + subquery: SelectQueryNode, + ): OperationNode { + // PG: subquery already has unnest() from _member, just use = ANY(subquery) + if (this.schema.provider.type === 'postgresql') { + return BinaryOperationNode.create( + scalar, + OperatorNode.create('='), + FunctionNode.create('any', [subquery as unknown as OperationNode]), + ); + } + + const eb = this.eb; + + const table = subquery.from!.froms[0] as TableNode; + const tableName = table.table.identifier.name; + + const sel = subquery.selections![0]!; + const alias = sel.selection as AliasNode; + const colName = (alias.node as ColumnNode).column.name; + + const tableRef = eb.ref(`${tableName}.${colName}`); + const scalarRef = new ExpressionWrapper(scalar); + + let arrayCheck: OperationNode; + if (this.schema.provider.type === 'sqlite') { + arrayCheck = eb + .exists( + eb + .selectFrom(eb.fn('json_each', [tableRef]).as('_je')) + .select(eb.lit(1).as('_')) + .where(eb.ref('_je.value'), '=', scalarRef), + ) + .toOperationNode(); + } else { + // mysql + arrayCheck = eb + .exists( + eb + .selectFrom( + sql`JSON_TABLE(${tableRef}, '$[*]' COLUMNS(value JSON PATH '$'))`.as('_jt'), + ) + .select(eb.lit(1).as('_')) + .where(eb.ref('_jt.value'), '=', scalarRef), + ) + .toOperationNode(); + } + + // combine array check with original WHERE + const combinedWhere = subquery.where + ? conjunction(this.dialect, [subquery.where.where, arrayCheck]) + : arrayCheck; + + // preserve original FROM to handle joins (m2m etc.) + return UnaryOperationNode.create( + { + kind: 'SelectQueryNode', + from: subquery.from, + where: WhereNode.create(combinedWhere), + selections: [SelectionNode.create(AliasNode.create(ValueNode.createImmediate(1), IdentifierNode.create('_')))], + } as SelectQueryNode, + OperatorNode.create('exists'), + ); + } } diff --git a/tests/e2e/orm/policy/auth-access.test.ts b/tests/e2e/orm/policy/auth-access.test.ts index 56942de49..6beb828ed 100644 --- a/tests/e2e/orm/policy/auth-access.test.ts +++ b/tests/e2e/orm/policy/auth-access.test.ts @@ -475,4 +475,130 @@ model Channel { userDb2.channel.update({ where: { id: 1 }, data: { name: 'general-updated' } }), ).resolves.toBeTruthy(); }); + + it('resolves this.relation.field against @@allow model in collection predicates (Fix #1)', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id @default(autoincrement()) + level Int + permissions Permission[] + posts Post[] + @@auth +} + +model Permission { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + clearance Int +} + +model Post { + id Int @id @default(autoincrement()) + author User @relation(fields: [authorId], references: [id]) + authorId Int + + @@allow('read', auth().permissions?[p, p.clearance >= this.author.level]) +} +`, + { provider: 'postgresql' }, + ); + + await db.$unuseAll().post.create({ + data: { id: 1, author: { create: { id: 1, level: 5 } } }, + }); + await db.$unuseAll().post.create({ + data: { id: 2, author: { create: { id: 2, level: 10 } } }, + }); + + // no auth: no permissions → cannot read any post + await expect(db.post.findMany()).resolves.toHaveLength(0); + + // clearance 5: can read author level ≤ 5 → only post 1 (author level 5) + const user1 = db.$setAuth({ + id: 3, + permissions: [{ id: 1, clearance: 5 }], + }); + const posts1 = await user1.post.findMany(); + expect(posts1.map((p) => p.id).sort()).toEqual([1]); + + // clearance 10: can read author level ≤ 10 → both posts + const user2 = db.$setAuth({ + id: 4, + permissions: [{ id: 2, clearance: 10 }], + }); + const posts2 = await user2.post.findMany(); + expect(posts2.map((p) => p.id).sort()).toEqual([1, 2]); + }); + + it('handles this.relation.arrayField with in operator (Fix #2)', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id @default(autoincrement()) + permissions Permission[] + @@auth +} + +model Group { + id Int @id @default(autoincrement()) + visibleDocIds Int[] + docs Doc[] +} + +model Permission { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + allowedDocIds Int[] +} + +model Doc { + id Int @id @default(autoincrement()) + group Group @relation(fields: [groupId], references: [id]) + groupId Int + + @@allow('read', + auth().permissions?[p, this.id in p.allowedDocIds] || + this.id in this.group.visibleDocIds + ) +} +`, + { provider: 'postgresql' }, + ); + + await db.$unuseAll().group.create({ + data: { id: 1, visibleDocIds: [1] }, + }); + await db.$unuseAll().group.create({ + data: { id: 2, visibleDocIds: [] }, + }); + await db.$unuseAll().user.create({ + data: { id: 1 }, + }); + await db.$unuseAll().user.create({ + data: { id: 2 }, + }); + await db.$unuseAll().permission.create({ + data: { id: 10, userId: 2, allowedDocIds: [2] }, + }); + await db.$unuseAll().doc.createMany({ + data: [ + { id: 1, groupId: 1 }, + { id: 2, groupId: 2 }, + ], + }); + + // User 1 (no perms): doc 1 visible via group.visibleDocIds + const user1 = db.$setAuth({ id: 1, permissions: [] }); + expect((await user1.doc.findMany()).map((d) => d.id).sort()).toEqual([1]); + + // User 2 (perm allows doc 2): sees doc 1 (group-visible) + doc 2 (permission) + const user2 = db.$setAuth({ + id: 2, + permissions: [{ id: 10, allowedDocIds: [2] }], + }); + expect((await user2.doc.findMany()).map((d) => d.id).sort()).toEqual([1, 2]); + }); });