diff --git a/src/compute-engine/differential-equation-utils.ts b/src/compute-engine/differential-equation-utils.ts new file mode 100644 index 00000000..ea27b053 --- /dev/null +++ b/src/compute-engine/differential-equation-utils.ts @@ -0,0 +1,47 @@ +import type { Expression, IComputeEngine } from './global-types'; +import { isFunction, isSymbol } from './boxed-expression/type-guards'; + +export function symbolArg( + engine: IComputeEngine, + arg: Expression | undefined +): Expression { + if (arg === undefined) return engine.error('missing'); + if (!isSymbol(arg)) return engine.typeError('symbol', arg.type, arg); + return arg; +} + +export function isDependentFunction( + expr: Expression, + dependentName: string, + independentName: string +): boolean { + return ( + isFunction(expr) && + expr.operator === dependentName && + expr.nops === 1 && + isSymbol(expr.op1, independentName) + ); +} + +export function isDerivativeOfDependent( + expr: Expression, + dependentName: string, + independentName: string +): boolean { + if (isFunction(expr, 'D')) { + return ( + isDependentFunction(expr.op1, dependentName, independentName) && + isSymbol(expr.op2, independentName) + ); + } + + if (isFunction(expr, 'Apply') && isFunction(expr.op1, 'Derivative')) { + return ( + isSymbol(expr.op1.op1, dependentName) && + expr.nops === 2 && + isSymbol(expr.op2, independentName) + ); + } + + return false; +} diff --git a/src/compute-engine/library/calculus.ts b/src/compute-engine/library/calculus.ts index 15d531c0..2a25551b 100644 --- a/src/compute-engine/library/calculus.ts +++ b/src/compute-engine/library/calculus.ts @@ -13,8 +13,11 @@ import { import { monteCarloEstimate } from '../numerics/monte-carlo'; import { integrateSemiInfiniteOscillatory } from '../numerics/oscillatory-quadrature'; import { centeredDiff8thOrder, limit } from '../numerics/numeric'; +import { nDSolve } from '../numerics/differential-equations'; import { derivative, differentiate } from '../symbolic/derivative'; import { antiderivative } from '../symbolic/antiderivative'; +import { dSolve } from '../symbolic/differential-equations'; +import { symbolArg } from '../differential-equation-utils'; import { symbolicLimit } from '../symbolic/limit'; import { residue } from '../symbolic/residue'; import { canonicalLimits, canonicalLimitsSequence } from './utils'; @@ -465,6 +468,66 @@ volumes }, }, + DSolve: { + description: 'Symbolic differential equation solver.', + broadcastable: false, + lazy: true, + signature: '(expression, symbol, symbol) -> expression', + canonical: (ops, { engine }) => { + if (ops.length === 0) + return engine._fn('DSolve', [ + engine.error('missing'), + engine.error('missing'), + engine.error('missing'), + ]); + if (ops.length === 1) + return engine._fn('DSolve', [ + ops[0], + engine.error('missing'), + engine.error('missing'), + ]); + if (ops.length === 2) + return engine._fn('DSolve', [ + ops[0], + symbolArg(engine, ops[1]), + engine.error('missing'), + ]); + + return engine._fn('DSolve', [ + ops[0], + symbolArg(engine, ops[1]), + symbolArg(engine, ops[2]), + ]); + }, + evaluate: ([equation, dependent, independent]) => + dSolve(equation, dependent, independent), + }, + + NDSolve: { + description: 'Numerical differential equation solver.', + broadcastable: false, + lazy: true, + signature: + '(expression, symbol, limits:(tuple|symbol), number, number?) -> list', + canonical: (ops, { engine }) => { + const missing = engine.error('missing'); + const limits = + ops[2] && isFunction(ops[2]) + ? canonicalLimits(ops[2].ops, { engine }) + : canonicalLimits(ops[2] ? [ops[2]] : [], { engine }); + + return engine._fn('NDSolve', [ + ops[0] ?? missing, + symbolArg(engine, ops[1]), + limits ?? missing, + ops[3]?.canonical ?? missing, + ...(ops[4] ? [ops[4].canonical] : []), + ]); + }, + evaluate: ([equation, dependent, limits, initialValue, steps]) => + nDSolve(equation, dependent, limits, initialValue, steps), + }, + // This is used to represent the indexing set/limits (i.e. // an index, lower and upper bounds) of a function // (not to be confused with Limit, which calculates the limit of a diff --git a/src/compute-engine/numerics/differential-equations.ts b/src/compute-engine/numerics/differential-equations.ts new file mode 100644 index 00000000..25a5a393 --- /dev/null +++ b/src/compute-engine/numerics/differential-equations.ts @@ -0,0 +1,148 @@ +import { checkDeadline } from '../../common/interruptible'; +import type { Expression } from '../global-types'; +import { isFunction, sym } from '../boxed-expression/type-guards'; +import { + isDependentFunction, + isDerivativeOfDependent, +} from '../differential-equation-utils'; + +export type RK4Options = { + steps: number; + deadline?: number; +}; + +export type ODESample = readonly [x: number, y: number]; + +/** + * Fixed-step classical fourth-order Runge-Kutta solver for scalar explicit + * initial value problems: y' = f(x, y), y(x0) = y0. + */ +export function rk4( + f: (x: number, y: number) => number, + x0: number, + y0: number, + x1: number, + options: RK4Options +): ODESample[] | undefined { + const steps = Math.trunc(options.steps); + if ( + !Number.isFinite(x0) || + !Number.isFinite(y0) || + !Number.isFinite(x1) || + !Number.isInteger(steps) || + steps <= 0 + ) + return undefined; + + const h = (x1 - x0) / steps; + const samples: ODESample[] = [[x0, y0]]; + let x = x0; + let y = y0; + + for (let i = 0; i < steps; i++) { + if ((i & 0xff) === 0) checkDeadline(options.deadline); + + const k1 = f(x, y); + const k2 = f(x + h / 2, y + (h * k1) / 2); + const k3 = f(x + h / 2, y + (h * k2) / 2); + const k4 = f(x + h, y + h * k3); + if (![k1, k2, k3, k4].every(Number.isFinite)) return undefined; + + y += (h / 6) * (k1 + 2 * k2 + 2 * k3 + k4); + x = i === steps - 1 ? x1 : x + h; + if (!Number.isFinite(y)) return undefined; + samples.push([x, y]); + } + + return samples; +} + +function explicitRhs( + equation: Expression, + dependentName: string, + independentName: string +): Expression | undefined { + if (!isFunction(equation, 'Equal')) return undefined; + if (isDerivativeOfDependent(equation.op1, dependentName, independentName)) + return equation.op2; + if (isDerivativeOfDependent(equation.op2, dependentName, independentName)) + return equation.op1; + return undefined; +} + +function substituteDependentCall( + expr: Expression, + dependentName: string, + independentName: string, + stateName: string +): Expression { + if (isDependentFunction(expr, dependentName, independentName)) + return expr.engine.symbol(stateName); + if (!isFunction(expr)) return expr; + return expr.engine._fn( + expr.operator, + expr.ops.map((op) => + substituteDependentCall(op, dependentName, independentName, stateName) + ) + ); +} + +export function nDSolve( + equation: Expression, + dependent: Expression, + limits: Expression, + initialValue: Expression, + stepsExpr?: Expression +): Expression | undefined { + const ce = equation.engine; + const dependentName = sym(dependent); + if (!dependentName) return undefined; + + if (!isFunction(limits, 'Limits')) return undefined; + const independentName = sym(limits.op1); + if (!independentName) return undefined; + + const [x0, x1, y0] = [ + limits.op2.N().re, + limits.op3.N().re, + initialValue.N().re, + ]; + if (![x0, x1, y0].every(Number.isFinite)) return undefined; + + const steps = stepsExpr === undefined ? 100 : stepsExpr.N().re; + if ( + !Number.isInteger(steps) || + steps <= 0 || + steps > ce.iterationLimit || + steps + 1 > ce.maxCollectionSize + ) + return undefined; + + const rhs = explicitRhs(equation.structural, dependentName, independentName); + if (!rhs) return undefined; + + const stateName = `ndsolve${dependentName}state`; + const compiledRhs = substituteDependentCall( + rhs, + dependentName, + independentName, + stateName + ); + const compiled = ce._compile(compiledRhs, { realOnly: true }); + if (!compiled.success) return undefined; + const run = compiled.run as (vars: Record) => number; + + const samples = rk4( + (x, y) => run({ [independentName]: x, [stateName]: y }), + x0, + y0, + x1, + { steps, deadline: ce._deadline } + ); + if (!samples) return undefined; + + return ce._fn( + 'List', + samples.map(([x, y]) => ce._fn('List', [ce.number(x), ce.number(y)])) + ); +} diff --git a/src/compute-engine/symbolic/differential-equations.ts b/src/compute-engine/symbolic/differential-equations.ts new file mode 100644 index 00000000..f2ace10c --- /dev/null +++ b/src/compute-engine/symbolic/differential-equations.ts @@ -0,0 +1,258 @@ +import { antiderivative } from './antiderivative'; +import type { Expression } from '../global-types'; +import { isFunction, isSymbol, sym } from '../boxed-expression/type-guards'; +import { + isDependentFunction, + isDerivativeOfDependent, +} from '../differential-equation-utils'; + +interface LinearTermCoefficients { + derivative: Expression; + dependent: Expression; + rest: Expression; +} + +function functionName(expr: Expression): string | undefined { + if (!isFunction(expr)) return undefined; + return expr.operator; +} + +function splitTerm( + term: Expression, + dependentName: string, + independentName: string +): { kind: 'derivative' | 'dependent' | 'rest'; coefficient: Expression } { + const ce = term.engine; + + if (isDerivativeOfDependent(term, dependentName, independentName)) + return { kind: 'derivative', coefficient: ce.One }; + + if (isDependentFunction(term, dependentName, independentName)) + return { kind: 'dependent', coefficient: ce.One }; + + if (isFunction(term, 'Negate')) { + const result = splitTerm(term.op1, dependentName, independentName); + return { ...result, coefficient: result.coefficient.neg() }; + } + + if (isFunction(term, 'Multiply')) { + let derivativeIndex = -1; + let dependentIndex = -1; + + for (let i = 0; i < term.ops.length; i++) { + if ( + isDerivativeOfDependent(term.ops[i], dependentName, independentName) + ) { + if (derivativeIndex >= 0) return { kind: 'rest', coefficient: term }; + derivativeIndex = i; + } else if ( + isDependentFunction(term.ops[i], dependentName, independentName) + ) { + if (dependentIndex >= 0) return { kind: 'rest', coefficient: term }; + dependentIndex = i; + } + } + + if (derivativeIndex >= 0 && dependentIndex >= 0) + return { kind: 'rest', coefficient: term }; + + const matchedIndex = + derivativeIndex >= 0 ? derivativeIndex : dependentIndex; + if (matchedIndex >= 0) { + const coefficientFactors = term.ops.filter((_, i) => i !== matchedIndex); + const coefficient = + coefficientFactors.length === 0 + ? ce.One + : ce.function('Multiply', coefficientFactors); + return { + kind: derivativeIndex >= 0 ? 'derivative' : 'dependent', + coefficient, + }; + } + } + + return { kind: 'rest', coefficient: term }; +} + +function hasDependentOrDerivative( + expr: Expression, + dependentName: string, + independentName: string +): boolean { + if (isDependentFunction(expr, dependentName, independentName)) return true; + if (isDerivativeOfDependent(expr, dependentName, independentName)) + return true; + if (!isFunction(expr)) return false; + return expr.ops.some((op) => + hasDependentOrDerivative(op, dependentName, independentName) + ); +} + +function collectLinearTerms( + residual: Expression, + dependentName: string, + independentName: string +): LinearTermCoefficients { + const ce = residual.engine; + const terms = isFunction(residual, 'Add') + ? residual.ops + : isFunction(residual, 'Subtract') + ? [residual.op1, residual.op2.neg()] + : [residual]; + let derivative = ce.Zero; + let dependent = ce.Zero; + let rest = ce.Zero; + + for (const term of terms) { + const split = splitTerm(term, dependentName, independentName); + if (split.kind === 'derivative') + derivative = derivative.add(split.coefficient); + else if (split.kind === 'dependent') + dependent = dependent.add(split.coefficient); + else rest = rest.add(split.coefficient); + } + + return { derivative, dependent, rest }; +} + +function negateCoefficients( + coefficients: LinearTermCoefficients +): LinearTermCoefficients { + return { + derivative: coefficients.derivative.neg(), + dependent: coefficients.dependent.neg(), + rest: coefficients.rest.neg(), + }; +} + +function equationCoefficients( + equation: Expression, + dependentName: string, + independentName: string +): LinearTermCoefficients { + if (!isFunction(equation, 'Equal')) + return collectLinearTerms( + equation.structural, + dependentName, + independentName + ); + + const lhs = collectLinearTerms( + equation.op1.structural, + dependentName, + independentName + ); + const rhs = negateCoefficients( + collectLinearTerms(equation.op2.structural, dependentName, independentName) + ); + + return { + derivative: lhs.derivative.add(rhs.derivative), + dependent: lhs.dependent.add(rhs.dependent), + rest: lhs.rest.add(rhs.rest), + }; +} + +function expressionForDependent( + dependent: Expression, + independent: Expression +): { + dependentName: string; + independentName: string; + dependentCall: Expression; +} | null { + const dependentName = sym(dependent) ?? functionName(dependent); + const independentName = sym(independent); + if (!dependentName || !independentName) return null; + + const ce = dependent.engine; + return { + dependentName, + independentName, + dependentCall: ce.function(dependentName, [ce.symbol(independentName)]), + }; +} + +function collectSymbols( + expr: Expression, + symbols = new Set() +): Set { + if (isSymbol(expr)) symbols.add(expr.symbol); + if (isFunction(expr)) { + for (const op of expr.ops) collectSymbols(op, symbols); + } + return symbols; +} + +function integrationConstant(equation: Expression): Expression { + const ce = equation.engine; + const usedSymbols = collectSymbols(equation); + + for (let i = 0; ; i++) { + const name = i === 0 ? 'C' : 'c'.repeat(i); + if (usedSymbols.has(name)) continue; + if (ce.context.lexicalScope.bindings.has(name)) continue; + return ce.symbol(name); + } +} + +/** + * Solve a small first-order linear ODE subset: + * + * y'(x) + p(x)y(x) = q(x) + * + * The returned expression is a `List` of `Equal` expressions for `y(x)`. + * Unsupported equations return `undefined`, allowing the `DSolve` operator to + * remain inert. + */ +export function dSolve( + equation: Expression, + dependent: Expression, + independent: Expression +): Expression | undefined { + const names = expressionForDependent(dependent, independent); + if (!names) return undefined; + + const { dependentName, independentName, dependentCall } = names; + const ce = equation.engine; + + const coefficients = equationCoefficients( + equation, + dependentName, + independentName + ); + + if (coefficients.derivative.isSame(0)) return undefined; + if ( + hasDependentOrDerivative( + coefficients.derivative, + dependentName, + independentName + ) || + hasDependentOrDerivative( + coefficients.dependent, + dependentName, + independentName + ) || + hasDependentOrDerivative(coefficients.rest, dependentName, independentName) + ) + return undefined; + + const p = coefficients.dependent.div(coefficients.derivative).simplify(); + const q = coefficients.rest.neg().div(coefficients.derivative).simplify(); + const c = integrationConstant(equation); + + let solution: Expression; + if (p.isSame(0)) { + const integral = antiderivative(q, independentName); + solution = c.add(integral).simplify(); + } else { + const integralP = antiderivative(p, independentName); + const integratingFactor = ce.function('Exp', [integralP]).simplify(); + const weightedRhs = integratingFactor.mul(q).simplify(); + const integral = antiderivative(weightedRhs, independentName); + solution = c.add(integral).div(integratingFactor).simplify(); + } + + return ce.function('List', [ce.function('Equal', [dependentCall, solution])]); +} diff --git a/test/compute-engine/differential-equations.test.ts b/test/compute-engine/differential-equations.test.ts new file mode 100644 index 00000000..2bfea910 --- /dev/null +++ b/test/compute-engine/differential-equations.test.ts @@ -0,0 +1,244 @@ +import { engine } from '../utils'; + +function dsolve(equation: unknown, dependent = 'y', independent = 'x') { + return engine.expr(['DSolve', equation, dependent, independent]).evaluate(); +} + +function ndsolve( + equation: unknown, + initialValue: unknown, + steps = 100, + dependent = 'y', + independent = 'x', + lower = 0, + upper = 1, + limits: unknown = ['Limits', independent, lower, upper] +) { + return engine + .expr(['NDSolve', equation, dependent, limits, initialValue, steps]) + .evaluate(); +} + +function finalSample(result: ReturnType): [number, number] { + const sample = result.ops[result.ops.length - 1]; + return [sample.op1.N().re, sample.op2.N().re]; +} + +function verifyFirstOrderSolution( + solution: ReturnType, + rhs: unknown +): boolean { + const solutionEquation = solution.op1; + const yValue = solutionEquation.op2; + const derivative = engine + .expr(['D', yValue, 'x']) + .evaluate() + .simplify().structural; + const expectedTemplate = engine.expr(rhs, { form: 'raw' }); + const expected = + expectedTemplate.replace( + { match: ['y', 'x'], replace: yValue }, + { recursive: true } + ) ?? expectedTemplate; + const sample = { C: 2, x: 0.75 }; + const value = derivative + .subs(sample) + .sub(expected.structural.subs(sample)) + .simplify() + .N().re; + return Math.abs(value) < 1e-10; +} + +describe('DSolve', () => { + test('solves y prime equals y', () => { + const solution = dsolve(['Equal', ['D', ['y', 'x'], 'x'], ['y', 'x']]); + + expect(solution.toString()).toMatchInlineSnapshot(`[y(x) === C * e^x]`); + expect(verifyFirstOrderSolution(solution, ['y', 'x'])).toBe(true); + }); + + test('solves y prime equals constant multiple of y', () => { + const solution = dsolve([ + 'Equal', + ['D', ['y', 'x'], 'x'], + ['Multiply', 3, ['y', 'x']], + ]); + + expect(solution.toString()).toMatchInlineSnapshot(`[y(x) === C / e^(-3x)]`); + expect( + verifyFirstOrderSolution(solution, ['Multiply', 3, ['y', 'x']]) + ).toBe(true); + }); + + test('solves y prime equals x squared', () => { + const solution = dsolve([ + 'Equal', + ['D', ['y', 'x'], 'x'], + ['Power', 'x', 2], + ]); + + expect(solution.toString()).toMatchInlineSnapshot( + `[y(x) === 1/3 * x^3 + C]` + ); + expect(verifyFirstOrderSolution(solution, ['Power', 'x', 2])).toBe(true); + }); + + test('solves first-order linear equation', () => { + const solution = dsolve([ + 'Equal', + ['Add', ['D', ['y', 'x'], 'x'], ['y', 'x']], + 'x', + ]); + + expect(solution.toString()).toMatchInlineSnapshot( + `[y(x) === x + C / e^x - 1]` + ); + expect( + verifyFirstOrderSolution(solution, ['Subtract', 'x', ['y', 'x']]) + ).toBe(true); + }); + + test('solves first-order homogeneous linear equation with variable coefficient', () => { + const solution = dsolve([ + 'Equal', + ['Add', ['D', ['y', 'x'], 'x'], ['Multiply', 2, 'x', ['y', 'x']]], + 0, + ]); + + expect(solution.toString()).toMatchInlineSnapshot(`[y(x) === C / e^(x^2)]`); + expect( + verifyFirstOrderSolution(solution, [ + 'Negate', + ['Multiply', 2, 'x', ['y', 'x']], + ]) + ).toBe(true); + }); + + test('uses a fallback integration constant when C is already declared', () => { + engine.declare('C', 'real'); + try { + const solution = dsolve(['Equal', ['D', ['y', 'x'], 'x'], ['y', 'x']]); + + expect(solution.toString()).toMatchInlineSnapshot(`[y(x) === c * e^x]`); + expect(verifyFirstOrderSolution(solution, ['y', 'x'])).toBe(true); + } finally { + engine.forget('C'); + } + }); + + test('stays inert for unsupported nonlinear first-order equations', () => { + const result = dsolve([ + 'Equal', + ['D', ['y', 'x'], 'x'], + ['Power', ['y', 'x'], 2], + ]); + + expect(result.operator).toBe('DSolve'); + }); + + test('stays inert for unsupported higher-order equations', () => { + const result = dsolve([ + 'Equal', + ['D', ['D', ['y', 'x'], 'x'], 'x'], + ['y', 'x'], + ]); + + expect(result.operator).toBe('DSolve'); + }); +}); + +describe('NDSolve', () => { + test('solves y prime equals y with RK4 samples', () => { + const result = ndsolve(['Equal', ['D', ['y', 'x'], 'x'], ['y', 'x']], 1); + const [x, y] = finalSample(result); + const expected = engine.expr(['Exp', 1]).N().re; + + expect(result.operator).toBe('List'); + expect(result.ops.length).toBe(101); + expect(x).toBeCloseTo(1, 12); + expect(y).toBeCloseTo(expected, 8); + }); + + test('accepts tuple limits', () => { + const result = ndsolve( + ['Equal', ['D', ['y', 'x'], 'x'], ['y', 'x']], + 1, + 100, + 'y', + 'x', + 0, + 1, + ['Tuple', 'x', 0, 1] + ); + const [x, y] = finalSample(result); + const expected = engine.expr(['Exp', 1]).N().re; + + expect(result.operator).toBe('List'); + expect(x).toBeCloseTo(1, 12); + expect(y).toBeCloseTo(expected, 8); + }); + + test('solves variable coefficient first-order IVP with RK4 samples', () => { + const result = ndsolve( + [ + 'Equal', + ['D', ['y', 'x'], 'x'], + ['Negate', ['Multiply', 2, 'x', ['y', 'x']]], + ], + 1 + ); + const [, y] = finalSample(result); + const expected = engine.expr(['Exp', -1]).N().re; + + expect(y).toBeCloseTo(expected, 8); + }); + + test('solves inhomogeneous polynomial IVP with RK4 samples', () => { + const result = ndsolve( + ['Equal', ['D', ['y', 'x'], 'x'], ['Power', 'x', 2]], + 0 + ); + const [, y] = finalSample(result); + + expect(y).toBeCloseTo(1 / 3, 12); + }); + + test('solves IVP with non-elementary antiderivative using RK4 samples', () => { + const result = ndsolve( + ['Equal', ['D', ['y', 'x'], 'x'], ['Exp', ['Negate', ['Power', 'x', 2]]]], + 0, + 400 + ); + const [, y] = finalSample(result); + const expected = engine + .expr(['Multiply', ['Divide', ['Sqrt', 'Pi'], 2], ['Erf', 1]]) + .N().re; + + expect(y).toBeCloseTo(expected, 10); + }); + + test('stays inert for unsupported implicit equations', () => { + const result = ndsolve( + ['Equal', ['Add', ['D', ['y', 'x'], 'x'], ['y', 'x']], 'x'], + 1 + ); + + expect(result.operator).toBe('NDSolve'); + }); + + test('stays inert when requested steps exceed the iteration limit', () => { + const savedLimit = engine.iterationLimit; + engine.iterationLimit = 10; + try { + const result = ndsolve( + ['Equal', ['D', ['y', 'x'], 'x'], ['y', 'x']], + 1, + 11 + ); + + expect(result.operator).toBe('NDSolve'); + } finally { + engine.iterationLimit = savedLimit; + } + }); +});