Skip to content

Add type inference #1652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 31 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
337ed96
feat: infer types for constant expressions
loki259 Apr 5, 2025
4ef9811
style: make stylistic changes to satisfy ESLint
loki259 Apr 6, 2025
746596b
test(typing): streamline the type inferencer tests using `describe.each`
loki259 Apr 7, 2025
e53e003
feat(typing): add preliminary support for expression lists and `NULL`
loki259 Apr 7, 2025
8072fe6
test(typing): add tests for unsupported expression types
loki259 Apr 7, 2025
ce1fc8c
test: add a test for empty expression lists
loki259 Apr 7, 2025
53fe22f
feat(typing): add support for complex number literals
loki259 Apr 7, 2025
18ec199
test(types): refactor tests to use tree-sitter
loki259 Apr 7, 2025
bb514bf
refactor: refactor the type representation of R data types
loki259 May 3, 2025
2eccfcd
Merge branch 'main' into 174-add-type-inference
loki259 May 4, 2025
67f9204
Merge branch 'main' into 174-add-type-inference
loki259 May 8, 2025
ebcff1e
feat(ast): add `mapAstInfo` for transforming metadata attached to nodes
loki259 May 10, 2025
9c8c4c1
Merge branch '1637-map-attached-ast-info' into 174-add-type-inference
loki259 May 10, 2025
e8f02a2
wip(typing): rework type inference to use type variables and constraints
loki259 May 10, 2025
c74a1cb
Merge branch 'main' into 174-add-type-inference
loki259 May 11, 2025
9b7e136
refactor(typing): replace '<>' with `EmptyArgument` constant
loki259 May 11, 2025
d936697
refactor(typing): rename string representation of `RDataTypeTag.Any`
loki259 May 11, 2025
ac3b105
feat-fix(typing): fix erroneous property access of `inferredType`
loki259 May 11, 2025
9710138
feat-fix(semantic-cfg): explicit onProgram visitor
EagleoutIce May 11, 2025
14913c4
refactor(ast): rework `mapAstInfo` to mutate the AST nodes in place
loki259 May 12, 2025
5339c1a
refactor(ast): define default for `downUpdater` argument of `mapAstInfo`
loki259 May 12, 2025
d00cced
feat(ast): add wrapper of `mapAstInfo` for normalized ASTs
loki259 May 12, 2025
c966fe9
Merge branch '1637-map-attached-ast-info' into 174-add-type-inference
loki259 May 12, 2025
b1b5963
refactor(typing): infer types directly on the normalized AST
loki259 May 12, 2025
a9bbff0
test(typing): add test helper for asserting types of queried subnodes
loki259 May 14, 2025
d65e243
meta(typing): rename test file to better reflect its testing domain
loki259 May 15, 2025
2427695
feat(typing): add support for variables
loki259 May 15, 2025
45a417d
refactor(typing): handle expression lists more robustly via return edges
loki259 May 19, 2025
5d631fc
refactor(typing): handle variables more robustly
loki259 May 19, 2025
76c5c1b
feat(typing): add preliminary support for functions
loki259 May 19, 2025
fa736e2
refactor(typing): handle program expression lists more robustly
loki259 May 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/control-flow/dfg-cfg-guided-visitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export class DataflowAwareCfgGuidedVisitor<
private onExprOrStmtNode(node: CfgStatementVertex | CfgExpressionVertex): void {
const dfgVertex = this.getDataflowGraph(node.id);
if(!dfgVertex) {
this.visitUnknown(node);
return;
}

Expand All @@ -76,6 +77,12 @@ export class DataflowAwareCfgGuidedVisitor<
}
}

/**
* called for every cfg vertex that has no corresponding dataflow vertex.
*/
protected visitUnknown(_vertex: CfgStatementVertex | CfgExpressionVertex): void {
}

protected visitValue(_val: DataflowGraphVertexValue): void {
}

Expand Down
14 changes: 13 additions & 1 deletion src/control-flow/semantic-cfg-guided-visitor.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { ControlFlowInformation } from './control-flow-graph';
import type { CfgExpressionVertex, CfgStatementVertex, ControlFlowInformation } from './control-flow-graph';

import type { DataflowInformation } from '../dataflow/info';

Expand Down Expand Up @@ -28,6 +28,7 @@ import type { FunctionArgument } from '../dataflow/graph/graph';
import { edgeIncludesType, EdgeType } from '../dataflow/graph/edge';
import { guard } from '../util/assert';
import type { NoInfo, RNode } from '../r-bridge/lang-4.x/ast/model/model';
import type { RExpressionList } from '../r-bridge/lang-4.x/ast/model/nodes/r-expression-list';



