From cb54db6129288ac5b850be8b8ea06fd354581b2c Mon Sep 17 00:00:00 2001 From: Mike Willbanks Date: Mon, 29 Dec 2025 17:03:22 +0000 Subject: [PATCH 1/4] feat: audit policy collection aliases provides a means to alias collections in @@allow collections by extending the ast this allows for utilizing collections inside of @@allow like: ``` memberships?[m, auth().memberships?[ tenantId == m.tenantId ... ] ] ``` --- packages/language/src/ast.ts | 7 ++ packages/language/src/generated/ast.ts | 8 ++- packages/language/src/generated/grammar.ts | 28 ++++++++ .../attribute-application-validator.ts | 13 +++- .../language/src/zmodel-code-generator.ts | 8 ++- packages/language/src/zmodel-linker.ts | 31 +++++++-- packages/language/src/zmodel-scope.ts | 21 ++++++ packages/language/src/zmodel.langium | 4 +- .../test/expression-validation.test.ts | 44 ++++++++++++ .../policy/src/expression-evaluator.ts | 31 ++++++++- .../policy/src/expression-transformer.ts | 69 +++++++++++++++++-- packages/schema/src/expression-utils.ts | 3 +- packages/schema/src/expression.ts | 1 + .../sdk/src/prisma/prisma-schema-generator.ts | 8 ++- 14 files changed, 250 insertions(+), 26 deletions(-) diff --git a/packages/language/src/ast.ts b/packages/language/src/ast.ts index 71d31d4d..c00cf647 100644 --- a/packages/language/src/ast.ts +++ b/packages/language/src/ast.ts @@ -46,6 +46,13 @@ declare module './ast' { $resolvedParam?: AttributeParam; } + interface BinaryExpr { + /** + * Optional iterator binding for collection predicates + */ + binding?: string; + } + export interface DataModel { /** * All fields including those marked with `@ignore` diff --git a/packages/language/src/generated/ast.ts b/packages/language/src/generated/ast.ts index e759aa1f..54a859ad 100644 --- a/packages/language/src/generated/ast.ts +++ b/packages/language/src/generated/ast.ts @@ -142,7 +142,7 @@ export function isMemberAccessTarget(item: unknown): item is MemberAccessTarget return reflection.isInstance(item, MemberAccessTarget); } -export type ReferenceTarget = DataField | EnumField | FunctionParam; +export type ReferenceTarget = BinaryExpr | DataField | EnumField | FunctionParam; export const ReferenceTarget = 'ReferenceTarget'; @@ -256,6 +256,7 @@ export function isAttributeParamType(item: unknown): item is AttributeParamType export interface BinaryExpr extends langium.AstNode { readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | FieldInitializer | FunctionDecl | MemberAccessExpr | ReferenceArg | UnaryExpr; readonly $type: 'BinaryExpr'; + binding?: RegularID; left: Expression; operator: '!' | '!=' | '&&' | '<' | '<=' | '==' | '>' | '>=' | '?' | '^' | 'in' | '||'; right: Expression; @@ -826,7 +827,6 @@ export class ZModelAstReflection extends langium.AbstractAstReflection { protected override computeIsSubtype(subtype: string, supertype: string): boolean { switch (subtype) { case ArrayExpr: - case BinaryExpr: case MemberAccessExpr: case NullExpr: case ObjectExpr: @@ -843,6 +843,9 @@ export class ZModelAstReflection extends langium.AbstractAstReflection { case Procedure: { return this.isSubtype(AbstractDeclaration, supertype); } + case BinaryExpr: { + return this.isSubtype(Expression, supertype) || this.isSubtype(ReferenceTarget, supertype); + } case BooleanLiteral: case NumberLiteral: case StringLiteral: { @@ -973,6 +976,7 @@ export class ZModelAstReflection extends langium.AbstractAstReflection { return { name: BinaryExpr, properties: [ + { name: 'binding' }, { name: 'left' }, { name: 'operator' }, { name: 'right' } diff --git a/packages/language/src/generated/grammar.ts b/packages/language/src/generated/grammar.ts index 02260ccd..6be9b88d 100644 --- a/packages/language/src/generated/grammar.ts +++ b/packages/language/src/generated/grammar.ts @@ -1418,6 +1418,28 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "$type": "Keyword", "value": "[" }, + { + "$type": "Group", + "elements": [ + { + "$type": "Assignment", + "feature": "binding", + "operator": "=", + "terminal": { + "$type": "RuleCall", + "rule": { + "$ref": "#/rules@51" + }, + "arguments": [] + } + }, + { + "$type": "Keyword", + "value": "," + } + ], + "cardinality": "?" + }, { "$type": "Assignment", "feature": "right", @@ -3996,6 +4018,12 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "typeRef": { "$ref": "#/rules@45" } + }, + { + "$type": "SimpleType", + "typeRef": { + "$ref": "#/rules@29/definition/elements@1/elements@0/inferredType" + } } ] } diff --git a/packages/language/src/validators/attribute-application-validator.ts b/packages/language/src/validators/attribute-application-validator.ts index 62df3a23..f9747ee0 100644 --- a/packages/language/src/validators/attribute-application-validator.ts +++ b/packages/language/src/validators/attribute-application-validator.ts @@ -491,9 +491,16 @@ function isValidAttributeTarget(attrDecl: Attribute, targetDecl: DataField) { return true; } - const fieldTypes = (targetField.args[0].value as ArrayExpr).items.map( - (item) => (item as ReferenceExpr).target.ref?.name, - ); + const fieldTypes = (targetField.args[0].value as ArrayExpr).items + .map((item) => { + if (!isReferenceExpr(item)) { + return undefined; + } + + const ref = item.target.ref; + return ref && 'name' in ref ? (ref as any).name : undefined; + }) + .filter((name): name is string => !!name); let allowed = false; for (const allowedType of fieldTypes) { diff --git a/packages/language/src/zmodel-code-generator.ts b/packages/language/src/zmodel-code-generator.ts index 55efb5fc..1e0366ed 100644 --- a/packages/language/src/zmodel-code-generator.ts +++ b/packages/language/src/zmodel-code-generator.ts @@ -252,13 +252,15 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')}${ const { left: isLeftParenthesis, right: isRightParenthesis } = this.isParenthesesNeededForBinaryExpr(ast); + const collectionPredicate = isCollectionPredicate + ? `[${ast.binding ? `${ast.binding}, ${rightExpr}` : rightExpr}]` + : rightExpr; + return `${isLeftParenthesis ? '(' : ''}${this.generate(ast.left)}${ isLeftParenthesis ? ')' : '' }${isCollectionPredicate ? '' : this.binaryExprSpace}${operator}${ isCollectionPredicate ? '' : this.binaryExprSpace - }${isRightParenthesis ? '(' : ''}${ - isCollectionPredicate ? `[${rightExpr}]` : rightExpr - }${isRightParenthesis ? ')' : ''}`; + }${isRightParenthesis ? '(' : ''}${collectionPredicate}${isRightParenthesis ? ')' : ''}`; } @gen(ReferenceExpr) diff --git a/packages/language/src/zmodel-linker.ts b/packages/language/src/zmodel-linker.ts index 3bb45134..9394f6c3 100644 --- a/packages/language/src/zmodel-linker.ts +++ b/packages/language/src/zmodel-linker.ts @@ -25,6 +25,7 @@ import { DataModel, Enum, EnumField, + isBinaryExpr, type ExpressionType, FunctionDecl, FunctionParam, @@ -121,7 +122,8 @@ export class ZModelLinker extends DefaultLinker { const target = provider(reference.$refText); if (target) { reference._ref = target; - reference._nodeDescription = this.descriptions.createDescription(target, target.name, document); + const targetName = (target as any).name ?? (target as any).binding ?? reference.$refText; + reference._nodeDescription = this.descriptions.createDescription(target, targetName, document); // Add the reference to the document's array of references document.references.push(reference); @@ -249,13 +251,25 @@ export class ZModelLinker extends DefaultLinker { private resolveReference(node: ReferenceExpr, document: LangiumDocument, extraScopes: ScopeProvider[]) { this.resolveDefault(node, document, extraScopes); - - if (node.target.ref) { - // resolve type - if (node.target.ref.$type === EnumField) { - this.resolveToBuiltinTypeOrDecl(node, node.target.ref.$container); + const target = node.target.ref; + + if (target) { + if (isBinaryExpr(target) && ['?', '!', '^'].includes(target.operator)) { + const collectionType = target.left.$resolvedType; + if (collectionType?.decl) { + node.$resolvedType = { + decl: collectionType.decl, + array: false, + nullable: collectionType.nullable, + }; + } + } else if (target.$type === EnumField) { + this.resolveToBuiltinTypeOrDecl(node, target.$container); } else { - this.resolveToDeclaredType(node, (node.target.ref as DataField | FunctionParam).type); + const targetWithType = target as Partial; + if (targetWithType.type) { + this.resolveToDeclaredType(node, targetWithType.type); + } } } } @@ -506,6 +520,9 @@ export class ZModelLinker extends DefaultLinker { //#region Utils private resolveToDeclaredType(node: AstNode, type: FunctionParamType | DataFieldType) { + if (!type) { + return; + } let nullable = false; if (isDataFieldType(type)) { nullable = type.optional; diff --git a/packages/language/src/zmodel-scope.ts b/packages/language/src/zmodel-scope.ts index 6fd866f0..4bd4c830 100644 --- a/packages/language/src/zmodel-scope.ts +++ b/packages/language/src/zmodel-scope.ts @@ -7,6 +7,7 @@ import { StreamScope, UriUtils, interruptAndCheck, + stream, type AstNode, type AstNodeDescription, type LangiumCoreServices, @@ -18,7 +19,9 @@ import { import { match } from 'ts-pattern'; import { BinaryExpr, + Expression, MemberAccessExpr, + isBinaryExpr, isDataField, isDataModel, isEnumField, @@ -145,6 +148,9 @@ export class ZModelScopeProvider extends DefaultScopeProvider { .when(isReferenceExpr, (operand) => { // operand is a reference, it can only be a model/type-def field const ref = operand.target.ref; + if (isBinaryExpr(ref) && isCollectionPredicate(ref)) { + return this.createScopeForCollectionElement(ref.left, globalScope, allowTypeDefScope); + } if (isDataField(ref)) { return this.createScopeForContainer(ref.type.reference?.ref, globalScope, allowTypeDefScope); } @@ -188,6 +194,21 @@ export class ZModelScopeProvider extends DefaultScopeProvider { // // typedef's fields are only added to the scope if the access starts with `auth().` const allowTypeDefScope = isAuthOrAuthMemberAccess(collection); + const collectionScope = this.createScopeForCollectionElement(collection, globalScope, allowTypeDefScope); + + if (collectionPredicate.binding) { + const description = this.descriptions.createDescription( + collectionPredicate, + collectionPredicate.binding, + collectionPredicate.$document!, + ); + return new StreamScope(stream([description]), collectionScope); + } + + return collectionScope; + } + + private createScopeForCollectionElement(collection: Expression, globalScope: Scope, allowTypeDefScope: boolean) { return match(collection) .when(isReferenceExpr, (expr) => { // collection is a reference - model or typedef field diff --git a/packages/language/src/zmodel.langium b/packages/language/src/zmodel.langium index 8d279787..a80c0a1f 100644 --- a/packages/language/src/zmodel.langium +++ b/packages/language/src/zmodel.langium @@ -66,7 +66,7 @@ ConfigArrayExpr: ConfigExpr: LiteralExpr | InvocationExpr | ConfigArrayExpr; -type ReferenceTarget = FunctionParam | DataField | EnumField; +type ReferenceTarget = FunctionParam | DataField | EnumField | BinaryExpr; ThisExpr: value='this'; @@ -113,7 +113,7 @@ CollectionPredicateExpr infers Expression: MemberAccessExpr ( {infer BinaryExpr.left=current} operator=('?'|'!'|'^') - '[' right=Expression ']' + '[' (binding=RegularID ',')? right=Expression ']' )*; InExpr infers Expression: diff --git a/packages/language/test/expression-validation.test.ts b/packages/language/test/expression-validation.test.ts index 100f02b2..7976c9e9 100644 --- a/packages/language/test/expression-validation.test.ts +++ b/packages/language/test/expression-validation.test.ts @@ -98,4 +98,48 @@ describe('Expression Validation Tests', () => { 'incompatible operand types', ); }); + + it('should allow collection predicate with iterator binding', async () => { + await loadSchema(` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id Int @id + memberships Membership[] + @@allow('read', memberships?[m, m.tenantId == id]) + } + + model Membership { + id Int @id + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + } + `); + }); + + it('should keep supporting unbound collection predicate syntax', async () => { + await loadSchema(` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id Int @id + memberships Membership[] + @@allow('read', memberships?[tenantId == id]) + } + + model Membership { + id Int @id + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + } + `); + }); }); diff --git a/packages/plugins/policy/src/expression-evaluator.ts b/packages/plugins/policy/src/expression-evaluator.ts index 45c7b855..d1e1ebe0 100644 --- a/packages/plugins/policy/src/expression-evaluator.ts +++ b/packages/plugins/policy/src/expression-evaluator.ts @@ -15,6 +15,7 @@ import { type ExpressionEvaluatorContext = { auth?: any; thisValue?: any; + scope?: Record; }; /** @@ -64,6 +65,9 @@ export class ExpressionEvaluator { } private evaluateField(expr: FieldExpression, context: ExpressionEvaluatorContext): any { + if (context.scope && expr.field in context.scope) { + return context.scope[expr.field]; + } return context.thisValue?.[expr.field]; } @@ -113,8 +117,28 @@ export class ExpressionEvaluator { invariant(Array.isArray(left), 'expected array'); return match(op) - .with('?', () => left.some((item: any) => this.evaluate(expr.right, { ...context, thisValue: item }))) - .with('!', () => left.every((item: any) => this.evaluate(expr.right, { ...context, thisValue: item }))) + .with('?', () => + left.some((item: any) => + this.evaluate(expr.right, { + ...context, + thisValue: item, + scope: expr.binding + ? { ...(context.scope ?? {}), [expr.binding]: item } + : context.scope, + }), + ), + ) + .with('!', () => + left.every((item: any) => + this.evaluate(expr.right, { + ...context, + thisValue: item, + scope: expr.binding + ? { ...(context.scope ?? {}), [expr.binding]: item } + : context.scope, + }), + ), + ) .with( '^', () => @@ -122,6 +146,9 @@ export class ExpressionEvaluator { this.evaluate(expr.right, { ...context, thisValue: item, + scope: expr.binding + ? { ...(context.scope ?? {}), [expr.binding]: item } + : context.scope, }), ), ) diff --git a/packages/plugins/policy/src/expression-transformer.ts b/packages/plugins/policy/src/expression-transformer.ts index 0ea84a97..ce3f6d0e 100644 --- a/packages/plugins/policy/src/expression-transformer.ts +++ b/packages/plugins/policy/src/expression-transformer.ts @@ -58,6 +58,8 @@ import { trueNode, } from './utils'; +type BindingScope = Record; + /** * Context for transforming a policy expression */ @@ -92,6 +94,11 @@ export type ExpressionTransformerContext = { */ contextValue?: Record; + /** + * Additional named bindings available during transformation + */ + scope?: BindingScope; + /** * The model or type name that `this` keyword refers to */ @@ -310,7 +317,11 @@ export class ExpressionTransformer { // LHS of the expression is evaluated as a value const evaluator = new ExpressionEvaluator(); - const receiver = evaluator.evaluate(expr.left, { thisValue: context.contextValue, auth: this.auth }); + const receiver = evaluator.evaluate(expr.left, { + thisValue: context.contextValue, + auth: this.auth, + scope: this.getEvaluationScope(context.scope), + }); // get LHS's type const baseType = this.isAuthMember(expr.left) ? this.authType : context.modelOrType; @@ -345,10 +356,18 @@ export class ExpressionTransformer { } } + const bindingScope = expr.binding + ? { + ...(context.scope ?? {}), + [expr.binding]: { type: newContextModel, alias: context.alias ?? newContextModel }, + } + : context.scope; + let predicateFilter = this.transform(expr.right, { ...context, modelOrType: newContextModel, alias: undefined, + scope: bindingScope, }); if (expr.op === '!') { @@ -391,6 +410,7 @@ export class ExpressionTransformer { const value = new ExpressionEvaluator().evaluate(expr, { auth: this.auth, thisValue: context.contextValue, + scope: this.getEvaluationScope(context.scope), }); return this.transformValue(value, 'Boolean'); } else { @@ -402,15 +422,20 @@ export class ExpressionTransformer { // e.g.: `auth().profiles[age == this.age]`, each `auth().profiles` element (which is a value) // is used to build an expression for the RHS `age == this.age` // the transformation happens recursively for nested collection predicates - const components = receiver.map((item) => - this.transform(expr.right, { + const components = receiver.map((item) => { + const bindingScope = expr.binding + ? { ...(context.scope ?? {}), [expr.binding]: { type: context.modelOrType, value: item } } + : context.scope; + + return this.transform(expr.right, { operation: context.operation, thisType: context.thisType, thisAlias: context.thisAlias, modelOrType: context.modelOrType, contextValue: item, - }), - ); + scope: bindingScope, + }); + }); // compose the components based on the operator return ( @@ -600,6 +625,25 @@ export class ExpressionTransformer { @expr('member') // @ts-ignore private _member(expr: MemberExpression, context: ExpressionTransformerContext) { + const bindingReceiver = + ExpressionUtils.isField(expr.receiver) && context.scope ? context.scope[expr.receiver.field] : undefined; + + if (bindingReceiver) { + if (bindingReceiver.value !== undefined) { + return this.valueMemberAccess(bindingReceiver.value, expr, bindingReceiver.type); + } + + const rewritten = ExpressionUtils.member(ExpressionUtils._this(), expr.members); + return this._member(rewritten, { + ...context, + modelOrType: bindingReceiver.type, + alias: bindingReceiver.alias ?? bindingReceiver.type, + thisType: bindingReceiver.type, + thisAlias: bindingReceiver.alias ?? bindingReceiver.type, + contextValue: bindingReceiver.value, + }); + } + // `auth()` member access if (this.isAuthCall(expr.receiver)) { return this.valueMemberAccess(this.auth, expr, this.authType); @@ -833,6 +877,21 @@ export class ExpressionTransformer { return this.buildDelegateBaseFieldSelect(context.modelOrType, tableName, column, fieldDef.originModel); } + private getEvaluationScope(scope?: BindingScope) { + if (!scope) { + return undefined; + } + + const result: Record = {}; + for (const [key, value] of Object.entries(scope)) { + if (value.value !== undefined) { + result[key] = value.value; + } + } + + return Object.keys(result).length > 0 ? result : undefined; + } + private buildDelegateBaseFieldSelect(model: string, modelAlias: string, field: string, baseModel: string) { const idFields = QueryUtils.requireIdFields(this.client.$schema, model); return { diff --git a/packages/schema/src/expression-utils.ts b/packages/schema/src/expression-utils.ts index ee48aecc..f7bd526d 100644 --- a/packages/schema/src/expression-utils.ts +++ b/packages/schema/src/expression-utils.ts @@ -39,12 +39,13 @@ export const ExpressionUtils = { }; }, - binary: (left: Expression, op: BinaryOperator, right: Expression): BinaryExpression => { + binary: (left: Expression, op: BinaryOperator, right: Expression, binding?: string): BinaryExpression => { return { kind: 'binary', op, left, right, + binding, }; }, diff --git a/packages/schema/src/expression.ts b/packages/schema/src/expression.ts index 3ce3c2d1..b3bb9c40 100644 --- a/packages/schema/src/expression.ts +++ b/packages/schema/src/expression.ts @@ -41,6 +41,7 @@ export type BinaryExpression = { op: BinaryOperator; left: Expression; right: Expression; + binding?: string; }; export type CallExpression = { diff --git a/packages/sdk/src/prisma/prisma-schema-generator.ts b/packages/sdk/src/prisma/prisma-schema-generator.ts index 553658ad..11435333 100644 --- a/packages/sdk/src/prisma/prisma-schema-generator.ts +++ b/packages/sdk/src/prisma/prisma-schema-generator.ts @@ -15,6 +15,7 @@ import { Enum, EnumField, Expression, + isBinaryExpr, GeneratorDecl, InvocationExpr, isArrayExpr, @@ -352,10 +353,15 @@ export class PrismaSchemaGenerator { new Array(...node.items.map((item) => this.makeAttributeArgValue(item))), ); } else if (isReferenceExpr(node)) { + const ref = node.target.ref!; + const refName = (ref as any).name ?? (isBinaryExpr(ref) ? ref.binding : undefined); + if (!refName) { + throw Error(`Unsupported reference expression target: ${ref.$type}`); + } return new PrismaAttributeArgValue( 'FieldReference', new PrismaFieldReference( - node.target.ref!.name, + refName, node.args.map((arg) => new PrismaFieldReferenceArg(arg.name, this.exprToText(arg.value))), ), ); From aa4b30d7fbb415afac422ba42c32e29e0541782c Mon Sep 17 00:00:00 2001 From: Mike Willbanks Date: Mon, 29 Dec 2025 18:21:59 +0000 Subject: [PATCH 2/4] fix: code review comments + syntax fixes --- .../attribute-application-validator.ts | 3 +- packages/language/src/zmodel-linker.ts | 16 ++++--- .../sdk/src/prisma/prisma-schema-generator.ts | 7 ++- packages/sdk/src/ts-schema-generator.ts | 33 +++++++++++--- tests/e2e/orm/policy/auth-access.test.ts | 44 +++++++++++++++++++ 5 files changed, 88 insertions(+), 15 deletions(-) diff --git a/packages/language/src/validators/attribute-application-validator.ts b/packages/language/src/validators/attribute-application-validator.ts index f9747ee0..09124178 100644 --- a/packages/language/src/validators/attribute-application-validator.ts +++ b/packages/language/src/validators/attribute-application-validator.ts @@ -11,7 +11,6 @@ import { DataFieldAttribute, DataModelAttribute, InternalAttribute, - ReferenceExpr, isArrayExpr, isAttribute, isConfigArrayExpr, @@ -498,7 +497,7 @@ function isValidAttributeTarget(attrDecl: Attribute, targetDecl: DataField) { } const ref = item.target.ref; - return ref && 'name' in ref ? (ref as any).name : undefined; + return ref && 'name' in ref && typeof ref.name === 'string' ? ref.name : undefined; }) .filter((name): name is string => !!name); diff --git a/packages/language/src/zmodel-linker.ts b/packages/language/src/zmodel-linker.ts index 9394f6c3..ba8d9bf5 100644 --- a/packages/language/src/zmodel-linker.ts +++ b/packages/language/src/zmodel-linker.ts @@ -122,7 +122,12 @@ export class ZModelLinker extends DefaultLinker { const target = provider(reference.$refText); if (target) { reference._ref = target; - const targetName = (target as any).name ?? (target as any).binding ?? reference.$refText; + let targetName = reference.$refText; + if ('name' in target && typeof target.name === 'string') { + targetName = target.name; + } else if ('binding' in target && typeof (target as { binding?: unknown }).binding === 'string') { + targetName = (target as { binding: string }).binding; + } reference._nodeDescription = this.descriptions.createDescription(target, targetName, document); // Add the reference to the document's array of references @@ -265,11 +270,10 @@ export class ZModelLinker extends DefaultLinker { } } else if (target.$type === EnumField) { this.resolveToBuiltinTypeOrDecl(node, target.$container); - } else { - const targetWithType = target as Partial; - if (targetWithType.type) { - this.resolveToDeclaredType(node, targetWithType.type); - } + } else if (isDataField(target)) { + this.resolveToDeclaredType(node, target.type); + } else if (target.$type === FunctionParam && (target as FunctionParam).type) { + this.resolveToDeclaredType(node, (target as FunctionParam).type); } } } diff --git a/packages/sdk/src/prisma/prisma-schema-generator.ts b/packages/sdk/src/prisma/prisma-schema-generator.ts index 11435333..78a132c8 100644 --- a/packages/sdk/src/prisma/prisma-schema-generator.ts +++ b/packages/sdk/src/prisma/prisma-schema-generator.ts @@ -354,7 +354,12 @@ export class PrismaSchemaGenerator { ); } else if (isReferenceExpr(node)) { const ref = node.target.ref!; - const refName = (ref as any).name ?? (isBinaryExpr(ref) ? ref.binding : undefined); + const refName = + ('name' in ref && typeof (ref as { name?: unknown }).name === 'string') + ? (ref as { name: string }).name + : isBinaryExpr(ref) && typeof ref.binding === 'string' + ? ref.binding + : undefined; if (!refName) { throw Error(`Unsupported reference expression target: ${ref.$type}`); } diff --git a/packages/sdk/src/ts-schema-generator.ts b/packages/sdk/src/ts-schema-generator.ts index f68bb0bc..325926ac 100644 --- a/packages/sdk/src/ts-schema-generator.ts +++ b/packages/sdk/src/ts-schema-generator.ts @@ -1271,11 +1271,17 @@ export class TsSchemaGenerator { } private createBinaryExpression(expr: BinaryExpr) { - return this.createExpressionUtilsCall('binary', [ + const args = [ this.createExpression(expr.left), this.createLiteralNode(expr.operator), this.createExpression(expr.right), - ]); + ]; + + if (expr.binding) { + args.push(this.createLiteralNode(expr.binding)); + } + + return this.createExpressionUtilsCall('binary', args); } private createUnaryExpression(expr: UnaryExpr) { @@ -1292,13 +1298,28 @@ export class TsSchemaGenerator { } private createRefExpression(expr: ReferenceExpr): any { - if (isDataField(expr.target.ref)) { + const target = expr.target.ref; + if (isDataField(target)) { return this.createExpressionUtilsCall('field', [this.createLiteralNode(expr.target.$refText)]); - } else if (isEnumField(expr.target.ref)) { + } + + if (isEnumField(target)) { return this.createLiteralExpression('StringLiteral', expr.target.$refText); - } else { - throw new Error(`Unsupported reference type: ${expr.target.$refText}`); } + + const refName = + target && 'name' in target && typeof (target as { name?: unknown }).name === 'string' + ? (target as { name: string }).name + : isBinaryExpr(target) && typeof target.binding === 'string' + ? target.binding + : undefined; + + if (refName) { + return this.createExpressionUtilsCall('field', [this.createLiteralNode(refName)]); + } + + // Fallback: treat unknown reference targets (e.g. unresolved iterator bindings) as named fields + return this.createExpressionUtilsCall('field', [this.createLiteralNode(expr.target.$refText)]); } private createCallExpression(expr: InvocationExpr) { diff --git a/tests/e2e/orm/policy/auth-access.test.ts b/tests/e2e/orm/policy/auth-access.test.ts index b994324f..76e0c9f3 100644 --- a/tests/e2e/orm/policy/auth-access.test.ts +++ b/tests/e2e/orm/policy/auth-access.test.ts @@ -130,6 +130,50 @@ model Foo { await expect(db.$setAuth({ profiles: [{ age: 15 }, { age: 20 }] }).foo.findFirst()).toResolveTruthy(); }); + it('uses iterator binding inside collection predicate for auth model', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id + tenantId Int + memberships Membership[] @relation("UserMemberships") +} + +model Membership { + id Int @id + tenantId Int + userId Int + user User @relation("UserMemberships", fields: [userId], references: [id]) +} + +model Foo { + id Int @id + tenantId Int + @@allow('read', auth().memberships?[m, m.tenantId == this.tenantId]) +} +`, + ); + + await db.$unuseAll().foo.createMany({ + data: [ + { id: 1, tenantId: 1 }, + { id: 2, tenantId: 2 }, + ], + }); + + // allowed because iterator binding matches tenantId = 1 + await expect( + db.$setAuth({ tenantId: 1, memberships: [{ id: 10, tenantId: 1 }] }).foo.findMany(), + ).resolves.toEqual([ + { id: 1, tenantId: 1 }, + ]); + + // denied because membership tenantId doesn't match + await expect( + db.$setAuth({ tenantId: 1, memberships: [{ id: 20, tenantId: 3 }] }).foo.findMany(), + ).resolves.toEqual([]); + }); + it('works with shallow auth model collection predicates involving fields - some', async () => { const db = await createPolicyTestClient( ` From bac5e6b311d290605a0a1b4dfe2b7288a87e7178 Mon Sep 17 00:00:00 2001 From: Yiming Cao Date: Sat, 17 Jan 2026 01:33:06 +0800 Subject: [PATCH 3/4] refactor: extract collection predicate binding to its own language construct (#2) - adjusted language processing chain accordingly - fixed several issues in policy transformer/evaluator - more test cases --- packages/language/src/ast.ts | 7 - packages/language/src/generated/ast.ts | 40 +- packages/language/src/generated/grammar.ts | 278 ++++++----- .../src/validators/expression-validator.ts | 17 +- packages/language/src/zmodel-linker.ts | 35 +- packages/language/src/zmodel-scope.ts | 70 ++- packages/language/src/zmodel.langium | 8 +- .../test/expression-validation.test.ts | 269 +++++++++-- .../orm/src/client/crud/validator/utils.ts | 3 + packages/orm/src/utils/schema-utils.ts | 4 + .../policy/src/expression-evaluator.ts | 36 +- .../policy/src/expression-transformer.ts | 123 +++-- packages/schema/src/expression-utils.ts | 10 + packages/schema/src/expression.ts | 6 + .../sdk/src/prisma/prisma-schema-generator.ts | 12 +- packages/sdk/src/ts-schema-generator.ts | 35 +- tests/e2e/orm/policy/auth-access.test.ts | 8 +- .../orm/policy/collection-predicate.test.ts | 447 ++++++++++++++++++ 18 files changed, 1095 insertions(+), 313 deletions(-) create mode 100644 tests/e2e/orm/policy/collection-predicate.test.ts diff --git a/packages/language/src/ast.ts b/packages/language/src/ast.ts index c00cf647..71d31d4d 100644 --- a/packages/language/src/ast.ts +++ b/packages/language/src/ast.ts @@ -46,13 +46,6 @@ declare module './ast' { $resolvedParam?: AttributeParam; } - interface BinaryExpr { - /** - * Optional iterator binding for collection predicates - */ - binding?: string; - } - export interface DataModel { /** * All fields including those marked with `@ignore` diff --git a/packages/language/src/generated/ast.ts b/packages/language/src/generated/ast.ts index 54a859ad..6c1af8d4 100644 --- a/packages/language/src/generated/ast.ts +++ b/packages/language/src/generated/ast.ts @@ -142,7 +142,7 @@ export function isMemberAccessTarget(item: unknown): item is MemberAccessTarget return reflection.isInstance(item, MemberAccessTarget); } -export type ReferenceTarget = BinaryExpr | DataField | EnumField | FunctionParam; +export type ReferenceTarget = CollectionPredicateBinding | DataField | EnumField | FunctionParam; export const ReferenceTarget = 'ReferenceTarget'; @@ -256,7 +256,7 @@ export function isAttributeParamType(item: unknown): item is AttributeParamType export interface BinaryExpr extends langium.AstNode { readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | FieldInitializer | FunctionDecl | MemberAccessExpr | ReferenceArg | UnaryExpr; readonly $type: 'BinaryExpr'; - binding?: RegularID; + binding?: CollectionPredicateBinding; left: Expression; operator: '!' | '!=' | '&&' | '<' | '<=' | '==' | '>' | '>=' | '?' | '^' | 'in' | '||'; right: Expression; @@ -280,6 +280,18 @@ export function isBooleanLiteral(item: unknown): item is BooleanLiteral { return reflection.isInstance(item, BooleanLiteral); } +export interface CollectionPredicateBinding extends langium.AstNode { + readonly $container: BinaryExpr; + readonly $type: 'CollectionPredicateBinding'; + name: RegularID; +} + +export const CollectionPredicateBinding = 'CollectionPredicateBinding'; + +export function isCollectionPredicateBinding(item: unknown): item is CollectionPredicateBinding { + return reflection.isInstance(item, CollectionPredicateBinding); +} + export interface ConfigArrayExpr extends langium.AstNode { readonly $container: ConfigField; readonly $type: 'ConfigArrayExpr'; @@ -774,6 +786,7 @@ export type ZModelAstType = { AttributeParamType: AttributeParamType BinaryExpr: BinaryExpr BooleanLiteral: BooleanLiteral + CollectionPredicateBinding: CollectionPredicateBinding ConfigArrayExpr: ConfigArrayExpr ConfigExpr: ConfigExpr ConfigField: ConfigField @@ -821,12 +834,13 @@ export type ZModelAstType = { export class ZModelAstReflection extends langium.AbstractAstReflection { getAllTypes(): string[] { - return [AbstractDeclaration, Argument, ArrayExpr, Attribute, AttributeArg, AttributeParam, AttributeParamType, BinaryExpr, BooleanLiteral, ConfigArrayExpr, ConfigExpr, ConfigField, ConfigInvocationArg, ConfigInvocationExpr, DataField, DataFieldAttribute, DataFieldType, DataModel, DataModelAttribute, DataSource, Enum, EnumField, Expression, FieldInitializer, FunctionDecl, FunctionParam, FunctionParamType, GeneratorDecl, InternalAttribute, InvocationExpr, LiteralExpr, MemberAccessExpr, MemberAccessTarget, Model, ModelImport, NullExpr, NumberLiteral, ObjectExpr, Plugin, PluginField, Procedure, ProcedureParam, ReferenceArg, ReferenceExpr, ReferenceTarget, StringLiteral, ThisExpr, TypeDeclaration, TypeDef, UnaryExpr, UnsupportedFieldType]; + return [AbstractDeclaration, Argument, ArrayExpr, Attribute, AttributeArg, AttributeParam, AttributeParamType, BinaryExpr, BooleanLiteral, CollectionPredicateBinding, ConfigArrayExpr, ConfigExpr, ConfigField, ConfigInvocationArg, ConfigInvocationExpr, DataField, DataFieldAttribute, DataFieldType, DataModel, DataModelAttribute, DataSource, Enum, EnumField, Expression, FieldInitializer, FunctionDecl, FunctionParam, FunctionParamType, GeneratorDecl, InternalAttribute, InvocationExpr, LiteralExpr, MemberAccessExpr, MemberAccessTarget, Model, ModelImport, NullExpr, NumberLiteral, ObjectExpr, Plugin, PluginField, Procedure, ProcedureParam, ReferenceArg, ReferenceExpr, ReferenceTarget, StringLiteral, ThisExpr, TypeDeclaration, TypeDef, UnaryExpr, UnsupportedFieldType]; } protected override computeIsSubtype(subtype: string, supertype: string): boolean { switch (subtype) { case ArrayExpr: + case BinaryExpr: case MemberAccessExpr: case NullExpr: case ObjectExpr: @@ -843,14 +857,16 @@ export class ZModelAstReflection extends langium.AbstractAstReflection { case Procedure: { return this.isSubtype(AbstractDeclaration, supertype); } - case BinaryExpr: { - return this.isSubtype(Expression, supertype) || this.isSubtype(ReferenceTarget, supertype); - } case BooleanLiteral: case NumberLiteral: case StringLiteral: { return this.isSubtype(LiteralExpr, supertype); } + case CollectionPredicateBinding: + case EnumField: + case FunctionParam: { + return this.isSubtype(ReferenceTarget, supertype); + } case ConfigArrayExpr: { return this.isSubtype(ConfigExpr, supertype); } @@ -862,10 +878,6 @@ export class ZModelAstReflection extends langium.AbstractAstReflection { case TypeDef: { return this.isSubtype(AbstractDeclaration, supertype) || this.isSubtype(TypeDeclaration, supertype); } - case EnumField: - case FunctionParam: { - return this.isSubtype(ReferenceTarget, supertype); - } case InvocationExpr: case LiteralExpr: { return this.isSubtype(ConfigExpr, supertype) || this.isSubtype(Expression, supertype); @@ -991,6 +1003,14 @@ export class ZModelAstReflection extends langium.AbstractAstReflection { ] }; } + case CollectionPredicateBinding: { + return { + name: CollectionPredicateBinding, + properties: [ + { name: 'name' } + ] + }; + } case ConfigArrayExpr: { return { name: ConfigArrayExpr, diff --git a/packages/language/src/generated/grammar.ts b/packages/language/src/generated/grammar.ts index 6be9b88d..91b8a0c6 100644 --- a/packages/language/src/generated/grammar.ts +++ b/packages/language/src/generated/grammar.ts @@ -70,7 +70,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@69" + "$ref": "#/rules@70" }, "arguments": [] } @@ -119,42 +119,42 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@37" + "$ref": "#/rules@38" }, "arguments": [] }, { "$type": "RuleCall", "rule": { - "$ref": "#/rules@42" + "$ref": "#/rules@43" }, "arguments": [] }, { "$type": "RuleCall", "rule": { - "$ref": "#/rules@44" + "$ref": "#/rules@45" }, "arguments": [] }, { "$type": "RuleCall", "rule": { - "$ref": "#/rules@46" + "$ref": "#/rules@47" }, "arguments": [] }, { "$type": "RuleCall", "rule": { - "$ref": "#/rules@53" + "$ref": "#/rules@54" }, "arguments": [] }, { "$type": "RuleCall", "rule": { - "$ref": "#/rules@50" + "$ref": "#/rules@51" }, "arguments": [] } @@ -176,7 +176,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@71" + "$ref": "#/rules@72" }, "arguments": [], "cardinality": "*" @@ -192,7 +192,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] } @@ -236,7 +236,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@71" + "$ref": "#/rules@72" }, "arguments": [], "cardinality": "*" @@ -252,7 +252,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] } @@ -296,7 +296,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@71" + "$ref": "#/rules@72" }, "arguments": [], "cardinality": "*" @@ -308,7 +308,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] } @@ -347,7 +347,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@71" + "$ref": "#/rules@72" }, "arguments": [], "cardinality": "*" @@ -363,7 +363,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] } @@ -407,7 +407,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@71" + "$ref": "#/rules@72" }, "arguments": [], "cardinality": "*" @@ -419,7 +419,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] } @@ -474,7 +474,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "definition": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@33" + "$ref": "#/rules@34" }, "arguments": [] }, @@ -495,7 +495,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@70" + "$ref": "#/rules@71" }, "arguments": [] } @@ -517,7 +517,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@69" + "$ref": "#/rules@70" }, "arguments": [] } @@ -539,7 +539,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@63" + "$ref": "#/rules@64" }, "arguments": [] } @@ -663,7 +663,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@68" + "$ref": "#/rules@69" }, "arguments": [] } @@ -761,7 +761,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@68" + "$ref": "#/rules@69" }, "arguments": [] } @@ -970,7 +970,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@52" + "$ref": "#/rules@53" }, "arguments": [] }, @@ -1069,7 +1069,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@68" + "$ref": "#/rules@69" }, "arguments": [] } @@ -1183,14 +1183,14 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] }, { "$type": "RuleCall", "rule": { - "$ref": "#/rules@69" + "$ref": "#/rules@70" }, "arguments": [] } @@ -1235,7 +1235,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/rules@46" + "$ref": "#/rules@47" }, "deprecatedSyntax": false } @@ -1247,7 +1247,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@35" + "$ref": "#/rules@36" }, "arguments": [], "cardinality": "?" @@ -1278,7 +1278,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@34" + "$ref": "#/rules@35" }, "arguments": [] }, @@ -1428,7 +1428,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@30" }, "arguments": [] } @@ -1468,6 +1468,28 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "parameters": [], "wildcard": false }, + { + "$type": "ParserRule", + "name": "CollectionPredicateBinding", + "definition": { + "$type": "Assignment", + "feature": "name", + "operator": "=", + "terminal": { + "$type": "RuleCall", + "rule": { + "$ref": "#/rules@52" + }, + "arguments": [] + } + }, + "definesHiddenTokens": false, + "entry": false, + "fragment": false, + "hiddenTokens": [], + "parameters": [], + "wildcard": false + }, { "$type": "ParserRule", "name": "InExpr", @@ -1543,7 +1565,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@30" + "$ref": "#/rules@31" }, "arguments": [] }, @@ -1592,7 +1614,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@30" + "$ref": "#/rules@31" }, "arguments": [] } @@ -1622,7 +1644,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@31" + "$ref": "#/rules@32" }, "arguments": [] }, @@ -1663,7 +1685,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@31" + "$ref": "#/rules@32" }, "arguments": [] } @@ -1693,7 +1715,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@32" + "$ref": "#/rules@33" }, "arguments": [] }, @@ -1734,7 +1756,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@32" + "$ref": "#/rules@33" }, "arguments": [] } @@ -1860,7 +1882,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@36" + "$ref": "#/rules@37" }, "arguments": [] } @@ -1879,7 +1901,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@36" + "$ref": "#/rules@37" }, "arguments": [] } @@ -1930,7 +1952,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@71" + "$ref": "#/rules@72" }, "arguments": [] }, @@ -1953,7 +1975,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] } @@ -1964,14 +1986,14 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@38" + "$ref": "#/rules@39" }, "arguments": [] }, { "$type": "RuleCall", "rule": { - "$ref": "#/rules@39" + "$ref": "#/rules@40" }, "arguments": [] }, @@ -1981,14 +2003,14 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@39" + "$ref": "#/rules@40" }, "arguments": [] }, { "$type": "RuleCall", "rule": { - "$ref": "#/rules@38" + "$ref": "#/rules@39" }, "arguments": [] } @@ -2000,14 +2022,14 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@38" + "$ref": "#/rules@39" }, "arguments": [] }, { "$type": "RuleCall", "rule": { - "$ref": "#/rules@39" + "$ref": "#/rules@40" }, "arguments": [] } @@ -2037,7 +2059,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] } @@ -2060,7 +2082,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@40" + "$ref": "#/rules@41" }, "arguments": [] } @@ -2072,7 +2094,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@57" + "$ref": "#/rules@58" }, "arguments": [] } @@ -2111,7 +2133,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/rules@42" + "$ref": "#/rules@43" }, "deprecatedSyntax": false } @@ -2131,7 +2153,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/rules@42" + "$ref": "#/rules@43" }, "deprecatedSyntax": false } @@ -2165,7 +2187,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/rules@37" + "$ref": "#/rules@38" }, "deprecatedSyntax": false } @@ -2191,7 +2213,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@71" + "$ref": "#/rules@72" }, "arguments": [] }, @@ -2204,7 +2226,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@52" + "$ref": "#/rules@53" }, "arguments": [] } @@ -2216,7 +2238,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@41" + "$ref": "#/rules@42" }, "arguments": [] } @@ -2228,7 +2250,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@56" + "$ref": "#/rules@57" }, "arguments": [] }, @@ -2259,7 +2281,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@62" + "$ref": "#/rules@63" }, "arguments": [] } @@ -2271,7 +2293,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@43" + "$ref": "#/rules@44" }, "arguments": [] } @@ -2288,7 +2310,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] }, @@ -2348,7 +2370,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@71" + "$ref": "#/rules@72" }, "arguments": [] }, @@ -2365,7 +2387,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] } @@ -2373,7 +2395,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@38" + "$ref": "#/rules@39" }, "arguments": [], "cardinality": "?" @@ -2392,7 +2414,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@40" + "$ref": "#/rules@41" }, "arguments": [] } @@ -2404,7 +2426,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@57" + "$ref": "#/rules@58" }, "arguments": [] } @@ -2477,7 +2499,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@71" + "$ref": "#/rules@72" }, "arguments": [] }, @@ -2494,7 +2516,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] } @@ -2513,7 +2535,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@45" + "$ref": "#/rules@46" }, "arguments": [] } @@ -2525,7 +2547,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@57" + "$ref": "#/rules@58" }, "arguments": [] } @@ -2559,7 +2581,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@71" + "$ref": "#/rules@72" }, "arguments": [] }, @@ -2572,7 +2594,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@52" + "$ref": "#/rules@53" }, "arguments": [] } @@ -2584,7 +2606,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@56" + "$ref": "#/rules@57" }, "arguments": [] }, @@ -2608,7 +2630,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@71" + "$ref": "#/rules@72" }, "arguments": [], "cardinality": "*" @@ -2624,7 +2646,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] } @@ -2643,7 +2665,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@47" + "$ref": "#/rules@48" }, "arguments": [] } @@ -2662,7 +2684,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@47" + "$ref": "#/rules@48" }, "arguments": [] } @@ -2688,7 +2710,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@48" + "$ref": "#/rules@49" }, "arguments": [] } @@ -2721,7 +2743,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@58" + "$ref": "#/rules@59" }, "arguments": [] }, @@ -2745,7 +2767,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@71" + "$ref": "#/rules@72" }, "arguments": [], "cardinality": "*" @@ -2757,7 +2779,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] } @@ -2773,7 +2795,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@48" + "$ref": "#/rules@49" }, "arguments": [] } @@ -2813,7 +2835,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@61" + "$ref": "#/rules@62" }, "arguments": [] } @@ -2830,7 +2852,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] }, @@ -2876,7 +2898,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@71" + "$ref": "#/rules@72" }, "arguments": [], "cardinality": "*" @@ -2888,7 +2910,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] } @@ -2904,7 +2926,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@48" + "$ref": "#/rules@49" }, "arguments": [] } @@ -2937,7 +2959,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@71" + "$ref": "#/rules@72" }, "arguments": [], "cardinality": "*" @@ -2962,7 +2984,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] } @@ -2981,7 +3003,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@49" + "$ref": "#/rules@50" }, "arguments": [] } @@ -3000,7 +3022,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@47" + "$ref": "#/rules@48" }, "arguments": [] } @@ -3026,7 +3048,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@48" + "$ref": "#/rules@49" }, "arguments": [] } @@ -3038,7 +3060,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@58" + "$ref": "#/rules@59" }, "arguments": [] }, @@ -3063,7 +3085,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@68" + "$ref": "#/rules@69" }, "arguments": [] }, @@ -3126,7 +3148,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] }, @@ -3204,7 +3226,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@71" + "$ref": "#/rules@72" }, "arguments": [] }, @@ -3224,21 +3246,21 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@65" + "$ref": "#/rules@66" }, "arguments": [] }, { "$type": "RuleCall", "rule": { - "$ref": "#/rules@66" + "$ref": "#/rules@67" }, "arguments": [] }, { "$type": "RuleCall", "rule": { - "$ref": "#/rules@67" + "$ref": "#/rules@68" }, "arguments": [] } @@ -3259,7 +3281,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@54" + "$ref": "#/rules@55" }, "arguments": [] } @@ -3278,7 +3300,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@54" + "$ref": "#/rules@55" }, "arguments": [] } @@ -3300,7 +3322,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@58" + "$ref": "#/rules@59" }, "arguments": [] }, @@ -3328,7 +3350,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@71" + "$ref": "#/rules@72" }, "arguments": [] }, @@ -3351,7 +3373,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] } @@ -3367,7 +3389,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@55" + "$ref": "#/rules@56" }, "arguments": [] } @@ -3379,7 +3401,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@58" + "$ref": "#/rules@59" }, "arguments": [] }, @@ -3413,7 +3435,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@61" + "$ref": "#/rules@62" }, "arguments": [] }, @@ -3444,7 +3466,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] }, @@ -3504,12 +3526,12 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/rules@53" + "$ref": "#/rules@54" }, "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@67" + "$ref": "#/rules@68" }, "arguments": [] }, @@ -3526,7 +3548,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@59" + "$ref": "#/rules@60" }, "arguments": [], "cardinality": "?" @@ -3556,7 +3578,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@71" + "$ref": "#/rules@72" }, "arguments": [], "cardinality": "*" @@ -3568,12 +3590,12 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/rules@53" + "$ref": "#/rules@54" }, "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@66" + "$ref": "#/rules@67" }, "arguments": [] }, @@ -3590,7 +3612,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@59" + "$ref": "#/rules@60" }, "arguments": [], "cardinality": "?" @@ -3624,12 +3646,12 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/rules@53" + "$ref": "#/rules@54" }, "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@65" + "$ref": "#/rules@66" }, "arguments": [] }, @@ -3646,7 +3668,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@59" + "$ref": "#/rules@60" }, "arguments": [], "cardinality": "?" @@ -3681,7 +3703,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@60" + "$ref": "#/rules@61" }, "arguments": [] } @@ -3700,7 +3722,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@60" + "$ref": "#/rules@61" }, "arguments": [] } @@ -3732,7 +3754,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] } @@ -4004,25 +4026,25 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "SimpleType", "typeRef": { - "$ref": "#/rules@47" + "$ref": "#/rules@48" } }, { "$type": "SimpleType", "typeRef": { - "$ref": "#/rules@40" + "$ref": "#/rules@41" } }, { "$type": "SimpleType", "typeRef": { - "$ref": "#/rules@45" + "$ref": "#/rules@46" } }, { "$type": "SimpleType", "typeRef": { - "$ref": "#/rules@29/definition/elements@1/elements@0/inferredType" + "$ref": "#/rules@30" } } ] @@ -4034,7 +4056,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "type": { "$type": "SimpleType", "typeRef": { - "$ref": "#/rules@40" + "$ref": "#/rules@41" } } }, @@ -4047,19 +4069,19 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "SimpleType", "typeRef": { - "$ref": "#/rules@37" + "$ref": "#/rules@38" } }, { "$type": "SimpleType", "typeRef": { - "$ref": "#/rules@42" + "$ref": "#/rules@43" } }, { "$type": "SimpleType", "typeRef": { - "$ref": "#/rules@44" + "$ref": "#/rules@45" } } ] diff --git a/packages/language/src/validators/expression-validator.ts b/packages/language/src/validators/expression-validator.ts index c2848c14..3efe6d91 100644 --- a/packages/language/src/validators/expression-validator.ts +++ b/packages/language/src/validators/expression-validator.ts @@ -3,6 +3,7 @@ import { BinaryExpr, Expression, isArrayExpr, + isCollectionPredicateBinding, isDataModel, isDataModelAttribute, isEnum, @@ -12,6 +13,7 @@ import { isReferenceExpr, isThisExpr, MemberAccessExpr, + ReferenceExpr, UnaryExpr, type ExpressionType, } from '../generated/ast'; @@ -51,7 +53,7 @@ export default class ExpressionValidator implements AstValidator { } return false; }); - if (!hasReferenceResolutionError) { + if (hasReferenceResolutionError) { // report silent errors not involving linker errors accept('error', 'Expression cannot be resolved', { node: expr, @@ -62,6 +64,9 @@ export default class ExpressionValidator implements AstValidator { // extra validations by expression type switch (expr.$type) { + case 'ReferenceExpr': + this.validateReferenceExpr(expr, accept); + break; case 'MemberAccessExpr': this.validateMemberAccessExpr(expr, accept); break; @@ -74,6 +79,16 @@ export default class ExpressionValidator implements AstValidator { } } + private validateReferenceExpr(expr: ReferenceExpr, accept: ValidationAcceptor) { + // reference to collection predicate's binding can't be used standalone like: + // `items?[e, e]`, `items?[e, e != null]`, etc. + if (isCollectionPredicateBinding(expr.target.ref) && !isMemberAccessExpr(expr.$container)) { + accept('error', 'Collection predicate binding cannot be used without a member access', { + node: expr, + }); + } + } + private validateMemberAccessExpr(expr: MemberAccessExpr, accept: ValidationAcceptor) { if (isBeforeInvocation(expr.operand) && isDataModel(expr.$resolvedType?.decl)) { accept('error', 'relation fields cannot be accessed from `before()`', { node: expr }); diff --git a/packages/language/src/zmodel-linker.ts b/packages/language/src/zmodel-linker.ts index ba8d9bf5..fc3a7f0d 100644 --- a/packages/language/src/zmodel-linker.ts +++ b/packages/language/src/zmodel-linker.ts @@ -24,11 +24,8 @@ import { DataFieldType, DataModel, Enum, - EnumField, - isBinaryExpr, type ExpressionType, FunctionDecl, - FunctionParam, FunctionParamType, InvocationExpr, LiteralExpr, @@ -44,10 +41,13 @@ import { UnaryExpr, isArrayExpr, isBooleanLiteral, + isCollectionPredicateBinding, isDataField, isDataFieldType, isDataModel, isEnum, + isEnumField, + isFunctionParam, isNumberLiteral, isReferenceExpr, isStringLiteral, @@ -122,13 +122,7 @@ export class ZModelLinker extends DefaultLinker { const target = provider(reference.$refText); if (target) { reference._ref = target; - let targetName = reference.$refText; - if ('name' in target && typeof target.name === 'string') { - targetName = target.name; - } else if ('binding' in target && typeof (target as { binding?: unknown }).binding === 'string') { - targetName = (target as { binding: string }).binding; - } - reference._nodeDescription = this.descriptions.createDescription(target, targetName, document); + reference._nodeDescription = this.descriptions.createDescription(target, target.name, document); // Add the reference to the document's array of references document.references.push(reference); @@ -259,21 +253,18 @@ export class ZModelLinker extends DefaultLinker { const target = node.target.ref; if (target) { - if (isBinaryExpr(target) && ['?', '!', '^'].includes(target.operator)) { - const collectionType = target.left.$resolvedType; - if (collectionType?.decl) { - node.$resolvedType = { - decl: collectionType.decl, - array: false, - nullable: collectionType.nullable, - }; + if (isCollectionPredicateBinding(target)) { + // collection predicate's binding is resolved to the element type of the collection + const collectionType = target.$container.left.$resolvedType; + if (collectionType) { + node.$resolvedType = { ...collectionType, array: false }; } - } else if (target.$type === EnumField) { + } else if (isEnumField(target)) { + // enum field is resolved to its containing enum this.resolveToBuiltinTypeOrDecl(node, target.$container); - } else if (isDataField(target)) { + } else if (isDataField(target) || isFunctionParam(target)) { + // other references are resolved to their declared type this.resolveToDeclaredType(node, target.type); - } else if (target.$type === FunctionParam && (target as FunctionParam).type) { - this.resolveToDeclaredType(node, (target as FunctionParam).type); } } } diff --git a/packages/language/src/zmodel-scope.ts b/packages/language/src/zmodel-scope.ts index 4bd4c830..30b77e29 100644 --- a/packages/language/src/zmodel-scope.ts +++ b/packages/language/src/zmodel-scope.ts @@ -7,7 +7,6 @@ import { StreamScope, UriUtils, interruptAndCheck, - stream, type AstNode, type AstNodeDescription, type LangiumCoreServices, @@ -21,7 +20,7 @@ import { BinaryExpr, Expression, MemberAccessExpr, - isBinaryExpr, + isCollectionPredicateBinding, isDataField, isDataModel, isEnumField, @@ -127,7 +126,7 @@ export class ZModelScopeProvider extends DefaultScopeProvider { // when reference expression is resolved inside a collection predicate, the scope is the collection const containerCollectionPredicate = getCollectionPredicateContext(context.container); if (containerCollectionPredicate) { - return this.getCollectionPredicateScope(context, containerCollectionPredicate as BinaryExpr); + return this.getCollectionPredicateScope(context, containerCollectionPredicate); } } @@ -148,13 +147,20 @@ export class ZModelScopeProvider extends DefaultScopeProvider { .when(isReferenceExpr, (operand) => { // operand is a reference, it can only be a model/type-def field const ref = operand.target.ref; - if (isBinaryExpr(ref) && isCollectionPredicate(ref)) { - return this.createScopeForCollectionElement(ref.left, globalScope, allowTypeDefScope); - } - if (isDataField(ref)) { - return this.createScopeForContainer(ref.type.reference?.ref, globalScope, allowTypeDefScope); - } - return EMPTY_SCOPE; + return match(ref) + .when(isDataField, (r) => + // build a scope with model/typedef members + this.createScopeForContainer(r.type.reference?.ref, globalScope, allowTypeDefScope), + ) + .when(isCollectionPredicateBinding, (r) => + // build a scope from the collection predicate's collection + this.createScopeForCollectionPredicateCollection( + r.$container.left, + globalScope, + allowTypeDefScope, + ), + ) + .otherwise(() => EMPTY_SCOPE); }) .when(isMemberAccessExpr, (operand) => { // operand is a member access, it must be resolved to a non-array model/typedef type @@ -185,30 +191,44 @@ export class ZModelScopeProvider extends DefaultScopeProvider { .otherwise(() => EMPTY_SCOPE); } - private getCollectionPredicateScope(context: ReferenceInfo, collectionPredicate: BinaryExpr) { - const referenceType = this.reflection.getReferenceType(context); - const globalScope = this.getGlobalScope(referenceType, context); + private getCollectionPredicateScope(context: ReferenceInfo, collectionPredicate: BinaryExpr): Scope { + // walk up to collect all collection predicate bindings, which are all available in the scope + let currPredicate: BinaryExpr | undefined = collectionPredicate; + const bindingStack: AstNode[] = []; + while (currPredicate) { + if (currPredicate.binding) { + bindingStack.unshift(currPredicate.binding); + } + currPredicate = AstUtils.getContainerOfType(currPredicate.$container, isCollectionPredicate); + } + + // build a scope chain: global scope -> bindings' scope -> collection scope + const globalScope = this.getGlobalScope(this.reflection.getReferenceType(context), context); + const parentScope = bindingStack.reduce( + (scope, binding) => this.createScopeForNodes([binding], scope), + globalScope, + ); + const collection = collectionPredicate.left; // TODO: full support of typedef member access - // // typedef's fields are only added to the scope if the access starts with `auth().` + // typedef's fields are only added to the scope if the access starts with `auth().` const allowTypeDefScope = isAuthOrAuthMemberAccess(collection); - const collectionScope = this.createScopeForCollectionElement(collection, globalScope, allowTypeDefScope); - - if (collectionPredicate.binding) { - const description = this.descriptions.createDescription( - collectionPredicate, - collectionPredicate.binding, - collectionPredicate.$document!, - ); - return new StreamScope(stream([description]), collectionScope); - } + const collectionScope = this.createScopeForCollectionPredicateCollection( + collection, + parentScope, + allowTypeDefScope, + ); return collectionScope; } - private createScopeForCollectionElement(collection: Expression, globalScope: Scope, allowTypeDefScope: boolean) { + private createScopeForCollectionPredicateCollection( + collection: Expression, + globalScope: Scope, + allowTypeDefScope: boolean, + ) { return match(collection) .when(isReferenceExpr, (expr) => { // collection is a reference - model or typedef field diff --git a/packages/language/src/zmodel.langium b/packages/language/src/zmodel.langium index a80c0a1f..017164b9 100644 --- a/packages/language/src/zmodel.langium +++ b/packages/language/src/zmodel.langium @@ -66,7 +66,7 @@ ConfigArrayExpr: ConfigExpr: LiteralExpr | InvocationExpr | ConfigArrayExpr; -type ReferenceTarget = FunctionParam | DataField | EnumField | BinaryExpr; +type ReferenceTarget = FunctionParam | DataField | EnumField | CollectionPredicateBinding; ThisExpr: value='this'; @@ -109,13 +109,17 @@ UnaryExpr: // binary operator precedence follow Javascript's rules: // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/Operator_Precedence#table +// TODO: promote CollectionPredicateExpr to a first-class expression type CollectionPredicateExpr infers Expression: MemberAccessExpr ( {infer BinaryExpr.left=current} operator=('?'|'!'|'^') - '[' (binding=RegularID ',')? right=Expression ']' + '[' (binding=CollectionPredicateBinding ',')? right=Expression ']' )*; +CollectionPredicateBinding: + name=RegularID; + InExpr infers Expression: CollectionPredicateExpr ( {infer BinaryExpr.left=current} diff --git a/packages/language/test/expression-validation.test.ts b/packages/language/test/expression-validation.test.ts index 7976c9e9..40bad771 100644 --- a/packages/language/test/expression-validation.test.ts +++ b/packages/language/test/expression-validation.test.ts @@ -2,9 +2,10 @@ import { describe, it } from 'vitest'; import { loadSchema, loadSchemaWithError } from './utils'; describe('Expression Validation Tests', () => { - it('should reject model comparison1', async () => { - await loadSchemaWithError( - ` + describe('Model Comparison Tests', () => { + it('should reject model comparison1', async () => { + await loadSchemaWithError( + ` model User { id Int @id name String @@ -19,13 +20,13 @@ describe('Expression Validation Tests', () => { @@allow('all', author == this) } `, - 'comparison between models is not supported', - ); - }); + 'comparison between models is not supported', + ); + }); - it('should reject model comparison2', async () => { - await loadSchemaWithError( - ` + it('should reject model comparison2', async () => { + await loadSchemaWithError( + ` model User { id Int @id name String @@ -48,13 +49,13 @@ describe('Expression Validation Tests', () => { userId Int @unique } `, - 'comparison between models is not supported', - ); - }); + 'comparison between models is not supported', + ); + }); - it('should allow auth comparison with auth type', async () => { - await loadSchema( - ` + it('should allow auth comparison with auth type', async () => { + await loadSchema( + ` datasource db { provider = 'sqlite' url = 'file:./dev.db' @@ -75,12 +76,12 @@ describe('Expression Validation Tests', () => { @@allow('read', auth() == user) } `, - ); - }); + ); + }); - it('should reject auth comparison with non-auth type', async () => { - await loadSchemaWithError( - ` + it('should reject auth comparison with non-auth type', async () => { + await loadSchemaWithError( + ` model User { id Int @id name String @@ -95,12 +96,39 @@ describe('Expression Validation Tests', () => { @@allow('read', auth() == this) } `, - 'incompatible operand types', - ); + 'incompatible operand types', + ); + }); }); - it('should allow collection predicate with iterator binding', async () => { - await loadSchema(` + describe('Collection Predicate Tests', () => { + it('should reject standalone binding access', async () => { + await loadSchemaWithError( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id Int @id + memberships Membership[] + @@allow('read', memberships?[m, m != null]) + } + + model Membership { + id Int @id + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + } + `, + 'binding cannot be used without a member access', + ); + }); + + it('should allow referencing binding', async () => { + await loadSchema(` datasource db { provider = 'sqlite' url = 'file:./dev.db' @@ -119,10 +147,10 @@ describe('Expression Validation Tests', () => { userId Int } `); - }); + }); - it('should keep supporting unbound collection predicate syntax', async () => { - await loadSchema(` + it('should keep supporting unbound collection predicate syntax', async () => { + await loadSchema(` datasource db { provider = 'sqlite' url = 'file:./dev.db' @@ -141,5 +169,192 @@ describe('Expression Validation Tests', () => { userId Int } `); + }); + + it('should support mixing bound and unbound syntax', async () => { + await loadSchema(` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id Int @id + memberships Membership[] + @@allow('read', memberships?[m, m.tenantId == id && tenantId == id]) + } + + model Membership { + id Int @id + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + } + `); + }); + + it('should allow disambiguation with this', async () => { + await loadSchema(` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id Int @id + memberships Membership[] + value Int + @@allow('read', memberships?[m, m.value == this.value]) + } + + model Membership { + id Int @id + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + value Int + } + `); + + await loadSchemaWithError( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id Int @id + memberships Membership[] + value String + @@allow('read', memberships?[m, m.value == this.value]) + } + + model Membership { + id Int @id + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + value Int + } + `, + 'incompatible operand types', + ); + + await loadSchemaWithError( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id Int @id + memberships Membership[] + value String + @@allow('read', memberships?[value == this.value]) + } + + model Membership { + id Int @id + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + value Int + } + `, + 'incompatible operand types', + ); + }); + + it('should support accessing binding from deep context', async () => { + await loadSchema(` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id Int @id + memberships Membership[] + @@allow('read', memberships?[m, roles?[value == m.value]]) + } + + model Membership { + id Int @id + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + value Int + roles Role[] + } + + model Role { + id Int @id + membership Membership @relation(fields: [membershipId], references: [id]) + membershipId Int + value Int + } + `); + + await loadSchema(` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id Int @id + memberships Membership[] + @@allow('read', memberships?[m, roles?[r, r.value == m.value]]) + } + + model Membership { + id Int @id + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + value Int + roles Role[] + } + + model Role { + id Int @id + membership Membership @relation(fields: [membershipId], references: [id]) + membershipId Int + value Int + } + `); + + await loadSchema(` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id Int @id + memberships Membership[] + x Int + @@allow('read', memberships?[m, roles?[this.x == m.value]]) + } + + model Membership { + id Int @id + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + value Int + roles Role[] + } + + model Role { + id Int @id + membership Membership @relation(fields: [membershipId], references: [id]) + membershipId Int + value Int + } + `); + }); }); }); diff --git a/packages/orm/src/client/crud/validator/utils.ts b/packages/orm/src/client/crud/validator/utils.ts index bbd35900..c9909a70 100644 --- a/packages/orm/src/client/crud/validator/utils.ts +++ b/packages/orm/src/client/crud/validator/utils.ts @@ -310,6 +310,9 @@ function evalExpression(data: any, expr: Expression): unknown { .with({ kind: 'call' }, (e) => evalCall(data, e)) .with({ kind: 'this' }, () => data ?? null) .with({ kind: 'null' }, () => null) + .with({ kind: 'binding' }, () => { + throw new Error('Binding expression is not supported in validation expressions'); + }) .exhaustive(); } diff --git a/packages/orm/src/utils/schema-utils.ts b/packages/orm/src/utils/schema-utils.ts index b928b078..17fb3149 100644 --- a/packages/orm/src/utils/schema-utils.ts +++ b/packages/orm/src/utils/schema-utils.ts @@ -2,6 +2,7 @@ import { match } from 'ts-pattern'; import type { ArrayExpression, BinaryExpression, + BindingExpression, CallExpression, Expression, FieldExpression, @@ -24,6 +25,7 @@ export class ExpressionVisitor { .with({ kind: 'binary' }, (e) => this.visitBinary(e)) .with({ kind: 'unary' }, (e) => this.visitUnary(e)) .with({ kind: 'call' }, (e) => this.visitCall(e)) + .with({ kind: 'binding' }, (e) => this.visitBinding(e)) .with({ kind: 'this' }, (e) => this.visitThis(e)) .with({ kind: 'null' }, (e) => this.visitNull(e)) .exhaustive(); @@ -68,6 +70,8 @@ export class ExpressionVisitor { } } + protected visitBinding(_e: BindingExpression): VisitResult {} + protected visitThis(_e: ThisExpression): VisitResult {} protected visitNull(_e: NullExpression): VisitResult {} diff --git a/packages/plugins/policy/src/expression-evaluator.ts b/packages/plugins/policy/src/expression-evaluator.ts index d1e1ebe0..85e97a03 100644 --- a/packages/plugins/policy/src/expression-evaluator.ts +++ b/packages/plugins/policy/src/expression-evaluator.ts @@ -1,9 +1,9 @@ import { invariant } from '@zenstackhq/common-helpers'; -import { match } from 'ts-pattern'; import { ExpressionUtils, type ArrayExpression, type BinaryExpression, + type BindingExpression, type CallExpression, type Expression, type FieldExpression, @@ -11,11 +11,13 @@ import { type MemberExpression, type UnaryExpression, } from '@zenstackhq/orm/schema'; +import { match } from 'ts-pattern'; type ExpressionEvaluatorContext = { auth?: any; thisValue?: any; - scope?: Record; + // scope for resolving references to collection predicate bindings + bindingScope?: Record; }; /** @@ -31,6 +33,7 @@ export class ExpressionEvaluator { .when(ExpressionUtils.isMember, (expr) => this.evaluateMember(expr, context)) .when(ExpressionUtils.isUnary, (expr) => this.evaluateUnary(expr, context)) .when(ExpressionUtils.isCall, (expr) => this.evaluateCall(expr, context)) + .when(ExpressionUtils.isBinding, (expr) => this.evaluateBinding(expr, context)) .when(ExpressionUtils.isThis, () => context.thisValue) .when(ExpressionUtils.isNull, () => null) .exhaustive(); @@ -65,8 +68,8 @@ export class ExpressionEvaluator { } private evaluateField(expr: FieldExpression, context: ExpressionEvaluatorContext): any { - if (context.scope && expr.field in context.scope) { - return context.scope[expr.field]; + if (context.bindingScope && expr.field in context.bindingScope) { + return context.bindingScope[expr.field]; } return context.thisValue?.[expr.field]; } @@ -122,9 +125,9 @@ export class ExpressionEvaluator { this.evaluate(expr.right, { ...context, thisValue: item, - scope: expr.binding - ? { ...(context.scope ?? {}), [expr.binding]: item } - : context.scope, + bindingScope: expr.binding + ? { ...(context.bindingScope ?? {}), [expr.binding]: item } + : context.bindingScope, }), ), ) @@ -133,9 +136,9 @@ export class ExpressionEvaluator { this.evaluate(expr.right, { ...context, thisValue: item, - scope: expr.binding - ? { ...(context.scope ?? {}), [expr.binding]: item } - : context.scope, + bindingScope: expr.binding + ? { ...(context.bindingScope ?? {}), [expr.binding]: item } + : context.bindingScope, }), ), ) @@ -146,12 +149,19 @@ export class ExpressionEvaluator { this.evaluate(expr.right, { ...context, thisValue: item, - scope: expr.binding - ? { ...(context.scope ?? {}), [expr.binding]: item } - : context.scope, + bindingScope: expr.binding + ? { ...(context.bindingScope ?? {}), [expr.binding]: item } + : context.bindingScope, }), ), ) .exhaustive(); } + + private evaluateBinding(expr: BindingExpression, context: ExpressionEvaluatorContext): any { + if (!context.bindingScope || !(expr.name in context.bindingScope)) { + throw new Error(`Unresolved binding: ${expr.name}`); + } + return context.bindingScope[expr.name]; + } } diff --git a/packages/plugins/policy/src/expression-transformer.ts b/packages/plugins/policy/src/expression-transformer.ts index ce3f6d0e..4f0f8fd7 100644 --- a/packages/plugins/policy/src/expression-transformer.ts +++ b/packages/plugins/policy/src/expression-transformer.ts @@ -10,6 +10,7 @@ import { import type { BinaryExpression, BinaryOperator, + BindingExpression, BuiltinType, FieldDef, GetModels, @@ -58,7 +59,7 @@ import { trueNode, } from './utils'; -type BindingScope = Record; +type BindingScope = Record; /** * Context for transforming a policy expression @@ -95,9 +96,9 @@ export type ExpressionTransformerContext = { contextValue?: Record; /** - * Additional named bindings available during transformation + * Additional named collection predicate bindings available during transformation */ - scope?: BindingScope; + bindingScope?: BindingScope; /** * The model or type name that `this` keyword refers to @@ -320,7 +321,7 @@ export class ExpressionTransformer { const receiver = evaluator.evaluate(expr.left, { thisValue: context.contextValue, auth: this.auth, - scope: this.getEvaluationScope(context.scope), + bindingScope: this.getEvaluationBindingScope(context.bindingScope), }); // get LHS's type @@ -345,11 +346,20 @@ export class ExpressionTransformer { newContextModel = fieldDef.type; } else { invariant( - ExpressionUtils.isMember(expr.left) && ExpressionUtils.isField(expr.left.receiver), + ExpressionUtils.isMember(expr.left) && + (ExpressionUtils.isField(expr.left.receiver) || ExpressionUtils.isBinding(expr.left.receiver)), 'left operand must be member access with field receiver', ); - const fieldDef = QueryUtils.requireField(this.schema, context.modelOrType, expr.left.receiver.field); - newContextModel = fieldDef.type; + if (ExpressionUtils.isField(expr.left.receiver)) { + // collection is a field access, context model is the field's type + const fieldDef = QueryUtils.requireField(this.schema, context.modelOrType, expr.left.receiver.field); + newContextModel = fieldDef.type; + } else { + // collection is a binding reference, get type from binding scope + const binding = this.requireBindingScope(expr.left.receiver, context); + newContextModel = binding.type; + } + for (const member of expr.left.members) { const memberDef = QueryUtils.requireField(this.schema, newContextModel, member); newContextModel = memberDef.type; @@ -358,16 +368,16 @@ export class ExpressionTransformer { const bindingScope = expr.binding ? { - ...(context.scope ?? {}), + ...(context.bindingScope ?? {}), [expr.binding]: { type: newContextModel, alias: context.alias ?? newContextModel }, } - : context.scope; + : context.bindingScope; let predicateFilter = this.transform(expr.right, { ...context, modelOrType: newContextModel, alias: undefined, - scope: bindingScope, + bindingScope: bindingScope, }); if (expr.op === '!') { @@ -410,7 +420,7 @@ export class ExpressionTransformer { const value = new ExpressionEvaluator().evaluate(expr, { auth: this.auth, thisValue: context.contextValue, - scope: this.getEvaluationScope(context.scope), + bindingScope: this.getEvaluationBindingScope(context.bindingScope), }); return this.transformValue(value, 'Boolean'); } else { @@ -424,8 +434,15 @@ export class ExpressionTransformer { // the transformation happens recursively for nested collection predicates const components = receiver.map((item) => { const bindingScope = expr.binding - ? { ...(context.scope ?? {}), [expr.binding]: { type: context.modelOrType, value: item } } - : context.scope; + ? { + ...(context.bindingScope ?? {}), + [expr.binding]: { + type: context.modelOrType, + alias: context.thisAlias ?? context.modelOrType, + value: item, + }, + } + : context.bindingScope; return this.transform(expr.right, { operation: context.operation, @@ -433,7 +450,7 @@ export class ExpressionTransformer { thisAlias: context.thisAlias, modelOrType: context.modelOrType, contextValue: item, - scope: bindingScope, + bindingScope: bindingScope, }); }); @@ -625,23 +642,12 @@ export class ExpressionTransformer { @expr('member') // @ts-ignore private _member(expr: MemberExpression, context: ExpressionTransformerContext) { - const bindingReceiver = - ExpressionUtils.isField(expr.receiver) && context.scope ? context.scope[expr.receiver.field] : undefined; - - if (bindingReceiver) { - if (bindingReceiver.value !== undefined) { - return this.valueMemberAccess(bindingReceiver.value, expr, bindingReceiver.type); + if (ExpressionUtils.isBinding(expr.receiver)) { + // if the binding has a plain value in the scope, evaluate directly + const scope = this.requireBindingScope(expr.receiver, context); + if (scope.value !== undefined) { + return this.valueMemberAccess(scope.value, expr, scope.type); } - - const rewritten = ExpressionUtils.member(ExpressionUtils._this(), expr.members); - return this._member(rewritten, { - ...context, - modelOrType: bindingReceiver.type, - alias: bindingReceiver.alias ?? bindingReceiver.type, - thisType: bindingReceiver.type, - thisAlias: bindingReceiver.alias ?? bindingReceiver.type, - contextValue: bindingReceiver.value, - }); } // `auth()` member access @@ -659,12 +665,15 @@ export class ExpressionTransformer { } invariant( - ExpressionUtils.isField(expr.receiver) || ExpressionUtils.isThis(expr.receiver), - 'expect receiver to be field expression or "this"', + ExpressionUtils.isField(expr.receiver) || + ExpressionUtils.isThis(expr.receiver) || + ExpressionUtils.isBinding(expr.receiver), + 'expect receiver to be field expression, collection predicate binding, or "this"', ); let members = expr.members; let receiver: OperationNode; + let startType: string | undefined; const { memberFilter, memberSelect, ...restContext } = context; if (ExpressionUtils.isThis(expr.receiver)) { @@ -682,6 +691,32 @@ export class ExpressionTransformer { const firstMemberFieldDef = QueryUtils.requireField(this.schema, context.thisType, expr.members[0]!); receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, restContext); members = expr.members.slice(1); + // startType should be the type of the relation access + startType = firstMemberFieldDef.type; + } + } else if (ExpressionUtils.isBinding(expr.receiver)) { + if (expr.members.length === 1) { + const bindingScope = this.requireBindingScope(expr.receiver, context); + // `binding.relation` case, equivalent to field access + return this._field(ExpressionUtils.field(expr.members[0]!), { + ...context, + modelOrType: bindingScope.type, + alias: bindingScope.alias, + thisType: context.thisType, + contextValue: undefined, + }); + } else { + // transform the first segment into a relation access, then continue with the rest of the members + const bindingScope = this.requireBindingScope(expr.receiver, context); + const firstMemberFieldDef = QueryUtils.requireField(this.schema, bindingScope.type, expr.members[0]!); + receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, { + ...restContext, + modelOrType: bindingScope.type, + alias: bindingScope.alias, + }); + members = expr.members.slice(1); + // startType should be the type of the relation access + startType = firstMemberFieldDef.type; } } else { receiver = this.transform(expr.receiver, restContext); @@ -689,13 +724,14 @@ export class ExpressionTransformer { invariant(SelectQueryNode.is(receiver), 'expected receiver to be select query'); - let startType: string; - if (ExpressionUtils.isField(expr.receiver)) { - const receiverField = QueryUtils.requireField(this.schema, context.modelOrType, expr.receiver.field); - startType = receiverField.type; - } else { - // "this." case - startType = context.thisType; + if (startType === undefined) { + if (ExpressionUtils.isField(expr.receiver)) { + const receiverField = QueryUtils.requireField(this.schema, context.modelOrType, expr.receiver.field); + startType = receiverField.type; + } else { + // "this." case - already handled above if members were sliced + startType = context.thisType; + } } // traverse forward to collect member types @@ -749,6 +785,12 @@ export class ExpressionTransformer { }; } + private requireBindingScope(expr: BindingExpression, context: ExpressionTransformerContext) { + const binding = context.bindingScope?.[expr.name]; + invariant(binding, `binding not found: ${expr.name}`); + return binding; + } + private valueMemberAccess(receiver: any, expr: MemberExpression, receiverType: string) { if (!receiver) { return ValueNode.createImmediate(null); @@ -877,7 +919,8 @@ export class ExpressionTransformer { return this.buildDelegateBaseFieldSelect(context.modelOrType, tableName, column, fieldDef.originModel); } - private getEvaluationScope(scope?: BindingScope) { + // convert transformer's binding scope to equivalent expression evaluator binding scope + private getEvaluationBindingScope(scope?: BindingScope) { if (!scope) { return undefined; } diff --git a/packages/schema/src/expression-utils.ts b/packages/schema/src/expression-utils.ts index f7bd526d..07ee1c11 100644 --- a/packages/schema/src/expression-utils.ts +++ b/packages/schema/src/expression-utils.ts @@ -2,6 +2,7 @@ import type { ArrayExpression, BinaryExpression, BinaryOperator, + BindingExpression, CallExpression, Expression, FieldExpression, @@ -72,6 +73,13 @@ export const ExpressionUtils = { }; }, + binding: (name: string): BindingExpression => { + return { + kind: 'binding', + name, + }; + }, + _this: (): ThisExpression => { return { kind: 'this', @@ -118,6 +126,8 @@ export const ExpressionUtils = { isMember: (value: unknown): value is MemberExpression => ExpressionUtils.is(value, 'member'), + isBinding: (value: unknown): value is BindingExpression => ExpressionUtils.is(value, 'binding'), + getLiteralValue: (expr: Expression): string | number | boolean | undefined => { return ExpressionUtils.isLiteral(expr) ? expr.value : undefined; }, diff --git a/packages/schema/src/expression.ts b/packages/schema/src/expression.ts index b3bb9c40..1828b9cc 100644 --- a/packages/schema/src/expression.ts +++ b/packages/schema/src/expression.ts @@ -6,6 +6,7 @@ export type Expression = | CallExpression | UnaryExpression | BinaryExpression + | BindingExpression | ThisExpression | NullExpression; @@ -30,6 +31,11 @@ export type MemberExpression = { members: string[]; }; +export type BindingExpression = { + kind: 'binding'; + name: string; +}; + export type UnaryExpression = { kind: 'unary'; op: UnaryOperator; diff --git a/packages/sdk/src/prisma/prisma-schema-generator.ts b/packages/sdk/src/prisma/prisma-schema-generator.ts index 78a132c8..cbc3fc56 100644 --- a/packages/sdk/src/prisma/prisma-schema-generator.ts +++ b/packages/sdk/src/prisma/prisma-schema-generator.ts @@ -353,20 +353,10 @@ export class PrismaSchemaGenerator { new Array(...node.items.map((item) => this.makeAttributeArgValue(item))), ); } else if (isReferenceExpr(node)) { - const ref = node.target.ref!; - const refName = - ('name' in ref && typeof (ref as { name?: unknown }).name === 'string') - ? (ref as { name: string }).name - : isBinaryExpr(ref) && typeof ref.binding === 'string' - ? ref.binding - : undefined; - if (!refName) { - throw Error(`Unsupported reference expression target: ${ref.$type}`); - } return new PrismaAttributeArgValue( 'FieldReference', new PrismaFieldReference( - refName, + node.target.ref!.name, node.args.map((arg) => new PrismaFieldReferenceArg(arg.name, this.exprToText(arg.value))), ), ); diff --git a/packages/sdk/src/ts-schema-generator.ts b/packages/sdk/src/ts-schema-generator.ts index 325926ac..ab592085 100644 --- a/packages/sdk/src/ts-schema-generator.ts +++ b/packages/sdk/src/ts-schema-generator.ts @@ -13,6 +13,7 @@ import { InvocationExpr, isArrayExpr, isBinaryExpr, + isCollectionPredicateBinding, isDataField, isDataModel, isDataSource, @@ -1278,7 +1279,7 @@ export class TsSchemaGenerator { ]; if (expr.binding) { - args.push(this.createLiteralNode(expr.binding)); + args.push(this.createLiteralNode(expr.binding.name)); } return this.createExpressionUtilsCall('binary', args); @@ -1299,27 +1300,17 @@ export class TsSchemaGenerator { private createRefExpression(expr: ReferenceExpr): any { const target = expr.target.ref; - if (isDataField(target)) { - return this.createExpressionUtilsCall('field', [this.createLiteralNode(expr.target.$refText)]); - } - - if (isEnumField(target)) { - return this.createLiteralExpression('StringLiteral', expr.target.$refText); - } - - const refName = - target && 'name' in target && typeof (target as { name?: unknown }).name === 'string' - ? (target as { name: string }).name - : isBinaryExpr(target) && typeof target.binding === 'string' - ? target.binding - : undefined; - - if (refName) { - return this.createExpressionUtilsCall('field', [this.createLiteralNode(refName)]); - } - - // Fallback: treat unknown reference targets (e.g. unresolved iterator bindings) as named fields - return this.createExpressionUtilsCall('field', [this.createLiteralNode(expr.target.$refText)]); + return match(target) + .when(isDataField, () => + this.createExpressionUtilsCall('field', [this.createLiteralNode(expr.target.$refText)]), + ) + .when(isEnumField, () => this.createLiteralExpression('StringLiteral', expr.target.$refText)) + .when(isCollectionPredicateBinding, () => + this.createExpressionUtilsCall('binding', [this.createLiteralNode(expr.target.$refText)]), + ) + .otherwise(() => { + throw Error(`Unsupported reference type: ${expr.target.$refText}`); + }); } private createCallExpression(expr: InvocationExpr) { diff --git a/tests/e2e/orm/policy/auth-access.test.ts b/tests/e2e/orm/policy/auth-access.test.ts index 76e0c9f3..56942de4 100644 --- a/tests/e2e/orm/policy/auth-access.test.ts +++ b/tests/e2e/orm/policy/auth-access.test.ts @@ -149,7 +149,7 @@ model Membership { model Foo { id Int @id tenantId Int - @@allow('read', auth().memberships?[m, m.tenantId == this.tenantId]) + @@allow('read', auth().memberships?[m, m.tenantId == auth().tenantId]) } `, ); @@ -164,14 +164,12 @@ model Foo { // allowed because iterator binding matches tenantId = 1 await expect( db.$setAuth({ tenantId: 1, memberships: [{ id: 10, tenantId: 1 }] }).foo.findMany(), - ).resolves.toEqual([ - { id: 1, tenantId: 1 }, - ]); + ).toResolveWithLength(2); // denied because membership tenantId doesn't match await expect( db.$setAuth({ tenantId: 1, memberships: [{ id: 20, tenantId: 3 }] }).foo.findMany(), - ).resolves.toEqual([]); + ).toResolveWithLength(0); }); it('works with shallow auth model collection predicates involving fields - some', async () => { diff --git a/tests/e2e/orm/policy/collection-predicate.test.ts b/tests/e2e/orm/policy/collection-predicate.test.ts new file mode 100644 index 00000000..09e22228 --- /dev/null +++ b/tests/e2e/orm/policy/collection-predicate.test.ts @@ -0,0 +1,447 @@ +import { createPolicyTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +describe('Collection Predicate Tests', () => { + it('should support collection predicates without binding', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id + memberships Membership[] + @@allow('create', true) + @@allow('read', memberships?[tenantId == id]) + } + + model Membership { + id Int @id + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + @@allow('all', true) + } +`, + ); + await db.$unuseAll().user.create({ + data: { id: 1, memberships: { create: [{ id: 1, tenantId: 1 }] } }, + }); + await db.$unuseAll().user.create({ + data: { id: 2, memberships: { create: [{ id: 2, tenantId: 1 }] } }, + }); + await expect(db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await expect(db.user.findUnique({ where: { id: 2 } })).toResolveNull(); + }); + + it('should support referencing binding', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id + memberships Membership[] + @@allow('create', true) + @@allow('read', memberships?[m, m.tenantId == id]) + } + + model Membership { + id Int @id + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + @@allow('all', true) + } +`, + ); + await db.$unuseAll().user.create({ + data: { id: 1, memberships: { create: [{ id: 1, tenantId: 1 }] } }, + }); + await db.$unuseAll().user.create({ + data: { id: 2, memberships: { create: [{ id: 2, tenantId: 1 }] } }, + }); + await expect(db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await expect(db.user.findUnique({ where: { id: 2 } })).toResolveNull(); + }); + + it('should support mixing bound and unbound syntax', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id + memberships Membership[] + @@allow('create', true) + @@allow('read', memberships?[m, m.tenantId == id && tenantId == id]) + } + + model Membership { + id Int @id + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + @@allow('all', true) + } +`, + ); + await db.$unuseAll().user.create({ + data: { id: 1, memberships: { create: [{ id: 1, tenantId: 1 }] } }, + }); + await db.$unuseAll().user.create({ + data: { id: 2, memberships: { create: [{ id: 2, tenantId: 1 }] } }, + }); + await expect(db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await expect(db.user.findUnique({ where: { id: 2 } })).toResolveNull(); + }); + + it('should allow disambiguation with this', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id + memberships Membership[] + tenantId Int + @@allow('create', true) + @@allow('read', memberships?[m, m.tenantId == this.tenantId]) + } + + model Membership { + id Int @id + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + @@allow('all', true) + } +`, + ); + await db.$unuseAll().user.create({ + data: { id: 1, tenantId: 1, memberships: { create: [{ id: 1, tenantId: 1 }] } }, + }); + await db.$unuseAll().user.create({ + data: { id: 2, tenantId: 2, memberships: { create: [{ id: 2, tenantId: 1 }] } }, + }); + await expect(db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await expect(db.user.findUnique({ where: { id: 2 } })).toResolveNull(); + }); + + it('should support accessing binding from deep context - case 1', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id + memberships Membership[] + @@allow('create', true) + @@allow('read', memberships?[m, roles?[tenantId == m.tenantId]]) + } + + model Membership { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + tenantId Int + roles Role[] + @@allow('all', true) + } + + model Role { + id Int @id + membership Membership @relation(fields: [membershipId], references: [id]) + membershipId Int + tenantId Int + @@allow('all', true) + } +`, + ); + await db.$unuseAll().user.create({ + data: { + id: 1, + memberships: { create: [{ id: 1, tenantId: 1, roles: { create: { id: 1, tenantId: 1 } } }] }, + }, + }); + await db.$unuseAll().user.create({ + data: { + id: 2, + memberships: { create: [{ id: 2, tenantId: 2, roles: { create: { id: 2, tenantId: 1 } } }] }, + }, + }); + await expect(db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await expect(db.user.findUnique({ where: { id: 2 } })).toResolveNull(); + }); + + it('should support accessing binding from deep context - case 2', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id + memberships Membership[] + tenantId Int + @@allow('create', true) + @@allow('read', memberships?[m, roles?[this.tenantId == m.tenantId]]) + } + + model Membership { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + tenantId Int + roles Role[] + @@allow('all', true) + } + + model Role { + id Int @id + membership Membership @relation(fields: [membershipId], references: [id]) + membershipId Int + @@allow('all', true) + } +`, + ); + await db.$unuseAll().user.create({ + data: { + id: 1, + tenantId: 1, + memberships: { create: [{ id: 1, tenantId: 1, roles: { create: { id: 1 } } }] }, + }, + }); + await db.$unuseAll().user.create({ + data: { + id: 2, + tenantId: 2, + memberships: { create: [{ id: 2, tenantId: 1, roles: { create: { id: 2 } } }] }, + }, + }); + await expect(db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await expect(db.user.findUnique({ where: { id: 2 } })).toResolveNull(); + }); + + it('should support accessing to-one relation from binding', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id + memberships Membership[] + tenants Tenant[] + @@allow('create', true) + @@allow('read', memberships?[m, m.tenant.ownerId == id]) + } + + model Tenant { + id Int @id + ownerId Int + owner User @relation(fields: [ownerId], references: [id]) + memberships Membership[] + @@allow('all', true) + } + + model Membership { + id Int @id + tenant Tenant @relation(fields: [tenantId], references: [id]) + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + @@allow('all', true) + } +`, + ); + await db.$unuseAll().user.create({ + data: { + id: 1, + memberships: { + create: [{ id: 1, tenant: { create: { id: 1, ownerId: 1 } } }], + }, + }, + }); + await db.$unuseAll().user.create({ + data: { + id: 2, + memberships: { + create: [{ id: 2, tenant: { create: { id: 2, ownerId: 1 } } }], + }, + }, + }); + await expect(db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await expect(db.user.findUnique({ where: { id: 2 } })).toResolveNull(); + }); + + it('should support multiple bindings in nested predicates', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id + memberships Membership[] + @@allow('create', true) + @@allow('read', memberships?[m, m.roles?[r, r.tenantId == m.tenantId]]) + } + + model Membership { + id Int @id + tenantId Int + user User @relation(fields: [userId], references: [id]) + userId Int + roles Role[] + @@allow('all', true) + } + + model Role { + id Int @id + tenantId Int + membership Membership @relation(fields: [membershipId], references: [id]) + membershipId Int + @@allow('all', true) + } +`, + ); + await db.$unuseAll().user.create({ + data: { + id: 1, + memberships: { + create: [{ id: 1, tenantId: 1, roles: { create: { id: 1, tenantId: 1 } } }], + }, + }, + }); + await db.$unuseAll().user.create({ + data: { + id: 2, + memberships: { + create: [{ id: 2, tenantId: 2, roles: { create: { id: 2, tenantId: 1 } } }], + }, + }, + }); + await expect(db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await expect(db.user.findUnique({ where: { id: 2 } })).toResolveNull(); + }); + + it('should work with inner binding masking outer binding names', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id + memberships Membership[] + tenantId Int + @@allow('create', true) + @@allow('read', memberships?[m, m.roles?[m, m.tenantId == this.tenantId]]) + } + + model Membership { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + roles Role[] + @@allow('all', true) + } + + model Role { + id Int @id + tenantId Int + membership Membership @relation(fields: [membershipId], references: [id]) + membershipId Int + @@allow('all', true) + } +`, + ); + await db.$unuseAll().user.create({ + data: { + id: 1, + tenantId: 1, + memberships: { create: [{ id: 1, roles: { create: { id: 1, tenantId: 1 } } }] }, + }, + }); + await db.$unuseAll().user.create({ + data: { + id: 2, + tenantId: 2, + memberships: { create: [{ id: 2, roles: { create: { id: 2, tenantId: 1 } } }] }, + }, + }); + await expect(db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await expect(db.user.findUnique({ where: { id: 2 } })).toResolveNull(); + }); + + it('should work with bindings with auth collection predicates', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id + companies Company[] + test Int + + @@allow('read', auth().companies?[c, c.staff?[s, s.companyId == this.test]]) + } + + model Company { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + + staff Staff[] + @@allow('read', true) + } + + model Staff { + id Int @id + + company Company @relation(fields: [companyId], references: [id]) + companyId Int + + @@allow('read', true) + } + `, + ); + await db.$unuseAll().user.create({ + data: { + id: 1, + test: 1, + companies: { create: { id: 1, staff: { create: { id: 1 } } } }, + }, + }); + + await expect( + db + .$setAuth({ id: 1, companies: [{ id: 1, staff: [{ id: 1, companyId: 1 }] }], test: 1 }) + .user.findUnique({ where: { id: 1 } }), + ).toResolveTruthy(); + }); + + it('should work with bindings with auth collection predicates - pure value', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id + companies Company[] + + @@allow('read', auth().companies?[c, c.staff?[s, s.companyId == c.id]]) + } + + model Company { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + + staff Staff[] + @@allow('read', true) + } + + model Staff { + id Int @id + + company Company @relation(fields: [companyId], references: [id]) + companyId Int + + @@allow('read', true) + } + `, + ); + await db.$unuseAll().user.create({ + data: { + id: 1, + companies: { create: { id: 1, staff: { create: { id: 1 } } } }, + }, + }); + + await expect( + db + .$setAuth({ id: 1, companies: [{ id: 1, staff: [{ id: 1, companyId: 1 }] }] }) + .user.findUnique({ where: { id: 1 } }), + ).toResolveTruthy(); + await expect( + db + .$setAuth({ id: 1, companies: [{ id: 1, staff: [{ id: 1, companyId: 2 }] }] }) + .user.findUnique({ where: { id: 1 } }), + ).toResolveNull(); + }); +}); From 81009fb4d8acdb1c2707300b527706ce474e590a Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sun, 18 Jan 2026 13:59:46 +0800 Subject: [PATCH 4/4] addressing PR comments --- packages/plugins/policy/src/expression-transformer.ts | 2 +- packages/sdk/src/prisma/prisma-schema-generator.ts | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/plugins/policy/src/expression-transformer.ts b/packages/plugins/policy/src/expression-transformer.ts index 912b7409..7977ccb2 100644 --- a/packages/plugins/policy/src/expression-transformer.ts +++ b/packages/plugins/policy/src/expression-transformer.ts @@ -370,7 +370,7 @@ export class ExpressionTransformer { const bindingScope = expr.binding ? { ...(context.bindingScope ?? {}), - [expr.binding]: { type: newContextModel, alias: context.alias ?? newContextModel }, + [expr.binding]: { type: newContextModel, alias: newContextModel }, } : context.bindingScope; diff --git a/packages/sdk/src/prisma/prisma-schema-generator.ts b/packages/sdk/src/prisma/prisma-schema-generator.ts index cbc3fc56..553658ad 100644 --- a/packages/sdk/src/prisma/prisma-schema-generator.ts +++ b/packages/sdk/src/prisma/prisma-schema-generator.ts @@ -15,7 +15,6 @@ import { Enum, EnumField, Expression, - isBinaryExpr, GeneratorDecl, InvocationExpr, isArrayExpr,