diff --git a/src/control-flow/dfg-cfg-guided-visitor.ts b/src/control-flow/dfg-cfg-guided-visitor.ts index 47bd315bfc..eb33574b4a 100644 --- a/src/control-flow/dfg-cfg-guided-visitor.ts +++ b/src/control-flow/dfg-cfg-guided-visitor.ts @@ -51,6 +51,7 @@ export class DataflowAwareCfgGuidedVisitor< private onExprOrStmtNode(node: CfgStatementVertex | CfgExpressionVertex): void { const dfgVertex = this.getDataflowGraph(node.id); if(!dfgVertex) { + this.visitUnknown(node); return; } @@ -76,6 +77,12 @@ export class DataflowAwareCfgGuidedVisitor< } } + /** + * called for every cfg vertex that has no corresponding dataflow vertex. + */ + protected visitUnknown(_vertex: CfgStatementVertex | CfgExpressionVertex): void { + } + protected visitValue(_val: DataflowGraphVertexValue): void { } diff --git a/src/control-flow/semantic-cfg-guided-visitor.ts b/src/control-flow/semantic-cfg-guided-visitor.ts index ef2986fae1..dd851f872d 100644 --- a/src/control-flow/semantic-cfg-guided-visitor.ts +++ b/src/control-flow/semantic-cfg-guided-visitor.ts @@ -1,4 +1,4 @@ -import type { ControlFlowInformation } from './control-flow-graph'; +import type { CfgExpressionVertex, CfgStatementVertex, ControlFlowInformation } from './control-flow-graph'; import type { DataflowInformation } from '../dataflow/info'; @@ -28,6 +28,7 @@ import type { FunctionArgument } from '../dataflow/graph/graph'; import { edgeIncludesType, EdgeType } from '../dataflow/graph/edge'; import { guard } from '../util/assert'; import type { NoInfo, RNode } from '../r-bridge/lang-4.x/ast/model/model'; +import type { RExpressionList } from '../r-bridge/lang-4.x/ast/model/nodes/r-expression-list'; @@ -115,6 +116,14 @@ export class SemanticCfgGuidedVisitor< } } + protected override visitUnknown(vertex: CfgStatementVertex | CfgExpressionVertex) { + super.visitUnknown(vertex); + const ast = this.getNormalizedAst(vertex.id); + if(ast && ast.type === RType.ExpressionList && ast.info.parent === undefined) { + this.onProgram(ast); + } + } + protected onDispatchFunctionCallOrigins(call: DataflowGraphVertexFunctionCall, origins: readonly string[]) { for(const origin of origins) { this.onDispatchFunctionCallOrigin(call, origin); @@ -182,6 +191,9 @@ export class SemanticCfgGuidedVisitor< } } + protected onProgram(_data: RExpressionList) { + } + /** * Requests the {@link getOriginInDfg|origins} of the given node. diff --git a/src/r-bridge/lang-4.x/ast/model/model.ts b/src/r-bridge/lang-4.x/ast/model/model.ts index 78e08db983..d9111b893a 100644 --- a/src/r-bridge/lang-4.x/ast/model/model.ts +++ b/src/r-bridge/lang-4.x/ast/model/model.ts @@ -31,7 +31,7 @@ export type NoInfo = object; * Will be used to reconstruct the source of the given element in the R-ast. * This will not be part of most comparisons as it is mainly of interest to the reconstruction of R code. */ -interface Source { +export interface Source { /** * The range is different from the assigned {@link Location} as it refers to the complete source range covered by the given * element. diff --git a/src/r-bridge/lang-4.x/ast/model/processing/decorate.ts b/src/r-bridge/lang-4.x/ast/model/processing/decorate.ts index 26ac2dbcb2..b15e5da13c 100644 --- a/src/r-bridge/lang-4.x/ast/model/processing/decorate.ts +++ b/src/r-bridge/lang-4.x/ast/model/processing/decorate.ts @@ -9,7 +9,7 @@ * @module */ -import type { NoInfo, RNode } from '../model'; +import type { NoInfo, RNode, Source } from '../model'; import { guard } from '../../../../../util/assert'; import type { SourceRange } from '../../../../../util/range'; import { BiMap } from '../../../../../util/collections/bimap'; @@ -463,3 +463,73 @@ function createFoldForFunctionArgument(info: FoldInfo) { return decorated; }; } + + +export function mapAstInfo(ast: RNode, down: Down, infoMapper: (node: RNode, down: Down) => NewInfo, downUpdater: (node: RNode, down: Down) => Down = (_node, down) => down): RNode { + const fullInfoMapper = (node: RNode, down: Down): NewInfo & Source => { + const sourceInfo = { + ...(node.info.fullRange !== undefined ? { fullRange: node.info.fullRange } : {}), + ...(node.info.fullLexeme !== undefined ? { fullLexeme: node.info.fullLexeme } : {}), + ...(node.info.additionalTokens !== undefined ? { additionalTokens: node.info.additionalTokens } : {}), + ...(node.info.file !== undefined ? { file: node.info.file } : {}) + }; + const mappedInfo = infoMapper(node, down); + return { ...sourceInfo, ...mappedInfo }; + }; + + function updateInfo(n: RNode, down: Down): RNode { + (n.info as NewInfo) = fullInfoMapper(n, down); + return n as unknown as RNode; + } + + return foldAstStateful(ast, down, { + down: downUpdater, + foldNumber: updateInfo, + foldString: updateInfo, + foldLogical: updateInfo, + foldSymbol: updateInfo, + foldAccess: (node, _name, _access, down) => updateInfo(node, down), + foldBinaryOp: (op, _lhs, _rhs, down) => updateInfo(op, down), + foldPipe: (op, _lhs, _rhs, down) => updateInfo(op, down), + foldUnaryOp: (op, _operand, down) => updateInfo(op, down), + loop: { + foldFor: (loop, _variable, _vector, _body, down) => updateInfo(loop, down), + foldWhile: (loop, _condition, _body, down) => updateInfo(loop, down), + foldRepeat: (loop, _body, down) => updateInfo(loop, down), + foldNext: (next, down) => updateInfo(next, down), + foldBreak: (next, down) => updateInfo(next, down), + }, + other: { + foldComment: (comment, down) => updateInfo(comment, down), + foldLineDirective: (comment, down) => updateInfo(comment, down), + }, + foldIfThenElse: (ifThenExpr, _condition, _then, _otherwise, down ) => + updateInfo(ifThenExpr, down), + foldExprList: (exprList, _grouping, _expressions, down) => updateInfo(exprList, down), + functions: { + foldFunctionDefinition: (definition, _parameters, _body, down) => updateInfo(definition, down), + /** folds named and unnamed function calls */ + foldFunctionCall: (call, _functionNameOrExpression, _args, down) => updateInfo(call, down), + /** The `name` is `undefined` if the argument is unnamed, the value, if we have something like `x=,...` */ + foldArgument: (argument, _name, _value, down) => updateInfo(argument, down), + /** The `defaultValue` is `undefined` if the argument was not initialized with a default value */ + foldParameter: (parameter, _name, _defaultValue, down) => updateInfo(parameter, down) + } + }); +} + +export function mapNormalizedAstInfo(normalizedAst: NormalizedAst, down: Down, infoMapper: (node: RNode, down: Down) => NewInfo, downUpdater: (node: RNode, down: Down) => Down = (_node, down) => down): NormalizedAst { + const parentInfoPreservingMapper = (node: RNode, down: Down): NewInfo & ParentInformation => { + const parentInfo = { + id: node.info.id, + parent: node.info.parent, + role: node.info.role, + nesting: node.info.nesting, + index: node.info.index + }; + const mappedInfo = infoMapper(node, down); + return { ...parentInfo, ...mappedInfo }; + }; + mapAstInfo(normalizedAst.ast, down, parentInfoPreservingMapper, downUpdater); + return normalizedAst as unknown as NormalizedAst; +} \ No newline at end of file diff --git a/src/typing/infer.ts b/src/typing/infer.ts new file mode 100644 index 0000000000..a1770d2472 --- /dev/null +++ b/src/typing/infer.ts @@ -0,0 +1,171 @@ +import { extractCFG } from '../control-flow/extract-cfg'; +import { SemanticCfgGuidedVisitor } from '../control-flow/semantic-cfg-guided-visitor'; +import type { DataflowGraphVertexFunctionCall, DataflowGraphVertexFunctionDefinition, DataflowGraphVertexUse, DataflowGraphVertexValue } from '../dataflow/graph/vertex'; +import type { DataflowInformation } from '../dataflow/info'; +import type { RLogical } from '../r-bridge/lang-4.x/ast/model/nodes/r-logical'; +import type { RNumber } from '../r-bridge/lang-4.x/ast/model/nodes/r-number'; +import type { RString } from '../r-bridge/lang-4.x/ast/model/nodes/r-string'; +import type { NormalizedAst } from '../r-bridge/lang-4.x/ast/model/processing/decorate'; +import { mapNormalizedAstInfo } from '../r-bridge/lang-4.x/ast/model/processing/decorate'; +import type { RDataType } from './types'; +import { RTypeVariable , RComplexType, RDoubleType, RIntegerType, RLogicalType, RStringType, resolveType, RNullType, RFunctionType } from './types'; +import type { RExpressionList } from '../r-bridge/lang-4.x/ast/model/nodes/r-expression-list'; +import { guard } from '../util/assert'; +import { OriginType } from '../dataflow/origin/dfg-get-origin'; +import type { NodeId } from '../r-bridge/lang-4.x/ast/model/processing/node-id'; +import { edgeIncludesType, EdgeType } from '../dataflow/graph/edge'; + +export function inferDataTypes(ast: NormalizedAst, dataFlowInfo: DataflowInformation): NormalizedAst { + const astWithTypeVars = decorateTypeVariables(ast); + const controlFlowInfo = extractCFG(astWithTypeVars); + const config = { + normalizedAst: astWithTypeVars, + controlFlow: controlFlowInfo, + dataflow: dataFlowInfo, + defaultVisitingOrder: 'forward' as const, + }; + const visitor = new TypeInferingCfgGuidedVisitor(config); + visitor.start(); + + return resolveTypeVariables(astWithTypeVars); +} + +type UnresolvedTypeInfo = { + typeVariable: RTypeVariable; +}; + +export type DataTypeInfo = { + inferredType: RDataType; +} + +function decorateTypeVariables(ast: NormalizedAst): NormalizedAst { + return mapNormalizedAstInfo(ast, {}, (node, _down) => ({ ...node.info, typeVariable: new RTypeVariable() })); +} + +function resolveTypeVariables(ast: NormalizedAst): NormalizedAst & DataTypeInfo> { + return mapNormalizedAstInfo(ast, {}, (node, _down) => { + const { typeVariable, ...rest } = node.info; + return { ...rest, inferredType: resolveType(typeVariable) }; + }); +} + +class TypeInferingCfgGuidedVisitor extends SemanticCfgGuidedVisitor{ + override onLogicalConstant(_vertex: DataflowGraphVertexValue, node: RLogical): void { + node.info.typeVariable.unify(new RLogicalType()); + } + + override onNumberConstant(_vertex: DataflowGraphVertexValue, node: RNumber): void { + if(node.content.complexNumber) { + node.info.typeVariable.unify(new RComplexType()); + } else if(Number.isInteger(node.content.num)) { + node.info.typeVariable.unify(new RIntegerType()); + } else { + node.info.typeVariable.unify(new RDoubleType()); + } + } + + override onStringConstant(_vertex: DataflowGraphVertexValue, node: RString): void { + node.info.typeVariable.unify(new RStringType()); + } + + override onVariableUse(vertex: DataflowGraphVertexUse): void { + const node = this.getNormalizedAst(vertex.id); + guard(node !== undefined, 'Expected AST node to be defined'); + + const origins = this.getOrigins(vertex.id); + const readOrigins = origins?.filter((origin) => origin.type === OriginType.ReadVariableOrigin); + + if(readOrigins === undefined || readOrigins.length === 0) { + node.info.typeVariable.unify(new RNullType()); + return; + } + + for(const readOrigin of readOrigins) { + const readNode = this.getNormalizedAst(readOrigin.id); + guard(readNode !== undefined, 'Expected read node to be defined'); + node.info.typeVariable.unify(readNode.info.typeVariable); + } + } + + override onAssignmentCall(data: { call: DataflowGraphVertexFunctionCall, target?: NodeId, source?: NodeId }): void { + if(data.target === undefined || data.source === undefined) { + return; // Malformed assignment + } + + const variableNode = this.getNormalizedAst(data.target); + const valueNode = this.getNormalizedAst(data.source); + const assignmentNode = this.getNormalizedAst(data.call.id); + guard(variableNode !== undefined && valueNode !== undefined && assignmentNode !== undefined, 'Expected AST nodes to be defined'); + + variableNode.info.typeVariable.unify(valueNode.info.typeVariable); + assignmentNode.info.typeVariable.unify(variableNode.info.typeVariable); + } + + override onDefaultFunctionCall(data: { call: DataflowGraphVertexFunctionCall }): void { + const node = this.getNormalizedAst(data.call.id); + guard(node !== undefined, 'Expected AST node to be defined'); + + const outgoing = this.config.dataflow.graph.outgoingEdges(data.call.id); + const callsTargets = outgoing?.entries() + .filter(([_target, edge]) => edgeIncludesType(edge.types, EdgeType.Calls)) + .map(([target, _edge]) => target) + .toArray(); + + guard(callsTargets === undefined || callsTargets.length <= 1, 'Expected at most one call edge'); + + if(callsTargets === undefined || callsTargets.length === 0) { + // TODO: Handle builtin functions + return; + } + + const target = this.getNormalizedAst(callsTargets[0]); + guard(target !== undefined, 'Expected target node to be defined'); + + target.info.typeVariable.unify(new RFunctionType()); + } + + override onFunctionDefinition(vertex: DataflowGraphVertexFunctionDefinition): void { + const node = this.getNormalizedAst(vertex.id); + guard(node !== undefined, 'Expected AST node to be defined'); + + node.info.typeVariable.unify(new RFunctionType()); + } + + override onProgram(node: RExpressionList) { + const exitPoints = this.config.dataflow.exitPoints; + const evalCandidates = exitPoints.map((exitPoint) => exitPoint.nodeId); + + if(evalCandidates.length === 0) { + node.info.typeVariable.unify(new RNullType()); + return; + } + + for(const candidateId of evalCandidates) { + const candidate = this.getNormalizedAst(candidateId); + guard(candidate !== undefined, 'Expected target node to be defined'); + node.info.typeVariable.unify(candidate.info.typeVariable); + } + } + + override onExpressionList(data: { call: DataflowGraphVertexFunctionCall }) { + const node = this.getNormalizedAst(data.call.id); + guard(node !== undefined, 'Expected AST node to be defined'); + + const outgoing = this.config.dataflow.graph.outgoingEdges(data.call.id); + const evalCandidates = outgoing?.entries() + .filter(([_target, edge]) => edgeIncludesType(edge.types, EdgeType.Returns)) + .map(([target, _edge]) => target) + .toArray(); + + if(evalCandidates === undefined || evalCandidates.length === 0) { + node.info.typeVariable.unify(new RNullType()); + return; + } + + for(const candidateId of evalCandidates) { + const candidate = this.getNormalizedAst(candidateId); + guard(candidate !== undefined, 'Expected target node to be defined'); + node.info.typeVariable.unify(candidate.info.typeVariable); + } + } +} \ No newline at end of file diff --git a/src/typing/types.ts b/src/typing/types.ts new file mode 100644 index 0000000000..1fe2654b04 --- /dev/null +++ b/src/typing/types.ts @@ -0,0 +1,172 @@ +// import type { NodeId } from '../r-bridge/lang-4.x/ast/model/processing/node-id'; + +// export type TypeId = T & { __brand?: 'type-id' }; +// export function typeIdFromNodeId(nodeId: NodeId): TypeId { +// return nodeId as TypeId; +// } + +// interface UnificationError extends Error { +// __brand: 'unification-error'; +// } +// function isUnificationError(error: unknown): error is UnificationError { +// return typeof error === 'object' && error !== null && '__brand' in error && error.__brand === 'unification-error'; +// } + + +/** + * This enum lists a tag for each of the possible R data types inferred by the + * type inferencer. It is mainly used to identify subtypes of {@link RDataType}. + */ +export enum RDataTypeTag { + /** {@link RAnyType} */ + Any = 'RAnyType', + /** {@link RLogicalType} */ + Logical = 'RLogicalType', + /** {@link RIntegerType} */ + Integer = 'RIntegerType', + /** {@link RDoubleType} */ + Double = 'RDoubleType', + /** {@link RComplexType} */ + Complex = 'RComplexType', + /** {@link RStringType} */ + String = 'RStringType', + /** {@link RRawType} */ + Raw = 'RRawType', + /** {@link RNullType} */ + Null = 'RNullType', + /** {@link RFunctionType} */ + Function = 'RFunctionType', + /** {@link RListType} */ + List = 'RListType', + /** {@link REnvironmentType} */ + Environment = 'REnvironmentType', + /** {@link RSpecialType} */ + Special = 'RSpecialType', + /** {@link RBuiltinType} */ + Builtin = 'RBuiltinType', + /** {@link RTypeVariable} */ + Variable = 'RTypeVariable', +} + +export class RAnyType { + readonly tag = RDataTypeTag.Any; +} + +export class RLogicalType { + readonly tag = RDataTypeTag.Logical; +} + +export class RIntegerType { + readonly tag = RDataTypeTag.Integer; +} + +export class RDoubleType { + readonly tag = RDataTypeTag.Double; +} + +export class RComplexType { + readonly tag = RDataTypeTag.Complex; +} + +export class RStringType { + readonly tag = RDataTypeTag.String; +} + +export class RRawType { + readonly tag = RDataTypeTag.Raw; +} + +export class RNullType { + readonly tag = RDataTypeTag.Null; +} + +export class RFunctionType { + readonly tag = RDataTypeTag.Function; +} + +export class RListType { + readonly tag = RDataTypeTag.List; +} + +export class REnvironmentType { + readonly tag = RDataTypeTag.Environment; +} + +export class RSpecialType { + readonly tag = RDataTypeTag.Special; +} + +export class RBuiltinType { + readonly tag = RDataTypeTag.Builtin; +} + +export class RTypeVariable { + readonly tag = RDataTypeTag.Variable; + private boundType: UnresolvedRDataType | undefined; + + find(): UnresolvedRDataType { + if(this.boundType instanceof RTypeVariable) { + this.boundType = this.boundType.find(); + } + return this.boundType ?? this; + } + + unify(other: UnresolvedRDataType): void { + const thisRep = this.find(); + const otherRep = other instanceof RTypeVariable ? other.find() : other; + + if(thisRep === otherRep) { + return; + } + + if(thisRep instanceof RTypeVariable) { + thisRep.boundType = otherRep; + } else if(otherRep instanceof RTypeVariable) { + otherRep.boundType = thisRep; + } else if(thisRep.tag !== otherRep.tag) { + this.boundType = new RAnyType(); + } + } +} + + +export function resolveType(type: UnresolvedRDataType): RDataType { + if(type instanceof RTypeVariable) { + const typeRep = type.find(); + return typeRep !== type ? resolveType(typeRep) : { tag: RDataTypeTag.Any }; + } + return type; +} + + +export type PrimitiveRDataType + = RAnyType + | RLogicalType + | RIntegerType + | RDoubleType + | RComplexType + | RStringType + | RRawType + | RNullType + | RFunctionType + | RListType + | REnvironmentType + | RSpecialType + | RBuiltinType; + +export type CompoundRDataType = never; + +/** + * The `RDataType` type is the union of all possible types that can be inferred + * by the type inferencer for R objects. + * It should be used whenever you either not care what kind of + * type you are dealing with or if you want to handle all possible types. + */ +export type RDataType = PrimitiveRDataType | CompoundRDataType; + +export type UnresolvedCompoundRDataType = never; + +export type UnresolvedRDataType + = PrimitiveRDataType + | UnresolvedCompoundRDataType + | RTypeVariable; \ No newline at end of file diff --git a/test/functionality/_helper/typing/assert-inferred-type.ts b/test/functionality/_helper/typing/assert-inferred-type.ts new file mode 100644 index 0000000000..368808eb3d --- /dev/null +++ b/test/functionality/_helper/typing/assert-inferred-type.ts @@ -0,0 +1,38 @@ +import { describe, expect, test } from 'vitest'; +import { TreeSitterExecutor } from '../../../../src/r-bridge/lang-4.x/tree-sitter/tree-sitter-executor'; +import { createDataflowPipeline } from '../../../../src/core/steps/pipeline/default-pipelines'; +import { requestFromInput } from '../../../../src/r-bridge/retriever'; +import type { RDataType } from '../../../../src/typing/types'; +import { inferDataTypes } from '../../../../src/typing/infer'; +import type { FlowrSearch } from '../../../../src/search/flowr-search-builder'; +import { runSearch } from '../../../../src/search/flowr-search-executor'; + +export function assertInferredType(input: string, expectedType: RDataType): void { + test(`Infer ${expectedType.tag} for ${input}`, async() => { + const executor = new TreeSitterExecutor(); + const result = await createDataflowPipeline(executor, { request: requestFromInput(input) }).allRemainingSteps(); + const typedAst = inferDataTypes(result.normalize, result.dataflow); + const rootNode = typedAst.ast; + expect(rootNode.info.inferredType).toEqual(expectedType); + }); +} +export function assertInferredTypes( + input: string, + ...expectations: { query: FlowrSearch, expectedType: RDataType }[] +): void { + describe(`Infer types for ${input}`, async() => { + const executor = new TreeSitterExecutor(); + const result = await createDataflowPipeline(executor, { request: requestFromInput(input) }).allRemainingSteps(); + inferDataTypes(result.normalize, result.dataflow); + + describe.each(expectations)('Infer $expectedType.tag for query $query', ({ query, expectedType }) => { + const searchResult = runSearch(query, result); + expect(searchResult).toHaveLength(1); + const node = searchResult[0].node; + + test(`Infer ${expectedType.tag} for ${node.lexeme}`, () => { + expect(node.info.inferredType).toEqual(expectedType); + }); + }); + }); +} \ No newline at end of file diff --git a/test/functionality/typing/basic-expression-type-inference.test.ts b/test/functionality/typing/basic-expression-type-inference.test.ts new file mode 100644 index 0000000000..e278225a6d --- /dev/null +++ b/test/functionality/typing/basic-expression-type-inference.test.ts @@ -0,0 +1,33 @@ +import { describe } from 'vitest'; +import { RDataTypeTag } from '../../../src/typing/types'; +import { assertInferredType, assertInferredTypes } from '../_helper/typing/assert-inferred-type'; +import { Q } from '../../../src/search/flowr-search-builder'; + +describe('Infer types for currently supported R expressions', () => { + // Test type inference for constants + describe.each([ + { description: 'logical constant', input: 'TRUE', expectedType: { tag: RDataTypeTag.Logical as const } }, + { description: 'integer constant', input: '42', expectedType: { tag: RDataTypeTag.Integer as const } }, + { description: 'double constant', input: '42.5', expectedType: { tag: RDataTypeTag.Double as const } }, + { description: 'complex number constant', input: '42i', expectedType: { tag: RDataTypeTag.Complex as const } }, + { description: 'string constant', input: '"Hello, world!"', expectedType: { tag: RDataTypeTag.String as const } }, + { description: 'empty expression list', input: '{}', expectedType: { tag: RDataTypeTag.Null as const } } + ])('Infer $expectedType for $description', ({ input, expectedType }) => assertInferredType(input, expectedType)); + + // Test type inference for variables + describe('Infer types for variables', () => { + assertInferredTypes( + 'x <- 42; x', + { query: Q.var('x').first().build(), expectedType: { tag: RDataTypeTag.Integer as const } }, + { query: Q.criterion('1@<-').build(), expectedType: { tag: RDataTypeTag.Integer as const } }, + { query: Q.var('x').last().build(), expectedType: { tag: RDataTypeTag.Integer as const } } + ); + assertInferredType('y', { tag: RDataTypeTag.Null }); + }); + + // Test type inference for currently unsupported R expressions + describe('Infer no type information for currently unsupported R expressions', () => { + assertInferredType('1 + 2', { tag: RDataTypeTag.Any }); + assertInferredType('print("Hello, world!")', { tag: RDataTypeTag.Any }); + }); +}); \ No newline at end of file