Expand Down Expand Up @@ -115,6 +116,14 @@ export class SemanticCfgGuidedVisitor<
}
}

protected override visitUnknown(vertex: CfgStatementVertex | CfgExpressionVertex) {
super.visitUnknown(vertex);
const ast = this.getNormalizedAst(vertex.id);
if(ast && ast.type === RType.ExpressionList && ast.info.parent === undefined) {
this.onProgram(ast);
}
}

protected onDispatchFunctionCallOrigins(call: DataflowGraphVertexFunctionCall, origins: readonly string[]) {
for(const origin of origins) {
this.onDispatchFunctionCallOrigin(call, origin);
Expand Down Expand Up @@ -182,6 +191,9 @@ export class SemanticCfgGuidedVisitor<
}
}

protected onProgram(_data: RExpressionList<OtherInfo>) {
}


/**
* Requests the {@link getOriginInDfg|origins} of the given node.
Expand Down
2 changes: 1 addition & 1 deletion src/r-bridge/lang-4.x/ast/model/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export type NoInfo = object;
* Will be used to reconstruct the source of the given element in the R-ast.
* This will not be part of most comparisons as it is mainly of interest to the reconstruction of R code.
*/
interface Source {
export interface Source {
/**
* The range is different from the assigned {@link Location} as it refers to the complete source range covered by the given
* element.
Expand Down
72 changes: 71 additions & 1 deletion src/r-bridge/lang-4.x/ast/model/processing/decorate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
* @module
*/

import type { NoInfo, RNode } from '../model';
import type { NoInfo, RNode, Source } from '../model';
import { guard } from '../../../../../util/assert';
import type { SourceRange } from '../../../../../util/range';
import { BiMap } from '../../../../../util/collections/bimap';
Expand Down Expand Up @@ -463,3 +463,73 @@ function createFoldForFunctionArgument<OtherInfo>(info: FoldInfo<OtherInfo>) {
return decorated;
};
}


export function mapAstInfo<OldInfo, Down, NewInfo>(ast: RNode<OldInfo>, down: Down, infoMapper: (node: RNode<OldInfo>, down: Down) => NewInfo, downUpdater: (node: RNode<OldInfo>, down: Down) => Down = (_node, down) => down): RNode<NewInfo> {
const fullInfoMapper = (node: RNode<OldInfo>, down: Down): NewInfo & Source => {
const sourceInfo = {
...(node.info.fullRange !== undefined ? { fullRange: node.info.fullRange } : {}),
...(node.info.fullLexeme !== undefined ? { fullLexeme: node.info.fullLexeme } : {}),
...(node.info.additionalTokens !== undefined ? { additionalTokens: node.info.additionalTokens } : {}),
...(node.info.file !== undefined ? { file: node.info.file } : {})
};
const mappedInfo = infoMapper(node, down);
return { ...sourceInfo, ...mappedInfo };
};

function updateInfo(n: RNode<OldInfo>, down: Down): RNode<NewInfo> {
(n.info as NewInfo) = fullInfoMapper(n, down);
return n as unknown as RNode<NewInfo>;
}

return foldAstStateful(ast, down, {
down: downUpdater,
foldNumber: updateInfo,
foldString: updateInfo,
foldLogical: updateInfo,
foldSymbol: updateInfo,
foldAccess: (node, _name, _access, down) => updateInfo(node, down),
foldBinaryOp: (op, _lhs, _rhs, down) => updateInfo(op, down),
foldPipe: (op, _lhs, _rhs, down) => updateInfo(op, down),
foldUnaryOp: (op, _operand, down) => updateInfo(op, down),
loop: {
foldFor: (loop, _variable, _vector, _body, down) => updateInfo(loop, down),
foldWhile: (loop, _condition, _body, down) => updateInfo(loop, down),
foldRepeat: (loop, _body, down) => updateInfo(loop, down),
foldNext: (next, down) => updateInfo(next, down),
foldBreak: (next, down) => updateInfo(next, down),
},
other: {
foldComment: (comment, down) => updateInfo(comment, down),
foldLineDirective: (comment, down) => updateInfo(comment, down),
},
foldIfThenElse: (ifThenExpr, _condition, _then, _otherwise, down ) =>
updateInfo(ifThenExpr, down),
foldExprList: (exprList, _grouping, _expressions, down) => updateInfo(exprList, down),
functions: {
foldFunctionDefinition: (definition, _parameters, _body, down) => updateInfo(definition, down),
/** folds named and unnamed function calls */
foldFunctionCall: (call, _functionNameOrExpression, _args, down) => updateInfo(call, down),
/** The `name` is `undefined` if the argument is unnamed, the value, if we have something like `x=,...` */
foldArgument: (argument, _name, _value, down) => updateInfo(argument, down),
/** The `defaultValue` is `undefined` if the argument was not initialized with a default value */
foldParameter: (parameter, _name, _defaultValue, down) => updateInfo(parameter, down)
}
});
}

export function mapNormalizedAstInfo<OldInfo, Down, NewInfo>(normalizedAst: NormalizedAst<OldInfo>, down: Down, infoMapper: (node: RNode<OldInfo & ParentInformation>, down: Down) => NewInfo, downUpdater: (node: RNode<OldInfo>, down: Down) => Down = (_node, down) => down): NormalizedAst<NewInfo> {
const parentInfoPreservingMapper = (node: RNode<OldInfo & ParentInformation>, down: Down): NewInfo & ParentInformation => {
const parentInfo = {
id: node.info.id,
parent: node.info.parent,
role: node.info.role,
nesting: node.info.nesting,
index: node.info.index
};
const mappedInfo = infoMapper(node, down);
return { ...parentInfo, ...mappedInfo };
};
mapAstInfo(normalizedAst.ast, down, parentInfoPreservingMapper, downUpdater);
return normalizedAst as unknown as NormalizedAst<NewInfo>;
}
171 changes: 171 additions & 0 deletions src/typing/infer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import { extractCFG } from '../control-flow/extract-cfg';
import { SemanticCfgGuidedVisitor } from '../control-flow/semantic-cfg-guided-visitor';
import type { DataflowGraphVertexFunctionCall, DataflowGraphVertexFunctionDefinition, DataflowGraphVertexUse, DataflowGraphVertexValue } from '../dataflow/graph/vertex';
import type { DataflowInformation } from '../dataflow/info';
import type { RLogical } from '../r-bridge/lang-4.x/ast/model/nodes/r-logical';
import type { RNumber } from '../r-bridge/lang-4.x/ast/model/nodes/r-number';
import type { RString } from '../r-bridge/lang-4.x/ast/model/nodes/r-string';
import type { NormalizedAst } from '../r-bridge/lang-4.x/ast/model/processing/decorate';
import { mapNormalizedAstInfo } from '../r-bridge/lang-4.x/ast/model/processing/decorate';
import type { RDataType } from './types';
import { RTypeVariable , RComplexType, RDoubleType, RIntegerType, RLogicalType, RStringType, resolveType, RNullType, RFunctionType } from './types';
import type { RExpressionList } from '../r-bridge/lang-4.x/ast/model/nodes/r-expression-list';
import { guard } from '../util/assert';
import { OriginType } from '../dataflow/origin/dfg-get-origin';
import type { NodeId } from '../r-bridge/lang-4.x/ast/model/processing/node-id';
import { edgeIncludesType, EdgeType } from '../dataflow/graph/edge';

export function inferDataTypes<Info extends { typeVariable?: undefined }>(ast: NormalizedAst<Info>, dataFlowInfo: DataflowInformation): NormalizedAst<Info & DataTypeInfo> {
const astWithTypeVars = decorateTypeVariables(ast);
const controlFlowInfo = extractCFG(astWithTypeVars);
const config = {
normalizedAst: astWithTypeVars,
controlFlow: controlFlowInfo,
dataflow: dataFlowInfo,
defaultVisitingOrder: 'forward' as const,
};
const visitor = new TypeInferingCfgGuidedVisitor(config);
visitor.start();

return resolveTypeVariables(astWithTypeVars);
}

type UnresolvedTypeInfo = {
typeVariable: RTypeVariable;
};

export type DataTypeInfo = {
inferredType: RDataType;
}

function decorateTypeVariables<OtherInfo>(ast: NormalizedAst<OtherInfo>): NormalizedAst<OtherInfo & UnresolvedTypeInfo> {
return mapNormalizedAstInfo(ast, {}, (node, _down) => ({ ...node.info, typeVariable: new RTypeVariable() }));
}

function resolveTypeVariables<Info extends UnresolvedTypeInfo>(ast: NormalizedAst<Info>): NormalizedAst<Omit<Info, keyof UnresolvedTypeInfo> & DataTypeInfo> {
return mapNormalizedAstInfo(ast, {}, (node, _down) => {
const { typeVariable, ...rest } = node.info;
return { ...rest, inferredType: resolveType(typeVariable) };
});
}

class TypeInferingCfgGuidedVisitor extends SemanticCfgGuidedVisitor<UnresolvedTypeInfo>{
override onLogicalConstant(_vertex: DataflowGraphVertexValue, node: RLogical<UnresolvedTypeInfo>): void {
node.info.typeVariable.unify(new RLogicalType());
}

override onNumberConstant(_vertex: DataflowGraphVertexValue, node: RNumber<UnresolvedTypeInfo>): void {
if(node.content.complexNumber) {
node.info.typeVariable.unify(new RComplexType());
} else if(Number.isInteger(node.content.num)) {
node.info.typeVariable.unify(new RIntegerType());
} else {
node.info.typeVariable.unify(new RDoubleType());
}
}

override onStringConstant(_vertex: DataflowGraphVertexValue, node: RString<UnresolvedTypeInfo>): void {
node.info.typeVariable.unify(new RStringType());
}

override onVariableUse(vertex: DataflowGraphVertexUse): void {
const node = this.getNormalizedAst(vertex.id);
guard(node !== undefined, 'Expected AST node to be defined');

const origins = this.getOrigins(vertex.id);
const readOrigins = origins?.filter((origin) => origin.type === OriginType.ReadVariableOrigin);

if(readOrigins === undefined || readOrigins.length === 0) {
node.info.typeVariable.unify(new RNullType());
return;
}

for(const readOrigin of readOrigins) {
const readNode = this.getNormalizedAst(readOrigin.id);
guard(readNode !== undefined, 'Expected read node to be defined');
node.info.typeVariable.unify(readNode.info.typeVariable);
}
}

override onAssignmentCall(data: { call: DataflowGraphVertexFunctionCall, target?: NodeId, source?: NodeId }): void {
if(data.target === undefined || data.source === undefined) {
return; // Malformed assignment
}

const variableNode = this.getNormalizedAst(data.target);
const valueNode = this.getNormalizedAst(data.source);
const assignmentNode = this.getNormalizedAst(data.call.id);
guard(variableNode !== undefined && valueNode !== undefined && assignmentNode !== undefined, 'Expected AST nodes to be defined');

variableNode.info.typeVariable.unify(valueNode.info.typeVariable);
assignmentNode.info.typeVariable.unify(variableNode.info.typeVariable);
}

override onDefaultFunctionCall(data: { call: DataflowGraphVertexFunctionCall }): void {
const node = this.getNormalizedAst(data.call.id);
guard(node !== undefined, 'Expected AST node to be defined');

const outgoing = this.config.dataflow.graph.outgoingEdges(data.call.id);
const callsTargets = outgoing?.entries()
.filter(([_target, edge]) => edgeIncludesType(edge.types, EdgeType.Calls))
.map(([target, _edge]) => target)
.toArray();

guard(callsTargets === undefined || callsTargets.length <= 1, 'Expected at most one call edge');

if(callsTargets === undefined || callsTargets.length === 0) {
// TODO: Handle builtin functions

Check failure on line 117 in src/typing/infer.ts

View workflow job for this annotation

GitHub Actions / 👩‍🏫 Linting on Main

Unexpected 'todo' comment: 'TODO: Handle builtin functions'
return;
}

const target = this.getNormalizedAst(callsTargets[0]);
guard(target !== undefined, 'Expected target node to be defined');

target.info.typeVariable.unify(new RFunctionType());
}

override onFunctionDefinition(vertex: DataflowGraphVertexFunctionDefinition): void {
const node = this.getNormalizedAst(vertex.id);
guard(node !== undefined, 'Expected AST node to be defined');

node.info.typeVariable.unify(new RFunctionType());
}

override onProgram(node: RExpressionList<UnresolvedTypeInfo>) {
const exitPoints = this.config.dataflow.exitPoints;
const evalCandidates = exitPoints.map((exitPoint) => exitPoint.nodeId);

if(evalCandidates.length === 0) {
node.info.typeVariable.unify(new RNullType());
return;
}

for(const candidateId of evalCandidates) {
const candidate = this.getNormalizedAst(candidateId);
guard(candidate !== undefined, 'Expected target node to be defined');
node.info.typeVariable.unify(candidate.info.typeVariable);
}
}

override onExpressionList(data: { call: DataflowGraphVertexFunctionCall }) {
const node = this.getNormalizedAst(data.call.id);
guard(node !== undefined, 'Expected AST node to be defined');

const outgoing = this.config.dataflow.graph.outgoingEdges(data.call.id);
const evalCandidates = outgoing?.entries()
.filter(([_target, edge]) => edgeIncludesType(edge.types, EdgeType.Returns))
.map(([target, _edge]) => target)
.toArray();

if(evalCandidates === undefined || evalCandidates.length === 0) {
node.info.typeVariable.unify(new RNullType());
return;
}

for(const candidateId of evalCandidates) {
const candidate = this.getNormalizedAst(candidateId);
guard(candidate !== undefined, 'Expected target node to be defined');
node.info.typeVariable.unify(candidate.info.typeVariable);
}
}
}
Loading
Loading