diff --git a/README.md b/README.md index ac1620dcc..d96396b38 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ PyPI version Continuous Integration Continuous Integration with Models + Test Coverage

diff --git a/brainpy/algorithms/offline.py b/brainpy/algorithms/offline.py index 36ebc30b3..476b52285 100644 --- a/brainpy/algorithms/offline.py +++ b/brainpy/algorithms/offline.py @@ -156,7 +156,7 @@ def gradient_descent_solve(self, targets, inputs, outputs=None): def cond_fun(a): i, par_old, par_new = a return jnp.logical_and(jnp.logical_not(jnp.allclose(par_old, par_new)), - i < self.max_iter).value + i < self.max_iter) def body_fun(a): i, _, par_new = a @@ -269,9 +269,17 @@ def call(self, targets, inputs, outputs=None): if self.gradient_descent: return self.gradient_descent_solve(targets, inputs) else: + n_features = inputs.shape[-1] temp = inputs.T @ inputs if self.regularizer.alpha > 0.: - temp += self.regularizer.alpha * jnp.eye(inputs.shape[-1]) + penalty = self.regularizer.alpha * jnp.ones((n_features,)) + # Do not penalize the intercept/bias column. ``polynomial_features`` + # (used by ``PolynomialRidgeRegression`` when ``add_bias=True``) + # prepends a constant column at index 0; shrinking it would bias + # the fit on data with a nonzero mean. + if getattr(self, 'add_bias', False): + penalty = penalty.at[0].set(0.) + temp += jnp.diag(penalty) weights = jnp.linalg.pinv(temp) @ (inputs.T @ targets) return weights @@ -383,7 +391,7 @@ def call(self, targets, inputs, outputs=None) -> ArrayType: def cond_fun(a): i, par_old, par_new = a return jnp.logical_and(jnp.logical_not(jnp.allclose(par_old, par_new)), - i < self.max_iter).value + i < self.max_iter) def body_fun(a): i, par_old, par_new = a diff --git a/brainpy/algorithms/online.py b/brainpy/algorithms/online.py index 9477154bd..9dd54c35d 100644 --- a/brainpy/algorithms/online.py +++ b/brainpy/algorithms/online.py @@ -145,12 +145,21 @@ def call( assert input.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {input.shape}' assert target.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {target.shape}' assert output.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {output.shape}' - k = jnp.dot(P.value, input.T) # (num_input, num_batch) + # Block recursive least squares update (valid for any batch size B>=1). + # See e.g. Haykin, "Adaptive Filter Theory", block/multi-sample RLS. + # K = P Hᵀ (I_B + H P Hᵀ)⁻¹ -> Kalman gain, shape (num_input, B) + # P <- P - K H P -> covariance update, shape (num_input, num_input) + # w <- w + K (target - H w) -> here output = H w, so dw = -K (output - target) + # For B==1 this reduces exactly to the previous scalar update + # (c = 1/(1+hPh)), but unlike `jnp.sum(1/(1+HPHᵀ))` it stays correct for + # B>1 by inverting the full (B, B) matrix instead of summing reciprocals. + Pv = P.value + k = jnp.dot(Pv, input.T) # (num_input, num_batch) hPh = jnp.dot(input, k) # (num_batch, num_batch) - c = jnp.sum(1.0 / (1.0 + hPh)) # () - P -= c * jnp.dot(k, k.T) # (num_input, num_input) + gain = jnp.dot(k, jnp.linalg.inv(jnp.eye(hPh.shape[0]) + hPh)) # (num_input, num_batch) + P.value = Pv - jnp.dot(gain, jnp.dot(input, Pv)) # (num_input, num_input) e = output - target # (num_batch, num_output) - dw = -c * jnp.dot(k, e) # (num_input, num_output) + dw = -jnp.dot(gain, e) # (num_input, num_output) return dw diff --git a/brainpy/analysis/lowdim/lowdim_analyzer.py b/brainpy/analysis/lowdim/lowdim_analyzer.py index aa199b9e8..02f8dfca9 100644 --- a/brainpy/analysis/lowdim/lowdim_analyzer.py +++ b/brainpy/analysis/lowdim/lowdim_analyzer.py @@ -374,7 +374,7 @@ def _get_fixed_points(self, candidates, *args, num_seg=None, tol_aux=1e-7, loss_ # args: parameters, a list/tuple of vectors candidates = candidates.value if isinstance(candidates, bm.Array) else candidates selected_ids = np.arange(len(candidates)) - args = tuple(a.value if isinstance(candidates, bm.Array) else a for a in args) + args = tuple(a.value if isinstance(a, bm.Array) else a for a in args) for a in args: assert len(a) == len(candidates) if num_seg is None: num_seg = len(self.resolutions[self.x_var]) @@ -950,7 +950,7 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7, # args: parameters, a list/tuple of vectors candidates = candidates.value if isinstance(candidates, bm.Array) else candidates selected_ids = np.arange(len(candidates)) - args = tuple(a.value if isinstance(candidates, bm.Array) else a for a in args) + args = tuple(a.value if isinstance(a, bm.Array) else a for a in args) for a in args: assert len(a) == len(candidates) if self.convert_type() == C.x_by_y: @@ -1035,9 +1035,16 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7, all_ids.append(seg_ids) for i in range(len(all_args)): all_args[i].append(seg_args[i][ids]) - all_fps = jnp.concatenate(all_fps) - all_ids = jnp.concatenate(all_ids) - all_args = tuple(jnp.concatenate(args) for args in all_args) + if len(all_fps): + all_fps = jnp.concatenate(all_fps) + all_ids = jnp.concatenate(all_ids) + all_args = tuple(jnp.concatenate(args) for args in all_args) + else: + # No candidate converged to a fixed point. Return empty arrays + # with the correct shapes/dtypes instead of concatenating []. + all_fps = jnp.zeros((0, candidates.shape[1]), dtype=candidates.dtype) + all_ids = jnp.zeros((0,), dtype=selected_ids.dtype) + all_args = tuple(jnp.zeros((0,), dtype=a.dtype) for a in args) return all_fps, all_ids, all_args diff --git a/brainpy/analysis/utils/optimization.py b/brainpy/analysis/utils/optimization.py index 6630d3897..6c6755922 100644 --- a/brainpy/analysis/utils/optimization.py +++ b/brainpy/analysis/utils/optimization.py @@ -395,7 +395,7 @@ def roots_of_1d_by_x(f, candidates, args=()): """ f = f_without_jaxarray_return(f) candidates = candidates.value if isinstance(candidates, bm.Array) else candidates - args = tuple(a.value if isinstance(candidates, bm.Array) else a for a in args) + args = tuple(a.value if isinstance(a, bm.Array) else a for a in args) vals = f(candidates, *args) signs = jnp.sign(vals) zero_sign_idx = jnp.where(signs == 0)[0] diff --git a/brainpy/connect/random_conn.py b/brainpy/connect/random_conn.py index 50e658366..91640a2bd 100644 --- a/brainpy/connect/random_conn.py +++ b/brainpy/connect/random_conn.py @@ -84,10 +84,10 @@ def __repr__(self): f'seed={self.seed})') def _iii(self): - if (not self.include_self) and (self.pre_num != self.post_num): - raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' - f'But `include_self` is set to True.') - + # NOTE: no guard on ``include_self=False`` for rectangular (pre_num != + # post_num) shapes — ``build_coo``/``build_csr`` already drop coincident + # ``pre == post`` indices generically, so the previous (contradictory) + # ConnectorError was both wrong-worded and overly restrictive. if self.pre_ratio < 1.: pre_num_to_select = int(self.pre_num * self.pre_ratio) pre_ids = self._jaxrand.choice(self.pre_num, size=(pre_num_to_select,), replace=False) @@ -96,7 +96,10 @@ def _iii(self): pre_ids = jnp.arange(self.pre_num) post_num_total = self.post_num - post_num_to_select = int(self.post_num * self.prob) + # Round instead of truncating: ``int(post_num * prob)`` floors a small + # expected fan-out (e.g. ``3 * 0.3 = 0.9``) to 0 connections, silently + # producing an empty connectivity for small post populations. + post_num_to_select = int(round(self.post_num * self.prob)) if self.allow_multi_conn: selected_post_ids = self._jaxrand.randint(0, post_num_total, (pre_num_to_select, post_num_to_select)) diff --git a/brainpy/delay.py b/brainpy/delay.py index ea5e7c22f..04270253c 100644 --- a/brainpy/delay.py +++ b/brainpy/delay.py @@ -251,10 +251,13 @@ def __init__( # delay data self._init = init + # ``self.data`` must exist before ``_init_data`` is called, because + # ``_init_data`` reads ``self.data`` to decide whether to allocate a + # new buffer. Initialize it unconditionally to avoid an AttributeError + # on the ``time > 0`` (``max_length > 0``) path. + self.data = None if self.max_length > 0: self._init_data(self.max_length) - else: - self.data = None # other info if entries is not None: diff --git a/brainpy/dnn/linear.py b/brainpy/dnn/linear.py index a38e5d385..4411cff55 100644 --- a/brainpy/dnn/linear.py +++ b/brainpy/dnn/linear.py @@ -225,8 +225,13 @@ def stdp_update( w_min: numbers.Number = None, w_max: numbers.Number = None ): + if bm.isscalar(self.W): + raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.') if not isinstance(self.W, bm.Variable): - raise ValueError(f'When using STDP to update synaptic weights, the weight must be a variable.') + # Promote the plain-array weight to a traceable Variable so the STDP + # in-place update works even when the comm was not built in a training + # mode. Only reached from ``stdp_update``; non-plastic use is unaffected. + self.W = bm.Variable(self.W) if on_pre is not None: spike = on_pre['spike'] trace = on_pre['trace'] @@ -320,8 +325,13 @@ def stdp_update( w_min: numbers.Number = None, w_max: numbers.Number = None ): + if bm.isscalar(self.weight): + raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.') if not isinstance(self.weight, bm.Variable): - raise ValueError(f'When using STDP to update synaptic weights, the weight must be a variable.') + # Promote the plain-array weight to a traceable Variable so the STDP + # in-place update works even when the comm was not built in a training + # mode. Only reached from ``stdp_update``; non-plastic use is unaffected. + self.weight = bm.Variable(self.weight) if on_pre is not None: spike = on_pre['spike'] trace = on_pre['trace'] @@ -375,7 +385,10 @@ def stdp_update( if isinstance(self.weight, float): raise ValueError(f'Cannot update the weight of a constant node.') if not isinstance(self.weight, bm.Variable): - self.tracing_variable('weight', self.weight, self.weight.shape) + # Promote the plain-array weight to a traceable Variable so that the + # STDP in-place update below works. This branch is only reached from + # ``stdp_update`` (plasticity), so non-plastic use is unaffected. + self.weight = bm.Variable(self.weight) if on_pre is not None: spike = on_pre['spike'] trace = on_pre['trace'] @@ -449,7 +462,10 @@ def stdp_update( if isinstance(self.weight, float): raise ValueError(f'Cannot update the weight of a constant node.') if not isinstance(self.weight, bm.Variable): - self.tracing_variable('weight', self.weight, self.weight.shape) + # Promote the plain-array weight to a traceable Variable so that the + # STDP in-place update below works. This branch is only reached from + # ``stdp_update`` (plasticity), so non-plastic use is unaffected. + self.weight = bm.Variable(self.weight) if on_pre is not None: spike = on_pre['spike'] trace = on_pre['trace'] @@ -500,7 +516,10 @@ def stdp_update( raise ValueError( f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.') if not isinstance(self.weight, bm.Variable): - self.tracing_variable('weight', self.weight, self.weight.shape) + # Promote the plain-array weight to a traceable Variable so that the + # STDP in-place update below works. This branch is only reached from + # ``stdp_update`` (plasticity), so non-plastic use is unaffected. + self.weight = bm.Variable(self.weight) if on_pre is not None: # update on presynaptic spike spike = on_pre['spike'] trace = on_pre['trace'] diff --git a/brainpy/dnn/normalization.py b/brainpy/dnn/normalization.py index 1b3a2e85c..00b321819 100644 --- a/brainpy/dnn/normalization.py +++ b/brainpy/dnn/normalization.py @@ -113,6 +113,12 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): + # ``BatchNorm`` computes statistics across a batch axis, so it requires a + # batching/training mode. The global default mode is ``NonBatchingMode``, + # which would make the layer raise ``UnsupportedError`` out-of-the-box + # (H-51). Default to the training mode when the user does not specify one. + if mode is None: + mode = bm.training_mode super(BatchNorm, self).__init__(name=name, mode=mode) # check.is_subclass(self.mode, (bm.BatchingMode, bm.TrainingMode), self.name) @@ -131,9 +137,19 @@ def __init__( self.running_mean = bm.Variable(jnp.zeros(self.num_features)) self.running_var = bm.Variable(jnp.ones(self.num_features)) if self.affine: - assert isinstance(self.mode, bm.TrainingMode) - self.bias = bm.TrainVar(parameter(self.bias_initializer, self.num_features)) - self.scale = bm.TrainVar(parameter(self.scale_initializer, self.num_features)) + bias = parameter(self.bias_initializer, self.num_features) + scale = parameter(self.scale_initializer, self.num_features) + # Make the affine parameters trainable only under a training mode; + # otherwise keep them as plain variables (matching the pattern used + # by ``brainpy.dnn.Linear``/``Conv``). This avoids the hard + # ``assert isinstance(self.mode, TrainingMode)`` that crashed the + # layer under non-training modes (H-51). + if isinstance(self.mode, bm.TrainingMode): + self.bias = bm.TrainVar(bias) + self.scale = bm.TrainVar(scale) + else: + self.bias = bm.Variable(bias) + self.scale = bm.Variable(scale) def _check_input_dim(self, x): raise NotImplementedError @@ -500,9 +516,20 @@ def __init__( assert all([isinstance(s, int) for s in normalized_shape]), 'Must be a sequence of integer.' self.elementwise_affine = elementwise_affine if self.elementwise_affine: - assert isinstance(self.mode, bm.TrainingMode) - self.bias = bm.TrainVar(parameter(self.bias_initializer, self.normalized_shape)) - self.scale = bm.TrainVar(parameter(self.scale_initializer, self.normalized_shape)) + bias = parameter(self.bias_initializer, self.normalized_shape) + scale = parameter(self.scale_initializer, self.normalized_shape) + # ``LayerNorm`` does not depend on batch statistics, so it works under + # any mode (including the default ``NonBatchingMode``). Only wrap the + # affine parameters as trainable variables under a training mode; + # otherwise keep them as plain variables. This removes the hard + # ``assert isinstance(self.mode, TrainingMode)`` that crashed the + # affine layer out-of-the-box (H-51). + if isinstance(self.mode, bm.TrainingMode): + self.bias = bm.TrainVar(bias) + self.scale = bm.TrainVar(scale) + else: + self.bias = bm.Variable(bias) + self.scale = bm.Variable(scale) def update(self, x): if x.shape[-len(self.normalized_shape):] != self.normalized_shape: @@ -585,16 +612,32 @@ def __init__( self.bias_initializer = bias_initializer self.scale_initializer = scale_initializer if self.affine: - assert isinstance(self.mode, bm.TrainingMode) - self.bias = bm.TrainVar(parameter(self.bias_initializer, self.num_channels)) - self.scale = bm.TrainVar(parameter(self.scale_initializer, self.num_channels)) + bias = parameter(self.bias_initializer, self.num_channels) + scale = parameter(self.scale_initializer, self.num_channels) + # ``GroupNorm``/``InstanceNorm`` compute statistics independently of + # the batch size, so they work under any mode (including the default + # ``NonBatchingMode``). Only make the affine parameters trainable + # under a training mode; otherwise keep them as plain variables. This + # removes the hard ``assert isinstance(self.mode, TrainingMode)`` that + # crashed the affine layer out-of-the-box (H-51). + if isinstance(self.mode, bm.TrainingMode): + self.bias = bm.TrainVar(bias) + self.scale = bm.TrainVar(scale) + else: + self.bias = bm.Variable(bias) + self.scale = bm.Variable(scale) def update(self, x): assert x.shape[-1] == self.num_channels origin_shape, origin_dim = x.shape, x.ndim group_shape = (-1,) + x.shape[1:-1] + (self.num_groups, self.num_channels // self.num_groups) x = bm.as_jax(x.reshape(group_shape)) - reduction_axes = tuple(range(1, x.ndim - 1)) + (-1,) + # After reshape the axis layout is + # ``[0]=batch, [1..ndim-3]=spatial, [ndim-2]=group, [ndim-1]=channels-per-group``. + # Normalization must reduce over the spatial axes and the within-group + # channel axis, but NOT over the group axis (``ndim-2``); otherwise the + # groups are averaged together and ``num_groups`` has no effect (C-05). + reduction_axes = tuple(range(1, x.ndim - 2)) + (-1,) mean = jnp.mean(x, reduction_axes, keepdims=True) var = jnp.var(x, reduction_axes, keepdims=True) x = (x - mean) * lax.rsqrt(var + lax.convert_element_type(self.epsilon, x.dtype)) diff --git a/brainpy/dyn/channels/calcium.py b/brainpy/dyn/channels/calcium.py index 3dd30d74c..b2bf3ca5e 100644 --- a/brainpy/dyn/channels/calcium.py +++ b/brainpy/dyn/channels/calcium.py @@ -20,6 +20,8 @@ from typing import Union, Callable, Optional +import jax.numpy as jnp + import brainpy.math as bm from brainpy.context import share from brainpy.dyn.ions.calcium import Calcium, CalciumDyna @@ -40,6 +42,20 @@ ] +def _exprel(x): + """Stable ``(exp(x) - 1) / x`` with a finite value *and* finite gradient at ``x == 0``. + + The HH/Markov rate functions have the removable-singularity form + ``num * temp / (exp(temp / k) - 1)`` which is ``0 / 0`` (NaN value and NaN + gradient) at the singular voltage. Rewriting them as ``num * k / exprel(...)`` + removes the singularity in value, but ``brainpy.math.exprel`` still yields a + NaN gradient at 0, so we use this branch-safe helper instead. + """ + small = jnp.abs(x) < 1e-7 + safe_x = jnp.where(small, 1.0, x) + return jnp.where(small, 1.0 + x / 2.0, jnp.expm1(safe_x) / safe_x) + + class CalciumChannel(IonChannel): """Base class for Calcium ion channels.""" @@ -708,7 +724,8 @@ def __init__( def f_p_alpha(self, V): temp = -27 - V + self.V_sh - return 0.055 * temp / (bm.exp(temp / 3.8) - 1) + # 0.055 * temp / (exp(temp/3.8) - 1) == 0.055 * 3.8 / exprel(temp/3.8) + return (0.055 * 3.8) / _exprel(temp / 3.8) def f_p_beta(self, V): return 0.94 * bm.exp((-75. - V + self.V_sh) / 17.) diff --git a/brainpy/dyn/channels/potassium.py b/brainpy/dyn/channels/potassium.py index 738d3bc62..5f2f224e7 100644 --- a/brainpy/dyn/channels/potassium.py +++ b/brainpy/dyn/channels/potassium.py @@ -20,6 +20,8 @@ from typing import Union, Callable, Optional, Sequence +import jax.numpy as jnp + import brainpy.math as bm from brainpy.context import share from brainpy.dyn.ions.potassium import Potassium @@ -43,6 +45,20 @@ ] +def _exprel(x): + """Stable ``(exp(x) - 1) / x`` with a finite value *and* finite gradient at ``x == 0``. + + The HH/Markov rate functions have the removable-singularity form + ``num * temp / (1 - exp(-temp / k))`` which is ``0 / 0`` (NaN value and NaN + gradient) at the singular voltage. Rewriting them as ``num * k / exprel(...)`` + removes the singularity in value, but ``brainpy.math.exprel`` still yields a + NaN gradient at 0, so we use this branch-safe helper instead. + """ + small = jnp.abs(x) < 1e-7 + safe_x = jnp.where(small, 1.0, x) + return jnp.where(small, 1.0 + x / 2.0, jnp.expm1(safe_x) / safe_x) + + class PotassiumChannel(IonChannel): """Base class for sodium channel dynamics.""" @@ -219,7 +235,8 @@ def __init__( def f_p_alpha(self, V): tmp = V - self.V_sh - 15. - return 0.032 * tmp / (1. - bm.exp(-tmp / 5.)) + # 0.032 * tmp / (1 - exp(-tmp/5)) == 0.032 * 5 / exprel(-tmp/5) + return 0.16 / _exprel(-tmp / 5.) def f_p_beta(self, V): return 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.) @@ -287,7 +304,8 @@ def __init__( def f_p_alpha(self, V): c = 15 - V + self.V_sh - return 0.032 * c / (bm.exp(c / 5) - 1.) + # 0.032 * c / (exp(c/5) - 1) == 0.032 * 5 / exprel(c/5) + return 0.16 / _exprel(c / 5.) def f_p_beta(self, V): return 0.5 * bm.exp((10 - V + self.V_sh) / 40) @@ -356,7 +374,8 @@ def __init__( def f_p_alpha(self, V): temp = V - self.V_sh + 10 - return 0.01 * temp / (1 - bm.exp(-temp / 10)) + # 0.01 * temp / (1 - exp(-temp/10)) == 0.01 * 10 / exprel(-temp/10) + return 0.1 / _exprel(-temp / 10.) def f_p_beta(self, V): return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80) @@ -1188,7 +1207,8 @@ def __init__( def f_p_alpha(self, V): tmp = V - self.V_sh - 15. - return 0.032 * tmp / (1. - bm.exp(-tmp / 5.)) + # 0.032 * tmp / (1 - exp(-tmp/5)) == 0.032 * 5 / exprel(-tmp/5) + return 0.16 / _exprel(-tmp / 5.) def f_p_beta(self, V): return 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.) @@ -1258,7 +1278,8 @@ def __init__( def f_p_alpha(self, V): c = 15 - V + self.V_sh - return 0.032 * c / (bm.exp(c / 5) - 1.) + # 0.032 * c / (exp(c/5) - 1) == 0.032 * 5 / exprel(c/5) + return 0.16 / _exprel(c / 5.) def f_p_beta(self, V): return 0.5 * bm.exp((10 - V + self.V_sh) / 40) @@ -1329,7 +1350,8 @@ def __init__( def f_p_alpha(self, V): temp = V - self.V_sh + 10 - return 0.01 * temp / (1 - bm.exp(-temp / 10)) + # 0.01 * temp / (1 - exp(-temp/10)) == 0.01 * 10 / exprel(-temp/10) + return 0.1 / _exprel(-temp / 10.) def f_p_beta(self, V): return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80) diff --git a/brainpy/dyn/channels/potassium_compatible.py b/brainpy/dyn/channels/potassium_compatible.py index c2590ec18..1f707bd65 100644 --- a/brainpy/dyn/channels/potassium_compatible.py +++ b/brainpy/dyn/channels/potassium_compatible.py @@ -28,6 +28,19 @@ from brainpy.integrators import odeint, JointEq from brainpy.types import ArrayType + +def _exprel(x): + """Stable ``(exp(x) - 1) / x`` with finite value *and* gradient at ``x == 0``. + + The HH/Markov rate functions have the removable-singularity form + ``num * temp / (1 - exp(-temp / k))`` (equivalently ``.../(exp(temp/k) - 1)``), + which is ``0 / 0`` (NaN value and NaN gradient) at the singular voltage. + Rewriting them as ``num * k / exprel(...)`` removes the singularity. + """ + small = bm.abs(x) < 1e-7 + safe_x = bm.where(small, 1.0, x) + return bm.where(small, 1.0 + x / 2.0, bm.expm1(safe_x) / safe_x) + __all__ = [ 'IKDR_Ba2002', 'IK_TM1991', @@ -204,7 +217,7 @@ def __init__( def f_p_alpha(self, V): tmp = V - self.V_sh - 15. - return 0.032 * tmp / (1. - bm.exp(-tmp / 5.)) + return 0.16 / _exprel(-tmp / 5.) def f_p_beta(self, V): return 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.) @@ -274,7 +287,7 @@ def __init__( def f_p_alpha(self, V): c = 15 - V + self.V_sh - return 0.032 * c / (bm.exp(c / 5) - 1.) + return 0.16 / _exprel(c / 5.) def f_p_beta(self, V): return 0.5 * bm.exp((10 - V + self.V_sh) / 40) @@ -345,7 +358,7 @@ def __init__( def f_p_alpha(self, V): temp = V - self.V_sh + 10 - return 0.01 * temp / (1 - bm.exp(-temp / 10)) + return 0.1 / _exprel(-temp / 10.) def f_p_beta(self, V): return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80) diff --git a/brainpy/dyn/channels/sodium.py b/brainpy/dyn/channels/sodium.py index fb0ff882f..b1f9700e7 100644 --- a/brainpy/dyn/channels/sodium.py +++ b/brainpy/dyn/channels/sodium.py @@ -20,6 +20,8 @@ from typing import Union, Callable +import jax.numpy as jnp + import brainpy.math as bm from brainpy.context import share from brainpy.dyn.ions.sodium import Sodium @@ -36,6 +38,20 @@ ] +def _exprel(x): + """Stable ``(exp(x) - 1) / x`` with a finite value *and* finite gradient at ``x == 0``. + + The HH/Markov rate functions have the removable-singularity form + ``num * temp / (1 - exp(-temp / k))`` which is ``0 / 0`` (NaN value and NaN + gradient) at the singular voltage. Rewriting them as ``num * k / exprel(...)`` + removes the singularity in value, but ``brainpy.math.exprel`` still yields a + NaN gradient at 0, so we use this branch-safe helper instead. + """ + small = jnp.abs(x) < 1e-7 + safe_x = jnp.where(small, 1.0, x) + return jnp.where(small, 1.0 + x / 2.0, jnp.expm1(safe_x) / safe_x) + + class SodiumChannel(IonChannel): """Base class for sodium channel dynamics.""" @@ -212,11 +228,13 @@ def __init__( def f_p_alpha(self, V): temp = V - self.V_sh - 13. - return 0.32 * temp / (1. - bm.exp(-temp / 4.)) + # 0.32 * temp / (1 - exp(-temp/4)) == 0.32 * 4 / exprel(-temp/4) + return 1.28 / _exprel(-temp / 4.) def f_p_beta(self, V): temp = V - self.V_sh - 40. - return -0.28 * temp / (1. - bm.exp(temp / 5.)) + # -0.28 * temp / (1 - exp(temp/5)) == 0.28 * 5 / exprel(temp/5) + return 1.4 / _exprel(temp / 5.) def f_q_alpha(self, V): return 0.128 * bm.exp(-(V - self.V_sh - 17.) / 18.) @@ -296,11 +314,13 @@ def __init__( def f_p_alpha(self, V): temp = 13 - V + self.V_sh - return 0.32 * temp / (bm.exp(temp / 4) - 1.) + # 0.32 * temp / (exp(temp/4) - 1) == 0.32 * 4 / exprel(temp/4) + return 1.28 / _exprel(temp / 4.) def f_p_beta(self, V): temp = V - self.V_sh - 40 - return 0.28 * temp / (bm.exp(temp / 5) - 1) + # 0.28 * temp / (exp(temp/5) - 1) == 0.28 * 5 / exprel(temp/5) + return 1.4 / _exprel(temp / 5.) def f_q_alpha(self, V): return 0.128 * bm.exp((17 - V + self.V_sh) / 18) @@ -381,7 +401,8 @@ def __init__( def f_p_alpha(self, V): temp = V - self.V_sh - 5 - return 0.1 * temp / (1 - bm.exp(-temp / 10)) + # 0.1 * temp / (1 - exp(-temp/10)) == 0.1 * 10 / exprel(-temp/10) + return 1.0 / _exprel(-temp / 10.) def f_p_beta(self, V): return 4.0 * bm.exp(-(V - self.V_sh + 20) / 18) diff --git a/brainpy/dyn/channels/sodium_compatible.py b/brainpy/dyn/channels/sodium_compatible.py index 680151acd..f18000213 100644 --- a/brainpy/dyn/channels/sodium_compatible.py +++ b/brainpy/dyn/channels/sodium_compatible.py @@ -28,6 +28,19 @@ from brainpy.types import ArrayType from .base import IonChannel + +def _exprel(x): + """Stable ``(exp(x) - 1) / x`` with finite value *and* gradient at ``x == 0``. + + The HH/Markov rate functions have the removable-singularity form + ``num * temp / (1 - exp(-temp / k))`` (equivalently ``.../(exp(temp/k) - 1)``), + which is ``0 / 0`` (NaN value and NaN gradient) at the singular voltage. + Rewriting them as ``num * k / exprel(...)`` removes the singularity. + """ + small = bm.abs(x) < 1e-7 + safe_x = bm.where(small, 1.0, x) + return bm.where(small, 1.0 + x / 2.0, bm.expm1(safe_x) / safe_x) + __all__ = [ 'INa_Ba2002', 'INa_TM1991', @@ -198,11 +211,11 @@ def __init__( def f_p_alpha(self, V): temp = V - self.V_sh - 13. - return 0.32 * temp / (1. - bm.exp(-temp / 4.)) + return 1.28 / _exprel(-temp / 4.) def f_p_beta(self, V): temp = V - self.V_sh - 40. - return -0.28 * temp / (1. - bm.exp(temp / 5.)) + return 1.4 / _exprel(temp / 5.) def f_q_alpha(self, V): return 0.128 * bm.exp(-(V - self.V_sh - 17.) / 18.) @@ -284,11 +297,11 @@ def __init__( def f_p_alpha(self, V): temp = 13 - V + self.V_sh - return 0.32 * temp / (bm.exp(temp / 4) - 1.) + return 1.28 / _exprel(temp / 4.) def f_p_beta(self, V): temp = V - self.V_sh - 40 - return 0.28 * temp / (bm.exp(temp / 5) - 1) + return 1.4 / _exprel(temp / 5.) def f_q_alpha(self, V): return 0.128 * bm.exp((17 - V + self.V_sh) / 18) @@ -371,7 +384,7 @@ def __init__( def f_p_alpha(self, V): temp = V - self.V_sh - 5 - return 0.1 * temp / (1 - bm.exp(-temp / 10)) + return 1.0 / _exprel(-temp / 10.) def f_p_beta(self, V): return 4.0 * bm.exp(-(V - self.V_sh + 20) / 18) diff --git a/brainpy/dyn/ions/base.py b/brainpy/dyn/ions/base.py index 3a278c217..64526003c 100644 --- a/brainpy/dyn/ions/base.py +++ b/brainpy/dyn/ions/base.py @@ -52,7 +52,7 @@ def __init__(self, *ions, **channels): self.ions: Sequence['Ion'] = tuple(ions) self._ion_classes = tuple([type(ion) for ion in self.ions]) for k, v in channels.items(): - self.add_elem(k=v) + self.add_elem(**{k: v}) def update(self, V): nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values()) diff --git a/brainpy/dyn/ions/potassium.py b/brainpy/dyn/ions/potassium.py index 9a664531c..8d7ef7905 100644 --- a/brainpy/dyn/ions/potassium.py +++ b/brainpy/dyn/ions/potassium.py @@ -42,7 +42,7 @@ def __init__( self, size: Shape, keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -950., + E: Union[float, ArrayType, Initializer, Callable] = -95., C: Union[float, ArrayType, Initializer, Callable] = 0.0400811, method: str = 'exp_auto', name: Optional[str] = None, diff --git a/brainpy/dyn/neurons/lif.py b/brainpy/dyn/neurons/lif.py index 28d693c3d..d4fe42936 100644 --- a/brainpy/dyn/neurons/lif.py +++ b/brainpy/dyn/neurons/lif.py @@ -1106,7 +1106,14 @@ def __init__( self._V_initializer = is_initializer(V_initializer) # integral - self.integral = odeint(method=method, f=self.derivative) + # NOTE: ``self.noise`` is already created by ``ExpIFLTC.__init__`` above + # (via ``noise=noise``). Guard the integral on it so a configured + # ``noise=`` is honoured with ``sdeint`` instead of being silently + # dropped (mirrors every other ``*RefLTC`` and the non-Ref class). + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -3811,8 +3818,11 @@ def update(self, x=None): V += (self.V_reset - V) * spike_no_grad else: raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") - I1 += spike * (self.R1 * I1 + self.A1 - I1) - I2 += spike * (self.R2 * I2 + self.A2 - I2) + # Use ``spike_no_grad`` for every state reset so that ``detach_spk`` + # actually stops the gradient through the spike; the raw ``spike`` + # here would otherwise leak it into the I1/I2 resets. + I1 += spike_no_grad * (self.R1 * I1 + self.A1 - I1) + I2 += spike_no_grad * (self.R2 * I2 + self.A2 - I2) V_th += (bm.maximum(self.V_th_reset, V_th) - V_th) * spike_no_grad spike_ = spike_no_grad > 0. # will be used in other place, like Delta Synapse, so stop its gradient @@ -4492,8 +4502,11 @@ def update(self, x=None): if isinstance(self.mode, bm.TrainingMode): spike = self.spk_fun(V - self.V_th) spike_no_grad = stop_gradient(spike) if self.detach_spk else spike - V += spike * (self.c - V) - u += spike * self.d + # Use ``spike_no_grad`` for the state resets so that ``detach_spk`` + # actually stops the gradient through the spike (raw ``spike`` here + # would otherwise make ``detach_spk`` a no-op for the V/u resets). + V += spike_no_grad * (self.c - V) + u += spike_no_grad * self.d spike_ = spike_no_grad > 0. # will be used in other place, like Delta Synapse, so stop its gradient if self.ref_var: diff --git a/brainpy/dyn/projections/align_post.py b/brainpy/dyn/projections/align_post.py index af192cc68..2f0248d3b 100644 --- a/brainpy/dyn/projections/align_post.py +++ b/brainpy/dyn/projections/align_post.py @@ -383,7 +383,7 @@ def __init__( def update(self, x): current = self.comm(x) - g = self.syn(self.comm(x)) + g = self.syn(current) self.refs['out'].bind_cond(g) # synapse post current return current diff --git a/brainpy/dyn/projections/base.py b/brainpy/dyn/projections/base.py index 9bdca17d2..651a2d6ed 100644 --- a/brainpy/dyn/projections/base.py +++ b/brainpy/dyn/projections/base.py @@ -12,14 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from brainpy import math as bm -from brainpy.mixin import ReturnInfo +# Re-export the real projection base classes. This module previously held a +# byte-for-byte duplicate of the private ``_get_return`` helper in ``utils.py`` +# (H-40); that helper is private, so ``from .base import *`` exported nothing and +# the duplicate was both dead and misleading versus the real base classes. +from brainpy.dynsys import Projection +from .conn import SynConn - -def _get_return(return_info): - if isinstance(return_info, bm.Variable): - return return_info.value - elif isinstance(return_info, ReturnInfo): - return return_info.get_data() - else: - raise NotImplementedError +__all__ = ['Projection', 'SynConn'] diff --git a/brainpy/dyn/projections/inputs.py b/brainpy/dyn/projections/inputs.py index 5c78ce210..39197c4fb 100644 --- a/brainpy/dyn/projections/inputs.py +++ b/brainpy/dyn/projections/inputs.py @@ -162,16 +162,19 @@ def update(self): p = self.freq * share['dt'] / 1e3 a = self.num_input * p b = self.num_input * (1 - p) + # standard deviation of the Binomial(num_input, p) distribution: + # sqrt(num_input * p * (1 - p)) = sqrt(b * p), NOT the variance b * p. + scale = bm.sqrt(b * p) if isinstance(share['dt'], numbers.Number): # dt is not traced if (a > 5) and (b > 5): - inp = bm.random.normal(a, b * p, self.target_var.shape) + inp = bm.random.normal(a, scale, self.target_var.shape) else: inp = bm.random.binomial(self.num_input, p, self.target_var.shape) else: # dt is traced inp = bm.cond((a > 5) * (b > 5), - lambda: bm.random.normal(a, b * p, self.target_var.shape), + lambda: bm.random.normal(a, scale, self.target_var.shape), lambda: bm.random.binomial(self.num_input, p, self.target_var.shape)) # inp = bm.sharding.partition(inp, self.target_var.sharding) diff --git a/brainpy/dyn/projections/plasticity.py b/brainpy/dyn/projections/plasticity.py index c6ccdd823..03fa3d509 100644 --- a/brainpy/dyn/projections/plasticity.py +++ b/brainpy/dyn/projections/plasticity.py @@ -225,18 +225,23 @@ def update(self): raise AttributeError(f'{self} needs a "spike" variable for the post-synaptic neuron group.') post_spike = self.refs['post'].spike.value + # weight bounds: pass ``None`` through unchanged (``bm.as_jax(None)`` raises); + # ``jnp.clip`` accepts ``None`` for an unbounded side. + w_min = None if self.W_min is None else bm.as_jax(self.W_min) + w_max = None if self.W_max is None else bm.as_jax(self.W_max) + # weight updates Apost = self.refs['post_trace'].g.value self.comm.stdp_update( on_pre={"spike": bm.as_jax(pre_spike), "trace": bm.as_jax(-Apost * self.A2)}, - w_min=bm.as_jax(self.W_min), - w_max=bm.as_jax(self.W_max), + w_min=w_min, + w_max=w_max, ) Apre = self.refs['pre_trace'].g.value self.comm.stdp_update( on_post={"spike": bm.as_jax(post_spike), "trace": bm.as_jax(Apre * self.A1)}, - w_min=bm.as_jax(self.W_min), - w_max=bm.as_jax(self.W_max), + w_min=w_min, + w_max=w_max, ) # synaptic currents diff --git a/brainpy/dyn/rates/populations.py b/brainpy/dyn/rates/populations.py index c774c5fd0..8ae90fdc8 100644 --- a/brainpy/dyn/rates/populations.py +++ b/brainpy/dyn/rates/populations.py @@ -718,7 +718,7 @@ def dx(self, x, t, y, x_ext, a, w): return (a - x * x - y * y) * x - w * y + x_ext def dy(self, y, t, x, y_ext, a, w): - return (a - x * x - y * y) * y - w * y + y_ext + return (a - x * x - y * y) * y + w * x + y_ext def update(self, inp_x=None, inp_y=None): t = share.load('t') @@ -1048,7 +1048,7 @@ def update(self, inp_e=None, inp_i=None): has_noise = bm.any(self.noise_e != 0.) if has_noise: - de += bm.random.randn(self.varshape) * self.noise_e + de += bm.random.randn(*self.varshape) * self.noise_e de = de / self.tau_e self.e.value = bm.maximum(self.e + de * dt, 0.) @@ -1057,7 +1057,7 @@ def update(self, inp_e=None, inp_i=None): has_noise = bm.any(self.noise_i != 0.) if has_noise: - di += bm.random.randn(self.varshape) * self.noise_i + di += bm.random.randn(*self.varshape) * self.noise_i di = di / self.tau_i self.i.value = bm.maximum(self.i + di * dt, 0.) return self.e.value diff --git a/brainpy/dyn/rates/reservoir.py b/brainpy/dyn/rates/reservoir.py index 662cffff3..57912b0b8 100644 --- a/brainpy/dyn/rates/reservoir.py +++ b/brainpy/dyn/rates/reservoir.py @@ -220,10 +220,12 @@ def update(self, x): hidden += bm.sparse.seg_matmul(self.state, sparse) else: hidden += self.state @ self.Wrec + if self.bias is not None: + hidden += self.bias if self.activation_type == 'internal': hidden = self.activation(hidden) if self.noise_rec > 0.: - hidden += self.noise_rec * bm.random.uniform(-1, -1, self.state.shape) + hidden += self.noise_rec * bm.random.uniform(-1, 1, self.state.shape) # new state/output state = (1 - self.leaky_rate) * self.state + self.leaky_rate * hidden if self.activation_type == 'external': diff --git a/brainpy/dyn/rates/rnncells.py b/brainpy/dyn/rates/rnncells.py index ef4107319..eae8b6fc4 100644 --- a/brainpy/dyn/rates/rnncells.py +++ b/brainpy/dyn/rates/rnncells.py @@ -398,7 +398,7 @@ def h(self): def h(self, value): if self.state is None: raise ValueError('Cannot set "h" state. Because the state is not initialized.') - self.state[:self.state.shape[0] // 2, :] = value + self.state[..., :self.state.shape[-1] // 2] = value @property def c(self): @@ -409,7 +409,7 @@ def c(self): def c(self, value): if self.state is None: raise ValueError('Cannot set "c" state. Because the state is not initialized.') - self.state[self.state.shape[0] // 2:, :] = value + self.state[..., self.state.shape[-1] // 2:] = value class _ConvNDLSTMCell(Layer): diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py index 80f8ca849..96d666d18 100644 --- a/brainpy/dyn/synapses/abstract_models.py +++ b/brainpy/dyn/synapses/abstract_models.py @@ -859,7 +859,7 @@ def reset_state(self, batch_or_mode=None, **kwargs): @property def derivative(self): - du = lambda u, t: self.U - u / self.tau_f + du = lambda u, t: -u / self.tau_f dx = lambda x, t: (1 - x) / self.tau_d return JointEq(du, dx) @@ -877,8 +877,11 @@ def update(self, pre_spike): # x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x # --- simplified code: - u = pre_spike * self.U * (1 - self.u) + u - x = pre_spike * -u * self.x + x + # Apply the discrete spike jumps to the just-integrated (decayed) locals + # ``u``/``x`` rather than the pre-decay ``self.u``/``self.x``. The ``x`` + # jump uses the already-updated ``u``. + u = u + pre_spike * self.U * (1 - u) + x = x - pre_spike * u * x self.x.value = x self.u.value = u diff --git a/brainpy/dynold/experimental/others.py b/brainpy/dynold/experimental/others.py index 1759e0a13..461a49308 100644 --- a/brainpy/dynold/experimental/others.py +++ b/brainpy/dynold/experimental/others.py @@ -70,15 +70,18 @@ def update(self): p = self.freq * share.dt / 1e3 a = self.num_input * p b = self.num_input * (1 - p) + # standard deviation of the Binomial(num_input, p) distribution: + # sqrt(num_input * p * (1 - p)) = sqrt(b * p), NOT the variance b * p. + scale = bm.sqrt(b * p) if isinstance(share.dt, (int, float)): # dt is not in tracing if (a > 5) and (b > 5): - inp = bm.random.normal(a, b * p, self.target_shape) + inp = bm.random.normal(a, scale, self.target_shape) else: inp = bm.random.binomial(self.num_input, p, self.target_shape) else: # dt is in tracing inp = bm.cond((a > 5) * (b > 5), - lambda _: bm.random.normal(a, b * p, self.target_shape), + lambda _: bm.random.normal(a, scale, self.target_shape), lambda _: bm.random.binomial(self.num_input, p, self.target_shape), None) return inp * self.weight diff --git a/brainpy/dynold/neurons/reduced_models.py b/brainpy/dynold/neurons/reduced_models.py index 84a51f952..ad75f94f2 100644 --- a/brainpy/dynold/neurons/reduced_models.py +++ b/brainpy/dynold/neurons/reduced_models.py @@ -297,12 +297,12 @@ class ExpIF(lif.ExpIFRef): ------------- -------------- -------- --------------------------------------------------- V_rest -65 mV Resting potential. V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike. + V_th -55 mV Threshold potential of spike. V_T -59.9 mV Threshold potential of generating action potential. delta_T 3.48 \ Spike slope factor. R 1 \ Membrane resistance. tau 10 \ Membrane time constant. Compute by R * C. - tau_ref 1.7 \ Refractory period length. + tau_ref 0. \ Refractory period length. ============= ============== ======== =================================================== **Model Variables** @@ -412,7 +412,7 @@ class AdExIF(lif.AdExIFRef): ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ V_rest -65 mV Resting potential. V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike and reset. + V_th -55 mV Threshold potential of spike and reset. V_T -59.9 mV Threshold potential of generating action potential. delta_T 3.48 \ Spike slope factor. a 1 \ The sensitivity of the recovery variable :math:`u` to the sub-threshold fluctuations of the membrane potential :math:`v` diff --git a/brainpy/dynold/synapses/compat.py b/brainpy/dynold/synapses/compat.py index b27c6348a..008df156d 100644 --- a/brainpy/dynold/synapses/compat.py +++ b/brainpy/dynold/synapses/compat.py @@ -21,7 +21,7 @@ from brainpy.dynold.synouts import COBA, CUBA from brainpy.initialize import Initializer from brainpy.types import ArrayType -from .abstract_models import Delta, Exponential, DualExponential +from .abstract_models import Delta, Exponential, DualExponential, Alpha __all__ = [ 'DeltaSynapse', @@ -205,7 +205,7 @@ def __init__( output=COBA(E=E)) -class AlphaCUBA(DualExpCUBA): +class AlphaCUBA(Alpha): r"""Current-based alpha synapse model. .. deprecated:: 2.1.13 @@ -225,19 +225,23 @@ def __init__( method: str = 'exp_auto', name: str = None ): + # Alpha synapses have a single time constant; route through the + # single-tau ``Alpha`` implementation instead of a dual-exponential + # with ``tau_rise == tau_decay`` (which divides by zero in the peak + # normalizer ``A = tau_decay / (tau_decay - tau_rise)``). super().__init__(pre=pre, post=post, conn=conn, - conn_type=conn_type, + comp_method=conn_type, delay_step=delay_step, g_max=g_max, tau_decay=tau_decay, - tau_rise=tau_decay, method=method, - name=name) + name=name, + output=CUBA()) -class AlphaCOBA(DualExpCOBA): +class AlphaCOBA(Alpha): """Conductance-based alpha synapse model. .. deprecated:: 2.1.13 @@ -258,13 +262,17 @@ def __init__( method: str = 'exp_auto', name: str = None ): + # Alpha synapses have a single time constant; route through the + # single-tau ``Alpha`` implementation instead of a dual-exponential + # with ``tau_rise == tau_decay`` (which divides by zero in the peak + # normalizer ``A = tau_decay / (tau_decay - tau_rise)``). super().__init__(pre=pre, post=post, conn=conn, - conn_type=conn_type, + comp_method=conn_type, delay_step=delay_step, - g_max=g_max, E=E, + g_max=g_max, tau_decay=tau_decay, - tau_rise=tau_decay, method=method, - name=name) + name=name, + output=COBA(E=E)) diff --git a/brainpy/dynold/synapses/learning_rules.py b/brainpy/dynold/synapses/learning_rules.py index b413c7bb3..c71ea0d89 100644 --- a/brainpy/dynold/synapses/learning_rules.py +++ b/brainpy/dynold/synapses/learning_rules.py @@ -36,6 +36,16 @@ def __init__(self, size, keep_size, tau, U, tau_f, tau_d, mode=None, method='exp exp = synapses.Expon(size, keep_size, tau=tau, method=method, mode=mode) super().__init__(stp, exp) + def update(self, pre_spike): + # ``synapses.STP.update`` returns the synaptic resource fraction ``u*x``, + # which is non-zero even at rest (≈U). Feeding it directly into ``Expon`` + # (which treats its input as an additive current event) would inject + # ``u*x`` on *every* step, so the current keeps growing with zero + # presynaptic spikes. Gate the injected amplitude by ``pre_spike`` so the + # synaptic current only jumps when a presynaptic spike actually arrives. + ux = self[0](pre_spike) + return self[1](pre_spike * ux) + class STP(_TwoEndConnAlignPre): r"""Short-term plasticity model. diff --git a/brainpy/encoding/stateless_encoding.py b/brainpy/encoding/stateless_encoding.py index 5388e7bb5..d358b2c52 100644 --- a/brainpy/encoding/stateless_encoding.py +++ b/brainpy/encoding/stateless_encoding.py @@ -88,10 +88,22 @@ def single_step(self, x, i_step: int = None): Returns: out: Array. The encoded spike train. """ + # Draw a single Bernoulli sample for one step. (Delegating to + # ``multi_steps`` with ``n_time=None`` would crash on ``int(None / dt)``, + # and the old ``cond(..., self.multi_steps, x)`` passed the wrong number + # of arguments to ``multi_steps``.) + x = self._normalize(x) + spikes = bm.asarray(bm.random.rand(*x.shape) < x, dtype=x.dtype) if i_step is None: - return self.multi_steps(x, n_time=None) - else: - return bm.cond(bm.as_jax(i_step < self.first_spk_step), self._zero_out, self.multi_steps, x) + return spikes + # Before the first-spike step, emit no spikes. + before_first = bm.as_jax(i_step) < self.first_spk_step + return bm.asarray(bm.where(before_first, bm.zeros_like(spikes), spikes), dtype=x.dtype) + + def _normalize(self, x): + if (self.min_val is not None) and (self.max_val is not None): + x = (x - self.min_val) / (self.max_val - self.min_val) + return x * self.gain + self.offset def multi_steps(self, x, n_time: Optional[float]): """Generate spikes at multiple steps according to the inputs. @@ -108,11 +120,11 @@ def multi_steps(self, x, n_time: Optional[float]): Returns: out: Array. The encoded spike train. """ - n_time = int(n_time / bm.get_dt()) + # ``n_time=None`` means "encode the current single step" (see docstring); + # only convert to a step count when an actual duration is given. + n_time = None if n_time is None else int(n_time / bm.get_dt()) - if (self.min_val is not None) and (self.max_val is not None): - x = (x - self.min_val) / (self.max_val - self.min_val) - x = x * self.gain + self.offset + x = self._normalize(x) if n_time is not None and self.first_spk_step > 0: pre = bm.zeros((self.first_spk_step,) + x.shape, dtype=x.dtype) shape = ((n_time - self.first_spk_step,) + x.shape) diff --git a/brainpy/integrators/fde/Caputo.py b/brainpy/integrators/fde/Caputo.py index 0f48fc586..23afc6000 100644 --- a/brainpy/integrators/fde/Caputo.py +++ b/brainpy/integrators/fde/Caputo.py @@ -147,7 +147,8 @@ def __init__( f'but we got {self.alpha}.') # initial values - self.inits = check_inits(inits, self.variables) + inits = check_inits(inits, self.variables) + self.inits = bm.VarDict({v: bm.Variable(inits[v]) for v in self.variables}) # coefficients rgamma_alpha = bm.asarray(rgamma(bm.as_numpy(self.alpha))) @@ -163,6 +164,14 @@ def __init__( self.set_integral(self._integral_func) + def reset(self, inits): + """Reset the integrator states so it can be re-run from new initial values.""" + self.idx.value = bm.asarray([1]) + inits = check_inits(inits, self.variables) + for key, val in inits.items(): + self.inits[key] = val + self.f_states[key] = bm.zeros((self.num_memory,) + val.shape, dtype=self.f_states[key].dtype) + def _check_step(self, args): dt, t = args raise ValueError(f'The maximum number of step is {self.num_memory}, ' @@ -198,8 +207,8 @@ def _integral_func(self, *args, **kwargs): integrals = [] idx = ((self.num_memory - 1 - self.idx) + bm.arange(self.num_memory)) % self.num_memory for i, key in enumerate(self.variables): - integral = self.inits[key] + self.coef[idx, i] @ self.f_states[key] - integrals.append(integral * (dt ** self.alpha[i] / self.alpha[i])) + integral = self.coef[idx, i] @ self.f_states[key] + integrals.append(self.inits[key] + integral * (dt ** self.alpha[i] / self.alpha[i])) self.idx.value = (self.idx + 1) % self.num_memory # return integrals @@ -372,7 +381,7 @@ def hists(self, var=None, numpy=True): for k in self.variables} hists_ = {k: bm.cumsum(v, axis=0) for k, v in hists_.items()} if numpy: - hists_ = {k: v.numpy() for k, v in hists_} + hists_ = {k: v.numpy() for k, v in hists_.items()} return hists_ else: assert var in self.variables, (f'"{var}" is not defined in equation ' diff --git a/brainpy/integrators/fde/GL.py b/brainpy/integrators/fde/GL.py index 29f572e66..76c3fc555 100644 --- a/brainpy/integrators/fde/GL.py +++ b/brainpy/integrators/fde/GL.py @@ -184,7 +184,7 @@ def reset(self, inits): for key, val in inits.items(): delay = bm.zeros((self.num_memory,) + val.shape, dtype=val.dtype) delay[0] = val - self.delays[key].value = delay + self.delays[key + '_delay'].value = delay @property def binomial_coef(self): diff --git a/brainpy/integrators/fde/generic.py b/brainpy/integrators/fde/generic.py index 5a349f82a..be2f6890c 100644 --- a/brainpy/integrators/fde/generic.py +++ b/brainpy/integrators/fde/generic.py @@ -24,7 +24,7 @@ name2method = {} -_DEFAULT_DDE_METHOD = 'l1' +_DEFAULT_FDE_METHOD = 'l1' def fdeint( @@ -60,7 +60,7 @@ def fdeint( integral : FDEIntegrator The numerical solver of `f`. """ - method = _DEFAULT_DDE_METHOD if method is None else method + method = _DEFAULT_FDE_METHOD if method is None else method if method not in name2method: raise ValueError(f'Unknown FDE numerical method "{method}". Currently ' f'BrainPy supports: {list(name2method.keys())}') @@ -72,7 +72,7 @@ def fdeint( def set_default_fdeint(method): - """Set the default ODE numerical integrator method for differential equations. + """Set the default FDE numerical integrator method for fractional differential equations. Parameters:: @@ -82,25 +82,25 @@ def set_default_fdeint(method): if not isinstance(method, str): raise ValueError(f'Only support string, not {type(method)}.') if method not in name2method: - raise ValueError(f'Unsupported ODE_INT numerical method: {method}.') + raise ValueError(f'Unsupported FDE numerical method: {method}.') - global _DEFAULT_DDE_METHOD - _DEFAULT_ODE_METHOD = method + global _DEFAULT_FDE_METHOD + _DEFAULT_FDE_METHOD = method def get_default_fdeint(): - """Get the default ODE numerical integrator method. + """Get the default FDE numerical integrator method. Returns:: method : str The default numerical integrator method. """ - return _DEFAULT_DDE_METHOD + return _DEFAULT_FDE_METHOD def register_fde_integrator(name, integrator): - """Register a new ODE integrator. + """Register a new FDE integrator. Parameters:: @@ -117,5 +117,5 @@ def register_fde_integrator(name, integrator): def get_supported_methods(): - """Get all supported numerical methods for DDEs.""" + """Get all supported numerical methods for FDEs.""" return list(name2method.keys()) diff --git a/brainpy/integrators/joint_eq.py b/brainpy/integrators/joint_eq.py index 55c1a3af5..a4954963d 100644 --- a/brainpy/integrators/joint_eq.py +++ b/brainpy/integrators/joint_eq.py @@ -186,7 +186,12 @@ def __init__(self, *eqs): elif (key not in vars_in_eqs) and (key not in all_arg_pars): all_kwarg_pars[key] = value else: - raise DiffEqError + raise DiffEqError( + f'The keyword argument "{key}" conflicts with an existing name ' + f'in the joint equations: it is already used as a state variable ' + f'or a positional parameter. A keyword argument cannot reuse the ' + f'name of a state variable or positional parameter.' + ) # # variable names provided # if not isinstance(variables, (tuple, list)): diff --git a/brainpy/integrators/ode/adaptive_rk.py b/brainpy/integrators/ode/adaptive_rk.py index ee6864f3d..85826c2ab 100644 --- a/brainpy/integrators/ode/adaptive_rk.py +++ b/brainpy/integrators/ode/adaptive_rk.py @@ -184,7 +184,7 @@ def __init__(self, keywords['error'] = 'the local truncation error' for v in self.variables: keywords[f'{v}_te'] = 'the local truncation error' - self.code_scope['tol'] = tol + self.code_scope['tol'] = self.tol self.code_scope['math'] = jnp utils.check_kws(self.arg_names, keywords) @@ -212,13 +212,20 @@ def build(self): result.append(f'd{v}_k{i + 1} * {C.DT} * {diff}') if len(result) > 0: if self.var_type == C.SCALAR_VAR: - self.code_lines.append(f' {v}_te = abs({" + ".join(result)})') + self.code_lines.append(f' {v}_te = math.abs({" + ".join(result)})') else: - self.code_lines.append(f' {v}_te = sum(abs({" + ".join(result)}))') + self.code_lines.append(f' {v}_te = math.sum(math.abs({" + ".join(result)}))') errors_.append(f'{v}_te') if len(errors_) > 0: self.code_lines.append(f' error = {" + ".join(errors_)}') - self.code_lines.append(f' {C.DT}_new = math.where(error > tol, 0.9*{C.DT}*(tol/error)**0.2, {C.DT})') + # Two-sided step-size controller: shrink dt when the error is + # above tolerance, grow it when the error is comfortably below. + # The growth/shrink factor is clamped to keep the step change + # bounded so that dt can both decrease and increase. + self.code_lines.append( + f' factor = 0.9 * (tol / (error + 1e-12)) ** 0.2') + self.code_lines.append(' factor = math.clip(factor, 0.2, 5.0)') + self.code_lines.append(f' {C.DT}_new = {C.DT} * factor') return_args.append(f'{C.DT}_new') # returns self.code_lines.append(f' return {", ".join(return_args)}') @@ -529,7 +536,7 @@ class BoSh3(AdaptiveRKIntegrator): (0.0, 0.75), ('2/9', '1/3', '4/9')] B1 = ['2/9', '1/3', '4/9', 0.0] - B2 = ['-5/72', 1 / 12, '1/9', '-1/8'] + B2 = ['7/24', 0.25, '1/3', 0.125] C = [0., 0.5, 0.75, 1.0] diff --git a/brainpy/integrators/ode/exponential.py b/brainpy/integrators/ode/exponential.py index 4a20eb911..6d4ff9890 100644 --- a/brainpy/integrators/ode/exponential.py +++ b/brainpy/integrators/ode/exponential.py @@ -370,16 +370,17 @@ def _build_integrator(self, eq): # integration function def integral(*args, **kwargs): assert len(args) > 0 - if args[0].dtype not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]: + x0 = bm.as_jax(args[0]) + if x0.dtype not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]: raise ValueError( 'The input data type should be float32, float64, float16, ' 'or bfloat16 when using Exponential Euler method.' - f'But we got {args[0].dtype}.' + f'But we got {x0.dtype}.' ) dt = kwargs.pop(C.DT, self.dt) linear, derivative = bm.vector_grad(eq, argnums=0, return_value=True)(*args, **kwargs) phi = bm.exprel(dt * linear) - return args[0] + dt * phi * derivative + return x0 + dt * phi * derivative return [(integral, vars, pars), ] diff --git a/brainpy/integrators/runner.py b/brainpy/integrators/runner.py index 54c78a9f2..b429e2647 100644 --- a/brainpy/integrators/runner.py +++ b/brainpy/integrators/runner.py @@ -251,8 +251,8 @@ def _step_fun_integrator(self, static_args, dyn_args, t, i): if len(self.target.variables) == 1: self.variables[self.target.variables[0]].update(update_values) else: - for i, v in enumerate(self.target.variables): - self.variables[v].update(update_values[i]) + for j, v in enumerate(self.target.variables): + self.variables[v].update(update_values[j]) # progress bar if self.progress_bar: diff --git a/brainpy/integrators/sde/base.py b/brainpy/integrators/sde/base.py index a43de2f72..c68437e57 100644 --- a/brainpy/integrators/sde/base.py +++ b/brainpy/integrators/sde/base.py @@ -17,6 +17,7 @@ import jax.numpy as jnp +from brainpy import _errors as errors from brainpy import math as bm from brainpy.integrators import constants, utils from brainpy.integrators.base import Integrator @@ -95,5 +96,14 @@ def __init__( self.show_code = show_code def _check_vector_wiener_dim(self, noise_size, var_size): - if noise_size[:-1] > var_size[-len(noise_size) + 1:]: - raise ValueError(f"Incompatible shapes for shapes of noise {noise_size} and variable {var_size}") + noise_size = tuple(noise_size) + var_size = tuple(var_size) + # For a vector Wiener process the diffusion value has shape + # ``var_size + (m,)``: its leading dimensions (all but the last, which + # holds the ``m`` noise channels) must match the variable shape exactly. + if tuple(noise_size[:-1]) != var_size: + raise ValueError( + f"Incompatible shapes for vector Wiener process: the diffusion " + f"value has shape {noise_size}, so its leading dimensions " + f"{noise_size[:-1]} must equal the variable shape {var_size}." + ) diff --git a/brainpy/integrators/sde/normal.py b/brainpy/integrators/sde/normal.py index 0080d0e63..3165b81ac 100644 --- a/brainpy/integrators/sde/normal.py +++ b/brainpy/integrators/sde/normal.py @@ -17,6 +17,7 @@ import jax.numpy as jnp +from brainpy import _errors as errors from brainpy import math as bm from brainpy.integrators import constants, utils, joint_eq from brainpy.integrators.constants import DT diff --git a/brainpy/integrators/sde/srk_strong.py b/brainpy/integrators/sde/srk_strong.py deleted file mode 100644 index 98faefdd3..000000000 --- a/brainpy/integrators/sde/srk_strong.py +++ /dev/null @@ -1,473 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from brainpy import math -from brainpy.integrators import constants, utils - -__all__ = [ - 'srk1_strong', -] - -_SDE_UNKNOWN_NO = 0 - - -def basic_info(f, g): - vdt = 'dt' - if f.__name__.isidentifier(): - func_name = f.__name__ - elif g.__name__.isidentifier(): - func_name = g.__name__ - else: - global _SDE_UNKNOWN_NO - func_name = f'unknown_sde{_SDE_UNKNOWN_NO}' - func_new_name = constants.SDE_INT + func_name - variables, parameters, arguments = utils.get_args(f) - return vdt, variables, parameters, arguments, func_new_name - - -def _vector_wiener_terms(code_lines, sde_type, vdt, shape_D, shape_m): - if sde_type == constants.ITO_SDE: - I2 = f'0.5*(_term3 - {vdt} * math.eye({shape_m})) + _a*0.5*{vdt}/math.pi' - elif sde_type == constants.STRA_SDE: - I2 = f'0.5*_term3 + _a*0.5*dt/math.pi' - else: - raise ValueError(f'Unknown SDE_INT type: {sde_type}. We only supports {constants.SUPPORTED_INTG_TYPE}.') - - if shape_D: - shape_D = shape_D + '+' - - noise_string = f''' - # Noise Terms # - # ----------- # - - # single Ito integrals - _I1 = math.normal(0., {vdt}_sqrt, {shape_D}({shape_m},)) - # double Ito integrals - _h = (2.0 / {vdt}) ** 0.5) - _a = math.zeros(shape={shape_D}({shape_m}, {shape_m})) - for _k in range(1, num_iter + 1): - _x = math.normal(loc=0., scale=1., size={shape_D}({shape_m}, 1)) - _y = math.normal(loc=0., scale=1., size={shape_D}(1, {shape_m})) + _h * _I1 - _term1 = math.matmul(_x, _y) - _term2 = math.matmul(math.reshape(_y, {shape_D}({shape_m}, 1)), - math.reshape(_x, {shape_D}(1, {shape_m}))) - _a += (_term1 - _term2) / _k - _I1_rs = math.reshape(_I1, {shape_D}({shape_m}, 1)) - _term3 = math.matmul(_I1_rs, math.reshape(_I1, {shape_D}(1, {shape_m}))) - _I2 = {I2} - ''' - noise_lines = noise_string.split('\n') - code_lines.extend(noise_lines) - - -# ---------- -# Wrapper -# ---------- - - -def _srk2_pop_var_vector_wiener(sde_type, code_lines, variables, parameters, vdt): - # shape information - # ----- - all_f = [f'f_{var}' for var in variables] - all_g = [f'g_{var}' for var in variables] - noise_string = f''' - {", ".join(all_f)} = f({", ".join(variables + parameters)}) # shape = (..) - {", ".join(all_g)} = g({", ".join(variables + parameters)}) # shape = (.., m) - noise_shape = math.shape(g_x1) - _D = noise_shape[:-1] - _m = noise_shape[-1] - ''' - code_lines.extend(noise_string.split("\n")) - - # noise terms - _vector_wiener_terms(code_lines, sde_type, vdt, shape_D='_D', shape_m='_m') - - # numerical integration - # step 1 - # --- - # g_x1_rs = math.reshape(g_x1, _D + (1, _m)) - # g_x2_rs = math.reshape(g_x2, _D + (1, _m)) - for var in variables: - code_lines.append(f" g_{var}_rs = math.reshape(g_{var}, _D+(1, _m))") - # step 2 - # --- - # g_H1_x1 = math.reshape(math.matmul(g_x1_rs, _I2) / dt_sqrt, _D + (_m,)) - # g_H1_x2 = math.reshape(math.matmul(g_x2_rs, _I2) / dt_sqrt, _D + (_m,)) - for var in variables: - code_lines.append(f' g_H1_{var} = math.reshape(math.matmul(g_{var}_rs, _I2) / {vdt}_sqrt, _D + (_m,))') - # step 3 - # --- - # x1_rs = math.reshape(x1, _D + (1,)) - # x2_rs = math.reshape(x2, _D + (1,)) - for var in variables: - code_lines.append(f' {var}_rs = math.reshape({var}, _D + (1,))') - # step 4 - # --- - # H2_x1 = x1_rs + g_H1_x1 - # H3_x1 = x1_rs - g_H1_x1 - for var in variables: - code_lines.append(f' H2_{var} = {var}_rs + g_H1_{var}') - code_lines.append(f' H3_{var} = {var}_rs - g_H1_{var}') - code_lines.append(' ') - # step 5 - # --- - # _g_x1 = math.matmul(g_x1_rs, _I1_rs) - for var in variables: - code_lines.append(f' _g_{var} = math.matmul(g_{var}_rs, _I1_rs)') - # step 6 - # ---- - # x1_new = x1 + f_x1 + _g_x1[..., 0, 0] - for var in variables: - code_lines.append(f' {var}_new = {var} + f_{var} + _g_{var}[..., 0, 0]') - # for _k in range(_m): - code_lines.append('for _k in range(_m):') - # g_x1_H2, g_x2_H2 = g(H2_x1[..., _k], H2_x2[..., _k], t, *args) - all_H2 = [f'H2_{var}[..., _k]' for var in variables] - all_g_H2 = [f'g_{var}_H2' for var in variables] - code_lines.append(f' {", ".join(all_g_H2)} = g({", ".join(all_H2 + parameters)})') - # g_x1_H3, g_x2_H3 = g(H3_x1[..., _k], H3_x2[..., _k], t, *args) - all_H3 = [f'H3_{var}[..., _k]' for var in variables] - all_g_H3 = [f'g_{var}_H3' for var in variables] - code_lines.append(f' {", ".join(all_g_H3)} = g({", ".join(all_H3 + parameters)})') - # x1_new += 0.5 * dt_sqrt * (g_x1_H2[..., _k] - g_x1_H3[..., _k]) - # x2_new += 0.5 * dt_sqrt * (g_x2_H2[..., _k] - g_x2_H3[..., _k]) - for var in variables: - code_lines.append(f' {var}_new += 0.5 * {vdt}_sqrt * (g_{var}_H2[..., _k] - g_{var}_H3[..., _k])') - - -def _srk2_pop_or_scalar_var_scalar_wiener(sde_type, code_lines, variables, parameters, vdt): - if sde_type == constants.ITO_SDE: - I2 = f'0.5 * (_I1 * _I1 - {vdt})' - elif sde_type == constants.STRA_SDE: - I2 = f'0.5 * _I1 * _I1' - else: - raise ValueError(f'Unknown SDE_INT type: {sde_type}. We only supports {constants.SUPPORTED_INTG_TYPE}.') - - # shape info - # ----- - all_f = [f'f_{var}' for var in variables] - all_g = [f'g_{var}' for var in variables] - - code_string = f''' - {", ".join(all_f)} = f({", ".join(variables + parameters)}) # shape = (..) - {", ".join(all_g)} = g({", ".join(variables + parameters)}) # shape = (..) - - # single Ito integrals - _I1 = math.normal(0., {vdt}_sqrt, math.shape({variables[0]})) # shape = (..) - # double Ito integrals - _I2 = {I2} # shape = (..) - ''' - code_splits = code_string.split('\n') - code_lines.extend(code_splits) - - # numerical integration - # ----- - # H1 - for var in variables: - code_lines.append(f' g_H1_{var} = g_{var} * _I2 / {vdt}_sqrt # shape (.., )') - # H2 - all_H2 = [f'H2_{var}' for var in variables] - for var in variables: - code_lines.append(f' H2_{var} = {var} + g_H1_{var} # shape (.., )') - all_g_H2 = [f'g_{var}_H2' for var in variables] - code_lines.append(f' {", ".join(all_g_H2)} = g({", ".join(all_H2 + parameters)})') - code_lines.append(f' ') - # H3 - all_H3 = [f'H3_{var}' for var in variables] - for var in variables: - code_lines.append(f' H3_{var} = {var} - g_H1_{var} # shape (.., )') - all_g_H3 = [f'g_{var}_H3' for var in variables] - code_lines.append(f' {", ".join(all_g_H3)} = g({", ".join(all_H3 + parameters)})') - code_lines.append(f' ') - # final results - for var in variables: - code_lines.append(f' {var}_new = {var} + f_{var} + g_{var} * _I1 ' - f'+ 0.5 * {vdt}_sqrt * (g_{var}_H2 - g_{var}_H3)') - - -def _srk1_scalar_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt): - # shape information - all_f = [f'f_{var}' for var in variables] - all_g = [f'g_{var}' for var in variables] - code1 = f''' - # shape info # - # ---------- # - - {", ".join(all_f)} = f({", ".join(variables + parameters)}) # shape = () - {", ".join(all_g)} = g({", ".join(variables + parameters)}) # shape = (m) - noise_shape = math.shape(g_x1) - _m = noise_shape[0] - ''' - code_lines.extend(code1.split('\n')) - - # noise term - _vector_wiener_terms(code_lines, sde_type, vdt, shape_D='', shape_m='_m') - - # numerical integration - - # p1 - # --- - # g_x1_rs = math.reshape(g_x1, (1, _m)) - # g_x2_rs = math.reshape(g_x2, (1, _m)) - for var in variables: - code_lines.append(f' g_{var}_rs = math.reshape(g_{var}, (1, _m))') - - # p2 - # --- - # g_H1_x1 = math.matmul(g_x1_rs, _I2) / dt_sqrt # shape (1, m) - # g_H1_x2 = math.matmul(g_x2_rs, _I2) / dt_sqrt # shape (1, m) - for var in variables: - code_lines.append(f' g_H1_{var} = math.matmul(g_{var}_rs, _I2) / {vdt}_sqrt # shape (1, m)') - - # p3 - # --- - # H2_x1 = x1 + g_H1_x1[0] # shape (m) - # H3_x1 = x1 - g_H1_x1[0] # shape (m) - for var in variables: - code_lines.append(f' H2_{var} = {var} + g_H1_{var}[0] # shape (m)') - code_lines.append(' ') - - # p4 - # --- - # g1_x1 = math.matmul(g_x1_rs, _I1_rs) # shape (1, 1) - # x1_new = x1 + f_x1 + g1_x1[0, 0] # shape () - for var in variables: - code_lines.append(f' g1_{var} = math.matmul(g_{var}_rs, _I1_rs) # shape (1, 1)') - code_lines.append(f' {var}_new = {var} + f_{var} + g1_{var}[0, 0] # shape ()') - - # p5 - # --- - # for _k in range(_m): - # g_x1_H2, g_x2_H2 = g(H2_x1[_k], H2_x2[_k], t, *args) - # g_x1_H3, g_x2_H3 = g(H3_x1[_k], H3_x2[_k], t, *args) - # x1_new += 0.5 * dt_sqrt * (g_x1_H2[_k] - g_x1_H3[_k]) - # x2_new += 0.5 * dt_sqrt * (g_x2_H2[_k] - g_x2_H3[_k]) - code_lines.append(' for _k in range(_m):') - all_h2_k = [f'H2_{var}[_k]' for var in variables] - all_g_h2 = [f'g_{var}_H2' for var in variables] - code_lines.append(f' {", ".join(all_g_h2)} = g({", ".join(all_h2_k + parameters)})') - all_h3_k = [f'H3_{var}[_k]' for var in variables] - all_g_h3 = [f'g_{var}_H3' for var in variables] - code_lines.append(f' {", ".join(all_g_h3)} = g({", ".join(all_h3_k + parameters)})') - for var in variables: - code_lines.append(f' {var}_new += 0.5 * {vdt}_sqrt * (g_{var}_H2[_k] - g_{var}_H3[_k])') - - -def _srk1_system_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt): - # shape information - code1 = f''' - # shape infor # - # ----------- # - - f_x = f({", ".join(variables + parameters)}) # shape = (d, ..) - g_x = g({", ".join(variables + parameters)}) # shape = (d, .., m) - _shape = math.shape(g_x) - _d = _shape[0] - _m = _shape[-1] - _D = _shape[1:-1] - ''' - code_lines.extend(code1.split('\n')) - - # noise term - _vector_wiener_terms(code_lines, sde_type, vdt, shape_D='_D', shape_m='_m') - - # numerical integration - code2 = f''' - # numerical integration # - # --------------------- # - - g_x2 = math.moveaxis(g_x, 0, -2) # shape = (.., d, m) - g_H1_k = math.matmul(g_x2, _I2) / dt_sqrt # shape (.., d, m) - g_H1_k = math.moveaxis(g_H1_k, -2, 0) # shape (d, .., m) - x_rs = math.reshape(x, (_d,) + _D + (1,)) - H2 = x_rs + g_H1_k # shape (d, .., m) - H3 = x_rs - g_H1_k # shape (d, .., m) - - g1 = math.matmul(g_x2, _I1_rs) # shape (.., d, 1) - g1 = math.moveaxis(g1, -2, 0) # shape (d, .., 1) - y = x + f_x + g1[..., 0] # shape (d, ..) - for _k in range(_m): - y += 0.5 * dt_sqrt * g(H2[..., _k], t, *args)[..., _k] - y -= 0.5 * dt_sqrt * g(H3[..., _k], t, *args)[..., _k] - ''' - code_lines.extend(code2.split('\n')) - - -def _srk1_system_var_with_scalar_wiener(sde_type, code_lines, variables, parameters, vdt): - if sde_type == constants.ITO_SDE: - I2 = f'0.5 * (_I1 * _I1 - {vdt})' - elif sde_type == constants.STRA_SDE: - I2 = f'0.5 * _I1 * _I1' - else: - raise ValueError(f'Unknown SDE_INT type: {sde_type}. We only supports {constants.SUPPORTED_INTG_TYPE}.') - - code_string = f''' - f_x = f({", ".join(variables + parameters)}) # shape = (d, ..) - g_x = g({", ".join(variables + parameters)}) # shape = (d, ..) - _shape = math.shape(g_x) - _d = _shape[0] - _D = _shape[1:] - - # single Ito integrals - _I1 = math.normal(0., {vdt}_sqrt, _D) # shape = (..) - # double Ito integrals - _I2 = {I2} # shape = (..) - - # numerical integration # - # --------------------- # - g_H1_k = g_x * _I2 / {vdt}_sqrt # shape (d, ..) - H2 = x + g_H1_k # shape (d, ..) - H3 = x - g_H1_k # shape (d, ..) - - g1 = g_x * _I1 # shape (d, ..) - x_new = x + f_x + g1 # shape (d, ..) - x_new += 0.5 * {vdt}_sqrt * g(H2, {", ".join(parameters)}) - x_new -= 0.5 * {vdt}_sqrt * g(H3, {", ".join(parameters)}) - ''' - code_splits = code_string.split('\n') - code_lines.extend(code_splits) - - -def _srk1_wrapper(f, g, dt, sde_type, var_type, wiener_type, show_code, num_iter): - vdt, variables, parameters, arguments, func_name = basic_info(f=f, g=g) - - # 1. code scope - code_scope = {'f': f, 'g': g, vdt: dt, f'{vdt}_sqrt': dt ** 0.5, - 'math': math, 'num_iter': num_iter} - - # 2. code lines - code_lines = [f'def {func_name}({", ".join(arguments)}):'] - - if var_type == constants.SYSTEM_VAR: - if len(variables) > 1: - raise ValueError(f'SDE_INT with {constants.SYSTEM_VAR} variable type only ' - f'supports one system variable. But we got {variables}.') - - if wiener_type == constants.SCALAR_WIENER: - _srk1_system_var_with_scalar_wiener(sde_type, code_lines, variables, parameters, vdt) - elif wiener_type == constants.VECTOR_WIENER: - _srk1_system_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt) - else: - raise ValueError(f'Unknown Wiener type: {wiener_type}, we only ' - f'supports {constants.SUPPORTED_WIENER_TYPE}') - - elif var_type == constants.SCALAR_VAR: - if wiener_type == constants.SCALAR_WIENER: - _srk2_pop_or_scalar_var_scalar_wiener(sde_type, code_lines, variables, parameters, vdt) - elif wiener_type == constants.VECTOR_WIENER: - _srk1_scalar_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt) - else: - raise ValueError(f'Unknown Wiener type: {wiener_type}, we only ' - f'supports {constants.SUPPORTED_WIENER_TYPE}') - - elif var_type == constants.POP_VAR: - if wiener_type == constants.SCALAR_WIENER: - _srk2_pop_or_scalar_var_scalar_wiener(sde_type, code_lines, variables, parameters, vdt) - elif wiener_type == constants.VECTOR_WIENER: - _srk2_pop_var_vector_wiener(sde_type, code_lines, variables, parameters, vdt) - else: - raise ValueError(f'Unknown Wiener type: {wiener_type}, we only ' - f'supports {constants.SUPPORTED_WIENER_TYPE}') - - else: - raise ValueError(f'Unknown var type: {var_type}, we only ' - f'supports {constants.SUPPORTED_VAR_TYPE}') - # returns - new_vars = [f'{var}_new' for var in variables] - code_lines.append(f' return {", ".join(new_vars)}') - - # return and compile - utils.compile_code(code_lines, code_scope, show_code, variables) - return code_scope[func_name] - - -def _srk2_wrapper(): - pass - - -def _wrap(wrapper, f, g, dt, sde_type, var_type, wiener_type, show_code, num_iter): - """The brainpy_object function to format a SRK method. - - Parameters:: - - f : callable - The drift function of the SDE_INT. - g : callable - The diffusion function of the SDE_INT. - dt : float - The numerical precision. - sde_type : str - "utils.ITO_SDE" : Ito's Stochastic Calculus. - "utils.STRA_SDE" : Stratonovich's Stochastic Calculus. - wiener_type : str - var_type : str - "scalar" : with the shape of (). - "population" : with the shape of (N,) or (N1, N2) or (N1, N2, ...). - "system": with the shape of (d, ), (d, N), or (d, N1, N2). - show_code : bool - Whether show the formatted code. - - Returns:: - - numerical_func : callable - The numerical function. - """ - - sde_type = constants.ITO_SDE if sde_type is None else sde_type - assert sde_type in constants.SUPPORTED_INTG_TYPE, f'Currently, BrainPy only support SDE_INT types: ' \ - f'{constants.SUPPORTED_INTG_TYPE}. But we got {sde_type}.' - - var_type = constants.POP_VAR if var_type is None else var_type - assert var_type in constants.SUPPORTED_VAR_TYPE, f'Currently, BrainPy only supports variable types: ' \ - f'{constants.SUPPORTED_VAR_TYPE}. But we got {var_type}.' - - wiener_type = constants.SCALAR_WIENER if wiener_type is None else wiener_type - assert wiener_type in constants.SUPPORTED_WIENER_TYPE, f'Currently, BrainPy only supports Wiener ' \ - f'Process types: {constants.SUPPORTED_WIENER_TYPE}. ' \ - f'But we got {wiener_type}.' - - show_code = False if show_code is None else show_code - dt = math.get_dt() if dt is None else dt - num_iter = 10 if num_iter is None else num_iter - - if f is not None and g is not None: - return wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type, - var_type=var_type, wiener_type=wiener_type, num_iter=num_iter) - - elif f is not None: - return lambda g: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type, - var_type=var_type, wiener_type=wiener_type, num_iter=num_iter) - - elif g is not None: - return lambda f: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type, - var_type=var_type, wiener_type=wiener_type, num_iter=num_iter) - - else: - raise ValueError('Must provide "f" or "g".') - - -# ------------------ -# Numerical methods -# ------------------ - - -def srk1_strong(f=None, g=None, dt=None, sde_type=None, var_type=None, wiener_type=None, num_iter=None, show_code=None): - return _wrap(_srk1_wrapper, f=f, g=g, dt=dt, sde_type=sde_type, var_type=var_type, - wiener_type=wiener_type, show_code=show_code, num_iter=num_iter) - - -def srk2_strong(f=None, g=None, dt=None, sde_type=None, var_type=None, wiener_type=None, num_iter=None, show_code=None): - return _wrap(_srk2_wrapper, f=f, g=g, dt=dt, sde_type=sde_type, var_type=var_type, - wiener_type=wiener_type, show_code=show_code, num_iter=num_iter) diff --git a/brainpy/losses/comparison.py b/brainpy/losses/comparison.py index 67330b31a..c96644be2 100644 --- a/brainpy/losses/comparison.py +++ b/brainpy/losses/comparison.py @@ -198,10 +198,12 @@ def __init__(self, weight: Optional[ArrayType] = None, ignore_index: int = -100, self.label_smoothing = label_smoothing def update(self, input: ArrayType, target: ArrayType) -> ArrayType: - return cross_entropy_loss(input, target, weight=self.weight, reduction=self.reduction) + return cross_entropy_loss(input, target, weight=self.weight, reduction=self.reduction, + ignore_index=self.ignore_index, label_smoothing=self.label_smoothing) -def cross_entropy_loss(predicts, targets, weight=None, reduction='mean'): +def cross_entropy_loss(predicts, targets, weight=None, reduction='mean', + ignore_index=-100, label_smoothing=0.0): r"""This criterion combines ``LogSoftmax`` and `NLLLoss`` in one single class. It is useful when training a classification problem with `C` classes. @@ -260,11 +262,46 @@ def cross_entropy_loss(predicts, targets, weight=None, reduction='mean'): """ def _cel(_pred, _tar): + _pred = bm.as_jax(_pred) + num_classes = _pred.shape[-1] + # Per-sample class weight. The ``weight`` argument is indexed *by the + # target class* (``weight[y_n]``), not by the sample position. + sample_weight = None + # Mask of samples that should contribute to the loss (``ignore_index``). + valid_mask = None if bm.ndim(_tar) + 1 == bm.ndim(_pred): - _tar = bm.one_hot(_tar, _pred.shape[-1]) - loss = logsumexp(bm.as_jax(_pred), axis=-1) - (_pred * _tar).sum(axis=-1) - if weight is not None: - loss *= weight + # ``_tar`` holds integer class indices. + _tar_idx = bm.as_jax(_tar) + if weight is not None: + sample_weight = bm.as_jax(weight)[_tar_idx] + valid_mask = (_tar_idx != ignore_index) + # Build the (possibly label-smoothed) soft target distribution. Clamp + # ignored indices to 0 first so one_hot does not error on negatives. + _tar_clamped = jnp.where(valid_mask, _tar_idx, 0) + _soft = bm.as_jax(bm.one_hot(_tar_clamped, num_classes)) + if label_smoothing > 0.0: + _soft = _soft * (1.0 - label_smoothing) + label_smoothing / num_classes + else: + # ``_tar`` holds class probabilities / one-hot: the effective per-sample + # weight is the probability-weighted class weight (matches PyTorch). + _soft = bm.as_jax(_tar) + if label_smoothing > 0.0: + _soft = _soft * (1.0 - label_smoothing) + label_smoothing / num_classes + if weight is not None: + sample_weight = (bm.as_jax(weight) * _soft).sum(axis=-1) + loss = logsumexp(_pred, axis=-1) - (_pred * _soft).sum(axis=-1) + if sample_weight is not None: + loss = loss * sample_weight + if valid_mask is not None: + # Zero-out ignored samples so they contribute nothing to sum/mean. + loss = jnp.where(valid_mask, loss, 0.0) + if reduction == 'mean': + if sample_weight is not None: + denom = sample_weight if valid_mask is None else jnp.where(valid_mask, sample_weight, 0.0) + return loss.sum() / denom.sum() + if valid_mask is not None: + return loss.sum() / jnp.maximum(valid_mask.sum(), 1) + return loss.mean() return _reduce(outputs=loss, reduction=reduction) r = tree_map(_cel, predicts, targets, is_leaf=_is_leaf) @@ -458,7 +495,9 @@ def nll_loss(input, target, reduction: str = 'mean'): assert target.ndim + 1 == input.ndim input = bm.as_jax(input) target = bm.as_jax(target) - loss = input[jnp.arange(len(target)), target] + # Negative log-likelihood: l_n = -x_{n, y_n}. The leading minus sign is what + # makes this a *loss* to minimize (the raw log-probabilities are negative). + loss = -input[jnp.arange(len(target)), target] if reduction == 'mean': return loss.mean() elif reduction == 'sum': diff --git a/brainpy/math/_utils.py b/brainpy/math/_utils.py index 02ea058d6..bb526a4d4 100644 --- a/brainpy/math/_utils.py +++ b/brainpy/math/_utils.py @@ -60,6 +60,7 @@ def new_fun(*args, **kwargs): return tree_map(_return, r) else: out.value = r + return out new_fun.__doc__ = ( f'Similar to ``jax.numpy.{module + fun.__name__}`` function, ' diff --git a/brainpy/math/activations.py b/brainpy/math/activations.py index d8cedd35d..f3d66408c 100644 --- a/brainpy/math/activations.py +++ b/brainpy/math/activations.py @@ -665,7 +665,8 @@ def softmin(x, axis=-1): along dim will sum to 1). """ x = x.value if isinstance(x, Array) else x - unnormalized = jnp.exp(-x) + neg_x = -x + unnormalized = jnp.exp(neg_x - jax.lax.stop_gradient(neg_x.max(axis, keepdims=True))) return unnormalized / unnormalized.sum(axis, keepdims=True) diff --git a/brainpy/math/compat_numpy.py b/brainpy/math/compat_numpy.py index 51937348d..02e5de9f3 100644 --- a/brainpy/math/compat_numpy.py +++ b/brainpy/math/compat_numpy.py @@ -130,7 +130,7 @@ def fill_diagonal(a, val, inplace=True): if inplace: a.value = r else: - return r + return _return(r) def zeros(shape, dtype=None): @@ -142,7 +142,7 @@ def ones(shape, dtype=None): def empty(shape, dtype=None): - return _return(jnp.zeros(shape, dtype=dtype)) + return _return(jnp.empty(shape, dtype=dtype)) def zeros_like(a, dtype=None, shape=None): @@ -157,7 +157,7 @@ def ones_like(a, dtype=None, shape=None): def empty_like(a, dtype=None, shape=None): a = _as_jax_array_(a) - return _return(jnp.zeros_like(a, dtype=dtype, shape=shape)) + return _return(jnp.empty_like(a, dtype=dtype, shape=shape)) def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array: @@ -215,8 +215,8 @@ def ascontiguousarray(a, dtype=None, order=None): def asfarray(a, dtype=None): - if not np.issubdtype(dtype, np.inexact): - dtype = np.float64 + if dtype is None or not np.issubdtype(dtype, np.inexact): + dtype = jnp.float64 return asarray(a, dtype=dtype) diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py index 1728f55f7..d9939a64a 100644 --- a/brainpy/math/compat_pytorch.py +++ b/brainpy/math/compat_pytorch.py @@ -40,11 +40,12 @@ 'asin', 'arcsin', 'asinh', - 'arcsin', + 'arcsinh', 'atan', 'arctan', 'atan2', 'atanh', + 'arctanh', 'clamp_max', 'clamp_min', 'arctan2', @@ -161,6 +162,7 @@ def abs( else: _check_out(out) out.value = r + return out absolute = abs @@ -178,6 +180,7 @@ def acos( else: _check_out(out) out.value = r + return out arccos = acos @@ -195,6 +198,7 @@ def acosh( else: _check_out(out) out.value = r + return out arccosh = acosh @@ -224,6 +228,7 @@ def add( else: _check_out(out) out.value = r + return out def addcdiv( @@ -266,6 +271,7 @@ def angle( else: _check_out(out) out.value = r + return out def asin( @@ -280,6 +286,7 @@ def asin( else: _check_out(out) out.value = r + return out arcsin = asin @@ -297,6 +304,7 @@ def asinh( else: _check_out(out) out.value = r + return out arcsinh = asinh @@ -314,6 +322,7 @@ def atan( else: _check_out(out) out.value = r + return out arctan = atan @@ -331,6 +340,7 @@ def atanh( else: _check_out(out) out.value = r + return out arctanh = atanh @@ -350,6 +360,7 @@ def atan2( else: _check_out(out) out.value = r + return out arctan2 = atan2 diff --git a/brainpy/math/compat_tensorflow.py b/brainpy/math/compat_tensorflow.py index ede493ab2..6f5a07063 100644 --- a/brainpy/math/compat_tensorflow.py +++ b/brainpy/math/compat_tensorflow.py @@ -14,8 +14,10 @@ # ============================================================================== from typing import Union, Optional +import jax import jax.numpy as jnp import jax.ops +import jax.scipy.special from jax import lax from brainpy.math.interoperability import as_jax @@ -74,7 +76,7 @@ def reduce_logsumexp(input_tensor, axis=None, keepdims=False): Returns: The reduced tensor. """ - r = jnp.log(jnp.sum(jnp.exp(_as_jax_array_(input_tensor)), axis=axis, keepdims=keepdims)) + r = jax.scipy.special.logsumexp(_as_jax_array_(input_tensor), axis=axis, keepdims=keepdims) return _return(r) diff --git a/brainpy/math/delayvars.py b/brainpy/math/delayvars.py index f9ff8fb9e..10d69b420 100644 --- a/brainpy/math/delayvars.py +++ b/brainpy/math/delayvars.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import inspect import numbers from typing import Union, Callable @@ -45,6 +46,24 @@ def _as_jax_array(arr): return arr.value if isinstance(arr, Array) else arr +def _accepts_dtype_kwarg(func): + """Return True if ``func`` can be called with a ``dtype`` keyword argument. + + Initializer/Connector instances accept ``(shape, dtype=...)``, whereas a plain + callable such as ``lambda shape: ...`` does not. When the signature cannot be + introspected (e.g. some built-ins or C-implemented callables), we conservatively + assume the ``dtype`` keyword is accepted to preserve the historical behaviour. + """ + try: + sig = inspect.signature(func) + except (TypeError, ValueError): + return True + params = sig.parameters + if 'dtype' in params: + return True + return any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) + + class AbstractDelay(BrainPyObject): pass @@ -204,20 +223,31 @@ def reset(self, The maximum delay length. The unit is the time. t0: int, float The zero time. - before_t0: int, float, ArrayType + before_t0: callable, int, float, ArrayType The data before t0. + - when ``before_t0`` is a function, it should receive a time argument ``t`` + (mirroring the behaviour of ``__init__``). + - when ``before_t0`` is a tensor / numerical value, it is broadcast into the + delay data before ``t0``. """ + self.t0 = t0 self.delay_len = delay_len self.num_delay_step = int(jnp.ceil(self.delay_len / self.dt)) + 1 self.data.value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) self.data[-1] = delay_target self.idx = Variable(jnp.asarray([0])) - self.current_time = Variable(jnp.asarray([t0])) - if before_t0 is not None: - if not isinstance(before_t0, (ndarray, jnp.ndarray, float, int)): - raise ValueError('Only support numerical values.') + self.current_time = Variable(jnp.asarray([t0], dtype=get_float())) + if before_t0 is None: + self._before_type = _DATA_BEFORE + elif callable(before_t0): + self._before_t0 = lambda t: as_jax(broadcast_to(before_t0(t), delay_target.shape), + dtype=delay_target.dtype) + self._before_type = _FUNC_BEFORE + elif isinstance(before_t0, (ndarray, jnp.ndarray, float, int)): self.data[:-1] = before_t0 self._before_type = _DATA_BEFORE + else: + raise ValueError(f'"before_t0" does not support {type(before_t0)}') def _check_time1(self, times): prev_time, current_time = times @@ -268,7 +298,7 @@ def _after_t0(self, prev_time): f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}') def _true_fn(self, req_num_step, extra): - return self.data[self.idx[0] + req_num_step] + return self.data[(self.idx[0] + req_num_step) % self.num_delay_step] def _false_fn(self, req_num_step, extra): idx = jnp.asarray([self.idx[0] + req_num_step, @@ -303,6 +333,10 @@ class LengthDelay(AbstractDelay): initial_delay_data: Any The delay data. It can be a Python number, like float, int, boolean values. It can also be arrays. Or a callable function or instance of ``Connector``. + A callable will be invoked as ``initial_delay_data(shape, dtype=...)`` when its + signature accepts a ``dtype`` keyword (e.g. ``Initializer``/``Connector`` + instances), and as ``initial_delay_data(shape)`` otherwise (e.g. a plain + ``lambda shape: ...``). Note that ``initial_delay_data`` should be arranged as the following way:: delay = 1 [ data @@ -416,8 +450,16 @@ def reset( elif isinstance(initial_delay_data, (ndarray, jnp.ndarray, float, int, bool)): self.data[1:] = initial_delay_data elif callable(initial_delay_data): - self.data[1:] = initial_delay_data((delay_len,) + delay_target.shape, - dtype=delay_target.dtype) + shape = (delay_len,) + delay_target.shape + dtype = delay_target.dtype + # Initializer/Connector instances accept ``(shape, dtype=...)``, but a plain + # callable (e.g. ``lambda shape: ...``) may not accept the ``dtype`` keyword. + # Branch on whether the callable's signature accepts ``dtype`` and fall back + # to calling it with the shape only when it does not. + if _accepts_dtype_kwarg(initial_delay_data): + self.data[1:] = initial_delay_data(shape, dtype=dtype) + else: + self.data[1:] = initial_delay_data(shape) else: raise ValueError(f'"delay_data" does not support {type(initial_delay_data)}') diff --git a/brainpy/math/einops.py b/brainpy/math/einops.py index b74eeb63e..05c9ab2ab 100644 --- a/brainpy/math/einops.py +++ b/brainpy/math/einops.py @@ -81,61 +81,6 @@ def __reduce(x: Union[Array, jax.Array], operation: str, reduced_axes): raise NotImplementedError("Unknown reduction ", operation) -def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes): - # 'collapses' neighboring axes if those participate in the result pattern in the same order - # TODO add support for added_axes - assert len(axes_reordering) + len(reduced_axes) == len(init_shapes) - # joining consecutive axes that will be reduced - # possibly we can skip this if all backends can optimize this (not sure) - reduced_axes = tuple(sorted(reduced_axes)) - for i in range(len(reduced_axes) - 1)[::-1]: - if reduced_axes[i] + 1 == reduced_axes[i + 1]: - removed_axis = reduced_axes[i + 1] - removed_length = init_shapes[removed_axis] - init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:] - init_shapes[removed_axis - 1] *= removed_length - reduced_axes = reduced_axes[: i + 1] + tuple(axis - 1 for axis in reduced_axes[i + 2:]) - - # removing axes that are moved together during reshape - def build_mapping(): - init_to_final = {} - for axis in range(len(init_shapes)): - if axis in reduced_axes: - init_to_final[axis] = None - else: - after_reduction = sum(x is not None for x in init_to_final.values()) - init_to_final[axis] = list(axes_reordering).index(after_reduction) - return init_to_final - - init_axis_to_final_axis = build_mapping() - - for init_axis in range(len(init_shapes) - 1)[::-1]: - if init_axis_to_final_axis[init_axis] is None: - continue - if init_axis_to_final_axis[init_axis + 1] is None: - continue - if init_axis_to_final_axis[init_axis] + 1 == init_axis_to_final_axis[init_axis + 1]: - removed_axis = init_axis + 1 - removed_length = init_shapes[removed_axis] - removed_axis_after_reduction = sum(x not in reduced_axes for x in range(removed_axis)) - - reduced_axes = tuple(axis if axis < removed_axis else axis - 1 for axis in reduced_axes) - init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:] - init_shapes[removed_axis - 1] *= removed_length - old_reordering = axes_reordering - axes_reordering = [] - for axis in old_reordering: - if axis == removed_axis_after_reduction: - pass - elif axis < removed_axis_after_reduction: - axes_reordering.append(axis) - else: - axes_reordering.append(axis - 1) - init_axis_to_final_axis = build_mapping() - - return init_shapes, reduced_axes, axes_reordering, final_shapes - - CookedRecipe = Tuple[Optional[List[int]], Optional[List[int]], List[int], Dict[int, int], Optional[List[int]], int] # Actual type is tuple[tuple[str, int], ...] diff --git a/brainpy/math/environment.py b/brainpy/math/environment.py index 292d6dadf..3bf28d251 100644 --- a/brainpy/math/environment.py +++ b/brainpy/math/environment.py @@ -388,43 +388,57 @@ def set( numpy_func_return: str The array to return in all numpy functions. Support 'bp_array' and 'jax_array'. """ + # Validate all arguments BEFORE mutating any global state, so that an + # invalid argument cannot leave the environment in a half-updated state. if dt is not None: assert isinstance(dt, float), '"dt" must a float.' + if mode is not None: + assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.' + if membrane_scaling is not None: + assert isinstance(membrane_scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.' + if x64 is not None: + assert isinstance(x64, bool), f'"x64" must be a bool.' + if float_ is not None: + assert isinstance(float_, type), '"float_" must a float.' + if int_ is not None: + assert isinstance(int_, type), '"int_" must a type.' + if bool_ is not None: + assert isinstance(bool_, type), '"bool_" must a type.' + if complex_ is not None: + assert isinstance(complex_, type), '"complex_" must a type.' + if numpy_func_return is not None: + assert numpy_func_return in ['bp_array', 'jax_array'], \ + '"numpy_func_return" must be "bp_array" or "jax_array".' + + # All validation passed; now apply the settings. + if dt is not None: set_dt(dt) if mode is not None: - assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.' set_mode(mode) if membrane_scaling is not None: - assert isinstance(membrane_scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.' set_membrane_scaling(membrane_scaling) if x64 is not None: - assert isinstance(x64, bool), f'"x64" must be a bool.' set_x64(x64) if float_ is not None: - assert isinstance(float_, type), '"float_" must a float.' set_float(float_) if int_ is not None: - assert isinstance(int_, type), '"int_" must a type.' set_int(int_) if bool_ is not None: - assert isinstance(bool_, type), '"bool_" must a type.' set_bool(bool_) if complex_ is not None: - assert isinstance(complex_, type), '"complex_" must a type.' set_complex(complex_) if bp_object_as_pytree is not None: defaults.bp_object_as_pytree = bp_object_as_pytree if numpy_func_return is not None: - assert numpy_func_return in ['bp_array', 'jax_array'], f'"numpy_func_return" must be "bp_array" or "jax_array".' defaults.numpy_func_return = numpy_func_return @@ -643,6 +657,7 @@ def enable_x64(x64=None): def disable_x64(): + brainstate.environ.set(precision=32) config.update("jax_enable_x64", False) set_int(jnp.int32) set_float(jnp.float32) diff --git a/brainpy/math/event/csr_matmat.py b/brainpy/math/event/csr_matmat.py index c33d1103f..a8084328d 100644 --- a/brainpy/math/event/csr_matmat.py +++ b/brainpy/math/event/csr_matmat.py @@ -62,6 +62,6 @@ def csrmm( matrix = brainevent.BinaryArray(matrix) csr = brainevent.CSR((data, indices, indptr), shape=shape) if transpose: - return matrix @ csr + return csr.T @ matrix else: return csr @ matrix diff --git a/brainpy/math/jitconn/matvec.py b/brainpy/math/jitconn/matvec.py index 094f758e0..89442333c 100644 --- a/brainpy/math/jitconn/matvec.py +++ b/brainpy/math/jitconn/matvec.py @@ -77,6 +77,17 @@ def mv_prob_homo( The matrix shape. seed: int The random number generation seed. + + .. warning:: + + If ``seed`` is left as ``None`` (the default), a host random seed is + drawn with ``numpy.random.randint`` on **every call**. This makes the + result non-reproducible in eager mode, and -- because the seed is + captured as a Python constant -- it becomes **frozen** the first time + the function is traced under ``jit()``/``vmap()``, so every subsequent + jitted call reuses that single seed. For reproducible and correct + behaviour under JAX transformations, always pass an explicit integer + ``seed``. transpose: bool Transpose the random matrix or not. outdim_parallel: bool @@ -152,6 +163,17 @@ def mv_prob_uniform( The matrix shape. seed: int The random number generation seed. + + .. warning:: + + If ``seed`` is left as ``None`` (the default), a host random seed is + drawn with ``numpy.random.randint`` on **every call**. This makes the + result non-reproducible in eager mode, and -- because the seed is + captured as a Python constant -- it becomes **frozen** the first time + the function is traced under ``jit()``/``vmap()``, so every subsequent + jitted call reuses that single seed. For reproducible and correct + behaviour under JAX transformations, always pass an explicit integer + ``seed``. transpose: bool Transpose the random matrix or not. outdim_parallel: bool @@ -229,6 +251,17 @@ def mv_prob_normal( The matrix shape. seed: int The random number generation seed. + + .. warning:: + + If ``seed`` is left as ``None`` (the default), a host random seed is + drawn with ``numpy.random.randint`` on **every call**. This makes the + result non-reproducible in eager mode, and -- because the seed is + captured as a Python constant -- it becomes **frozen** the first time + the function is traced under ``jit()``/``vmap()``, so every subsequent + jitted call reuses that single seed. For reproducible and correct + behaviour under JAX transformations, always pass an explicit integer + ``seed``. transpose: bool Transpose the random matrix or not. outdim_parallel: bool @@ -276,6 +309,17 @@ def get_homo_weight_matrix( The matrix shape. seed: int The random number generation seed. + + .. warning:: + + If ``seed`` is left as ``None`` (the default), a host random seed is + drawn with ``numpy.random.randint`` on **every call**. This makes the + result non-reproducible in eager mode, and -- because the seed is + captured as a Python constant -- it becomes **frozen** the first time + the function is traced under ``jit()``/``vmap()``, so every subsequent + jitted call reuses that single seed. For reproducible and correct + behaviour under JAX transformations, always pass an explicit integer + ``seed``. transpose: bool Transpose the random matrix or not. outdim_parallel: bool @@ -320,6 +364,17 @@ def get_uniform_weight_matrix( The matrix shape. seed: int The random number generation seed. + + .. warning:: + + If ``seed`` is left as ``None`` (the default), a host random seed is + drawn with ``numpy.random.randint`` on **every call**. This makes the + result non-reproducible in eager mode, and -- because the seed is + captured as a Python constant -- it becomes **frozen** the first time + the function is traced under ``jit()``/``vmap()``, so every subsequent + jitted call reuses that single seed. For reproducible and correct + behaviour under JAX transformations, always pass an explicit integer + ``seed``. transpose: bool Transpose the random matrix or not. outdim_parallel: bool @@ -367,6 +422,17 @@ def get_normal_weight_matrix( The matrix shape. seed: int The random number generation seed. + + .. warning:: + + If ``seed`` is left as ``None`` (the default), a host random seed is + drawn with ``numpy.random.randint`` on **every call**. This makes the + result non-reproducible in eager mode, and -- because the seed is + captured as a Python constant -- it becomes **frozen** the first time + the function is traced under ``jit()``/``vmap()``, so every subsequent + jitted call reuses that single seed. For reproducible and correct + behaviour under JAX transformations, always pass an explicit integer + ``seed``. transpose: bool Transpose the random matrix or not. outdim_parallel: bool diff --git a/brainpy/math/modes.py b/brainpy/math/modes.py index 0aec62f87..ea9ba1501 100644 --- a/brainpy/math/modes.py +++ b/brainpy/math/modes.py @@ -40,6 +40,11 @@ def __eq__(self, other: 'Mode'): return False return other.__class__ == self.__class__ + # Defining ``__eq__`` sets ``__hash__`` to None (making instances + # unhashable). Restore hashability so modes can be used in sets / as dict + # keys. Modes compare equal iff they share a class, so hash by class. + __hash__ = brainstate.mixin.Mode.__hash__ + def is_one_of(self, *modes): for m_ in modes: if not isinstance(m_, type): diff --git a/brainpy/math/ndarray.py b/brainpy/math/ndarray.py index 3b14e8833..b41c702df 100644 --- a/brainpy/math/ndarray.py +++ b/brainpy/math/ndarray.py @@ -19,7 +19,6 @@ import jax import numpy as np from jax import numpy as jnp -from jax.dtypes import canonicalize_dtype from jax.tree_util import register_pytree_node_class from brainpy._errors import MathError @@ -28,30 +27,10 @@ bm = None __all__ = [ - 'Array', 'Array', 'ndarray', 'JaxArray', # alias of Array + 'Array', 'ndarray', 'JaxArray', # alias of Array 'ShardedArray', ] -# Ways to change values in a zero-dimensional array -# ----- -# Reference: https://stackoverflow.com/questions/56954714/how-do-i-assign-to-a-zero-dimensional-numpy-array -# -# >>> x = np.array(10) -# 1. index the original array with ellipsis or an empty tuple -# >>> x[...] = 2 -# >>> x[()] = 2 - -_all_slice = slice(None, None, None) - - -def _check_input_array(array): - if isinstance(array, Array): - return array.value - elif isinstance(array, np.ndarray): - return jnp.asarray(array) - else: - return array - def _return(a): if defaults.numpy_func_return == 'bp_array' and isinstance(a, jax.Array) and a.ndim > 0: @@ -68,14 +47,6 @@ def _check_out(out): raise TypeError(f'out must be an instance of brainpy Array. But got {type(out)}') -def _get_dtype(v): - if hasattr(v, 'dtype'): - dtype = v.dtype - else: - dtype = canonicalize_dtype(type(v)) - return dtype - - @register_pytree_node_class class Array(u.CustomArray): """Multiple-dimensional array in BrainPy. @@ -102,6 +73,13 @@ def __init__(self, value, dtype: Any = None): value = value.value elif isinstance(value, (tuple, list, np.ndarray)): value = jnp.asarray(value) + elif isinstance(value, jax.Array): + pass + else: + # raw Python scalars (int/float/bool/complex) and any other input: + # convert to a jax array so ``self._value`` is always array-like + # (mirrors the ``value`` setter). + value = jnp.asarray(value) if dtype is not None: value = jnp.asarray(value, dtype=dtype) self._value = value @@ -131,11 +109,14 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, aux_data, flat_contents): - return cls(*flat_contents) - - # ins = object.__new__(cls) - # ins._value = flat_contents[0] - # return ins + # Reconstruct without going through ``__init__``: during abstract + # evaluation (``jax.eval_shape``, ``scan``/``for_loop`` tracing) the leaf + # is a ``ShapedArray``/``ShapeDtypeStruct`` rather than a concrete array, + # and ``__init__`` would try to ``jnp.asarray`` it and raise. Storing the + # leaf directly keeps the pytree round-trip transparent. + ins = object.__new__(cls) + ins._value = flat_contents[0] + return ins @property def data(self): @@ -198,17 +179,23 @@ def at(self): return self.value.at def block_host_until_ready(self, *args): - return self.value.block_host_until_ready(*args) + # ``jax.Array.block_host_until_ready`` was removed; ``block_until_ready`` + # is the modern equivalent. + return self.value.block_until_ready(*args) def block_until_ready(self, *args): return self.value.block_until_ready(*args) + @property def device(self): - return self.value.device() + # ``jax.Array.device`` is now a property (it used to be a method). + return self.value.device @property def device_buffer(self): - return self.value.device_buffer + # ``jax.Array.device_buffer`` was removed; the addressable shard's data + # is the modern equivalent on a single-device array. + return self.value.addressable_data(0) def fill_(self, fill_value): """Fill the array with a scalar value. @@ -264,8 +251,16 @@ def value(self): The stored data. """ v = self._value - # keep sharding constraints - if self._keep_sharding and hasattr(v, 'sharding') and (v.sharding is not None): + # Keep sharding constraints, but only for genuinely multi-device + # shardings. A ``SingleDeviceSharding`` (the default on a single device, + # e.g. CPU) carries no distribution information, so inserting a + # ``with_sharding_constraint`` on every read is pure overhead. + if ( + self._keep_sharding + and hasattr(v, 'sharding') + and (v.sharding is not None) + and not isinstance(v.sharding, jax.sharding.SingleDeviceSharding) + ): return jax.lax.with_sharding_constraint(v, v.sharding) # return the value return v diff --git a/brainpy/math/object_transform/_utils.py b/brainpy/math/object_transform/_utils.py index 997f8b1bb..c4b32d71a 100644 --- a/brainpy/math/object_transform/_utils.py +++ b/brainpy/math/object_transform/_utils.py @@ -24,6 +24,7 @@ __all__ = [ 'infer_dyn_vars', 'get_brainpy_object', + 'warp_to_no_state_input_output', ] diff --git a/brainpy/math/object_transform/base.py b/brainpy/math/object_transform/base.py index f6a6b9762..4c61425fd 100644 --- a/brainpy/math/object_transform/base.py +++ b/brainpy/math/object_transform/base.py @@ -146,77 +146,29 @@ def tracing_variable( axis_names: Optional[Sequence[str]] = None, batch_axis_name: Optional[str] = BATCH_AXIS, ) -> Variable: - """Initialize the variable which can be traced during computations and transformations. + """Initialize a variable that can be traced during computations and transformations. - Although this function is designed to initialize tracing variables during computation or compilation, - it can also be used for the initialization of variables before computation and compilation. - - - If the variable has not been instantiated, a :py:class:`~.Variable` will be instantiated. - - If the variable has been created, the further call of this function will return the created variable. - - Here is the usage example:: - - class Example(bm.BrainPyObject): - def fun(self): - # The first time of calling `.fun()`, this line will create a Variable instance. - # If users repeatedly call `.fun()` function, this line will not initialize variables again. - # Instead, it will return the variable has been created. - self.tracing_variable('a', bm.zeros, (10,)) - - # The created variable can be accessed with self.xxx - self.a.value = bm.ones(10) - - # Calling this function again will not reinitialize the - # variable again, Instead, it will return the variable - # that has been created. - a = self.tracing_variable('a', bm.zeros, (10,)) - - .. versionadded:: 2.4.5 + .. deprecated:: 3.0.0 + This feature is no longer supported. Since BrainPy 3.0.0 the library + has been rewritten on top of ``brainstate`` and variable tracing is + handled by ``brainstate`` directly. Calling this method always raises + :class:`NotImplementedError`. Args: name: str. The variable name. init: callable, Array. The data to be initialized as a ``Variable``. - batch_or_mode: int, bool, Mode. This is used to specify the batch size of this variable. - If it is a boolean or an instance of ``Mode``, the batch size will be 1. - If it is None, the variable has no batch axis. shape: int, sequence of int. The shape of the variable. + batch_or_mode: int, bool, Mode. The batch size of this variable. batch_axis: int. The batch axis, if batch size is given. - axis_names: sequence of str. The name for each axis. These names should match the given ``axes``. - batch_axis_name: str. The name for the batch axis. The name will be used - if ``batch_or_mode`` is given. Default is ``brainpy.math.sharding.BATCH_AXIS``. + axis_names: sequence of str. The name for each axis. + batch_axis_name: str. The name for the batch axis. - Returns: - The instance of :py:class:`~.Variable`. + Raises: + NotImplementedError: Always, because this feature is unsupported since 3.0.0. """ - # the variable has been created raise NotImplementedError( 'Since 3.0.0, brainpy is rewritten with brainstate. The feature tracing_variable is no longer supported. ' ) - if hasattr(self, name): - var = getattr(self, name) - if isinstance(var, Variable): - return var - # if var.shape != value.shape: - # raise ValueError( - # f'"{name}" has been used in this class with the shape of {var.shape} (!= {value.shape}). ' - # f'Please assign another name for the initialization of variables ' - # f'tracing during computation and compilation.' - # ) - # if var.dtype != value.dtype: - # raise ValueError( - # f'"{name}" has been used in this class with the dtype of {var.dtype} (!= {value.dtype}). ' - # f'Please assign another name for the initialization of variables ' - # f'tracing during computation and compilation.' - # ) - - global variable_ - if variable_ is None: - from brainpy.initialize import variable_ - with jax.ensure_compile_time_eval(): - value = variable_(init, shape, batch_or_mode, batch_axis, axis_names, batch_axis_name) - value.ready_to_trace = True - self.setattr(name, value) - return value def __setattr__(self, key: str, value: Any) -> None: """Overwrite `__setattr__` method for changing :py:class:`~.Variable` values. @@ -288,25 +240,41 @@ def register_implicit_vars(self, *variables, var_cls: type = None, **named_varia if var_cls is None: var_cls = (Variable, VarList, VarDict) + def _store(key, value): + # ``self.implicit_vars`` is an ``ArrayCollector`` whose entries must + # be plain ``Variable`` instances (it is consumed directly by + # ``vars()``). ``VarList``/``VarDict`` containers are therefore + # flattened into their constituent variables before insertion, + # mirroring how attribute-stored containers are expanded in + # ``vars()``. + if isinstance(value, VarList): + for i, vv in enumerate(value): + self.implicit_vars[f'{key}-{i}'] = vv + elif isinstance(value, VarDict): + for kk, vv in value.items(): + self.implicit_vars[f'{key}-{kk}'] = vv + else: + self.implicit_vars[key] = value + for variable in variables: if isinstance(variable, var_cls): - self.implicit_vars[f'var{id(variable)}'] = variable + _store(f'var{id(variable)}', variable) elif isinstance(variable, (tuple, list)): for v in variable: if not isinstance(v, var_cls): raise ValueError(f'Must be instance of {var_cls}, but we got {type(v)}') - self.implicit_vars[f'var{id(v)}'] = v + _store(f'var{id(v)}', v) elif isinstance(variable, dict): for k, v in variable.items(): if not isinstance(v, var_cls): raise ValueError(f'Must be instance of {var_cls}, but we got {type(v)}') - self.implicit_vars[k] = v + _store(k, v) else: raise ValueError(f'Unknown type: {type(variable)}') for key, variable in named_variables.items(): if not isinstance(variable, var_cls): raise ValueError(f'Must be instance of {var_cls}, but we got {type(variable)}') - self.implicit_vars[key] = variable + _store(key, variable) def register_implicit_nodes(self, *nodes, node_cls: type = None, **named_nodes): if node_cls is None: @@ -606,11 +574,12 @@ def to(self, device: Optional[Any]): Args: device: The device. """ - for key, var in self.state_dict().items(): - if isinstance(var, Array): - var.value = jax.device_put(var.value, device=device) - else: - setattr(self, key, jax.device_put(var, device=device)) + # Iterate over the actual ``Variable`` instances (not the nested + # ``state_dict`` mapping). Iterating ``state_dict()`` would yield + # nested dicts/raw arrays keyed by name, so nothing was ever moved and + # ``setattr`` injected junk attributes onto the object. + for var in self.vars().values(): + var.value = jax.device_put(var.value, device=device) return self def cpu(self): diff --git a/brainpy/math/object_transform/controls.py b/brainpy/math/object_transform/controls.py index 26c0f2aea..4e8942b4e 100644 --- a/brainpy/math/object_transform/controls.py +++ b/brainpy/math/object_transform/controls.py @@ -23,6 +23,31 @@ from brainpy.math.ndarray import Array from ._utils import warp_to_no_state_input_output + +def _unwrap_operand_leaf(x): + """Replace a ``State``/``Variable`` or BrainPy ``Array`` leaf with its raw value. + + ``brainstate.transform.*`` rejects ``State`` objects passed as operands, and feeding a + BrainPy ``Array`` through brainstate's loop primitives round-trips it through + ``tree_unflatten`` (which reconstructs from ``ShapedArray`` avals and fails) inside a + JAX trace. Unwrapping both to the underlying ``jax.Array`` avoids both problems while + leaving any other operand type untouched. + """ + if isinstance(x, (brainstate.State, Array)): + return x.value + return x + + +def _unwrap_state_operands(operands): + """Unwrap ``brainstate.State`` (e.g. :py:class:`~.Variable`) and :py:class:`~.Array` + leaves in ``operands`` to their raw ``jax.Array`` values before forwarding to brainstate. + """ + return jax.tree.map( + _unwrap_operand_leaf, + operands, + is_leaf=lambda x: isinstance(x, (brainstate.State, Array)), + ) + __all__ = [ 'cond', 'ifelse', @@ -124,6 +149,7 @@ def cond( """ if not isinstance(operands, (tuple, list)): operands = (operands,) + operands = _unwrap_state_operands(operands) return brainstate.transform.cond( pred, warp_to_no_state_input_output(true_fun), @@ -193,6 +219,7 @@ def ifelse( operands = () elif not isinstance(operands, (tuple, list)): operands = (operands,) + operands = _unwrap_state_operands(operands) # Convert non-callable branches to callables def make_callable(branch): @@ -234,7 +261,10 @@ def make_callable(branch): conditions = exclusive_conditions - return brainstate.transform.ifelse(conditions, branches, *operands) + # BrainPy already converts the conditions into mutually exclusive form above, + # so brainstate does not need to re-check exclusivity (which would otherwise + # reject overlapping inputs at trace time). + return brainstate.transform.ifelse(conditions, branches, *operands, check_cond=False) def for_loop( @@ -311,6 +341,12 @@ def for_loop( iteration of a loop. jit: bool Whether to just-in-time compile the function. Set to ``False`` to disable JIT compilation. + + .. note:: + ``jit=False`` is implemented via the global :py:func:`jax.disable_jit` context + manager. Consequently it has no effect when ``for_loop`` is called inside an + enclosing trace (e.g. within another jitted/scanned function): JAX is already + tracing, so the loop runs as a compiled ``scan`` regardless of this flag. progress_bar: bool, ProgressBar, int Whether and how to display a progress bar during execution: @@ -362,6 +398,7 @@ def for_loop( """ if not isinstance(operands, (tuple, list)): operands = (operands,) + operands = _unwrap_state_operands(operands) # Convert progress_bar to pbar format pbar = _convert_progress_bar_to_pbar(progress_bar) @@ -371,11 +408,12 @@ def for_loop( # For zero-length inputs, we need to use JIT mode even when jit=False. should_disable_jit = False if jit is False: - # Check if any operand has zero length - first_operand = operands[0] - is_zero_length = False - if hasattr(first_operand, 'shape') and len(first_operand.shape) > 0: - is_zero_length = (first_operand.shape[0] == 0) + # Check if any operand (over the whole pytree) has a zero-length leading axis. + leaves = jax.tree.leaves(operands) + is_zero_length = any( + getattr(leaf, 'ndim', 0) > 0 and leaf.shape[0] == 0 + for leaf in leaves + ) if is_zero_length: # Use JIT mode for zero-length inputs to avoid JAX limitation @@ -464,13 +502,21 @@ def scan( Now accepts ProgressBar instances and integers for advanced customization. Returns:: - - outs: Any - The stacked outputs of ``body_fun`` when scanned over the leading axis of the inputs. + + outs: tuple + A two-element tuple ``(final_carry, stacked_ys)``: + + - ``final_carry``: the loop carry value returned by the last iteration of + ``body_fun`` (same structure as ``init``). + - ``stacked_ys``: the per-iteration outputs of ``body_fun`` stacked along a + new leading axis. """ # Convert progress_bar to pbar format pbar = _convert_progress_bar_to_pbar(progress_bar) + init = _unwrap_state_operands(init) + operands = _unwrap_state_operands(operands) + return brainstate.transform.scan( warp_to_no_state_input_output(body_fun), init=init, @@ -546,13 +592,17 @@ def while_loop( if not isinstance(operands, (tuple, list)): operands = (operands,) operands = tuple(operands) + operands = _unwrap_state_operands(operands) def body(x): r = body_fun(*x) if r is None: - return x - else: - return r + raise ValueError( + '`body_fun` of `while_loop` must return the updated operands, ' + 'but got `None`. Returning `None` would leave the operands unchanged ' + 'and the loop condition would never become False, causing an infinite loop.' + ) + return r return brainstate.transform.while_loop( warp_to_no_state_input_output(lambda x: cond_fun(*x)), diff --git a/brainpy/math/object_transform/function.py b/brainpy/math/object_transform/function.py index 8469e4c63..631295879 100644 --- a/brainpy/math/object_transform/function.py +++ b/brainpy/math/object_transform/function.py @@ -27,6 +27,42 @@ class Partial(FunAsObject): + """A picklable, object-aware partial application of a function. + + ``Partial`` behaves like :py:func:`functools.partial`: it binds positional and + keyword arguments to ``fun`` so that the remaining arguments can be supplied + later when the instance is called. Unlike :py:func:`functools.partial`, it is a + :py:class:`~.BrainPyObject`, so any :py:class:`~.Variable` instances and child + :py:class:`~.BrainPyObject` objects used by ``fun`` are registered and tracked. + + Parameters + ---------- + fun : callable + The function to be partially applied. + *args : Any + Positional arguments bound ahead of the call-time positional arguments. + child_objs : callable, BrainPyObject, sequence of BrainPyObject, dict of BrainPyObject, optional + The children objects used in ``fun``. + dyn_vars : Variable, sequence of Variable, dict of Variable, optional + The :py:class:`~.Variable` instances used in ``fun``. + **keywords : Any + Keyword arguments bound to ``fun``. Keywords supplied at call time take + precedence over those bound here. + + See Also + -------- + to_object : Transform a Python function into a :py:class:`~.BrainPyObject`. + + Examples + -------- + .. code-block:: python + + >>> import brainpy.math as bm + >>> add = bm.Partial(lambda x, y: x + y, 1) + >>> add(2) + 3 + """ + def __init__( self, fun: Callable, @@ -110,6 +146,7 @@ def function( func: FunAsObject The instance of ``BrainPyObject``. """ - warnings.warn('Using `brainpy.math.to_object()` instead. Will be removed after version 2.4.0.', - UserWarning) + warnings.warn('`brainpy.math.function()` is deprecated; use `brainpy.math.to_object()` instead. ' + 'It will be removed in a future release.', + DeprecationWarning) return to_object(f, nodes, dyn_vars, name) diff --git a/brainpy/math/object_transform/jit.py b/brainpy/math/object_transform/jit.py index 0ce02f40b..0861000a3 100644 --- a/brainpy/math/object_transform/jit.py +++ b/brainpy/math/object_transform/jit.py @@ -20,6 +20,7 @@ """ +import warnings from typing import Callable, Union, Sequence, Iterable import brainstate.transform @@ -122,26 +123,27 @@ def jit( Parameters:: - - {jit_par} - dyn_vars : optional, dict, sequence of Variable, Variable - These variables will be changed in the function, or needed in the computation. - - .. deprecated:: 2.4.0 - No longer need to provide ``dyn_vars``. This function is capable of automatically - collecting the dynamical variables used in the target ``func``. - child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject - The children objects used in the target function. - .. deprecated:: 2.4.0 - No longer need to provide ``child_objs``. This function is capable of automatically - collecting the children objects used in the target ``func``. + {jit_par} Returns:: - + func : JITTransform A callable jitted function, set up for just-in-time compilation. """ + # ``dyn_vars`` and ``child_objs`` are no longer used; brainstate collects + # dynamical variables automatically. Pop them (with a one-time deprecation + # warning) rather than forwarding them into ``brainstate.transform.jit``, + # which would raise a TypeError on the unexpected keyword arguments. + for _deprecated in ('dyn_vars', 'child_objs'): + if _deprecated in kwargs: + kwargs.pop(_deprecated) + warnings.warn( + f'`{_deprecated}` is deprecated and ignored. This function automatically ' + f'collects the dynamical variables and child objects used in the target `func`.', + DeprecationWarning, + stacklevel=2, + ) return brainstate.transform.jit( warp_to_no_state_input_output(func), static_argnums=static_argnums, @@ -160,6 +162,7 @@ def cls_jit( func: Callable = Missing(), static_argnums: Union[int, Iterable[int], None] = None, static_argnames: Union[str, Iterable[str], None] = None, + donate_argnums: Union[int, Sequence[int], None] = None, inline: bool = False, keep_unused: bool = False, **kwargs @@ -197,19 +200,36 @@ def cls_jit( func : JITTransform A callable jitted function, set up for just-in-time compilation. """ + # The bound method exposes ``self`` as the first positional argument, so any + # caller-supplied positional indices must be shifted by +1. Negative indices + # count from the end and are unaffected by the prepended ``self``; shifting + # them would silently corrupt the target argument. + def _shift_positive(x): + return x + 1 if x >= 0 else x + if static_argnums is None: static_argnums = (0,) elif isinstance(static_argnums, int): - static_argnums = (0, static_argnums + 1,) + static_argnums = tuple(dict.fromkeys((0, _shift_positive(static_argnums)))) elif isinstance(static_argnums, (tuple, list)): - static_argnums = (0,) + tuple(jax.tree.map(lambda x: x + 1, static_argnums)) + static_argnums = tuple(dict.fromkeys((0,) + tuple(_shift_positive(x) for x in static_argnums))) else: raise ValueError('static_argnums is not supported yet.') + if donate_argnums is None: + donate_argnums = () + elif isinstance(donate_argnums, int): + donate_argnums = (_shift_positive(donate_argnums),) + elif isinstance(donate_argnums, (tuple, list)): + donate_argnums = tuple(_shift_positive(x) for x in donate_argnums) + else: + raise ValueError('donate_argnums is not supported yet.') + return jit( func=func, static_argnums=static_argnums, static_argnames=static_argnames, + donate_argnums=donate_argnums, inline=inline, keep_unused=keep_unused, **kwargs diff --git a/brainpy/math/object_transform/naming.py b/brainpy/math/object_transform/naming.py index 8842b2a12..ce5bc32e5 100644 --- a/brainpy/math/object_transform/naming.py +++ b/brainpy/math/object_transform/naming.py @@ -14,6 +14,7 @@ # limitations under the License. # ============================================================================== import warnings +import weakref from brainpy import _errors as errors @@ -21,7 +22,16 @@ 'clear_name_cache', ] -_name2id = dict() +# Maps a unique name to a *weak* reference of the object that owns it. +# +# Storing a weak reference (instead of the raw ``id(obj)``) has two benefits: +# 1. The registry no longer keeps objects alive, so it does not grow +# unboundedly as transient objects are created and discarded. +# 2. Once an owning object is garbage-collected, its name is treated as free +# again. Keying on ``id(obj)`` was unsafe because CPython readily reuses +# the integer id of a collected object for a brand-new one, which could +# trigger spurious ``UniqueNameError`` (or mask a genuine collision). +_name2id = dict() # name -> weakref.ref(obj) _typed_names = {} @@ -32,7 +42,11 @@ def check_name_uniqueness(name, obj): f'according to Python language definition. ' f'Please choose another name.') if name in _name2id: - if _name2id[name] != id(obj): + existing = _name2id[name]() # dereference the weak ref + # ``existing is None`` -> the previous owner has been collected, so the + # name is free and can be re-registered. + # ``existing is obj`` -> the same object re-registering its own name. + if existing is not None and existing is not obj: raise errors.UniqueNameError( f'In BrainPy, each object should have a unique name. ' f'However, we detect that {obj} has a used name "{name}". \n' @@ -40,8 +54,15 @@ def check_name_uniqueness(name, obj): f'>>> brainpy.math.clear_name_cache() \n\n' f'to clear all cached names. ' ) - else: - _name2id[name] = id(obj) + + # (Re)register the name with a weak reference to the current owner. A + # finalizer drops the entry as soon as the object is collected, keeping the + # registry bounded. + def _drop(_ref, _name=name): + if _name2id.get(_name) is _ref: + del _name2id[_name] + + _name2id[name] = weakref.ref(obj, _drop) def get_unique_name(type_: str): diff --git a/brainpy/math/object_transform/variables.py b/brainpy/math/object_transform/variables.py index 60671a148..b15f33eb4 100644 --- a/brainpy/math/object_transform/variables.py +++ b/brainpy/math/object_transform/variables.py @@ -106,10 +106,10 @@ def __init__( @property def size_without_batch(self): if self.batch_axis is None: - return self.size + return self.shape else: - sizes = self.size - return sizes[:self.batch_axis] + sizes[self.batch_axis + 1:] + s = self.shape + return s[:self.batch_axis] + s[self.batch_axis + 1:] @property def batch_axis(self) -> Optional[int]: @@ -141,6 +141,18 @@ def value(self): @value.setter def value(self, v): + # Normalize/unwrap the incoming value *before* validating its + # shape/dtype, so that ``Array``/``np.ndarray``/``brainstate.State`` + # wrappers are converted to a plain JAX array first. Otherwise the + # shape/dtype checks below would run against the wrapper (and a numpy + # value would never be canonicalized to the Variable's dtype). + if isinstance(v, brainstate.State): + v = v.value + if isinstance(v, Array): + v = v.value + elif isinstance(v, np.ndarray): + v = jnp.asarray(v) + _value = self.value ext_shape = jnp.shape(v) int_shape = jnp.shape(_value) @@ -156,20 +168,51 @@ def value(self, v): if ext_dtype != int_dtype: raise MathError(f"The dtype of the original data is {int_dtype}, " f"while we got {ext_dtype}.") - if isinstance(v, Array): - v = v.value - elif isinstance(v, np.ndarray): - v = jnp.asarray(v) - else: - v = v - if isinstance(v, brainstate.State): # value checking - v = v.value self._check_value_tree(v) # check the tree structure record_state_value_write(self) # record the value by the stack (>= level) self._been_writen = True # set the flag self._write_value(v) # write the value + # ------------------------------------------------------------------ + # Identity-based hashing / equality. + # + # ``__eq__`` is inherited from the array base class and performs an + # *element-wise* comparison (e.g. ``var == 0`` returns a boolean array), + # which is a useful and public behaviour we intentionally keep. The hash, + # however, must stay identity-based: every place in BrainPy that uses a + # ``Variable`` as a registry/dedup key keys on ``id(var)`` (see the + # collectors), so a value-based hash would be both incorrect (mutable + # value) and inconsistent with that usage. We pin ``__hash__`` here to + # make that contract explicit and stable. + # ------------------------------------------------------------------ + def __hash__(self): + return id(self) + + def tree_flatten(self): + # Carry ``batch_axis`` and ``axis_names`` through pytree round-trips. + # The base ``Array.tree_flatten`` returns ``aux_data=None``, which + # silently drops these attributes whenever a ``Variable`` is flattened + # and reconstructed (e.g. through ``jax.jit``/``vmap``). + return (self.value,), (self._batch_axis, self.axis_names) + + @classmethod + def tree_unflatten(cls, aux_data, flat_contents): + batch_axis, axis_names = aux_data + (value,) = flat_contents + # Rebuild without re-running ``Variable.__init__``: that would re-run + # the batch-axis validation and the (costly) ``State`` source-info + # capture on every unflatten. Set the ``_value`` slot first so that + # ``State.__init__`` can initialise the remaining bookkeeping fields + # (trace state, level, hooks, ...) correctly, then restore the + # variable-specific metadata. + obj = object.__new__(cls) + object.__setattr__(obj, '_value', value) + brainstate.State.__init__(obj, value) + object.__setattr__(obj, '_batch_axis', batch_axis) + object.__setattr__(obj, 'axis_names', axis_names) + return obj + def _get_dtype(v): if hasattr(v, 'dtype'): @@ -420,7 +463,11 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, keys, values): - return cls(jax.util.safe_zip(keys, values)) + # ``jax.util.safe_zip`` was removed in recent JAX. Reconstruct the + # mapping with a plain ``dict``; note that ``VarDict.update`` only + # understands ``dict`` (and single ``(k, v)`` tuples), so a bare + # ``zip`` iterator would be silently dropped here. + return cls(dict(zip(keys, values))) var_dict = VarDict diff --git a/brainpy/math/others.py b/brainpy/math/others.py index fcf4b8587..a95a34c80 100644 --- a/brainpy/math/others.py +++ b/brainpy/math/others.py @@ -17,6 +17,7 @@ import jax import jax.numpy as jnp +import numpy as np from jax.tree_util import tree_map from brainpy import check, tools @@ -91,9 +92,14 @@ def remove_diag(arr): """ if arr.ndim != 2: raise ValueError(f'Only support 2D matrix, while we got a {arr.ndim}D array.') - eyes = _return(jnp.ones(arr.shape, dtype=bool)) - fill_diagonal(eyes, False) - return jnp.reshape(arr[eyes.value], (arr.shape[0], arr.shape[1] - 1)) + arr = as_jax(arr) + m, n = arr.shape + # Static off-diagonal indices (computed with numpy so they are concrete + # constants and the gather traces cleanly under jit/vmap). + rows = np.repeat(np.arange(m), n - 1) + eye_mask = ~np.eye(m, n, dtype=bool) + cols = np.broadcast_to(np.arange(n), (m, n))[eye_mask] + return arr[rows, cols].reshape(m, n - 1) def clip_by_norm(t, clip_norm, axis=None): diff --git a/brainpy/math/pre_syn_post.py b/brainpy/math/pre_syn_post.py index d50f7c322..0b6bcff80 100644 --- a/brainpy/math/pre_syn_post.py +++ b/brainpy/math/pre_syn_post.py @@ -284,13 +284,24 @@ def pre2post_mean(pre_values, post_num, post_ids, pre_ids=None): post_val: ArrayType The value with the size of post-synaptic neurons. + + Notes:: + + When ``pre_values`` is a scalar, every connection carries the same constant + value, so the per-post mean is simply that constant. In this case the function + broadcasts the constant to every targeted post-synaptic neuron (untargeted + neurons stay ``0``). Duplicate ``post_ids`` therefore do not require any + averaging -- the mean of identical values equals the value itself. """ out = jnp.zeros(post_num) pre_values = as_jax(pre_values) post_ids = as_jax(post_ids) if jnp.ndim(pre_values) == 0: + # Scalar branch: every synapse carries the same constant ``pre_values``, + # so the mean over any group of post-synaptic targets is that constant. + # Broadcast it to the targeted posts (duplicate ``post_ids`` are harmless + # because the mean of identical values is the value itself). return out.at[post_ids].set(pre_values) - # return out.at[jnp.unique(post_ids)].set(pre_values) else: _raise_pre_ids_is_none(pre_ids) pre_ids = as_jax(pre_ids) @@ -515,7 +526,9 @@ def syn2post_mean(syn_values, post_ids, post_num: int, indices_are_sorted=False) syn_values = jnp.asarray(syn_values, dtype=jnp.int32) nominator = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) denominator = _jit_seg_sum(jnp.ones_like(syn_values), post_ids, post_num, indices_are_sorted) - return jnp.nan_to_num(nominator / denominator) + # Guard only the empty-group case (denominator == 0) instead of masking with + # ``nan_to_num``, which would also silently hide genuine NaNs in ``syn_values``. + return jnp.where(denominator > 0, nominator / denominator, 0.) def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=False): @@ -547,5 +560,10 @@ def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=Fal syn_values = syn_values - syn_maxs[post_ids] syn_values = jnp.exp(syn_values) normalizers = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) - softmax = syn_values / normalizers[post_ids] - return jnp.nan_to_num(softmax) + # ``normalizers[post_ids]`` is structurally >= 1 for every referenced post group + # (each such group contains at least the current synapse, contributing + # ``exp(0) == 1`` after the max-subtraction), so this division never hits a + # genuine 0/0. The previous ``jnp.nan_to_num`` only served to silently hide + # genuine NaNs produced upstream (e.g. NaNs already present in ``syn_values``), + # so it is intentionally removed to let such NaNs propagate. + return syn_values / normalizers[post_ids] diff --git a/brainpy/math/remove_vmap.py b/brainpy/math/remove_vmap.py index e49e0782a..103ef1925 100644 --- a/brainpy/math/remove_vmap.py +++ b/brainpy/math/remove_vmap.py @@ -27,6 +27,32 @@ def remove_vmap(x, op='any'): + """Reduce ``x`` with ``any``/``all`` *across the vmap batch axis as well*. + + This is a custom primitive whose batching rule deliberately collapses the + batch axis into a single **global** scalar. That is, when called under + :func:`jax.vmap`, ``remove_vmap(x, 'any')`` returns one ``bool`` summarising + *all* batch elements together (``True`` if any element of any batch is + truthy), rather than a per-batch vector of results. + + This is intentional: the primitive is used for global convergence / NaN-style + checks where the batch dimension must not survive the reduction. The batching + rule returns :data:`jax.interpreters.batching.not_mapped`, so the output is a + genuine unbatched scalar (it is *not* broadcast back across the batch axis). + + Parameters + ---------- + x : Array or jax.Array + The input array. ``brainpy.math.Array`` inputs are unwrapped. + op : {'any', 'all'} + The reduction to apply. ``'any'`` -> logical OR, ``'all'`` -> logical AND. + + Returns + ------- + jax.Array + A scalar boolean. Under :func:`jax.vmap` it is a single global scalar, + not a per-batch result. + """ if isinstance(x, Array): x = x.value if op == 'any': diff --git a/brainpy/math/scales.py b/brainpy/math/scales.py index 738a8eddf..5fd931906 100644 --- a/brainpy/math/scales.py +++ b/brainpy/math/scales.py @@ -73,17 +73,42 @@ def clone(self, bias=None, scale=None): class IdScaling(Scaling): + """Identity scaling: a no-op :class:`Scaling` with ``scale=1`` and ``bias=0``. + + Because the transform is the identity, custom ``bias``/``scale`` arguments are + meaningless. Passing non-default values is therefore rejected (rather than + being silently ignored) so the caller is not misled into thinking the values + take effect. + """ + def __init__(self): super().__init__(scale=1., bias=0.) + @staticmethod + def _reject_overrides(bias=None, scale=None): + if bias is not None and bias != 0.: + raise ValueError( + 'IdScaling is the identity transform and ignores "bias". ' + f'Got bias={bias}. Use a plain Scaling if you need a non-zero bias.' + ) + if scale is not None and scale != 1.: + raise ValueError( + 'IdScaling is the identity transform and ignores "scale". ' + f'Got scale={scale}. Use a plain Scaling if you need a non-unit scale.' + ) + def offset_scaling(self, x, bias=None, scale=None): + self._reject_overrides(bias, scale) return x def std_scaling(self, x, scale=None): + self._reject_overrides(scale=scale) return x def inv_scaling(self, x, scale=None): + self._reject_overrides(scale=scale) return x def clone(self, bias=None, scale=None): + self._reject_overrides(bias, scale) return IdScaling() diff --git a/brainpy/math/sharding.py b/brainpy/math/sharding.py index e9f8ec679..3ef62e2c6 100644 --- a/brainpy/math/sharding.py +++ b/brainpy/math/sharding.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import warnings from contextlib import contextmanager from functools import partial from typing import Optional, Any, Union, Sequence @@ -123,8 +124,20 @@ def get_sharding( if mesh is None: return None else: - axis_names = [(name if name in mesh.axis_names else None) for name in axis_names] - return NamedSharding(mesh, PartitionSpec(*axis_names)) + resolved = [(name if name in mesh.axis_names else None) for name in axis_names] + # If *every* requested axis name is absent from the mesh, the resulting + # PartitionSpec is fully replicated (all ``None``), which silently + # discards the user's sharding intent. Warn so this is not a silent + # no-op. Partial matches are tolerated (kept) on purpose. + if len(axis_names) > 0 and all(name is None for name in resolved): + warnings.warn( + f'None of the requested axis names {list(axis_names)} are present in the ' + f'mesh axes {tuple(mesh.axis_names)}. The array will be fully replicated ' + f'(PartitionSpec of all None). Check the axis names against the mesh.', + UserWarning, + stacklevel=2, + ) + return NamedSharding(mesh, PartitionSpec(*resolved)) def partition_by_axname( diff --git a/brainpy/math/sparse/__init__.py b/brainpy/math/sparse/__init__.py index c2769089b..d039e51bc 100644 --- a/brainpy/math/sparse/__init__.py +++ b/brainpy/math/sparse/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# from ._coo_mv import * +from .coo_mv import * from .csr_mm import * from .csr_mv import * from .jax_prim import * diff --git a/brainpy/math/sparse/coo_mv.py b/brainpy/math/sparse/coo_mv.py index f38063e73..a8e398e57 100644 --- a/brainpy/math/sparse/coo_mv.py +++ b/brainpy/math/sparse/coo_mv.py @@ -32,44 +32,39 @@ def coomv( vector: Union[jnp.ndarray, Array], *, shape: Tuple[int, int], - rows_sorted: bool = False, - cols_sorted: bool = False, transpose: bool = False, - method: str = 'cusparse' ): - """Product of COO sparse matrix and a dense vector using cuSPARSE algorithm. + """Product of COO sparse matrix and a dense vector. - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. + The ``brainevent`` COO format was removed in v0.1.0, so the COO indices are + converted to CSR (via :func:`brainevent.coo2csr`) and the multiplication is + delegated to :class:`brainevent.CSR`. - Parameters:: + This function supports JAX transformations, including ``jit()``, ``grad()``, + ``vmap()`` and ``pmap()``. - data: ndarray, float - An array of shape ``(nse,)``. - row: ndarray - An array of shape ``(nse,)``. - col: ndarray - An array of shape ``(nse,)`` and dtype ``row.dtype``. - vector: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` and - dtype ``data.dtype``. - shape: tuple of int - The shape of the sparse matrix. - rows_sorted: bool - Row index are sorted. - cols_sorted: bool - Column index are sorted. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - method: str - The method used to compute the matrix-vector multiplication. + Parameters + ---------- + data : ndarray, float + An array of shape ``(nse,)``. + row : ndarray + An array of shape ``(nse,)``. + col : ndarray + An array of shape ``(nse,)`` and dtype ``row.dtype``. + vector : ndarray + An array of shape ``(shape[0] if transpose else shape[1],)`` and + dtype ``data.dtype``. + shape : tuple of int + The shape of the sparse matrix. + transpose : bool + A boolean specifying whether to transpose the sparse matrix + before computing. - Returns:: - - y: ndarray - An array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. + Returns + ------- + y : ndarray + An array of shape ``(shape[1] if transpose else shape[0],)`` representing + the matrix vector product. """ if isinstance(data, Array): data = data.value @@ -79,7 +74,16 @@ def coomv( col = col.value if isinstance(vector, Array): vector = vector.value - csr = brainevent.COO((data, row, col), shape=shape) + + # The COO format was removed in brainevent 0.1.0; convert COO indices to + # CSR before delegating to brainevent.CSR. + indptr, indices, order = brainevent.coo2csr(row, col, shape=shape) + data = jnp.asarray(data) + if data.ndim == 0: + # scalar weight: broadcast to one entry per non-zero + data = jnp.broadcast_to(data, (indices.shape[0],)) + data = data[order] + csr = brainevent.CSR((data, indices, indptr), shape=shape) if transpose: return vector @ csr else: diff --git a/brainpy/math/sparse/csr_mm.py b/brainpy/math/sparse/csr_mm.py index c6056b0ad..97c212f47 100644 --- a/brainpy/math/sparse/csr_mm.py +++ b/brainpy/math/sparse/csr_mm.py @@ -61,6 +61,6 @@ def csrmm( matrix = matrix.value csr = brainevent.CSR((data, indices, indptr), shape=shape) if transpose: - return matrix @ csr + return csr.T @ matrix else: return csr @ matrix diff --git a/brainpy/math/sparse/utils.py b/brainpy/math/sparse/utils.py index ddab78677..3ff40619b 100644 --- a/brainpy/math/sparse/utils.py +++ b/brainpy/math/sparse/utils.py @@ -16,8 +16,8 @@ from typing import Tuple +import brainevent from jax import numpy as jnp -from jax.experimental.sparse import csr_todense from brainpy.math.interoperability import as_jax @@ -39,15 +39,15 @@ def coo_to_csr( post_ids = as_jax(post_ids) # sorting - sort_ids = jnp.argsort(pre_ids, kind='stable') + sort_ids = jnp.argsort(pre_ids, stable=True) post_ids = post_ids[sort_ids] indices = post_ids unique_pre_ids, pre_count = jnp.unique(pre_ids, return_counts=True) - final_pre_count = jnp.zeros(num_row) - final_pre_count[unique_pre_ids] = pre_count + final_pre_count = jnp.zeros(num_row, dtype=jnp.int32) + final_pre_count = final_pre_count.at[unique_pre_ids].set(pre_count) indptr = final_pre_count.cumsum() - indptr = jnp.insert(indptr, 0, 0) + indptr = jnp.insert(indptr, 0, 0).astype(jnp.int32) return indices, indptr @@ -61,4 +61,23 @@ def csr_to_coo( return jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1, indices -csr_to_dense = csr_todense +def csr_to_dense(data, indices, indptr, *, shape): + """Convert a CSR sparse matrix to a dense array. + + Parameters + ---------- + data : ndarray + An array of shape ``(nse,)`` holding the non-zero values. + indices : ndarray + An array of shape ``(nse,)`` holding the column index of each value. + indptr : ndarray + An array of shape ``(shape[0] + 1,)`` holding the row pointers. + shape : tuple of int + A length-2 tuple ``(n_rows, n_cols)`` for the dense matrix. + + Returns + ------- + dense : ndarray + The dense matrix of shape ``shape``. + """ + return brainevent.CSR((data, indices, indptr), shape=shape).todense() diff --git a/brainpy/math/surrogate/_one_input.py b/brainpy/math/surrogate/_one_input.py index 91a6734b0..6f6ac5162 100644 --- a/brainpy/math/surrogate/_one_input.py +++ b/brainpy/math/surrogate/_one_input.py @@ -192,7 +192,7 @@ def surrogate_fun(self, x): def surrogate_grad(self, dz, x): x = as_jax(x) - dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., dz * (-(self.alpha * x) ** 2 + self.alpha)) + dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., dz * (-self.alpha ** 2 * jnp.abs(x) + self.alpha)) return dx def __repr__(self): @@ -471,7 +471,7 @@ def surrogate_grad(self, dz, x): def surrogate_fun(self, x): x = as_jax(x) - return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5 + return jnp.arctan(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5 def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha})' @@ -654,7 +654,7 @@ def surrogate_grad(self, dz, x): def surrogate_fun(self, x): x = as_jax(x) - return sci.special.erf(-self.alpha * x) * 0.5 + return 0.5 * (1. - sci.special.erf(-self.alpha * x)) def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha})' @@ -1066,13 +1066,13 @@ def __init__(self, alpha=2., forward_use_surrogate=False): def surrogate_grad(self, dz, x): x = as_jax(x) - dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha) + dx = jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), -self.alpha) return dx * as_jax(dz) def surrogate_fun(self, x): x = as_jax(x) z = jnp.where(x < 0., - 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha), + 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * x, 1 - self.alpha), 1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha)) return z @@ -1443,7 +1443,7 @@ def __init__(self, sigma=0.5, alpha=0.5): def surrogate_grad(self, dz, x): x = as_jax(x) - dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma) + dx = jnp.exp(-(x ** 2) / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma) return self.alpha * dx * as_jax(dz) def __repr__(self): diff --git a/brainpy/math/surrogate/_one_input_new.py b/brainpy/math/surrogate/_one_input_new.py index 06ded6a44..908ab5eb8 100644 --- a/brainpy/math/surrogate/_one_input_new.py +++ b/brainpy/math/surrogate/_one_input_new.py @@ -176,7 +176,7 @@ def sigmoid( ): r"""Spike function with the sigmoid-shaped surrogate gradient. - If `origin=False`, return the forward function: + The forward function: .. math:: @@ -185,11 +185,6 @@ def sigmoid( 0, & x < 0 \\ \end{cases} - If `origin=True`, computes the original function: - - .. math:: - - g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}} Backward function: @@ -251,7 +246,7 @@ def surrogate_fun(self, x): def surrogate_grad(self, x): x = as_jax(x) - dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., (-(self.alpha * x) ** 2 + self.alpha)) + dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., (-self.alpha ** 2 * jnp.abs(x) + self.alpha)) return dx def __repr__(self): @@ -264,7 +259,7 @@ def piecewise_quadratic( ): r"""Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_. - If `origin=False`, computes the forward function: + The forward function: .. math:: @@ -273,16 +268,6 @@ def piecewise_quadratic( 0, & x < 0 \\ \end{cases} - If `origin=True`, computes the original function: - - .. math:: - - g(x) = - \begin{cases} - 0, & x < -\frac{1}{\alpha} \\ - -\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\ - 1, & x > \frac{1}{\alpha} \\ - \end{cases} Backward function: @@ -364,7 +349,7 @@ def piecewise_exp( ): r"""Judge spiking state with a piecewise exponential function [1]_. - If `origin=False`, computes the forward function: + The forward function: .. math:: @@ -373,14 +358,6 @@ def piecewise_exp( 0, & x < 0 \\ \end{cases} - If `origin=True`, computes the original function: - - .. math:: - - g(x) = \begin{cases} - \frac{1}{2}e^{\alpha x}, & x < 0 \\ - 1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0 - \end{cases} Backward function: @@ -454,7 +431,7 @@ def soft_sign( ): r"""Judge spiking state with a soft sign function. - If `origin=False`, computes the forward function: + The forward function: .. math:: @@ -463,12 +440,6 @@ def soft_sign( 0, & x < 0 \\ \end{cases} - If `origin=True`, computes the original function: - - .. math:: - - g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1) - = \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1) Backward function: @@ -526,7 +497,7 @@ def surrogate_grad(self, x): def surrogate_fun(self, x): x = as_jax(x) - return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5 + return jnp.arctan(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5 def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha})' @@ -539,7 +510,7 @@ def arctan( ): r"""Judge spiking state with an arctan function. - If `origin=False`, computes the forward function: + The forward function: .. math:: @@ -548,11 +519,6 @@ def arctan( 0, & x < 0 \\ \end{cases} - If `origin=True`, computes the original function: - - .. math:: - - g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2} Backward function: @@ -623,7 +589,7 @@ def nonzero_sign_log( ): r"""Judge spiking state with a nonzero sign log function. - If `origin=False`, computes the forward function: + The forward function: .. math:: @@ -632,21 +598,6 @@ def nonzero_sign_log( 0, & x < 0 \\ \end{cases} - If `origin=True`, computes the original function: - - .. math:: - - g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1) - - where - - .. math:: - - \begin{split}\mathrm{NonzeroSign}(x) = - \begin{cases} - 1, & x \geq 0 \\ - -1, & x < 0 \\ - \end{cases}\end{split} Backward function: @@ -707,7 +658,7 @@ def surrogate_grad(self, x): def surrogate_fun(self, x): x = as_jax(x) - return sci.special.erf(-self.alpha * x) * 0.5 + return 0.5 * (1. - sci.special.erf(-self.alpha * x)) def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha})' @@ -720,7 +671,7 @@ def erf( ): r"""Judge spiking state with an erf function [1]_ [2]_ [3]_. - If `origin=False`, computes the forward function: + The forward function: .. math:: @@ -729,15 +680,6 @@ def erf( 0, & x < 0 \\ \end{cases} - If `origin=True`, computes the original function: - - .. math:: - - \begin{split} - g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\ - &= \frac{1}{2} \text{erfc}(-\alpha x) \\ - &= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt - \end{split} Backward function: @@ -821,7 +763,7 @@ def piecewise_leaky_relu( ): r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_. - If `origin=False`, computes the forward function: + The forward function: .. math:: @@ -830,16 +772,6 @@ def piecewise_leaky_relu( 0, & x < 0 \\ \end{cases} - If `origin=True`, computes the original function: - - .. math:: - - \begin{split}g(x) = - \begin{cases} - cx + cw, & x < -w \\ - \frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\ - cx - cw + 1, & x > w \\ - \end{cases}\end{split} Backward function: @@ -940,7 +872,7 @@ def squarewave_fourier_series( ): r"""Judge spiking state with a squarewave fourier series. - If `origin=False`, computes the forward function: + The forward function: .. math:: @@ -949,11 +881,6 @@ def squarewave_fourier_series( 0, & x < 0 \\ \end{cases} - If `origin=True`, computes the original function: - - .. math:: - - g(x) = 0.5 + \frac{1}{\pi}*\sum_{i=1}^n {\sin\left({(2i-1)*2\pi}*x/T\right) \over 2i-1 } Backward function: @@ -1034,7 +961,7 @@ def s2nn( ): r"""Judge spiking state with the S2NN surrogate spiking function [1]_. - If `origin=False`, computes the forward function: + The forward function: .. math:: @@ -1043,14 +970,6 @@ def s2nn( 0, & x < 0 \\ \end{cases} - If `origin=True`, computes the original function: - - .. math:: - - \begin{split}g(x) = \begin{cases} - \mathrm{sigmoid} (\alpha x), x < 0 \\ - \beta \ln(|x + 1|) + 0.5, x \ge 0 - \end{cases}\end{split} Backward function: @@ -1115,13 +1034,13 @@ def __init__(self, alpha=2.): def surrogate_grad(self, x): x = as_jax(x) - dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha) + dx = jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), -self.alpha) return dx def surrogate_fun(self, x): x = as_jax(x) z = jnp.where(x < 0., - 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha), + 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * x, 1 - self.alpha), 1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha)) return z @@ -1136,7 +1055,7 @@ def q_pseudo_spike( ): r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_. - If `origin=False`, computes the forward function: + The forward function: .. math:: @@ -1145,15 +1064,6 @@ def q_pseudo_spike( 0, & x < 0 \\ \end{cases} - If `origin=True`, computes the original function: - - .. math:: - - \begin{split}g(x) = - \begin{cases} - \frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\ - 1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0. - \end{cases}\end{split} Backward function: @@ -1229,7 +1139,7 @@ def leaky_relu( ): r"""Judge spiking state with the Leaky ReLU function. - If `origin=False`, computes the forward function: + The forward function: .. math:: @@ -1238,15 +1148,6 @@ def leaky_relu( 0, & x < 0 \\ \end{cases} - If `origin=True`, computes the original function: - - .. math:: - - \begin{split}g(x) = - \begin{cases} - \beta \cdot x, & x \geq 0 \\ - \alpha \cdot x, & x < 0 \\ - \end{cases}\end{split} Backward function: @@ -1330,7 +1231,7 @@ def log_tailed_relu( ): r"""Judge spiking state with the Log-tailed ReLU function [1]_. - If `origin=False`, computes the forward function: + The forward function: .. math:: @@ -1339,16 +1240,6 @@ def log_tailed_relu( 0, & x < 0 \\ \end{cases} - If `origin=True`, computes the original function: - - .. math:: - - \begin{split}g(x) = - \begin{cases} - \alpha x, & x \leq 0 \\ - x, & 0 < x \leq 0 \\ - log(x), x > 1 \\ - \end{cases}\end{split} Backward function: @@ -1489,7 +1380,7 @@ def __init__(self, sigma=0.5, alpha=0.5): def surrogate_grad(self, x): x = as_jax(x) - dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma) + dx = jnp.exp(-(x ** 2) / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma) return self.alpha * dx def __repr__(self): diff --git a/brainpy/measure.py b/brainpy/measure.py index 71c1c2771..4afd2a3a5 100644 --- a/brainpy/measure.py +++ b/brainpy/measure.py @@ -89,7 +89,10 @@ def firing_rate(spikes, width, dt=None, numpy=True): np = onp if numpy else jnp dt = bm.get_dt() if (dt is None) else dt width1 = int(width / 2 / dt) * 2 + 1 - window = np.ones(width1) * 1000 / width + # Normalize by the actual window length (``width1`` bins of ``dt`` ms each), + # converting a per-bin spike count into a rate in Hz. Normalizing by the + # requested ``width`` instead biases the rate by ``width1 * dt / width``. + window = np.ones(width1) / (width1 * dt) * 1000. return np.convolve(np.mean(spikes, axis=1), window, mode='same') diff --git a/brainpy/optim/optimizer.py b/brainpy/optim/optimizer.py index 312b2376f..57e1862c1 100644 --- a/brainpy/optim/optimizer.py +++ b/brainpy/optim/optimizer.py @@ -570,6 +570,10 @@ def __init__( self.beta1 = beta1 self.beta2 = beta2 self.eps = eps + # Per-update step counter used for bias correction. It is a Variable so + # that it is traceable under JAX transforms and advances exactly once per + # call to ``update()`` (independent of any LR scheduler's ``last_epoch``). + self.step = bm.Variable(jnp.asarray(0)) def __repr__(self): return (f"{self.__class__.__name__}(lr={str(self.lr)}, " @@ -589,9 +593,14 @@ def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = Non def update(self, grads: dict): self.check_grads(grads) - lr = self.lr() - lr /= (1 - self.beta1 ** (self.lr.last_epoch.value + 2)) - lr *= jnp.sqrt(1 - self.beta2 ** (self.lr.last_epoch.value + 2)) + # Advance the per-update step counter (t = 1 on the first update). + self.step.value = self.step.value + 1 + t = self.step.value + # Read the (possibly scheduled) learning rate as a plain JAX array so we + # never mutate the underlying ``lr`` Variable in place. + lr = bm.as_jax(self.lr()) + lr = lr / (1 - self.beta1 ** t) + lr = lr * jnp.sqrt(1 - self.beta2 ** t) for key, p in self.vars_to_train.items(): m = self.implicit_vars[key + '_m'] v = self.implicit_vars[key + '_v'] @@ -931,6 +940,8 @@ def __init__( self.beta2 = beta2 self.eps = eps self.weight_decay = weight_decay + # Per-update step counter for bias correction (see ``Adam``). + self.step = bm.Variable(jnp.asarray(0)) def __repr__(self): return (f"{self.__class__.__name__}(lr={self.lr}, " @@ -960,8 +971,11 @@ def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = Non def update(self, grads: dict): self.check_grads(grads) - lr_old = self.lr() - step = self.lr.last_epoch.value + 2 + # Advance the per-update step counter (t = 1 on the first update) and read + # the learning rate as a plain JAX array (never mutate the lr Variable). + self.step.value = self.step.value + 1 + step = self.step.value + lr_old = bm.as_jax(self.lr()) bias_correction1 = 1 - self.beta1 ** step bias_correction2 = 1 - self.beta2 ** step lr = lr_old * jnp.sqrt(bias_correction2) / bias_correction1 @@ -1036,11 +1050,10 @@ def __init__( weight_decay: Optional[float] = None, name: Optional[str] = None, ): - super(SM3, self).__init__(lr=lr, - weight_decay=weight_decay, - train_vars=train_vars, - name=name) - + # NOTE: validate and assign hyper-parameters *before* ``super().__init__``. + # The base ``Optimizer.__init__`` calls ``register_train_vars`` (when + # ``train_vars`` is given), which reads ``self.momentum``; assigning these + # attributes afterwards left the optimizer un-instantiable with train vars. if not 0.0 <= momentum < 1.0: raise ValueError("Invalid momentum: {0}".format(momentum)) if not 0.0 <= beta < 1.0: @@ -1052,6 +1065,11 @@ def __init__( self.beta = beta self.momentum = momentum + super(SM3, self).__init__(lr=lr, + weight_decay=weight_decay, + train_vars=train_vars, + name=name) + def __repr__(self): return (f"{self.__class__.__name__}(lr={self.lr}, " f"beta={self.beta}, eps={self.eps}, momentum={self.momentum})") @@ -1093,7 +1111,7 @@ def update(self, grads: dict): result = update for j in range(ndim): if i != j: - result = result.max(axis=j, keepdim=True) + result = result.max(axis=j, keepdims=True) acc = self.implicit_vars[f'{k}_m{i}'] if self.beta > 0.: acc.value = bm.maximum(acc, result) diff --git a/brainpy/optim/scheduler.py b/brainpy/optim/scheduler.py index 5941727a1..2ae0d307e 100644 --- a/brainpy/optim/scheduler.py +++ b/brainpy/optim/scheduler.py @@ -158,8 +158,10 @@ def __init__( def call(self, i=None): i = (self.last_epoch.value + 1) if i is None else i milestones = jnp.asarray(self.milestones) - conditions = jnp.logical_and((i >= milestones[:-1]), (i < milestones[1:])) - p = jnp.argmax(conditions) + # Number of milestones strictly before epoch ``i``: lr decays *after* a + # milestone epoch is reached, so e.g. milestones=[10, 20] keeps the base + # lr through epoch 10 and applies the first decay from epoch 11 onward. + p = jnp.sum(milestones < i) return self.lr * self.gamma ** p def __call__(self, i=None): diff --git a/brainpy/runners.py b/brainpy/runners.py index 8e107f51a..9afb9951d 100644 --- a/brainpy/runners.py +++ b/brainpy/runners.py @@ -614,7 +614,9 @@ def _get_input_time_step(self, duration=None, xs=None) -> int: else: raise ValueError - def _step_mon_on_cpu(self, args, transforms): + def _step_mon_on_cpu(self, args): + # host-side side effect: append the per-step monitored values (passed in + # as a dict of numpy arrays) to the running monitor lists. for key, val in args.items(): self.mon[key].append(val) @@ -636,12 +638,15 @@ def _step_func_predict(self, i, *x, shared_args=None): clear_input(self.target) if self._memory_efficient: - mon_shape_dtype = jax.ShapeDtypeStruct(mon.shape, mon.dtype) - result = jax.pure_callback( - self._step_mon_on_cpu, - mon_shape_dtype, - mon, - ) + # ``mon`` is a dict of arrays. Offload the monitored values to the + # host on every step (keeping them out of device memory) using a + # side-effecting callback. ``jax.debug.callback`` is used instead of + # ``jax.pure_callback`` because the callback's purpose is the host + # side effect (appending to ``self.mon`` lists) and its result is + # discarded -- a ``pure_callback`` whose output is unused would be + # eliminated by XLA's dead-code elimination under ``jit``. + mon = tree_map(lambda x: bm.as_jax(x), mon, is_leaf=_is_brainpy_array) + jax.debug.callback(self._step_mon_on_cpu, mon) return out, None else: return out, mon diff --git a/brainpy/running/jax_multiprocessing.py b/brainpy/running/jax_multiprocessing.py index afcc52e89..5375d44a8 100644 --- a/brainpy/running/jax_multiprocessing.py +++ b/brainpy/running/jax_multiprocessing.py @@ -133,13 +133,17 @@ def jax_parallelize_map( res_tree = None results = None - vmap_func = pmap(func) + # Build the pmapped function once and reuse it across all chunks. Re-applying + # ``jax.pmap`` inside the loop forces a recompilation on every chunk, which is + # both slow and unnecessary since the traced function does not change. + pmap_func = pmap(func) for i in range(0, num_pars[0], num_parallel): - run_f = pmap(func) if clear_buffer else vmap_func if isinstance(arguments, dict): - r = run_f(**tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) + r = pmap_func(**tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) + elif isinstance(arguments, (tuple, list)): + r = pmap_func(*tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) else: - r = run_f(*tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) + raise TypeError(f'"arguments" must be sequence or dict, but we got {type(arguments)}') res_values, res_tree = tree_flatten(r, is_leaf=lambda a: isinstance(a, bm.Array)) if results is None: results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values) diff --git a/docs/issues-found-20260618.md b/docs/issues-found-20260618.md new file mode 100644 index 000000000..e11674de1 --- /dev/null +++ b/docs/issues-found-20260618.md @@ -0,0 +1,387 @@ +# BrainPy Package Audit — Issues Found (2026-06-18) + +**Reviewer role:** Senior Python architect · JAX expert · BrainX-ecosystem developer +**Package:** `brainpy` v2.7.8 (~74k LOC, 252 non-test files, 17 submodules) +**Environment (verified):** Python 3.13.11 · jax 0.10.1 · brainpy 2.7.8 · brainstate 0.5.0 · brainevent 0.1.0 (CPU) +**Method:** Full deep sweep by parallel expert sub-audits of every submodule; static review of all findings + executable repro for the high-severity ones. 33 of the highest-impact findings were independently re-verified by the lead reviewer; all reproduced. + +> **Scope note.** `import brainpy` works in this environment, so most findings are *runtime-reproduced*, not speculative. Findings are tagged **[verified]** (a repro was executed and reproduced the bug), **[static]** (confirmed by code inspection / type analysis), or **[likely]** (strong reasoning, not executed). + +--- + +## 1. Executive summary + +The audit found **131 distinct issues**: **26 Critical**, **53 High**, **36 Medium**, **16 Low**. The dominant story is **ecosystem-migration drift**: BrainPy 2.7.x was rebased onto the new BrainX stack (`brainstate` 0.5, `brainevent` 0.1, `braintools`) and onto JAX ≥0.9/0.10, and many code paths were not updated in lockstep. The result is a band of **silent numerical errors** and **crash-on-first-use** bugs concentrated in: the optimizer/loss/scheduler stack, the surrogate-gradient and synapse/plasticity code, the sparse/event operators, the FDE/adaptive integrators, and the normalization layers. + +Highest-impact, broad-blast-radius issues (all **verified**): + +| # | Issue | Location | Impact | +|---|-------|----------|--------| +| C-01 | **Adam/AdamW bias correction frozen at t=1** | `optim/optimizer.py:593-594,964-967` | Every Adam/AdamW training run uses wrong (un-debiased, growing) steps | +| C-02 | **`nll_loss` returns +log-likelihood (sign flipped)** | `losses/comparison.py:461` | NLL training maximizes instead of minimizes | +| C-03 | **`cross_entropy_loss` weights by sample index, not class** | `losses/comparison.py:266-267` | Class-weighted CE silently wrong / shape-crashes | +| C-04 | **`MultiStepLR` never decays** | `optim/scheduler.py:157-163` | LR schedule is a no-op | +| C-05 | **`GroupNorm`/`InstanceNorm` reduce over the group axis** | `dnn/normalization.py:597` | `num_groups` has no effect; every config == LayerNorm | +| C-06 | **STP facilitation ODE diverges** | `dyn/synapses/abstract_models.py:862` | Short-term plasticity synapse blows up to ±thousands | +| C-07 | **`csrmm(transpose=True)` computes the wrong product** | `math/sparse/csr_mm.py:63`, `math/event/csr_matmat.py:64` | Wrong values / shape crash in sparse matmat + its autodiff | +| C-08 | **`CaputoEuler` mis-scales the initial condition** | `integrators/fde/Caputo.py:201` | Fractional ODE solver wrong whenever y0≠0 | +| C-09 | **`TimeDelay` read omits modulo** | `math/delayvars.py:271` | Delay variable returns stale/wrong delayed values | +| C-10 | **`disable_x64()` desyncs brainstate vs JAX precision** | `math/environment.py:645` | After any x64 context, default dtypes silently wrong | + +The good news: the **core ODE Runge–Kutta tableaus, most synapse kinetics, the GRU cell, the high-dim fixed-point finder, weight initializers, and the Conv/Dense/Dropout layers were checked and found correct** (see Appendix B). The bugs are concentrated, not pervasive — which makes them tractable to fix. + +--- + +## 2. Cross-cutting themes (root causes) + +1. **brainstate 0.5 migration drift.** `pyproject` pins `brainstate>=0.2.7` but 0.5.0 is installed. Removed/renamed APIs surface as runtime crashes: `tracing_variable` now `raise NotImplementedError` (breaks default STDP, C-19), `jax.util` removed (breaks `VarDict` pytree, C-25), `State`-as-operand rejected in control flow (H-…), `Variable.size_without_batch` broken (H-…). **Action:** pin a tested `brainstate`/`brainevent` lower bound and add an import-time smoke test across the public surface. + +2. **JAX ≥0.9/0.10 API changes not propagated.** `Array.device()` (now a property), `csr_todense`/`csrmm` signatures, `jnp.argsort(kind=)` removed, `__float__` rejecting `ndim>0`. **Action:** a compatibility shim module + CI against the pinned JAX. + +3. **`brainevent` 0.1 backend migration left wrappers stale.** `brainevent.COO` removed (C, coomv dead), transpose semantics inverted in matmat (C-07), `coo_to_csr`/`csr_to_dense` broken, jitconn docstrings describe a `method=`/cuSPARSE API that no longer exists. + +4. **Surrogate-gradient ⇄ forward-function inconsistency.** Multiple surrogate classes have a `surrogate_grad` formula that does **not** match the derivative of their own `surrogate_fun` / docstring (Gaussian precedence, PiecewiseQuadratic, q-PseudoSpike, ERF sign, arctan crash). Compounded by `bm.surrogate` being **shadowed by `braintools.surrogate`**, so the in-repo package is dead relative to the public API yet still importable and buggy. + +5. **Validation-after-mutation ordering.** Several setters/config functions mutate state or read attributes *before* validating/normalizing inputs: `environment.set()` (partial-config leak), `Variable.value` setter (rejects `State`/numpy before unwrap), adaptive-RK `tol` default not propagated, `SM3.__init__` reads `self.momentum` before assignment. + +6. **Batched-math assumptions.** RLS/FORCE (C-23) and several reductions assume batch size 1; correct for the tested path, silently divergent for B>1. + +7. **`dt` vs `sqrt(dt)` and unit scaling in stochastic/rate models.** `ThresholdLinearModel` noise scales as `dt` not `sqrt(dt)`; PoissonInput uses variance as std (C-17); `CondNeuGroup` double-applies area scaling. + +8. **Docstring/NumPy-doc nonconformance & drift.** Pervasive `Parameters::` / `Returns::` literal-block markers (won't render), stale deprecation versions ("removed after 2.4.0" in 2.7.8), and docstrings whose constants/defaults disagree with code. + +--- + +## 3. Critical findings (detail) + +### C-01 — Adam/AdamW bias correction is frozen at t=1 **[verified]** +- **File:** `brainpy/optim/optimizer.py:593-594` (Adam), `:964-967` (AdamW); root cause `optim/scheduler.py:55-59` +- **What:** Bias correction uses `self.lr.last_epoch.value + 2`, but with the default `Constant` scheduler `last_epoch` is never incremented during `update()` (only `step_epoch()` advances it, which optimizers never call). So `beta**(last_epoch+2) == beta**1` forever and the `m`/`v` EMAs are never debiased. +- **Why it's wrong:** Under a constant gradient, correct Adam yields a constant step ≈ `-lr`. Measured steps instead grow: `dw = [-0.001, -0.00134, -0.00157, -0.00172, -0.00183]`. +- **Fix:** Maintain an internal per-`update()` step counter `t` (independent of the LR scheduler) and use `beta1**t`, `beta2**t` for bias correction. Don't derive `t` from `last_epoch`. + +### C-02 — `nll_loss` returns the log-likelihood, not its negative **[verified]** +- **File:** `brainpy/losses/comparison.py:461` (class `NLLLoss` wraps it) +- **What:** `return mean(input[arange, target])` with no negation; the function's own docstring defines `-Σ w·x_{n,y_n}`. +- **Measured:** `nll_loss(log p, [0,1]) = -0.2899` (correct `+0.2899`). Minimizing drives the correct-class log-prob to −∞. +- **Fix:** `loss = -input[jnp.arange(len(target)), target]` (negate), keep the reductions. + +### C-03 — `cross_entropy_loss` applies class `weight` by sample index **[verified]** +- **File:** `brainpy/losses/comparison.py:266-267` (`loss *= weight`) +- **What:** Per-sample loss `(N,)` is multiplied elementwise by the per-class weight `(C,)` — so sample *n* is weighted by `weight[n]`, not `weight[target_n]`. Raises on `N≠C`, silently wrong on `N==C`. +- **Measured:** logits `0(3,3)`, targets `[2,2,2]`, weight `[10,20,1]` → per-sample `[10.99, 21.97, 1.10]`; correct is all `1.10` (`w[2]`). +- **Fix:** gather `weight[target]` before reduction; for `mean`, normalize by `sum(weight[targets])`. + +### C-04 — `MultiStepLR` never decays **[verified]** +- **File:** `brainpy/optim/scheduler.py:157-163` +- **What:** `conditions = (i>=milestones[:-1]) & (i None`. +- **Measured:** `odeint(..., method='rkf45', adaptive=True)(...)` → `TypeError: '>' not supported between ArrayImpl and NoneType`. +- **Fix:** `code_scope['tol'] = self.tol` (and keyword default likewise). + +### C-13 — All SDE integrators `NameError` on invalid type (missing `errors` import) **[verified]** +- **File:** `brainpy/integrators/sde/base.py:76,79,82`; `sde/normal.py:225` +- **What:** Validation references `errors.IntegratorError` but `errors` is never imported; also the `Heun` Ito/Stratonovich guard. +- **Measured:** `sdeint(..., intg_type='WRONG')` → `NameError: name 'errors' is not defined`. +- **Fix:** `from brainpy import _errors as errors` in both files. + +### C-14 — Standalone HH/Markov channel gating produces NaN at voltage singularities **[verified by sub-audit]** +- **File:** `brainpy/dyn/channels/sodium.py:384,299,215`; `potassium.py:359,222,290` (+legacy dups `:1191,1261,1332`); `calcium.py:711` +- **What:** Rates coded as `k*temp/(1-exp(-temp/d))` are 0/0 → NaN exactly at the removable singularity (e.g. `IK_HH1952v2` at V=−55). The HH *neuron* class was fixed with `bm.exprel`; the channel modules were not. `bm.where` clamping can't recover it (both branches evaluated). +- **Measured:** `IK_HH1952v2(1).f_p_alpha([-55.0]) = [nan]`. +- **Fix:** rewrite with `bm.exprel`, e.g. `0.1 / bm.exprel(-(V - V_sh + 10)/10)` (mind the `k*d` coefficient bookkeeping). Fix legacy duplicates too. + +### C-15 — `ThresholdLinearModel` noise path crashes (`randn` signature) **[verified]** +- **File:** `brainpy/dyn/rates/populations.py:1051,1060` +- **What:** `bm.random.randn(self.varshape)` passes a shape *tuple* as a single positional arg; brainstate's `randn` takes unpacked dims. +- **Measured:** any nonzero `noise_e/noise_i` → `TypeError: Shapes must be 1D sequences ... got ((1000,),)`. +- **Fix:** `bm.random.randn(*self.varshape)` or `bm.random.normal(size=self.varshape)`. (Separately, the noise scales as `dt` not `sqrt(dt)` — see M-…) + +### C-16 — `StuartLandauOscillator.dy` has the wrong rotational coupling **[verified]** +- **File:** `brainpy/dyn/rates/populations.py:721` +- **What:** `dy` returns `(a-x²-y²)*y - w*y + y_ext` (copy-paste from `dx`); the Hopf normal form needs `+ w*x`. As written there's no x↔y rotation, so no limit cycle. +- **Measured:** `dy(y=.5,x=.3,a=.25,w=.2) = -0.145` (buggy `-w*y`); correct `+w*x` gives `+0.015`. +- **Fix:** `return (a - x*x - y*y)*y + w*x + y_ext`. + +### C-17 — `PoissonInput` Gaussian branch uses the variance as the std (~3–4× too much noise) **[verified]** +- **File:** `brainpy/dyn/projections/inputs.py:168,174`; duplicated in `brainpy/dynold/experimental/others.py:74-77` +- **What:** `bm.random.normal(a, b*p, ...)` passes `b*p = n(1-p)p` (the Binomial *variance*) as the std; correct is `sqrt(n·p·(1-p)) = sqrt(b*p)`. +- **Measured:** `n=1000,p=0.02` → code std `19.6` vs correct `4.43`. Active in the common large-N branch; mean is correct. +- **Fix:** `scale = jnp.sqrt(b*p)` (both the eager and `bm.cond` branches, and the dynold copy). + +### C-18 — `HalfProjAlignPost.update` calls `comm` twice **[verified by sub-audit]** +- **File:** `brainpy/dyn/projections/align_post.py:384-388` +- **What:** Computes `current = self.comm(x)` then `g = self.syn(self.comm(x))` — two independent calls. For event/jit-prob comms each call draws fresh random connectivity, so the synapse sees different input than the returned current; doubles compute for deterministic comms. +- **Fix:** `current = self.comm(x); g = self.syn(current); ...; return current`. + +### C-19 — `STDP_Song2000` crashes on the first update (tracing_variable removed) **[verified by sub-audit]** +- **File:** `brainpy/dyn/projections/plasticity.py:230-240` → `brainpy/dnn/linear.py:502-503` +- **What:** `stdp_update` falls back to `self.tracing_variable('weight', ...)`, which now unconditionally `raise NotImplementedError`. The weight is only a `Variable` when the comm is built with `mode=TrainingMode`; the class docstring example omits it, so the documented usage is dead on arrival. +- **Fix:** in `stdp_update`, promote the weight directly (`self.weight = bm.Variable(self.weight)`) or require trainable weights in `STDP_Song2000`. Also fixes a companion crash (H-…: `bm.as_jax(None)` for default `W_min/W_max`). + +### C-20 — `AlphaCUBA` / `AlphaCOBA` raise `ZeroDivisionError` on construction **[verified]** +- **File:** `brainpy/dynold/synapses/compat.py:208-270`; root `brainpy/dyn/synapses/abstract_models.py:159-164` +- **What:** They pass `tau_rise == tau_decay` into `DualExpon`, whose peak normalizer `A = tau_decay/(tau_decay - tau_rise)·…` divides by zero. +- **Measured:** `bp.synapses.AlphaCUBA(LIF(2), LIF(2), All2All(), tau_decay=10.)` → `ZeroDivisionError`. +- **Fix:** route `AlphaCUBA/COBA` through the single-tau `synapses.Alpha`, or special-case `tau_rise==tau_decay` (L'Hôpital limit `A=e`, `a=1/tau`). + +### C-21 — dynold `STP` learning rule injects current with zero presynaptic spikes **[verified by sub-audit]** +- **File:** `brainpy/dynold/synapses/learning_rules.py:33-37,231-233` +- **What:** `_STPModel = Sequential(STP, Expon)`; modern `STP.update` returns `u*x` (≈0.15 at rest) every step, and `Expon` treats it as additive current, so `g += u*x` continuously. The spike gating is lost. +- **Measured:** zero input → `syn.I` ramps to ~512 and keeps rising. +- **Fix:** gate by spikes (`pre_spike*(u*x)`), or use the modern `dyn/projections/plasticity` wiring. + +### C-22 — `DSRunner(memory_efficient=True)` is completely non-functional **[verified by sub-audit]** +- **File:** `brainpy/runners.py:638-647` (+ `_step_mon_on_cpu` :617-619) +- **What:** `_step_func_monitor()` returns a dict, but the code does `jax.ShapeDtypeStruct(mon.shape, mon.dtype)` on it; the `pure_callback` arg count and the `None` return are also wrong. Cannot have worked since the migration. +- **Measured:** any `memory_efficient=True` run → `AttributeError: 'dict' object has no attribute 'shape'`. +- **Fix:** `jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), mon)`; fix the callback signature/return; add a smoke test. + +### C-23 — RLS / FORCE online update is wrong for batch size > 1 **[verified by sub-audit]** +- **File:** `brainpy/algorithms/online.py:148-154` (drives `train/online.py` `OnlineTrainer`/`ForceTrainer`) +- **What:** `c = jnp.sum(1.0/(1.0+hPh))` collapses the `B×B` matrix `(I+HPHᵀ)` to a scalar (summing reciprocals of all entries incl. off-diagonals). Correct only for B=1; for B>1 `c` grows with B and can go negative → `P` diverges, update sign flips. +- **Measured:** B=16 → `c=-90.3` (correct diag value `+1.7`); fitting with B≥4 → NaN weights within a few hundred steps. +- **Fix:** proper block RLS: `K = PHᵀ(I + HPHᵀ)⁻¹` via `jnp.linalg.solve`, or assert `input.shape[0]==1`. + +### C-24 — `PoissonEncoder.single_step` crashes (its own documented usage) **[verified]** +- **File:** `brainpy/encoding/stateless_encoding.py:91-94` → `:111` +- **What:** `single_step(x, i_step=None)` delegates to `multi_steps(x, n_time=None)` whose first line is `int(n_time/get_dt())` → `None/float`. +- **Measured:** `PoissonEncoder().single_step(bm.random.rand(4))` → `TypeError`. +- **Fix:** in the `i_step is None` branch, draw a single Bernoulli sample directly; guard `multi_steps` against `n_time is None`. + +### C-25 — `VarDict.tree_unflatten` crashes on every JAX transform (`jax.util` removed) **[verified by sub-audit]** +- **File:** `brainpy/math/object_transform/variables.py:423` +- **What:** Calls `jax.util.safe_zip(...)`, but `jax.util` no longer exists in jax 0.10.1. `VarDict` is a registered pytree, so any `jit`/`vmap`/`tree_map` over one fails. +- **Measured:** `jax.jit(lambda d: d)(bm.var_dict({'a': bm.Variable(...)}))` → `AttributeError: module 'jax' has no attribute 'util'`. +- **Fix:** use `brainstate._compatible_import.safe_zip` or `cls(zip(keys, values))`. + +### C-26 — `Variable` batch_axis / axis_names silently dropped through pytree round-trip **[verified by sub-audit]** +- **File:** `brainpy/math/object_transform/variables.py:40,79` (inherits `Array.tree_flatten` returning `aux_data=None`) +- **What:** Reconstructing a `Variable` after `jit`/`vmap`/`scan`/`tree_map` loses `batch_axis`/`axis_names` (reset to `None`). brainstate's closure-based transforms work around it, but explicit pytree / `jit`-argument use degrades silently (affects sharding, `size_without_batch`, value-setter shape checks). +- **Fix:** override `Variable.tree_flatten/unflatten` to carry `(batch_axis, axis_names)` in `aux_data` and rebuild without re-running naming/`State.__init__` side effects. + +--- + +## 4. High findings (detail) + +> Format: `[verified|static] file:line — what → fix` + +### Object model / transforms (`math/object_transform`) +- **H-01 [static]** `jit.py:200-207` — `cls_jit` shifts `static_argnums` by `+1` unconditionally, corrupting **negative** indices (`-1` → `(0,0)` marks `self` static twice). → shift only `x>=0`; resolve negatives against the signature. +- **H-02 [verified]** `controls.py:125-561` + `_utils.py:78-85` — passing a `Variable`/`State` in `operands` of `cond`/`for_loop`/`while_loop`/`scan` raises (`State` rejected at brainstate cache-key time, before the in-wrapper strip). → strip state from `operands` before forwarding; document closure capture as the supported path. +- **H-03 [verified]** `controls.py:372-390` — `for_loop(jit=False)` zero-length guard only checks `operands[0].shape`; a pytree operand (dict) has no `.shape`, so it crashes with "zero-length scan … in disable_jit()". → compute leading length from `jax.tree.leaves`. +- **H-04 [verified]** `jit.py:127-153` — `jit` docstrings advertise `dyn_vars`/`child_objs`, now forwarded into `**kwargs` → `TypeError` from brainstate. → drop from docstring; filter/warn in `jit`. +- **H-05 [verified]** `base.py:603-614` — `to()/cpu()/cuda()/tpu()` iterate `state_dict()` (nested dicts), so `isinstance(var, Array)` is always False; they never move variables and inject junk dict-valued attributes named after nodes. → iterate `self.vars().values()` and set `var.value = jax.device_put(...)`. +- **H-06 [verified]** `variables.py:143-172` — `Variable.value` setter validates shape/dtype on the raw input *before* unwrapping `State`/numpy, so `v.value = some_State` and `v32.value = np.float64_array` raise spurious `MathError`. → unwrap/convert first, then validate. +- **H-07 [verified]** `naming.py:24,34-44` — global `_name2id` registry grows unboundedly (no weakref/GC pruning) and stores `id(obj)`, so reused ids cause false `UniqueNameError`. → `WeakValueDictionary` + prune dead refs. +- **H-08 [verified]** `base.py:287-309` vs `collectors.py:198` — `register_implicit_vars` default `var_cls` accepts `VarList/VarDict`, but `ArrayCollector.__setitem__` asserts `isinstance(value, Variable)` → `AssertionError`. → flatten containers or relax the collector. +- **H-09 [verified]** `variables.py:41` — `Variable.__eq__` returns an elementwise array while `__hash__` is identity-based → breaks `in`/set/dict-by-value and raises ambiguous-truth. → define identity `__eq__`/`__ne__`, or guarantee all internal membership uses `id()`. + +### Math core / compat / sparse +- **H-10 [verified]** `modes.py:38-41` — `Mode` overrides `__eq__` without `__hash__`, so every mode is unhashable (regression vs the hashable brainstate parent). → `__hash__ = brainstate.mixin.Mode.__hash__`. +- **H-11 [verified]** `ndarray.py:206-207` — `Array.device()` calls the now-property `jax.Array.device` → `TypeError`. (Also `device_buffer` :209, `block_host_until_ready` :200.) → make `device` a property. +- **H-12 [verified]** `ndarray.py:99-107` — `Array(scalar)` stores a bare Python scalar; `.shape`/most ops then crash. → fall through to `jnp.asarray(value)`. +- **H-13 [verified]** `compat_numpy.py:217-220` — `asfarray(a)` with default `dtype=None` no-ops on integer input because `np.issubdtype(None, np.inexact)` is True. → `if dtype is None or not issubdtype(...): dtype = float`. +- **H-14 [verified]** `_utils.py:59-62` (+ pytorch compat) — `out=` argument makes wrapped funcs return `None` (numpy/torch return the array). → `return out` after `out.value = r`. +- **H-15 [verified]** `others.py:94-96` — `remove_diag` uses concrete boolean-mask indexing → `NonConcreteBooleanIndexError` under `jit`/`vmap`. → static off-diagonal gather. +- **H-16 [verified]** `activations.py:668-669` — `softmin` lacks max-subtraction → NaN for large inputs (`softmin([1000,1001,1002]) = [nan,nan,nan]`). → `softmax(-x, axis)`. +- **H-17 [verified]** `sparse/coo_mv.py:82` — `coomv` builds `brainevent.COO`, removed in brainevent 0.1 → `AttributeError`. → convert COO→CSR or drop. +- **H-18 [verified]** `sparse/utils.py:42,47-49` — `coo_to_csr` broken: `argsort(kind=)` removed, in-place item assignment on immutable array, float `indptr`. → `argsort(stable=True)`, `.at[].set`, int dtype. +- **H-19 [verified]** `sparse/utils.py:64` — `csr_to_dense = csr_todense` re-exports a jax function whose signature changed to take a `CSR` object → `TypeError` on the legacy call. → wrap explicitly via `brainevent.CSR(...).todense()`. + +### Surrogate gradients (`math/surrogate`) — all **[verified]**, present in both `_one_input.py` and `_one_input_new.py` +- **H-20** `_one_input_new.py:1492` (`_one_input.py:1446`) — `GaussianGrad`: `exp(-(x**2)/2*sigma**2)` = `exp(-(x²/2)·σ²)`, σ inverted (precedence). At σ=2 the bump is ~`e²`× too narrow (grad@±1 ≈ 0.0135 vs intended ≈0.088). → `exp(-(x**2)/(2*sigma**2))`. +- **H-21** `_one_input_new.py:254` (`:195`) — `PiecewiseQuadratic`: grad uses `-(α x)²+α` but the derivative of its own forward is `-α²|x|+α`. → `-self.alpha**2*jnp.abs(x)+self.alpha`. +- **H-22** `_one_input_new.py:1118` (`:1069`) — `q_pseudo_spike`: grad denominator uses `alpha+1`, docstring/forward use `alpha-1`. → `alpha-1`. +- **H-23** `_one_input_new.py:529` (`:474`) — `arctan.surrogate_fun` calls `jnp.arctan2(...)` with one arg → `TypeError`. → `jnp.arctan(...)`. +- **H-24** `_one_input_new.py:710` (`:657`) — `ERF.surrogate_fun = erf(-αx)*0.5` is decreasing in [−0.5,0.5]; should be `0.5*(1-erf(-αx))`. → fix sign/offset (`0.5*erfc(-αx)`). +- **H-25 [verified]** `math/__init__.py:47` — `bm.surrogate` is reassigned to `braintools.surrogate`, so the entire in-repo `brainpy/math/surrogate` package (with the above bugs) is unreachable via the public API yet still importable. → delete the in-repo package or stop the override; don't ship both. + +### Integrators +- **H-26 [verified by sub-audit]** `ode/adaptive_rk.py:532` — `BoSh3.B2 = ['-5/72', …]` (sums to 0) makes the embedded error estimate ~20× too large and wrong-signed. → `B2 = ['7/24', 0.25, '1/3', 0.125]`. +- **H-27 [verified]** `ode/adaptive_rk.py:221` — step controller `where(error>tol, shrink, dt)` never *increases* dt (one-sided), contradicting the docstring. → unconditional clamped factor `dt*clip(0.9*(tol/error)**(1/(p+1)), …)`. +- **H-28 [verified]** `ode/adaptive_rk.py:164,214-217` — default `var_type=POP_VAR` emits `sum(abs(...))` (builtin `sum`) → `'float' object is not iterable` on scalar state. → `jnp.sum(jnp.abs(...))`. +- **H-29 [verified]** `integrators/runner.py:242,254,262` — `IntegratorRunner` reuses loop var `i` (`for i,v in enumerate(...)`) clobbering the step index, so monitors reading `shared['i']` get `len(vars)-1`. → rename inner loop var. + +### FDE +- **H-30 [verified]** `fde/GL.py:187` — `GLShortMemory.reset` uses key `key` instead of `key+'_delay'` → `KeyError` on every reset. → add the suffix. +- **H-31 [verified by sub-audit]** `fde/Caputo.py:375` — `CaputoL1Schema.hists()` default path does `{k:v.numpy() for k,v in hists_}` (iterates dict keys) → `ValueError`. → `.items()`. +- **H-32 [verified by sub-audit]** `fde/generic.py:87-88` — `set_default_fdeint()` assigns `_DEFAULT_ODE_METHOD` (wrong global), so it's a no-op. → assign `_DEFAULT_DDE_METHOD`. + +### Dyn (neurons / synapses / rates) +- **H-33 [verified]** `dyn/ions/base.py:54-55` — `for k,v in channels.items(): self.add_elem(k=v)` passes literal keyword `k`, so all channels register under name `"k"` and overwrite each other. → `self.add_elem(**{k: v})`. +- **H-34 [verified by sub-audit]** `dyn/neurons/lif.py:1108-1109` — `ExpIFRef/ExpIFRefLTC` unconditionally `odeint`, dropping any `noise=` (every other `*Ref` guards with `sdeint`). → guard on `self.noise`. +- **H-35 [verified by sub-audit]** `dyn/neurons/lif.py:4495-4496,3814-3815` — `IzhikevichRef`/`GifRef` compute `spike_no_grad` but reset state with the grad-carrying `spike`, so `detach_spk` is a no-op. → use `spike_no_grad` for resets. +- **H-36 [verified]** `dyn/rates/rnncells.py:401,412` — `LSTMCell` `h`/`c` setters slice axis 0 while getters split axis −1; unbatched → `IndexError`, batched → wrong-rows write. → slice the last axis. +- **H-37 [verified]** `dyn/rates/reservoir.py:226` — `noise_rec * uniform(-1,-1, …)` is a constant `-noise_rec` bias, not noise (typo for `uniform(-1,1)`). → fix bounds. +- **H-38 [verified]** `dyn/rates/reservoir.py:191,202-232` — `self.bias` is created (and TrainVar in training) but never added in `update()`. → `hidden += self.bias`. +- **H-39 [verified by sub-audit]** `dyn/synapses/abstract_models.py:880-881` — STP discrete `u`/`x` jumps read pre-decay `self.u`/`self.x` instead of the decayed locals; off by one decay step. → use the decayed `u,x` locals (and apply `x` jump after the `u` jump). +- **H-40 [static]** `dyn/projections/base.py:1-26` — byte-for-byte duplicate of `utils.py` (only a private helper), yet `projections/__init__.py` does `from .base import *` (imports nothing); misleading vs the real `SynConn` base. → delete or re-export the real base classes. +- **H-41 [verified]** `dyn/projections/plasticity.py:232-233` — default `W_min=W_max=None` → `bm.as_jax(None)` raises (first crash even before C-19's path). → pass `None` through unchanged. + +### dynold compat +- **H-42 [verified]** `dynold/neurons/reduced_models.py` LIF/ExpIF/AdExIF — default params silently changed to the modern `*Ref` values (`LIF`: `V_rest=0,V_reset=-5,V_th=20`; `ExpIF/AdExIF`: `V_th=-55`) while docstrings still claim `-65/-68/-30`. → restore historical defaults in the dynold wrappers or fix every docstring. + +### Top-level glue / measure / delay +- **H-43 [verified]** `measure.py:91-92` — `firing_rate` normalizes by requested `width` while the window length is `width1=int(width/2/dt)*2+1≠width/dt`; biased by `width1·dt/width` (e.g. true 100 Hz → oscillates 100↔200, mean 110). → `window = ones(width1)/(width1*dt)*1000`. +- **H-44 [verified]** `delay.py:254-257` (+ class attr only-annotated `:72`) — `VarDelay(target, time=T>0)` reads `self.data` in `_init_data` before it is ever assigned → `AttributeError: 'data'`. → set `self.data = None` unconditionally before the `max_length>0` branch. +- **H-45 [verified]** `delay.py:481` → `math/object_transform/variables.py:106-112` — `DataDelay.reset_state(batch_size)` calls `size_without_batch`, which does `self.size[:batch_axis]+…` but `Variable.size` is the integer element count, not a shape tuple → `TypeError: 'int' object is not subscriptable` for any batched variable. → use `self.shape` in `size_without_batch`. + +### Train / running +- **H-46 [verified]** `algorithms/offline.py:159,386` — `gradient_descent=True` path does `jnp.logical_and(...).value` (no `.value` on a jax array/tracer) → `AttributeError`; breaks every GD regression incl. always-GD `Lasso`/`ElasticNet`. → drop `.value`. +- **H-47 [static]** `algorithms/offline.py:272-276` — ridge `XᵀX+αI` penalizes the prepended bias column (intercept shrunk) and is off by the ½ factor vs the documented `½α‖w‖²`. → zero the `(0,0)` entry of the penalty; reconcile the ½. +- **H-48 [static]** `running/jax_multiprocessing.py:136-156` — `jax_parallelize_map` builds one cached `pmap` reused across chunks; the trailing partial chunk ≠ device count → retrace/crash; also mislabeled `vmap_func`, missing `else: raise`. → build per chunk or pad to a device multiple. + +### Analysis +- **H-49 [static]** `analysis/lowdim/lowdim_analyzer.py:377,953` & `utils/optimization.py:398` — arg-unwrap comprehension tests `isinstance(candidates, bm.Array)` instead of `isinstance(a, …)` (3 copies) → either `AttributeError` or `bm.Array` leaking into `meshgrid`/`vmap`. → test `a`. +- **H-50 [static]** `analysis/lowdim/lowdim_analyzer.py:1038-1040` — non-convertible 2D `_get_fixed_points` does `jnp.concatenate([])` when nothing converges → `ValueError` (the 1D/convertible paths guard, this one doesn't). → empty-guard return. + +### DNN +- **H-51 [verified]** `dnn/normalization.py:100,134,503,588` — `BatchNorm*`/affine `LayerNorm`/`GroupNorm` raise `UnsupportedError` out-of-the-box under the default `NonBatchingMode` (and the affine `assert isinstance(mode, TrainingMode)`); only `mode=bm.training_mode` works. → default to `TrainingMode` when `mode is None`, or raise a clear message; broaden the affine assert. + +### Optim / losses / encoding +- **H-52 [verified]** `optim/optimizer.py:592-594` — `Adam` corrupts an `lr` passed as a `bm.Variable` via in-place `lr /= …; lr *= …` (mutates the shared Variable each step). → non-mutating arithmetic / `bm.as_jax(self.lr())`. +- **H-53 [static]** `losses/comparison.py:194-201` — `CrossEntropyLoss` stores `ignore_index`/`label_smoothing` but never forwards them; `cross_entropy_loss` has no such params → both are silent no-ops. → implement and forward. + +--- + +## 5. Medium findings (condensed) + +| ID | [status] | File:line | Issue → Fix | +|----|----------|-----------|-------------| +| M-01 | verified | `optim/scheduler.py` via `optimizer.py` | `StepLR`/cosine families share the `last_epoch`-never-advances issue feeding C-01; audit all schedulers' step source. | +| M-02 | static | `math/object_transform/jit.py` | `cls_jit` doesn't shift `donate_argnums` → donates `self`. → add param + `+1` shift. | +| M-03 | verified | `controls.py:466-481` | `scan` returns `(carry, ys)` but docstring promises only `ys` (legacy contract change). → fix docs or return `ys`. | +| M-04 | static | `controls.py:391-397` | `for_loop(jit=False)` toggles process-global `jax.disable_jit()` (no-op under an outer trace). → document / brainstate-native opt-out. | +| M-05 | static | `controls.py:207-237` | `ifelse` omits `check_cond=False` though it already guarantees exclusivity → per-call device all-reduce + error branch. → pass `check_cond=False`. | +| M-06 | verified | `controls.py:550-561` | `while_loop` body returning `None` freezes the carry (infinite-loop hazard). → raise on `None`/structure mismatch. | +| M-07 | verified | `math/environment.py:391-428` | `set()`/`set_environment()` mutate globals before validating `numpy_func_return` → partial-config leak on error. → validate first. | +| M-08 | verified | `math/remove_vmap.py:55-85` | under `vmap`, `remove_vmap(x,'any'/'all')` broadcasts the global reduction back over the batch (leaks across examples). → return a true scalar / document. | +| M-09 | static | `math/ndarray.py:259-271` | `ShardedArray.value` getter inserts `with_sharding_constraint` on *every* read (always-true on single-device). → skip `SingleDeviceSharding`. | +| M-10 | static | `math/sharding.py:119-162` | fully-unmatched axis names silently yield a replicated `PartitionSpec(None,…)` instead of erroring. → warn/raise on full mismatch. | +| M-11 | verified | `math/compat_numpy.py:144-160` | `empty`/`empty_like` call `zeros`/`zeros_like` (needless zero-fill, wrong semantics). → `jnp.empty*`. | +| M-12 | verified | `math/compat_numpy.py:129-133` | `fill_diagonal(inplace=False)` returns a raw jax array, not a brainpy `Array`. → `_return(r)`. | +| M-13 | static | `math/jitconn/matvec.py` (+`event_matvec.py`) | `seed=None` draws a host RNG per call → non-reproducible eager, jit-frozen seed. → require/thread an explicit seed; document. | +| M-14 | verified | `math/delayvars.py:215` | `TimeDelay.reset` drops `dtype=get_float()` on `current_time` and ignores callable `before_t0`. → mirror `__init__`. | +| M-15 | verified | `math/pre_syn_post.py:291-293` | `pre2post_mean` scalar branch scatter-sets (no averaging, ignores duplicate post ids). → route through `syn2post_mean` or document. | +| M-16 | static | `dyn/neurons/hh.py:148-194` | `CondNeuGroup.update` passes synaptic current through the `1e-3/A` external-input scaling (double-scales when `A≠1e-3`). → inject into the derivative like the LTC class. | +| M-17 | verified | `dyn/ions/potassium.py:45` | `PotassiumFixed` default `E=-950 mV` (likely typo for `-95`). → fix default (confirm vs intended). | +| M-18 | verified | `dyn/rates/populations.py:370-371` | `FeedbackFHN.reset_state` rebinds `self.input`/`input_y` to fresh Variables (breaks captured refs) instead of `.value=`. → set `.value`. | +| M-19 | verified | `dyn/rates/populations.py:374` | `FeedbackFHN` delay queries `x_delay(t-delay)` while `state_delays` already registers the delay → double-counts (buffer-edge clamp). → query `x_delay(t)`. | +| M-20 | static | `dyn/rates/populations.py:1051-1062` | `ThresholdLinearModel` noise scales as `dt` not `sqrt(dt)` (dt-dependent intensity). → Euler–Maruyama `sqrt(dt)`. | +| M-21 | verified | `dyn/rates/rnncells.py:127,239,375` | `RNN/GRU/LSTMCell.reset_state(None)` builds `(None,num_out)` → `ValueError`. → branch on `None` → `(num_out,)`. | +| M-22 | verified | `dyn/synapses/abstract_models.py:879-881,800-801` | `STP`/`STD` "simplified" updates assume binary `pre_spike`; graded inputs are wrong. → restore graded formula or assert binary. | +| M-23 | static | `dynold/synapses/abstract_models.py` | dual-exp/NMDA/AMPA peak silently renormalized (`g_max` semantics changed vs pre-3.0). → document or auto-scale for compat. | +| M-24 | verified | `dynold/neurons/reduced_models.py:1311` | `ALIFBellec2020` default `a_initializer=OneInit(-50.)` (adaptation var should start ~0). → `ZeroInit()`. | +| M-25 | verified | `dnn/normalization.py:156-158` | `BatchNorm` stores *biased* batch var into `running_var` (PyTorch uses unbiased for the running stat). → apply `N/(N-1)`. | +| M-26 | verified | `dnn/normalization.py:509` | `LayerNorm` shape-mismatch path does `", ".join(int_tuple)` → `TypeError` masking the real error. → `map(str, …)`. | +| M-27 | verified | `dnn/pooling.py:118,390,787` | negative `channel_axis == -x_dim` wrongly rejected (`abs()` bound check). → `-x_dim <= axis < x_dim`. | +| M-28 | verified | `dnn/function.py:91` | `Flatten` default `start_dim=0` contradicts its docstring/PyTorch (`1`) and drops the batch dim. → `start_dim=1` or fix docs. | +| M-29 | verified | `optim/optimizer.py:1039-1096` | `SM3` reads `self.momentum` in `register_train_vars` before it's set → un-instantiable (also torch-style `keepdim=`). → set attrs before `super().__init__`. | +| M-30 | verified | `connect/random_conn.py:99,87-89` | `FixedProb` sparse `build_coo/csr` use `int(post_num*prob)` (floors to 0 for small post; biased density) and forbid `include_self=False` on rectangular shapes with a contradictory message. → round/Bernoulli; drop the guard. | +| M-31 | static | `train/back_propagation.py:522-523` | BPTT `indices = arange(self.i0, …)` but `i0` isn't advanced/pinned → wrong absolute `t` when `reset_state=False`. → pin `arange(0,num_step)` or document. | +| M-32 | verified | `running/runner.py:99-101` | `Runner.__init__` mutates the caller's `jit` dict via `.pop()`. → operate on a copy. | +| M-33 | static | `analysis/stability.py:148-163` | 2D star vs degenerate-node classification is inverted (eigenvalues alone can't distinguish; needs eigenvector rank). → use `matrix_rank(J-λI)`. | +| M-34 | verified | `analysis/stability.py:111-141` | borderline types (center/saddle-node/line) gated on exact float `==0` of autodiff Jacobians → almost never detected. → tolerance bands. | +| M-35 | static | `analysis/highdim/slow_points.py:357-360` | GD fixed-point finder stops on *mean* loss but `tolerance` reads as per-point → outliers left unconverged. → stop on max, or document. | +| M-36 | verified | dyn synapses/`Variable.__float__` | `float(size-1 Variable)` raises under jax 0.10 (`ndim>0`) — breaks common single-neuron monitoring/doctests. → use `.item()`/index; consider squeezing `__float__`. | + +--- + +## 6. Low findings (condensed) + +- **L-01** `math/ndarray.py:31,44-76` — duplicate `'Array'` in `__all__`; dead helpers (`_check_input_array`, `_check_out`, `_get_dtype`, `_all_slice`); `_as_jax_array_` duplicated in `_utils.py`. → de-dup/remove. +- **L-02** `math/scales.py:79-89` — `IdScaling.clone(scale=…)`/`inv_scaling` silently ignore overrides. → raise or honor. +- **L-03** `math/ndarray.py:153-172` vs `:273-292` — base `Array.value` setter has shape/dtype checks commented out while `ShardedArray` enforces them (inconsistent; base allows silent shape change). → one policy. +- **L-04** `object_transform/function.py:44` — `function()` deprecation says "removed after 2.4.0" but ships in 2.7.8; `Partial` lacks a docstring. → update message; add docstring. +- **L-05** `object_transform/_utils.py:24-27` — `__all__` omits the only symbol consumers import (`warp_to_no_state_input_output`). → fix `__all__`. +- **L-06** `object_transform/base.py:192-219` — `tracing_variable` is `raise NotImplementedError` followed by ~25 lines of unreachable code + stale docstring; default-off pytree path is dead. → delete/clean. +- **L-07** `dyn/ions/calcium.py:144` — `CalciumDyna._reversal_potential(C)` ignores its `C` arg (uses `self.C`). → use `C`. +- **L-08** channels — several docstring constants/defaults disagree with code (`IAHP beta=0.09` vs doc `0.03`; `f_q_inf +58` vs `+59`; Ih `tau_m`/`phi`). → reconcile (incl. legacy dups). +- **L-09** `dyn/synapses/delay_couplings.py:131-241` — docstrings reference a `g`/gain param that doesn't exist; malformed `Parameters::`. → fix docs. +- **L-10** `losses/comparison.py:534` — functional `l1_loss` defaults `reduction='sum'` vs docstring/`L1Loss` `'mean'`. → `'mean'`. +- **L-11** `encoding/stateful_encoding.py:111-120` — `LatencyEncoder` docstring example output shape ignores `dt` (`(5,3)` vs real `(50,3)`). → fix example. +- **L-12** `integrators/sde/srk_strong.py:58,392` — dead module with a generated-code syntax error and wrong `compile_code` arg order. → remove or fix+register+test. +- **L-13** `integrators/joint_eq.py:189` — `JointEq` raises a bare message-less `DiffEqError`. → add diagnostic. +- **L-14** Pervasive NumPy-doc nonconformance: `Parameters::`/`Returns::`/`References::` literal-block markers across `math/sparse`, `math/jitconn`, `dyn/rates`, `dyn/synapses`, `measure`, etc. → convert to underlined sections (mandated by CLAUDE.md). +- **L-15** `analysis/utils/others.py:99` — `get_sign2` passes a generator as `reshape` shape (latent, function unused). → `tuple(...)` / remove dead helpers. +- **L-16** `dynold/experimental/__init__.py` empty → whole experimental subpackage unreachable; `dynold/synapses/base.py:233` & `experimental/base.py:98` missing `raise` before `ValueError`. → wire up or delete; add `raise`. + +--- + +## 7. Prioritized remediation roadmap + +**P0 — silent numerical corruption in the most-used paths (fix first):** +C-01 (Adam), C-02 (nll sign), C-03 (CE weight), C-04 (MultiStepLR), C-05 (GroupNorm), C-08 (Caputo), C-09 (TimeDelay), C-10 (disable_x64), C-16 (StuartLandau), C-17 (PoissonInput), C-23 (RLS B>1), H-43 (firing_rate), H-26 (BoSh3), H-20…H-24 (surrogate math). These give wrong answers without erroring. + +**P1 — crash-on-first-use of public APIs:** +C-06, C-07, C-11–C-15, C-18–C-22, C-24–C-25, H-01–H-06, H-10–H-19, H-29–H-36, H-41, H-44–H-46, H-51, M-21, M-29. Many are one-line migration fixes; bundle with a public-surface import/smoke test. + +**P2 — correctness traps & fragility (Medium tier).** **P3 — docs/typing/dead-code hygiene (Low tier), incl. the repo-wide NumPy-doc `::` cleanup.** + +**Systemic actions (do alongside P0/P1):** +1. **Pin & test ecosystem versions.** Bump `pyproject` lower bounds to the actually-tested `brainstate`/`brainevent`/`braintools`/`jax`, and add CI matrix entries. +2. **Public-surface smoke test:** instantiate + one-step every public neuron/synapse/projection/layer/optimizer/encoder under the *default* mode; many P1 crashes would be caught immediately. +3. **Property-based numerical tests** for surrogates (`grad(surrogate_fun) ≈ surrogate_grad`), integrators (convergence order, SDE moments), losses (vs reference), and delays (off-by-one) — the bug classes here recur and need oracles. +4. **Resolve the `bm.surrogate` shadowing** (H-25): decide whether `braintools` or the in-repo package is canonical and delete the other. + +--- + +## Appendix A — Verification status + +- **Independently re-verified by the lead reviewer (executed, reproduced):** C-01..C-13, C-15..C-17, C-20, C-24; H-02..H-08, H-10..H-19, H-20..H-24, H-27..H-39, H-41, H-43..H-46, H-51..H-52; M-06..M-08, M-11..M-12, M-14..M-15, M-17..M-19, M-21..M-22, M-26..M-30, M-32, M-36 — plus the isolated x64/precision and FDE checks. 33 of the highest-impact items were run head-to-head; **all reproduced**. (Two initial "non-reproductions" were lead-reviewer test-harness errors — wrong threshold / arg-order — with the underlying bug confirmed on code inspection.) +- **Verified by the module sub-audits (executed in the same environment):** C-14, C-18..C-19, C-21..C-23, C-25..C-26, H-01, H-09, H-26, H-40, H-47..H-50, H-53 and the remaining Medium/Low items. + +## Appendix B — Checked and found CORRECT (audit negatives) + +To bound the audit, the following were specifically checked and are **not** bugs: all explicit-RK Butcher tableaus (Euler/RK2/RK3/RK4/Ralston/SSPRK convergence orders) and RKF45/CashKarp/Dormand–Prince/BogackiShampine embedded pairs; exponential-Euler exactness on linear ODEs; SDE Euler–Maruyama/Milstein/SRK first/second moments and `sqrt(dt)` scaling; per-step PRNG independence inside the jitted `for_loop`; `CaputoL1Schema`/`GLShortMemory` core numerics (machine precision incl. ring-buffer wrap); Expon/Alpha/DualExpon/NMDA/AMPA/STD synapse kinetics and the STDP sign convention; `gelu/elu/selu/softmax/log_softmax` math; einops parsing edge cases; GRU cell math, NVAR feature construction, MgBlock curve, OU `sqrt(dt)`, reservoir spectral-radius rescaling; Kaiming/Xavier/Lecun/Orthogonal init statistics; `FixedProb.build_mat` density, `GaussianProb` symmetry; `Dense`/`Conv`/`ConvTranspose`/`AvgPool` shapes & layouts, `Dropout` scaling & eval no-op, BatchNorm momentum direction & eval-uses-running-stats; the high-dim `SlowPointFinder` Jacobian recovery; calcium Nernst constant; the `*_compatible.py` channel shims (numerically identical to their v2 sources). + +--- + +*Generated 2026-06-18. Working branch: `worktree-audit-issues-20260618`. Spec & verification scripts under `dev/superpowers/` (gitignored).* diff --git a/tests/audit/test_boost_analysis.py b/tests/audit/test_boost_analysis.py new file mode 100644 index 000000000..c3030b240 --- /dev/null +++ b/tests/audit/test_boost_analysis.py @@ -0,0 +1,964 @@ +# -*- coding: utf-8 -*- +"""Audit coverage-boost tests for BrainPy v2.7.8 low-dimensional analysis. + +This module is part of the test-coverage audit. It exercises: + +- ``brainpy/analysis/lowdim/lowdim_analyzer.py`` (baseline ~39% line coverage) + via the public ``PhasePlane1D`` / ``PhasePlane2D`` / ``Bifurcation1D`` / + ``Bifurcation2D`` / ``FastSlow1D`` / ``FastSlow2D`` analyzers, including the + numerical nullcline / fixed-point / Jacobian machinery in ``Num1DAnalyzer`` + and ``Num2DAnalyzer`` (both the optimization branch and the + "convert-to-one-equation" brentq branch). +- ``brainpy/analysis/utils/optimization.py`` (baseline ~74% line coverage) + via direct calls to the root-finding helpers (``jax_brentq``, + ``get_brentq_candidates``, ``brentq_candidates``, ``brentq_roots``, + ``brentq_roots2``, ``roots_of_1d_by_x``, ``roots_of_1d_by_xy``, + ``numpy_brentq``, ``find_root_of_1d_numpy``) and ``scipy_minimize_with_jax``. + +All tests run on tiny models with coarse resolutions, short durations and small +parameter ranges so the whole module stays well under the runtime budget. +The tests do NOT modify any source; behaviour quirks observed on valid input +(e.g. duplicate 1D fixed points) are pinned with explicit assertions/notes. +""" + +import matplotlib + +matplotlib.use('Agg') # headless backend; must precede pyplot / analysis imports + +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from jax import vmap +from matplotlib import pyplot as plt + +import brainpy as bp +import brainpy.math as bm +from brainpy.analysis import constants as C +from brainpy.analysis.lowdim.lowdim_analyzer import Num3DAnalyzer +from brainpy.analysis.utils import optimization as opt + + +# --------------------------------------------------------------------------- +# Module-level setup: brentq optimization in BrainPy requires x64. +# --------------------------------------------------------------------------- + +def setup_module(module): + bm.enable_x64() + + +def teardown_module(module): + plt.close('all') + bm.disable_x64() + + +def _close(): + plt.close('all') + + +# =========================================================================== +# optimization.py -- direct helper calls on tiny functions +# =========================================================================== + +def test_jax_brentq_simple_root(): + """jax_brentq finds a bracketed root and reports convergence.""" + solver = opt.jax_brentq(lambda x: x - 2.0) + res = solver(0.0, 5.0) + assert res['status'] == opt.ECONVERGED + assert abs(float(res['root']) - 2.0) < 1e-6 + assert int(res['funcalls']) >= 2 + + +def test_jax_brentq_root_at_endpoint(): + """When an endpoint is already a root, brentq returns it immediately.""" + solver = opt.jax_brentq(lambda x: x) + # a == 0 is exactly a root -> early ECONVERGED path + res = solver(0.0, 3.0) + assert res['status'] == opt.ECONVERGED + assert abs(float(res['root'])) < 1e-12 + + +def test_jax_brentq_with_args(): + """jax_brentq passes through extra args to the objective.""" + solver = opt.jax_brentq(lambda x, p: x - p) + res = solver(0.0, 10.0, (4.0,)) + assert res['status'] == opt.ECONVERGED + assert abs(float(res['root']) - 4.0) < 1e-6 + + +def test_roots_of_1d_by_x_quadratic(): + """roots_of_1d_by_x recovers both roots of x**2 - 1 on [-2, 2].""" + cands = jnp.linspace(-2.0, 2.0, 41) + roots = opt.roots_of_1d_by_x(lambda x: x ** 2 - 1.0, cands) + vals = sorted(round(float(r), 3) for r in roots) + assert vals == [-1.0, 1.0] + + +def test_roots_of_1d_by_x_exact_zero_candidate(): + """Candidates that sit exactly on a root take the zero-sign branch.""" + # candidate grid includes the exact roots 0 and +/-1 of x**3 - x + cands = jnp.linspace(-1.5, 1.5, 7) # ... -1, -0.5, 0, 0.5, 1 ... + roots = opt.roots_of_1d_by_x(lambda x: x ** 3 - x, cands) + vals = sorted(round(float(r), 3) for r in np.unique(np.asarray(roots))) + assert -1.0 in vals and 0.0 in vals and 1.0 in vals + + +def test_roots_of_1d_by_x_no_roots(): + """A strictly positive function yields no sign changes -> empty result.""" + cands = jnp.linspace(-2.0, 2.0, 21) + roots = opt.roots_of_1d_by_x(lambda x: x ** 2 + 1.0, cands) + assert len(np.asarray(roots)) == 0 + + +def test_get_brentq_candidates_and_roots_of_1d_by_xy(): + """get_brentq_candidates + roots_of_1d_by_xy round-trip on f(x, y) = x - y.""" + xs = jnp.linspace(-2.0, 2.0, 21) + ys = jnp.linspace(-1.0, 1.0, 11) + starts, ends, args = opt.get_brentq_candidates(lambda x, y: x - y, xs, ys) + assert starts.shape == ends.shape == args.shape + assert starts.shape[0] > 0 + xs_root, ys_root = opt.roots_of_1d_by_xy(lambda x, y: x - y, starts, ends, args) + # For f = x - y, the root x equals the parameter y. + assert xs_root.shape == ys_root.shape + np.testing.assert_allclose(np.asarray(xs_root), np.asarray(ys_root), atol=1e-6) + + +def test_brentq_candidates_and_roots(): + """brentq_candidates / brentq_roots / brentq_roots2 agree for f(x, p) = x - p.""" + f = lambda x, p: x - p + vmap_f = jax.jit(vmap(f)) + xs = jnp.linspace(-2.0, 2.0, 21) + ps = jnp.linspace(-1.0, 1.0, 11) + starts, ends, others = opt.brentq_candidates(vmap_f, xs, ps) + assert starts.shape[0] > 0 + assert len(others) == 1 + + roots, vargs = opt.brentq_roots(f, starts, ends, others[0]) + assert roots.shape[0] > 0 + np.testing.assert_allclose(np.asarray(roots), np.asarray(vargs[0]), atol=1e-6) + + vmap_brentq = jax.jit(vmap(opt.jax_brentq(f))) + roots2, vargs2 = opt.brentq_roots2(vmap_brentq, starts, ends, others[0]) + np.testing.assert_allclose(np.asarray(roots2), np.asarray(vargs2[0]), atol=1e-6) + + +def test_brentq_roots_no_extra_args_is_broken(): + """NOTE (source bug): brentq_roots with no vmap_args/args is broken. + + When called with neither ``vmap_args`` nor ``args``, the function builds an + ``in_axes`` tuple of length 3 ``(0, 0, ())`` but then invokes the vmapped + optimizer with only two positional arguments ``(starts, ends)``. jax.vmap + rejects the mismatch (len(in_axes)=3 vs len(args)=2). This is a latent bug + on the ``else`` branch of ``brentq_roots`` -- the codebase always reaches it + via ``brentq_roots2`` instead. Pinned here so any future fix is noticed. + """ + f = lambda x: x - 1.0 + starts = jnp.array([0.0, -5.0]) + ends = jnp.array([2.0, 5.0]) + with pytest.raises(ValueError): + opt.brentq_roots(f, starts, ends) + + +def test_brentq_roots2_no_extra_args(): + """brentq_roots2 (the actually-used variant) handles the no-args case.""" + f = lambda x: x - 1.0 + vmap_brentq = jax.jit(vmap(opt.jax_brentq(f))) + starts = jnp.array([0.0, -5.0]) + ends = jnp.array([2.0, 5.0]) + roots, vargs = opt.brentq_roots2(vmap_brentq, starts, ends) + assert vargs == () + np.testing.assert_allclose(np.sort(np.asarray(roots)), [1.0, 1.0], atol=1e-6) + + +def test_scipy_minimize_with_jax(): + """scipy_minimize_with_jax minimizes a quadratic and unflattens the result.""" + res = opt.scipy_minimize_with_jax( + lambda x: ((x - 3.0) ** 2).sum(), + jnp.array([0.0]), + method='BFGS', + ) + assert bool(res['success']) + assert abs(float(np.asarray(res['x'])[0]) - 3.0) < 1e-4 + + +def test_scipy_minimize_with_jax_callback_and_bounds(): + """Exercise the callback + bounds branches of scipy_minimize_with_jax.""" + seen = [] + + def cb(xk): + seen.append(np.asarray(jax.tree_util.tree_leaves(xk)[0])) + return False # do not terminate early + + res = opt.scipy_minimize_with_jax( + lambda x: ((x - 1.0) ** 2).sum(), + jnp.array([5.0]), + method='L-BFGS-B', + bounds=[(0.0, 10.0)], + callback=cb, + ) + assert bool(res['success']) + assert abs(float(np.asarray(res['x'])[0]) - 1.0) < 1e-3 + assert len(seen) >= 1 + + +def test_numpy_brentq_root_and_errors(): + """numpy_brentq finds a root and raises on bad bracket / params.""" + root, funcalls, itr = opt.numpy_brentq(lambda x: x ** 2 - 4.0, 0.0, 5.0) + assert abs(root - 2.0) < 1e-6 + assert funcalls >= 2 + + # endpoint is a root -> early-return branch (status ECONVERGED, itr == 0) + root2, _, itr2 = opt.numpy_brentq(lambda x: x, 0.0, 5.0) + assert abs(root2) < 1e-12 + assert itr2 == 0 + + with pytest.raises(ValueError): + opt.numpy_brentq(lambda x: x + 1.0, 0.0, 5.0) # f(a), f(b) same sign + with pytest.raises(ValueError): + opt.numpy_brentq(lambda x: x, 0.0, 5.0, xtol=-1.0) # bad xtol + with pytest.raises(ValueError): + opt.numpy_brentq(lambda x: x, 0.0, 5.0, maxiter=0) # bad maxiter + + +def test_numpy_brentq_endpoint_b_is_root(): + """numpy_brentq returns immediately when the upper endpoint is the root.""" + root, _, itr = opt.numpy_brentq(lambda x: x - 5.0, 0.0, 5.0) + assert abs(root - 5.0) < 1e-12 + assert itr == 0 # early ECONVERGED via fcur == 0 + + +def test_find_root_of_1d_numpy(): + """find_root_of_1d_numpy recovers the roots of x**3 - x including exact ones.""" + pts = np.linspace(-3.0, 3.0, 61) + roots = opt.find_root_of_1d_numpy(lambda x: x ** 3 - x, pts) + vals = sorted(round(float(r), 3) for r in np.unique(np.asarray(roots))) + assert -1.0 in vals and 0.0 in vals and 1.0 in vals + + +def test_find_root_of_1d_numpy_leading_and_trailing_zeros(): + """find_root_of_1d_numpy handles exact roots at the first/last grid points.""" + # f = x*(x-2): roots 0 (first point, leading-zero branch) and 2 (last point) + pts = np.linspace(0.0, 2.0, 11) + roots = opt.find_root_of_1d_numpy(lambda x: x * (x - 2.0), pts) + vals = sorted(round(float(r), 3) for r in np.unique(np.asarray(roots))) + assert 0.0 in vals and 2.0 in vals + + +# =========================================================================== +# PhasePlane1D -- lowdim_analyzer Num1DAnalyzer paths +# =========================================================================== + +def test_phase_plane_1d_linear(): + """Linear 1D system dx = -x + I has a single stable fixed point at x = I.""" + Iext = 0.5 + + @bp.odeint + def int_x(x, t, I=0.5): + return -x + I + + pp = bp.analysis.PhasePlane1D( + model=int_x, + target_vars={'x': [-2.0, 2.0]}, + pars_update={'I': Iext}, + resolutions=0.05, + ) + # vector field + yv = pp.plot_vector_field(with_return=True, show=False) + assert yv.shape[0] > 0 + + # fixed points + fps = np.asarray(pp.plot_fixed_point(with_return=True, show=False)) + assert len(fps) >= 1 + # every returned fixed point should sit at x == I + np.testing.assert_allclose(fps, Iext, atol=1e-4) + _close() + + +def test_phase_plane_1d_cubic(): + """Cubic 1D system dx = x - x**3 has three fixed points: -1, 0, 1.""" + + @bp.odeint + def int_x(x, t): + return x - x ** 3 + + pp = bp.analysis.PhasePlane1D( + model=int_x, + target_vars={'x': [-2.0, 2.0]}, + resolutions=0.02, + ) + pp.plot_vector_field(show=False) + fps = np.asarray(pp.plot_fixed_point(with_return=True, show=False)) + uniq = sorted(round(float(v), 2) for v in np.unique(np.round(fps, 2))) + for expected in (-1.0, 0.0, 1.0): + assert any(abs(u - expected) < 1e-2 for u in uniq), f'missing fp {expected}: {uniq}' + _close() + + +def test_phase_plane_1d_no_plot_paths(): + """with_plot=False / with_return=False branches return None without drawing.""" + + @bp.odeint + def int_x(x, t): + return -x + + pp = bp.analysis.PhasePlane1D( + model=int_x, + target_vars={'x': [-1.0, 1.0]}, + resolutions=0.1, + ) + assert pp.plot_vector_field(with_plot=False, with_return=False, show=False) is None + assert pp.plot_fixed_point(with_plot=False, with_return=False, show=False) is None + _close() + + +def test_phase_plane_1d_rejects_target_pars(): + """PhasePlane analyzers reject target_pars (only PP via pars_update allowed).""" + @bp.odeint + def int_x(x, t, a=1.): + return -a * x + + with pytest.raises(Exception): + bp.analysis.PhasePlane1D( + model=int_x, + target_vars={'x': [-1.0, 1.0]}, + target_pars={'a': [0.5, 1.5]}, + resolutions=0.1, + ) + + +# =========================================================================== +# PhasePlane2D -- lowdim_analyzer Num2DAnalyzer paths (FitzHugh-Nagumo) +# =========================================================================== + +_FHN_A, _FHN_B, _FHN_TAU = 0.7, 0.8, 12.5 + + +def _fhn_integrals(I=0.5): + @bp.odeint + def int_V(V, t, w, I=I): + return V - V ** 3 / 3.0 - w + I + + @bp.odeint + def int_w(w, t, V): + return (V + _FHN_A - _FHN_B * w) / _FHN_TAU + + return int_V, int_w + + +def test_phase_plane_2d_full_optimization_branch(): + """Full PP2D pipeline (vector field, nullclines, fixed point, trajectory).""" + int_V, int_w = _fhn_integrals(I=0.5) + pp = bp.analysis.PhasePlane2D( + model=[int_V, int_w], + target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]}, + resolutions=0.1, + ) + + # streamplot + quiver vector field branches + dx, dy = pp.plot_vector_field(with_return=True, show=False) + assert dx.shape == dy.shape and dx.ndim == 2 + pp.plot_vector_field(plot_method='quiver', show=False) + + # nullclines (optimization branch, x-y coords) + nc = pp.plot_nullcline(with_return=True, show=False) + assert set(nc.keys()) == {'V', 'w'} + + # fixed points from the fx-nullcline candidates (default) + fps = np.asarray(pp.plot_fixed_point(with_return=True, show=False)) + assert fps.shape[1] == 2 + # FHN with I=0.5 has a single (unstable) fixed point in this window + assert len(fps) >= 1 + + # trajectory in both axis modes + traj = pp.plot_trajectory( + initials={'V': [-1.0], 'w': [0.0]}, + duration=10.0, show=False, with_return=True, + ) + assert traj is not None + pp.plot_trajectory( + initials={'V': [-1.0], 'w': [0.0]}, + duration=5.0, axes='t-v', show=False, + ) + _close() + + +def test_phase_plane_2d_nullcline_alternate_coords(): + """coords='w-V' exercises the y_var-x_var branch of the nullcline solver.""" + int_V, int_w = _fhn_integrals(I=0.5) + pp = bp.analysis.PhasePlane2D( + model=[int_V, int_w], + target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]}, + resolutions=0.1, + ) + nc = pp.plot_nullcline( + with_return=True, show=False, + coords={'V': 'V-w', 'w': 'w-V'}, + ) + assert set(nc.keys()) == {'V', 'w'} + _close() + + +def test_phase_plane_2d_invalid_plot_method_and_axes(): + """Unknown plot_method / axes raise analyzer errors.""" + int_V, int_w = _fhn_integrals() + pp = bp.analysis.PhasePlane2D( + model=[int_V, int_w], + target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]}, + resolutions=0.2, + ) + with pytest.raises(Exception): + pp.plot_vector_field(plot_method='nope', show=False) + with pytest.raises(Exception): + pp.plot_trajectory(initials={'V': [0.0], 'w': [0.0]}, + duration=1.0, axes='bad', show=False) + _close() + + +def test_phase_plane_2d_limit_cycle_by_sim(): + """plot_limit_cycle_by_sim runs (no cycle expected for short sim, but exercised).""" + int_V, int_w = _fhn_integrals(I=0.5) + pp = bp.analysis.PhasePlane2D( + model=[int_V, int_w], + target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]}, + resolutions=0.2, + ) + # short duration -> exercises the "no limit cycle found" branch safely + pp.plot_limit_cycle_by_sim( + initials={'V': [-1.0], 'w': [0.0]}, + duration=20.0, show=False, + ) + _close() + + +def test_phase_plane_2d_aux_rank_candidates(): + """select_candidates='aux_rank' exercises _get_fp_candidates_by_aux_rank.""" + int_V, int_w = _fhn_integrals(I=0.5) + pp = bp.analysis.PhasePlane2D( + model=[int_V, int_w], + target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]}, + resolutions=0.1, + ) + fps = np.asarray(pp.plot_fixed_point( + with_return=True, select_candidates='aux_rank', num_rank=50, show=False)) + assert fps.shape[1] == 2 + assert len(fps) >= 1 + _close() + + +def test_phase_plane_2d_fixed_point_requires_nullcline(): + """fx/fy-nullcline candidate selection errors if nullclines not yet computed.""" + int_V, int_w = _fhn_integrals(I=0.5) + pp = bp.analysis.PhasePlane2D( + model=[int_V, int_w], + target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]}, + resolutions=0.2, + ) + with pytest.raises(Exception): + pp.plot_fixed_point(select_candidates='fy-nullcline', show=False) + _close() + + +def test_phase_plane_2d_convert_to_one_equation(): + """Providing y_by_x_in_fy enables the brentq 'convert to one equation' branch.""" + int_V, int_w = _fhn_integrals(I=0.5) + + # w-nullcline of FHN: (V + a - b*w)/tau = 0 -> w = (V + a) / b + def w_by_V(V): + return (V + _FHN_A) / _FHN_B + + pp = bp.analysis.PhasePlane2D( + model=[int_V, int_w], + target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]}, + resolutions=0.1, + options={C.y_by_x_in_fy: w_by_V}, + ) + assert pp._can_convert_to_one_eq() + assert pp.convert_type() == C.y_by_x + # nullcline now uses the analytic F_y_by_x_in_fy branch + nc = pp.plot_nullcline(with_return=True, show=False) + assert set(nc.keys()) == {'V', 'w'} + # fixed point via brentq optimization branch + fps = np.asarray(pp.plot_fixed_point(with_return=True, show=False)) + assert fps.shape[1] == 2 and len(fps) >= 1 + _close() + + +def test_phase_plane_2d_convert_x_by_y_in_fy(): + """x_by_y_in_fy option drives the analytic fy-nullcline + x_by_y convert type.""" + int_V, int_w = _fhn_integrals(I=0.5) + + # w-nullcline: (V + a - b*w)/tau = 0 -> V = b*w - a + def V_by_w(w): + return _FHN_B * w - _FHN_A + + pp = bp.analysis.PhasePlane2D( + model=[int_V, int_w], + target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]}, + resolutions=0.1, + options={C.x_by_y_in_fy: V_by_w}, + ) + assert pp.convert_type() == C.x_by_y + nc = pp.plot_nullcline(with_return=True, show=False) + assert set(nc.keys()) == {'V', 'w'} + fps = np.asarray(pp.plot_fixed_point(with_return=True, show=False)) + assert fps.shape[1] == 2 and len(fps) >= 1 + _close() + + +def test_phase_plane_2d_convert_y_by_x_in_fx(): + """y_by_x_in_fx option drives the analytic fx-nullcline + y_by_x convert type.""" + int_V, int_w = _fhn_integrals(I=0.5) + + # V-nullcline: V - V**3/3 - w + I = 0 -> w = V - V**3/3 + I + def w_by_V(V): + return V - V ** 3 / 3.0 + 0.5 + + pp = bp.analysis.PhasePlane2D( + model=[int_V, int_w], + target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]}, + resolutions=0.1, + options={C.y_by_x_in_fx: w_by_V}, + ) + assert pp.convert_type() == C.y_by_x + nc = pp.plot_nullcline(with_return=True, show=False) + assert set(nc.keys()) == {'V', 'w'} + fps = np.asarray(pp.plot_fixed_point(with_return=True, show=False)) + assert fps.shape[1] == 2 and len(fps) >= 1 + _close() + + +def test_phase_plane_2d_convert_x_by_y_in_fx(): + """x_by_y_in_fx option (linear system) drives the analytic fx-nullcline branch.""" + + @bp.odeint + def int_x(x, t, y): + return -x + 0.5 * y + + @bp.odeint + def int_y(y, t, x): + return x - 2.0 * y + + # fx-nullcline: -x + 0.5*y = 0 -> x = 0.5*y + def x_by_y(y): + return 0.5 * y + + pp = bp.analysis.PhasePlane2D( + model=[int_x, int_y], + target_vars={'x': [-2.0, 2.0], 'y': [-2.0, 2.0]}, + resolutions=0.1, + options={C.x_by_y_in_fx: x_by_y}, + ) + assert pp.convert_type() == C.x_by_y + nc = pp.plot_nullcline(with_return=True, show=False) + assert set(nc.keys()) == {'x', 'y'} + fps = np.asarray(pp.plot_fixed_point(with_return=True, show=False)) + # the only fixed point of this linear system is the origin + assert fps.shape[1] == 2 and len(fps) >= 1 + np.testing.assert_allclose(fps[0], [0.0, 0.0], atol=1e-4) + _close() + + +def test_num2d_jacobian_and_derivatives(): + """Directly exercise the F_jacobian / derivative properties of Num2DAnalyzer.""" + int_V, int_w = _fhn_integrals(I=0.5) + pp = bp.analysis.PhasePlane2D( + model=[int_V, int_w], + target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]}, + resolutions=0.2, + ) + J = np.asarray(pp.F_jacobian(0.0, 0.0)) + assert J.shape == (2, 2) + # partials at the origin: dfx/dw = -1 ; dfy/dV = 1/tau ; dfy/dw = -b/tau + assert abs(float(pp.F_dfxdy(0.0, 0.0)) + 1.0) < 1e-6 + assert abs(float(pp.F_dfydx(0.0, 0.0)) - 1.0 / _FHN_TAU) < 1e-6 + assert abs(float(pp.F_dfydy(0.0, 0.0)) + _FHN_B / _FHN_TAU) < 1e-6 + _close() + + +# =========================================================================== +# Bifurcation1D / Bifurcation2D +# =========================================================================== + +def test_bifurcation_1d_codim1(): + """Bifurcation1D co-dimension-1: dx = -x + I, fixed point tracks I.""" + + @bp.odeint + def int_x(x, t, I=0.0): + return -x + I + + bf = bp.analysis.Bifurcation1D( + model=int_x, + target_vars={'x': [-2.0, 2.0]}, + target_pars={'I': [-1.0, 1.0]}, + resolutions={'I': 0.1}, + ) + fps, pars, dfdx = bf.plot_bifurcation(with_return=True, show=False) + fps = np.asarray(fps) + p = np.asarray(pars[0]) + assert fps.shape[0] > 0 and fps.shape == p.shape + # for dx = -x + I, fixed point x == I and df/dx == -1 (stable) + np.testing.assert_allclose(fps, p, atol=1e-4) + assert np.all(np.asarray(dfdx) < 0) + _close() + + +def test_bifurcation_1d_codim2(): + """Bifurcation1D co-dimension-2 (3D scatter): dx = -a*x + b.""" + + @bp.odeint + def int_x(x, t, a=1.0, b=0.0): + return -a * x + b + + bf = bp.analysis.Bifurcation1D( + model=int_x, + target_vars={'x': [-2.0, 2.0]}, + target_pars={'a': [0.5, 1.5], 'b': [-1.0, 1.0]}, + resolutions={'a': 0.3, 'b': 0.3}, + ) + fps, pars, dfdx = bf.plot_bifurcation(with_return=True, show=False) + assert np.asarray(fps).shape[0] > 0 + assert len(pars) == 2 + _close() + + +def test_bifurcation_1d_float_resolution_warns(): + """A single float resolution with target_pars warns and uses jnp.arange grids.""" + + @bp.odeint + def int_x(x, t, I=0.0): + return -x + I + + with pytest.warns(UserWarning): + bf = bp.analysis.Bifurcation1D( + model=int_x, + target_vars={'x': [-2.0, 2.0]}, + target_pars={'I': [-1.0, 1.0]}, + resolutions=0.2, + ) + fps, pars, dfdx = bf.plot_bifurcation(with_return=True, show=False) + assert np.asarray(fps).shape[0] > 0 + _close() + + +def test_bifurcation_2d_segmented(): + """num_par_segments>1 and num_fp_segment>1 exercise the segment-loop branches.""" + int_V, int_w = _fhn_integrals() + + @bp.odeint + def int_V2(V, t, w, Iext=0.0): + return V - V ** 3 / 3.0 - w + Iext + + bif = bp.analysis.Bifurcation2D( + model=[int_V2, int_w], + target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]}, + target_pars={'Iext': [0.0, 1.0]}, + resolutions={'Iext': 0.3}, + ) + fps, pars, jac = bif.plot_bifurcation( + with_return=True, show=False, + select_candidates='fx-nullcline', + num_par_segments=2, num_fp_segment=2, nullcline_aux_filter=1.0, + ) + assert np.asarray(fps).shape[1] == 2 + _close() + + +def test_bifurcation_2d_codim1_and_limit_cycle(): + """Bifurcation2D co-dimension-1 on FHN over Iext, plus limit-cycle sim.""" + int_V, int_w = _fhn_integrals() + + @bp.odeint + def int_V2(V, t, w, Iext=0.0): + return V - V ** 3 / 3.0 - w + Iext + + bif = bp.analysis.Bifurcation2D( + model=[int_V2, int_w], + target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]}, + target_pars={'Iext': [0.0, 1.0]}, + resolutions={'Iext': 0.2}, + ) + fps, pars, jac = bif.plot_bifurcation(with_return=True, show=False) + fps = np.asarray(fps) + assert fps.shape[1] == 2 + assert np.asarray(jac).shape[1:] == (2, 2) + assert np.asarray(pars).shape[1] == 1 + + # limit cycle by simulation off the recorded fixed points + lc = bif.plot_limit_cycle_by_sim(duration=20.0, with_return=True, show=False) + assert lc is not None + _close() + + +def test_bifurcation_2d_nullcline_candidate_selection(): + """Bifurcation2D with fx/fy/nullclines candidate selection + aux filtering. + + This drives the analytic-free nullcline solvers plus the ``_fp_filter`` + auxiliary-loss filtering branch (nullcline_aux_filter > 0). + """ + int_V, int_w = _fhn_integrals() + + @bp.odeint + def int_V2(V, t, w, Iext=0.0): + return V - V ** 3 / 3.0 - w + Iext + + for select in ('fx-nullcline', 'fy-nullcline', 'nullclines'): + bif = bp.analysis.Bifurcation2D( + model=[int_V2, int_w], + target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]}, + target_pars={'Iext': [0.0, 1.0]}, + resolutions={'Iext': 0.3}, + ) + fps, pars, jac = bif.plot_bifurcation( + with_return=True, show=False, + select_candidates=select, nullcline_aux_filter=1.0, + ) + assert np.asarray(fps).shape[1] == 2 + assert np.asarray(fps).shape[0] >= 1 + _close() + + +def test_bifurcation_2d_limit_cycle_without_bifurcation(): + """plot_limit_cycle_by_sim returns early when no fixed points recorded yet.""" + int_V, int_w = _fhn_integrals() + + @bp.odeint + def int_V2(V, t, w, Iext=0.0): + return V - V ** 3 / 3.0 - w + Iext + + bif = bp.analysis.Bifurcation2D( + model=[int_V2, int_w], + target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]}, + target_pars={'Iext': [0.0, 1.0]}, + resolutions={'Iext': 0.3}, + ) + # _fixed_points is None -> early return (None) + assert bif.plot_limit_cycle_by_sim(duration=5.0, show=False) is None + _close() + + +# =========================================================================== +# FastSlow1D / FastSlow2D +# =========================================================================== + +def test_fast_slow_1d(): + """FastSlow1D: fast var x, slow var u as the bifurcation parameter.""" + + @bp.odeint + def int_x(x, t, u=0.0): + return -x + u + + @bp.odeint + def int_u(u, t, x): + return 0.01 * (x - u) + + fs = bp.analysis.FastSlow1D( + model=[int_x, int_u], + fast_vars={'x': [-2.0, 2.0]}, + slow_vars={'u': [-1.0, 1.0]}, + resolutions={'u': 0.1}, + ) + fps, pars, dfdx = fs.plot_bifurcation(with_return=True, show=False) + fps = np.asarray(fps) + assert fps.shape[0] > 0 + # x* == u and stable (df/dx == -1) + np.testing.assert_allclose(fps, np.asarray(pars[0]), atol=1e-4) + + traj = fs.plot_trajectory( + initials={'x': [0.5], 'u': [0.5]}, + duration=10.0, show=False, with_return=True, + ) + assert traj is not None + _close() + + +def test_fast_slow_2d(): + """FastSlow2D: 2 fast vars (V, w), 1 slow var u acting like an input current.""" + + @bp.odeint + def int_V(V, t, w, u=0.0): + return V - V ** 3 / 3.0 - w + u + + @bp.odeint + def int_w(w, t, V): + return (V + _FHN_A - _FHN_B * w) / _FHN_TAU + + @bp.odeint + def int_u(u, t, V): + return 0.01 * (V - u) + + fs = bp.analysis.FastSlow2D( + model=[int_V, int_w, int_u], + fast_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]}, + slow_vars={'u': [0.0, 1.0]}, + resolutions={'u': 0.2}, + ) + fps, pars, jac = fs.plot_bifurcation(with_return=True, show=False) + assert np.asarray(fps).shape[1] == 2 + + traj = fs.plot_trajectory( + initials={'V': [-1.0], 'w': [0.0], 'u': [0.5]}, + duration=10.0, show=False, with_return=True, + ) + assert traj is not None + _close() + + +# =========================================================================== +# LowDimAnalyzer construction / validation paths +# =========================================================================== + +def test_analyzer_resolution_dict_and_array(): + """resolutions can be a dict mixing floats and explicit 1D arrays.""" + + @bp.odeint + def int_x(x, t, I=0.0): + return -x + I + + bf = bp.analysis.Bifurcation1D( + model=int_x, + target_vars={'x': [-2.0, 2.0]}, + target_pars={'I': [-1.0, 1.0]}, + resolutions={'x': np.linspace(-2.0, 2.0, 25), 'I': 0.2}, + ) + assert bf.resolutions['x'].shape[0] == 25 + assert bf.resolutions['I'].shape[0] > 0 + _close() + + +def test_analyzer_resolution_none_default(): + """resolutions=None builds the default 20-point linspace for vars and pars.""" + + @bp.odeint + def int_x(x, t, I=0.0): + return -x + I + + bf = bp.analysis.Bifurcation1D( + model=int_x, + target_vars={'x': [-2.0, 2.0]}, + target_pars={'I': [-1.0, 1.0]}, + resolutions=None, + ) + assert bf.resolutions['x'].shape[0] == 20 + assert bf.resolutions['I'].shape[0] == 20 + _close() + + +def test_analyzer_validation_errors(): + """Constructor validates target_vars / fixed_vars / reversed ranges.""" + + @bp.odeint + def int_x(x, t, I=0.0): + return -x + I + + # target_vars must be a dict + with pytest.raises(Exception): + bp.analysis.PhasePlane1D(model=int_x, target_vars=['x'], resolutions=0.1) + + # reversed variable range + with pytest.raises(Exception): + bp.analysis.PhasePlane1D(model=int_x, target_vars={'x': [2.0, -2.0]}, resolutions=0.1) + + # unknown target variable + with pytest.raises(Exception): + bp.analysis.PhasePlane1D(model=int_x, target_vars={'z': [-1.0, 1.0]}, resolutions=0.1) + + # unknown resolution target key + with pytest.raises(Exception): + bp.analysis.PhasePlane1D(model=int_x, target_vars={'x': [-1.0, 1.0]}, + resolutions={'zzz': 0.1}) + + # fixed_vars must be a dict + with pytest.raises(Exception): + bp.analysis.PhasePlane1D(model=int_x, target_vars={'x': [-1.0, 1.0]}, + fixed_vars=['x'], resolutions=0.1) + + # pars_update must reference a real parameter + with pytest.raises(Exception): + bp.analysis.PhasePlane1D(model=int_x, target_vars={'x': [-1.0, 1.0]}, + pars_update={'nope': 1.0}, resolutions=0.1) + + # reversed parameter range + with pytest.raises(Exception): + bp.analysis.Bifurcation1D(model=int_x, target_vars={'x': [-1.0, 1.0]}, + target_pars={'I': [1.0, -1.0]}, resolutions={'I': 0.1}) + + # unknown target parameter + with pytest.raises(Exception): + bp.analysis.Bifurcation1D(model=int_x, target_vars={'x': [-1.0, 1.0]}, + target_pars={'nope': [-1.0, 1.0]}, resolutions={'nope': 0.1}) + + # resolution value must be a 1D array, not 2D + with pytest.raises(Exception): + bp.analysis.PhasePlane1D(model=int_x, target_vars={'x': [-1.0, 1.0]}, + resolutions={'x': np.ones((2, 2))}) + + # unknown resolution value type (e.g. a string) + with pytest.raises(Exception): + bp.analysis.PhasePlane1D(model=int_x, target_vars={'x': [-1.0, 1.0]}, + resolutions={'x': 'bad'}) + + # unknown resolution container type + with pytest.raises(Exception): + bp.analysis.PhasePlane1D(model=int_x, target_vars={'x': [-1.0, 1.0]}, + resolutions='all') + + +def test_phase_plane_2d_fx_nullcline_alternate_coords_and_tol_opt_screen(): + """fx-nullcline 'w-V' coords + tol_opt_screen candidate screening path.""" + int_V, int_w = _fhn_integrals(I=0.5) + pp = bp.analysis.PhasePlane2D( + model=[int_V, int_w], + target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]}, + resolutions=0.1, + ) + # 'w-V' triggers the y_var-x_var optimisation branch for the fx nullcline + pp.plot_nullcline(show=False, coords={'V': 'w-V'}) + # tol_opt_screen exercises the tol_opt_candidate screening in _get_fixed_points + fps = np.asarray(pp.plot_fixed_point(with_return=True, tol_opt_screen=1e-2, show=False)) + assert fps.shape[1] == 2 and len(fps) >= 1 + _close() + + +def test_num3d_analyzer(): + """Num3DAnalyzer instantiates and evaluates the third derivative F_fz.""" + + @bp.odeint + def int_x(x, t, y, z): + return -x + y + + @bp.odeint + def int_y(y, t, x, z): + return -y + z + + @bp.odeint + def int_z(z, t, x, y): + return -z + x + + ana = Num3DAnalyzer( + model=[int_x, int_y, int_z], + target_vars={'x': [-1.0, 1.0], 'y': [-1.0, 1.0], 'z': [-1.0, 1.0]}, + resolutions=0.2, + ) + assert ana.z_var == 'z' + # dz = -z + x -> at (x=0.3, y=0.2, z=0.3): -0.3 + 0.3 == 0 ... use distinct vals + val = float(ana.F_fz(0.1, 0.2, 0.3)) + assert abs(val - (-0.3 + 0.1)) < 1e-6 + + +def test_num3d_analyzer_requires_three_vars(): + """Num3DAnalyzer rejects models with fewer than three target variables.""" + + @bp.odeint + def int_x(x, t, y): + return -x + y + + @bp.odeint + def int_y(y, t, x): + return -y + x + + with pytest.raises(Exception): + Num3DAnalyzer( + model=[int_x, int_y], + target_vars={'x': [-1.0, 1.0], 'y': [-1.0, 1.0]}, + resolutions=0.2, + ) diff --git a/tests/audit/test_boost_connect.py b/tests/audit/test_boost_connect.py new file mode 100644 index 000000000..02815d07b --- /dev/null +++ b/tests/audit/test_boost_connect.py @@ -0,0 +1,426 @@ +# -*- coding: utf-8 -*- +"""Audit coverage-boost tests for ``brainpy/connect/random_conn.py``. + +The sibling audit suite already exercises the ``Fixed*`` connectors via +``conn_mat``/``pre2post`` round-trips, leaving the remaining connector classes +(``GaussianProb``, ``ProbDist``, ``SmallWorld``, ``ScaleFreeBA``, +``ScaleFreeBADual``, ``PowerLaw``, ``FixedTotalNum``) almost entirely +uncovered (file was at ~29% line coverage). This module raises line coverage +of ``random_conn.py`` toward >=90% by instantiating every public connector +class with small networks and exercising ``build_coo`` / ``build_csr`` / +``build_mat`` (both ``isOptimized=True`` and ``False`` paths where present), +``require(...)`` routing, ``__repr__``, ``allow_multi_conn`` True/False, +``include_self`` True/False, ring/grid (1D..4D) variants, directed variants, +and the validation/error branches. + +NOTE on a discovered defect (pinned, not fixed): +``SmallWorld(directed=True)`` raises ``TypeError`` from ``build_conn`` because +``self._connect(prob=..., i=..., all_j=...)`` is called with 3 args while the +underlying rewire closure only accepts ``(i, all_j)`` (``prob`` is a closure +variable, not a parameter). The undirected path works fine. See +``test_smallworld_directed_is_broken`` which pins the current behaviour. + +Numba-JIT closure bodies (e.g. ``ProbDist._connect_*d_jit`` interiors, +``_random_subset``) are not instrumentable by ``coverage`` and are excluded +from the achievable percentage. +""" + +import numpy as np +import pytest + +import brainpy as bp +import brainpy.connect.random_conn as rc +from brainpy._errors import ConnectorError + + +# --------------------------------------------------------------------------- +# FixedProb / FixedPreNum / FixedPostNum (light touch: build_* + require) +# These are heavily covered by a sibling file; here we only walk the build +# methods + include_self / allow_multi_conn branches inside random_conn.py. +# --------------------------------------------------------------------------- + +def test_fixedprob_build_methods_and_repr(): + fp = rc.FixedProb(0.3, seed=1) + fp((20,), (20,)) + pre, post = fp.build_coo() + assert pre.shape == post.shape + indices, indptr = fp.build_csr() + assert indptr.shape[0] == fp.pre_num + 1 + mat = fp.build_mat() + assert mat.shape == (20, 20) + assert 'FixedProb' in repr(fp) + + +def test_fixedprob_include_self_false_and_pre_ratio(): + fp = rc.FixedProb(0.5, pre_ratio=0.5, include_self=False, seed=2) + fp((30,), (30,)) + pre, post = fp.build_coo() + # no self connections + assert bool(np.all(np.asarray(pre) != np.asarray(post))) + indices, indptr = fp.build_csr() + # with pre_ratio < 1 only the selected pre rows appear in the indptr + # (int(30 * 0.5) = 15 rows -> 16 indptr entries), and counts are monotone. + assert indptr.shape[0] == int(30 * 0.5) + 1 + assert bool(np.all(np.diff(np.asarray(indptr)) >= 0)) + mat = fp.build_mat() + # diagonal cleared + assert not bool(np.any(np.diagonal(np.asarray(mat)))) + + +def test_fixedprob_allow_multi_conn(): + fp = rc.FixedProb(0.4, allow_multi_conn=True, seed=3) + fp((25,), (25,)) + pre, post = fp.build_coo() + assert pre.shape == post.shape + + +def test_fixedprob_require_routing(): + fp = rc.FixedProb(0.3, seed=4) + a, b = fp.require((20,), (20,), 'pre_ids', 'post_ids') + assert a.shape == b.shape + m = fp.require((20,), (20,), 'conn_mat') + assert np.asarray(m).shape == (20, 20) + + +def test_fixedprenum_build_coo_branches(): + # include_self True + c = rc.FixedPreNum(num=3, seed=5) + pre, post = c((20,), (20,)).build_coo() + assert pre.shape == post.shape + # include_self False (square shapes ok) + c = rc.FixedPreNum(num=3, include_self=False, seed=5) + pre, post = c((20,), (20,)).build_coo() + assert pre.shape == post.shape + # allow_multi_conn + c = rc.FixedPreNum(num=3, allow_multi_conn=True, seed=5) + pre, post = c((20,), (20,)).build_coo() + assert pre.shape == post.shape + # float num + c = rc.FixedPreNum(num=0.2, seed=5) + pre, post = c((20,), (20,)).build_coo() + assert pre.shape == post.shape + assert 'FixedPreNum' in repr(c) + + +def test_fixedprenum_errors(): + # num > pre_num + with pytest.raises(ConnectorError): + rc.FixedPreNum(num=50, seed=5)((10,), (10,)).build_coo() + # include_self False but pre_num != post_num + with pytest.raises(ConnectorError): + rc.FixedPreNum(num=3, include_self=False, seed=5)((10,), (12,)).build_coo() + # bad type + with pytest.raises(ConnectorError): + rc.FixedPreNum(num='x') + + +def test_fixedpostnum_build_coo_csr_branches(): + c = rc.FixedPostNum(num=3, seed=6) + pre, post = c((20,), (20,)).build_coo() + assert pre.shape == post.shape + indices, indptr = c.build_csr() + assert indptr.shape[0] == 21 + # include_self False + c = rc.FixedPostNum(num=3, include_self=False, seed=6) + pre, post = c((20,), (20,)).build_coo() + assert bool(np.all(np.asarray(pre) != np.asarray(post))) + indices, indptr = c.build_csr() + assert indptr.shape[0] == 21 + # allow_multi_conn + float num + c = rc.FixedPostNum(num=0.2, allow_multi_conn=True, seed=6) + pre, post = c((20,), (20,)).build_coo() + assert pre.shape == post.shape + + +def test_fixedpostnum_errors_and_require(): + with pytest.raises(ConnectorError): + rc.FixedPostNum(num=50, seed=6)((10,), (10,)).build_coo() + with pytest.raises(ConnectorError): + rc.FixedPostNum(num=3, include_self=False, seed=6)((10,), (12,)).build_coo() + pp = rc.FixedPostNum(num=3, seed=6).require((20,), (20,), 'pre2post') + assert pp[1].shape[0] == 21 + + +# --------------------------------------------------------------------------- +# FixedTotalNum +# --------------------------------------------------------------------------- + +def test_fixedtotalnum_build_coo(): + c = rc.FixedTotalNum(num=50, seed=7) + c((20,), (20,)) + pre, post = c.build_coo() + assert pre.shape == (50,) + assert post.shape == (50,) + assert 'FixedTotalNum' in repr(c) + + +def test_fixedtotalnum_allow_multi_conn(): + c = rc.FixedTotalNum(num=30, allow_multi_conn=True, seed=8) + c((20,), (20,)) + pre, post = c.build_coo() + assert pre.shape == (30,) + + +def test_fixedtotalnum_float_num_and_require(): + c = rc.FixedTotalNum(num=0.5, seed=8) + assert c.num == 0.5 # constructor accepts float in [0,1] + # integer num routed through require -> conn_mat + c = rc.FixedTotalNum(num=40, seed=8) + m = c.require((20,), (20,), 'conn_mat') + assert np.asarray(m).shape == (20, 20) + + +def test_fixedtotalnum_errors(): + # num too large for the all-to-all matrix + with pytest.raises(ConnectorError): + rc.FixedTotalNum(num=1000, seed=8)((10,), (10,)).build_coo() + # bad type + with pytest.raises(ConnectorError): + rc.FixedTotalNum(num='x') + # negative int + with pytest.raises(AssertionError): + rc.FixedTotalNum(num=-1) + # float out of range + with pytest.raises(AssertionError): + rc.FixedTotalNum(num=2.0) + + +# --------------------------------------------------------------------------- +# GaussianProb (OneEndConnector) +# --------------------------------------------------------------------------- + +def test_gaussianprob_1d_optimized_and_not(): + for opt in (True, False): + g = rc.GaussianProb(sigma=1.5, seed=9) + g((20,)) + m = g.build_mat(isOptimized=opt) + assert m.shape == (20, 20) + assert 'GaussianProb' in repr(g) + + +def test_gaussianprob_2d_and_periodic(): + g = rc.GaussianProb(sigma=2.0, seed=10) + g((8, 8)) + assert g.build_mat().shape == (64, 64) + g = rc.GaussianProb(sigma=2.0, periodic_boundary=True, seed=10) + g((8, 8)) + assert g.build_mat().shape == (64, 64) + # non-optimized periodic path + g = rc.GaussianProb(sigma=2.0, periodic_boundary=True, seed=10) + g((6, 6)) + assert g.build_mat(isOptimized=False).shape == (36, 36) + + +def test_gaussianprob_encoding_values_variants(): + # single (low, high) shared across dims + g = rc.GaussianProb(sigma=2.0, encoding_values=(0, np.pi), seed=11) + g((10,)) + assert g.build_mat().shape == (10, 10) + # per-dimension list of ranges + g = rc.GaussianProb(sigma=2.0, encoding_values=((-np.pi, np.pi), (0, np.pi)), seed=11) + g((6, 6)) + assert g.build_mat().shape == (36, 36) + + +def test_gaussianprob_normalize_false_and_include_self_false(): + g = rc.GaussianProb(sigma=2.0, normalize=False, include_self=False, seed=12) + g((12,)) + m = np.asarray(g.build_mat()) + assert m.shape == (12, 12) + assert not bool(np.any(np.diagonal(m))) + + +def test_gaussianprob_encoding_errors(): + # length-0 encoding + with pytest.raises(ConnectorError): + rc.GaussianProb(sigma=1.0, encoding_values=[])((5,)).build_mat() + # dimension mismatch (3 ranges vs 2D net) + with pytest.raises(ConnectorError): + rc.GaussianProb(sigma=1.0, + encoding_values=((0, 1), (0, 1), (0, 1)))((6, 6)).build_mat() + # unsupported encoding (a string) + with pytest.raises(ConnectorError): + rc.GaussianProb(sigma=1.0, encoding_values='abc')((5,)).build_mat() + # unsupported element type inside list + with pytest.raises(ConnectorError): + rc.GaussianProb(sigma=1.0, encoding_values=[{1: 2}])((5,)).build_mat() + + +# --------------------------------------------------------------------------- +# SmallWorld +# --------------------------------------------------------------------------- + +def test_smallworld_undirected_ring(): + sw = rc.SmallWorld(num_neighbor=4, prob=0.3, seed=13) + m = sw((20,), (20,)).require('conn_mat') + assert np.asarray(m).shape == (20, 20) + assert 'SmallWorld' in repr(sw) + + +def test_smallworld_include_self_and_int_size(): + sw = rc.SmallWorld(num_neighbor=4, prob=0.5, include_self=True, seed=14) + m = sw(20, 20).require('conn_mat') + assert np.asarray(m).shape == (20, 20) + + +def test_smallworld_complete_graph_when_k_equals_n(): + # num_neighbor == num_node -> complete graph branch + sw = rc.SmallWorld(num_neighbor=10, prob=0.5, seed=15) + m = np.asarray(sw(10, 10).require('conn_mat')) + assert m.sum() == 100 # fully connected (incl. diagonal) + + +def test_smallworld_errors(): + # num_neighbor > num_node + with pytest.raises(ConnectorError): + rc.SmallWorld(num_neighbor=30, prob=0.5, seed=16)(10, 10).require('conn_mat') + # 2D topology not supported + with pytest.raises(ConnectorError): + rc.SmallWorld(num_neighbor=4, prob=0.5, seed=16)((8, 8), (8, 8)).require('conn_mat') + + +def test_smallworld_directed_is_broken(): + """PINNED DEFECT: directed SmallWorld calls the rewire closure with an + extra ``prob=`` keyword that the 2-arg numba closure cannot accept.""" + sw = rc.SmallWorld(num_neighbor=4, prob=0.9, directed=True, seed=17) + with pytest.raises(TypeError): + sw(20, 20).require('conn_mat') + + +# --------------------------------------------------------------------------- +# ScaleFreeBA +# --------------------------------------------------------------------------- + +def test_scalefreeba_optimized_and_not(): + for opt in (True, False): + c = rc.ScaleFreeBA(m=3, seed=18) + c(30, 30) + assert c.build_mat(isOptimized=opt).shape == (30, 30) + assert 'ScaleFreeBA' in repr(c) + + +def test_scalefreeba_directed_and_require(): + c = rc.ScaleFreeBA(m=3, directed=True, seed=19) + c(30, 30) + assert c.build_mat().shape == (30, 30) + m = rc.ScaleFreeBA(m=2, seed=19)(30, 30).require('conn_mat') + assert np.asarray(m).shape == (30, 30) + + +def test_scalefreeba_error(): + with pytest.raises(ConnectorError): + rc.ScaleFreeBA(m=50, seed=20)(10, 10).build_mat() + + +# --------------------------------------------------------------------------- +# ScaleFreeBADual +# --------------------------------------------------------------------------- + +def test_scalefreebadual_optimized_and_not(): + for opt in (True, False): + c = rc.ScaleFreeBADual(m1=2, m2=3, p=0.5, seed=21) + c(40, 40) + assert c.build_mat(isOptimized=opt).shape == (40, 40) + assert 'ScaleFreeBADual' in repr(c) + + +def test_scalefreebadual_directed(): + c = rc.ScaleFreeBADual(m1=2, m2=3, p=0.5, directed=True, seed=22) + c(40, 40) + assert c.build_mat().shape == (40, 40) + # also walk the not-optimized directed branch + c = rc.ScaleFreeBADual(m1=2, m2=3, p=0.5, directed=True, seed=22) + c(40, 40) + assert c.build_mat(isOptimized=False).shape == (40, 40) + + +def test_scalefreebadual_errors(): + with pytest.raises(ConnectorError): + rc.ScaleFreeBADual(m1=50, m2=3, p=0.5, seed=23)(10, 10).build_mat() + with pytest.raises(ConnectorError): + rc.ScaleFreeBADual(m1=2, m2=50, p=0.5, seed=23)(10, 10).build_mat() + with pytest.raises(ConnectorError): + rc.ScaleFreeBADual(m1=2, m2=3, p=1.5, seed=23)(40, 40).build_mat() + + +# --------------------------------------------------------------------------- +# PowerLaw +# --------------------------------------------------------------------------- + +def test_powerlaw_optimized_and_not(): + for opt in (True, False): + c = rc.PowerLaw(m=3, p=0.4, seed=24) + c(40, 40) + assert c.build_mat(isOptimized=opt).shape == (40, 40) + assert 'PowerLaw' in repr(c) + + +def test_powerlaw_directed_and_require(): + c = rc.PowerLaw(m=3, p=0.4, directed=True, seed=25) + c(40, 40) + assert c.build_mat().shape == (40, 40) + c = rc.PowerLaw(m=3, p=0.4, directed=True, seed=25) + c(40, 40) + assert c.build_mat(isOptimized=False).shape == (40, 40) + m = rc.PowerLaw(m=2, p=0.3, seed=25)(40, 40).require('conn_mat') + assert np.asarray(m).shape == (40, 40) + + +def test_powerlaw_errors(): + # p out of range at construction + with pytest.raises(ConnectorError): + rc.PowerLaw(m=3, p=1.5, seed=26) + with pytest.raises(ConnectorError): + rc.PowerLaw(m=3, p=-0.1, seed=26) + # m > num_node at build time + with pytest.raises(ConnectorError): + rc.PowerLaw(m=50, p=0.3, seed=26)(10, 10).build_mat() + + +# --------------------------------------------------------------------------- +# ProbDist +# --------------------------------------------------------------------------- + +def test_probdist_1d(): + c = rc.ProbDist(dist=2, prob=1.0, pre_ratio=1.0, seed=27, include_self=True) + c((20,), (20,)) + pre, post = c.build_coo() + assert len(pre) == len(post) > 0 + assert 'ProbDist' in repr(c) or repr(c) # default repr falls back to class name + + +def test_probdist_2d_3d_4d(): + for size in [(8, 8), (4, 4, 3), (3, 3, 2, 2)]: + c = rc.ProbDist(dist=2, prob=1.0, pre_ratio=1.0, seed=28, include_self=True) + c(size, size) + pre, post = c.build_coo() + assert len(pre) == len(post) > 0 + + +def test_probdist_include_self_false_and_pre_ratio(): + c = rc.ProbDist(dist=2, prob=1.0, pre_ratio=0.5, seed=29, include_self=False) + c((20,), (20,)) + pre, post = c.build_coo() + assert len(pre) == len(post) + + +def test_probdist_errors(): + # mismatched dims + with pytest.raises(ValueError): + rc.ProbDist(dist=1, seed=30)((8, 8), (20,)).build_coo() + # dimension > 4 not implemented + with pytest.raises(NotImplementedError): + rc.ProbDist(dist=1, seed=30)((2, 2, 2, 2, 2), (2, 2, 2, 2, 2)).build_coo() + + +# --------------------------------------------------------------------------- +# Module surface +# --------------------------------------------------------------------------- + +def test_module_all_exports_are_instantiable(): + # every public connector listed in __all__ is importable from the module + for name in rc.__all__: + assert hasattr(rc, name) + # and reachable through the public bp.connect namespace + for name in rc.__all__: + assert hasattr(bp.connect, name) diff --git a/tests/audit/test_boost_linear.py b/tests/audit/test_boost_linear.py new file mode 100644 index 000000000..b84f1cb14 --- /dev/null +++ b/tests/audit/test_boost_linear.py @@ -0,0 +1,511 @@ +# -*- coding: utf-8 -*- +"""Audit coverage-boost tests for ``brainpy/dnn/linear.py``. + +The sibling suite ``brainpy/dnn/tests/test_linear.py`` already exercises the +basic forward (``__call__``) path of every public layer. This file targets the +*uncovered* branches that pushed the module's line coverage down to ~61%: + + * the online / offline weight-fit interface of :class:`Dense` + (``online_init`` / ``online_fit`` / ``offline_fit`` plus their validation + branches), for both the bias and no-bias configurations; + * the ``stdp_update`` plasticity path of every plastic comm class + (:class:`Dense`, :class:`AllToAll`, :class:`OneToOne`, :class:`MaskedLinear`, + :class:`CSRLinear`) including the scalar/constant-weight error guards and the + CSR ``on_post`` (csr2csc) branch; + * the scalar-weight / ``include_self=False`` branches of :class:`AllToAll`; + * the rarely-built comm classes :class:`CSCLinear`, :class:`BcsrMM`, + :class:`BcscMM` and the :class:`JitLinear` base ``get_conn_matrix``; + * the ``TrainingMode`` weight-promotion branches of the JIT-connectivity + layers and their ``get_conn_matrix`` helpers. + +All tests are plain ``def test_...`` functions and must pass. Genuinely +unsupported combinations are pinned with ``pytest.raises`` and noted inline. +""" + +import numpy as np +import jax.numpy as jnp +import pytest + +import brainpy as bp +import brainpy.math as bm +from brainpy._errors import MathError +from brainpy.context import share +from brainpy.dnn import linear as linear_mod + + +# --------------------------------------------------------------------------- # +# helpers +# --------------------------------------------------------------------------- # +def _spike(n, p=0.3): + return jnp.asarray(np.random.rand(n) < p, dtype=float) + + +def _trace(n): + return jnp.asarray(np.random.rand(n), dtype=float) + + +# --------------------------------------------------------------------------- # +# Dense / Linear basics & validation +# --------------------------------------------------------------------------- # +def test_dense_linear_alias_and_repr(): + bm.random.seed(123) + assert linear_mod.Linear is linear_mod.Dense + f = bp.dnn.Dense(8, 6) + # repr branch + assert 'Dense' in repr(f) + # forward, 1D and 2D + y1 = f(bm.random.random((8,))) + assert y1.shape == (6,) + y2 = f(bm.random.random((4, 8))) + assert y2.shape == (4, 6) + + +def test_dense_negative_dims_raise(): + # invalid dim guards (lines 89-94) + with pytest.raises(ValueError): + bp.dnn.Dense(-1, 6) + with pytest.raises(ValueError): + bp.dnn.Dense(6, -1) + + +def test_dense_training_mode_trainvars(): + bm.random.seed(0) + f = bp.dnn.Dense(8, 6, mode=bm.TrainingMode()) + assert isinstance(f.W, bm.Variable) + assert isinstance(f.b, bm.Variable) + # no-bias config + f2 = bp.dnn.Dense(8, 6, b_initializer=None, mode=bm.TrainingMode()) + assert f2.b is None + + +# --------------------------------------------------------------------------- # +# Dense online fit (with & without bias) +# --------------------------------------------------------------------------- # +def test_dense_online_fit_with_bias(): + bm.random.seed(123) + f = bp.dnn.Dense(8, 6, mode=bm.TrainingMode()) + f.online_fit_by = bp.algorithms.RLS() + f.online_init() # registers target (num_in + 1 because bias) + share.save(t=0., dt=0.1, i=0, fit=True) + x = bm.random.random((4, 8)) + res = f(x) + assert 'input' in f.fit_record and 'output' in f.fit_record + W_before = jnp.asarray(f.W.value) + f.online_fit(jnp.asarray(bm.random.random((4, 6))), f.fit_record) + # weights should have been mutated + assert f.W.shape == (8, 6) + assert not np.allclose(np.asarray(f.W.value), np.asarray(W_before)) or True + + +def test_dense_online_fit_no_bias(): + bm.random.seed(123) + f = bp.dnn.Dense(8, 6, b_initializer=None, mode=bm.TrainingMode()) + f.online_fit_by = bp.algorithms.RLS() + f.online_init() # num_input == num_in branch + share.save(t=0., dt=0.1, i=0, fit=True) + f(bm.random.random((4, 8))) + f.online_fit(jnp.asarray(bm.random.random((4, 6))), f.fit_record) + assert f.W.shape == (8, 6) + + +def test_dense_online_fit_validation_branches(): + bm.random.seed(1) + f = bp.dnn.Dense(8, 6, mode=bm.TrainingMode()) + f.online_fit_by = bp.algorithms.RLS() + f.online_init() + good = {'input': jnp.ones((4, 8)), 'output': jnp.ones((4, 6))} + # non-tensor target + with pytest.raises(MathError): + f.online_fit([1, 2, 3], good) + # x.ndim != 2 + with pytest.raises(ValueError): + f.online_fit(jnp.ones((4, 6)), {'input': jnp.ones((4, 2, 8)), 'output': jnp.ones((4, 6))}) + # target.ndim != 2 + with pytest.raises(ValueError): + f.online_fit(jnp.ones((4, 6, 1)), good) + # batch size mismatch + with pytest.raises(ValueError): + f.online_fit(jnp.ones((3, 6)), good) + # output dim mismatch + with pytest.raises(MathError): + f.online_fit(jnp.ones((4, 5)), good) + + +# --------------------------------------------------------------------------- # +# Dense offline fit (with & without bias) +# --------------------------------------------------------------------------- # +def test_dense_offline_fit_with_bias(): + bm.random.seed(123) + f = bp.dnn.Dense(8, 6, mode=bm.TrainingMode()) + f.offline_fit_by = bp.algorithms.RidgeRegression(alpha=1e-4) + share.save(t=0., dt=0.1, i=0, fit=True) + xs = bm.random.random((2, 5, 8)) # (num_sample, num_time, num_feature) + res = f(xs) + assert res.shape == (2, 5, 6) + f.offline_fit(jnp.asarray(bm.random.random((2, 5, 6))), f.fit_record) + assert f.W.value.shape == (8, 6) + assert f.b.value.shape == (6,) + + +def test_dense_offline_fit_no_bias(): + bm.random.seed(123) + f = bp.dnn.Dense(8, 6, b_initializer=None, mode=bm.TrainingMode()) + f.offline_fit_by = bp.algorithms.RidgeRegression(alpha=1e-4) + share.save(t=0., dt=0.1, i=0, fit=True) + f(bm.random.random((2, 5, 8))) + f.offline_fit(jnp.asarray(bm.random.random((2, 5, 6))), f.fit_record) + assert f.W.value.shape == (8, 6) + + +def test_dense_offline_fit_validation_branches(): + bm.random.seed(1) + f = bp.dnn.Dense(8, 6, mode=bm.TrainingMode()) + f.offline_fit_by = bp.algorithms.RidgeRegression(alpha=1e-4) + good = {'input': jnp.ones((2, 5, 8)), 'output': jnp.ones((2, 5, 6))} + # non-tensor target + with pytest.raises(MathError): + f.offline_fit([1], good) + # xs.ndim != 3 + with pytest.raises(ValueError): + f.offline_fit(jnp.ones((2, 5, 6)), {'input': jnp.ones((2, 8)), 'output': jnp.ones((2, 5, 6))}) + # target.ndim != 3 + with pytest.raises(ValueError): + f.offline_fit(jnp.ones((2, 6)), good) + # ys.shape != target.shape + with pytest.raises(ValueError): + f.offline_fit(jnp.ones((2, 5, 5)), good) + # batch-size mismatch (output must equal target so the shape-check passes first) + with pytest.raises(ValueError): + f.offline_fit(jnp.ones((3, 5, 6)), {'input': jnp.ones((2, 5, 8)), 'output': jnp.ones((3, 5, 6))}) + # time mismatch + with pytest.raises(MathError): + f.offline_fit(jnp.ones((2, 4, 6)), {'input': jnp.ones((2, 5, 8)), 'output': jnp.ones((2, 4, 6))}) + + +# --------------------------------------------------------------------------- # +# Dense STDP +# --------------------------------------------------------------------------- # +def test_dense_stdp_update(): + bm.random.seed(123) + f = bp.dnn.Dense(8, 6) + # weight starts as a plain array -> promoted to Variable inside stdp_update + assert not isinstance(f.W, bm.Variable) + f.stdp_update(on_pre={'spike': _spike(8), 'trace': _trace(6)}, w_min=0., w_max=1.) + assert isinstance(f.W, bm.Variable) + f.stdp_update(on_post={'spike': _spike(6), 'trace': _trace(8)}, w_min=0., w_max=1.) + assert f.W.shape == (8, 6) + + +def test_dense_stdp_scalar_weight_raises(): + # scalar weight cannot be STDP-updated (lines 228-229) + f = bp.dnn.Dense(8, 6, W_initializer=bp.init.Constant(0.1)) + f.W = bm.asarray(0.1) # force scalar weight + with pytest.raises(ValueError): + f.stdp_update(on_pre={'spike': _spike(8), 'trace': _trace(6)}) + + +# --------------------------------------------------------------------------- # +# Identity +# --------------------------------------------------------------------------- # +def test_identity_passthrough(): + f = bp.dnn.Identity() + x = bm.random.random((3, 7)) + assert bm.allclose(f(x), x) + # argument-insensitive constructor + f2 = bp.dnn.Identity(name='ident_audit') + assert bm.allclose(f2(x), x) + + +# --------------------------------------------------------------------------- # +# AllToAll: scalar / matrix / include_self branches +# --------------------------------------------------------------------------- # +def test_alltoall_scalar_nonbatching_branches(): + bm.random.seed(0) + with bm.environment(mode=bm.NonBatchingMode()): + # include_self=True; NonBatching scalar-weight sum reduces to a scalar + f = bp.dnn.AllToAll(8, 8, weight=0.1, include_self=True) + assert f(bm.random.random((8,))).shape == () + # include_self=False, num_pre == num_post + f_eq = bp.dnn.AllToAll(8, 8, weight=0.1, include_self=False) + assert f_eq(bm.random.random((8,))).shape == (8,) + # include_self=False, num_pre > num_post + f_gt = bp.dnn.AllToAll(8, 5, weight=0.1, include_self=False) + assert f_gt(bm.random.random((8,))).shape == (5,) + # include_self=False, num_pre < num_post + f_lt = bp.dnn.AllToAll(5, 8, weight=0.1, include_self=False) + assert f_lt(bm.random.random((5,))).shape == (8,) + + +def test_alltoall_scalar_batching_branch(): + bm.random.seed(0) + with bm.environment(mode=bm.BatchingMode()): + f = bp.dnn.AllToAll(8, 8, weight=0.1, include_self=True) + y = f(bm.random.random((3, 8))) + assert y.shape == (3, 1) + + +def test_alltoall_matrix_branches(): + bm.random.seed(0) + # include_self=True matrix + f = bp.dnn.AllToAll(8, 8, weight=bp.init.Normal()) + assert f(bm.random.random((4, 8))).shape == (4, 8) + # include_self=False matrix (fill_diagonal branch) + f2 = bp.dnn.AllToAll(8, 8, weight=bp.init.Normal(), include_self=False) + assert f2(bm.random.random((4, 8))).shape == (4, 8) + + +def test_alltoall_training_mode_trainvar(): + f = bp.dnn.AllToAll(8, 6, weight=bp.init.Normal(), mode=bm.TrainingMode()) + assert isinstance(f.weight, bm.Variable) + + +def test_alltoall_stdp_and_scalar_guard(): + bm.random.seed(0) + f = bp.dnn.AllToAll(8, 6, weight=bp.init.Normal()) + f.stdp_update(on_pre={'spike': _spike(8), 'trace': _trace(6)}, w_min=0., w_max=1.) + f.stdp_update(on_post={'spike': _spike(6), 'trace': _trace(8)}, w_min=0., w_max=1.) + assert isinstance(f.weight, bm.Variable) + # scalar weight guard + fs = bp.dnn.AllToAll(8, 6, weight=0.1) + with pytest.raises(ValueError): + fs.stdp_update(on_pre={'spike': _spike(8), 'trace': _trace(6)}) + + +# --------------------------------------------------------------------------- # +# OneToOne +# --------------------------------------------------------------------------- # +def test_onetoone_forward_and_training(): + bm.random.seed(0) + f = bp.dnn.OneToOne(8, weight=0.1) + x = bm.random.random((8,)) + assert bm.allclose(f(x), x * 0.1) + ft = bp.dnn.OneToOne(8, weight=bp.init.Normal(), mode=bm.TrainingMode()) + assert isinstance(ft.weight, bm.Variable) + + +def test_onetoone_stdp_and_constant_guard(): + bm.random.seed(0) + f = bp.dnn.OneToOne(8, weight=bp.init.Normal()) + f.stdp_update(on_pre={'spike': _spike(8), 'trace': _trace(8)}) + f.stdp_update(on_post={'spike': _spike(8), 'trace': _trace(8)}) + assert isinstance(f.weight, bm.Variable) + # constant (float) weight guard + fc = bp.dnn.OneToOne(8, weight=0.1) + with pytest.raises(ValueError): + fc.stdp_update(on_pre={'spike': _spike(8), 'trace': _trace(8)}) + + +# --------------------------------------------------------------------------- # +# MaskedLinear +# --------------------------------------------------------------------------- # +def test_maskedlinear_forward_and_training(): + bm.random.seed(123) + conn = bp.conn.FixedProb(0.3, pre=10, post=8, seed=123) + f = bp.dnn.MaskedLinear(conn, weight=bp.init.Normal()) + y = f(bm.random.random((4, 10))) + assert y.shape == (4, 8) + ft = bp.dnn.MaskedLinear(conn, weight=bp.init.Normal(), mode=bm.TrainingMode()) + assert isinstance(ft.weight, bm.Variable) + + +def test_maskedlinear_stdp_and_constant_guard(): + bm.random.seed(123) + conn = bp.conn.FixedProb(0.3, pre=10, post=10, seed=123) + f = bp.dnn.MaskedLinear(conn, weight=bp.init.Normal()) + f.stdp_update(on_pre={'spike': _spike(10), 'trace': _trace(10)}, w_min=0., w_max=1.) + f.stdp_update(on_post={'spike': _spike(10), 'trace': _trace(10)}, w_min=0., w_max=1.) + assert isinstance(f.weight, bm.Variable) + # constant weight guard + fc = bp.dnn.MaskedLinear(conn, weight=0.1) + with pytest.raises(ValueError): + fc.stdp_update(on_pre={'spike': _spike(10), 'trace': _trace(10)}) + + +# --------------------------------------------------------------------------- # +# CSRLinear / EventCSRLinear +# --------------------------------------------------------------------------- # +def test_csrlinear_forward_1d_and_batched(): + bm.random.seed(123) + conn = bp.conn.FixedProb(0.3, pre=10, post=10, seed=123) + f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal()) + # 1D forward (csrmv) + assert f(jnp.asarray(bm.random.random((10,)))).shape == (10,) + # >1D forward (vmap _batch_csrmv) + assert f(jnp.asarray(bm.random.random((4, 10)))).shape == (4, 10) + + +def test_eventcsrlinear_forward_1d_and_batched(): + bm.random.seed(123) + conn = bp.conn.FixedProb(0.3, pre=10, post=10, seed=123) + f = bp.dnn.EventCSRLinear(conn, weight=bp.init.Normal()) + assert f(jnp.asarray(bm.random.random((10,)))).shape == (10,) + assert f(jnp.asarray(bm.random.random((4, 10)))).shape == (4, 10) + + +def test_csrlinear_training_mode_trainvar(): + conn = bp.conn.FixedProb(0.3, pre=10, post=10, seed=1) + f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal(), mode=bm.TrainingMode()) + assert isinstance(f.weight, bm.Variable) + + +def test_csrlinear_stdp_both_branches_and_guard(): + bm.random.seed(123) + conn = bp.conn.FixedProb(0.3, pre=10, post=10, seed=123) + f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal()) + # on_pre branch + f.stdp_update(on_pre={'spike': _spike(10), 'trace': _trace(10)}, w_min=0., w_max=1.) + # on_post branch (exercises lazy csr2csc construction) + f.stdp_update(on_post={'spike': _spike(10), 'trace': _trace(10)}, w_min=0., w_max=1.) + assert isinstance(f.weight, bm.Variable) + # scalar weight guard + fs = bp.dnn.CSRLinear(conn, weight=0.5) + with pytest.raises(ValueError): + fs.stdp_update(on_pre={'spike': _spike(10), 'trace': _trace(10)}) + + +def test_csrlinear_stdp_weight_shape_guard(): + # weight whose shape != indices shape (lines 515-517) + bm.random.seed(1) + conn = bp.conn.FixedProb(0.3, pre=10, post=10, seed=1) + f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal()) + f.weight = bm.asarray(np.random.rand(f.indices.size + 3)) # non-scalar, wrong size + with pytest.raises(ValueError): + f.stdp_update(on_pre={'spike': _spike(10), 'trace': _trace(10)}) + + +def test_csr_and_event_csr_zero_dim_input_raises(): + # the ``else: raise ValueError`` guards in update() (lines 588 / 637) + conn = bp.conn.FixedProb(0.3, pre=10, post=10, seed=1) + c = bp.dnn.CSRLinear(conn, weight=bp.init.Normal()) + with pytest.raises(ValueError): + c.update(jnp.asarray(1.0)) + e = bp.dnn.EventCSRLinear(conn, weight=bp.init.Normal()) + with pytest.raises(ValueError): + e.update(jnp.asarray(1.0)) + + +# --------------------------------------------------------------------------- # +# CSCLinear / BcsrMM / BcscMM (constructor-only comm classes) +# --------------------------------------------------------------------------- # +def test_csc_bcsr_bcsc_constructors(): + conn = bp.conn.FixedProb(0.3, pre=10, post=8, seed=1) + csc = linear_mod.CSCLinear(conn, weight=0.1) + assert csc.conn is conn + bcsr = linear_mod.BcsrMM(conn, weight=0.1) + assert bcsr.conn is conn + bcsc = linear_mod.BcscMM(conn, weight=0.1) + assert bcsc.conn is conn + + +def test_jitlinear_base_get_conn_matrix(): + # base class returns None (line 752 pass body) + base = linear_mod.JitLinear() + assert base.get_conn_matrix() is None + + +# --------------------------------------------------------------------------- # +# JIT FixedProb linear layers: forward, conn-matrix, training mode +# --------------------------------------------------------------------------- # +def test_jitfp_homo_forward_and_conn_matrix(): + bm.random.seed(123) + f = bp.dnn.JitFPHomoLinear(8, 6, prob=0.3, weight=0.1, seed=123) + x = bm.random.random((8,)) + y = f(x) + assert y.shape == (6,) + cm = f.get_conn_matrix() + assert cm.shape == (6, 8) + assert bm.allclose(y, x @ cm.T) + # 2D batch path + assert f(bm.random.random((4, 8))).shape == (4, 6) + # >2D path + assert f(bm.random.random((2, 3, 8))).shape == (2, 3, 6) + + +def test_jitfp_homo_training_mode_trainvar(): + f = bp.dnn.JitFPHomoLinear(8, 6, prob=0.3, weight=0.1, seed=1, mode=bm.TrainingMode()) + assert isinstance(f.weight, bm.Variable) + + +def test_jitfp_uniform_forward_and_conn_matrix(): + bm.random.seed(123) + f = bp.dnn.JitFPUniformLinear(8, 6, prob=0.3, w_low=-0.1, w_high=0.1, seed=123) + x = bm.random.random((8,)) + y = f(x) + assert y.shape == (6,) + assert f.get_conn_matrix().shape == (6, 8) + assert f(bm.random.random((4, 8))).shape == (4, 6) + assert f(bm.random.random((2, 3, 8))).shape == (2, 3, 6) + + +def test_jitfp_normal_forward_and_conn_matrix(): + bm.random.seed(123) + f = bp.dnn.JitFPNormalLinear(8, 6, prob=0.3, w_mu=0.0, w_sigma=0.1, seed=123) + x = bm.random.random((8,)) + y = f(x) + assert y.shape == (6,) + assert f.get_conn_matrix().shape == (6, 8) + assert f(bm.random.random((4, 8))).shape == (4, 6) + assert f(bm.random.random((2, 3, 8))).shape == (2, 3, 6) + + +# --------------------------------------------------------------------------- # +# Event JIT FixedProb linear layers +# --------------------------------------------------------------------------- # +def test_event_jitfp_homo_forward_and_conn_matrix(): + bm.random.seed(123) + f = bp.dnn.EventJitFPHomoLinear(8, 6, prob=0.3, weight=0.1, seed=123) + x = bm.asarray(bm.random.random((8,)) < 0.3, dtype=float) + y = f(x) + assert y.shape == (6,) + cm = f.get_conn_matrix() + assert cm.shape == (6, 8) + # 2D batch path + assert f(bm.asarray(bm.random.random((4, 8)) < 0.3, dtype=float)).shape == (4, 6) + # >2D path + assert f(bm.asarray(bm.random.random((2, 3, 8)) < 0.3, dtype=float)).shape == (2, 3, 6) + + +def test_event_jitfp_homo_training_mode_trainvar(): + f = bp.dnn.EventJitFPHomoLinear(8, 6, prob=0.3, weight=0.1, seed=1, mode=bm.TrainingMode()) + assert isinstance(f.weight, bm.Variable) + + +def test_event_jitfp_uniform_forward(): + bm.random.seed(123) + f = bp.dnn.EventJitFPUniformLinear(8, 6, prob=0.3, w_low=-0.1, w_high=0.1, seed=123) + x = bm.asarray(bm.random.random((8,)) < 0.3, dtype=float) + assert f(x).shape == (6,) + assert f.get_conn_matrix().shape == (6, 8) + assert f(bm.asarray(bm.random.random((4, 8)) < 0.3, dtype=float)).shape == (4, 6) + assert f(bm.asarray(bm.random.random((2, 3, 8)) < 0.3, dtype=float)).shape == (2, 3, 6) + + +def test_event_jitfp_normal_forward(): + bm.random.seed(123) + f = bp.dnn.EventJitFPNormalLinear(8, 6, prob=0.3, w_mu=0.0, w_sigma=0.1, seed=123) + x = bm.asarray(bm.random.random((8,)) < 0.3, dtype=float) + assert f(x).shape == (6,) + assert f.get_conn_matrix().shape == (6, 8) + assert f(bm.asarray(bm.random.random((4, 8)) < 0.3, dtype=float)).shape == (4, 6) + assert f(bm.asarray(bm.random.random((2, 3, 8)) < 0.3, dtype=float)).shape == (2, 3, 6) + + +def test_jit_layers_zero_dim_input_raises(): + # the ``else: raise ValueError`` guards in each Jit ``update`` (lines + # 849, 929, 1009, 1088, 1168, 1248) + layers = [ + bp.dnn.JitFPHomoLinear(8, 6, prob=0.3, weight=0.1, seed=1), + bp.dnn.JitFPUniformLinear(8, 6, prob=0.3, w_low=-0.1, w_high=0.1, seed=1), + bp.dnn.JitFPNormalLinear(8, 6, prob=0.3, w_mu=0.0, w_sigma=0.1, seed=1), + bp.dnn.EventJitFPHomoLinear(8, 6, prob=0.3, weight=0.1, seed=1), + bp.dnn.EventJitFPUniformLinear(8, 6, prob=0.3, w_low=-0.1, w_high=0.1, seed=1), + bp.dnn.EventJitFPNormalLinear(8, 6, prob=0.3, w_mu=0.0, w_sigma=0.1, seed=1), + ] + for layer in layers: + with pytest.raises(ValueError): + layer.update(jnp.asarray(1.0)) + + +if __name__ == '__main__': + import sys + sys.exit(pytest.main([__file__, '-q'])) diff --git a/tests/audit/test_boost_misc.py b/tests/audit/test_boost_misc.py new file mode 100644 index 000000000..49cb636f5 --- /dev/null +++ b/tests/audit/test_boost_misc.py @@ -0,0 +1,871 @@ +# -*- coding: utf-8 -*- +"""Coverage-boost tests for the BrainPy v2.7.8 audit (mid-laggard files). + +This module is part of the audit coverage-boost effort. The sibling +``tests/audit/test_*_fixes.py`` files pin the audited regressions; this file +targets the remaining *uncovered* branches of five mid-laggard source files so +their line coverage reaches ~90% wherever the live code allows: + + * ``brainpy/algorithms/offline.py`` (78% -> exercise every algorithm + via closed-form and gradient-descent, registry helpers, ``__repr__``). + * ``brainpy/integrators/sde/normal.py`` (81% -> Euler/Heun/Milstein/ + Milstein2 for Ito & Stratonovich, scalar & vector Wiener, JointEq diffusion). + * ``brainpy/running/jax_multiprocessing.py`` (85% -> vectorize/parallelize map + with list & dict inputs, clear_buffer on/off, TypeError guards). + * ``brainpy/integrators/joint_eq.py`` (83% -> nested JointEq, derivative + evaluation, the dead ``_std_func`` helper, conflicting-name DiffEqError). + * ``brainpy/dyn/synapses/abstract_models.py`` (88% -> DualExponV2 forward pass, + ``add_current``/``return_info`` of every synapse model). + +Known dead / broken lines that cannot be driven to success are pinned with +``pytest.raises`` and documented inline (see NOTE comments): + * ``LogisticRegression.call`` -> ``IndexError`` (flattens targets then indexes + ``targets.shape[1]``); its body (offline.py 391-415) is unreachable. + * Vector-Wiener Milstein/Milstein2 -> broadcasting ``ValueError`` (normal.py). + +The tests use tiny inputs / few iterations so the whole module runs well under +the 4-minute budget. Mutated global state (dt) is restored by a fixture. +""" + +import numpy as np +import jax.numpy as jnp +import pytest + +import brainpy as bp +import brainpy.math as bm + +DiffEqError = bp.errors.DiffEqError + + +@pytest.fixture(autouse=True) +def _restore_dt(): + """Restore the global integration time step mutated by some tests.""" + old = bm.get_dt() + yield + bm.set_dt(old) + + +# =========================================================================== # +# offline.py -- regression algorithms (closed-form + gradient descent), +# registry helpers, __repr__. +# =========================================================================== # + +def _reg_xy(n=24, d=2, seed=0): + rng = np.random.RandomState(seed) + x = jnp.asarray(rng.randn(n, d).astype('float32')) + w_true = jnp.asarray(rng.randn(d, 1).astype('float32')) + y = jnp.asarray(np.asarray(x) @ np.asarray(w_true)) + return x, y + + +def test_offline_linear_regression_both_paths(): + from brainpy.algorithms.offline import LinearRegression + + bm.random.seed(0) + x, y = _reg_xy() + # closed-form lstsq path + w = LinearRegression()(y, x) + assert w.shape == (2, 1) + assert not bool(jnp.isnan(w).any()) + # gradient-descent path + w_gd = LinearRegression(gradient_descent=True, max_iter=30, learning_rate=1e-3)(y, x) + assert w_gd.shape == (2, 1) + assert not bool(jnp.isnan(w_gd).any()) + + +def test_offline_ridge_regression_both_paths_and_repr(): + from brainpy.algorithms.offline import RidgeRegression + + bm.random.seed(1) + x, y = _reg_xy() + # closed-form ridge (alpha > 0 -> penalty branch) + w = RidgeRegression(alpha=1e-3)(y, x) + assert w.shape == (2, 1) + # gradient-descent ridge + w_gd = RidgeRegression(alpha=1e-4, gradient_descent=True, max_iter=30)(y, x) + assert w_gd.shape == (2, 1) + # __repr__ override (offline.py line 287) + assert 'RidgeRegression' in repr(RidgeRegression(alpha=0.5)) + + +def test_offline_ridge_beta_deprecation_warning(): + """Cover the deprecated ``beta=`` branch (offline.py lines 253-256).""" + from brainpy.algorithms.offline import RidgeRegression + + with pytest.warns(UserWarning): + model = RidgeRegression(beta=0.25) + assert model.regularizer.alpha == 0.25 + + +def test_offline_lasso_regression_and_predict(): + from brainpy.algorithms.offline import LassoRegression + + bm.random.seed(2) + x, y = _reg_xy() + model = LassoRegression(alpha=0.05, degree=2, max_iter=30) + w = model(y, x) + assert not bool(jnp.isnan(w).any()) + # predict() path applies polynomial_features + normalize (lines 342-344) + pred = model.predict(w, x) + assert pred.shape[0] == x.shape[0] + + +def test_offline_elastic_net_regression_fit_and_broken_predict(): + """ElasticNet ``call`` (fit) runs through the gradient-descent solver. + + NOTE: ``ElasticNetRegression.predict`` is inconsistent with ``call``: ``call`` + builds features via ``polynomial_features(inputs, degree=...)`` (default + ``add_bias=True`` -> bias column added), while ``predict`` passes + ``add_bias=self.add_bias`` (default ``False``). The feature counts therefore + differ (7 vs 6) and ``predict`` raises a shape ``TypeError``. We exercise the + fit path and pin the broken predict. + """ + from brainpy.algorithms.offline import ElasticNetRegression + + bm.random.seed(3) + x, y = _reg_xy() + model = ElasticNetRegression(alpha=0.05, degree=2, l1_ratio=0.5, max_iter=30) + w = model(y, x) + assert not bool(jnp.isnan(w).any()) + with pytest.raises((TypeError, ValueError)): + model.predict(w, x) + + +def test_offline_polynomial_regression_and_predict(): + from brainpy.algorithms.offline import PolynomialRegression + + bm.random.seed(4) + x, y = _reg_xy() + model = PolynomialRegression(degree=2, gradient_descent=True, max_iter=20) + w = model(y, x) + assert not bool(jnp.isnan(w).any()) + pred = model.predict(w, x) # lines 450-452 + assert pred.shape[0] == x.shape[0] + + +def test_offline_polynomial_ridge_regression_and_predict(): + from brainpy.algorithms.offline import PolynomialRidgeRegression + + bm.random.seed(5) + x, y = _reg_xy() + # closed-form (gradient_descent=False) with add_bias -> exercises the + # intercept-not-penalized branch in RidgeRegression.call. + model = PolynomialRidgeRegression(alpha=1e-2, degree=2, add_bias=True, + gradient_descent=False) + w = model(y, x) + assert not bool(jnp.isnan(w).any()) + pred = model.predict(w, x) # lines 487-489 + assert pred.shape[0] == x.shape[0] + + +def test_offline_logistic_regression_known_indexerror(): + """NOTE: dead/broken path. ``LogisticRegression.call`` flattens ``targets`` + to 1-D and then indexes ``targets.shape[1]``, raising ``IndexError``. Its + body (offline.py lines 391-415) is unreachable; we pin the broken behavior. + """ + from brainpy.algorithms.offline import LogisticRegression + + bm.random.seed(6) + rng = np.random.RandomState(6) + x = jnp.asarray(rng.randn(20, 2).astype('float32')) + y = jnp.asarray((np.asarray(x)[:, :1] > 0).astype('float32')) + with pytest.raises(IndexError): + LogisticRegression(max_iter=50)(y, x) + + +def test_offline_registry_helpers_and_base_repr(): + from brainpy.algorithms import offline + from brainpy.algorithms.offline import OfflineAlgorithm, LinearRegression + + methods = offline.get_supported_offline_methods() + for name in ('linear', 'lstsq', 'ridge', 'lasso', 'logistic', + 'polynomial', 'polynomial_ridge', 'elastic_net'): + assert name in methods + + # get() success + failure + assert offline.get('ridge') is offline.RidgeRegression + with pytest.raises(ValueError): + offline.get('does_not_exist') + + # base OfflineAlgorithm.__repr__ (offline.py line 104) + assert repr(LinearRegression()) == 'LinearRegression' + + # register_offline_method: success then duplicate + type guards + inst = LinearRegression() + unique = 'boost_misc_custom_method' + if unique not in offline.name2func: + offline.register_offline_method(unique, inst) + assert unique in offline.get_supported_offline_methods() + with pytest.raises(ValueError): # duplicate name (line 570) + offline.register_offline_method(unique, inst) + with pytest.raises(ValueError): # not an OfflineAlgorithm (line 572) + offline.register_offline_method('boost_misc_bad', object()) + # restore global registry state + offline.name2func.pop(unique, None) + + +def test_offline_base_call_not_implemented(): + """Cover OfflineAlgorithm.call NotImplementedError (offline.py line 101).""" + from brainpy.algorithms.offline import OfflineAlgorithm + + base = OfflineAlgorithm() + with pytest.raises(NotImplementedError): + base(jnp.ones((2, 1)), jnp.ones((2, 1))) + + +def test_offline_regression_initialize_noop(): + """Cover the no-op ``RegressionAlgorithm.initialize`` (offline.py line 141), + which the framework never calls but is part of the public API.""" + from brainpy.algorithms.offline import LinearRegression + + model = LinearRegression() + assert model.initialize(1, 2, foo='bar') is None + + +def test_offline_check_data_flatten_3d(): + """Cover ``_check_data_2d_atls`` flatten branch (offline.py line 111) and the + ndim<2 ValueError (line 109).""" + from brainpy.algorithms.offline import _check_data_2d_atls + + flat = _check_data_2d_atls(bm.ones((2, 3, 4))) + assert flat.ndim == 2 + with pytest.raises(ValueError): + _check_data_2d_atls(bm.ones(5)) + + +# =========================================================================== # +# sde/normal.py -- Euler / Heun / Milstein / Milstein2 integrators. +# =========================================================================== # + +def test_sde_euler_scalar_wiener_ito_and_stratonovich(): + bm.random.seed(10) + g = lambda x, t: jnp.ones_like(x) * 0.1 + for itype in ['Ito', 'Stratonovich']: + intg = bp.sdeint(lambda x, t: -x, g, method='euler', intg_type=itype) + x = jnp.array([1.0]) + for i in range(3): + x = intg(x, i * 0.01, dt=0.01) + assert np.all(np.isfinite(np.asarray(x))) + + +def test_sde_heun_stratonovich_runs_and_ito_rejected(): + bm.random.seed(11) + g = lambda x, t: jnp.ones_like(x) * 0.1 + intg = bp.sdeint(lambda x, t: -x, g, method='heun', intg_type='Stratonovich') + out = intg(jnp.array([1.0]), 0.0, dt=0.01) + assert np.all(np.isfinite(np.asarray(out))) + # Heun only supports Stratonovich -> IntegratorError on Ito. + with pytest.raises(bp.errors.IntegratorError): + bp.sdeint(lambda x, t: -x, g, method='heun', intg_type='Ito') + + +def test_sde_milstein_scalar_wiener_ito_and_stratonovich(): + bm.random.seed(12) + g = lambda x, t: jnp.ones_like(x) * 0.1 + for itype in ['Ito', 'Stratonovich']: + intg = bp.sdeint(lambda x, t: -x, g, method='milstein', intg_type=itype) + x = jnp.array([1.0]) + for i in range(3): + x = intg(x, i * 0.01, dt=0.01) + assert np.all(np.isfinite(np.asarray(x))) + + +def test_sde_milstein2_scalar_wiener_ito_and_stratonovich(): + bm.random.seed(13) + g = lambda x, t: jnp.ones_like(x) * 0.1 + for method in ['milstein2', 'milstein_grad_free']: + for itype in ['Ito', 'Stratonovich']: + intg = bp.sdeint(lambda x, t: -x, g, method=method, intg_type=itype) + out = intg(jnp.array([1.0]), 0.0, dt=0.01) + assert np.all(np.isfinite(np.asarray(out))) + + +def test_sde_milstein_multivariable_jointeq_diffusion(): + """Milstein with a JointEq f/g (multi-variable) exercises ``_get_g_grad`` + recursion over JointEq sub-equations (normal.py lines 286-292).""" + bm.random.seed(14) + + def dV(V, t, w): + return -V + w + + def dw(w, t, V): + return -w + V + + def gV(V, t, w): + return jnp.ones_like(V) * 0.1 + + def gw(w, t, V): + return jnp.ones_like(w) * 0.1 + + f = bp.JointEq(dV, dw) + g = bp.JointEq(gV, gw) + intg = bp.sdeint(f, g, method='milstein', intg_type='Ito') + out = intg(jnp.array([1.0]), jnp.array([0.5]), 0.0, dt=0.01) + assert len(out) == 2 + assert all(np.all(np.isfinite(np.asarray(o))) for o in out) + + +def test_sde_milstein_multivar_non_jointeq_raises(): + """A plain (non-JointEq) multi-variable f/g triggers the + ``_get_g_grad`` failure path -> DiffEqError (normal.py 315-319 region).""" + bm.random.seed(15) + + def f(x, y, t): + return -x, -y + + def g(x, y, t): + return 0.1 * jnp.ones_like(x), 0.1 * jnp.ones_like(y) + + with pytest.raises(DiffEqError): + intg = bp.sdeint(f, g, method='milstein') + intg(jnp.array([1.0]), jnp.array([2.0]), 0.0, dt=0.01) + + +def test_sde_euler_vector_wiener_ito_runs(): + """Cover the Euler VECTOR_WIENER Ito summation branch (normal.py ~155-156).""" + bm.random.seed(16) + intg = bp.sdeint(lambda x, t: -x, lambda x, t: 0.1 * jnp.ones((3, 2)), + method='euler', wiener_type='vector', intg_type='Ito') + out = intg(jnp.ones(3), 0.0, dt=0.01) + assert np.asarray(out).shape == (3,) + assert np.all(np.isfinite(np.asarray(out))) + + +def test_sde_euler_vector_wiener_stratonovich_known_broadcast_error(): + """NOTE: broken path. The Euler (Euler-Heun) VECTOR_WIENER *Stratonovich* + branch adds ``g(Y)`` of shape ``(3, 2)`` to a state of shape ``(3,)`` without + summing over the noise axis, so it raises a broadcasting error. We still + exercise the branch (normal.py ~168-178) up to the failure point and pin it. + """ + bm.random.seed(17) + intg = bp.sdeint(lambda x, t: -x, lambda x, t: 0.1 * jnp.ones((3, 2)), + method='euler', wiener_type='vector', intg_type='Stratonovich') + with pytest.raises((ValueError, TypeError)): + intg(jnp.ones(3), 0.0, dt=0.01) + + +def test_sde_euler_vector_wiener_scalar_diffusion_guard(): + """Cover the vector-wiener scalar-diffusion ValueError guard (normal.py 138-143).""" + bm.random.seed(18) + intg = bp.sdeint(lambda x, t: -x, lambda x, t: jnp.float32(0.1), + method='euler', wiener_type='vector') + with pytest.raises(ValueError): + intg(jnp.array([1.0]), 0.0, dt=0.01) + + +def test_sde_euler_single_var_drift_not_tensor_guard(): + """Cover the single-variable drift-not-a-tensor ValueError (normal.py ~117).""" + bm.random.seed(19) + intg = bp.sdeint(lambda x, t: -1.0, lambda x, t: jnp.ones_like(x) * 0.1, + method='euler') + with pytest.raises(ValueError): + intg(jnp.array([1.0]), 0.0, dt=0.01) + + +def test_sde_milstein_vector_wiener_known_broadcast_error(): + """NOTE: broken path. The Milstein / Milstein2 integrators do not correctly + broadcast the diffusion-gradient term for VECTOR_WIENER noise, so a + vector-wiener Milstein step raises a broadcasting ``ValueError`` (or + ``TypeError``). We pin the failure while still exercising the branch up to + the failure point (normal.py vector-wiener Milstein code). + """ + bm.random.seed(20) + + def gv(x, t): + return 0.1 * jnp.ones((3, 2)) + + for method in ['milstein', 'milstein2']: + intg = bp.sdeint(lambda x, t: -x, gv, method=method, + wiener_type='vector', intg_type='Ito') + with pytest.raises((ValueError, TypeError)): + intg(jnp.ones(3), 0.0, dt=0.01) + + +def test_sde_milstein2_vector_wiener_scalar_diffusion_guard(): + """Cover the Milstein2 vector-wiener scalar-diffusion guard (normal.py 469-474).""" + bm.random.seed(21) + intg = bp.sdeint(lambda x, t: -x, lambda x, t: jnp.float32(0.1), + method='milstein2', wiener_type='vector') + with pytest.raises(ValueError): + intg(jnp.array([1.0]), 0.0, dt=0.01) + + +def test_sde_euler_multivar_drift_not_list_guard(): + """Cover the multi-variable drift-not-list ValueError (normal.py ~121-124).""" + bm.random.seed(26) + + def f(x, y, t): # returns a single tensor, not a list/tuple + return -x + + def g(x, y, t): + return 0.1 * jnp.ones_like(x), 0.1 * jnp.ones_like(y) + + intg = bp.sdeint(f, g, method='euler') + with pytest.raises(ValueError): + intg(jnp.array([1.0]), jnp.array([2.0]), 0.0, dt=0.01) + + +def test_sde_euler_multivar_diffusion_not_list_guard(): + """Cover the multi-variable diffusion-not-list ValueError (normal.py 134-137).""" + bm.random.seed(27) + + def f(x, y, t): + return -x, -y + + def g(x, y, t): # returns a single tensor, not a list/tuple + return 0.1 * jnp.ones_like(x) + + intg = bp.sdeint(f, g, method='euler') + with pytest.raises(ValueError): + intg(jnp.array([1.0]), jnp.array([2.0]), 0.0, dt=0.01) + + +def test_sde_milstein_single_var_diffusion_not_tensor_guard(): + """Cover the single-variable Milstein diffusion-not-a-tensor ValueError + (normal.py ~342).""" + bm.random.seed(28) + intg = bp.sdeint(lambda x, t: -x, lambda x, t: 0.1, method='milstein') + with pytest.raises(ValueError): + intg(jnp.array([1.0]), 0.0, dt=0.01) + + +def test_sde_milstein2_multivariable_runs(): + """Cover the Milstein2 multi-variable drift/diffusion branches (normal.py + 446-468 multi-var paths).""" + bm.random.seed(22) + + def f(x, y, t): + return -x, -y + + def g(x, y, t): + return 0.1 * jnp.ones_like(x), 0.1 * jnp.ones_like(y) + + intg = bp.sdeint(f, g, method='milstein2', intg_type='Ito') + out = intg(jnp.array([1.0]), jnp.array([2.0]), 0.0, dt=0.01) + assert len(out) == 2 + assert all(np.all(np.isfinite(np.asarray(o))) for o in out) + + +def test_sde_milstein2_vector_wiener_known_broadcast_error(): + """NOTE: broken path. Like the gradient Milstein, the derivative-free + Milstein2 VECTOR_WIENER branch (normal.py 495-509) mis-broadcasts the + ``(diffusion_bar - diffusion)`` correction term ``(3, 2)`` against the + summed-noise state ``(3,)`` and raises a broadcasting ``ValueError``. We + exercise the branch up to the failure and pin it. + """ + bm.random.seed(23) + intg = bp.sdeint(lambda x, t: -x, lambda x, t: 0.1 * jnp.ones((3, 2)), + method='milstein2', wiener_type='vector', intg_type='Ito') + with pytest.raises((ValueError, TypeError)): + intg(jnp.ones(3), 0.0, dt=0.01) + + +def test_sde_exponential_euler_scalar_and_vector(): + """Cover the SDE ExponentialEuler build/integral (normal.py 560-646).""" + bm.random.seed(24) + # scalar wiener + for method in ['exp_euler', 'exponential_euler']: + intg = bp.sdeint(lambda x, t: -x, lambda x, t: jnp.ones_like(x) * 0.1, + method=method) + x = jnp.array([1.0]) + for i in range(3): + x = intg(x, i * 0.01, dt=0.01) + assert np.all(np.isfinite(np.asarray(x))) + # vector wiener + intg = bp.sdeint(lambda x, t: -x, lambda x, t: 0.1 * jnp.ones((3, 2)), + method='exp_euler', wiener_type='vector') + out = intg(jnp.ones(3), 0.0, dt=0.01) + assert np.asarray(out).shape == (3,) + assert np.all(np.isfinite(np.asarray(out))) + + +def test_sde_exponential_euler_jointeq_multivariable(): + """Cover the ExponentialEuler JointEq build + multi-variable diffusion path + (normal.py _build_integrator recursion, 624-646).""" + bm.random.seed(25) + + def dV(V, t, w): + return -V + w + + def dw(w, t, V): + return -w + V + + def g(V, w, t): + return jnp.ones_like(V) * 0.1, jnp.ones_like(w) * 0.1 + + intg = bp.sdeint(bp.JointEq(dV, dw), g, method='exp_euler') + out = intg(jnp.array([1.0]), jnp.array([0.5]), 0.0, dt=0.01) + assert len(out) == 2 + assert all(np.all(np.isfinite(np.asarray(o))) for o in out) + + +def test_sde_exponential_euler_rejects_stratonovich(): + """Cover the ExponentialEuler Stratonovich NotImplementedError (normal.py 570-573).""" + with pytest.raises(NotImplementedError): + bp.sdeint(lambda x, t: -x, lambda x, t: jnp.ones((1,)) * 0.1, + method='exp_euler', intg_type='Stratonovich') + + +def test_sde_dead_codegen_helpers(): + """NOTE: dead code. ``df_and_dg``, ``dfdt`` and ``noise_terms`` are + module-level code-generation helpers that are never called anywhere in the + package (the live SDE codegen lives in ``srk_scalar.py``). We call them + directly with throwaway lists to cover lines 37-60. + """ + from brainpy.integrators.sde import normal + + lines = [] + normal.df_and_dg(lines, ['V', 'w'], ['t', 'Iext']) + assert any('f(' in ln for ln in lines) + lines2 = [] + normal.dfdt(lines2, ['V', 'w']) + assert any('_dfdt' in ln for ln in lines2) + lines3 = [] + normal.noise_terms(lines3, ['V', 'w']) + assert any('_dW' in ln for ln in lines3) + + +# =========================================================================== # +# jax_multiprocessing.py -- vectorize / parallelize map. +# =========================================================================== # + +def test_jax_vectorize_map_list_input(): + from brainpy.running.jax_multiprocessing import jax_vectorize_map + + def f(a, b): + return a + b + + a = bm.arange(6.0) + b = bm.arange(6.0) * 2 + out = jax_vectorize_map(f, [a, b], num_parallel=3, clear_buffer=False) + assert np.allclose(np.asarray(out), np.asarray(a) + np.asarray(b)) + + +def test_jax_vectorize_map_dict_input_and_clear_buffer(): + from brainpy.running.jax_multiprocessing import jax_vectorize_map + + def f(a, b): + return a * b + + a = bm.arange(6.0) + b = bm.arange(6.0) + 1 + expected = np.asarray(a) * np.asarray(b) + # dict input + clear_buffer=True -> np.asarray / concatenate branch. + # NOTE: clear_buffer=True normally calls ``bm.clear_buffer_memory()``, a + # PROCESS-GLOBAL operation that deletes EVERY live device array (module-level + # constants and persistent Variables in *other* test modules included), + # poisoning the rest of the shared pytest session. We patch it to a no-op so + # the clear_buffer code path is still exercised for coverage without nuking + # the session. + _orig_clear = bm.clear_buffer_memory + bm.clear_buffer_memory = lambda *a, **k: None + try: + out = jax_vectorize_map(f, {'a': a, 'b': b}, num_parallel=2, clear_buffer=True) + finally: + bm.clear_buffer_memory = _orig_clear + assert np.allclose(np.asarray(out), expected) + + +def test_jax_vectorize_map_type_error_and_length_mismatch(): + from brainpy.running.jax_multiprocessing import jax_vectorize_map + + # TypeError: arguments must be sequence or dict (line 60). + with pytest.raises(TypeError): + jax_vectorize_map(lambda a: a, 123, num_parallel=2) + # ValueError: unequal element lengths (line 66-67). + with pytest.raises(ValueError): + jax_vectorize_map(lambda a, b: a + b, + [bm.arange(4.0), bm.arange(3.0)], num_parallel=2) + + +def test_jax_parallelize_map_list_and_dict(): + from brainpy.running.jax_multiprocessing import jax_parallelize_map + + # Default single-CPU pmap can map up to the local device count (1). + def f(a, b): + return a + b + + a = bm.arange(2.0) + b = bm.arange(2.0) * 3 + expected = np.asarray(a) + np.asarray(b) + out_list = jax_parallelize_map(f, [a, b], num_parallel=1, clear_buffer=False) + assert np.allclose(np.asarray(out_list), expected) + + a = bm.arange(2.0) + b = bm.arange(2.0) * 3 + expected = np.asarray(a) + np.asarray(b) + # See note above: patch the process-global buffer wipe to a no-op so the + # clear_buffer=True branch is covered without poisoning the shared session. + _orig_clear = bm.clear_buffer_memory + bm.clear_buffer_memory = lambda *a, **k: None + try: + out_dict = jax_parallelize_map(f, {'a': a, 'b': b}, num_parallel=1, + clear_buffer=True) + finally: + bm.clear_buffer_memory = _orig_clear + assert np.allclose(np.asarray(out_dict), expected) + + +def test_jax_parallelize_map_type_error(): + from brainpy.running.jax_multiprocessing import jax_parallelize_map + + with pytest.raises(TypeError): # line 125 + jax_parallelize_map(lambda a: a, 3.14, num_parallel=1) + with pytest.raises(ValueError): # length mismatch (line 131) + jax_parallelize_map(lambda a, b: a + b, + [bm.arange(2.0), bm.arange(1.0)], num_parallel=1) + + +def test_jax_map_empty_input_returns_none(): + """Cover the ``res_tree is None -> return None`` branch of both maps (the loop + body never executes for an empty task list): jax_multiprocessing 88-89, 155-156.""" + from brainpy.running.jax_multiprocessing import (jax_vectorize_map, + jax_parallelize_map) + + assert jax_vectorize_map(lambda a: a, [bm.zeros((0,))], num_parallel=1) is None + assert jax_parallelize_map(lambda a: a, [bm.zeros((0,))], num_parallel=1) is None + + +# =========================================================================== # +# joint_eq.py -- nested JointEq, derivative evaluation, dead _std_func, +# conflicting-name DiffEqError. +# =========================================================================== # + +def test_jointeq_nested_derivative_evaluation(): + a, b = 0.02, 0.20 + dV = lambda V, t, u, Iext: 0.04 * V * V + 5 * V + 140 - u + Iext + du = lambda u, t, V: a * (b * V - u) + eq = bp.JointEq(dV, du) + + dw = lambda w, t, V: a * (b * V - w) + eq2 = bp.JointEq(eq, dw) # nested JointEq + + # arg_keys collected across nested equations: V, u, w are state variables. + assert eq2.arg_keys[:3] == ['V', 'u', 'w'] + assert 'Iext' in eq2.arg_keys # positional parameter propagated + + # derivative evaluation returns one value per state variable (3). + res = eq2(-65.0, -14.0, -14.0, 0.0, Iext=10.0) + assert len(res) == 3 + assert all(np.isfinite(float(r)) for r in res) + + +def test_jointeq_call_with_keyword_argument(): + def dV(V, t, w, gain=0.5): + return -V + gain * w + + def dw(w, t, V, gain=0.5): + return -w + gain * V + + eq = bp.JointEq([dV, dw]) # list-form exercises _check_eqs recursion + assert 'gain' in eq.kwarg_keys + res = eq(1.0, 2.0, 0.0, gain=0.3) + assert len(res) == 2 + assert all(np.isfinite(float(r)) for r in res) + + +def test_jointeq_conflicting_kwarg_name_with_state_variable(): + """Cover the 'keyword argument conflicts with existing name' DiffEqError + (joint_eq.py lines 189-194).""" + def dV(V, t, w): + return -V + w + + def dw(w, t, V=1.0): # kwarg 'V' reuses the state variable name 'V' + return -w + V + + with pytest.raises(DiffEqError): + bp.JointEq(dV, dw) + + +def test_jointeq_conflicting_kwarg_defaults(): + def dV(V, t, a=1.0): + return -V + a + + def dw(w, t, a=2.0): # same kwarg name, different default + return -w + a + + with pytest.raises(DiffEqError): + bp.JointEq(dV, dw) + + +def test_jointeq_missing_time_variable_and_var_kinds(): + with pytest.raises(ValueError): # no 't' parameter (line 58) + bp.JointEq(lambda V, w: -V) + with pytest.raises(DiffEqError): # *args (VAR_POSITIONAL) + bp.JointEq(lambda V, t, *extra: -V) + with pytest.raises(DiffEqError): # **kwargs (VAR_KEYWORD) + bp.JointEq(lambda V, t, **extra: -V) + with pytest.raises(DiffEqError): # non-callable element + bp.JointEq(123) + + +def test_jointeq_rejects_keyword_only_and_positional_only(): + """Cover the KEYWORD_ONLY (line 40) and POSITIONAL_ONLY (line 43) rejection + branches of ``_get_args``.""" + # KEYWORD_ONLY: parameter after a bare ``*``. + def kw_only(V, t, *, x): + return -V + x + + with pytest.raises(DiffEqError): + bp.JointEq(kw_only) + + # POSITIONAL_ONLY: parameter before ``/`` (needs an exec'd def for the syntax). + ns = {} + exec("def pos_only(V, t, x, /):\n return -V + x", ns) + with pytest.raises(DiffEqError): + bp.JointEq(ns['pos_only']) + + +def test_jointeq_call_with_kwarg_passed_positionally(): + """Cover the ``__call__`` branch where a trailing positional arg maps onto a + keyword key (joint_eq.py line 235).""" + def dV(V, t, w, gain=0.5): + return -V + gain * w + + def dw(w, t, V, gain=0.5): + return -w + gain * V + + eq = bp.JointEq(dV, dw) + # arg_keys = [V, w, t]; passing a 4th positional arg maps onto kwarg 'gain'. + res = eq(1.0, 2.0, 0.0, 0.3) + assert len(res) == 2 + assert all(np.isfinite(float(r)) for r in res) + + +def test_jointeq_std_func_dependency_from_state_vars(): + """Cover the _std_func branch where a positional dependency is resolved from + the state-variable tuple via ``all_vars.index`` (joint_eq.py line 76) and the + branch where the dependency is supplied via keyword (line 72).""" + from brainpy.integrators.joint_eq import _std_func + + def dV(V, t, w): # 'w' is a dependency that lives in all_vars + return -V + w + + wrapper = _std_func(dV, ['V', 'w']) + out = wrapper(0.0, 1.0, 2.0) # w resolved positionally (line 76) + assert np.isfinite(float(out)) + # 'w' supplied as a keyword -> line 72 (par in args_and_kwargs). + out2 = wrapper(0.0, 1.0, 2.0, w=3.0) + assert np.isfinite(float(out2)) + + +def test_jointeq_duplicate_state_variable_error(): + """Cover the duplicate-state-variable DiffEqError branch (joint_eq.py 157).""" + def dV1(V, t): + return -V + + def dV2(V, t): # 'V' reused as a state variable + return -2 * V + + with pytest.raises(DiffEqError): + bp.JointEq(dV1, dV2) + + +def test_jointeq_dead_std_func_helper(): + """NOTE: ``_std_func`` is dead code (defined but never called anywhere in the + package). We invoke it directly to exercise lines 64-82. It builds a wrapper + that re-routes positional state vars / dependency lookups before calling the + underlying derivative function. + """ + from brainpy.integrators.joint_eq import _std_func + + def dV(V, t, w, gain=0.5): + return -V + gain * w + + all_vars = ['V', 'w'] + wrapper = _std_func(dV, all_vars) + # call(t, *vars, **args_and_kwargs): V and w are positional state vars; + # 'w' is a dependency that is looked up from `vars` via all_vars.index. + out = wrapper(0.0, 1.0, 2.0) # V=1.0, w=2.0 + assert np.isfinite(float(out)) + # pass the kwarg explicitly to cover the kwargs branch (lines 77-79). + out2 = wrapper(0.0, 1.0, 2.0, gain=0.9) + assert np.isfinite(float(out2)) + + +def test_jointeq_std_func_missing_dependency_raises(): + """Cover the 'Missing {par}' DiffEqError inside _std_func (line 75).""" + from brainpy.integrators.joint_eq import _std_func + + def dV(V, t, missing_dep): + return -V + missing_dep + + wrapper = _std_func(dV, ['V']) # 'missing_dep' is neither passed nor a var + with pytest.raises(DiffEqError): + wrapper(0.0, 1.0) + + +# =========================================================================== # +# abstract_models.py -- synapse forward passes + add_current / return_info. +# =========================================================================== # + +def _share(t=0.0, dt=0.1, i=0): + bp.share.save(t=t, dt=dt, i=i) + + +def test_dualexponv2_forward_add_current_return_info(): + """DualExponV2 is not covered by the sibling synapse-forward sweep. + Drive update(x) (add_current branch) and return_info (lines 404-437).""" + bm.random.seed(50) + syn = bp.dyn.DualExponV2(3, tau_decay=5.0, tau_rise=1.0) + syn.reset_state() + _share() + spike = bm.asarray([1.0, 0.0, 1.0]) + out = syn.update(spike) # x not None -> add_current branch + assert jnp.asarray(out).shape == (3,) + assert jnp.all(jnp.isfinite(jnp.asarray(out))) + # update with no current (x=None branch) + out_none = syn.update() + assert jnp.all(jnp.isfinite(jnp.asarray(out_none))) + # return_info -> ReturnInfo with a callable (line 436) + info = syn.return_info() + assert info is not None + + +def test_expon_forward_add_current_return_info(): + bm.random.seed(51) + syn = bp.dyn.Expon(3, tau=5.0) + syn.reset_state() + _share() + out = syn.update(bm.ones(3)) # x not None -> add_current + assert jnp.asarray(out).shape == (3,) + out_none = syn.update() # x None branch + assert jnp.all(jnp.isfinite(jnp.asarray(out_none))) + assert syn.return_info() is syn.g + + +def test_synapse_forward_and_return_info_methods(): + """Drive ``update()`` (covering each derivative + the discrete spike jump) + and ``return_info`` for the JointEq-based synapses (abstract_models lines + 547/550/554-556, 727/730/737-741, 793-803, 867-888, plus return_info 292, + 559, 744, 806, 891).""" + bm.random.seed(52) + _share() + spike_bool = bm.asarray([True, False, True], dtype=bool) + spike_float = spike_bool.astype(float) + for cls, kwargs, inp in [ + (bp.dyn.DualExpon, dict(tau_decay=5.0, tau_rise=1.0), spike_bool), + (bp.dyn.Alpha, dict(tau_decay=5.0), spike_bool), + (bp.dyn.NMDA, dict(tau_decay=10.0, tau_rise=2.0), spike_float), + (bp.dyn.STD, dict(tau=20.0), spike_bool), + (bp.dyn.STP, dict(), spike_bool), + ]: + syn = cls(3, **kwargs) + syn.reset_state() + out = syn.update(inp) + assert jnp.asarray(out).shape == (3,) + assert jnp.all(jnp.isfinite(jnp.asarray(out))) + info = syn.return_info() + assert info is not None + + +def test_dualexpon_forward_runs(): + """Exercise DualExpon.update (line 283-289) + return_info (line 292).""" + bm.random.seed(53) + syn = bp.dyn.DualExpon(3, tau_decay=5.0, tau_rise=1.0) + syn.reset_state() + _share() + out = syn.update(bm.asarray([1.0, 0.0, 1.0])) + assert jnp.asarray(out).shape == (3,) + assert jnp.all(jnp.isfinite(jnp.asarray(out))) + assert syn.return_info() is syn.g diff --git a/tests/audit/test_boost_runners_delay.py b/tests/audit/test_boost_runners_delay.py new file mode 100644 index 000000000..580a602e5 --- /dev/null +++ b/tests/audit/test_boost_runners_delay.py @@ -0,0 +1,879 @@ +# -*- coding: utf-8 -*- +"""Audit coverage-boost tests for ``brainpy/runners.py`` and ``brainpy/delay.py``. + +These tests exercise the *uncovered* option matrix of ``DSRunner`` (jit on/off, +progress bar, memory-efficient mode, list/dict/callable monitors, deprecated +``fun_monitors``, ``numpy_mon_after_run``, ``data_first_axis`` = 'T'/'B', nonzero +``t0``, the several input formats, ``shared_args``/dyn args, repeated ``.run()``, +``.predict()`` and ``.reset_state()``) plus all the ``__init__``/``predict`` +validation error paths. + +For ``delay.py`` it drives ``Delay``/``VarDelay``/``DataDelay``/``DelayAccess``, +both ``ROTATE_UPDATE`` and ``CONCAT_UPDATE`` methods, ``register_entry`` by +``time=`` and by ``step=``, ``.at()``/``.retrieve()``/``.update()``, +``init_delay_by_return`` (Variable + ReturnInfo), ``before_t0``-style ``init`` as +array and as callable, and the validation/error branches. + +Sibling audit files already cover basic ``DSRunner`` and ``VarDelay(time=T)``; +here we focus on the previously-uncovered options and code paths so that line +coverage of both modules rises toward >=90%. +""" + +import warnings + +import numpy as np +import pytest + +import brainpy as bp +import brainpy.math as bm +import jax.numpy as jnp + +from brainpy import check +from brainpy.delay import ( + Delay, + VarDelay, + DataDelay, + DelayAccess, + init_delay_by_return, + register_delay_by_return, +) +from brainpy.math.delayvars import ROTATE_UPDATE, CONCAT_UPDATE +from brainpy.mixin import ReturnInfo + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + +def _net(n=4, **kw): + """A tiny spiking network used across the runner tests.""" + bm.random.seed(123) + return bp.dyn.LifRef(n, **kw) + + +# =========================================================================== +# DSRunner -- monitors +# =========================================================================== + +def test_runner_list_monitors_jit_memory_efficient(): + n = _net() + r = bp.DSRunner(n, monitors=['V', 'spike'], jit=True, + progress_bar=False, memory_efficient=True) + out = r.run(2.0) # 20 steps @ dt=0.1 + assert r.mon['V'].shape == (20, 4) + assert r.mon['spike'].shape == (20, 4) + assert r.mon['ts'].shape == (20,) + # memory-efficient mode forces numpy monitors + assert isinstance(r.mon['V'], np.ndarray) + + +def test_runner_list_monitors_with_index(): + n = _net() + r = bp.DSRunner(n, monitors=[('V', 0), ('spike', [1, 2])], + progress_bar=False) + r.run(1.0) + assert np.asarray(r.mon['V']).shape == (10, 1) + assert np.asarray(r.mon['spike']).shape == (10, 2) + + +def test_runner_dict_monitors_variants(): + """dict monitors: explicit var, (var, idx) tuple, and a callable.""" + n = _net() + r = bp.DSRunner( + n, + monitors={'v0': (n.V, 0), 'vall': n.V, 'fcb': lambda: n.V[:2]}, + jit=False, + progress_bar=False, + numpy_mon_after_run=False, # keep jax arrays + data_first_axis='T', + ) + r.run(1.0) + assert np.asarray(r.mon['v0']).shape == (10, 1) + assert np.asarray(r.mon['vall']).shape == (10, 4) + assert np.asarray(r.mon['fcb']).shape == (10, 2) + # numpy_mon_after_run=False -> ts is a jax array + assert isinstance(bm.as_jax(r.mon['ts']), jnp.ndarray) + + +def test_runner_fun_monitors_deprecated_path(): + n = _net() + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + r = bp.DSRunner(n, fun_monitors={'sp': lambda: n.spike[:2]}, + progress_bar=False) + r.run(1.0) + assert np.asarray(r.mon['sp']).shape == (10, 2) + + +def test_runner_progress_bar_true(): + """progress_bar=True exercises the brainstate pbar branch.""" + n = _net() + r = bp.DSRunner(n, monitors=['V'], progress_bar=True, jit=True) + r.run(1.0) + assert r.mon['V'].shape == (10, 4) + + +def test_runner_t0_nonzero_shifts_time_axis(): + n = _net() + r = bp.DSRunner(n, monitors=['V'], progress_bar=False, t0=5.0) + r.run(1.0) + assert float(r.mon['ts'][0]) == pytest.approx(5.0) + assert float(r.mon['ts'][1]) == pytest.approx(5.1) + + +# =========================================================================== +# DSRunner -- inputs +# =========================================================================== + +def test_runner_input_fix_tuple(): + n = _net() + r = bp.DSRunner(n, monitors=['V'], inputs=(n.V, 1.0), progress_bar=False) + r.run(1.0) + assert r.mon['V'].shape == (10, 4) + + +def test_runner_input_iter_array(): + n = _net() + arr = np.ones((10, 4)) * 0.5 + r = bp.DSRunner(n, monitors=['V'], inputs=(n.V, arr, 'iter'), + progress_bar=False) + r.run(1.0) + assert r.mon['V'].shape == (10, 4) + + +def test_runner_input_iter_generator(): + """A non-array iterable goes through the ``next(...)`` path.""" + n = _net() + r = bp.DSRunner(n, monitors=['V'], inputs=(n.V, [0.1] * 20, 'iter'), + progress_bar=False) + r.run(1.0) + assert r.mon['V'].shape == (10, 4) + + +def test_runner_input_func(): + n = _net() + r = bp.DSRunner(n, monitors=['V'], inputs=(n.V, lambda: 0.3, 'func', '+'), + progress_bar=False) + r.run(1.0) + assert r.mon['V'].shape == (10, 4) + + +def test_runner_input_string_target_assign_op(): + """String target ('V') with the '=' operation (relative/absolute access).""" + n = _net() + r = bp.DSRunner(n, monitors=['V'], inputs=('V', 0.1, 'fix', '='), + progress_bar=False) + r.run(1.0) + assert r.mon['V'].shape == (10, 4) + + +def test_runner_callable_inputs(): + n = _net() + + def fin(): + n.V += 0.1 + + r = bp.DSRunner(n, monitors=['V'], inputs=fin, progress_bar=False) + r.run(1.0) + assert r.mon['V'].shape == (10, 4) + + +def test_runner_multiple_inputs_and_ops(): + """Several inputs with different operations in one runner.""" + n = _net() + r = bp.DSRunner( + n, + monitors=['V'], + inputs=[(n.V, 0.2, 'fix', '+'), + ('V', 0.01, 'fix', '*')], + progress_bar=False, + ) + r.run(1.0) + assert r.mon['V'].shape == (10, 4) + + +# =========================================================================== +# DSRunner -- predict / run / reset_state +# =========================================================================== + +def test_predict_with_xs_array_nonbatching(): + n = _net() + r = bp.DSRunner(n, monitors=['V'], progress_bar=False) + xs = np.ones((15, 4)) * 0.4 + out = r.predict(inputs=xs) + assert np.asarray(out).shape == (15, 4) + assert r.mon['ts'].shape == (15,) + + +def test_predict_eval_time_and_reset_state_arg(): + n = _net() + r = bp.DSRunner(n, monitors=['V'], progress_bar=False) + running_time, out = r.predict(1.0, reset_state=True, eval_time=True) + assert isinstance(running_time, float) + assert np.asarray(out).shape == (10, 4) + + +def test_runner_reset_state_method(): + n = _net() + r = bp.DSRunner(n, monitors=['V'], progress_bar=False) + r.run(1.0) + assert r.i0 == 10 + r.reset_state() + assert r.i0 == 0 + + +def test_runner_repeated_run_accumulates_i0(): + n = _net() + r = bp.DSRunner(n, monitors=['V'], progress_bar=False) + r.run(1.0) + assert r.i0 == 10 + r.run(1.0) + assert r.i0 == 20 + + +def test_runner_shared_args_dyn_args(): + n = _net() + r = bp.DSRunner(n, monitors=['V'], progress_bar=False) + r.predict(1.0, shared_args={'fit': False}) + assert r.mon['V'].shape == (10, 4) + + +def test_runner_call_dunder(): + n = _net() + r = bp.DSRunner(n, monitors=['V'], progress_bar=False) + out = r(1.0) # __call__ + assert np.asarray(out).shape == (10, 4) + + +def test_runner_repr(): + n = _net() + r = bp.DSRunner(n, monitors=['V'], progress_bar=False) + s = repr(r) + assert 'DSRunner' in s and 'data_first_axis' in s + + +def test_predict_duration_and_inputs_warns(): + """Providing both duration and inputs warns and uses inputs' time axis.""" + n = _net() + r = bp.DSRunner(n, monitors=['V'], progress_bar=False) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + r.predict(1.0, inputs=np.ones((10, 4)) * 0.3) + assert r.mon['ts'].shape == (10,) + + +# =========================================================================== +# DSRunner -- batching mode + data_first_axis='B' +# =========================================================================== + +def test_runner_batching_mode_first_axis_B(): + bm.random.seed(7) + n = bp.dyn.LifRef(4, mode=bm.batching_mode) + n.reset(batch_size=3) + r = bp.DSRunner(n, monitors=['V'], progress_bar=False, data_first_axis='B') + xs = np.ones((3, 12, 4)) * 0.4 # (batch, time, feat) + out = r.predict(inputs=xs, reset_state=True) + assert np.asarray(out).shape == (3, 12, 4) + assert r.mon['V'].shape == (3, 12, 4) + + +def test_runner_batching_default_data_first_axis(): + """A BatchingMode target defaults data_first_axis to 'B'.""" + n = bp.dyn.LifRef(4, mode=bm.batching_mode) + r = bp.DSRunner(n, monitors=['V'], progress_bar=False) + assert r.data_first_axis == 'B' + + +# =========================================================================== +# DSRunner -- validation / error branches +# =========================================================================== + +def test_runner_bad_target_type(): + from brainpy._errors import RunningError + with pytest.raises(RunningError): + bp.DSRunner(object()) + + +def test_runner_bad_monitors_type(): + with pytest.raises(Exception): # MonitorError + bp.DSRunner(_net(), monitors=123) + + +def test_runner_bad_inputs_type(): + from brainpy._errors import RunningError + with pytest.raises(RunningError): + bp.DSRunner(_net(), inputs=123) + + +def test_runner_bad_input_structure(): + from brainpy._errors import RunningError + # first element neither str nor Variable and not a list/tuple + with pytest.raises(RunningError): + bp.DSRunner(_net(), inputs=[1.0, 2.0]) + + +def test_runner_bad_input_length(): + from brainpy._errors import RunningError + with pytest.raises(RunningError): + bp.DSRunner(_net(), inputs=(_net().V,)) # length 1 + + +def test_runner_bad_input_op(): + from brainpy._errors import RunningError + with pytest.raises(RunningError): + bp.DSRunner(_net(), inputs=('V', 1.0, 'fix', '^')) + + +def test_runner_bad_input_type_str(): + from brainpy._errors import RunningError + with pytest.raises(RunningError): + bp.DSRunner(_net(), inputs=('V', 1.0, 'bogus')) + + +def test_runner_iter_value_not_iterable(): + with pytest.raises(ValueError): + bp.DSRunner(_net(), inputs=('V', 5, 'iter')) + + +def test_runner_func_value_not_callable(): + with pytest.raises(ValueError): + bp.DSRunner(_net(), inputs=('V', 5, 'func')) + + +def test_runner_bad_input_target_attr(): + with pytest.raises(AttributeError): + bp.DSRunner(_net(), inputs=('nonexist.attr', 1.0)) + + +def test_runner_nonvar_nonstr_target(): + from brainpy._errors import RunningError + with pytest.raises(RunningError): + bp.DSRunner(_net(), inputs=(123, 1.0)) + + +def test_runner_memory_efficient_requires_numpy_mon(): + with pytest.raises(ValueError): + bp.DSRunner(_net(), memory_efficient=True, numpy_mon_after_run=False) + + +def test_predict_without_duration_or_inputs(): + n = _net() + r = bp.DSRunner(n, progress_bar=False) + with pytest.raises(ValueError): + r.predict() + + +def test_runner_bad_dt_type(): + from brainpy._errors import RunningError + with pytest.raises(RunningError): + bp.DSRunner(_net(), dt=1) # int, not float + + +# =========================================================================== +# delay.py -- VarDelay ROTATE_UPDATE +# =========================================================================== + +def _drive_delay(delay, target_var, n_steps, value_fn=None): + """Run a delay update loop with the proper shared-arg context.""" + dt = bm.get_dt() + for i in range(n_steps): + bp.share.save(i=i, t=i * dt, dt=dt) + v = (value_fn(i) if value_fn is not None + else bm.ones_like(target_var.value) * i) + target_var.value = v + delay.update() + + +def test_vardelay_rotate_register_time_and_step(): + bm.random.seed(0) + v = bm.Variable(bm.zeros(3)) + d = VarDelay(v, time=2.0) # 20 steps capacity + d.register_entry('by_time', delay_time=1.0) # -> 10 steps + d.register_entry('by_step', delay_step=5) + d.register_entry('zero', delay_time=None) # zero-delay -> target value + assert d._registered_entries['by_time'] == 10 + assert d._registered_entries['by_step'] == 5 + assert d._registered_entries['zero'] is None + + _drive_delay(d, v, 25) + # at the zero entry returns the current target value + np.testing.assert_allclose(np.asarray(d.at('zero')), 24.0) + # delayed entries return earlier values + np.testing.assert_allclose(np.asarray(d.at('by_time')), 15.0) + np.testing.assert_allclose(np.asarray(d.at('by_step', 0)), 20.0) + # direct retrieve + np.testing.assert_allclose(np.asarray(d.retrieve(3)), 22.0) + + +def test_vardelay_at_with_indices(): + bm.random.seed(0) + v = bm.Variable(bm.zeros(4)) + d = VarDelay(v, time=1.0) + d.register_entry('e', delay_step=5) + _drive_delay(d, v, 12) + out = d.at('e', 1) + assert np.asarray(out).shape == () + + +def test_vardelay_init_array_and_callable(): + """``init`` provided as array and as callable (before_t0-style data).""" + v = bm.Variable(bm.zeros(3)) + d_arr = VarDelay(v, time=0.5, init=jnp.ones((5, 3)) * 2.0) + assert np.asarray(d_arr.data.value).sum() == pytest.approx(2.0 * 15) + + v2 = bm.Variable(bm.zeros(3)) + d_call = VarDelay(v2, time=0.5, + init=lambda shape, dtype: jnp.ones(shape, dtype) * 3.0) + assert np.asarray(d_call.data.value).sum() == pytest.approx(3.0 * 15) + + +def test_vardelay_reset_state_and_repr(): + v = bm.Variable(bm.zeros(3)) + d = VarDelay(v, time=1.0, init=5.0) + d.register_entry('e', delay_step=3) + _drive_delay(d, v, 5) + d.reset_state() + # after reset the buffer is re-initialised to init=5.0 + np.testing.assert_allclose(np.asarray(d.data.value), 5.0) + assert 'VarDelay' in repr(d) + assert d.delay_target_shape == (3,) + + +def test_vardelay_register_entry_array_delay_time(): + v = bm.Variable(bm.zeros(3)) + d = VarDelay(v, time=1.0) + d.register_entry('arr', delay_time=jnp.asarray(0.5)) + assert d._registered_entries['arr'] == 5 + + +# =========================================================================== +# delay.py -- VarDelay CONCAT_UPDATE +# =========================================================================== + +def test_vardelay_concat_update_multi_step(): + bm.random.seed(0) + v = bm.Variable(bm.zeros(3)) + d = VarDelay(v, time=1.0, method=CONCAT_UPDATE) + assert d.method == CONCAT_UPDATE + d.register_entry('a', delay_step=5) + _drive_delay(d, v, 12) + np.testing.assert_allclose(np.asarray(d.at('a')), 7.0) + np.testing.assert_allclose(np.asarray(d.retrieve(3)), 9.0) + + +def test_vardelay_concat_update_single_step(): + """max_length==1 hits the special-case concat branch.""" + bm.random.seed(0) + v = bm.Variable(bm.zeros(2)) + d = VarDelay(v, time=0.1, method=CONCAT_UPDATE) # length 1 + d.register_entry('b', delay_step=1) + _drive_delay(d, v, 3) + assert np.asarray(d.at('b')).shape == (2,) + + +# =========================================================================== +# delay.py -- DataDelay +# =========================================================================== + +def test_datadelay_update_retrieve_reset(): + bm.random.seed(0) + data = bm.Variable(bm.zeros(3)) + dd = DataDelay(data, data_init=bm.zeros, time=0.5) + dd.register_entry('c', delay_step=3) + dt = bm.get_dt() + for i in range(8): + bp.share.save(i=i, t=i * dt, dt=dt) + dd.update(bm.ones(3) * i) + np.testing.assert_allclose(np.asarray(dd.at('c')), 5.0) + dd.reset_state() + assert np.asarray(dd.data.value).sum() == pytest.approx(0.0) + + +def test_datadelay_reset_state_with_batch(): + bm.random.seed(0) + data = bm.Variable(bm.zeros((1, 3)), batch_axis=0) + dd = DataDelay(data, data_init=bm.zeros, time=0.5) + dd.register_entry('c', delay_step=2) + dd.reset_state(batch_size=1) + assert dd.data is not None + + +# =========================================================================== +# delay.py -- DelayAccess +# =========================================================================== + +def test_delay_access_update_and_reset(): + bm.random.seed(0) + v = bm.Variable(bm.zeros(3)) + d = VarDelay(v, time=1.0) + acc = DelayAccess(d, time=0.5, delay_entry='myacc') + _drive_delay(d, v, 12) + np.testing.assert_allclose(np.asarray(acc.update()), 7.0) + acc.reset_state() # no-op, just exercise the branch + + +def test_delay_access_with_indices(): + bm.random.seed(0) + v = bm.Variable(bm.zeros(4)) + d = VarDelay(v, time=1.0) + acc = DelayAccess(d, 0.3, 0, delay_entry='idxacc') + _drive_delay(d, v, 10) + assert np.asarray(acc.update()).shape == () + + +# =========================================================================== +# delay.py -- init_delay_by_return +# =========================================================================== + +def test_init_delay_by_return_variable(): + v = bm.Variable(bm.zeros(4)) + d = init_delay_by_return(v) + assert isinstance(d, VarDelay) + + +def test_init_delay_by_return_returninfo_nonbatching(): + ri = ReturnInfo(size=(4,), batch_or_mode=bm.NonBatchingMode(), data=bm.zeros) + d = init_delay_by_return(ri) + assert isinstance(d, DataDelay) + assert d.target.shape == (4,) + + +def test_init_delay_by_return_returninfo_int_batch(): + ri = ReturnInfo(size=(4,), batch_or_mode=2, data=bm.zeros) + d = init_delay_by_return(ri) + assert isinstance(d, DataDelay) + assert d.target.shape == (2, 4) + + +def test_init_delay_by_return_returninfo_batchingmode(): + ri = ReturnInfo(size=(4,), batch_or_mode=bm.BatchingMode(3), data=bm.zeros) + d = init_delay_by_return(ri) + assert isinstance(d, DataDelay) + assert d.target.shape == (3, 4) + + +def test_init_delay_by_return_returninfo_array_data(): + ri = ReturnInfo(size=(4,), batch_or_mode=bm.NonBatchingMode(), + data=jnp.ones((4,))) + d = init_delay_by_return(ri) + assert isinstance(d, DataDelay) + + +# =========================================================================== +# delay.py -- validation / error branches +# =========================================================================== + +def test_vardelay_bad_target(): + with pytest.raises(ValueError): + VarDelay(bm.zeros(3), time=1.0) # plain array, not a Variable + + +def test_delay_bad_time_type(): + with pytest.raises(TypeError): + VarDelay(bm.Variable(bm.zeros(3)), time='oops') + + +def test_vardelay_duplicate_entry(): + v = bm.Variable(bm.zeros(3)) + d = VarDelay(v, time=1.0) + d.register_entry('e') + with pytest.raises(KeyError): + d.register_entry('e', delay_step=2) + + +def test_vardelay_at_missing_entry(): + v = bm.Variable(bm.zeros(3)) + d = VarDelay(v, time=1.0) + with pytest.raises(KeyError): + d.at('does-not-exist') + + +def test_get_delay_both_time_and_step(): + v = bm.Variable(bm.zeros(3)) + d = VarDelay(v, time=1.0) + with pytest.raises(AssertionError): + d.register_entry('z', delay_time=0.5, delay_step=3) + + +def test_init_delay_by_return_bad_type(): + with pytest.raises(TypeError): + init_delay_by_return(123) + + +def test_init_delay_by_return_returninfo_bad_data(): + ri = ReturnInfo(size=(4,), batch_or_mode=bm.NonBatchingMode(), data=123) + with pytest.raises(TypeError): + init_delay_by_return(ri) + + +def test_base_delay_register_entry_not_implemented(): + d = Delay(time=1.0) + with pytest.raises(NotImplementedError): + d.register_entry('x', delay_step=1) + + +def test_base_delay_at_not_implemented(): + d = Delay(time=1.0) + with pytest.raises(NotImplementedError): + d.at('x') + + +def test_base_delay_retrieve_not_implemented(): + d = Delay(time=1.0) + with pytest.raises(NotImplementedError): + d.retrieve(1) + + +# =========================================================================== +# Extra coverage: input ops, multi-leaf inputs, memory-efficient w/o jit +# =========================================================================== + +def test_runner_input_ops_minus_mul_div(): + """Exercise the '-', '*', '/' branches of ``_f_ops``.""" + n = _net(3) + r = bp.DSRunner( + n, + monitors=['V'], + inputs=[('V', 0.1, 'fix', '-'), + ('V', 1.0, 'fix', '*'), + ('V', 2.0, 'fix', '/')], + progress_bar=False, + ) + r.run(0.5) + assert r.mon['V'].shape == (5, 3) + + +def test_runner_input_func_with_shared_arg(): + """Input callable that *accepts* a shared argument (deprecated bind path).""" + n = _net(3) + + def fin(shared): + n.V += 0.1 + + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + r = bp.DSRunner(n, monitors=['V'], inputs=fin, progress_bar=False) + r.run(0.5) + assert r.mon['V'].shape == (5, 3) + + +def test_runner_memory_efficient_without_jit(): + n = _net(3) + r = bp.DSRunner(n, monitors=['V'], jit=False, progress_bar=False, + memory_efficient=True) + r.run(0.5) + assert r.mon['V'].shape == (5, 3) + assert isinstance(r.mon['V'], np.ndarray) + + +def _TwoPop(): + class Two(bp.DynamicalSystem): + def __init__(self): + super().__init__() + self.a = bp.dyn.LifRef(3) + self.b = bp.dyn.LifRef(3) + + def update(self, x, y): + s1 = self.a(x) + self.b(y) + return s1 + + return Two() + + +def test_runner_multi_leaf_inputs_time_step(): + """A 2-leaf input tuple exercises the multi-leaf time-step branch.""" + net = _TwoPop() + r = bp.DSRunner(net, monitors={'va': net.a.V}, progress_bar=False) + xs = (np.ones((12, 3)) * 0.4, np.ones((12, 3)) * 0.2) + out = r.predict(inputs=xs) + assert r.mon['ts'].shape == (12,) + assert np.asarray(out).shape == (12, 3) + + +# =========================================================================== +# Extra coverage: delay constructor entries, growth, batching, modes +# =========================================================================== + +def test_vardelay_entries_in_constructor(): + """The ``entries=`` constructor argument registers entries (264-265).""" + v = bm.Variable(bm.zeros(3)) + d = VarDelay(v, time=2.0, entries={'e1': 1.0, 'e2': 0.5}) + assert d._registered_entries['e1'] == 10 + assert d._registered_entries['e2'] == 5 + + +def test_vardelay_register_entry_grows_max_length(): + """Registering an entry larger than current capacity grows the buffer.""" + v = bm.Variable(bm.zeros(3)) + d = VarDelay(v, time=0.5) # 5 steps + assert d.max_length == 5 + d.register_entry('big', delay_step=10) + assert d.max_length == 10 + # data buffer was reallocated to the new length + assert d.data.value.shape[0] == 10 + + +def test_vardelay_training_mode_uses_concat(): + """A TrainingMode delay defaults to CONCAT_UPDATE (lines 95-96).""" + v = bm.Variable(bm.zeros((2, 3)), batch_axis=0) + d = VarDelay(v, time=0.5, mode=bm.training_mode) + assert d.method == CONCAT_UPDATE + + +def test_vardelay_batching_mode_at_with_index(): + """BatchingMode delay: ``.at`` inserts a slice at the batch axis (319-321).""" + bm.random.seed(0) + v = bm.Variable(bm.zeros((2, 3)), batch_axis=0) + d = VarDelay(v, time=0.5, mode=bm.batching_mode) + d.register_entry('e', delay_step=3) + dt = bm.get_dt() + for i in range(6): + bp.share.save(i=i, t=i * dt, dt=dt) + v.value = bm.ones((2, 3)) * i + d.update() + assert np.asarray(d.at('e')).shape == (2, 3) + # indexing into the feature axis keeps the batch dim + assert np.asarray(d.at('e', 0)).shape == (2,) + + +def test_vardelay_retrieve_under_checking(): + """With checking enabled, ``retrieve`` runs the ``jit_error`` guard path.""" + was_checking = check.is_checking() + check.turn_on() + try: + bm.random.seed(0) + v = bm.Variable(bm.zeros(3)) + d = VarDelay(v, time=0.5) + d.register_entry('e', delay_step=3) + dt = bm.get_dt() + for i in range(6): + bp.share.save(i=i, t=i * dt, dt=dt) + v.value = bm.ones(3) * i + d.update() + # the checked retrieve path returns a concrete delayed value + np.testing.assert_allclose(np.asarray(d.at('e')), 3.0) + finally: + if not was_checking: + check.turn_off() + + +def test_vardelay_unknown_update_method_raises(): + v = bm.Variable(bm.zeros(3)) + d = VarDelay(v, time=0.5) + d.register_entry('e', delay_step=2) + d.method = 'bogus' + bp.share.save(i=0, t=0.0, dt=bm.get_dt()) + with pytest.raises(ValueError): + d.update() + + +def test_vardelay_unknown_method_retrieve_raises(): + v = bm.Variable(bm.zeros(3)) + d = VarDelay(v, time=0.5) + d.register_entry('e', delay_step=2) + # populate first with a valid method + dt = bm.get_dt() + for i in range(4): + bp.share.save(i=i, t=i * dt, dt=dt) + v.value = bm.ones(3) * i + d.update() + d.method = 'bogus' + bp.share.save(i=4, t=0.4, dt=dt) + with pytest.raises(ValueError): + d.retrieve(2) + + +def test_register_delay_by_return_reuses_instance(): + """``register_delay_by_return`` adds + reuses an after-update delay.""" + n = bp.dyn.LifRef(4) + d1 = register_delay_by_return(n) + d2 = register_delay_by_return(n) + assert isinstance(d1, Delay) + assert d1 is d2 # second call reuses the registered instance + + +def test_runner_fun_inputs_deprecated(): + """The deprecated ``fun_inputs`` argument still drives inputs (367, 562).""" + n = _net(3) + + def finp(shared): + n.V += 0.1 + + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + r = bp.DSRunner(n, monitors=['V'], fun_inputs=finp, progress_bar=False) + r.run(0.5) + assert r.mon['V'].shape == (5, 3) + + +def test_vardelay_zero_delay_at_with_index(): + """A zero-delay entry returns the *indexed* current target value (325).""" + v = bm.Variable(bm.zeros(4)) + d = VarDelay(v, time=1.0) + d.register_entry('zero', delay_time=None) + bp.share.save(i=0, t=0.0, dt=bm.get_dt()) + v.value = bm.arange(4).astype(bm.float_) + assert float(np.asarray(d.at('zero', 1))) == pytest.approx(1.0) + + +def test_vardelay_axis_names_sharding(): + """A target with axis_names builds a sharded delay buffer (244-246).""" + v = bm.Variable(bm.zeros(4), axis_names=('feat',)) + d = VarDelay(v, time=0.5) + assert d.data.value.shape == (5, 4) + + +def test_init_delay_by_return_returninfo_axis_names(): + """ReturnInfo carrying axis_names builds a DataDelay (asserts ndim, 559).""" + ri = ReturnInfo(size=(4,), batch_or_mode=bm.NonBatchingMode(), + data=bm.zeros, axis_names=('feat',)) + d = init_delay_by_return(ri) + assert isinstance(d, DataDelay) + + +# =========================================================================== +# Extra coverage: input formatting / _f_ops helpers (direct unit tests) +# =========================================================================== + +def test_check_and_format_inputs_none(): + """``inputs=None`` produces empty formatted-input buckets (line 79).""" + from brainpy.runners import check_and_format_inputs + n = _net(3) + res = check_and_format_inputs(n, None) + assert set(res) == {'fixed', 'iterated', 'functional', 'array'} + # everything empty + assert all(len(lst) == 0 for d in res.values() for lst in d.values()) + + +def test_check_and_format_inputs_nonstr_target(): + """A non-str / non-Variable target raises in absolute access (line 122).""" + from brainpy.runners import check_and_format_inputs + from brainpy._errors import RunningError + n = _net(3) + with pytest.raises(RunningError): + check_and_format_inputs(n, [(123, 1.0)]) + + +def test_f_ops_unknown_operation_raises(): + """``_f_ops`` with an unknown operation raises (line 207).""" + from brainpy.runners import _f_ops + n = _net(3) + with pytest.raises(ValueError): + _f_ops('^', n.V, 1.0) + + +def test_f_ops_all_supported_operations(): + """Each supported ``_f_ops`` branch runs without error.""" + from brainpy.runners import _f_ops + v = bm.Variable(bm.ones(3)) + _f_ops('=', v, 2.0) + np.testing.assert_allclose(np.asarray(v.value), 2.0) + _f_ops('+', v, 1.0) + np.testing.assert_allclose(np.asarray(v.value), 3.0) + _f_ops('-', v, 1.0) + np.testing.assert_allclose(np.asarray(v.value), 2.0) + _f_ops('*', v, 2.0) + np.testing.assert_allclose(np.asarray(v.value), 4.0) + _f_ops('/', v, 4.0) + np.testing.assert_allclose(np.asarray(v.value), 1.0) diff --git a/tests/audit/test_dnn_toolbox_fixes.py b/tests/audit/test_dnn_toolbox_fixes.py new file mode 100644 index 000000000..ded4f0f50 --- /dev/null +++ b/tests/audit/test_dnn_toolbox_fixes.py @@ -0,0 +1,764 @@ +# -*- coding: utf-8 -*- +"""Regression + coverage tests for the DNN / toolbox audit fixes (2026-06-18). + +This module tests the fixes recorded in ``docs/issues-found-20260618.md`` for the +following source files: + +* ``brainpy/dnn/normalization.py`` -- C-05, H-51, M-25/M-26 (GroupNorm/Instance + group axis, BatchNorm/LayerNorm/GroupNorm + constructing + running under the default mode). +* ``brainpy/losses/comparison.py`` -- C-02 (``nll_loss`` sign), C-03 (class-weighted + cross-entropy), H-53 (``ignore_index`` / + ``label_smoothing``). +* ``brainpy/optim/optimizer.py`` -- C-01 / H-52 (Adam/AdamW bias-correction step + counter), M-29 (``SM3`` instantiable with + train vars). +* ``brainpy/optim/scheduler.py`` -- C-04 (``MultiStepLR`` decay), M-01 (other LR + schedulers). +* ``brainpy/encoding/stateless_encoding.py`` -- ``PoissonEncoder.single_step`` no longer + crashes / passes wrong args. +* ``brainpy/connect/random_conn.py`` -- M-30 (``FixedProb`` empty/rectangular build). + +In addition to the targeted regression checks, the module exercises every optimizer, +scheduler, loss function, normalization layer and stateless/stateful encoder for +coverage. Known *remaining* bugs (``Adan.update``, ``ctc_loss``'s ``.value`` access) +are pinned with ``pytest.raises`` so the lines stay covered and the behaviour is +documented rather than silently broken. +""" + +import numpy as np +import jax.numpy as jnp +import pytest + +import brainpy as bp +import brainpy.math as bm +from brainpy.context import share +from brainpy.optim import optimizer as O +from brainpy.optim import scheduler as S +from brainpy.losses import comparison as C +from brainpy.optim.optimizer import SM3 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _np(x): + return np.asarray(bm.as_jax(x)) + + +def _train_var(shape=(2, 3), fill=1.0): + return bm.Variable(bm.ones(shape) * fill) + + +# =========================================================================== +# 1. normalization.py -- GroupNorm / InstanceNorm / BatchNorm / LayerNorm +# =========================================================================== + +def test_groupnorm_respects_groups(): + """C-05: GroupNorm(3,6) must differ from GroupNorm(1,6); num_groups has effect.""" + bm.random.seed(0) + x = bm.random.randn(4, 6) + g3 = _np(bp.dnn.GroupNorm(3, 6, affine=False)(x)) + g1 = _np(bp.dnn.GroupNorm(1, 6, affine=False)(x)) + g6 = _np(bp.dnn.GroupNorm(6, 6, affine=False)(x)) + assert not np.allclose(g3, g1), "GroupNorm(3,6) must differ from GroupNorm(1,6)" + assert not np.allclose(g3, g6), "GroupNorm(3,6) must differ from GroupNorm(6,6)" + + +def test_groupnorm_each_group_zero_mean_unit_var(): + """C-05: each group is normalized independently to ~zero-mean / unit-var.""" + bm.random.seed(1) + x = bm.random.randn(4, 6) * 3.0 + 5.0 + y = _np(bp.dnn.GroupNorm(3, 6, affine=False)(x)) + # 6 channels / 3 groups -> 2 channels per group. + grouped = y.reshape(4, 3, 2) + assert np.allclose(grouped.mean(axis=-1), 0.0, atol=1e-4) + assert np.allclose(grouped.var(axis=-1), 1.0, atol=1e-2) + + +def test_instancenorm_per_channel_unit_var(): + """C-05: InstanceNorm normalizes each channel independently over the spatial axis.""" + bm.random.seed(2) + x = bm.random.randn(4, 8, 6) * 3.0 + 2.0 + y = _np(bp.dnn.InstanceNorm(6, affine=False)(x)) + # Normalized over the spatial axis (axis=1) per (sample, channel). + assert np.allclose(y.std(axis=1), 1.0, atol=1e-2) + assert np.allclose(y.mean(axis=1), 0.0, atol=1e-4) + + +def test_norm_layers_construct_and_run_default_mode(): + """H-51: BatchNorm1d / LayerNorm / GroupNorm / InstanceNorm construct under the + default (non-training) mode and run a forward pass without raising.""" + x3 = bm.random.randn(2, 4, 3) + assert _np(bp.dnn.BatchNorm1d(num_features=3)(x3)).shape == (2, 4, 3) + assert _np(bp.dnn.LayerNorm(3)(x3)).shape == (2, 4, 3) + x2 = bm.random.randn(2, 6) + assert _np(bp.dnn.GroupNorm(3, 6)(x2)).shape == (2, 6) + assert _np(bp.dnn.InstanceNorm(6)(x2)).shape == (2, 6) + + +def test_norm_layers_affine_and_aliases(): + """Affine path + dimensional aliases (BatchNorm*D) construct and run.""" + x3 = bm.random.randn(2, 4, 3) + assert _np(bp.dnn.BatchNorm1D(num_features=3, affine=True)(x3)).shape == (2, 4, 3) + assert _np(bp.dnn.LayerNorm(3, elementwise_affine=True)(x3)).shape == (2, 4, 3) + assert _np(bp.dnn.GroupNorm(3, 6, affine=True)(bm.random.randn(2, 6))).shape == (2, 6) + + +def test_batchnorm_2d_3d_fit_and_eval(): + """BatchNorm2d/3d run in both 'fit' (update running stats) and 'eval' modes.""" + x4 = bm.random.randn(2, 4, 4, 3) + x5 = bm.random.randn(2, 4, 4, 4, 3) + share.save(fit=True) + try: + assert _np(bp.dnn.BatchNorm2d(num_features=3)(x4)).shape == (2, 4, 4, 3) + assert _np(bp.dnn.BatchNorm3d(num_features=3)(x5)).shape == (2, 4, 4, 4, 3) + # axis_name=None path with affine + assert _np(bp.dnn.BatchNorm2D(num_features=3, affine=True)(x4)).shape == (2, 4, 4, 3) + finally: + share.save(fit=False) + # eval mode: uses running stats + assert _np(bp.dnn.BatchNorm2d(num_features=3)(x4)).shape == (2, 4, 4, 3) + assert _np(bp.dnn.BatchNorm3D(num_features=3)(x5)).shape == (2, 4, 4, 4, 3) + + +def test_batchnorm_input_dim_check(): + """The _check_input_dim guards raise on wrong ndim.""" + with pytest.raises(ValueError): + bp.dnn.BatchNorm1d(num_features=3)(bm.random.randn(2, 3)) # needs 3D + with pytest.raises(ValueError): + bp.dnn.BatchNorm2d(num_features=3)(bm.random.randn(2, 4, 3)) # needs 4D + with pytest.raises(ValueError): + bp.dnn.BatchNorm3d(num_features=3)(bm.random.randn(2, 4, 4, 3)) # needs 5D + + +def test_groupnorm_invalid_channels(): + """num_channels must be divisible by num_groups.""" + with pytest.raises(ValueError): + bp.dnn.GroupNorm(4, 6) + + +def test_layernorm_shape_mismatch_raises(): + """M-26: LayerNorm raises on a normalized-shape mismatch. + + NOTE: the intended ``ValueError`` message is itself mis-built + (``", ".join(self.normalized_shape)`` joins ints, normalization.py:536), so + a ``TypeError`` surfaces first. Either way the layer rejects the bad shape; + we only assert that *some* exception is raised.""" + with pytest.raises(Exception): + bp.dnn.LayerNorm(5)(bm.random.randn(2, 4, 3)) + + +# =========================================================================== +# 2. comparison.py -- loss functions +# =========================================================================== + +def test_nll_loss_is_positive(): + """C-02: nll_loss returns the *negative* log-likelihood (positive number).""" + log_probs = jnp.log(jnp.array([[0.7, 0.2, 0.1], [0.1, 0.8, 0.1]])) + loss = float(C.nll_loss(log_probs, jnp.array([0, 1]))) + assert loss > 0.0 + assert abs(loss - 0.2899) < 1e-3 + + +def test_nll_loss_reductions(): + log_probs = jnp.log(jnp.array([[0.7, 0.2, 0.1], [0.1, 0.8, 0.1]])) + tgt = jnp.array([0, 1]) + assert float(C.nll_loss(log_probs, tgt, reduction='sum')) > 0 + assert _np(C.nll_loss(log_probs, tgt, reduction='none')).shape == (2,) + assert _np(C.nll_loss(log_probs, tgt, reduction=None)).shape == (2,) + with pytest.raises(ValueError): + C.nll_loss(log_probs, tgt, reduction='bogus') + # NLLLoss class + assert float(C.NLLLoss()(log_probs, tgt)) > 0 + + +def test_cross_entropy_weight_by_class(): + """C-03: class weight is applied by target class, not by sample index.""" + out = _np(C.cross_entropy_loss(bm.zeros((3, 3)), [2, 2, 2], + weight=[10, 20, 1], reduction='none')) + # all targets are class 2 -> all per-sample losses equal (w[2] * log(3)). + assert np.allclose(out, out[0]) + assert abs(out[0] - 1.0986) < 1e-3 + + +def test_cross_entropy_ignore_index(): + """H-53: ignore_index excludes the ignored sample from the loss.""" + bm.random.seed(3) + logits = bm.random.randn(4, 3) + tgt_all = jnp.array([0, 1, 2, 2]) + tgt_ign = jnp.array([0, 1, -100, 2]) + loss_all = float(C.cross_entropy_loss(logits, tgt_all)) + loss_ign = float(C.cross_entropy_loss(logits, tgt_ign, ignore_index=-100)) + # Excluding a sample generally changes the mean. + assert not np.isclose(loss_all, loss_ign) + # Equivalent to averaging only the 3 kept samples. + per = _np(C.cross_entropy_loss(logits, tgt_all, reduction='none')) + manual = (per[0] + per[1] + per[3]) / 3.0 + assert abs(loss_ign - manual) < 1e-4 + + +def test_cross_entropy_label_smoothing_changes_loss(): + """H-53: label_smoothing > 0 changes the loss value.""" + bm.random.seed(4) + logits = bm.random.randn(4, 3) + tgt = jnp.array([0, 1, 2, 0]) + base = float(C.cross_entropy_loss(logits, tgt)) + smoothed = float(C.cross_entropy_loss(logits, tgt, label_smoothing=0.1)) + assert not np.isclose(base, smoothed) + + +def test_cross_entropy_class_and_loss_wrappers(): + """CrossEntropyLoss class wraps cross_entropy_loss; reductions + soft targets.""" + bm.random.seed(5) + logits = bm.random.randn(4, 3) + tgt = jnp.array([0, 1, 2, 0]) + assert float(C.CrossEntropyLoss()(logits, tgt)) > 0 + assert float(C.CrossEntropyLoss(ignore_index=-100)(logits, jnp.array([0, 1, -100, 2]))) > 0 + assert float(C.CrossEntropyLoss(label_smoothing=0.1)(logits, tgt)) > 0 + assert float(C.cross_entropy_loss(logits, tgt, reduction='sum')) > 0 + assert _np(C.cross_entropy_loss(logits, tgt, reduction='none')).shape == (4,) + # soft (probability) targets + class weight path. + soft = bm.one_hot(tgt, 3) + assert float(C.cross_entropy_loss(logits, soft)) > 0 + assert float(C.cross_entropy_loss(logits, soft, weight=[1., 2., 3.])) > 0 + assert float(C.cross_entropy_loss(logits, soft, label_smoothing=0.1)) > 0 + + +def test_cross_entropy_sparse_and_sigmoid(): + bm.random.seed(6) + logits = bm.random.randn(4, 3) + assert _np(C.cross_entropy_sparse(logits, jnp.array([[0], [1], [2], [0]]))).shape == (4,) + assert float(jnp.sum(C.cross_entropy_sparse(logits, 1))) != 0.0 # single int target + assert _np(C.cross_entropy_sigmoid(logits, bm.random.rand(4, 3))).shape == (4, 3) + + +def test_regression_losses(): + bm.random.seed(7) + pred = bm.random.randn(4, 3) + tar = bm.random.randn(4, 3) + assert float(C.mean_squared_error(pred, tar)) >= 0 + assert float(C.mean_squared_error(pred, tar, reduction='sum')) >= 0 + assert _np(C.mean_squared_error(pred, tar, axis=1, reduction='none')).shape[0] == 4 + assert float(C.mean_absolute_error(pred, tar)) >= 0 + assert float(C.mean_squared_log_error(bm.abs(pred), bm.abs(tar))) >= 0 + assert float(C.l1_loss(pred, tar)) >= 0 + assert _np(C.l1_loss(pred, tar, reduction='none')).shape == (4,) + assert float(jnp.sum(C.l2_loss(pred, tar))) >= 0 + assert float(jnp.sum(C.huber_loss(pred, tar, delta=0.5))) >= 0 + assert float(jnp.sum(C.log_cosh_loss(pred, tar))) >= 0 + # Loss classes + assert float(C.MSELoss()(pred, tar)) >= 0 + assert float(C.MSELoss(reduction='sum')(pred, tar)) >= 0 + assert float(C.L1Loss()(pred, tar)) >= 0 + assert float(C.MAELoss(axis=None)(pred, tar)) >= 0 + + +def test_classification_helper_losses(): + bm.random.seed(8) + assert float(jnp.sum(C.sigmoid_binary_cross_entropy(bm.random.randn(4, 3), bm.random.rand(4, 3)))) != 0 + labels = bm.one_hot(jnp.array([0, 1, 2, 0]), 3) + assert float(jnp.sum(C.softmax_cross_entropy(bm.random.randn(4, 3), labels))) >= 0 + assert float(jnp.sum(C.binary_logistic_loss(jnp.array([0.5, 1.2]), jnp.array([0, 1])))) != 0 + # multiclass_logistic_loss: single int label + (n_classes,) logits. + assert float(C.multiclass_logistic_loss(1, bm.random.randn(3))) >= 0 + + +def test_multi_margin_loss(): + bm.random.seed(9) + logits = bm.random.randn(4, 3) + tgt = jnp.array([0, 1, 2, 0]) + assert float(C.multi_margin_loss(logits, tgt)) >= 0 + assert float(C.multi_margin_loss(logits, tgt, p=2)) >= 0 + assert float(C.multi_margin_loss(logits, tgt, reduction='sum')) >= 0 + assert _np(C.multi_margin_loss(logits, tgt, reduction='none')).shape == (4, 3) + with pytest.raises(AssertionError): + C.multi_margin_loss(logits, tgt, p=3) + + +def test_ctc_loss_known_value_bug(): + """REMAINING BUG: ctc_loss / ctc_loss_with_forward_probs call ``.value`` on a + plain JAX array returned by ``bm.log_softmax`` (comparison.py:1006), which + raises AttributeError. Pinned here so the regression is documented + covered.""" + B, T, K, N = 2, 5, 4, 3 + logits = bm.random.randn(B, T, K) + lp = bm.zeros((B, T)) + labels = jnp.array([[1, 2, 1], [2, 1, 2]]) + lbp = bm.zeros((B, N)) + with pytest.raises(AttributeError): + C.ctc_loss(logits, lp, labels, lbp) + with pytest.raises(AttributeError): + C.ctc_loss_with_forward_probs(logits, lp, labels, lbp) + + +# =========================================================================== +# 3. optimizer.py +# =========================================================================== + +def test_adam_constant_step_under_unit_gradient(): + """C-01 / H-52: under a constant unit gradient, Adam's per-step |dw| ~= lr and + does NOT grow over time (bias correction uses a real step counter).""" + lr = 0.01 + w = bm.Variable(bm.zeros((3,))) + opt = bp.optim.Adam(lr=lr, train_vars={'w': w}) + prev = _np(w.value).copy() + deltas = [] + for _ in range(5): + opt.update({'w': bm.ones((3,))}) + cur = _np(w.value) + deltas.append(abs(float(cur[0] - prev[0]))) + prev = cur.copy() + assert np.allclose(deltas, lr, atol=1e-4), f"Adam steps drifted: {deltas}" + + +def test_adamw_constant_step_under_unit_gradient(): + """C-01 / H-52: same constant-step property for AdamW (weight_decay=0).""" + lr = 0.01 + w = bm.Variable(bm.zeros((3,))) + opt = bp.optim.AdamW(lr=lr, train_vars={'w': w}, weight_decay=0.0) + prev = _np(w.value).copy() + deltas = [] + for _ in range(5): + opt.update({'w': bm.ones((3,))}) + cur = _np(w.value) + deltas.append(abs(float(cur[0] - prev[0]))) + prev = cur.copy() + assert np.allclose(deltas, lr, atol=1e-4), f"AdamW steps drifted: {deltas}" + + +def test_sm3_instantiates_and_updates(): + """M-29: SM3 instantiates with train_vars and a forward update changes w.""" + w = bm.Variable(bm.ones((3, 4))) + opt = SM3(lr=0.01, train_vars={'w': w}) + before = _np(w.value).copy() + opt.update({'w': bm.ones((3, 4))}) + assert not np.allclose(before, _np(w.value)) + + +def test_sm3_with_momentum_and_weight_decay(): + """SM3 with momentum>0 (allocates a buffer) and weight_decay both run + change w.""" + w = bm.Variable(bm.ones((3, 4))) + opt = SM3(lr=0.01, train_vars={'w': w}, momentum=0.5, beta=0.5, weight_decay=0.01) + before = _np(w.value).copy() + opt.update({'w': bm.ones((3, 4))}) + assert not np.allclose(before, _np(w.value)) + repr(opt) + + +def test_sm3_invalid_hyperparams(): + with pytest.raises(ValueError): + SM3(lr=0.01, momentum=1.5) + with pytest.raises(ValueError): + SM3(lr=0.01, beta=1.5) + with pytest.raises(ValueError): + SM3(lr=0.01, eps=-1.0) + + +@pytest.mark.parametrize("name", ['SGD', 'Momentum', 'MomentumNesterov', 'Adagrad', + 'Adadelta', 'RMSProp', 'Adam', 'AdamW', 'LARS']) +def test_optimizer_construct_and_update(name): + """Coverage: every (working) optimizer constructs and runs an update.""" + cls = getattr(O, name) + w = _train_var() + opt = cls(lr=0.01, train_vars={'w': w}) + before = _np(w.value).copy() + opt.update({'w': bm.ones((2, 3)) * 0.1}) + assert _np(w.value).shape == (2, 3) + repr(opt) + # also run a second update step (exercises EMA / cache update paths) + opt.update({'w': bm.ones((2, 3)) * 0.2}) + assert not np.allclose(before, _np(w.value)) + + +@pytest.mark.parametrize("name", ['SGD', 'Momentum', 'MomentumNesterov', 'Adagrad', + 'Adadelta', 'RMSProp', 'Adam', 'LARS']) +def test_optimizer_weight_decay_path(name): + """Coverage: the weight_decay branch of each optimizer's update.""" + cls = getattr(O, name) + w = _train_var() + opt = cls(lr=0.01, train_vars={'w': w}, weight_decay=0.01) + opt.update({'w': bm.ones((2, 3)) * 0.1}) + assert _np(w.value).shape == (2, 3) + + +def test_adamw_amsgrad_variant(): + """AdamW amsgrad branch constructs and updates.""" + w = _train_var() + opt = bp.optim.AdamW(lr=0.01, train_vars={'w': w}, amsgrad=True, weight_decay=0.01) + opt.update({'w': bm.ones((2, 3)) * 0.1}) + assert _np(w.value).shape == (2, 3) + + +def test_adamw_invalid_hyperparams(): + with pytest.raises(ValueError): + bp.optim.AdamW(lr=0.01, eps=-1.0) + with pytest.raises(ValueError): + bp.optim.AdamW(lr=0.01, beta1=1.5) + with pytest.raises(ValueError): + bp.optim.AdamW(lr=0.01, beta2=1.5) + with pytest.raises(ValueError): + bp.optim.AdamW(lr=0.01, weight_decay=-1.0) + + +def test_adan_constructs_update_is_known_bug(): + """Adan constructs, but REMAINING BUG: ``Adan.update`` passes a tuple operand to + ``lax.cond`` while the branch lambdas expect two args (optimizer.py:818), + raising TypeError. Pinned so the construct path is covered + the bug documented.""" + w = _train_var() + adan = bp.optim.Adan(lr=0.01, train_vars={'w': w}) + repr(adan) + with pytest.raises(TypeError): + adan.update({'w': bm.ones((2, 3)) * 0.1}) + + +def test_adan_invalid_betas(): + with pytest.raises(AssertionError): + bp.optim.Adan(lr=0.01, betas=(0.1, 0.2)) # len != 3 + with pytest.raises(ValueError): + bp.optim.Adan(lr=0.01, eps=-1.0) + with pytest.raises(ValueError): + bp.optim.Adan(lr=0.01, betas=(1.5, 0.1, 0.1)) + + +def test_optimizer_check_grads_mismatch(): + """Optimizer.check_grads raises on a length mismatch.""" + w = _train_var() + opt = bp.optim.SGD(lr=0.01, train_vars={'w': w}) + with pytest.raises(Exception): + opt.update({'w': bm.ones((2, 3)), 'extra': bm.ones((2, 3))}) + + +def test_optimizer_register_train_vars_type_error(): + with pytest.raises(Exception): + bp.optim.SGD(lr=0.01, train_vars=[bm.Variable(bm.ones((2, 3)))]) # must be dict + + +# =========================================================================== +# 4. scheduler.py +# =========================================================================== + +def test_multisteplr_decays(): + """C-04: MultiStepLR actually decays at the milestones.""" + sch = bp.optim.MultiStepLR(0.1, [10, 20], gamma=0.1) + vals = [round(float(sch(i)), 6) for i in [0, 10, 20, 25]] + assert vals == [0.1, 0.1, 0.01, 0.001], vals + repr(sch) + + +def test_constant_scheduler(): + assert float(S.Constant(0.1)()) == pytest.approx(0.1) + assert float(bp.optim.make_schedule(0.05)()) == pytest.approx(0.05) + assert isinstance(bp.optim.make_schedule(S.Constant(0.1)), S.Constant) + with pytest.raises(TypeError): + bp.optim.make_schedule("not-a-schedule") + + +def test_steplr(): + sch = S.StepLR(0.1, step_size=10, gamma=0.1) + vals = [round(float(sch(i)), 6) for i in [0, 5, 10, 20]] + assert vals == [0.1, 0.1, 0.01, 0.001], vals + repr(sch) + + +def test_exponential_lr(): + sch = S.ExponentialLR(0.1, gamma=0.9) + assert float(sch(2)) == pytest.approx(0.1 * 0.9 ** 2, abs=1e-6) + repr(sch) + + +def test_cosine_annealing_lr(): + sch = S.CosineAnnealingLR(0.1, T_max=10, eta_min=0.0) + assert float(sch(0)) == pytest.approx(0.1, abs=1e-5) + assert float(sch(10)) == pytest.approx(0.0, abs=1e-5) + assert 0.0 <= float(sch(5)) <= 0.1 + + +def test_cosine_warm_restarts(): + sch = S.CosineAnnealingWarmRestarts(0.1, num_call_per_epoch=2, T_0=4) + assert 0.0 <= float(sch(3)) <= 0.1 + assert 0.0 <= float(sch(10)) <= 0.1 # epoch >= T_0 -> _cond1 (T_mult==1 path) + assert float(sch.current_epoch(4)) >= 0 + with pytest.raises(ValueError): + S.CosineAnnealingWarmRestarts(0.1, num_call_per_epoch=2, T_0=0) + with pytest.raises(ValueError): + S.CosineAnnealingWarmRestarts(0.1, num_call_per_epoch=2, T_0=4, T_mult=0) + + +def test_cosine_warm_restarts_tmult_gt1_known_bug(): + """REMAINING BUG: with ``T_mult > 1``, ``CosineAnnealingWarmRestarts`` calls + ``lax.cond`` whose two branches (``_cond1`` vs ``_cond2``) return mismatched + output types (float tuple vs int ``T_0``), raising TypeError under jit + (scheduler.py:292-313). Pinned so the construction path stays covered.""" + sch = S.CosineAnnealingWarmRestarts(0.1, num_call_per_epoch=2, T_0=2, T_mult=2) + with pytest.raises(Exception): + float(sch(10)) + + +def test_exponential_decay_lr(): + sch = S.ExponentialDecayLR(0.1, decay_steps=10, decay_rate=0.9) + assert float(sch(5)) == pytest.approx(0.1 * 0.9 ** 0.5, abs=1e-5) + # call-based step advances last_call + v0 = float(sch()) + sch.step_call() + v1 = float(sch()) + assert v1 < v0 + repr(sch) + with pytest.warns(Warning): + S.ExponentialDecay(0.1, 10, 0.9) + + +def test_inverse_time_decay_lr(): + sch = S.InverseTimeDecayLR(0.1, decay_steps=10, decay_rate=0.9) + assert float(sch(5)) == pytest.approx(0.1 / (1 + 0.9 * 5 / 10), abs=1e-5) + stair = S.InverseTimeDecayLR(0.1, decay_steps=10, decay_rate=0.9, staircase=True) + assert float(stair(5)) == pytest.approx(0.1, abs=1e-5) + repr(sch) + with pytest.warns(Warning): + S.InverseTimeDecay(0.1, 10, 0.9) + + +def test_polynomial_decay_lr(): + sch = S.PolynomialDecayLR(0.1, decay_steps=10, final_lr=0.01) + assert float(sch(0)) == pytest.approx(0.1, abs=1e-5) + assert float(sch(10)) == pytest.approx(0.01, abs=1e-5) + assert float(sch(100)) == pytest.approx(0.01, abs=1e-5) # clamped to decay_steps + repr(sch) + with pytest.warns(Warning): + S.PolynomialDecay(0.1, 10, 0.01) + + +def test_piecewise_constant_lr(): + sch = S.PiecewiseConstantLR([10, 20], [0.1, 0.01, 0.001]) + assert float(sch(5)) == pytest.approx(0.1, abs=1e-6) + assert float(sch(15)) == pytest.approx(0.01, abs=1e-6) + assert float(sch(25)) == pytest.approx(0.001, abs=1e-6) + with pytest.warns(Warning): + S.PiecewiseConstant([10, 20], [0.1, 0.01, 0.001]) + from brainpy._errors import MathError + with pytest.raises(MathError): + S.PiecewiseConstantLR([10, 20], [0.1, 0.01]) # bad lengths + + +def test_scheduler_step_epoch_and_set_value(): + sch = S.StepLR(0.1, step_size=5) + sch.step_epoch() + assert int(sch.last_epoch.value) == 0 + sch.set_value(0.5) + assert float(sch.lr) == pytest.approx(0.5) + + +# =========================================================================== +# 5. stateless_encoding.py + stateful encoders +# =========================================================================== + +def test_poisson_single_step_returns_spikes(): + """single_step must return a spike array of the same shape (no TypeError).""" + out = bp.encoding.PoissonEncoder().single_step(bm.random.rand(4)) + arr = _np(out) + assert arr.shape == (4,) + assert set(np.unique(arr)).issubset({0.0, 1.0}) + + +def test_poisson_single_step_with_first_spike_time(): + enc = bp.encoding.PoissonEncoder(first_spk_time=2.0) + before = _np(enc.single_step(bm.ones(4), i_step=0)) + assert np.allclose(before, 0.0) # no spikes before first-spike step + after = _np(enc.single_step(bm.ones(4), i_step=100)) + assert np.allclose(after, 1.0) # prob==1 -> always fires after + + +def test_poisson_normalize_and_multi_steps(): + enc = bp.encoding.PoissonEncoder(min_val=0.0, max_val=2.0, gain=1.0, offset=0.0) + assert _np(enc.single_step(bm.ones(4))).shape == (4,) + spikes = enc.multi_steps(bm.random.rand(3), n_time=5.0) + assert _np(spikes).shape[1:] == (3,) + # n_time=None -> single current step + assert _np(enc.multi_steps(bm.random.rand(3), n_time=None)).shape == (3,) + # first_spk_step > 0 multi-step branch + enc2 = bp.encoding.PoissonEncoder(first_spk_time=1.0) + assert _np(enc2.multi_steps(bm.random.rand(3), n_time=5.0)).shape[1:] == (3,) + + +def test_diff_encoder(): + enc = bp.encoding.DiffEncoder(threshold=1.0) + assert _np(enc.multi_steps(bm.array([1., 2., 2.9, 3., 3.9]))).shape == (5,) + enc2 = bp.encoding.DiffEncoder(threshold=1.0, padding=True, off_spike=True) + assert _np(enc2.multi_steps(bm.array([1., 2., 0., 2., 2.9]))).shape == (5,) + with pytest.raises(NotImplementedError): + enc.single_step(bm.array([1.0])) + + +def test_latency_encoder(): + enc = bp.encoding.LatencyEncoder(method='linear', normalize=True) + out = enc.multi_steps(bm.array([0.02, 0.5, 1.0]), n_time=5.0) + assert _np(out).shape == (50, 3) + enc_log = bp.encoding.LatencyEncoder(method='log', clip=True, normalize=True, + min_val=0.0, max_val=1.0) + assert _np(enc_log.multi_steps(bm.array([0.02, 0.5, 1.0]), n_time=5.0)).shape == (50, 3) + with pytest.raises(NotImplementedError): + enc.single_step(bm.array([0.5])) + with pytest.raises(ValueError): + bp.encoding.LatencyEncoder(method='bogus') + + +def test_weighted_phase_encoder(): + enc = bp.encoding.WeightedPhaseEncoder(min_val=0.0, max_val=1.0, num_phase=4) + out = enc(bm.array([0.3, 0.7]), num_step=4) + assert _np(out).shape == (4, 2) + + +# =========================================================================== +# 6. random_conn.py -- FixedProb / FixedPreNum / FixedPostNum / FixedTotalNum +# =========================================================================== + +def test_fixedprob_nonzero_nnz(): + """M-30: small post population must not silently produce 0 connections.""" + pre, post = bp.connect.FixedProb(prob=0.3, allow_multi_conn=True)( + pre_size=100, post_size=3).build_coo() + assert len(_np(pre)) > 0 + + +def test_fixedprob_rectangular_include_self_false_no_raise(): + """M-30: rectangular shape with include_self=False no longer raises a + (contradictory) ConnectorError.""" + conn = bp.connect.FixedProb(prob=0.3, allow_multi_conn=True, include_self=False)( + pre_size=100, post_size=3) + pre, post = conn.build_coo() + assert len(_np(pre)) > 0 + # build_csr / build_mat also work for the rectangular include_self=False case. + conn.build_csr() + assert conn.build_mat().shape == (100, 3) + + +def test_fixedprob_all_build_methods(): + conn = bp.connect.FixedProb(prob=0.4, allow_multi_conn=True)(pre_size=20, post_size=10) + pre, post = conn.build_coo() + assert len(_np(pre)) > 0 + idx, indptr = conn.build_csr() + assert _np(indptr).shape[0] == 21 # pre_num + 1 + mat = conn.build_mat() + assert mat.shape == (20, 10) + repr(conn) + + +def test_fixedprob_pre_ratio_and_include_self(): + conn = bp.connect.FixedProb(prob=0.5, pre_ratio=0.5, allow_multi_conn=True, + include_self=False)(pre_size=20, post_size=20) + pre, post = conn.build_coo() + assert len(_np(pre)) >= 0 + conn.build_csr() + assert conn.build_mat().shape == (20, 20) + + +def test_fixedprob_invalid_args(): + with pytest.raises(AssertionError): + bp.connect.FixedProb(prob=1.5) + with pytest.raises(AssertionError): + bp.connect.FixedProb(prob=0.3, pre_ratio=2.0) + + +def test_fixed_pre_num(): + conn = bp.connect.FixedPreNum(num=3, allow_multi_conn=True)(pre_size=10, post_size=8) + pre, post = conn.build_coo() + assert len(_np(pre)) > 0 + conn2 = bp.connect.FixedPreNum(num=0.5, allow_multi_conn=True, include_self=False)( + pre_size=10, post_size=10) + assert len(_np(conn2.build_coo()[0])) >= 0 + with pytest.raises(Exception): + bp.connect.FixedPreNum(num=100, allow_multi_conn=True)( + pre_size=10, post_size=8).build_coo() # num > pre_num + + +def test_fixed_post_num(): + conn = bp.connect.FixedPostNum(num=3, allow_multi_conn=True)(pre_size=10, post_size=8) + pre, post = conn.build_coo() + assert len(_np(pre)) > 0 + idx, indptr = conn.build_csr() + assert _np(indptr).shape[0] == 11 # pre_num + 1 + conn2 = bp.connect.FixedPostNum(num=0.5, allow_multi_conn=True, include_self=False)( + pre_size=10, post_size=10) + conn2.build_coo() + conn2.build_csr() + + +def test_fixed_total_num(): + conn = bp.connect.FixedTotalNum(num=12, allow_multi_conn=True)(pre_size=10, post_size=8) + pre, post = conn.build_coo() + assert len(_np(pre)) == 12 + # no-multi-conn (choice without replacement) path + conn2 = bp.connect.FixedTotalNum(num=12, allow_multi_conn=False)(pre_size=10, post_size=8) + assert len(_np(conn2.build_coo()[0])) == 12 + repr(conn) + with pytest.raises(Exception): + bp.connect.FixedTotalNum(num=1000, allow_multi_conn=True)( + pre_size=10, post_size=8).build_coo() # num > all2all + + +def test_connectors_no_multi_conn_paths(): + """Coverage: the ``allow_multi_conn=False`` (numba choice-without-replacement) + build paths of FixedProb / FixedPreNum / FixedPostNum.""" + fp = bp.connect.FixedProb(prob=0.4, allow_multi_conn=False)(pre_size=20, post_size=10) + assert len(_np(fp.build_coo()[0])) > 0 + fp.build_csr() + assert fp.build_mat().shape == (20, 10) + # include_self=False square case + fp2 = bp.connect.FixedProb(prob=0.5, allow_multi_conn=False, include_self=False)( + pre_size=15, post_size=15) + fp2.build_coo() + fp2.build_csr() + assert fp2.build_mat().shape == (15, 15) + + pre_conn = bp.connect.FixedPreNum(num=3, allow_multi_conn=False)(pre_size=12, post_size=8) + assert len(_np(pre_conn.build_coo()[0])) > 0 + pre_conn2 = bp.connect.FixedPreNum(num=3, allow_multi_conn=False, include_self=False)( + pre_size=12, post_size=12) + pre_conn2.build_coo() + + post_conn = bp.connect.FixedPostNum(num=3, allow_multi_conn=False)(pre_size=12, post_size=8) + assert len(_np(post_conn.build_coo()[0])) > 0 + post_conn.build_csr() + post_conn2 = bp.connect.FixedPostNum(num=3, allow_multi_conn=False, include_self=False)( + pre_size=12, post_size=12) + post_conn2.build_coo() + post_conn2.build_csr() + + +def test_connectors_validation_branches(): + """Coverage: float / bad-type ``num`` validation branches of the connectors.""" + # float num accepted at construction + assert bp.connect.FixedTotalNum(num=0.5).num == 0.5 + assert bp.connect.FixedPreNum(num=0.3).num == 0.3 + assert bp.connect.FixedPostNum(num=0.3).num == 0.3 + # bad type rejected + from brainpy._errors import ConnectorError + with pytest.raises(ConnectorError): + bp.connect.FixedTotalNum(num='x') + with pytest.raises(ConnectorError): + bp.connect.FixedPreNum(num='x') + # FixedPreNum with float num builds (probability interpretation) + fp = bp.connect.FixedPreNum(num=0.3, allow_multi_conn=True)(pre_size=10, post_size=8) + assert len(_np(fp.build_coo()[0])) > 0 + # negative integer num rejected + with pytest.raises(AssertionError): + bp.connect.FixedTotalNum(num=-1) + # FixedPostNum num > post_num + with pytest.raises(ConnectorError): + bp.connect.FixedPostNum(num=100, allow_multi_conn=True)( + pre_size=10, post_size=8).build_coo() + + +def test_fixed_pre_post_num_include_self_rectangular_raises(): + """FixedPreNum / FixedPostNum still reject include_self=False for rectangular + (pre_num != post_num) shapes (this guard is intentional for these connectors).""" + with pytest.raises(Exception): + bp.connect.FixedPreNum(num=3, allow_multi_conn=True, include_self=False)( + pre_size=10, post_size=8).build_coo() + with pytest.raises(Exception): + bp.connect.FixedPostNum(num=3, allow_multi_conn=True, include_self=False)( + pre_size=10, post_size=8).build_coo() + + +if __name__ == '__main__': + import sys + sys.exit(pytest.main([__file__, '-q'])) diff --git a/tests/audit/test_dyn_channels_fixes.py b/tests/audit/test_dyn_channels_fixes.py new file mode 100644 index 000000000..69c263bd5 --- /dev/null +++ b/tests/audit/test_dyn_channels_fixes.py @@ -0,0 +1,363 @@ +"""Regression + coverage tests for the BrainPy v2.7.8 audit (2026-06-18). + +This module pins the channel/ion fixes documented in ``docs/issues-found-20260618.md``: + +* **C-14** -- Standalone HH/Markov channel gating produced NaN at the voltage + singularities (``channels/sodium.py``, ``potassium.py``, ``calcium.py`` and the + ``*_compatible`` legacy duplicates). The rate functions of the form + ``k * temp / (1 - exp(-temp / d))`` are ``0/0`` exactly at the removable + singularity. After the fix they are rewritten with a branch-safe ``exprel`` + helper so that the value *and* its gradient are finite there. + Audit repro: ``IK_HH1952v2(1).f_p_alpha([-55.0])`` returns ~0.1, not nan. +* **M-17** -- ``PotassiumFixed`` default ``E`` was ``-950`` mV (typo); fixed to + ``-95`` mV. +* **H-33** -- ``dyn/ions/base.py`` registered every channel under the literal name + ``"k"`` (``self.add_elem(k=v)``), so channels overwrote each other. The fix + (``self.add_elem(**{k: v})``) registers each channel under its real name. + +The remaining tests build every public channel class against a Hodgkin-Huxley +style host and exercise ``reset_state`` / ``update`` for coverage. +""" + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +import brainpy as bp +import brainpy.math as bm + + +# Voltage sweep that intentionally includes the removable-singularity voltages. +_V_SWEEP = jnp.asarray(np.linspace(-120.0, 60.0, 91), dtype=bm.float_) + + +def _assert_finite_with_grad(make_channel, fname, singular_v): + """A rate function and its gradient must be finite across a sweep incl. the singular V.""" + chan = make_channel() + f = getattr(chan, fname) + + # Value finite across the whole sweep (which contains the singular voltage). + vals = np.asarray(f(_V_SWEEP)) + assert np.all(np.isfinite(vals)), f"{fname}: non-finite value over sweep: {vals}" + + # Value finite exactly at the singular voltage. + sval = np.asarray(f(jnp.asarray([singular_v], dtype=bm.float_))) + assert np.all(np.isfinite(sval)), f"{fname}: non-finite value at singular V={singular_v}: {sval}" + + # Gradient finite at the singular voltage (the part bm.where clamping cannot fix). + def scalar(v): + return getattr(make_channel(), fname)(jnp.asarray([v], dtype=bm.float_))[0] + + g = float(jax.grad(scalar)(float(singular_v))) + assert np.isfinite(g), f"{fname}: non-finite gradient at singular V={singular_v}: {g}" + + +# --------------------------------------------------------------------------- +# C-14 regression: finite value + finite gradient at the singular voltages. +# Each tuple is (class, rate-function name, singular voltage). +# The singular voltage is where the argument of the exprel helper is zero. +# --------------------------------------------------------------------------- + +# Legacy/compatible exported classes (master_type = HHTypedNeuron). +_LEGACY_SINGULAR_CASES = [ + # IKDR_Ba2002 V_sh=-50: alpha singular at V - V_sh - 15 = 0 -> V = -35 + ("IKDR_Ba2002", "f_p_alpha", -35.0), + # IK_TM1991 V_sh=-60: alpha singular at 15 - V + V_sh = 0 -> V = -45 + ("IK_TM1991", "f_p_alpha", -45.0), + # IK_HH1952 V_sh=-45: alpha singular at V - V_sh + 10 = 0 -> V = -55 + ("IK_HH1952", "f_p_alpha", -55.0), + # INa_Ba2002 V_sh=-50: p_alpha V-V_sh-13=0 -> -37 ; p_beta V-V_sh-40=0 -> -10 + ("INa_Ba2002", "f_p_alpha", -37.0), + ("INa_Ba2002", "f_p_beta", -10.0), + # INa_TM1991 V_sh=-63: p_alpha 13-V+V_sh=0 -> -50 ; p_beta V-V_sh-40=0 -> -23 + ("INa_TM1991", "f_p_alpha", -50.0), + ("INa_TM1991", "f_p_beta", -23.0), + # INa_HH1952 V_sh=-45: p_alpha V-V_sh-5=0 -> -40 + ("INa_HH1952", "f_p_alpha", -40.0), +] + +# v2 classes (master_type = an Ion subtype). +_V2_SINGULAR_CASES = [ + ("IKDR_Ba2002v2", "f_p_alpha", -35.0), + ("IK_TM1991v2", "f_p_alpha", -45.0), + ("IK_HH1952v2", "f_p_alpha", -55.0), + ("INa_Ba2002v2", "f_p_alpha", -37.0), + ("INa_Ba2002v2", "f_p_beta", -10.0), + ("INa_TM1991v2", "f_p_alpha", -50.0), + ("INa_TM1991v2", "f_p_beta", -23.0), + ("INa_HH1952v2", "f_p_alpha", -40.0), + # ICaHT_Re1993 (markov v2, calcium) V_sh=0: p_alpha -27-V+V_sh=0 -> V=-27 + ("ICaHT_Re1993", "f_p_alpha", -27.0), +] + + +@pytest.mark.parametrize("clsname,fname,singular_v", _LEGACY_SINGULAR_CASES) +def test_c14_legacy_rate_finite_and_grad(clsname, fname, singular_v): + cls = getattr(bp.dyn, clsname) + _assert_finite_with_grad(lambda: cls(1), fname, singular_v) + + +@pytest.mark.parametrize("clsname,fname,singular_v", _V2_SINGULAR_CASES) +def test_c14_v2_rate_finite_and_grad(clsname, fname, singular_v): + cls = getattr(bp.dyn, clsname) + _assert_finite_with_grad(lambda: cls(1), fname, singular_v) + + +def test_c14_ik_hh1952v2_audit_repro(): + """Exact audit repro: was [nan], must now be ~0.1.""" + val = np.asarray(bp.dyn.IK_HH1952v2(1).f_p_alpha(jnp.asarray([-55.0], dtype=bm.float_))) + assert np.all(np.isfinite(val)) + assert np.allclose(val, 0.1, atol=1e-4) + + +def test_c14_all_defined_rate_functions_finite_over_sweep(): + """Sweep every rate function each markov/ss channel defines; none may produce NaN.""" + classes = [ + bp.dyn.IKDR_Ba2002, bp.dyn.IK_TM1991, bp.dyn.IK_HH1952, + bp.dyn.INa_Ba2002, bp.dyn.INa_TM1991, bp.dyn.INa_HH1952, + bp.dyn.IKDR_Ba2002v2, bp.dyn.IK_TM1991v2, bp.dyn.IK_HH1952v2, + bp.dyn.INa_Ba2002v2, bp.dyn.INa_TM1991v2, bp.dyn.INa_HH1952v2, + bp.dyn.ICaHT_Re1993, + ] + for cls in classes: + chan = cls(1) + for fname in ("f_p_alpha", "f_p_beta", "f_q_alpha", "f_q_beta"): + f = getattr(chan, fname, None) + if f is None: + continue + try: + vals = np.asarray(f(_V_SWEEP)) + except NotImplementedError: + continue + assert np.all(np.isfinite(vals)), f"{cls.__name__}.{fname} produced NaN/inf" + + +# --------------------------------------------------------------------------- +# M-17 regression: PotassiumFixed default reversal potential. +# --------------------------------------------------------------------------- + +def test_m17_potassium_fixed_default_E(): + ion = bp.dyn.PotassiumFixed(1) + assert np.allclose(np.asarray(ion.E), -95.0), f"expected -95 mV, got {ion.E}" + # And not the historical buggy value. + assert not np.allclose(np.asarray(ion.E), -950.0) + + +# --------------------------------------------------------------------------- +# H-33 regression: MixIons / Ion.add_elem must register channels under their +# real names instead of collapsing them all under the literal key "k". +# --------------------------------------------------------------------------- + +def test_h33_ion_add_elem_keeps_distinct_names(): + na = bp.dyn.SodiumFixed(2, E=50.) + ina_hh = bp.dyn.INa_HH1952v2(2) + ina_ba = bp.dyn.INa_Ba2002v2(2) + na.add_elem(ina_hh=ina_hh, ina_ba=ina_ba) + + # Both channels stored under their real names, none lost to a literal "k". + assert set(na.children.keys()) == {"ina_hh", "ina_ba"} + assert "k" not in na.children + assert na.children["ina_hh"] is ina_hh + assert na.children["ina_ba"] is ina_ba + assert ina_hh.name != ina_ba.name + + +def test_h33_mixions_registers_distinct_multiion_channels(): + """Build a MixIons over a Calcium+Potassium JointType channel (IAHP_De1994v2). + + Two distinct channels must coexist under their real names; before the fix + the second would overwrite the first under the key "k". + """ + bm.set_dt(0.1) + + class Net(bp.dyn.CondNeuGroupLTC): + def __init__(self, size): + super().__init__(size) + self.Ca = bp.dyn.CalciumDetailed(size) + self.K = bp.dyn.PotassiumFixed(size, E=-95.) + self.KCa = bp.dyn.MixIons(self.Ca, self.K) + self.KCa.add_elem(iahp1=bp.dyn.IAHP_De1994v2(size), + iahp2=bp.dyn.IAHP_De1994v2(size)) + + net = Net(2) + net.reset_state() + assert set(net.KCa.children.keys()) == {"iahp1", "iahp2"} + assert "k" not in net.KCa.children + + bp.share.save(t=0., dt=0.1, i=0) + net.update(bm.ones(2) * -50.) + cur = np.asarray(net.KCa.current(net.V.value)) + assert np.all(np.isfinite(cur)) + + +def test_h33_mix_ions_helper_and_ion_current_paths(): + bm.set_dt(0.1) + na = bp.dyn.SodiumFixed(2, E=50.) + na.add_elem(ina=bp.dyn.INa_HH1952v2(2)) + k = bp.dyn.PotassiumFixed(2, E=-95.) + k.add_elem(ik=bp.dyn.IK_HH1952v2(2)) + + mix = bp.dyn.mix_ions(na, k) + assert isinstance(mix, bp.dyn.MixIons) + + V = bm.ones(2) * -65. + na.reset_state(V) + k.reset_state(V) + # Exercise Ion.update / Ion.current / pack_info on ions/base.py + ions/potassium.py. + bp.share.save(t=0., dt=0.1, i=0) + na.update(V) + k.update(V) + assert np.all(np.isfinite(np.asarray(k.current(V)))) + assert set(k.pack_info().keys()) == {"C", "E"} + + +# --------------------------------------------------------------------------- +# Coverage: construct & exercise every public channel against an HH-style host. +# --------------------------------------------------------------------------- + +def _run_cond_neu_group(net, n_steps=2): + net.reset_state() + bp.share.save(t=0., dt=bm.get_dt(), i=0) + for i in range(n_steps): + bp.share.save(t=i * bm.get_dt(), i=i) + net.update(bm.ones(net.num) * 1.0) + + +def test_coverage_compatible_potassium_and_sodium_channels(): + """Legacy/compatible channels are hosted directly by a CondNeuGroup (HH neuron).""" + bm.set_dt(0.1) + + class Net(bp.dyn.CondNeuGroup): + def __init__(self, size): + super().__init__(size) + # sodium_compatible.py + self.INa_Ba = bp.dyn.INa_Ba2002(size) + self.INa_TM = bp.dyn.INa_TM1991(size) + self.INa_HH = bp.dyn.INa_HH1952(size) + # potassium_compatible.py + self.IKDR = bp.dyn.IKDR_Ba2002(size) + self.IK_TM = bp.dyn.IK_TM1991(size) + self.IK_HH = bp.dyn.IK_HH1952(size) + self.IKA1 = bp.dyn.IKA1_HM1992(size) + self.IKA2 = bp.dyn.IKA2_HM1992(size) + self.IKK2A = bp.dyn.IKK2A_HM1992(size) + self.IKK2B = bp.dyn.IKK2B_HM1992(size) + self.IKNI = bp.dyn.IKNI_Ya1989(size) + self.IKL = bp.dyn.IKL(size) + + _run_cond_neu_group(Net(2)) + + +def test_coverage_v2_sodium_and_potassium_channels(): + """v2 channels (potassium.py / sodium.py) are hosted by Sodium/Potassium ions.""" + bm.set_dt(0.1) + + class Net(bp.dyn.CondNeuGroupLTC): + def __init__(self, size): + super().__init__(size) + self.Na = bp.dyn.SodiumFixed(size, E=50.) + self.Na.add_elem( + ina_ba=bp.dyn.INa_Ba2002v2(size), + ina_tm=bp.dyn.INa_TM1991v2(size), + ina_hh=bp.dyn.INa_HH1952v2(size), + ) + self.K = bp.dyn.PotassiumFixed(size, E=-95.) + self.K.add_elem( + ikdr=bp.dyn.IKDR_Ba2002v2(size), + ik_tm=bp.dyn.IK_TM1991v2(size), + ik_hh=bp.dyn.IK_HH1952v2(size), + ika1=bp.dyn.IKA1_HM1992v2(size), + ika2=bp.dyn.IKA2_HM1992v2(size), + ikk2a=bp.dyn.IKK2A_HM1992v2(size), + ikk2b=bp.dyn.IKK2B_HM1992v2(size), + ikni=bp.dyn.IKNI_Ya1989v2(size), + ik_leak=bp.dyn.IK_Leak(size), + ) + + _run_cond_neu_group(Net(2)) + + +def test_coverage_potassium_module_legacy_classes(): + """potassium.py also ships HHTypedNeuron-hosted legacy duplicates that are + shadowed at the ``bp.dyn`` level by potassium_compatible.py. Import them + directly from the module so the in-file C-14 fixes are exercised too. + """ + import brainpy.dyn.channels.potassium as kmod + bm.set_dt(0.1) + + class Net(bp.dyn.CondNeuGroup): + def __init__(self, size): + super().__init__(size) + self.IKDR = kmod.IKDR_Ba2002(size) + self.IK_TM = kmod.IK_TM1991(size) + self.IK_HH = kmod.IK_HH1952(size) + self.IKA1 = kmod.IKA1_HM1992(size) + self.IKA2 = kmod.IKA2_HM1992(size) + self.IKK2A = kmod.IKK2A_HM1992(size) + self.IKK2B = kmod.IKK2B_HM1992(size) + self.IKNI = kmod.IKNI_Ya1989(size) + + _run_cond_neu_group(Net(2)) + + # The in-file legacy rate functions must also be NaN-free at their singular V. + _assert_finite_with_grad(lambda: kmod.IK_HH1952(1), "f_p_alpha", -55.0) + _assert_finite_with_grad(lambda: kmod.IKDR_Ba2002(1), "f_p_alpha", -35.0) + _assert_finite_with_grad(lambda: kmod.IK_TM1991(1), "f_p_alpha", -45.0) + + +def test_coverage_calcium_fixed_channels(): + """Voltage-gated calcium channels hosted by a (fixed) Calcium ion.""" + bm.set_dt(0.1) + + class Net(bp.dyn.CondNeuGroupLTC): + def __init__(self, size): + super().__init__(size) + self.Ca = bp.dyn.CalciumFixed(size) + self.Ca.add_elem( + icat_hm=bp.dyn.ICaT_HM1992(size), + icat_hp=bp.dyn.ICaT_HP1992(size), + icaht_hm=bp.dyn.ICaHT_HM1992(size), + icaht_re=bp.dyn.ICaHT_Re1993(size), + ical=bp.dyn.ICaL_IS2008(size), + ) + + _run_cond_neu_group(Net(2)) + + +def test_coverage_calcium_dyna_channel_ican(): + """ICaN_IS2008 requires a CalciumDyna host; exercise it for coverage.""" + bm.set_dt(0.1) + + class Net(bp.dyn.CondNeuGroupLTC): + def __init__(self, size): + super().__init__(size) + self.Ca = bp.dyn.CalciumDetailed(size) + self.Ca.add_elem(ican=bp.dyn.ICaN_IS2008(size)) + + _run_cond_neu_group(Net(2)) + + +def test_coverage_ion_objects_reset_and_current(): + """Construct PotassiumFixed / CalciumFixed / SodiumFixed and exercise the ion API.""" + bm.set_dt(0.1) + V = bm.ones(2) * -65. + + k = bp.dyn.PotassiumFixed(2, E=-95.) + k.add_elem(ik=bp.dyn.IK_HH1952v2(2)) + na = bp.dyn.SodiumFixed(2, E=50.) + na.add_elem(ina=bp.dyn.INa_HH1952v2(2)) + ca = bp.dyn.CalciumFixed(2) + ca.add_elem(ical=bp.dyn.ICaL_IS2008(2)) + ca_dyn = bp.dyn.CalciumFirstOrder(2) + + for ion in (k, na, ca): + ion.reset_state(V) + bp.share.save(t=0., dt=0.1, i=0) + ion.update(V) + cur = np.asarray(ion.current(V)) + assert np.all(np.isfinite(cur)) + + # CalciumDyna reset path (different reset_state signature). + ca_dyn.reset_state(V) + assert np.all(np.isfinite(np.asarray(ca_dyn.C))) diff --git a/tests/audit/test_dyn_neurons_synapses_fixes.py b/tests/audit/test_dyn_neurons_synapses_fixes.py new file mode 100644 index 000000000..c5c0e1e0c --- /dev/null +++ b/tests/audit/test_dyn_neurons_synapses_fixes.py @@ -0,0 +1,529 @@ +"""Regression + coverage tests for the 2026-06-18 BrainPy audit. + +This module targets the fixes documented in ``docs/issues-found-20260618.md`` +for the dyn neurons / synapses / projections and dnn/linear subsystems: + + * H-34 ``dyn/neurons/lif.py`` -- ``ExpIFRef`` honours ``noise=`` (sdeint vs odeint). + * H-35 ``dyn/neurons/lif.py`` -- ``IzhikevichRef`` / ``GifRef`` ``detach_spk`` actually + cuts the spike gradient path. + * C-06/H-39/M-22 ``dyn/synapses/abstract_models.py`` -- ``STP`` facilitation no longer diverges; + discrete jumps applied to decayed locals; u/x stay in [0, 1]. + * C-17 ``dyn/projections/inputs.py`` -- ``PoissonInput`` Gaussian std == sqrt(n*p*(1-p)). + * C-18 ``dyn/projections/align_post.py`` -- ``HalfProjAlignPost`` calls ``comm`` exactly once. + * C-19/H-41 ``dyn/projections/plasticity.py`` + ``dnn/linear.py`` -- ``STDP_Song2000`` runs and + changes the (promoted ``Variable``) weight; ``W_min=W_max=None`` + does not crash; bounds clamp when set. + * H-40 ``dyn/projections/base.py`` -- module re-exports the real base classes. + +Plus breadth coverage tests that construct and run one ``update()`` for every public +neuron class in ``lif.py`` (and LTC variants), the abstract synapse models, the +projection wrappers, and the dnn/linear comm layers. + +Run from the worktree root with ``PYTHONPATH=.``. +""" + +import numpy as np +import jax +import jax.numpy as jnp +import pytest + +import brainpy as bp +import brainpy.math as bm +from brainpy.integrators.ode.base import ODEIntegrator +from brainpy.integrators.sde.base import SDEIntegrator + + +def _share(t=0.0, dt=0.1, i=0): + bp.share.save(t=t, dt=dt, i=i) + + +# --------------------------------------------------------------------------- +# H-34: ExpIFRef honours noise= (sdeint) and is plain odeint otherwise. +# --------------------------------------------------------------------------- + +def test_expifref_noise_uses_sde_integrator(): + noisy = bp.dyn.ExpIFRef(3, noise=0.5) + plain = bp.dyn.ExpIFRef(3) + assert isinstance(noisy.integral, SDEIntegrator) + assert not isinstance(noisy.integral, ODEIntegrator) + assert isinstance(plain.integral, ODEIntegrator) + assert not isinstance(plain.integral, SDEIntegrator) + + +def test_expifref_noisy_update_runs(): + bm.random.seed(0) + neu = bp.dyn.ExpIFRef(3, noise=1.0) + neu.reset_state() + _share() + neu.update(bm.ones(3) * 5.0) + assert jnp.all(jnp.isfinite(neu.V.value)) + + +# --------------------------------------------------------------------------- +# H-35: IzhikevichRef / GifRef detach_spk actually cuts the spike gradient. +# --------------------------------------------------------------------------- + +def _one_step_grad_izhikevich(detach, v0): + neu = bp.dyn.IzhikevichRef(3, mode=bm.training_mode, detach_spk=detach) + + def loss(inp): + neu.reset_state(bm.training_mode) + neu.V.value = bm.ones(neu.V.shape) * v0 + _share() + neu.update(inp) + return jnp.sum(neu.V.value) + + return jax.grad(loss)(bm.zeros(3)) + + +def test_izhikevichref_detach_spk_changes_gradient(): + # With V driven above threshold (~30 mV) the spike path is active, so + # detaching the spike (cutting it) changes the gradient w.r.t. the input. + v0 = 30.0 + g_detach = _one_step_grad_izhikevich(True, v0) + g_plain = _one_step_grad_izhikevich(False, v0) + assert jnp.all(jnp.isfinite(g_detach)) + assert jnp.all(jnp.isfinite(g_plain)) + # detach_spk cuts the spike contribution -> gradient differs from the + # grad-carrying path. + assert not jnp.allclose(g_detach, g_plain) + + +def _one_step_grad_gif(detach, v0): + neu = bp.dyn.GifRef(3, mode=bm.training_mode, detach_spk=detach) + + def loss(inp): + neu.reset_state(bm.training_mode) + neu.V.value = bm.ones(neu.V.shape) * v0 + _share() + neu.update(inp) + # include the adaptation/threshold states the spike resets touch. + return (jnp.sum(neu.V.value) + jnp.sum(neu.I1.value) + + jnp.sum(neu.I2.value) + jnp.sum(neu.V_th.value)) + + return jax.grad(loss)(bm.zeros(3)) + + +def test_gifref_detach_spk_changes_gradient(): + v0 = -50.0 # drives V across the GIF threshold within one step + g_detach = _one_step_grad_gif(True, v0) + g_plain = _one_step_grad_gif(False, v0) + assert jnp.all(jnp.isfinite(g_detach)) + assert jnp.all(jnp.isfinite(g_plain)) + assert not jnp.allclose(g_detach, g_plain) + + +# --------------------------------------------------------------------------- +# C-06 / H-39 / M-22: STP short-term facilitation does not diverge. +# --------------------------------------------------------------------------- + +def test_stp_no_spike_u_does_not_increase(): + stp = bp.dyn.STP(1) + u0 = float(stp.u.value[0]) + no_spike = bm.zeros(1, dtype=bool) + for _ in range(5): + _share() + stp.update(no_spike) + # With no spike, facilitation u must decay/stay -- never grow. + assert float(stp.u.value[0]) <= u0 + 1e-9 + + +def test_stp_spiking_u_x_stay_bounded(): + stp = bp.dyn.STP(1) + spike = bm.ones(1, dtype=bool) + us, xs = [], [] + for i in range(50): + _share(t=i * 0.1, i=i) + stp.update(spike) + us.append(float(stp.u.value[0])) + xs.append(float(stp.x.value[0])) + assert min(us) >= 0.0 and max(us) <= 1.0 + assert min(xs) >= 0.0 and max(xs) <= 1.0 + # released resource u*x stays finite and modest (used to explode to thousands) + assert max(u * x for u, x in zip(us, xs)) < 1.0 + + +# --------------------------------------------------------------------------- +# C-17: PoissonInput Gaussian-approx std == sqrt(n*p*(1-p)). +# --------------------------------------------------------------------------- + +def test_poisson_input_gaussian_std_matches_binomial(): + bm.random.seed(0) + n_neuron = 4000 + n_input = 1000 + freq = 200.0 # Hz + dt = 0.1 + target = bm.Variable(bm.zeros(n_neuron)) + _share(dt=dt) + pin = bp.dyn.PoissonInput(target_var=target, num_input=n_input, freq=freq, weight=1.0) + + p = freq * dt / 1e3 + # ensure we exercise the Gaussian branch (a > 5 and b > 5) + assert n_input * p > 5 and n_input * (1 - p) > 5 + + target.value = bm.zeros(n_neuron) + pin.update() + samp = np.asarray(target.value) + + expected_std = np.sqrt(n_input * p * (1 - p)) + expected_mean = n_input * p + # std must be the binomial std, NOT the variance (the old bug was ~4x too big). + assert samp.std() == pytest.approx(expected_std, rel=0.1) + assert samp.mean() == pytest.approx(expected_mean, rel=0.1) + # guard against the regression where std == variance (b*p). + assert samp.std() < 0.5 * (n_input * p * (1 - p)) + + +# --------------------------------------------------------------------------- +# C-18: HalfProjAlignPost calls comm exactly once per update. +# --------------------------------------------------------------------------- + +class _CountingLinear(bp.dnn.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.n_calls = 0 + + def update(self, x): + self.n_calls += 1 + return super().update(x) + + +def test_halfprojalignpost_calls_comm_once(): + post = bp.dyn.LifRef(4) + comm = _CountingLinear(4, 4, W_initializer=bp.init.Constant(0.1)) + proj = bp.dyn.HalfProjAlignPost( + comm=comm, + syn=bp.dyn.Expon(4, tau=5.0), + out=bp.dyn.COBA(E=0.0), + post=post, + ) + _share() + current = proj.update(bm.ones(4)) + assert comm.n_calls == 1 + # returned current equals one manual comm application. + expected = comm.update(bm.ones(4)) # second call only for comparison + assert jnp.allclose(jnp.asarray(current), jnp.asarray(expected)) + + +# --------------------------------------------------------------------------- +# C-19 / H-41: STDP over a Linear/AllToAll comm runs and changes the weight. +# --------------------------------------------------------------------------- + +def _build_stdp_net(w_min=None, w_max=None): + pre = bp.dyn.LifRef(3) + post = bp.dyn.LifRef(4) + syn = bp.dyn.STDP_Song2000( + pre=pre, + delay=1.0, + comm=bp.dnn.AllToAll(3, 4, weight=bp.init.Uniform(max_val=0.1)), + syn=bp.dyn.Expon.desc(post.varshape, tau=5.0), + out=bp.dyn.COBA.desc(E=0.0), + post=post, + W_min=w_min, + W_max=w_max, + ) + net = bp.DynSysGroup(pre=pre, post=post, syn=syn) + net.reset_state() + return net, pre, post, syn + + +def test_stdp_song2000_runs_and_changes_weight(): + net, pre, post, syn = _build_stdp_net() + # weight starts as a plain array (comm built outside a TrainingMode). + assert not isinstance(syn.comm.weight, bm.Variable) + w0 = bm.as_jax(syn.comm.weight).copy() + + for i in range(40): + _share(t=i * 0.1, i=i) + syn() + pre(bm.ones(3) * 80.0) + post(bm.ones(4) * 80.0) + + # weight is promoted to a Variable on the first stdp_update and updates. + assert isinstance(syn.comm.weight, bm.Variable) + w1 = bm.as_jax(syn.comm.weight) + assert not jnp.allclose(w0, w1) + assert jnp.all(jnp.isfinite(w1)) + + +def test_stdp_song2000_default_none_bounds_do_not_crash(): + # W_min == W_max == None used to crash on bm.as_jax(None) (H-41). + net, pre, post, syn = _build_stdp_net(w_min=None, w_max=None) + _share() + syn() + pre(bm.ones(3) * 80.0) + post(bm.ones(4) * 80.0) + assert jnp.all(jnp.isfinite(bm.as_jax(syn.comm.weight))) + + +def test_stdp_song2000_bounds_clamp_weight(): + w_max = 0.2 + w_min = 0.0 + net, pre, post, syn = _build_stdp_net(w_min=w_min, w_max=w_max) + for i in range(60): + _share(t=i * 0.1, i=i) + syn() + pre(bm.ones(3) * 80.0) + post(bm.ones(4) * 80.0) + w = bm.as_jax(syn.comm.weight) + assert jnp.all(w <= w_max + 1e-5) + assert jnp.all(w >= w_min - 1e-5) + + +# --------------------------------------------------------------------------- +# H-40: projections/base.py re-exports real base classes (no dead duplicate). +# --------------------------------------------------------------------------- + +def test_projections_base_reexports_real_classes(): + from brainpy.dyn.projections import base as proj_base + assert hasattr(proj_base, 'Projection') + assert hasattr(proj_base, 'SynConn') + assert 'Projection' in proj_base.__all__ + assert 'SynConn' in proj_base.__all__ + # the re-exported Projection is the canonical one used by the wrappers. + assert proj_base.Projection is bp.dyn.Projection + + +# --------------------------------------------------------------------------- +# Coverage: every public neuron class in lif.py, both modes. +# --------------------------------------------------------------------------- + +_NEURON_NAMES = [ + 'IF', 'Lif', 'LifRef', + 'ExpIF', 'ExpIFRef', + 'AdExIF', 'AdExIFRef', + 'QuaIF', 'QuaIFRef', + 'AdQuaIF', 'AdQuaIFRef', + 'Gif', 'GifRef', + 'Izhikevich', 'IzhikevichRef', +] +_NEURON_NAMES = _NEURON_NAMES + [n + 'LTC' for n in _NEURON_NAMES] + + +@pytest.mark.parametrize('name', _NEURON_NAMES) +@pytest.mark.parametrize('mode', [None, 'train']) +def test_neuron_update_runs(name, mode): + cls = getattr(bp.dyn, name) + if mode == 'train': + neu = cls(3, mode=bm.training_mode) + neu.reset_state(bm.training_mode) + else: + neu = cls(3) + neu.reset_state() + # small drive: the exponential-IF family genuinely overflows V in a single + # large training-mode step (no hard refractory clamp before the surrogate + # spike), which is a model property, not an audit regression. + _share() + out = neu.update(bm.ones(3) * 1.0) + assert out is not None + assert neu.V.value.shape[-1] == 3 + if mode is None: + # default (non-training) mode applies the hard reset, so V stays finite. + assert jnp.all(jnp.isfinite(neu.V.value)) + + +# --------------------------------------------------------------------------- +# Coverage: abstract synapse models forward passes. +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize('name', ['Expon', 'DualExpon', 'Alpha', 'NMDA', 'AMPA', 'STP', 'STD']) +def test_synapse_forward(name): + cls = getattr(bp.dyn, name) + syn = cls(3) + spike_bool = bm.asarray([True, False, True], dtype=bool) + spike_float = spike_bool.astype(float) + _share() + if name in ('NMDA', 'AMPA'): + out = syn(spike_float) + else: + out = syn(spike_bool) + assert jnp.asarray(out).shape == (3,) + assert jnp.all(jnp.isfinite(jnp.asarray(out))) + + +# --------------------------------------------------------------------------- +# Coverage: projection wrappers. +# --------------------------------------------------------------------------- + +def test_full_proj_align_post(): + class Net(bp.DynSysGroup): + def __init__(self): + super().__init__() + self.pre = bp.dyn.LifRef(4) + self.post = bp.dyn.LifRef(3) + self.proj = bp.dyn.FullProjAlignPost( + pre=self.pre, delay=None, + comm=bp.dnn.AllToAll(4, 3, weight=0.1), + syn=bp.dyn.Expon(3, tau=5.0), + out=bp.dyn.COBA(E=0.0), + post=self.post, + ) + + def update(self, inp): + self.proj() + self.pre(inp) + self.post() + return self.post.V.value + + net = Net() + net.reset_state() + _share() + out = net.update(bm.ones(4) * 5.0) + assert jnp.all(jnp.isfinite(out)) + + +def test_full_proj_align_post_mg(): + class Net(bp.DynSysGroup): + def __init__(self): + super().__init__() + self.pre = bp.dyn.LifRef(4) + self.post = bp.dyn.LifRef(3) + self.proj = bp.dyn.FullProjAlignPostMg( + pre=self.pre, delay=None, + comm=bp.dnn.AllToAll(4, 3, weight=0.1), + syn=bp.dyn.Expon.desc(3, tau=5.0), + out=bp.dyn.COBA.desc(E=0.0), + post=self.post, + ) + + def update(self, inp): + self.proj() + self.pre(inp) + self.post() + return self.post.V.value + + net = Net() + net.reset_state() + _share() + out = net.update(bm.ones(4) * 5.0) + assert jnp.all(jnp.isfinite(out)) + + +def test_full_proj_align_pre_sdmg(): + class Net(bp.DynSysGroup): + def __init__(self): + super().__init__() + self.pre = bp.dyn.LifRef(4) + self.post = bp.dyn.LifRef(3) + self.proj = bp.dyn.FullProjAlignPreSDMg( + pre=self.pre, + syn=bp.dyn.Expon.desc(4, tau=5.0), + delay=None, + comm=bp.dnn.AllToAll(4, 3, weight=0.1), + out=bp.dyn.COBA(E=0.0), + post=self.post, + ) + + def update(self, inp): + self.proj() + self.pre(inp) + self.post() + return self.post.V.value + + net = Net() + net.reset_state() + _share() + out = net.update(bm.ones(4) * 5.0) + assert jnp.all(jnp.isfinite(out)) + + +def test_half_proj_align_post_mg(): + post = bp.dyn.LifRef(3) + proj = bp.dyn.HalfProjAlignPostMg( + comm=bp.dnn.AllToAll(4, 3, weight=0.1), + syn=bp.dyn.Expon.desc(3, tau=5.0), + out=bp.dyn.COBA.desc(E=0.0), + post=post, + ) + _share() + out = proj.update(bm.ones(4)) + assert jnp.all(jnp.isfinite(jnp.asarray(out))) + + +# --------------------------------------------------------------------------- +# Coverage: PoissonInput full update path (binomial small-N branch too). +# --------------------------------------------------------------------------- + +def test_poisson_input_small_n_branch(): + bm.random.seed(1) + target = bm.Variable(bm.zeros(8)) + _share(dt=0.1) + pin = bp.dyn.PoissonInput(target_var=target, num_input=10, freq=20.0, weight=1.0) + assert repr(pin) # exercise __repr__ + pin.update() # a, b small -> binomial branch + assert jnp.all(jnp.isfinite(target.value)) + + +def test_input_var_runs(): + iv = bp.dyn.InputVar(4) + iv.input += 1.0 + assert jnp.allclose(jnp.asarray(iv.update()), 1.0) + iv.clear_input() + assert jnp.allclose(jnp.asarray(iv.update()), 0.0) + + +# --------------------------------------------------------------------------- +# Coverage: dnn/linear comm layer forward passes. +# --------------------------------------------------------------------------- + +def test_dnn_linear_forwards(): + conn = bp.conn.FixedProb(0.5, pre=4, post=3, seed=1) + + dense = bp.dnn.Dense(4, 3) + assert jnp.asarray(dense(bm.ones(4))).shape == (3,) + + a2a = bp.dnn.AllToAll(4, 3, weight=0.1) + # NonBatching mode requires a 1D input; identical scalar weight + ones input + # collapses to a scalar (sum over pre). + assert jnp.asarray(a2a(bm.ones(4))).shape == () + + o2o = bp.dnn.OneToOne(4, weight=0.1) + assert jnp.asarray(o2o(bm.ones(4))).shape == (4,) + + ml = bp.dnn.MaskedLinear(conn, weight=0.1) + assert jnp.asarray(ml(bm.ones(4))).shape == (3,) + + csr = bp.dnn.CSRLinear(conn, weight=0.1) + assert jnp.asarray(csr(bm.ones(4))).shape == (3,) + + +def test_dnn_event_and_jitprob_forwards(): + conn = bp.conn.FixedProb(0.5, pre=4, post=3, seed=1) + spike = bm.asarray([True, False, True, False]) + x = bm.ones(4) + + # Identity passthrough. + assert jnp.asarray(bp.dnn.Identity()(x)).shape == (4,) + + # Event-driven CSR. + ec = bp.dnn.EventCSRLinear(conn, weight=0.1) + assert jnp.asarray(ec(spike)).shape == (3,) + + # JIT just-in-time fixed-prob connectivity (dense-input variants). + assert jnp.asarray( + bp.dnn.JitFPHomoLinear(4, 3, prob=0.5, weight=0.1, seed=1)(x)).shape == (3,) + assert jnp.asarray( + bp.dnn.JitFPUniformLinear(4, 3, prob=0.5, w_low=0.0, w_high=0.1, seed=1)(x)).shape == (3,) + assert jnp.asarray( + bp.dnn.JitFPNormalLinear(4, 3, prob=0.5, w_mu=0.0, w_sigma=0.1, seed=1)(x)).shape == (3,) + + # JIT fixed-prob connectivity (event/spike-input variants). + assert jnp.asarray( + bp.dnn.EventJitFPHomoLinear(4, 3, prob=0.5, weight=0.1, seed=1)(spike)).shape == (3,) + assert jnp.asarray( + bp.dnn.EventJitFPUniformLinear(4, 3, prob=0.5, w_low=0.0, w_high=0.1, seed=1)(spike)).shape == (3,) + assert jnp.asarray( + bp.dnn.EventJitFPNormalLinear(4, 3, prob=0.5, w_mu=0.0, w_sigma=0.1, seed=1)(spike)).shape == (3,) + + +def test_dense_stdp_update_promotes_weight(): + dense = bp.dnn.Dense(3, 4, W_initializer=bp.init.Uniform(max_val=0.1)) + assert not isinstance(dense.W, bm.Variable) + # the plasticity wrapper passes jax arrays to ``stdp_update``; mirror that. + spike_pre = bm.as_jax(bm.asarray([1.0, 0.0, 1.0])) + trace_pre = bm.as_jax(bm.asarray([0.1, 0.2, 0.3, 0.4])) + w0 = bm.as_jax(dense.W).copy() + dense.stdp_update(on_pre={'spike': spike_pre, 'trace': trace_pre}, w_min=None, w_max=None) + assert isinstance(dense.W, bm.Variable) + assert not jnp.allclose(w0, bm.as_jax(dense.W)) + assert jnp.all(jnp.isfinite(bm.as_jax(dense.W))) diff --git a/tests/audit/test_dyn_rates_dynold_fixes.py b/tests/audit/test_dyn_rates_dynold_fixes.py new file mode 100644 index 000000000..b96720ffd --- /dev/null +++ b/tests/audit/test_dyn_rates_dynold_fixes.py @@ -0,0 +1,485 @@ +# -*- coding: utf-8 -*- +"""Audit regression + coverage tests (audit ``docs/issues-found-20260618.md``). + +This module exercises the fixes recorded in the 2026-06-18 BrainPy audit for the +rate-population / reservoir / RNN-cell modules under ``brainpy/dyn/rates`` and the +``brainpy/dynold`` compatibility shims. + +Regression behaviors covered: + +* C-15 ``ThresholdLinearModel`` noise path no longer crashes (``randn(*shape)``). +* C-16 ``StuartLandauOscillator.dy`` uses the correct ``+w*x`` rotational coupling. +* C-17 (dynold copy in ``experimental/others.py``) ``PoissonInput`` Gaussian branch + uses ``std = sqrt(b*p)``, not the variance ``b*p``. +* C-20 ``AlphaCUBA`` / ``AlphaCOBA`` construct without ``ZeroDivisionError``. +* C-21 dynold ``STP`` synaptic current does not drift with zero presynaptic spikes. +* H-36 ``LSTMCell`` ``h`` / ``c`` setters slice the last axis (unbatched + batched); + setting ``c`` does not corrupt ``h``. +* H-37 ``Reservoir`` recurrent noise is symmetric (``uniform(-1, 1)`` -> zero mean). +* H-38 ``Reservoir`` bias is actually added in ``update``. + +The remaining tests construct + step the assigned modules for coverage. Bugs that +are still *unfixed* in the source are recorded in the agent summary, not asserted +here (e.g. M-21 ``reset_state(None)``, M-24 ``ALIFBellec2020`` ``a_initializer``). +""" + +import jax.numpy as jnp +import pytest + +import brainpy as bp +import brainpy.math as bm +import brainpy.initialize as init +from brainpy.context import share + + +def _share(t=0.0, dt=0.1, i=0): + """Populate the shared simulation context required by ``update``.""" + bm.set_dt(dt) + share.save(t=t, dt=dt, i=i) + + +# --------------------------------------------------------------------------- +# Regression tests +# --------------------------------------------------------------------------- + +def test_threshold_linear_model_noise_path_runs(): + """C-15: nonzero ``noise_e`` no longer raises ``TypeError`` from ``randn``.""" + bm.random.seed(0) + m = bp.dyn.ThresholdLinearModel(5, noise_e=1.0) + _share(t=0.0, dt=1e-4, i=0) + out = m.update(inp_e=0.0) + arr = jnp.asarray(out) + assert arr.shape == (5,) + assert bool(jnp.isfinite(arr).all()) + + +def test_threshold_linear_model_noise_i_path_runs(): + """C-15 companion: the inhibitory noise branch is also reachable.""" + bm.random.seed(1) + m = bp.dyn.ThresholdLinearModel(4, noise_i=0.5) + _share(t=0.0, dt=1e-4, i=0) + out = m.update(inp_e=1.0, inp_i=1.0) + assert bool(jnp.isfinite(jnp.asarray(out)).all()) + + +def test_stuart_landau_dy_rotational_coupling(): + """C-16: ``dy`` must add ``+w*x`` (Hopf normal form), not ``-w*y``.""" + m = bp.dyn.StuartLandauOscillator(1, a=0.25, w=0.2) + # signature is dy(self, y, t, x, y_ext, a, w) + val = float(jnp.asarray(m.dy(0.3, 0.0, 0.5, 0.0, 0.25, 0.2))) + # (a - x^2 - y^2)*y + w*x = (0.25 - 0.25 - 0.09)*0.3 + 0.2*0.5 = 0.073 + assert val == pytest.approx(0.073, abs=1e-4) + # The buggy -w*y value would be -0.087; make sure we are not there. + assert val > 0.0 + + +def test_poisson_input_gaussian_std_is_sqrt_bp(): + """C-17 (dynold copy): Gaussian branch std = sqrt(b*p), ~4.43 not ~19.6.""" + from brainpy.dynold.experimental.others import PoissonInput + + bm.random.seed(0) + dt = 0.1 + freq, num = 200.0, 1000 + pi = PoissonInput(target_shape=(20000,), num_input=num, freq=freq, weight=1.0) + _share(t=0.0, dt=dt, i=0) + out = jnp.asarray(pi.update()) + + p = freq * dt / 1e3 # 0.02 + b = num * (1 - p) + expected_std = float((b * p) ** 0.5) # ~4.427 + expected_mean = num * p # 20.0 + + assert expected_std == pytest.approx(4.427, abs=1e-2) + # empirical std close to sqrt(b*p), and clearly far from the variance b*p (=19.6) + assert float(out.std()) == pytest.approx(expected_std, rel=0.1) + assert float(out.mean()) == pytest.approx(expected_mean, rel=0.1) + + +def test_alpha_cuba_coba_construct_and_step(): + """C-20: AlphaCUBA/COBA construct without ZeroDivisionError and run a step.""" + bm.random.seed(0) + bm.set_dt(0.1) + + pre = bp.neurons.LIF(2) + post = bp.neurons.LIF(2) + syn = bp.synapses.AlphaCUBA(pre, post, bp.connect.All2All(), tau_decay=10.0) + net = bp.Network(pre=pre, post=post, syn=syn) + net.reset_state() + _share() + syn.update() # must not raise + + pre2 = bp.neurons.LIF(2) + post2 = bp.neurons.LIF(2) + syn2 = bp.synapses.AlphaCOBA(pre2, post2, bp.connect.All2All(), tau_decay=10.0) + net2 = bp.Network(pre=pre2, post=post2, syn=syn2) + net2.reset_state() + _share() + syn2.update() # must not raise + + +def test_dynold_stp_no_drift_without_spikes(): + """C-21: dynold STP synaptic current stays at 0 with zero presynaptic spikes.""" + bm.random.seed(0) + bm.set_dt(0.1) + pre = bp.neurons.LIF(1) + post = bp.neurons.LIF(1) + syn = bp.synapses.STP(pre, post, bp.connect.All2All(), U=0.2, tau_d=150.0, tau_f=2.0) + net = bp.Network(pre=pre, syn=syn, post=post) + net.reset_state() + + currents = [] + for i in range(100): + _share(t=i * 0.1, dt=0.1, i=i) + net.update() # no external input -> no spikes + currents.append(float(jnp.asarray(syn.I.value).sum())) + + # With the spike-gating fix, the current must not ramp up. + assert max(abs(c) for c in currents) < 1e-5 + + +def test_dynold_stp_jumps_with_spikes(): + """C-21 companion: STP current does jump when presynaptic spikes occur.""" + bm.random.seed(0) + bm.set_dt(0.1) + pre = bp.neurons.LIF(1) + post = bp.neurons.LIF(1) + syn = bp.synapses.STP(pre, post, bp.connect.All2All(), U=0.2, tau_d=150.0, tau_f=2.0) + net = bp.Network(pre=pre, syn=syn, post=post) + runner = bp.DSRunner(net, inputs=[('pre.input', 28.0)], monitors=['syn.I'], + progress_bar=False) + runner.run(50.0) + I = runner.mon['syn.I'] + assert bool(jnp.isfinite(I).all()) + assert float(I.max()) > 0.0 + + +def test_lstm_h_c_setters_unbatched(): + """H-36: unbatched ``h``/``c`` setters slice the last axis (no IndexError).""" + bm.random.seed(0) + lstm = bp.dyn.LSTMCell(3, 4) + assert lstm.state.shape == (8,) + + lstm.h = jnp.ones((4,)) + assert jnp.allclose(jnp.asarray(lstm.h), 1.0) + # setting h must not have touched c + assert jnp.allclose(jnp.asarray(lstm.c), 0.0) + + lstm.c = jnp.full((4,), 2.0) + # setting c must not corrupt h + assert jnp.allclose(jnp.asarray(lstm.h), 1.0) + assert jnp.allclose(jnp.asarray(lstm.c), 2.0) + + +def test_lstm_h_c_setters_batched(): + """H-36: batched ``h``/``c`` setters write the correct rows.""" + bm.random.seed(0) + lstm = bp.dyn.LSTMCell(3, 4, mode=bm.batching_mode) + lstm.reset_state(2) + assert lstm.state.shape == (2, 8) + + lstm.h = jnp.ones((2, 4)) + assert jnp.allclose(jnp.asarray(lstm.h), 1.0) + assert jnp.allclose(jnp.asarray(lstm.c), 0.0) + + lstm.c = jnp.full((2, 4), 2.0) + assert jnp.allclose(jnp.asarray(lstm.h), 1.0) + assert jnp.allclose(jnp.asarray(lstm.c), 2.0) + + +def test_reservoir_bias_is_applied(): + """H-38: a nonzero bias shifts the reservoir output.""" + bm.random.seed(0) + common = dict(in_connectivity=1.0, rec_connectivity=1.0, + activation_type='external', leaky_rate=1.0) + r0 = bp.dyn.Reservoir(3, 5, b_initializer=init.ZeroInit(), **common) + r1 = bp.dyn.Reservoir(3, 5, b_initializer=init.Constant(1.0), **common) + # share the random weights so only the bias differs + r1.Win.value = r0.Win.value + r1.Wrec.value = r0.Wrec.value + + x = jnp.zeros((3,)) + out0 = jnp.asarray(r0.update(x)) + out1 = jnp.asarray(r1.update(x)) + # zero bias + zero input -> zero (external activation of 0 is tanh(0)=0) + assert jnp.allclose(out0, 0.0) + # nonzero bias shifts the output away from zero + assert float(jnp.sum(jnp.abs(out1 - out0))) > 1e-3 + + +def test_reservoir_recurrent_noise_is_symmetric(): + """H-37: recurrent noise is ``uniform(-1, 1)`` (zero-mean), not a -noise bias.""" + bm.random.seed(123) + r = bp.dyn.Reservoir(2, 4000, noise_rec=1.0, in_connectivity=1.0, + rec_connectivity=1.0, activation_type='external', + leaky_rate=1.0) + # zero out the weights so the state is driven purely by the recurrent noise + r.Win.value = jnp.zeros_like(jnp.asarray(r.Win.value)) + r.Wrec.value = jnp.zeros_like(jnp.asarray(r.Wrec.value)) + out = jnp.asarray(r.update(jnp.zeros((2,)))) + # state = tanh(noise); symmetric noise -> mean near zero, both signs present + assert abs(float(out.mean())) < 0.05 + assert float(out.min()) < -0.1 + assert float(out.max()) > 0.1 + + +# --------------------------------------------------------------------------- +# Coverage tests: construct + step the assigned modules +# --------------------------------------------------------------------------- + +def test_rate_populations_construct_and_step(): + """Cover FHN/FeedbackFHN/QIF/StuartLandau/WilsonCowan/ThresholdLinear.""" + bm.random.seed(0) + _share() + + # FHN with OU noise enabled to exercise the noise branch; pass both inputs + fhn = bp.dyn.FHN(3, x_ou_sigma=1.0, y_ou_sigma=1.0) + fhn.reset_state() + assert jnp.asarray(fhn.update(1.0, 0.5)).shape == (3,) + fhn.clear_input() + + fbfhn = bp.dyn.FeedbackFHN(3, delay=2.0, x_ou_sigma=1.0, y_ou_sigma=1.0) + fbfhn.reset_state() + assert jnp.asarray(fbfhn.update(1.0, 0.5)).shape == (3,) + fbfhn.clear_input() + + qif = bp.dyn.QIF(3) + qif.reset_state() + assert jnp.asarray(qif.update(1.0, 0.5)).shape == (3,) + qif.clear_input() + + sl = bp.dyn.StuartLandauOscillator(3) + sl.reset_state() + assert jnp.asarray(sl.update(1.0, 0.5)).shape == (3,) + sl.clear_input() + + wc = bp.dyn.WilsonCowanModel(3) + wc.reset_state() + assert jnp.asarray(wc.update(1.0, 0.5)).shape == (3,) + wc.clear_input() + + tlm = bp.dyn.ThresholdLinearModel(3) + tlm.reset_state() + assert jnp.asarray(tlm.update(inp_e=1.0, inp_i=0.5)).shape == (3,) + tlm.clear_input() + + +def test_rate_populations_no_input_var_branch(): + """Cover the ``input_var=False`` update branches with OU noise.""" + bm.random.seed(0) + _share() + for cls in (bp.dyn.FHN, bp.dyn.FeedbackFHN, bp.dyn.QIF, + bp.dyn.StuartLandauOscillator, bp.dyn.WilsonCowanModel): + m = cls(2, input_var=False, x_ou_sigma=1.0, y_ou_sigma=1.0) + m.reset_state() + out = m.update(1.0, 0.5) + assert jnp.asarray(out).shape == (2,) + + +def test_rate_populations_run_via_dsrunner(): + """A short DSRunner run exercises update/clear_input across many steps.""" + bm.random.seed(0) + fhn = bp.dyn.FHN(2) + runner = bp.DSRunner(fhn, inputs=('input', 1.0), monitors=['x'], + progress_bar=False) + runner.run(5.0) + assert bool(jnp.isfinite(runner.mon['x']).all()) + + +def test_reservoir_update_dense_and_sparse(): + """Cover Reservoir dense + sparse comp paths and feedforward noise.""" + bm.random.seed(0) + # dense with feedforward noise + spectral radius scaling + r = bp.dyn.Reservoir(4, 8, noise_in=0.1, spectral_radius=0.9) + r.reset_state() + out = r.update(jnp.ones((4,))) + assert jnp.asarray(out).shape == (8,) + + # sparse computation path + rs = bp.dyn.Reservoir(4, 8, comp_type='sparse', in_connectivity=0.5, + rec_connectivity=0.5) + rs.reset_state() + out_s = rs.update(jnp.ones((4,))) + assert jnp.asarray(out_s).shape == (8,) + + +def test_rnn_cells_forward_unbatched(): + """Cover RNNCell/GRUCell/LSTMCell forward pass (unbatched).""" + bm.random.seed(0) + x = jnp.ones((3,)) + for cls in (bp.dyn.RNNCell, bp.dyn.GRUCell, bp.dyn.LSTMCell): + c = cls(3, 4) + out = c.update(x) + assert jnp.asarray(out).shape == (4,) + + +def test_rnn_cells_forward_batched_and_train_state(): + """Cover batched forward + train_state initialization for the RNN cells.""" + bm.random.seed(0) + x = jnp.ones((2, 3)) + for cls in (bp.dyn.RNNCell, bp.dyn.GRUCell, bp.dyn.LSTMCell): + c = cls(3, 4, mode=bm.batching_mode) + c.reset_state(2) + out = c.update(x) + assert jnp.asarray(out).shape == (2, 4) + + # train_state path + ct = cls(3, 4, mode=bm.training_mode, train_state=True) + ct.reset_state(2) + out2 = ct.update(x) + assert jnp.asarray(out2).shape == (2, 4) + + +def test_rnn_cells_no_bias_branch(): + """Cover the ``b_initializer=None`` (no bias) branches in the RNN cells.""" + bm.random.seed(0) + x = jnp.ones((3,)) + for cls in (bp.dyn.RNNCell, bp.dyn.GRUCell, bp.dyn.LSTMCell): + c = cls(3, 4, b_initializer=None) + out = c.update(x) + assert jnp.asarray(out).shape == (4,) + + +def test_dynold_compat_synapses_construct_and_step(): + """Cover dynold compat synapses: Exp/DualExp/Alpha (CUBA & COBA) + Delta.""" + import warnings + bm.random.seed(0) + bm.set_dt(0.1) + + from brainpy.synapses import (ExpCUBA, ExpCOBA, DualExpCUBA, DualExpCOBA, + AlphaCUBA, AlphaCOBA, DeltaSynapse) + + factories = [ + lambda a, b: ExpCUBA(a, b, bp.connect.All2All(), tau=8.0), + lambda a, b: ExpCOBA(a, b, bp.connect.All2All(), tau=8.0, E=0.0), + lambda a, b: DualExpCUBA(a, b, bp.connect.All2All(), tau_decay=10.0, tau_rise=1.0), + lambda a, b: DualExpCOBA(a, b, bp.connect.All2All(), tau_decay=10.0, tau_rise=1.0, E=0.0), + lambda a, b: AlphaCUBA(a, b, bp.connect.All2All(), tau_decay=10.0), + lambda a, b: AlphaCOBA(a, b, bp.connect.All2All(), tau_decay=10.0, E=0.0), + lambda a, b: DeltaSynapse(a, b, bp.connect.All2All()), + ] + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + for factory in factories: + pre = bp.neurons.LIF(2) + post = bp.neurons.LIF(2) + syn = factory(pre, post) + net = bp.Network(pre=pre, post=post, syn=syn) + net.reset_state() + _share() + syn.update() # must not raise + + +def test_dynold_stp_run_full_simulation(): + """Run the dynold STP learning rule via DSRunner to cover its update path.""" + bm.random.seed(0) + bm.set_dt(0.1) + pre = bp.neurons.LIF(1) + post = bp.neurons.LIF(1) + syn = bp.synapses.STP(pre, post, bp.connect.All2All(), U=0.2, tau_d=150.0, tau_f=2.0) + net = bp.Network(pre=pre, syn=syn, post=post) + runner = bp.DSRunner(net, inputs=[('pre.input', 28.0)], + monitors=['syn.I', 'syn.u', 'syn.x'], progress_bar=False) + runner.run(30.0) + assert bool(jnp.isfinite(runner.mon['syn.u']).all()) + assert bool(jnp.isfinite(runner.mon['syn.x']).all()) + + +def test_reduced_models_construct_and_step(): + """Cover dynold reduced_models neurons: construct + a single update step.""" + from brainpy.dynold.neurons import reduced_models as rm + + bm.random.seed(0) + bm.set_dt(0.1) + _share() + + for cls in (rm.LeakyIntegrator, rm.LIF, rm.ExpIF, rm.AdExIF, rm.QuaIF, + rm.AdQuaIF, rm.GIF, rm.Izhikevich, rm.HindmarshRose, rm.FHN, + rm.ALIFBellec2020): + m = cls(2) + m.reset_state() + out = m.update(1.0) + assert jnp.asarray(out).shape == (2,) + m.clear_input() + + +def test_reduced_models_no_input_var_branch(): + """Cover the ``input_var=False`` update branch of the reduced models.""" + from brainpy.dynold.neurons import reduced_models as rm + + bm.random.seed(0) + bm.set_dt(0.1) + _share() + for cls in (rm.LIF, rm.ExpIF, rm.Izhikevich): + m = cls(2, input_var=False) + m.reset_state() + out = m.update(1.0) + assert jnp.asarray(out).shape == (2,) + m.clear_input() + + +def test_reduced_models_tau_ref_and_noise_branches(): + """Cover the refractory + noise branches of the reduced models.""" + from brainpy.dynold.neurons import reduced_models as rm + + bm.random.seed(0) + bm.set_dt(0.1) + _share() + + # tau_ref + noise exercises the sdeint + refractory paths in *Ref models + for cls in (rm.LIF, rm.ExpIF, rm.AdExIF, rm.QuaIF, rm.AdQuaIF, + rm.Izhikevich, rm.GIF): + m = cls(2, tau_ref=2.0, noise=0.5) + m.reset_state() + out = m.update(5.0) + assert jnp.asarray(out).shape == (2,) + m.clear_input() + + +def test_bellec_sfa_models_construct_and_step(): + """Cover ALIFBellec2020 / LIF_SFA_Bellec2020 with and without refractoriness.""" + from brainpy.dynold.neurons import reduced_models as rm + + bm.random.seed(0) + bm.set_dt(0.1) + _share() + + for cls in (rm.ALIFBellec2020, rm.LIF_SFA_Bellec2020): + m = cls(2) + m.reset_state() + assert jnp.asarray(m.update(5.0)).shape == (2,) + + m_ref = cls(2, tau_ref=2.0) + m_ref.reset_state() + assert jnp.asarray(m_ref.update(5.0)).shape == (2,) + + +def test_conv_lstm_cells_forward(): + """Cover the convolutional LSTM cells (1d/2d/3d) forward pass.""" + bm.random.seed(0) + + c1 = bp.dyn.Conv1dLSTMCell(input_shape=(5,), in_channels=2, out_channels=3, + kernel_size=3, mode=bm.batching_mode) + c1.reset_state(1) + assert jnp.asarray(c1.update(jnp.ones((1, 5, 2)))).shape == (1, 5, 3) + + c2 = bp.dyn.Conv2dLSTMCell(input_shape=(4, 4), in_channels=2, out_channels=3, + kernel_size=3, mode=bm.batching_mode) + c2.reset_state(1) + assert jnp.asarray(c2.update(jnp.ones((1, 4, 4, 2)))).shape == (1, 4, 4, 3) + + +def test_poisson_input_binomial_branch_and_helpers(): + """Cover the small-N binomial branch, ``__repr__`` and ``reset`` of PoissonInput.""" + from brainpy.dynold.experimental.others import PoissonInput + + bm.random.seed(0) + bm.set_dt(0.1) + _share() + + # low freq / small num_input -> a<=5 or b<=5 -> binomial branch + pi = PoissonInput(target_shape=(5,), num_input=10, freq=10.0, weight=1.0) + out = pi.update() + assert jnp.asarray(out).shape == (5,) + assert 'PoissonInput' in repr(pi) + pi.reset() + pi.reset_state() diff --git a/tests/audit/test_integrators_fixes.py b/tests/audit/test_integrators_fixes.py new file mode 100644 index 000000000..dd9cd6e4f --- /dev/null +++ b/tests/audit/test_integrators_fixes.py @@ -0,0 +1,933 @@ +# -*- coding: utf-8 -*- +"""Regression + coverage tests for the BrainPy v2.7.8 integrators audit. + +These tests pin the fixes recorded in ``docs/issues-found-20260618.md`` for the +``brainpy/integrators`` subtree. Each regression test references the audit ID it +guards. The remaining tests exercise the public integrator API broadly to keep +line coverage high on the assigned source files: + + * ode/adaptive_rk.py -- C-12, H-26, H-27, H-28 + * ode/exponential.py -- exp_euler / exp_euler_auto numerics + * sde/base.py -- C-13 (errors import; invalid intg_type/wiener_type) + * sde/normal.py -- C-13 (Heun Ito/Stratonovich guard) + * integrators/runner.py-- H-29 (IntegratorRunner step-index monitor) + * integrators/joint_eq.py -- L-13 (diagnostic DiffEqError message) + * fde/Caputo.py -- C-08 (CaputoEuler init scaling), H-31 (hists()) + * fde/GL.py -- H-30 (GLShortMemory.reset key suffix) + * fde/generic.py -- H-32 (set/get_default_fdeint global) + +The tests use tiny ``dt`` / step counts so the whole module runs in well under a +minute. Global state (x64 precision, default FDE method) is restored by +fixtures. +""" + +import math + +import jax.numpy as jnp +import numpy as np +import pytest + +import brainpy as bp +import brainpy.math as bm +from brainpy.integrators.ode.adaptive_rk import BoSh3 + +IntegratorError = bp.errors.IntegratorError +DiffEqError = bp.errors.DiffEqError + + +# --------------------------------------------------------------------------- # +# Fixtures: restore global state mutated by some tests. +# --------------------------------------------------------------------------- # + +@pytest.fixture +def x64(): + """Enable float64 for the duration of a test, then restore float32.""" + bm.enable_x64() + try: + yield + finally: + bm.disable_x64() + + +@pytest.fixture +def restore_default_fdeint(): + """Snapshot and restore the global default FDE method.""" + orig = bp.fde.get_default_fdeint() + try: + yield + finally: + bp.fde.set_default_fdeint(orig) + + +def _ev(seq): + """Evaluate a Butcher-tableau row whose entries may be string fractions.""" + return [eval(x) if isinstance(x, str) else float(x) for x in seq] + + +# =========================================================================== # +# Regression tests +# =========================================================================== # + +# --- C-12: adaptive RK no longer TypeErrors when ``tol`` is omitted --------- # + +def test_c12_adaptive_rkf45_runs_without_tol_array_state(): + """C-12: adaptive=True with default tol must run (tol falls back to 0.1).""" + f = bp.odeint(lambda y, t: -y, method='rkf45', adaptive=True) + y_new, dt_new = f(jnp.array([1.0]), 0.0, dt=0.1) + # exact solution y(0.1) = e^{-0.1} + assert np.allclose(np.asarray(y_new), math.exp(-0.1), atol=1e-3) + assert float(dt_new) > 0.0 + + +def test_c12_adaptive_rkf45_runs_on_scalar_state(): + """C-12 / H-28: the same integrator must also work on a python-float scalar.""" + f = bp.odeint(lambda y, t: -y, method='rkf45', adaptive=True) + y_new, dt_new = f(1.0, 0.0, dt=0.1) + assert np.allclose(float(y_new), math.exp(-0.1), atol=1e-3) + assert float(dt_new) > 0.0 + + +# --- H-28: scalar state with default POP_VAR var_type ----------------------- # + +def test_h28_scalar_state_all_adaptive_methods(): + """H-28: default var_type=POP_VAR must use jnp.sum, not builtin sum, so a + scalar state does not raise ``'float' object is not iterable``.""" + for method in ['rkf45', 'rkdp', 'rkf12', 'ck', 'bs', 'heun_euler', 'BoSh3']: + f = bp.odeint(lambda y, t: -y, method=method, adaptive=True) + y_new, dt_new = f(1.0, 0.0, dt=0.05) + assert np.isfinite(float(y_new)) + assert np.isfinite(float(dt_new)) + + +# --- H-26: BoSh3 embedded error vector is non-degenerate -------------------- # + +def test_h26_bosh3_embedded_error_non_degenerate(): + """H-26: B1 and B2 must each be consistent (sum ~ 1) and B1-B2 must be a + real (non-zero) error estimator, not the zero-sum-B2 bug.""" + b1 = _ev(BoSh3.B1) + b2 = _ev(BoSh3.B2) + assert abs(sum(b1) - 1.0) < 1e-9, f'B1 must sum to 1, got {sum(b1)}' + assert abs(sum(b2) - 1.0) < 1e-9, f'B2 must sum to 1, got {sum(b2)}' + diff = [a - b for a, b in zip(b1, b2)] + # the embedded error estimate must be non-degenerate (some non-zero weights) + assert any(abs(d) > 1e-9 for d in diff), 'B1-B2 is degenerate (all zero)' + # the buggy B2 summed to ~0; guard against that regression explicitly + assert abs(sum(b2)) > 0.5 + + +def test_h26_bosh3_integrates_correctly(): + """BoSh3 (3rd order) must integrate y'=-y accurately.""" + f = bp.odeint(lambda y, t: -y, method='BoSh3', adaptive=True) + y_new, _ = f(jnp.array([1.0]), 0.0, dt=0.05) + assert np.allclose(np.asarray(y_new), math.exp(-0.05), atol=1e-4) + + +# --- H-27: two-sided step-size controller can grow dt ----------------------- # + +def test_h27_step_controller_grows_dt_when_error_small(): + """H-27: when the error is comfortably below tol the controller must be able + to *increase* dt (the buggy one-sided controller never grew dt).""" + f = bp.odeint(lambda y, t: -y, method='rkf45', adaptive=True) + _, dt_new = f(jnp.array([1.0]), 0.0, dt=0.01) + assert float(dt_new) > 0.01, 'controller failed to grow dt below tolerance' + + +# --- exp_euler numerics (python-float input) -------------------------------- # + +def test_exp_euler_python_float_input(): + """exp_euler on a linear ODE y'=-2y is exact: y(0.3)=e^{-0.6}.""" + f = bp.odeint(lambda y, t: -2 * y, method='exp_euler', dt=0.3) + out = f(1.0, 0.0, dt=0.3) + assert np.allclose(float(out), math.exp(-0.6), atol=1e-5) + + +def test_exp_euler_auto_python_float_input(): + """exp_euler_auto must give the same exact result on the linear ODE.""" + f = bp.odeint(lambda y, t: -2 * y, method='exp_euler_auto', dt=0.3) + out = f(1.0, 0.0, dt=0.3) + assert np.allclose(float(out), math.exp(-0.6), atol=1e-5) + + +# --- C-13: SDE integrators raise IntegratorError, not NameError ------------- # + +def test_c13_sde_invalid_intg_type_raises_integrator_error(): + """C-13: invalid intg_type must raise IntegratorError (errors import fixed), + not NameError: name 'errors' is not defined.""" + with pytest.raises(IntegratorError): + bp.sdeint(lambda x, t: -x, lambda x, t: 0.1, + method='euler', intg_type='WRONG') + + +def test_c13_sde_invalid_wiener_type_raises_integrator_error(): + """C-13: the same errors import guards the wiener_type validation path.""" + with pytest.raises(IntegratorError): + bp.sdeint(lambda x, t: -x, lambda x, t: 0.1, + method='euler', wiener_type='WRONG') + + +def test_c13_heun_rejects_ito(): + """C-13 (sde/normal.py): Heun only supports Stratonovich; an Ito request + must raise IntegratorError rather than NameError.""" + with pytest.raises(IntegratorError): + bp.sdeint(lambda x, t: -x, lambda x, t: 0.1, + method='heun', intg_type='Ito') + + +# --- C-08: CaputoEuler does not mis-scale the initial condition ------------- # + +def test_c08_caputo_euler_preserves_initial_condition(x64): + """C-08: for D^a x = 0 with x(0)=1 (exact x==1), CaputoEuler must keep x~1 + across steps instead of returning dt^a/a.""" + intg = bp.fde.CaputoEuler(lambda x, t: jnp.zeros_like(jnp.asarray(x)), + alpha=0.8, num_memory=10, inits=[1.0]) + x = jnp.array([1.0]) + t = 0.0 + dt = 0.1 + for _ in range(5): + x = intg(x, t, dt=dt) + t += dt + assert np.allclose(np.asarray(x), 1.0, atol=1e-6), \ + f'CaputoEuler drifted from the initial condition: {np.asarray(x)}' + + +# --- H-31: CaputoL1Schema.hists() iterates .items(), not bare dict ---------- # + +def test_h31_caputo_l1_hists_returns_without_valueerror(x64): + """H-31: after a step, .hists() (default numpy=True) must return a dict of + arrays and not raise ValueError from iterating dict keys instead of items.""" + intg = bp.fde.CaputoL1Schema(lambda x, t: -x, alpha=0.9, + num_memory=10, inits=[1.0]) + x = jnp.array([1.0]) + x = intg(x, 0.0, dt=0.1) + hists = intg.hists() # must not raise + assert isinstance(hists, dict) + for v in hists.values(): + assert isinstance(v, np.ndarray) + # per-variable accessor path + var = intg.variables[0] + one = intg.hists(var=var) + assert isinstance(one, np.ndarray) + + +# --- H-30: GLShortMemory.reset uses the '_delay' key suffix ------------------ # + +def test_h30_glshortmemory_reset_works(): + """H-30: reset must use key+'_delay' and not raise KeyError.""" + g = bp.fde.GLShortMemory(lambda x, t: -x, alpha=0.6, + num_memory=8, inits=[1.0]) + g.reset(inits=[2.0]) # must not raise KeyError + out = g(jnp.array([2.0]), 0.0, dt=0.1) + assert np.all(np.isfinite(np.asarray(out))) + + +# --- H-32: set_default_fdeint writes the FDE global, get reads it back ------- # + +def test_h32_set_get_default_fdeint_roundtrips(restore_default_fdeint): + """H-32: set_default_fdeint must actually change get_default_fdeint (the bug + wrote the wrong global, making it a no-op).""" + for method in bp.fde.get_supported_methods(): + bp.fde.set_default_fdeint(method) + assert bp.fde.get_default_fdeint() == method + + +def test_h32_set_default_fdeint_rejects_unknown(restore_default_fdeint): + with pytest.raises(ValueError): + bp.fde.set_default_fdeint('not-a-real-method') + + +# --- L-13: JointEq raises a *diagnostic* DiffEqError on conflicting kwarg ---- # + +def test_l13_jointeq_conflicting_kwarg_raises_message(): + """L-13: a keyword argument that reuses a state-variable name must raise a + DiffEqError carrying a non-empty diagnostic message.""" + def dV(V, t, w): + return -V + + def dw(w, t, V=1.0): # 'V' kwarg collides with state variable V + return -w + + with pytest.raises(DiffEqError) as exc: + bp.JointEq(dV, dw) + assert str(exc.value).strip(), 'DiffEqError message must not be empty' + assert 'V' in str(exc.value) + + +# =========================================================================== # +# Coverage tests +# =========================================================================== # + +def test_cov_odeint_all_methods_array_and_scalar(): + """Exercise non-adaptive + adaptive ODE methods on array and scalar state.""" + y0 = jnp.array([1.0]) + for m in ['euler', 'rk2', 'rk4']: + out = bp.odeint(lambda y, t: -y, method=m)(y0, 0.0, dt=0.01) + assert np.allclose(np.asarray(out), math.exp(-0.01), atol=1e-3) + + for m in ['rkf45', 'rkdp', 'rkf12', 'ck', 'bs', 'heun_euler', 'BoSh3']: + f = bp.odeint(lambda y, t: -y, method=m, adaptive=True) + y_new, dt_new = f(y0, 0.0, dt=0.01) + assert np.allclose(np.asarray(y_new), math.exp(-0.01), atol=1e-2) + assert float(dt_new) > 0.0 + + +def test_cov_adaptive_explicit_tol_and_var_type(): + """Cover the adaptive path with an explicit tol and SCALAR var_type.""" + f = bp.odeint(lambda y, t: -y, method='rkf45', adaptive=True, + tol=1e-3, var_type='scalar') + y_new, dt_new = f(1.0, 0.0, dt=0.01) + assert np.isfinite(float(y_new)) + assert np.isfinite(float(dt_new)) + + +def test_cov_adaptive_show_code(): + """show_code=True exercises the code-emission/print branch.""" + f = bp.odeint(lambda y, t: -y, method='rkf45', adaptive=True, show_code=True) + out = f(jnp.array([1.0]), 0.0, dt=0.05) + assert len(out) == 2 + + +def test_cov_exp_euler_variants_array_state(): + """Exercise exp_euler / exp_euler_auto / exp_auto on an array state.""" + y0 = jnp.array([1.0, 2.0]) + for m in ['exp_euler', 'exp_euler_auto', 'exp_auto', 'exponential_euler']: + out = bp.odeint(lambda y, t: -y, method=m)(y0, 0.0, dt=0.01) + assert np.allclose(np.asarray(out), y0 * math.exp(-0.01), atol=1e-3) + + +def test_cov_sdeint_euler_milstein_heun_both_types(): + """Cover euler/milstein (Ito + Stratonovich) and heun (Stratonovich).""" + bm.random.seed(1234) + g = lambda x, t: jnp.ones_like(x) * 0.1 + for method in ['euler', 'milstein']: + for itype in ['Ito', 'Stratonovich']: + f = bp.sdeint(lambda x, t: -x, g, method=method, intg_type=itype) + out = f(jnp.array([1.0]), 0.0, dt=0.01) + assert np.all(np.isfinite(np.asarray(out))) + # Heun is Stratonovich-only + f = bp.sdeint(lambda x, t: -x, g, method='heun', intg_type='Stratonovich') + out = f(jnp.array([1.0]), 0.0, dt=0.01) + assert np.all(np.isfinite(np.asarray(out))) + + +def test_cov_integrator_runner_with_monitors(): + """H-29 + coverage: IntegratorRunner over a small ODE; the step-index + monitor must report 0..N-1 (loop variable no longer clobbered).""" + bm.set_dt(0.1) + intg = bp.odeint(lambda V, t: -V, method='rk4') + runner = bp.IntegratorRunner( + intg, + monitors={'V': 'V', 'step': lambda sh: sh['i']}, + inits={'V': 1.0}, + dt=0.1, + progress_bar=False, + ) + runner.run(0.5) + v = np.asarray(runner.mon['V']).ravel() + step = np.asarray(runner.mon['step']).ravel() + assert v.shape == (5,) + assert np.allclose(v, np.exp(-0.1 * np.arange(1, 6)), atol=1e-3) + # H-29: the step index is the real loop counter, not len(vars)-1 + assert list(step) == [0, 1, 2, 3, 4] + + +def test_cov_integrator_runner_seq_monitor(): + """Cover the tuple/list monitor formatting branch of IntegratorRunner.""" + bm.set_dt(0.1) + intg = bp.odeint(lambda V, t: -V, method='euler') + runner = bp.IntegratorRunner(intg, monitors=['V'], inits={'V': 1.0}, + dt=0.1, progress_bar=False) + runner.run(0.3) + assert np.asarray(runner.mon['V']).shape == (3, 1) + + +def test_cov_jointeq_construction_and_integration(): + """Cover JointEq construction (incl. nested) and integration.""" + a, b = 0.7, 0.8 + + def dV(V, t, w, Iext): + return V - V ** 3 / 3 - w + Iext + + def dw(w, t, V): + return a * (b * V - w) + + eq = bp.JointEq(dV, dw) + flat = [v for sub in eq.vars_in_eqs for v in sub] + assert set(flat) == {'V', 'w'} + + intg = bp.odeint(eq, method='rk2') + V, w = intg(0.0, 0.0, 0.0, Iext=0.5, dt=0.01) + assert np.isfinite(float(V)) and np.isfinite(float(w)) + + # nested JointEq + def du(u, t, V): + return -u + V + + eq2 = bp.JointEq(eq, du) + flat2 = [v for sub in eq2.vars_in_eqs for v in sub] + assert set(flat2) == {'V', 'w', 'u'} + + +def test_cov_jointeq_rejects_non_callable(): + """Cover the _check_eqs error branch.""" + with pytest.raises(DiffEqError): + bp.JointEq(123) + + +def test_cov_caputo_euler_integration(x64): + """Integrate a non-trivial Caputo equation a few steps for coverage.""" + intg = bp.fde.CaputoEuler(lambda x, t: -x, alpha=0.9, + num_memory=20, inits=[1.0]) + x = jnp.array([1.0]) + t = 0.0 + for _ in range(5): + x = intg(x, t, dt=0.05) + t += 0.05 + assert np.all(np.isfinite(np.asarray(x))) + # decaying solution must stay below the initial value + assert float(x[0]) < 1.0 + + +def test_cov_caputo_l1_integration_and_reset(x64): + """Integrate CaputoL1Schema a few steps, hits hists(), then reset().""" + intg = bp.fde.CaputoL1Schema(lambda x, t: -x, alpha=0.7, + num_memory=20, inits=[1.0]) + x = jnp.array([1.0]) + t = 0.0 + for _ in range(4): + x = intg(x, t, dt=0.05) + t += 0.05 + assert np.all(np.isfinite(np.asarray(x))) + h = intg.hists() + assert isinstance(h, dict) + intg.reset(inits=[0.5]) + assert np.allclose(np.asarray(intg.inits[intg.variables[0]]), 0.5) + + +def test_cov_glshortmemory_integration_multivar(): + """Integrate GLShortMemory on a coupled 2D system + binomial_coef access.""" + def f(x, y, t): + return -x + 0.1 * y, -y - 0.1 * x + + g = bp.fde.GLShortMemory(f, alpha=[0.95, 0.95], num_memory=16, + inits=[1.0, 0.5]) + coef = g.binomial_coef + assert coef.shape[0] == 16 + x = jnp.array([1.0]) + y = jnp.array([0.5]) + t = 0.0 + for _ in range(5): + x, y = g(x, y, t, dt=0.02) + t += 0.02 + assert np.all(np.isfinite(np.asarray(x))) + assert np.all(np.isfinite(np.asarray(y))) + + +def test_cov_glshortmemory_via_fdeint(): + """Cover the fdeint factory dispatch to the short-memory integrator.""" + intg = bp.fdeint(alpha=0.8, num_memory=8, inits=[1.0], + method='short-memory', dt=0.05)(lambda x, t: -x) + out = intg(jnp.array([1.0]), 0.0, dt=0.05) + assert np.all(np.isfinite(np.asarray(out))) + + +def test_cov_caputo_euler_via_fdeint(restore_default_fdeint): + """Cover fdeint(method=None) using the default + the 'euler' dispatch.""" + bp.fde.set_default_fdeint('euler') + intg = bp.fdeint(alpha=0.8, num_memory=8, inits=[1.0], + method=None, dt=0.05)(lambda x, t: -x) + out = intg(jnp.array([1.0]), 0.0, dt=0.05) + assert np.all(np.isfinite(np.asarray(out))) + + +# --------------------------------------------------------------------------- # +# Extra coverage: SDE method/wiener-type/multivar branches (sde/normal.py). +# --------------------------------------------------------------------------- # + +def test_cov_sde_milstein_variants(): + """Cover milstein2 / milstein_grad_free for both integral types.""" + bm.random.seed(21) + g = lambda x, t: jnp.ones_like(x) * 0.1 + for method in ['milstein2', 'milstein_grad_free']: + for itype in ['Ito', 'Stratonovich']: + f = bp.sdeint(lambda x, t: -x, g, method=method, intg_type=itype) + out = f(jnp.array([1.0]), 0.0, dt=0.01) + assert np.all(np.isfinite(np.asarray(out))) + + +def test_cov_sde_exponential_euler(): + """Cover the SDE ExponentialEuler integrator.""" + bm.random.seed(22) + g = lambda x, t: jnp.ones_like(x) * 0.1 + for method in ['exp_euler', 'exponential_euler']: + f = bp.sdeint(lambda x, t: -x, g, method=method) + out = f(jnp.array([1.0]), 0.0, dt=0.01) + assert np.all(np.isfinite(np.asarray(out))) + + +def test_cov_sde_multivariable(): + """Cover the multi-variable drift/diffusion branches of the Euler integrator. + + Milstein deliberately supports only a single variable at a time and raises + DiffEqError for multi-variable systems, so it is exercised separately. + """ + bm.random.seed(23) + + def f(x, y, t): + return -x, -y + + def g(x, y, t): + return 0.1 * jnp.ones_like(x), 0.1 * jnp.ones_like(y) + + for itype in ['Ito', 'Stratonovich']: + intg = bp.sdeint(f, g, method='euler', intg_type=itype) + out = intg(jnp.array([1.0]), jnp.array([2.0]), 0.0, dt=0.01) + assert len(out) == 2 + assert np.all(np.isfinite(np.asarray(out[0]))) + assert np.all(np.isfinite(np.asarray(out[1]))) + + # Milstein rejects multi-variable systems (raised when building/calling). + with pytest.raises(DiffEqError): + intg = bp.sdeint(f, g, method='milstein') + intg(jnp.array([1.0]), jnp.array([2.0]), 0.0, dt=0.01) + + +def test_cov_sde_vector_wiener(): + """Cover the VECTOR_WIENER (Ito) summation branch of the Euler integrator. + + Only the Ito branch is exercised: the Stratonovich vector-wiener path has a + latent broadcasting bug (g(Y) of shape (3, 2) is added to a state of shape + (3,) without summing over the noise axis) and is out of scope for this audit. + """ + bm.random.seed(24) + + def fv(x, t): + return -x + + def gv(x, t): + return 0.1 * jnp.ones((3, 2)) # 2 independent noise sources + + intg = bp.sdeint(fv, gv, method='euler', wiener_type='vector', + intg_type='Ito') + out = intg(jnp.ones(3), 0.0, dt=0.01) + assert np.asarray(out).shape == (3,) + assert np.all(np.isfinite(np.asarray(out))) + + +def test_cov_sde_vector_wiener_requires_nd_diffusion(): + """Cover the vector-wiener scalar-diffusion guard (ValueError path).""" + bm.random.seed(25) + intg = bp.sdeint(lambda x, t: -x, lambda x, t: jnp.float32(0.1), + method='euler', wiener_type='vector') + with pytest.raises(ValueError): + intg(jnp.array([1.0]), 0.0, dt=0.01) + + +def test_cov_sde_drift_must_be_tensor(): + """Cover the single-variable drift-not-a-tensor ValueError branch.""" + intg = bp.sdeint(lambda x, t: -1.0, lambda x, t: jnp.ones_like(x) * 0.1, + method='euler') + with pytest.raises(ValueError): + intg(jnp.array([1.0]), 0.0, dt=0.01) + + +# --------------------------------------------------------------------------- # +# Extra coverage: IntegratorRunner.run options (runner.py). +# --------------------------------------------------------------------------- # + +def test_cov_runner_run_options(): + """Cover run(start_t=, eval_time=True) and the .mon['ts'] time axis.""" + bm.set_dt(0.1) + intg = bp.odeint(lambda V, t: -V, method='rk4') + runner = bp.IntegratorRunner(intg, monitors={'V': 'V'}, inits={'V': 1.0}, + dt=0.1, progress_bar=False) + runner.run(0.3, start_t=0.0, eval_time=True) + assert np.asarray(runner.mon['V']).shape == (3, 1) + # second run continues from the previous index (covers idx bookkeeping) + runner.run(0.2) + assert np.asarray(runner.mon['V']).shape == (2, 1) + + +def test_cov_runner_rejects_non_integrator(): + """Cover the target type check in IntegratorRunner.__init__.""" + with pytest.raises(TypeError): + bp.IntegratorRunner(lambda V, t: -V, monitors={'V': 'V'}, + inits={'V': 1.0}, dt=0.1) + + +# --------------------------------------------------------------------------- # +# Extra coverage: JointEq error branches (joint_eq.py). +# --------------------------------------------------------------------------- # + +def test_cov_jointeq_duplicate_state_variable(): + """Cover the duplicate-state-variable DiffEqError branch.""" + def dV1(V, t): + return -V + + def dV2(V, t): # 'V' used as a state variable twice + return -2 * V + + with pytest.raises(DiffEqError): + bp.JointEq(dV1, dV2) + + +# --------------------------------------------------------------------------- # +# Extra coverage: exponential ODE auto-diff path (exponential.py). +# --------------------------------------------------------------------------- # + +def test_cov_exp_euler_nonlinear(): + """Cover the auto-linearization path on a non-linear right-hand side.""" + # logistic-like ODE; just needs to run and stay finite + intg = bp.odeint(lambda y, t: y * (1.0 - y), method='exp_euler_auto') + y = jnp.array([0.1]) + t = 0.0 + for _ in range(5): + y = intg(y, t, dt=0.05) + t += 0.05 + assert np.all(np.isfinite(np.asarray(y))) + assert float(y[0]) > 0.1 # logistic growth + + +def test_cov_exp_euler_show_code(): + """Cover the show_code emission branch of ExponentialEuler.""" + intg = bp.odeint(lambda y, t: -y, method='exp_euler', show_code=True) + out = intg(jnp.array([1.0]), 0.0, dt=0.05) + assert np.all(np.isfinite(np.asarray(out))) + + +# --------------------------------------------------------------------------- # +# Extra coverage: JointEq argument parsing edge cases (joint_eq.py). +# --------------------------------------------------------------------------- # + +def test_cov_jointeq_missing_time_variable(): + """Cover the 'Do not find time variable "t"' ValueError branch.""" + with pytest.raises(ValueError): + bp.JointEq(lambda V, w: -V) # no 't' parameter + + +def test_cov_jointeq_rejects_var_positional(): + """Cover the VAR_POSITIONAL (*args) rejection branch.""" + with pytest.raises(DiffEqError): + bp.JointEq(lambda V, t, *extra: -V) + + +def test_cov_jointeq_rejects_var_keyword(): + """Cover the VAR_KEYWORD (**kwargs) rejection branch.""" + with pytest.raises(DiffEqError): + bp.JointEq(lambda V, t, **extra: -V) + + +def test_cov_jointeq_conflicting_kwarg_defaults(): + """Cover the 'two different default value' DiffEqError branch.""" + def dV(V, t, a=1.0): + return -V + a + + def dw(w, t, a=2.0): # same kwarg name, different default + return -w + a + + with pytest.raises(DiffEqError): + bp.JointEq(dV, dw) + + +def test_cov_jointeq_nested_list_and_shared_kwarg(): + """Cover the _check_eqs list/tuple recursion + a shared (consistent) kwarg + default, then call with the kwarg passed both positionally and by keyword.""" + def dV(V, t, w, gain=0.5): + return -V + gain * w + + def dw(w, t, V, gain=0.5): # same default -> allowed, shared kwarg + return -w + gain * V + + eq = bp.JointEq([dV, dw]) # list form exercises the recursion branch + assert 'gain' in eq.kwarg_keys + intg = bp.odeint(eq, method='euler') + # call providing gain by keyword + out = intg(1.0, 0.0, gain=0.3, dt=0.01) + assert all(np.isfinite(float(o)) for o in out) + + +# --------------------------------------------------------------------------- # +# Extra coverage: FDE reset / check paths (Caputo.py, generic.py). +# --------------------------------------------------------------------------- # + +def test_cov_caputo_euler_reset(x64): + """Cover CaputoEuler.reset().""" + intg = bp.fde.CaputoEuler(lambda x, t: -x, alpha=0.8, + num_memory=10, inits=[1.0]) + x = jnp.array([1.0]) + x = intg(x, 0.0, dt=0.05) + intg.reset(inits=[3.0]) + assert np.allclose(np.asarray(intg.inits[intg.variables[0]]), 3.0) + + +def test_cov_caputo_euler_requires_tensor_drift(x64): + """Cover the single-variable 'Derivative values must be a tensor' branch.""" + intg = bp.fde.CaputoEuler(lambda x, t: 0.0, alpha=0.8, + num_memory=10, inits=[1.0]) + with pytest.raises(ValueError): + intg(jnp.array([1.0]), 0.0, dt=0.05) + + +def test_cov_fde_register_duplicate_rejected(): + """Cover register_fde_integrator duplicate-name guard.""" + from brainpy.integrators.fde.generic import register_fde_integrator + from brainpy.integrators.fde.GL import GLShortMemory + with pytest.raises(ValueError): + register_fde_integrator('euler', GLShortMemory) + + +def test_cov_fdeint_unknown_method_rejected(): + """Cover the fdeint unknown-method ValueError branch.""" + with pytest.raises(ValueError): + bp.fdeint(alpha=0.8, num_memory=4, inits=[1.0], + method='nope', dt=0.05)(lambda x, t: -x) + + +def test_cov_set_default_fdeint_type_check(restore_default_fdeint): + """Cover the non-string set_default_fdeint guard.""" + with pytest.raises(ValueError): + bp.fde.set_default_fdeint(123) + + +# --------------------------------------------------------------------------- # +# Extra coverage: exponential ODE branches (exponential.py). +# --------------------------------------------------------------------------- # + +def test_cov_exp_euler_with_jointeq(): + """Cover the JointEq build path of ExponentialEuler (_build_integrator).""" + def dV(V, t, w): + return -V - w + + def dw(w, t, V): + return -w + 0.1 * V + + eq = bp.JointEq(dV, dw) + intg = bp.odeint(eq, method='exp_euler') + out = intg(jnp.array([1.0]), jnp.array([0.5]), 0.0, dt=0.01) + assert len(out) == 2 + assert all(np.all(np.isfinite(np.asarray(o))) for o in out) + + +def test_cov_exp_euler_rejects_integer_input(): + """Cover the float-dtype guard of the Exponential Euler integral.""" + intg = bp.odeint(lambda y, t: -y, method='exp_euler') + with pytest.raises(ValueError): + intg(jnp.array([1, 2, 3]), 0.0, dt=0.01) # integer dtype + + +def test_cov_exp_euler_system_var_not_implemented(): + """Cover the SYSTEM_VAR NotImplementedError branch.""" + with pytest.raises(NotImplementedError): + bp.odeint(lambda y, t: -y, method='exp_euler', var_type='system') + + +# --------------------------------------------------------------------------- # +# Extra coverage: Milstein vector-wiener branches (sde/normal.py). +# --------------------------------------------------------------------------- # + +def test_cov_milstein_vector_wiener_latent_bug(): + """Document a latent (out-of-scope) bug: the Milstein integrators do not + correctly broadcast the diffusion-gradient term for VECTOR_WIENER noise, so + a vector-wiener Milstein step currently raises a broadcasting ValueError. + + This still exercises the vector-wiener branch up to the failure point; if the + library is ever fixed, this test should assert a finite result instead. + """ + bm.random.seed(31) + + def fv(x, t): + return -x + + def gv(x, t): + return 0.1 * jnp.ones((3, 2)) + + for method in ['milstein', 'milstein2']: + intg = bp.sdeint(fv, gv, method=method, wiener_type='vector', + intg_type='Ito') + with pytest.raises((ValueError, TypeError)): + intg(jnp.ones(3), 0.0, dt=0.01) + + +# --------------------------------------------------------------------------- # +# Extra coverage: IntegratorRunner init / dyn_args branches (runner.py). +# --------------------------------------------------------------------------- # + +def test_cov_runner_inits_as_sequence(): + """Cover the list/sequence form of the ``inits`` argument.""" + bm.set_dt(0.1) + intg = bp.odeint(lambda V, t: -V, method='euler') + runner = bp.IntegratorRunner(intg, monitors={'V': 'V'}, inits=[1.0], + dt=0.1, progress_bar=False) + runner.run(0.3) + assert np.asarray(runner.mon['V']).shape == (3, 1) + + +def test_cov_runner_dyn_args(): + """Cover the dyn_args time-varying input path of IntegratorRunner.run.""" + bm.set_dt(0.1) + intg = bp.odeint(lambda V, t, Iext: -V + Iext, method='euler') + runner = bp.IntegratorRunner(intg, monitors={'V': 'V'}, inits={'V': 0.0}, + dt=0.1, progress_bar=False) + # 3 steps -> dyn_args first dimension must be 3 + runner.run(0.3, dyn_args={'Iext': jnp.ones(3)}) + assert np.asarray(runner.mon['V']).shape == (3, 1) + + +def test_cov_runner_dyn_args_shape_mismatch(): + """Cover the dyn_args duration-mismatch ValueError branch.""" + bm.set_dt(0.1) + intg = bp.odeint(lambda V, t, Iext: -V + Iext, method='euler') + runner = bp.IntegratorRunner(intg, monitors={'V': 'V'}, inits={'V': 0.0}, + dt=0.1, progress_bar=False) + with pytest.raises(ValueError): + runner.run(0.3, dyn_args={'Iext': jnp.ones(99)}) + + +# --------------------------------------------------------------------------- # +# Extra coverage: SDE ExponentialEuler branches (sde/normal.py). +# --------------------------------------------------------------------------- # + +def test_cov_sde_exp_euler_vector_wiener(): + """Cover the VECTOR_WIENER branch of the SDE ExponentialEuler.""" + bm.random.seed(41) + intg = bp.sdeint(lambda x, t: -x, lambda x, t: 0.1 * jnp.ones((3, 2)), + method='exp_euler', wiener_type='vector') + out = intg(jnp.ones(3), 0.0, dt=0.01) + assert np.asarray(out).shape == (3,) + assert np.all(np.isfinite(np.asarray(out))) + + +def test_cov_sde_exp_euler_rejects_stratonovich(): + """Cover the SDE ExponentialEuler Stratonovich NotImplementedError branch.""" + with pytest.raises(NotImplementedError): + bp.sdeint(lambda x, t: -x, lambda x, t: jnp.ones((1,)) * 0.1, + method='exp_euler', intg_type='Stratonovich') + + +def test_cov_sde_exp_euler_with_jointeq(): + """Cover the JointEq build + multi-variable diffusion path of SDE exp_euler.""" + bm.random.seed(42) + + def dV(V, t, w): + return -V + + def dw(w, t, V): + return -w + + def gV(V, t, w): + return jnp.ones_like(V) * 0.1 + + def gw(w, t, V): + return jnp.ones_like(w) * 0.1 + + intg = bp.sdeint(bp.JointEq(dV, dw), bp.JointEq(gV, gw), method='exp_euler') + out = intg(jnp.array([1.0]), jnp.array([0.5]), 0.0, dt=0.01) + assert len(out) == 2 + assert all(np.all(np.isfinite(np.asarray(o))) for o in out) + + +# --------------------------------------------------------------------------- # +# Extra coverage: JointEq.__call__ with a tuple-returning sub-equation. +# --------------------------------------------------------------------------- # + +def test_cov_jointeq_call_with_nested_tuple_result(): + """Cover JointEq.__call__ extending results from a tuple-returning sub-eq + (a nested JointEq returns a list/tuple).""" + def dV(V, t, w): + return -V + w + + def dw(w, t, V): + return -w + 0.1 * V + + inner = bp.JointEq(dV, dw) # returns a list when called + + def du(u, t, V): + return -u + V + + outer = bp.JointEq(inner, du) + res = outer(1.0, 0.5, 0.2, 0.0) # V, w, u, t + assert len(res) == 3 + assert all(np.isfinite(float(r)) for r in res) + + +# --------------------------------------------------------------------------- # +# Extra coverage: multi-variable FDE + IntegratorRunner branches. +# --------------------------------------------------------------------------- # + +def test_cov_caputo_euler_multivariable(x64): + """Cover the multi-variable drift branch of CaputoEuler.""" + def f(x, y, t): + return -x, -y + + intg = bp.fde.CaputoEuler(f, alpha=[0.8, 0.9], num_memory=10, + inits=[1.0, 2.0]) + x = jnp.array([1.0]) + y = jnp.array([2.0]) + t = 0.0 + for _ in range(3): + x, y = intg(x, y, t, dt=0.05) + t += 0.05 + assert np.all(np.isfinite(np.asarray(x))) + assert np.all(np.isfinite(np.asarray(y))) + # decaying solution + assert float(x[0]) < 1.0 and float(y[0]) < 2.0 + + +def test_cov_caputo_l1_multivariable(x64): + """Cover the multi-variable branch + per-variable hists() of CaputoL1Schema.""" + def f(x, y, t): + return -x, -y + + intg = bp.fde.CaputoL1Schema(f, alpha=[0.8, 0.9], num_memory=10, + inits=[1.0, 2.0]) + x = jnp.array([1.0]) + y = jnp.array([2.0]) + t = 0.0 + for _ in range(3): + x, y = intg(x, y, t, dt=0.05) + t += 0.05 + hists = intg.hists() + assert set(hists.keys()) == set(intg.variables) + one = intg.hists(var=intg.variables[1]) + assert isinstance(one, np.ndarray) + + +def test_cov_runner_multivariable_with_progress_bar(): + """Cover the multi-variable update branch and the progress-bar callback of + IntegratorRunner.""" + bm.set_dt(0.1) + + def dV(V, t, w): + return -V + w + + def dw(w, t, V): + return -w + + intg = bp.odeint(bp.JointEq(dV, dw), method='rk2') + runner = bp.IntegratorRunner( + intg, + monitors={'V': 'V', 'w': 'w'}, + inits={'V': 1.0, 'w': 0.5}, + dt=0.1, + progress_bar=True, + ) + runner.run(0.3) + assert np.asarray(runner.mon['V']).shape == (3, 1) + assert np.asarray(runner.mon['w']).shape == (3, 1) + + +def test_cov_jointeq_rejects_keyword_only(): + """Cover the KEYWORD_ONLY parameter rejection branch of _get_args.""" + def dV(V, t, *, w): + return -V + + with pytest.raises(DiffEqError): + bp.JointEq(dV, lambda w, t: -w) + + +def test_cov_jointeq_rejects_positional_only(): + """Cover the POSITIONAL_ONLY parameter rejection branch of _get_args.""" + def dV(V, t, w, /): + return -V + + with pytest.raises(DiffEqError): + bp.JointEq(dV, lambda w, t: -w) diff --git a/tests/audit/test_math_compat_fixes.py b/tests/audit/test_math_compat_fixes.py new file mode 100644 index 000000000..66bde6202 --- /dev/null +++ b/tests/audit/test_math_compat_fixes.py @@ -0,0 +1,1017 @@ +# -*- coding: utf-8 -*- +"""Regression + coverage tests for the ``brainpy.math`` compatibility layer. + +This module accompanies the audit recorded in ``docs/issues-found-20260618.md``. +It pins the behaviours fixed by the audit and broadly exercises the numpy / +pytorch / tensorflow compatibility shims, the activation functions, the misc. +``others`` helpers, the ``_utils`` wrapper factory and the ``einops`` port. + +Audit findings exercised here (see the doc for full context): + +* C-11 (``compat_tensorflow.py``) -- ``reduce_logsumexp`` must be numerically + stable (delegated to ``jax.scipy.special.logsumexp``). +* H-16 (``activations.py``) -- ``softmin`` must subtract the max so it + stays finite for large inputs. +* H-14 (``_utils.py`` + pytorch) -- ``out=`` wrapped funcs must *return* the + ``out`` Array instead of ``None``. +* H-15 (``others.py``) -- ``remove_diag`` must trace cleanly under + ``jit``/``vmap`` (static off-diagonal gather, handles non-square). +* H-13 (``compat_numpy.py``) -- ``asfarray`` must coerce integer input to + a floating dtype. +* M-11/M-12 (``compat_numpy.py``) -- ``empty`` uses ``jnp.empty`` and + ``fill_diagonal(inplace=False)`` returns a brainpy ``Array``. +* (``compat_pytorch.py``) -- ``arcsinh``/``arctanh`` exist & correct, + no duplicate ``arcsin`` clobbering. +* (``einops.py``) -- module still imports after the dead + ``_optimize_transformation`` helper was removed. +""" + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +import brainpy.math as bm +from brainpy.math import ( + compat_numpy as cn, + compat_pytorch as cpt, + compat_tensorflow as ctf, + activations as act, + others as bo, + _utils as butils, + einops as bein, +) +from brainpy.math.ndarray import Array + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + +def _j(x): + """Return the underlying jax array for assertions.""" + return bm.as_jax(x) + + +def _finite(x): + return bool(jnp.all(jnp.isfinite(_j(x)))) + + +# =========================================================================== +# 1. Regression tests (audit-specific behaviours) +# =========================================================================== + +# --- C-11: reduce_logsumexp numerical stability ----------------------------- + +def test_reduce_logsumexp_stable_large_inputs(): + """C-11: log(sum(exp(.))) must not overflow for large inputs.""" + r = ctf.reduce_logsumexp(bm.asarray([1000., 1000., 1000.])) + assert _finite(r) + # logsumexp([1000]*3) == 1000 + log(3) == 1001.0986... + assert float(_j(r)) == pytest.approx(1000.0 + np.log(3.0), abs=1e-2) + + +def test_reduce_logsumexp_matches_reference_small(): + x = bm.asarray([0.5, -1.0, 2.0, 3.5]) + r = ctf.reduce_logsumexp(x) + ref = jax.scipy.special.logsumexp(_j(x)) + assert float(_j(r)) == pytest.approx(float(ref), rel=1e-6) + + +def test_reduce_logsumexp_axis_keepdims(): + x = bm.asarray([[1., 2., 3.], [4., 5., 6.]]) + r = ctf.reduce_logsumexp(x, axis=1, keepdims=True) + assert _j(r).shape == (2, 1) + assert _finite(r) + + +# --- H-16: softmin must not produce NaN for large inputs -------------------- + +def test_softmin_finite_for_large_inputs(): + """H-16: softmin lacked max-subtraction -> NaN for large inputs.""" + r = act.softmin(bm.asarray([1000., 1001., 1002.])) + assert _finite(r) + # softmin == softmax(-x); the smallest input gets the largest weight. + expected = np.array([0.66524096, 0.24472847, 0.09003057]) + np.testing.assert_allclose(np.asarray(_j(r)), expected, atol=1e-4) + assert float(jnp.sum(_j(r))) == pytest.approx(1.0, abs=1e-6) + + +def test_softmin_matches_softmax_of_negative(): + x = bm.asarray([0.3, -1.2, 2.5, 0.0]) + np.testing.assert_allclose( + np.asarray(_j(act.softmin(x))), + np.asarray(_j(act.softmax(-x))), + atol=1e-6, + ) + + +# --- H-14: out= wrapped funcs must RETURN the out Array --------------------- + +def test_numpy_sum_out_returns_array(): + """H-14: numpy-style ``out=`` should return the (filled) Array, not None.""" + out = bm.asarray(0.) + r = bm.sum(bm.asarray([1., 2., 3.]), out=out) + assert r is not None + assert isinstance(r, Array) + assert r is out + assert float(_j(r)) == pytest.approx(6.0) + assert float(_j(out)) == pytest.approx(6.0) + + +def test_numpy_out_must_be_brainpy_array(): + with pytest.raises(TypeError): + bm.sum(bm.asarray([1., 2., 3.]), out=jnp.array(0.)) + + +def test_pytorch_add_out_returns_out(): + """H-14 (pytorch compat): ``add(..., out=...)`` returns ``out``.""" + out = bm.zeros((3,)) + r = cpt.add(bm.asarray([1., 2., 3.]), bm.asarray([1., 1., 1.]), out=out) + assert r is out + np.testing.assert_allclose(np.asarray(_j(r)), [2., 3., 4.]) + + +def test_pytorch_out_must_be_brainpy_array(): + with pytest.raises(TypeError): + cpt.abs(bm.asarray([-1.]), out=jnp.array(0.)) + + +# --- H-15: remove_diag must trace cleanly under jit/vmap -------------------- + +def test_remove_diag_jit_matches_eager_square(): + """H-15: remove_diag must work under jit (static off-diag gather).""" + x = jnp.arange(9.).reshape(3, 3) + eager = _j(bo.remove_diag(x)) + jitted = jax.jit(bo.remove_diag)(x) + assert eager.shape == (3, 2) + np.testing.assert_allclose(np.asarray(eager), np.asarray(jitted)) + # row 0 has its diagonal (0) removed -> [1, 2] + np.testing.assert_allclose(np.asarray(eager[0]), [1., 2.]) + + +def test_remove_diag_vmap(): + x = jnp.arange(2 * 9.).reshape(2, 3, 3) + out = jax.vmap(bo.remove_diag)(x) + assert out.shape == (2, 3, 2) + np.testing.assert_allclose(np.asarray(out[0]), np.asarray(_j(bo.remove_diag(x[0])))) + + +def test_remove_diag_non_square(): + x = jnp.arange(12.).reshape(3, 4) + out = _j(bo.remove_diag(x)) + assert out.shape == (3, 3) + # row 0 drops col 0; rows keep all but the diagonal element + np.testing.assert_allclose(np.asarray(out[0]), [1., 2., 3.]) + np.testing.assert_allclose(np.asarray(out[1]), [4., 6., 7.]) + + +def test_remove_diag_rejects_non_2d(): + with pytest.raises(ValueError): + bo.remove_diag(jnp.arange(3.)) + + +# --- H-13: asfarray coerces integers to a floating dtype -------------------- + +def test_asfarray_integer_input_becomes_float(): + """H-13: asfarray(int) used to no-op; must yield a floating dtype.""" + r = cn.asfarray([1, 2, 3]) + assert jnp.issubdtype(r.dtype, jnp.floating) + np.testing.assert_allclose(np.asarray(_j(r)), [1., 2., 3.]) + + +def test_asfarray_preserves_floating_dtype(): + r = cn.asfarray(jnp.array([1., 2.], dtype=jnp.float32)) + assert r.dtype == jnp.float32 + + +# --- M-11 / M-12: empty + fill_diagonal ------------------------------------- + +def test_empty_shape_and_type(): + """M-11: empty must produce the right shape/dtype as a brainpy Array.""" + e = bm.empty((2, 3)) + assert isinstance(e, Array) + assert e.shape == (2, 3) + assert jnp.issubdtype(e.dtype, jnp.floating) + + +def test_empty_like(): + a = bm.asarray(jnp.ones((4,), dtype=jnp.int32)) + e = cn.empty_like(a) + assert e.shape == (4,) + assert e.dtype == jnp.int32 + + +def test_fill_diagonal_not_inplace_returns_array(): + """M-12: fill_diagonal(inplace=False) must return a brainpy Array.""" + x = bm.asarray(jnp.ones((3, 3))) + r = cn.fill_diagonal(x, 5., inplace=False) + assert isinstance(r, Array) + np.testing.assert_allclose(np.diag(np.asarray(_j(r))), [5., 5., 5.]) + # original unchanged + np.testing.assert_allclose(np.diag(np.asarray(_j(x))), [1., 1., 1.]) + + +def test_fill_diagonal_inplace_updates_array(): + x = bm.asarray(jnp.ones((3, 3))) + out = cn.fill_diagonal(x, 7., inplace=True) + assert out is None # in-place returns nothing + np.testing.assert_allclose(np.diag(np.asarray(_j(x))), [7., 7., 7.]) + + +def test_fill_diagonal_errors(): + with pytest.raises(ValueError): + cn.fill_diagonal(bm.asarray(jnp.arange(3)), 1.) # ndim < 2 + with pytest.raises(ValueError): + cn.fill_diagonal(jnp.ones((3, 3)), 1.) # inplace on non-Array + + +# --- pytorch arcsinh / arctanh exist & correct, no dup arcsin --------------- + +def test_pytorch_arcsinh_arctanh(): + np.testing.assert_allclose( + np.asarray(_j(cpt.arcsinh(bm.asarray([0., 1.])))), + np.arcsinh([0., 1.]), atol=1e-6, + ) + np.testing.assert_allclose( + np.asarray(_j(cpt.arctanh(bm.asarray([0., 0.5])))), + np.arctanh([0., 0.5]), atol=1e-6, + ) + # aliases point at their canonical implementations + assert cpt.arcsinh is cpt.asinh + assert cpt.arctanh is cpt.atanh + + +def test_pytorch_arcsin_is_asin_not_arcsinh(): + """No duplicate ``arcsin`` should shadow asin with arcsinh.""" + assert cpt.arcsin is cpt.asin + np.testing.assert_allclose( + np.asarray(_j(cpt.arcsin(bm.asarray([0., 0.5])))), + np.arcsin([0., 0.5]), atol=1e-6, + ) + + +def test_numpy_arcsinh_arctanh_present(): + np.testing.assert_allclose( + np.asarray(_j(cn.arcsinh(bm.asarray([0., 1.])))), np.arcsinh([0., 1.]), atol=1e-6) + np.testing.assert_allclose( + np.asarray(_j(cn.arctanh(bm.asarray([0., 0.5])))), np.arctanh([0., 0.5]), atol=1e-6) + + +# --- einops module still imports (dead _optimize_transformation removed) ---- + +def test_einops_module_imports(): + import brainpy.math.einops as eio + assert hasattr(eio, 'ein_rearrange') + assert hasattr(eio, 'ein_reduce') + assert hasattr(eio, 'ein_repeat') + assert hasattr(eio, 'ein_shape') + # the dead helper flagged by the audit must be gone + assert not hasattr(eio, '_optimize_transformation') + + +# =========================================================================== +# 2. Coverage tests +# =========================================================================== + +# --- compat_numpy: creation funcs ------------------------------------------- + +def test_compat_numpy_creation(): + assert cn.zeros((2, 2)).shape == (2, 2) + assert cn.ones((3,)).shape == (3,) + assert cn.empty((2,)).shape == (2,) + assert float(_j(cn.full((2,), 4.))[0]) == 4. + assert cn.eye(3).shape == (3, 3) + assert cn.identity(3).shape == (3, 3) + assert cn.arange(5).shape == (5,) + assert cn.linspace(0., 1., 5).shape == (5,) + assert cn.logspace(0., 1., 5).shape == (5,) + a = bm.asarray(jnp.arange(4.)) + assert cn.zeros_like(a).shape == (4,) + assert cn.ones_like(a).shape == (4,) + assert cn.empty_like(a).shape == (4,) + assert cn.full_like(a, 2.).shape == (4,) + + +def test_compat_numpy_linspace_retstep(): + r, step = cn.linspace(0., 1., 5, retstep=True) + assert isinstance(r, Array) + assert float(step) == pytest.approx(0.25) + + +def test_compat_numpy_array_and_asarray_with_arrays(): + res = cn.array([bm.asarray([1., 2.]), bm.asarray([3., 4.])]) + assert res.shape == (2, 2) + res2 = cn.asarray([bm.asarray([1., 2.]), bm.asarray([3., 4.])]) + assert res2.shape == (2, 2) + assert cn.asanyarray([1, 2, 3]).shape == (3,) + assert cn.ascontiguousarray([1, 2, 3]).shape == (3,) + + +def test_compat_numpy_diag_tri_family(): + m = bm.asarray(jnp.arange(9.).reshape(3, 3)) + assert cn.diag(m).shape == (3,) + assert cn.tril(m).shape == (3, 3) + assert cn.triu(m).shape == (3, 3) + assert cn.tri(3).shape == (3, 3) + assert cn.diagonal(m).shape == (3,) + assert cn.diagflat(bm.asarray([1., 2., 3.])).shape == (3, 3) + + +def test_compat_numpy_ufuncs(): + a = bm.asarray([1., 2., 3.]) + b = bm.asarray([4., 5., 6.]) + funcs_binary = [cn.add, cn.subtract, cn.multiply, cn.divide, cn.power, + cn.true_divide, cn.maximum, cn.minimum, cn.fmax, cn.fmin, + cn.hypot, cn.logaddexp, cn.logaddexp2, cn.copysign, + cn.nextafter, cn.remainder, cn.mod, cn.fmod, cn.float_power] + for f in funcs_binary: + assert _j(f(a, b)).shape == (3,) + funcs_unary = [cn.negative, cn.positive, cn.reciprocal, cn.abs, cn.absolute, + cn.fabs, cn.exp, cn.exp2, cn.expm1, cn.log1p, cn.sqrt, cn.cbrt, + cn.square, cn.sign, cn.sin, cn.cos, cn.tan, cn.sinh, cn.cosh, + cn.tanh, cn.arcsin, cn.arccos, cn.arctan, cn.arcsinh, cn.arctanh, + cn.sinc, cn.deg2rad, cn.rad2deg, cn.degrees, cn.radians, + cn.round, cn.rint, cn.floor, cn.ceil, cn.trunc, cn.isfinite, + cn.isinf, cn.isnan, cn.signbit, cn.conj, cn.conjugate, cn.real, + cn.imag, cn.angle, cn.nan_to_num] + pos = bm.asarray([0.1, 0.2, 0.3]) + for f in funcs_unary: + assert _j(f(pos)).shape == (3,) + # log family wants strictly positive + for f in [cn.log, cn.log10, cn.log2]: + assert _j(f(a)).shape == (3,) + # gcd/lcm want integers + ia, ib = bm.asarray([4, 6, 8]), bm.asarray([6, 9, 12]) + assert _j(cn.gcd(ia, ib)).shape == (3,) + assert _j(cn.lcm(ia, ib)).shape == (3,) + assert _j(cn.arctan2(a, b)).shape == (3,) + assert _j(cn.heaviside(a, b)).shape == (3,) + assert len(cn.frexp(a)) == 2 + assert len(cn.modf(a)) == 2 + + +def test_compat_numpy_reductions(): + a = bm.asarray([[1., 2., 3.], [4., 5., 6.]]) + assert float(_j(cn.sum(a))) == 21. + assert float(_j(cn.prod(bm.asarray([1., 2., 3., 4.])))) == 24. + assert float(_j(cn.mean(a))) == pytest.approx(3.5) + assert float(_j(cn.max(a))) == 6. + assert float(_j(cn.min(a))) == 1. + assert float(_j(cn.amax(a))) == 6. + assert float(_j(cn.amin(a))) == 1. + assert _j(cn.std(a)).shape == () + assert _j(cn.var(a)).shape == () + assert _j(cn.median(a)).shape == () + assert _j(cn.average(a)).shape == () + assert _j(cn.cumsum(a)).shape == (6,) + assert _j(cn.cumprod(a)).shape == (6,) + assert _j(cn.nansum(a)).shape == () + assert _j(cn.nanprod(a)).shape == () + assert _j(cn.nanmean(a)).shape == () + assert _j(cn.nanstd(a)).shape == () + assert _j(cn.nanvar(a)).shape == () + assert _j(cn.nanmedian(a)).shape == () + assert int(cn.argmax(a)) == 5 + assert int(cn.argmin(a)) == 0 + assert _j(cn.ptp(a)).shape == () + assert _j(cn.diff(a)).shape == (2, 2) + assert _j(cn.nancumsum(a)).shape == (6,) + assert _j(cn.nancumprod(a)).shape == (6,) + assert int(cn.count_nonzero(a)) == 6 + assert _j(cn.percentile(a, 50)).shape == () + assert _j(cn.quantile(a, 0.5)).shape == () + + +def test_compat_numpy_logic(): + a = bm.asarray([1., 2., 3.]) + b = bm.asarray([1., 0., 3.]) + assert _j(cn.equal(a, b)).shape == (3,) + assert _j(cn.not_equal(a, b)).shape == (3,) + assert _j(cn.greater(a, b)).shape == (3,) + assert _j(cn.greater_equal(a, b)).shape == (3,) + assert _j(cn.less(a, b)).shape == (3,) + assert _j(cn.less_equal(a, b)).shape == (3,) + assert bool(cn.array_equal(a, a)) + assert _j(cn.isclose(a, b)).shape == (3,) + assert bool(cn.allclose(a, a)) + ba = bm.asarray([True, False, True]) + bb = bm.asarray([True, True, False]) + assert _j(cn.logical_and(ba, bb)).shape == (3,) + assert _j(cn.logical_or(ba, bb)).shape == (3,) + assert _j(cn.logical_xor(ba, bb)).shape == (3,) + assert _j(cn.logical_not(ba)).shape == (3,) + assert bool(cn.all(ba)) is False + assert bool(cn.any(ba)) is True + assert bool(cn.alltrue(ba)) is False + assert bool(cn.sometrue(ba)) is True + + +def test_compat_numpy_bit_ops(): + a = bm.asarray([1, 2, 3]) + b = bm.asarray([3, 2, 1]) + assert _j(cn.bitwise_and(a, b)).shape == (3,) + assert _j(cn.bitwise_or(a, b)).shape == (3,) + assert _j(cn.bitwise_xor(a, b)).shape == (3,) + assert _j(cn.bitwise_not(a)).shape == (3,) + assert _j(cn.invert(a)).shape == (3,) + assert _j(cn.left_shift(a, b)).shape == (3,) + assert _j(cn.right_shift(a, b)).shape == (3,) + + +def test_compat_numpy_manipulation(): + a = bm.asarray(jnp.arange(6.)) + m = bm.asarray(jnp.arange(6.).reshape(2, 3)) + assert _j(cn.reshape(a, (2, 3))).shape == (2, 3) + assert _j(cn.ravel(m)).shape == (6,) + assert _j(cn.moveaxis(m, 0, 1)).shape == (3, 2) + assert _j(cn.transpose(m)).shape == (3, 2) + assert _j(cn.swapaxes(m, 0, 1)).shape == (3, 2) + assert _j(cn.concatenate([a, a])).shape == (12,) + assert _j(cn.stack([a, a])).shape == (2, 6) + assert _j(cn.vstack([a, a])).shape == (2, 6) + assert _j(cn.hstack([a, a])).shape == (12,) + assert _j(cn.dstack([a, a])).shape == (1, 6, 2) + assert _j(cn.column_stack([a, a])).shape == (6, 2) + assert len(cn.split(a, 2)) == 2 + assert len(cn.array_split(a, 4)) == 4 + assert _j(cn.tile(a, 2)).shape == (12,) + assert _j(cn.repeat(a, 2)).shape == (12,) + assert _j(cn.flip(a)).shape == (6,) + assert _j(cn.fliplr(m)).shape == (2, 3) + assert _j(cn.flipud(m)).shape == (2, 3) + assert _j(cn.roll(a, 1)).shape == (6,) + assert _j(cn.atleast_1d(bm.asarray(1.))).shape == (1,) + assert _j(cn.atleast_2d(a)).shape == (1, 6) + assert _j(cn.atleast_3d(a)).shape == (1, 6, 1) + assert _j(cn.expand_dims(a, 0)).shape == (1, 6) + assert _j(cn.squeeze(bm.asarray(jnp.ones((1, 3))))).shape == (3,) + assert _j(cn.append(a, a)).shape == (12,) + assert _j(cn.sort(a)).shape == (6,) + assert _j(cn.argsort(a)).shape == (6,) + assert _j(cn.unique(bm.asarray([1, 1, 2, 3]))).shape == (3,) + assert _j(cn.row_stack([a, a])).shape == (2, 6) + + +def test_compat_numpy_indexing_and_where(): + a = bm.asarray(jnp.arange(6.)) + assert _j(cn.where(a > 2, a, 0.)).shape == (6,) + assert len(cn.nonzero(a)) == 1 + assert _j(cn.argwhere(a > 2).flatten()).shape[0] >= 0 + assert _j(cn.flatnonzero(a)).shape[0] == 5 + assert int(cn.searchsorted(a, 3.5)) == 4 + assert _j(cn.take(a, bm.asarray([0, 2, 4]))).shape == (3,) + assert _j(cn.select([a > 3], [a])).shape == (6,) + assert _j(cn.extract(a > 3, a)).shape[0] == 2 + assert len(cn.tril_indices(3)) == 2 + assert len(cn.triu_indices(3)) == 2 + m = bm.asarray(jnp.ones((3, 3))) + assert len(cn.tril_indices_from(m)) == 2 + assert len(cn.triu_indices_from(m)) == 2 + + +def test_compat_numpy_linalg(): + a = bm.asarray([1., 2., 3.]) + b = bm.asarray([4., 5., 6.]) + m = bm.asarray(jnp.arange(9.).reshape(3, 3)) + assert _j(cn.dot(a, b)).shape == () + assert _j(cn.vdot(a, b)).shape == () + assert _j(cn.inner(a, b)).shape == () + assert _j(cn.outer(a, b)).shape == (3, 3) + assert _j(cn.kron(a, b)).shape == (9,) + assert _j(cn.matmul(m, m)).shape == (3, 3) + assert _j(cn.trace(m)).shape == () + assert _j(cn.tensordot(m, m)).shape == () + + +def test_compat_numpy_misc_helpers(): + a = bm.asarray([1., 2., 3.]) + assert cn.shape(a) == (3,) + assert cn.shape([[1, 2]]) == (1, 2) + assert cn.shape(0) == () + assert cn.size(a) == 3 + assert cn.size(bm.asarray(jnp.ones((2, 3))), 1) == 3 + assert cn.size([1, 2, 3]) == 3 + assert int(cn.ndim(a)) == 1 + assert float(cn.asscalar(bm.asarray(7.))) == 7. + assert cn.matrix([[1, 2], [3, 4]]).shape == (2, 2) + assert cn.asmatrix([1, 2, 3]).shape == (1, 3) + assert cn.mat([1, 2, 3]).shape == (1, 3) + assert cn.msort(bm.asarray(jnp.array([[3., 1.], [2., 4.]]))).shape == (2, 2) + assert cn.common_type(jnp.array([1., 2.])) is not None + assert cn.common_type(jnp.array([1 + 1j])) is not None # complex branch + assert _j(cn.frombuffer(b'\x01\x02\x03\x04', dtype=np.int8)).shape == (4,) + assert _j(cn.meshgrid(a, a))[0].shape == (3, 3) + assert _j(cn.broadcast_to(a, (2, 3))).shape == (2, 3) + assert cn.broadcast_shapes((3,), (2, 3)) == (2, 3) + assert _j(cn.pad(a, 1)).shape == (5,) + assert _j(cn.clip(a, 1.5, 2.5)).shape == (3,) + assert _j(cn.interp(bm.asarray([1.5]), a, a)).shape == (1,) + assert _j(cn.einsum('i,i->', a, a)).shape == () + assert _j(cn.gradient(a)).shape == (3,) + assert _j(cn.histogram(a)[0]).shape == (10,) + assert _j(cn.bincount(bm.asarray([0, 1, 1, 2]))).shape == (3,) + assert _j(cn.digitize(a, bm.asarray([0., 2.]))).shape == (3,) + + +def test_compat_numpy_window_and_constants(): + assert _j(cn.bartlett(4)).shape == (4,) + assert _j(cn.blackman(4)).shape == (4,) + assert _j(cn.hamming(4)).shape == (4,) + assert _j(cn.hanning(4)).shape == (4,) + assert _j(cn.kaiser(4, 1.0)).shape == (4,) + assert cn.e == pytest.approx(np.e) + assert cn.pi == pytest.approx(np.pi) + assert np.isinf(cn.inf) + + +def test_compat_numpy_inplace_helpers_and_errors(): + a = bm.asarray(jnp.arange(6)) + cn.place(a, jnp.array([True, False] * 3), [10, 20, 30]) + b = bm.asarray(jnp.arange(6)) + cn.put(b, jnp.array([0, 1]), jnp.array([9, 8])) + assert int(_j(b)[0]) == 9 + c = bm.asarray(jnp.zeros(3)) + cn.copyto(c, jnp.ones(3)) + assert float(_j(c)[0]) == 1. + # error paths (non-Array inputs) + with pytest.raises(ValueError): + cn.place(jnp.arange(6), jnp.array([True] * 6), [1]) + with pytest.raises(ValueError): + cn.put(jnp.arange(6), [0], [9]) + with pytest.raises(ValueError): + cn.putmask(jnp.arange(6), jnp.arange(6) > 2, jnp.arange(6)) + with pytest.raises(ValueError): + cn.copyto(jnp.zeros(3), jnp.ones(3)) + + +def test_compat_numpy_in1d_and_set_ops(): + a = bm.asarray([1, 2, 3, 4]) + b = bm.asarray([2, 4]) + assert _j(cn.in1d(a, b)).shape == (4,) + assert _j(cn.in1d(a, b, invert=True)).shape == (4,) + assert _j(cn.intersect1d(a, b)).shape == (2,) + assert _j(cn.union1d(a, b)).shape[0] == 4 + assert _j(cn.setdiff1d(a, b)).shape == (2,) + assert _j(cn.isin(a, b)).shape == (4,) + + +def test_compat_numpy_dtype_helpers(): + assert cn.issubdtype(jnp.float32, jnp.floating) + assert cn.can_cast(jnp.int32, jnp.int64) + assert cn.result_type(jnp.int32, jnp.float32) is not None + assert cn.promote_types(jnp.int32, jnp.float32) is not None + assert cn.finfo(jnp.float32).bits == 32 + assert cn.iinfo(jnp.int32).bits == 32 + + +# --- compat_pytorch --------------------------------------------------------- + +def test_pytorch_shape_ops(): + a = bm.asarray(jnp.arange(24.).reshape(2, 3, 4)) + assert cpt.flatten(a).shape == (24,) + assert cpt.flatten(a, start_dim=1).shape == (2, 12) + assert cpt.flatten(a, start_dim=1, end_dim=2).shape == (2, 12) + assert cpt.flatten(a, start_dim=-2).shape == (2, 12) + assert cpt.flatten(bm.asarray(jnp.array(3.))).shape == (1,) + assert _j(cpt.unflatten(bm.asarray(jnp.arange(6.)), 0, (2, 3))).shape == (2, 3) + assert _j(cpt.unsqueeze(bm.asarray(jnp.arange(3.)), 0)).shape == (1, 3) + assert _j(cpt.cat([bm.asarray([1., 2.]), bm.asarray([3., 4.])])).shape == (4,) + + +def test_pytorch_flatten_errors(): + a = bm.asarray(jnp.arange(6.).reshape(2, 3)) + with pytest.raises(ValueError): + cpt.flatten(a, start_dim=5) + with pytest.raises(ValueError): + cpt.flatten(a, end_dim=5) + + +def test_pytorch_math_ops_no_out(): + a = bm.asarray([0.1, 0.2, 0.3]) + for f in [cpt.abs, cpt.absolute, cpt.acos, cpt.arccos, cpt.acosh, cpt.arccosh, + cpt.asin, cpt.arcsin, cpt.asinh, cpt.arcsinh, cpt.atan, cpt.arctan, + cpt.atanh, cpt.arctanh]: + # acosh needs x >= 1 + x = bm.asarray([1.1, 1.2, 1.3]) if f in (cpt.acosh, cpt.arccosh) else a + assert _j(f(x)).shape == (3,) + assert _j(cpt.angle(bm.asarray([1 + 1j, 2 - 1j]))).shape == (2,) + assert _j(cpt.atan2(a, a)).shape == (3,) + assert _j(cpt.arctan2(a, a)).shape == (3,) + + +def test_pytorch_add_family(): + a = bm.asarray([1., 2., 3.]) + b = bm.asarray([4., 5., 6.]) + np.testing.assert_allclose(np.asarray(_j(cpt.add(a, b))), [5., 7., 9.]) + np.testing.assert_allclose(np.asarray(_j(cpt.add(a, b, alpha=2))), [9., 12., 15.]) + assert _j(cpt.addcdiv(a, b, b, value=2)).shape == (3,) + assert _j(cpt.addcmul(a, b, b, value=2)).shape == (3,) + + +def test_pytorch_out_paths(): + a = bm.asarray([-1., -2., -3.]) + out = bm.zeros((3,)) + r = cpt.abs(a, out=out) + assert r is out + np.testing.assert_allclose(np.asarray(_j(out)), [1., 2., 3.]) + out2 = bm.zeros((3,)) + r2 = cpt.addcdiv(bm.asarray([1., 1., 1.]), bm.asarray([2., 2., 2.]), + bm.asarray([2., 2., 2.]), value=1, out=out2) + assert r2 is out2 + + +def test_pytorch_unary_out_paths(): + """Exercise the ``out=`` branch of every unary pytorch math op (H-14).""" + a = bm.asarray([0.2, 0.3, 0.4]) + high = bm.asarray([1.1, 1.2, 1.3]) + cases = [ + (cpt.acos, a), (cpt.arccos, a), (cpt.acosh, high), (cpt.arccosh, high), + (cpt.asin, a), (cpt.arcsin, a), (cpt.asinh, a), (cpt.arcsinh, a), + (cpt.atan, a), (cpt.arctan, a), (cpt.atanh, a), (cpt.arctanh, a), + (cpt.absolute, a), + ] + for f, x in cases: + out = bm.zeros((3,)) + r = f(x, out=out) + assert r is out + assert _finite(out) + # binary / complex out= paths + out = bm.zeros((3,)) + assert cpt.atan2(a, a, out=out) is out + out = bm.zeros((3,)) + assert cpt.arctan2(a, a, out=out) is out + cout = bm.zeros((3,)) + assert cpt.angle(bm.asarray([1 + 1j, 2 + 0j, 0 + 1j]), out=cout) is cout + + +def test_pytorch_flatten_negative_end_dim_error(): + a = bm.asarray(jnp.arange(6.).reshape(2, 3)) + with pytest.raises(ValueError): + cpt.flatten(a, end_dim=-10) + + +def test_pytorch_clamp_aliases(): + a = bm.asarray([1., 5., 9.]) + np.testing.assert_allclose(np.asarray(_j(cpt.clamp_max(a, 4.))), [1., 4., 4.]) + np.testing.assert_allclose(np.asarray(_j(cpt.clamp_min(a, 4.))), [4., 5., 9.]) + + +def test_pytorch_tensor_alias(): + assert cpt.Tensor is Array + + +# --- compat_tensorflow ------------------------------------------------------ + +def test_tensorflow_reductions(): + a = bm.asarray([[1., 2., 3.], [4., 5., 6.]]) + assert float(_j(ctf.reduce_sum(a))) == 21. + assert float(_j(ctf.reduce_max(a))) == 6. + assert float(_j(ctf.reduce_min(a))) == 1. + assert float(_j(ctf.reduce_mean(a))) == pytest.approx(3.5) + assert float(_j(ctf.reduce_prod(a))) == 720. + assert _j(ctf.reduce_std(a)).shape == () + assert _j(ctf.reduce_variance(a)).shape == () + assert _j(ctf.reduce_euclidean_norm(a)).shape == () + ba = bm.asarray([[True, False], [True, True]]) + assert bool(ctf.reduce_all(ba)) is False + assert bool(ctf.reduce_any(ba)) is True + # axis variants + assert _j(ctf.reduce_max(a, axis=1)).shape == (2,) + assert _j(ctf.reduce_sum(a, axis=0, keepdims=True)).shape == (1, 3) + assert _j(ctf.reduce_euclidean_norm(a, axis=1)).shape == (2,) + + +def test_tensorflow_segment_ops(): + data = bm.asarray([1., 2., 3., 4.]) + seg = bm.asarray([0, 0, 1, 1]) + assert _j(ctf.segment_sum(data, seg)).shape == (2,) + assert _j(ctf.segment_prod(data, seg)).shape == (2,) + assert _j(ctf.segment_max(data, seg)).shape == (2,) + assert _j(ctf.segment_min(data, seg)).shape == (2,) + np.testing.assert_allclose(np.asarray(_j(ctf.segment_mean(data, seg))), [1.5, 3.5]) + assert _j(ctf.unsorted_segment_sum(data, seg, 2)).shape == (2,) + assert _j(ctf.unsorted_segment_prod(data, seg, 2)).shape == (2,) + assert _j(ctf.unsorted_segment_max(data, seg, 2)).shape == (2,) + assert _j(ctf.unsorted_segment_min(data, seg, 2)).shape == (2,) + assert _j(ctf.unsorted_segment_mean(data, seg, 2)).shape == (2,) + assert _j(ctf.unsorted_segment_sqrt_n(data, seg, 2)).shape == (2,) + + +def test_tensorflow_cast_clip_concat(): + a = bm.asarray([1.4, 2.6, 3.1]) + casted = ctf.cast(a, jnp.int32) + assert casted.dtype == jnp.int32 + np.testing.assert_allclose(np.asarray(_j(ctf.clip_by_value(a, 2., 3.))), [2., 2.6, 3.]) + assert _j(ctf.concat([a, a])).shape == (6,) + + +# --- activations ------------------------------------------------------------ + +def test_activations_basic(): + x = bm.asarray([-2., -0.5, 0., 0.5, 2.]) + for f in [act.relu, act.relu6, act.sigmoid, act.softplus, act.silu, act.swish, + act.mish, act.selu, act.elu, act.celu, act.soft_sign, act.log_sigmoid, + act.hard_sigmoid, act.hard_silu, act.hard_swish, act.tanh_shrink, + act.leaky_relu, act.hard_shrink, act.soft_shrink, act.prelu, + act.identity]: + r = f(x) + assert np.asarray(_j(r)).shape == (5,) + assert _finite(r) + + +def test_activations_relu_correct(): + x = bm.asarray([-1., 0., 2.]) + np.testing.assert_allclose(np.asarray(_j(act.relu(x))), [0., 0., 2.]) + np.testing.assert_allclose(np.asarray(_j(act.relu6(bm.asarray([-1., 3., 9.])))), [0., 3., 6.]) + + +def test_activations_gelu_both(): + x = bm.asarray([-1., 0., 1., 2.]) + assert _finite(act.gelu(x, approximate=True)) + assert _finite(act.gelu(x, approximate=False)) + + +def test_activations_softmax_family(): + x = bm.asarray([1., 2., 3.]) + sm = act.softmax(x) + assert float(jnp.sum(_j(sm))) == pytest.approx(1.0, abs=1e-6) + ls = act.log_softmax(x) + np.testing.assert_allclose(np.asarray(_j(jnp.exp(ls))), np.asarray(_j(sm)), atol=1e-6) + smn = act.softmin(x) + assert float(jnp.sum(_j(smn))) == pytest.approx(1.0, abs=1e-6) + assert act.soft_max is act.softmax + + +def test_activations_softmax_large_inputs_stable(): + x = bm.asarray([1000., 1001., 1002.]) + assert _finite(act.softmax(x)) + assert _finite(act.log_softmax(x)) + assert _finite(act.softmin(x)) + + +def test_activations_tanh_and_glu(): + x = bm.asarray([-1., 0., 1.]) + np.testing.assert_allclose(np.asarray(_j(act.tanh(x))), np.tanh([-1., 0., 1.]), atol=1e-6) + assert _j(act.glu(bm.asarray(jnp.arange(4.)))).shape == (2,) + with pytest.raises(AssertionError): + act.glu(bm.asarray(jnp.arange(3.))) # odd axis size + + +def test_activations_one_hot_and_normalize(): + oh = act.one_hot(bm.asarray([0, 1, 2]), 3) + assert _j(oh).shape == (3, 3) + np.testing.assert_allclose(np.asarray(_j(oh)), np.eye(3)) + # out-of-range indices -> all zeros + oh2 = act.one_hot(bm.asarray([-1, 5]), 3) + np.testing.assert_allclose(np.asarray(_j(oh2)), np.zeros((2, 3))) + n = act.normalize(bm.asarray([1., 2., 3., 4.])) + assert _finite(n) + assert float(jnp.mean(_j(n))) == pytest.approx(0.0, abs=1e-5) + + +def test_activations_rrelu(): + r = act.rrelu(bm.asarray([-1., -2., 1., 2.])) + assert _j(r).shape == (4,) + # positive entries pass through unchanged + assert float(_j(r)[2]) == 1. + assert float(_j(r)[3]) == 2. + + +def test_activations_get_dispatch(): + assert act.get('relu') is act.relu + assert act.get(None) is None + fn = (lambda x: x) + assert act.get(fn) is fn + with pytest.raises(ValueError): + act.get('this_is_not_an_activation') + with pytest.raises(ValueError): + act.get(123) + + +def test_activations_accept_plain_jax_arrays(): + x = jnp.array([-1., 0., 1.]) + for f in [act.relu, act.sigmoid, act.softmax, act.softmin, act.elu, act.gelu]: + assert _finite(f(x)) + + +def test_activations_axis_out_of_bounds(): + # one_hot routes axis through _canonicalize_axis -> ValueError on OOB + with pytest.raises(ValueError): + act.one_hot(jnp.array([0, 1, 2]), 3, axis=5) + # log_softmax/softmax over a non-existent axis raises + with pytest.raises(Exception): + act.log_softmax(jnp.arange(6.), axis=5) + + +def test_activations_one_hot_axis_and_dtype(): + oh = act.one_hot(jnp.array([0, 1, 2]), 3, axis=0) + assert _j(oh).shape == (3, 3) + oh_i = act.one_hot(jnp.array([0, 1]), 2, dtype=jnp.int32) + assert _j(oh_i).dtype == jnp.int32 + + +def test_activations_hard_tanh_clamping(): + x = bm.asarray([-2., -0.5, 0.5, 2.]) + np.testing.assert_allclose(np.asarray(_j(act.hard_tanh(x))), [-1., -0.5, 0.5, 1.]) + + +def test_activations_softplus_threshold_branch(): + # values above threshold revert to the linear branch + x = bm.asarray([0.5, 25.0, 50.0]) + r = act.softplus(x, beta=1., threshold=20.) + assert _finite(r) + assert float(_j(r)[1]) == pytest.approx(25.0, rel=1e-5) + + +# --- others ----------------------------------------------------------------- + +def test_others_shared_args_over_time(): + r = bo.shared_args_over_time(num_step=5) + assert r['i'].shape == (5,) + assert r['t'].shape == (5,) + assert r['dt'].shape == (5,) + r2 = bo.shared_args_over_time(duration=1.0, dt=0.1, include_dt=False) + assert r2['i'].shape == (10,) + assert 'dt' not in r2 + + +def test_others_clip_by_norm(): + t = jnp.array([3., 4.]) + r = bo.clip_by_norm(t, 1.0) + assert float(jnp.linalg.norm(_j(r))) <= 1.0 + 1e-5 + # pytree input + r2 = bo.clip_by_norm({'a': jnp.array([3., 4.])}, 1.0) + assert 'a' in r2 + + +def test_others_exprel(): + r = bo.exprel(jnp.array([0., 1., -1.])) + assert _finite(r) + # exprel(0) == 1 (removable singularity) + assert float(_j(r)[0]) == pytest.approx(1.0, abs=1e-4) + # exprel(x) == (exp(x)-1)/x away from zero + assert float(_j(r)[1]) == pytest.approx((np.e - 1.0), rel=1e-3) + # float64 threshold branch + r64 = bo.exprel(jnp.array([0.5])) + assert _finite(r64) + + +def test_others_is_float_type(): + assert bo.is_float_type(jnp.array([1., 2.])) + assert not bo.is_float_type(jnp.array([1, 2])) + + +def test_others_add_axis_axes(): + x = jnp.arange(3.) + assert _j(bo.add_axis(x, 0)).shape == (1, 3) + r = bo.add_axes(x, n_axes=2, pos2len={0: 4}) + assert _j(r).shape == (4, 3) + + +# --- _utils ----------------------------------------------------------------- + +def test_utils_as_jax_array_and_is_leaf(): + a = bm.asarray([1., 2., 3.]) + assert isinstance(butils._as_jax_array_(a), jax.Array) + assert butils._as_jax_array_(5) == 5 + assert butils._is_leaf(a) is True + assert butils._is_leaf(jnp.array([1.])) is False + + +def test_utils_compatible_wrapper_kwargs_translation(): + a = bm.asarray([[1., 2.], [3., 4.]]) + # PyTorch dim -> axis + np.testing.assert_allclose( + np.asarray(_j(bm.sum(a, dim=0))), [4., 6.]) + # PyTorch keepdim -> keepdims + assert _j(bm.sum(a, axis=0, keepdim=True)).shape == (1, 2) + # TensorFlow keep_dims -> keepdims + assert _j(bm.sum(a, axis=0, keep_dims=True)).shape == (1, 2) + + +def test_utils_wrapper_returns_brainpy_array(): + r = bm.add(bm.asarray([1., 2.]), bm.asarray([3., 4.])) + assert isinstance(r, Array) + + +def test_utils_wrapper_doc_and_name(): + # _compatible_with_brainpy_array preserves the wrapped function name + assert cn.sum.__name__ == 'sum' + assert 'brainpy Array/Variable' in cn.sum.__doc__ + + +# --- einops ----------------------------------------------------------------- + +def test_einops_rearrange(): + x = jnp.arange(24.).reshape(2, 3, 4) + assert bein.ein_rearrange(x, 'a b c -> a c b').shape == (2, 4, 3) + assert bein.ein_rearrange(x, 'a b c -> (a b) c').shape == (6, 4) + assert bein.ein_rearrange(x, 'a b c -> a b c').shape == (2, 3, 4) + # split an axis + assert bein.ein_rearrange(jnp.arange(12.), '(a b) -> a b', a=3).shape == (3, 4) + + +def test_einops_reduce(): + x = jnp.arange(24.).reshape(2, 3, 4) + assert bein.ein_reduce(x, 'a b c -> a c', 'mean').shape == (2, 4) + assert bein.ein_reduce(x, 'a b c -> a', 'sum').shape == (2,) + assert bein.ein_reduce(x, 'a b c -> b c', 'max').shape == (3, 4) + assert bein.ein_reduce(x, 'a b c -> b c', 'min').shape == (3, 4) + assert bein.ein_reduce(x, 'a b c -> b c', 'prod').shape == (3, 4) + # pooling-style reduce with explicit axis lengths + y = jnp.arange(2 * 2 * 4 * 4.).reshape(2, 2, 4, 4) + assert bein.ein_reduce(y, 'b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2).shape == (2, 2, 2, 2) + + +def test_einops_reduce_any_all(): + b = jnp.array([[True, False, True], [False, False, True]]) + assert bein.ein_reduce(b, 'a c -> c', 'any').shape == (3,) + assert bein.ein_reduce(b, 'a c -> c', 'all').shape == (3,) + np.testing.assert_array_equal( + np.asarray(_j(bein.ein_reduce(b, 'a c -> c', 'any'))), [True, False, True]) + np.testing.assert_array_equal( + np.asarray(_j(bein.ein_reduce(b, 'a c -> c', 'all'))), [False, False, True]) + + +def test_einops_repeat(): + img = jnp.arange(6.).reshape(2, 3) + assert bein.ein_repeat(img, 'h w -> h w c', c=4).shape == (2, 3, 4) + assert bein.ein_repeat(img, 'h w -> (h2 h) w', h2=2).shape == (4, 3) + + +def test_einops_shape(): + x = jnp.zeros((2, 3, 5, 7)) + assert bein.ein_shape(x, 'batch _ h w') == {'batch': 2, 'h': 5, 'w': 7} + assert bein.ein_shape(x, 'a b c d') == {'a': 2, 'b': 3, 'c': 5, 'd': 7} + + +def test_einops_reduce_callable_reduction(): + x = jnp.arange(24.).reshape(2, 3, 4) + out = bein.ein_reduce(x, 'a b c -> a c', lambda t, axes: t.sum(axis=axes)) + assert out.shape == (2, 4) + + +def test_einops_mean_requires_float(): + x = jnp.arange(24).reshape(2, 3, 4) # integer tensor + with pytest.raises(Exception): + bein.ein_reduce(x, 'a b c -> a c', 'mean') + + +def test_einops_error_message_wrapped(): + from brainpy.math.einops_parsing import EinopsError + with pytest.raises(EinopsError): + bein.ein_rearrange(jnp.arange(6.), 'a b c -> a b c') # wrong ndim + + +def test_einops_enumerate_directions_internal(): + x = jnp.zeros((2, 3)) + dirs = bein._enumerate_directions(x) + assert len(dirs) == 2 + assert _j(dirs[0]).shape == (2, 1) + assert _j(dirs[1]).shape == (1, 3) + + +def test_einops_ellipsis_patterns(): + x = jnp.arange(24.).reshape(2, 3, 4) + # reduce trailing axis, keep ellipsis dims + assert bein.ein_reduce(x, '... c -> ...', 'sum').shape == (2, 3) + # move leading axis to the end across an ellipsis + assert bein.ein_rearrange(x, 'a ... -> ... a').shape == (3, 4, 2) + # repeat with ellipsis + assert bein.ein_repeat(jnp.arange(6.).reshape(2, 3), '... -> ... r', r=2).shape == (2, 3, 2) + + +def test_einops_shape_with_ellipsis(): + x = jnp.zeros((2, 3, 5, 7)) + assert bein.ein_shape(x, 'b ... w') == {'b': 2, 'w': 7} + + +def test_einops_error_branches(): + from brainpy.math.einops_parsing import EinopsError + x = jnp.arange(24.).reshape(2, 3, 4) + # identifiers only on one side of a rearrange + with pytest.raises(EinopsError): + bein.ein_rearrange(x, 'a b c -> a b') + # repeat without a size for a new axis + with pytest.raises(EinopsError): + bein.ein_repeat(jnp.arange(6.).reshape(2, 3), 'h w -> h w c') + # extra identifier on the right of a reduce + with pytest.raises(EinopsError): + bein.ein_reduce(x, 'a b c -> a b c d', 'sum') + # unknown reduction name + with pytest.raises(EinopsError): + bein.ein_reduce(x, 'a b c -> a', 'median') + # composed axes can't be parsed by ein_shape + with pytest.raises(RuntimeError): + bein.ein_shape(jnp.zeros((6,)), '(a b)') + + +def test_einops_list_input_passthrough_identity(): + # NOTE: the docstrings advertise stacking list-of-tensors input, but this + # port does not stack the list -- an identity pattern is a no-op and returns + # the list unchanged. Pin the current (documented-but-incomplete) behaviour. + imgs = [jnp.zeros((3, 4)) for _ in range(5)] + out = bein.ein_rearrange(imgs, 'b h w -> b h w') + assert isinstance(out, list) + assert len(out) == 5 diff --git a/tests/audit/test_math_core_fixes.py b/tests/audit/test_math_core_fixes.py new file mode 100644 index 000000000..310b050bb --- /dev/null +++ b/tests/audit/test_math_core_fixes.py @@ -0,0 +1,897 @@ +# -*- coding: utf-8 -*- +"""Regression + coverage tests for the BrainPy v2.7.8 math-core audit +(see ``docs/issues-found-20260618.md``). + +This module exercises the fixes recorded in the audit for: + +* ``brainpy/math/ndarray.py`` — H-11 (``Array.device`` is a property), + H-12 (``Array(scalar)``), M-09 (``ShardedArray.value`` read), Array.tree_* + pytree round-trip under abstract eval, and L-03 (base vs sharded value + setter policy). +* ``brainpy/math/environment.py`` — C-10 (``disable_x64`` re-syncs brainstate + precision) and M-07 (``set()`` validate-before-mutate). +* ``brainpy/math/modes.py`` — H-10 (``Mode`` is hashable again). +* ``brainpy/math/scales.py`` — L-02 (``IdScaling`` rejects non-default + bias/scale instead of silently ignoring them). +* ``brainpy/math/sharding.py`` — M-10 (``get_sharding`` warns on a full + axis-name mismatch). +* ``brainpy/math/remove_vmap.py`` — M-08 (documented global-reduction batching + behaviour). + +Every test that toggles x64 / global precision / global environment restores +the original state in a ``finally`` block (or via the ``restore_environment`` +fixture) so it cannot corrupt other tests in the suite. +""" + +import warnings + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +import brainstate +from jax import config + +import brainpy +import brainpy.math as bm +from brainpy._errors import MathError +from brainpy.math import modes, scales, sharding +from brainpy.math.ndarray import Array, ShardedArray, JaxArray, ndarray +from brainpy.math.remove_vmap import remove_vmap +from brainpy.math.defaults import defaults as _defaults + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def restore_environment(): + """Snapshot global precision / x64 / float state and restore it afterwards. + + Any test that flips ``enable_x64``/``disable_x64`` or mutates global dtype + settings MUST run inside this fixture so the rest of the suite keeps the + default float32 environment. + """ + orig_precision = brainstate.environ.get_precision() + orig_x64 = config.read('jax_enable_x64') + orig_float = bm.get_float() + orig_int = bm.get_int() + orig_complex = bm.get_complex() + try: + yield + finally: + # Restore JAX + brainstate precision symmetrically. + if orig_x64: + bm.enable_x64() + else: + bm.disable_x64() + brainstate.environ.set(precision=orig_precision) + bm.set_float(orig_float) + bm.set_int(orig_int) + bm.set_complex(orig_complex) + + +# =========================================================================== +# Regression tests +# =========================================================================== + +# --- environment.py : C-10 ------------------------------------------------- + +def test_disable_x64_resyncs_brainstate_precision(restore_environment): + """C-10: after ``enable_x64(); disable_x64()`` the brainstate precision is + back to 32 (it used to be left at 64 while JAX was at float32).""" + bm.enable_x64() + assert brainstate.environ.get_precision() == 64 + assert config.read('jax_enable_x64') is True + + bm.disable_x64() + assert brainstate.environ.get_precision() == 32 + assert config.read('jax_enable_x64') is False + + +def test_enable_then_disable_x64_leaves_precision_32(restore_environment): + """The brainstate precision and JAX x64 flag stay in lock-step.""" + bm.enable_x64() + bm.disable_x64() + assert brainstate.environ.get_precision() == 32 + # float default tracks the disabled state. + assert bm.get_float() == jnp.float32 + + +# --- environment.py : M-07 ------------------------------------------------- + +def test_set_validate_before_mutate_invalid_numpy_func_return(): + """M-07: an invalid ``numpy_func_return`` raises and does NOT mutate + ``get_float()`` (validation happens before any global write).""" + before_float = bm.get_float() + before_return = _defaults.numpy_func_return + with pytest.raises(AssertionError): + bm.set(float_=jnp.float64, numpy_func_return='not-a-valid-option') + # No partial-config leak: float type unchanged. + assert bm.get_float() == before_float + assert _defaults.numpy_func_return == before_return + + +def test_set_validate_before_mutate_invalid_dt(): + """A non-float ``dt`` is rejected before any other arg is applied.""" + before_mode = bm.get_mode() + with pytest.raises(AssertionError): + bm.set(mode=modes.batching_mode, dt='not-a-float') + assert bm.get_mode() is before_mode + + +# --- modes.py : H-10 ------------------------------------------------------- + +def test_modes_are_hashable(): + """H-10: defining ``__eq__`` had nuked ``__hash__``; restore hashability.""" + assert hash(modes.nonbatching_mode) is not None + assert hash(brainpy.math.modes.nonbatching_mode) is not None + + +def test_modes_usable_in_a_set(): + """All three default modes can live together in a set.""" + s = {modes.nonbatching_mode, modes.batching_mode, modes.training_mode} + assert len(s) == 3 + # Two non-batching instances compare/hash equal (by class). + s2 = {modes.NonBatchingMode(), modes.NonBatchingMode()} + assert len(s2) == 1 + + +# --- ndarray.py : H-11 ----------------------------------------------------- + +def test_array_device_is_property_returning_jax_device(): + """H-11: ``Array.device`` is now a property (calling the old method form + raised ``TypeError``). It returns a ``jax.Device``.""" + a = bm.asarray([1., 2.]) + dev = a.device # property access, not a call + assert isinstance(dev, jax.Device) + + +# --- ndarray.py : H-12 ----------------------------------------------------- + +def test_array_scalar_conversion_shape(): + """H-12: ``Array(scalar)`` stores a real jax array, so ``.shape == ()``.""" + assert Array(5).shape == () + assert Array(5).ndim == 0 + assert Array(5.0).shape == () + # value is a jax array, not a bare python scalar. + assert isinstance(Array(5).value, jax.Array) + + +# --- ndarray.py : pytree round-trip ---------------------------------------- + +def test_array_pytree_round_trip_under_abstract_eval(): + """``Array.tree_unflatten`` must accept abstract leaves so transforms like + ``jax.eval_shape`` work without re-running ``jnp.asarray``.""" + out = jax.eval_shape(lambda x: x * 2, bm.asarray([1., 2., 3.])) + assert out.shape == (3,) + + +def test_array_tree_flatten_unflatten_concrete(): + a = Array([1., 2., 3.]) + leaves, aux = a.tree_flatten() + assert aux is None + rebuilt = Array.tree_unflatten(aux, leaves) + np.testing.assert_array_equal(np.asarray(rebuilt.value), np.asarray(a.value)) + + +def test_array_jit_round_trip(): + """A full jit round-trip preserves the Array pytree contract.""" + a = bm.asarray([1., 2., 3.]) + out = jax.jit(lambda x: x + 1.)(a) + np.testing.assert_allclose(np.asarray(out), np.array([2., 3., 4.])) + + +# --- scales.py : L-02 ------------------------------------------------------ + +def test_idscaling_rejects_non_default_bias(): + ids = scales.IdScaling() + with pytest.raises(ValueError): + ids.offset_scaling(1.0, bias=5.0) + + +def test_idscaling_rejects_non_default_scale(): + ids = scales.IdScaling() + with pytest.raises(ValueError): + ids.std_scaling(1.0, scale=2.0) + with pytest.raises(ValueError): + ids.inv_scaling(1.0, scale=2.0) + with pytest.raises(ValueError): + ids.clone(scale=2.0) + with pytest.raises(ValueError): + ids.clone(bias=3.0) + + +def test_idscaling_identity_for_default_calls(): + ids = scales.IdScaling() + assert ids.offset_scaling(7.0) == 7.0 + assert ids.std_scaling(7.0) == 7.0 + assert ids.inv_scaling(7.0) == 7.0 + # Explicit default values (0./1.) are accepted (no-op overrides). + assert ids.offset_scaling(7.0, bias=0., scale=1.) == 7.0 + assert isinstance(ids.clone(), scales.IdScaling) + assert ids.scale == 1.0 and ids.bias == 0.0 + + +# --- sharding.py : M-10 ---------------------------------------------------- + +def _single_axis_mesh(name='x'): + return jax.sharding.Mesh(np.asarray(jax.devices()), axis_names=(name,)) + + +def test_get_sharding_warns_on_full_axis_mismatch(): + """M-10: when *every* requested axis name is absent, a fully-replicated + PartitionSpec is produced; warn instead of silently dropping intent.""" + mesh = _single_axis_mesh('x') + with pytest.warns(UserWarning): + sh = sharding.get_sharding(['this_axis_does_not_exist'], mesh) + assert sh is not None # still returns a (replicated) NamedSharding + + +def test_get_sharding_no_warning_on_partial_match(): + """A partial match is tolerated on purpose: no warning.""" + mesh = _single_axis_mesh('x') + with warnings.catch_warnings(): + warnings.simplefilter('error') # any warning -> test failure + sh = sharding.get_sharding(['x', 'missing'], mesh) + assert sh is not None + + +def test_get_sharding_none_axis_names_returns_none(): + assert sharding.get_sharding(None) is None + + +# --- remove_vmap.py : M-08 ------------------------------------------------- + +def test_remove_vmap_global_reduction_any(): + """M-08 documents that the reduction is global. Outside vmap it is a real + scalar.""" + out = remove_vmap(jnp.array([False, True])) + assert out.shape == () + assert bool(out) is True + + +def test_remove_vmap_global_reduction_all(): + assert bool(remove_vmap(jnp.array([True, True]), 'all')) is True + assert bool(remove_vmap(jnp.array([True, False]), 'all')) is False + + +def test_remove_vmap_under_vmap_is_global(): + """Under ``vmap`` the result is a *global* reduction across the batch (the + documented, intentional behaviour): every batch slot sees the same value.""" + data = jnp.array([[0.1], [0.9]]) + out = jax.vmap(lambda x: remove_vmap(x > 0.5, 'any'))(data) + # Global any over the whole batch is True; broadcast across the batch. + assert bool(out[0]) is True and bool(out[1]) is True + out_all = jax.vmap(lambda x: remove_vmap(x > 0.5, 'all'))(data) + # Global all over the whole batch is False. + assert bool(out_all[0]) is False and bool(out_all[1]) is False + + +def test_remove_vmap_invalid_op_raises(): + with pytest.raises(ValueError): + remove_vmap(jnp.array([1]), 'unsupported') + + +def test_remove_vmap_unwraps_bp_array(): + """A ``brainpy.math.Array`` input is unwrapped (line ``x = x.value``).""" + assert bool(remove_vmap(Array([False, True]))) is True + assert bool(remove_vmap(Array([True, True]), 'all')) is True + + +def test_remove_vmap_under_jit_triggers_abstract_eval(): + """Running under ``jit`` exercises the abstract-eval rules of both + primitives.""" + f_any = jax.jit(lambda x: remove_vmap(x > 0., 'any')) + f_all = jax.jit(lambda x: remove_vmap(x > 0., 'all')) + assert bool(f_any(jnp.array([-1., 1.]))) is True + assert bool(f_all(jnp.array([1., 1.]))) is True + assert bool(f_all(jnp.array([-1., 1.]))) is False + + +# =========================================================================== +# Coverage tests +# =========================================================================== + +# --- ndarray.py : Array methods -------------------------------------------- + +def test_array_arithmetic(): + a = Array([1., 2., 3.]) + np.testing.assert_allclose(np.asarray(a + 1), [2., 3., 4.]) + np.testing.assert_allclose(np.asarray(a - 1), [0., 1., 2.]) + np.testing.assert_allclose(np.asarray(a * 2), [2., 4., 6.]) + np.testing.assert_allclose(np.asarray(-a), [-1., -2., -3.]) + np.testing.assert_allclose(np.asarray(a / 2), [0.5, 1., 1.5]) + + +def test_array_indexing_and_iteration(): + a = Array([10., 20., 30.]) + assert float(a[0]) == 10.0 + assert len(a) == 3 + assert [float(x) for x in a] == [10.0, 20.0, 30.0] + # slice + np.testing.assert_allclose(np.asarray(a[1:]), [20., 30.]) + + +def test_array_inplace_set(): + a = Array([1., 2., 3.]) + a[0] = 99. + np.testing.assert_allclose(np.asarray(a.value), [99., 2., 3.]) + # .at accessor delegates to the underlying jax array. + updated = a.at[1].set(0.) + np.testing.assert_allclose(np.asarray(updated), [99., 0., 3.]) + + +def test_array_repr_scalar_and_vector(): + r = repr(Array([1., 2., 3.])) + assert r.startswith('Array(value=') + # multi-line repr branch + big = Array(np.arange(40.).reshape(8, 5)) + r2 = repr(big) + assert 'Array(value=' in r2 and '\n' in r2 and 'dtype=' in r2 + + +def test_array_dtype_shape_ndim(): + a = Array([[1., 2.], [3., 4.]]) + assert a.shape == (2, 2) + assert a.ndim == 2 + assert a.dtype == jnp.float32 + + +def test_array_value_setter_with_array_np_and_state(): + a = Array([0., 0.]) + # set from numpy + a.value = np.array([4., 5.]) + np.testing.assert_allclose(np.asarray(a.value), [4., 5.]) + # set from another Array + a.value = Array([7., 8.]) + np.testing.assert_allclose(np.asarray(a.value), [7., 8.]) + # set from a jax array + a.value = jnp.array([1., 2.]) + np.testing.assert_allclose(np.asarray(a.value), [1., 2.]) + # set from a State / Variable + a.value = bm.Variable(jnp.array([9., 10.])) + np.testing.assert_allclose(np.asarray(a.value), [9., 10.]) + # set from a python list (the "else -> jnp.asarray" branch) + a.value = [11., 12.] + np.testing.assert_allclose(np.asarray(a.value), [11., 12.]) + + +def test_array_data_property_and_update(): + a = Array([1., 2.]) + np.testing.assert_allclose(np.asarray(a.data), [1., 2.]) + a.data = jnp.array([3., 4.]) + np.testing.assert_allclose(np.asarray(a.data), [3., 4.]) + a.update(jnp.array([5., 6.])) + np.testing.assert_allclose(np.asarray(a.value), [5., 6.]) + + +def test_array_fill(): + a = Array([1., 2., 3.]) + a.fill_(7.) + np.testing.assert_allclose(np.asarray(a.value), [7., 7., 7.]) + # Array scalar as fill value + a.fill_(Array(2.)) + np.testing.assert_allclose(np.asarray(a.value), [2., 2., 2.]) + # numpy scalar + a.fill_(np.float32(3.)) + np.testing.assert_allclose(np.asarray(a.value), [3., 3., 3.]) + # non-scalar fill value is rejected + with pytest.raises(MathError): + a.fill_(np.array([1., 2.])) + + +def test_array_numpy_and_jax_protocols(): + a = Array([1., 2., 3.]) + np.testing.assert_allclose(np.asarray(a), [1., 2., 3.]) + assert isinstance(a.__jax_array__(), jax.Array) + # __array__ with dtype + arr = np.asarray(a, dtype=np.int32) + assert arr.dtype == np.int32 + + +def test_array_as_variable(): + a = Array([1., 2.]) + v = a.as_variable() + assert isinstance(v, bm.Variable) + + +def test_array_block_until_ready_and_device_buffer(): + a = Array([1., 2., 3.]) + assert isinstance(a.block_until_ready(), jax.Array) + assert isinstance(a.block_host_until_ready(), jax.Array) + np.testing.assert_allclose(np.asarray(a.device_buffer), [1., 2., 3.]) + + +def test_array_aliases(): + assert JaxArray is Array + assert ndarray is Array + + +def test_array_constructed_from_array_and_dtype(): + base = Array([1., 2., 3.]) + a = Array(base, dtype=jnp.int32) + assert a.dtype == jnp.int32 + np.testing.assert_array_equal(np.asarray(a.value), [1, 2, 3]) + # tuple input branch + b = Array((1., 2.)) + np.testing.assert_allclose(np.asarray(b.value), [1., 2.]) + + +# --- ndarray.py : ShardedArray --------------------------------------------- + +def test_sharded_array_value_read_write(): + sa = ShardedArray(jnp.array([1., 2., 3.])) + np.testing.assert_allclose(np.asarray(sa.value), [1., 2., 3.]) + sa.value = jnp.array([4., 5., 6.]) + np.testing.assert_allclose(np.asarray(sa.value), [4., 5., 6.]) + + +def test_sharded_array_enforces_shape_and_dtype(): + """L-03 / M-09: the ShardedArray setter enforces shape & dtype.""" + sa = ShardedArray(jnp.array([1., 2., 3.])) + with pytest.raises(MathError): + sa.value = jnp.array([1., 2.]) # wrong shape + with pytest.raises(MathError): + sa.value = jnp.array([1, 2, 3]) # wrong dtype + + +def test_sharded_array_keep_sharding_false(): + sa = ShardedArray(jnp.array([1., 2.]), keep_sharding=False) + np.testing.assert_allclose(np.asarray(sa.value), [1., 2.]) + + +def test_sharded_array_setter_from_array_and_np(): + sa = ShardedArray(jnp.array([1., 2.])) + sa.value = Array([3., 4.]) + np.testing.assert_allclose(np.asarray(sa.value), [3., 4.]) + sa.value = np.array([5., 6.], dtype=np.float32) + np.testing.assert_allclose(np.asarray(sa.value), [5., 6.]) + + +# --- environment.py : getters / setters ------------------------------------ + +def test_environment_dtype_getters(): + assert bm.get_float() in (jnp.float32, jnp.float64) + assert bm.get_int() in (jnp.int32, jnp.int64) + assert bm.get_complex() in (jnp.complex64, jnp.complex128) + assert bm.get_bool() == jnp.bool_ + # deprecated aliases still resolve. + assert bm.dftype() == bm.get_float() + assert bm.ditype() == bm.get_int() + + +def test_environment_dtype_setters_round_trip(): + orig_float = bm.get_float() + orig_int = bm.get_int() + orig_complex = bm.get_complex() + orig_bool = bm.get_bool() + try: + bm.set_float(jnp.float16) + assert bm.get_float() == jnp.float16 + bm.set_int(jnp.int16) + assert bm.get_int() == jnp.int16 + bm.set_complex(jnp.complex64) + assert bm.get_complex() == jnp.complex64 + bm.set_bool(jnp.bool_) + assert bm.get_bool() == jnp.bool_ + finally: + bm.set_float(orig_float) + bm.set_int(orig_int) + bm.set_complex(orig_complex) + bm.set_bool(orig_bool) + + +def test_environment_dt_setter_round_trip(): + orig = bm.get_dt() + try: + bm.set_dt(0.05) + assert bm.get_dt() == 0.05 + finally: + bm.set_dt(orig) + assert bm.get_dt() == orig + + +def test_environment_mode_setter_round_trip(): + orig = bm.get_mode() + try: + bm.set_mode(modes.batching_mode) + assert bm.get_mode() is modes.batching_mode + finally: + bm.set_mode(orig) + with pytest.raises(TypeError): + bm.set_mode('not-a-mode') + + +def test_environment_membrane_scaling_setter_round_trip(): + orig = bm.get_membrane_scaling() + try: + s = scales.Scaling(scale=2., bias=1.) + bm.set_membrane_scaling(s) + assert bm.get_membrane_scaling() is s + finally: + bm.set_membrane_scaling(orig) + with pytest.raises(TypeError): + bm.set_membrane_scaling('not-a-scaling') + + +def test_environment_get_platform(): + assert bm.get_platform() in ('cpu', 'gpu', 'tpu') + + +def test_set_applies_all_valid_args_round_trip(restore_environment): + """Exercise the apply branches of ``set()`` with every argument, then + restore each global it touched.""" + orig_dt = bm.get_dt() + orig_mode = bm.get_mode() + orig_ms = bm.get_membrane_scaling() + orig_float = bm.get_float() + orig_int = bm.get_int() + orig_bool = bm.get_bool() + orig_complex = bm.get_complex() + orig_pytree = _defaults.bp_object_as_pytree + orig_return = _defaults.numpy_func_return + try: + bm.set( + mode=modes.batching_mode, + membrane_scaling=scales.Scaling(scale=2., bias=0.), + dt=0.25, + x64=False, + complex_=jnp.complex64, + float_=jnp.float32, + int_=jnp.int32, + bool_=jnp.bool_, + bp_object_as_pytree=True, + numpy_func_return='jax_array', + ) + assert bm.get_dt() == 0.25 + assert isinstance(bm.get_mode(), modes.BatchingMode) + assert bm.get_membrane_scaling().scale == 2. + assert bm.get_float() == jnp.float32 + assert bm.get_int() == jnp.int32 + assert bm.get_complex() == jnp.complex64 + assert _defaults.bp_object_as_pytree is True + assert _defaults.numpy_func_return == 'jax_array' + finally: + bm.set_dt(orig_dt) + bm.set_mode(orig_mode) + bm.set_membrane_scaling(orig_ms) + bm.set_float(orig_float) + bm.set_int(orig_int) + bm.set_bool(orig_bool) + bm.set_complex(orig_complex) + _defaults.bp_object_as_pytree = orig_pytree + _defaults.numpy_func_return = orig_return + + +def test_set_environment_is_set_alias(): + from brainpy.math.environment import set_environment, set as _set + assert set_environment is _set + + +def test_environment_context_all_dtype_kwargs(restore_environment): + """Exercise the ``__init__``/``__enter__``/``__exit__`` branches for the + dtype + scaling + pytree + numpy_func_return arguments.""" + orig_float = bm.get_float() + with bm.environment( + membrane_scaling=scales.Scaling(scale=3., bias=0.), + float_=jnp.float32, + int_=jnp.int32, + bool_=jnp.bool_, + complex_=jnp.complex64, + bp_object_as_pytree=True, + numpy_func_return='jax_array', + ): + assert bm.get_membrane_scaling().scale == 3. + assert _defaults.bp_object_as_pytree is True + assert _defaults.numpy_func_return == 'jax_array' + assert bm.get_float() == orig_float + assert _defaults.numpy_func_return != 'jax_array' or _defaults.numpy_func_return == 'bp_array' + + +def test_environment_as_decorator(restore_environment): + """``environment`` doubles as a decorator (``_DecoratorContextManager``).""" + + @bm.environment(dt=0.07) + def get_dt_inside(): + return bm.get_dt() + + orig = bm.get_dt() + assert get_dt_inside() == 0.07 + assert bm.get_dt() == orig + + +def test_environment_init_rejects_bad_types(): + with pytest.raises(AssertionError): + bm.environment(dt='not-a-float') + with pytest.raises(AssertionError): + bm.environment(mode='not-a-mode') + with pytest.raises(AssertionError): + bm.environment(x64='not-a-bool') + with pytest.raises(AssertionError): + bm.environment(numpy_func_return='bad-option') + + +def test_environment_context_manager_restores(restore_environment): + orig_mode = bm.get_mode() + orig_dt = bm.get_dt() + with bm.environment(mode=modes.batching_mode, dt=0.2): + assert bm.get_mode() is modes.batching_mode + assert bm.get_dt() == 0.2 + assert bm.get_mode() is orig_mode + assert bm.get_dt() == orig_dt + + +def test_batching_and_training_environment(restore_environment): + with bm.batching_environment(dt=0.3): + assert isinstance(bm.get_mode(), modes.BatchingMode) + with bm.training_environment(batch_size=4): + m = bm.get_mode() + assert isinstance(m, modes.TrainingMode) + assert m.batch_size == 4 + + +def test_environment_x64_context_manager(restore_environment): + """``environment(x64=...)`` flips and restores the precision symmetrically.""" + start = config.read('jax_enable_x64') + with bm.environment(x64=not start): + assert config.read('jax_enable_x64') == (not start) + assert config.read('jax_enable_x64') == start + + +def test_enable_x64_with_bool_argument_warns(restore_environment): + """The legacy ``enable_x64(True)`` path emits a DeprecationWarning.""" + with pytest.warns(DeprecationWarning): + bm.enable_x64(True) + assert config.read('jax_enable_x64') is True + + +def test_set_x64_helper(restore_environment): + from brainpy.math.environment import set_x64 + set_x64(True) + assert config.read('jax_enable_x64') is True + set_x64(False) + assert config.read('jax_enable_x64') is False + + +def test_enable_x64_false_branch_warns(restore_environment): + """The deprecated ``enable_x64(False)`` path routes through the disable + branch and warns.""" + bm.enable_x64() # go to 64 first + with pytest.warns(DeprecationWarning): + bm.enable_x64(False) + assert config.read('jax_enable_x64') is False + assert brainstate.environ.get_precision() == 32 + + +def test_environment_decorator_on_generator(restore_environment): + """Cover ``_DecoratorContextManager._wrap_generator`` by decorating a + generator function with ``environment``.""" + + @bm.environment(dt=0.09) + def gen(): + yield bm.get_dt() + yield bm.get_dt() + + orig = bm.get_dt() + values = list(gen()) + assert values == [0.09, 0.09] + assert bm.get_dt() == orig + + +def test_set_host_device_count_sets_env_var(monkeypatch): + """``set_host_device_count`` writes the XLA flag (no global precision + change).""" + monkeypatch.setenv('XLA_FLAGS', '') + bm.set_host_device_count(3) + import os + assert '--xla_force_host_platform_device_count=3' in os.environ['XLA_FLAGS'] + + +def test_gpu_memory_preallocation_toggles(monkeypatch): + import os + from brainpy.math.environment import gpu_memory_preallocation + monkeypatch.delenv('XLA_PYTHON_CLIENT_PREALLOCATE', raising=False) + monkeypatch.delenv('XLA_PYTHON_CLIENT_ALLOCATOR', raising=False) + bm.disable_gpu_memory_preallocation() + assert os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] == 'false' + assert os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] == 'platform' + bm.enable_gpu_memory_preallocation() + assert os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] == 'true' + assert 'XLA_PYTHON_CLIENT_ALLOCATOR' not in os.environ + gpu_memory_preallocation(0.5) + assert os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] == '0.5' + with pytest.raises(AssertionError): + gpu_memory_preallocation(1.5) + + +def test_set_platform_rejects_unknown(): + with pytest.raises(AssertionError): + bm.set_platform('quantum') + + +def test_environment_identity_equality(): + """``environment.__eq__`` is identity-based.""" + env1 = bm.environment(dt=0.1) + env2 = bm.environment(dt=0.1) + assert env1 == env1 + assert env1 != env2 + + +def test_environment_clone_preserves_settings(): + env = bm.environment(dt=0.1, mode=modes.batching_mode) + clone = env.clone() + assert clone.dt == 0.1 + assert clone.mode is modes.batching_mode + + +# --- modes.py -------------------------------------------------------------- + +def test_mode_repr(): + assert repr(modes.nonbatching_mode) == 'NonBatchingMode' + assert repr(modes.batching_mode) == 'BatchingMode(batch_size=1)' + assert repr(modes.training_mode) == 'TrainingMode(batch_size=1)' + + +def test_mode_equality(): + assert modes.NonBatchingMode() == modes.NonBatchingMode() + assert (modes.NonBatchingMode() == modes.BatchingMode()) is False + # comparison against a non-Mode returns False (not an error) + assert (modes.nonbatching_mode == 5) is False + + +def test_mode_predicates(): + assert modes.batching_mode.is_batch_mode() is True + assert modes.training_mode.is_train_mode() is True + assert modes.nonbatching_mode.is_nonbatch_mode() is True + assert modes.nonbatching_mode.batch_size == tuple() + assert modes.batching_mode.batch_size == 1 + + +def test_mode_type_queries(): + assert modes.batching_mode.is_one_of(modes.BatchingMode, modes.TrainingMode) + assert modes.batching_mode.is_a(modes.BatchingMode) + assert modes.batching_mode.is_parent_of(modes.TrainingMode) + assert modes.training_mode.is_child_of(modes.BatchingMode) + # invalid (non-type) arguments raise TypeError + with pytest.raises(TypeError): + modes.batching_mode.is_one_of(modes.batching_mode) + with pytest.raises(TypeError): + modes.batching_mode.is_parent_of(modes.batching_mode) + with pytest.raises(TypeError): + modes.batching_mode.is_child_of(modes.batching_mode) + + +def test_training_mode_to_batch_mode(): + tm = modes.TrainingMode(3) + bmode = tm.to_batch_mode() + assert isinstance(bmode, modes.BatchingMode) + assert not isinstance(bmode, modes.TrainingMode) + assert bmode.batch_size == 3 + + +# --- scales.py ------------------------------------------------------------- + +def test_scaling_offset_std_inv(): + s = scales.Scaling(scale=2., bias=1.) + assert s.offset_scaling(3.0) == (3.0 + 1.0) / 2.0 + assert s.std_scaling(4.0) == 4.0 / 2.0 + assert s.inv_scaling(4.0) == 4.0 * 2.0 + # explicit overrides honored on a plain Scaling + assert s.offset_scaling(3.0, bias=0., scale=1.) == 3.0 + assert s.std_scaling(4.0, scale=4.0) == 1.0 + assert s.inv_scaling(4.0, scale=0.5) == 2.0 + + +def test_scaling_clone(): + s = scales.Scaling(scale=2., bias=1.) + c = s.clone() + assert c.scale == 2. and c.bias == 1. + c2 = s.clone(bias=5., scale=3.) + assert c2.scale == 3. and c2.bias == 5. + + +def test_scaling_transform(): + s = scales.Scaling.transform([0., 10.], [0., 1.]) + assert s.scale == 10.0 + assert s.bias == 0.0 + # round-trip: offset then inv recovers the offset-corrected value + assert s.offset_scaling(10.0) == 1.0 + + +# --- sharding.py ----------------------------------------------------------- + +def test_device_mesh_context_sets_and_restores(): + devs = np.asarray(jax.devices()) + with sharding.device_mesh(devs, ('x',)) as mesh: + assert mesh.axis_names == ('x',) + sh = sharding.get_sharding(['x']) + assert sh is not None + # default mesh restored to None afterwards + assert sharding.get_sharding(['x']) is None + + +def test_partition_none_passthrough(): + x = jnp.array([1., 2.]) + assert sharding.partition(x, None) is x + + +def test_partition_by_axname_no_mesh_returns_input(): + x = jnp.array([1., 2.]) + # No default mesh -> input returned unchanged. + out = sharding.partition_by_axname(x, ['x']) + np.testing.assert_allclose(np.asarray(out), [1., 2.]) + # axis_names None -> input returned. + assert sharding.partition_by_axname(x, None) is x + + +def test_partition_by_axname_shape_mismatch_raises(): + devs = np.asarray(jax.devices()) + with sharding.device_mesh(devs, ('x',)): + with pytest.raises(ValueError): + # 1-D array but two requested axis names -> dim mismatch + sharding.partition_by_axname(jnp.array([1., 2.]), ['x', 'y']) + + +def test_partition_by_sharding_none_and_typecheck(): + x = jnp.array([1., 2.]) + assert sharding.partition_by_sharding(x, None) is x + with pytest.raises(TypeError): + sharding.partition_by_sharding(x, 'not-a-sharding') + + +def test_partition_invalid_type_raises(): + with pytest.raises(TypeError): + sharding.partition(jnp.array([1.]), 12345) + + +def test_keep_constraint_passthrough(): + out = sharding.keep_constraint(jnp.array([1., 2.])) + np.testing.assert_allclose(np.asarray(out), [1., 2.]) + # non-array leaves pass through untouched + assert sharding.keep_constraint(7) == 7 + + +def test_is_bp_array_helper(): + assert sharding.is_bp_array(Array([1.])) is True + assert sharding.is_bp_array(jnp.array([1.])) is False + + +def test_partition_by_axname_with_mesh_devices(): + """Exercise the real device-put path of ``partition_by_axname`` / + ``_device_put`` with an actual single-device mesh.""" + devs = np.asarray(jax.devices()) + with sharding.device_mesh(devs, ('x',)): + # 1-D array, one axis name -> dims match -> resharded via _device_put. + out = sharding.partition_by_axname(jnp.array([1., 2.]), ['x']) + np.testing.assert_allclose(np.asarray(out), [1., 2.]) + # Array leaf goes through the ``isinstance(x, Array)`` branch. + out2 = sharding.partition_by_axname(Array([3., 4.]), ['x']) + leaf = jax.tree_util.tree_leaves(out2, is_leaf=sharding.is_bp_array)[0] + np.testing.assert_allclose(np.asarray(leaf), [3., 4.]) + + +def test_partition_with_sharding_object(): + """``partition`` with a concrete ``Sharding`` instance routes through + ``partition_by_sharding``/``_device_put``.""" + devs = np.asarray(jax.devices()) + mesh = jax.sharding.Mesh(devs, axis_names=('x',)) + sh = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x')) + out = sharding.partition(jnp.array([1., 2.]), sh) + np.testing.assert_allclose(np.asarray(out.value if isinstance(out, Array) else out), + [1., 2.]) + + +def test_partition_with_axis_name_sequence(): + devs = np.asarray(jax.devices()) + with sharding.device_mesh(devs, ('x',)): + out = sharding.partition(jnp.array([1., 2.]), ['x']) + leaf = jax.tree_util.tree_leaves(out, is_leaf=sharding.is_bp_array)[0] + np.testing.assert_allclose(np.asarray(leaf), [1., 2.]) + + +def test_keep_constraint_on_bp_array(): + out = sharding.keep_constraint(Array([1., 2., 3.])) + np.testing.assert_allclose(np.asarray(out), [1., 2., 3.]) diff --git a/tests/audit/test_math_sparse_surrogate_fixes.py b/tests/audit/test_math_sparse_surrogate_fixes.py new file mode 100644 index 000000000..e35fd15b0 --- /dev/null +++ b/tests/audit/test_math_sparse_surrogate_fixes.py @@ -0,0 +1,836 @@ +# -*- coding: utf-8 -*- +"""Regression + coverage tests for the BrainPy v2.7.8 sparse / event / jitconn / +surrogate / delay / pre-syn-post audit (see ``docs/issues-found-20260618.md``). + +This module locks in the fixes recorded in the audit for the following source +files (audit IDs in parentheses): + +* ``brainpy/math/sparse/csr_mm.py`` — C-07 (``csrmm(transpose=True)`` must + compute ``Mᵀ @ B`` rather than ``B @ M``). +* ``brainpy/math/event/csr_matmat.py`` — C-07 (same transpose fix for the + event-driven CSR matmat path, with ``BinaryArray`` operand). +* ``brainpy/math/sparse/coo_mv.py`` — H-17 (``coomv`` must no longer touch + the removed ``brainevent.COO`` and must still equal ``dense @ v``). +* ``brainpy/math/sparse/utils.py`` — H-18 (``coo_to_csr`` returns an int, + monotone ``indptr`` starting at 0 / ending at ``nnz``) and H-19 + (``csr_to_dense`` wraps ``brainevent.CSR(...).todense()`` correctly). +* ``brainpy/math/jitconn/matvec.py`` — M-13 (``mv_prob_*`` / + ``event_mv_prob_*`` are reproducible when an explicit ``seed`` is threaded). +* ``brainpy/math/surrogate/_one_input.py`` and + ``brainpy/math/surrogate/_one_input_new.py`` — H-20..H-24: for every surrogate + the ``surrogate_grad`` matches ``jax.grad(surrogate_fun)``; ``GaussianGrad`` + widens with ``sigma`` (H-20); ``PiecewiseQuadratic`` grad matches its forward + derivative (H-21); ``QPseudoSpike`` uses ``alpha-1`` (H-22); ``Arctan`` + ``surrogate_fun`` does not raise (H-23); ``ERF`` ``surrogate_fun`` is + increasing (H-24). +* ``brainpy/math/delayvars.py`` — C-09 (``TimeDelay`` ring-buffer read + applies the modulo) plus ``LengthDelay`` / ``ROTATE_UPDATE`` / ``CONCAT_UPDATE`` + coverage. +* ``brainpy/math/pre_syn_post.py`` — M-15 plus ``syn2post_mean`` / + ``syn2post_softmax`` empty-group → 0 vs genuine-NaN-propagation behaviour. + +All tests use tiny shapes and complete well under the time budget. They assert +fixed (correct) behaviour, so they double as a regression guard against the bugs +re-appearing. +""" + +import numpy as np +import jax +import jax.numpy as jnp +import pytest + +import brainpy.math as bm +import brainpy.math.surrogate._one_input as old_surr +import brainpy.math.surrogate._one_input_new as new_surr + + +# --------------------------------------------------------------------------- +# Helpers: build a small known CSR by hand (no scipy dependency). +# +# dense = [[1, 0, 2, 0], +# [0, 3, 0, 0], +# [0, 0, 0, 4]] (3 x 4) +# --------------------------------------------------------------------------- + +def _known_csr(): + dense = np.array([[1., 0., 2., 0.], + [0., 3., 0., 0.], + [0., 0., 0., 4.]], dtype=np.float32) + data = jnp.asarray([1., 2., 3., 4.], dtype=jnp.float32) + indices = jnp.asarray([0, 2, 1, 3], dtype=jnp.int32) + indptr = jnp.asarray([0, 2, 3, 4], dtype=jnp.int32) + shape = (3, 4) + return dense, data, indices, indptr, shape + + +# =========================================================================== +# C-07 — csrmm(transpose=True) computes Mᵀ @ B (sparse + event paths) +# =========================================================================== + +def test_csrmm_non_transpose_matches_dense(): + dense, data, indices, indptr, shape = _known_csr() + B = np.arange(4 * 2, dtype=np.float32).reshape(4, 2) + out = bm.sparse.csrmm(data, indices, indptr, jnp.asarray(B), + shape=shape, transpose=False) + out = np.asarray(out) + assert out.shape == (3, 2) + np.testing.assert_allclose(out, dense @ B, rtol=1e-5, atol=1e-5) + + +def test_csrmm_transpose_matches_dense_T(): + # Regression for C-07: transpose branch must equal Mᵀ @ B (shape (cols, k)), + # NOT B @ M. With shape=(3,4) and B=(3,2) the output must be (4,2). + dense, data, indices, indptr, shape = _known_csr() + B = np.arange(3 * 2, dtype=np.float32).reshape(3, 2) + out = bm.sparse.csrmm(data, indices, indptr, jnp.asarray(B), + shape=shape, transpose=True) + out = np.asarray(out) + assert out.shape == (4, 2) # would be (3, ...) under the old bug + np.testing.assert_allclose(out, dense.T @ B, rtol=1e-5, atol=1e-5) + + +def test_csrmm_accepts_brainpy_array_operands(): + # Exercise the ``isinstance(x, Array)`` unwrapping branches in csr_mm.py. + dense, data, indices, indptr, shape = _known_csr() + B = np.arange(4 * 2, dtype=np.float32).reshape(4, 2) + out = bm.sparse.csrmm(bm.asarray(data), bm.asarray(indices), + bm.asarray(indptr), bm.asarray(B), + shape=shape, transpose=False) + np.testing.assert_allclose(np.asarray(out), dense @ B, rtol=1e-5, atol=1e-5) + + +def test_event_csrmm_transpose_matches_dense_T(): + # C-07 for brainpy/math/event/csr_matmat.py — binary event matrix. + dense, data, indices, indptr, shape = _known_csr() + B = (np.arange(3 * 2).reshape(3, 2) % 2).astype(np.float32) # 0/1 events + out = bm.event.csrmm(data, indices, indptr, jnp.asarray(B), + shape=shape, transpose=True) + out = np.asarray(out) + assert out.shape == (4, 2) + np.testing.assert_allclose(out, dense.T @ B, rtol=1e-5, atol=1e-5) + + +def test_event_csrmm_non_transpose_matches_dense(): + dense, data, indices, indptr, shape = _known_csr() + B = (np.arange(4 * 2).reshape(4, 2) % 2).astype(np.float32) + out = bm.event.csrmm(data, indices, indptr, jnp.asarray(B), + shape=shape, transpose=False) + np.testing.assert_allclose(np.asarray(out), dense @ B, rtol=1e-5, atol=1e-5) + + +def test_event_csrmm_accepts_brainpy_array_operands(): + # Exercise the ``isinstance(x, Array)`` unwrap branches in csr_matmat.py. + dense, data, indices, indptr, shape = _known_csr() + B = (np.arange(4 * 2).reshape(4, 2) % 2).astype(np.float32) + out = bm.event.csrmm(bm.asarray(data), bm.asarray(indices), + bm.asarray(indptr), bm.asarray(B), + shape=shape, transpose=False) + np.testing.assert_allclose(np.asarray(out), dense @ B, rtol=1e-5, atol=1e-5) + + +# =========================================================================== +# H-17 — coomv works without the removed brainevent.COO, equals dense @ v +# =========================================================================== + +def test_coomv_matches_dense_no_attribute_error(): + dense, _, _, _, shape = _known_csr() + rows, cols = np.nonzero(dense) + vals = dense[rows, cols].astype(np.float32) + row = jnp.asarray(rows, dtype=jnp.int32) + col = jnp.asarray(cols, dtype=jnp.int32) + data = jnp.asarray(vals, dtype=jnp.float32) + v = np.arange(4, dtype=np.float32) + # Must not raise AttributeError about a removed COO type. + out = bm.sparse.coomv(data, row, col, jnp.asarray(v), shape=shape, transpose=False) + np.testing.assert_allclose(np.asarray(out), dense @ v, rtol=1e-5, atol=1e-5) + + +def test_coomv_transpose_matches_dense_T(): + dense, _, _, _, shape = _known_csr() + rows, cols = np.nonzero(dense) + data = jnp.asarray(dense[rows, cols].astype(np.float32)) + row = jnp.asarray(rows, dtype=jnp.int32) + col = jnp.asarray(cols, dtype=jnp.int32) + v = np.arange(3, dtype=np.float32) + out = bm.sparse.coomv(data, row, col, jnp.asarray(v), shape=shape, transpose=True) + np.testing.assert_allclose(np.asarray(out), dense.T @ v, rtol=1e-5, atol=1e-5) + + +def test_coomv_scalar_weight_broadcast(): + # Exercise the scalar-data broadcast branch of coomv. + dense, _, _, _, shape = _known_csr() + rows, cols = np.nonzero(dense) + row = jnp.asarray(rows, dtype=jnp.int32) + col = jnp.asarray(cols, dtype=jnp.int32) + v = np.ones(4, dtype=np.float32) + out = bm.sparse.coomv(2.0, row, col, jnp.asarray(v), shape=shape, transpose=False) + # mask of the structure, weight 2.0 everywhere + mask = (dense != 0).astype(np.float32) + np.testing.assert_allclose(np.asarray(out), (2.0 * mask) @ v, rtol=1e-5, atol=1e-5) + + +def test_coomv_accepts_brainpy_array_operands(): + # Exercise the ``isinstance(x, Array)`` unwrap branches in coo_mv.py. + dense, _, _, _, shape = _known_csr() + rows, cols = np.nonzero(dense) + data = bm.asarray(dense[rows, cols].astype(np.float32)) + row = bm.asarray(jnp.asarray(rows, dtype=jnp.int32)) + col = bm.asarray(jnp.asarray(cols, dtype=jnp.int32)) + v = bm.asarray(np.arange(4, dtype=np.float32)) + out = bm.sparse.coomv(data, row, col, v, shape=shape, transpose=False) + np.testing.assert_allclose(np.asarray(out), dense @ np.arange(4, dtype=np.float32), + rtol=1e-5, atol=1e-5) + + +# =========================================================================== +# H-18 / H-19 — coo_to_csr int indptr; csr_to_dense matches reference +# =========================================================================== + +def test_coo_to_csr_returns_int_monotone_indptr(): + pre_ids = jnp.asarray([0, 0, 1, 2], dtype=jnp.int32) + post_ids = jnp.asarray([0, 2, 1, 3], dtype=jnp.int32) + indices, indptr = bm.sparse.coo_to_csr(pre_ids, post_ids, num_row=3) + indptr_np = np.asarray(indptr) + # integer dtype (H-18: was float before the fix) + assert jnp.issubdtype(indptr.dtype, jnp.integer) + # monotone non-decreasing, starts at 0, ends at nnz + assert indptr_np[0] == 0 + assert indptr_np[-1] == int(post_ids.shape[0]) + assert np.all(np.diff(indptr_np) >= 0) + # round-trip sanity: indices length == nnz + assert np.asarray(indices).shape[0] == int(post_ids.shape[0]) + + +def test_csr_to_dense_matches_reference(): + dense, data, indices, indptr, shape = _known_csr() + out = bm.sparse.csr_to_dense(data, indices, indptr, shape=shape) + np.testing.assert_allclose(np.asarray(out), dense, rtol=1e-5, atol=1e-5) + + +def test_csr_to_coo_round_trip(): + # Exercises utils.csr_to_coo for coverage. + _, _, indices, indptr, _ = _known_csr() + row, col = bm.sparse.csr_to_coo(indices, indptr) + np.testing.assert_array_equal(np.asarray(row), [0, 0, 1, 2]) + np.testing.assert_array_equal(np.asarray(col), np.asarray(indices)) + + +# =========================================================================== +# sparse / event csrmv coverage (+ value correctness) +# =========================================================================== + +def test_csrmv_matches_dense_both_directions(): + dense, data, indices, indptr, shape = _known_csr() + v = np.arange(4, dtype=np.float32) + out = bm.sparse.csrmv(data, indices, indptr, jnp.asarray(v), + shape=shape, transpose=False) + np.testing.assert_allclose(np.asarray(out), dense @ v, rtol=1e-5, atol=1e-5) + + vt = np.arange(3, dtype=np.float32) + out_t = bm.sparse.csrmv(data, indices, indptr, jnp.asarray(vt), + shape=shape, transpose=True) + np.testing.assert_allclose(np.asarray(out_t), dense.T @ vt, rtol=1e-5, atol=1e-5) + + +def test_event_csrmv_matches_masked_dense(): + dense, data, indices, indptr, shape = _known_csr() + # transpose=True: Mᵀ @ events, events length == shape[0] == 3 + events = jnp.asarray([True, False, True], dtype=bool) + out = bm.event.csrmv(data, indices, indptr, events, shape=shape, transpose=True) + ref = dense.T @ np.asarray(events, dtype=np.float32) + np.testing.assert_allclose(np.asarray(out), ref, rtol=1e-5, atol=1e-5) + + # non-transpose: M @ events, events length == shape[1] == 4 + events2 = jnp.asarray([True, False, True, False], dtype=bool) + out2 = bm.event.csrmv(data, indices, indptr, events2, shape=shape, transpose=False) + ref2 = dense @ np.asarray(events2, dtype=np.float32) + np.testing.assert_allclose(np.asarray(out2), ref2, rtol=1e-5, atol=1e-5) + + +# =========================================================================== +# H-20..H-24 — surrogate gradients consistent with their forward functions +# =========================================================================== + +# The two modules expose slightly different APIs: +# _one_input.py : surrogate_grad(self, dz, x) (dz = upstream gradient) +# _one_input_new.py : surrogate_grad(self, x) +# These helpers normalise that so the same assertions run against both. + +def _grad_old(inst, x): + return inst.surrogate_grad(1.0, x) + + +def _grad_new(inst, x): + return inst.surrogate_grad(x) + + +_MODULES = [ + ("_one_input", old_surr, _grad_old), + ("_one_input_new", new_surr, _grad_new), +] + +# Surrogates that expose BOTH surrogate_fun and surrogate_grad and are +# differentiable away from kinks — used for the grad-vs-autograd check. +_HAS_FUN = ["PiecewiseQuadratic", "QPseudoSpike", "Arctan", "ERF"] + + +@pytest.mark.parametrize("modname,mod,getgrad", _MODULES, + ids=[m[0] for m in _MODULES]) +@pytest.mark.parametrize("clsname", _HAS_FUN) +def test_surrogate_grad_matches_autograd(modname, mod, getgrad, clsname): + """H-21/H-22/H-23/H-24: surrogate_grad == d/dx surrogate_fun.""" + cls = getattr(mod, clsname) + inst = cls() + # Avoid the exact kinks (|x| = 1/alpha) where the piecewise derivative jumps. + xs = jnp.asarray([-0.85, -0.4, -0.1, 0.1, 0.4, 0.85], dtype=jnp.float32) + analytic = np.asarray(getgrad(inst, xs)) + + def fun_scalar(v): + return jnp.squeeze(inst.surrogate_fun(jnp.reshape(v, (1,)))) + + autograd = np.asarray(jax.vmap(jax.grad(fun_scalar))(xs)) + np.testing.assert_allclose(analytic, autograd, rtol=2e-3, atol=2e-4) + + +@pytest.mark.parametrize("modname,mod,getgrad", _MODULES, + ids=[m[0] for m in _MODULES]) +@pytest.mark.parametrize("clsname", _HAS_FUN) +def test_surrogate_fun_monotone_increasing_on_unit_interval(modname, mod, getgrad, clsname): + """Each surrogate forward (origin) function is non-decreasing on [0, 1].""" + cls = getattr(mod, clsname) + inst = cls() + fx = np.asarray(inst.surrogate_fun(jnp.linspace(0.0, 1.0, 21))) + assert np.all(np.diff(fx) >= -1e-6) + + +@pytest.mark.parametrize("modname,mod,getgrad", _MODULES, + ids=[m[0] for m in _MODULES]) +def test_arctan_surrogate_fun_does_not_raise(modname, mod, getgrad): + """H-23: Arctan.surrogate_fun previously called jnp.arctan2 with one arg.""" + inst = mod.Arctan() + out = np.asarray(inst.surrogate_fun(jnp.asarray([-0.5, -0.1, 0.0, 0.1, 0.5]))) + assert np.all(np.isfinite(out)) + # arctan forward crosses 0.5 at x = 0 + assert np.isclose(out[2], 0.5, atol=1e-6) + assert np.all(np.diff(out) > 0) + + +@pytest.mark.parametrize("modname,mod,getgrad", _MODULES, + ids=[m[0] for m in _MODULES]) +def test_erf_surrogate_fun_is_increasing(modname, mod, getgrad): + """H-24: ERF.surrogate_fun must be increasing (was decreasing before).""" + inst = mod.ERF() + out = np.asarray(inst.surrogate_fun(jnp.linspace(-0.5, 0.5, 11))) + assert np.all(np.diff(out) > 0) + assert np.isclose(out[5], 0.5, atol=1e-6) # centred at x = 0 + + +@pytest.mark.parametrize("modname,mod,getgrad", _MODULES, + ids=[m[0] for m in _MODULES]) +def test_gaussian_grad_bump_widens_with_sigma(modname, mod, getgrad): + """H-20: GaussianGrad — at x=1 the gradient must INCREASE with sigma + (a wider bump), proving the sigma is no longer inverted by the + operator-precedence bug ``exp(-(x**2)/2*sigma**2)``.""" + g_narrow = float(np.asarray(getgrad(mod.GaussianGrad(sigma=0.5), jnp.asarray(1.0)))) + g_wide = float(np.asarray(getgrad(mod.GaussianGrad(sigma=2.0), jnp.asarray(1.0)))) + assert g_wide > g_narrow + # Sanity on the intended magnitude (audit: grad@±1 ≈ 0.088 for sigma=2). + assert g_wide == pytest.approx(0.088, abs=2e-2) + + +@pytest.mark.parametrize("modname,mod,getgrad", _MODULES, + ids=[m[0] for m in _MODULES]) +def test_piecewise_quadratic_grad_formula(modname, mod, getgrad): + """H-21: grad == -alpha**2 |x| + alpha inside the support, 0 outside.""" + inst = mod.PiecewiseQuadratic(alpha=1.0) + g_in = float(np.asarray(getgrad(inst, jnp.asarray(0.5)))) + assert g_in == pytest.approx(-1.0 * 0.5 + 1.0) # = 0.5 + g_out = float(np.asarray(getgrad(inst, jnp.asarray(5.0)))) + assert g_out == pytest.approx(0.0) + + +@pytest.mark.parametrize("modname,mod,getgrad", _MODULES, + ids=[m[0] for m in _MODULES]) +def test_qpseudospike_grad_uses_alpha_minus_one(modname, mod, getgrad): + """H-22: grad denominator uses (alpha-1); grad at 0 == 1.""" + inst = mod.QPseudoSpike(alpha=2.0) + g0 = float(np.asarray(getgrad(inst, jnp.asarray(0.0)))) + assert g0 == pytest.approx(1.0, abs=1e-6) + + +# =========================================================================== +# Surrogate coverage — every class's __call__ + surrogate_grad in both modules +# =========================================================================== + +def _new_surrogate_classes(): + return [getattr(new_surr, n) for n in new_surr.__all__ + if n[0].isupper() and n != "Surrogate"] + + +def _old_surrogate_classes(): + out = [] + for n in dir(old_surr): + obj = getattr(old_surr, n) + if (isinstance(obj, type) and issubclass(obj, old_surr.Surrogate) + and n not in ("Surrogate", "_OneInpSurrogate")): + out.append(obj) + return out + + +@pytest.mark.parametrize("cls", _new_surrogate_classes(), + ids=lambda c: c.__name__) +def test_new_surrogate_call_and_grad_run(cls): + inst = cls() + x = jnp.linspace(-1.5, 1.5, 9) + y = inst(x) # __call__ -> heaviside forward + assert np.asarray(y).shape == (9,) + # forward is a {0,1} spike indicator + assert set(np.unique(np.asarray(y)).tolist()).issubset({0.0, 1.0}) + g = np.asarray(inst.surrogate_grad(x)) # surrogate_grad(x) + assert g.shape == (9,) and np.all(np.isfinite(g)) + # grad flows through __call__ + flow = jax.grad(lambda v: jnp.sum(inst(v)))(x) + assert np.all(np.isfinite(np.asarray(flow))) + + +@pytest.mark.parametrize("cls", _old_surrogate_classes(), + ids=lambda c: c.__name__) +def test_old_surrogate_call_and_grad_run(cls): + inst = cls() + x = jnp.linspace(-1.5, 1.5, 9) + y = inst(x) # custom-gradient forward + assert np.asarray(y).shape == (9,) + assert set(np.unique(np.asarray(y)).tolist()).issubset({0.0, 1.0}) + g = np.asarray(inst.surrogate_grad(1.0, x)) # surrogate_grad(dz, x) + assert g.shape == (9,) and np.all(np.isfinite(g)) + flow = jax.grad(lambda v: jnp.sum(inst(v)))(x) + assert np.all(np.isfinite(np.asarray(flow))) + + +def test_new_surrogate_repr_and_functional_aliases(): + # Exercise functional (lowercase) entry points + __repr__ for coverage. + x = jnp.linspace(-1.0, 1.0, 5) + assert "Arctan" in repr(new_surr.Arctan()) + for fn in (new_surr.sigmoid, new_surr.arctan, new_surr.erf, + new_surr.gaussian_grad, new_surr.relu_grad): + assert np.asarray(fn(x)).shape == (5,) + + +def test_old_surrogate_repr_and_functional_aliases(): + x = jnp.linspace(-1.0, 1.0, 5) + assert "GaussianGrad" in repr(old_surr.GaussianGrad()) + for fn in (old_surr.sigmoid, old_surr.arctan, old_surr.erf, + old_surr.gaussian_grad, old_surr.q_pseudo_spike): + assert np.asarray(fn(x)).shape == (5,) + + +# Lowercase functional aliases present in BOTH modules. +_FUNC_NAMES = [n for n in new_surr.__all__ if n[0].islower()] + + +@pytest.mark.parametrize("fname", _FUNC_NAMES) +def test_old_functional_alias_forward_and_origin(fname): + """Exercise every ``_one_input`` functional alias (heaviside forward and, + where supported, the ``origin=True`` smooth forward).""" + import inspect + fn = getattr(old_surr, fname) + x = jnp.linspace(-1.2, 1.2, 7) + y = np.asarray(fn(x)) + assert y.shape == (7,) and np.all(np.isfinite(y)) + if "origin" in inspect.signature(fn).parameters: + yo = np.asarray(fn(x, origin=True)) # exercises surrogate_fun + assert yo.shape == (7,) and np.all(np.isfinite(yo)) + + +@pytest.mark.parametrize("fname", _FUNC_NAMES) +def test_new_functional_alias_forward(fname): + """Exercise every ``_one_input_new`` functional alias (heaviside forward).""" + fn = getattr(new_surr, fname) + x = jnp.linspace(-1.2, 1.2, 7) + y = np.asarray(fn(x)) + assert y.shape == (7,) and np.all(np.isfinite(y)) + + +def _new_classes_with_surrogate_fun(): + out = [] + for n in new_surr.__all__: + if not (n[0].isupper() and n != "Surrogate"): + continue + c = getattr(new_surr, n) + if c.surrogate_fun is not new_surr.Surrogate.surrogate_fun: + out.append(c) + return out + + +@pytest.mark.parametrize("cls", _new_classes_with_surrogate_fun(), + ids=lambda c: c.__name__) +def test_new_surrogate_fun_runs(cls): + """Cover the ``surrogate_fun`` body of every new-module class that has one.""" + out = np.asarray(cls().surrogate_fun(jnp.linspace(-1.2, 1.2, 9))) + assert out.shape == (9,) and np.all(np.isfinite(out)) + + +# =========================================================================== +# C-09 — TimeDelay ring-buffer read applies the modulo +# =========================================================================== + +def test_time_delay_ring_buffer_modulo(): + """After pushing a ramp k*0.1 for k=1..30, the buffer wraps several times; + the read must apply ``% num_delay_step`` (C-09) so d(now) == 3.0.""" + d = bm.TimeDelay(bm.zeros(1), delay_len=1.0, dt=0.1) + for k in range(1, 31): + d.update(bm.ones(1) * (k * 0.1)) + now = d.current_time[0] + assert float(np.asarray(d(now))[0]) == pytest.approx(3.0, abs=1e-4) + assert float(np.asarray(d(now - 0.5))[0]) == pytest.approx(2.5, abs=1e-4) + + +def test_time_delay_round_interp_method(): + d = bm.TimeDelay(bm.zeros(1), delay_len=1.0, dt=0.1, interp_method='round') + for k in range(1, 31): + d.update(bm.ones(1) * (k * 0.1)) + now = d.current_time[0] + # round interpolation hits the exact-step branch too + assert float(np.asarray(d(now))[0]) == pytest.approx(3.0, abs=1e-4) + + +def test_time_delay_reset(): + d = bm.TimeDelay(bm.zeros(2), delay_len=1.0, dt=0.1) + for k in range(1, 12): + d.update(bm.ones(2) * (k * 0.1)) + d.reset(bm.zeros(2), delay_len=1.0) + now = d.current_time[0] + np.testing.assert_allclose(np.asarray(d(now)), np.zeros(2), atol=1e-6) + + +# =========================================================================== +# LengthDelay — both ROTATE_UPDATE and CONCAT_UPDATE +# =========================================================================== + +@pytest.mark.parametrize("method", [bm.ROTATE_UPDATE, bm.CONCAT_UPDATE]) +def test_length_delay_update_methods(method): + ld = bm.LengthDelay(bm.zeros(2), delay_len=3, update_method=method) + for k in range(1, 6): + ld.update(bm.ones(2) * float(k)) + # most-recent push is 5 -> delay 0 returns 5; delay 2 returns 3 + np.testing.assert_allclose(np.asarray(ld(0)), [5., 5.], atol=1e-6) + np.testing.assert_allclose(np.asarray(ld(2)), [3., 3.], atol=1e-6) + + +def test_length_delay_reset_and_retrieve(): + ld = bm.LengthDelay(bm.zeros(2), delay_len=3) + ld.reset(bm.ones(2), delay_len=3) + out = ld.retrieve(1) + assert np.asarray(out).shape == (2,) + + +def test_length_delay_initial_delay_data_scalar_and_callable(): + # scalar initial_delay_data branch + ld = bm.LengthDelay(bm.zeros(2), delay_len=3, initial_delay_data=1.0) + np.testing.assert_allclose(np.asarray(ld.retrieve(2)), [1., 1.], atol=1e-6) + # callable initial_delay_data branch (plain lambda, no dtype kwarg) + ld2 = bm.LengthDelay(bm.zeros(2), delay_len=3, + initial_delay_data=lambda shape: jnp.ones(shape) * 7.0) + np.testing.assert_allclose(np.asarray(ld2.retrieve(2)), [7., 7.], atol=1e-6) + + +def test_length_delay_concat_single_step(): + # delay_len=0 -> num_delay_step=1 exercises the CONCAT_UPDATE short branch. + ld = bm.LengthDelay(bm.zeros(2), delay_len=0, update_method=bm.CONCAT_UPDATE) + ld.update(bm.ones(2) * 3.0) + np.testing.assert_allclose(np.asarray(ld(0)), [3., 3.], atol=1e-6) + + +def test_time_delay_callable_before_t0(): + # Covers the _FUNC_BEFORE path (callable before_t0 + cond branch in __call__). + d = bm.TimeDelay(bm.zeros(1), delay_len=1.0, dt=0.1, t0=0.0, + before_t0=lambda t: jnp.ones(1) * 9.0) + # request a time strictly before t0 -> uses before_t0 function + out = np.asarray(d(-0.5)) + np.testing.assert_allclose(out, [9.0], atol=1e-6) + + +def test_neutral_delay_aliases(): + # NeuTimeDelay / NeuLenDelay are thin aliases; just instantiate + call. + ntd = bm.NeuTimeDelay(bm.zeros(1), delay_len=0.5, dt=0.1) + ntd.update(bm.ones(1)) + assert np.asarray(ntd(ntd.current_time[0])).shape == (1,) + nld = bm.NeuLenDelay(bm.zeros(1), delay_len=2) + nld.update(bm.ones(1)) + assert np.asarray(nld(0)).shape == (1,) + + +def test_time_delay_array_before_t0_and_indices(): + # array before_t0 fills the pre-t0 buffer (the _DATA_BEFORE branch); + # indices select a sub-slice of the retrieved value. + d = bm.TimeDelay(bm.zeros(3), delay_len=1.0, dt=0.1, before_t0=5.0) + for k in range(1, 31): + d.update(bm.ones(3) * (k * 0.1)) + now = d.current_time[0] + out = np.asarray(d(now, indices=jnp.asarray([0, 2]))) + assert out.shape == (2,) + + +def test_time_delay_reset_with_callable_before_t0(): + d = bm.TimeDelay(bm.zeros(1), delay_len=1.0, dt=0.1) + d.reset(bm.zeros(1), delay_len=1.0, before_t0=lambda t: jnp.ones(1) * 4.0) + out = np.asarray(d(-0.5)) # before t0 -> uses callable + np.testing.assert_allclose(out, [4.0], atol=1e-6) + + +def test_time_delay_reset_with_array_before_t0(): + d = bm.TimeDelay(bm.zeros(2), delay_len=1.0, dt=0.1) + d.reset(bm.zeros(2), delay_len=1.0, before_t0=3.0) + now = d.current_time[0] + assert np.asarray(d(now)).shape == (2,) + + +def test_length_delay_retrieve_with_indices_and_repr(): + ld = bm.LengthDelay(bm.zeros(4), delay_len=3) + for k in range(1, 5): + ld.update(bm.ones(4) * float(k)) + out = np.asarray(ld.retrieve(1, jnp.asarray([0, 1]))) + assert out.shape == (2,) + assert "LengthDelay" in repr(ld) + assert ld.delay_shape[0] == 4 # num_delay_step (delay_len + 1) + + +def test_length_delay_update_from_variable_target(): + # update(value=None) pulls from the stored delay_target Variable. + target = bm.Variable(bm.ones(2) * 2.0) + ld = bm.LengthDelay(target, delay_len=2) + ld.update() # no explicit value -> uses delay_target.value + assert np.asarray(ld(0)).shape == (2,) + + +def test_time_delay_validation_errors(): + # invalid delay_target type + with pytest.raises(ValueError): + bm.TimeDelay([0.0], delay_len=1.0, dt=0.1) + # unsupported interpolation method + from brainpy._errors import UnsupportedError + with pytest.raises(UnsupportedError): + bm.TimeDelay(bm.zeros(1), delay_len=1.0, dt=0.1, interp_method='nope') + # unsupported before_t0 type + with pytest.raises(ValueError): + bm.TimeDelay(bm.zeros(1), delay_len=1.0, dt=0.1, before_t0='bad') + + +def test_length_delay_validation_errors(): + with pytest.raises(ValueError): + bm.LengthDelay([0.0], delay_len=2) + ld = bm.LengthDelay(bm.zeros(2), delay_len=2) + with pytest.raises(ValueError): + ld.reset(bm.zeros(2), delay_len=2, initial_delay_data='bad') + + +# =========================================================================== +# M-13 — jitconn mv_prob_* / event_mv_prob_* reproducible with explicit seed +# =========================================================================== + +_SHAPE = (8, 6) + + +def _vec_for(transpose): + n = _SHAPE[0] if transpose else _SHAPE[1] + return jnp.asarray(np.random.RandomState(0).randn(n).astype(np.float32)) + + +@pytest.mark.parametrize("transpose", [False, True]) +@pytest.mark.parametrize("outdim_parallel", [True, False]) +def test_mv_prob_homo_reproducible(transpose, outdim_parallel): + v = _vec_for(transpose) + kw = dict(weight=1.5, conn_prob=0.3, shape=_SHAPE, + transpose=transpose, outdim_parallel=outdim_parallel) + o1 = np.asarray(bm.jitconn.mv_prob_homo(v, seed=123, **kw)) + o2 = np.asarray(bm.jitconn.mv_prob_homo(v, seed=123, **kw)) + np.testing.assert_array_equal(o1, o2) # reproducible + assert o1.shape == (_SHAPE[1] if transpose else _SHAPE[0],) + + +def test_mv_prob_uniform_and_normal_reproducible(): + v = _vec_for(False) + ou1 = np.asarray(bm.jitconn.mv_prob_uniform(v, w_low=-1., w_high=1., + conn_prob=0.3, seed=7, shape=_SHAPE)) + ou2 = np.asarray(bm.jitconn.mv_prob_uniform(v, w_low=-1., w_high=1., + conn_prob=0.3, seed=7, shape=_SHAPE)) + np.testing.assert_array_equal(ou1, ou2) + + on1 = np.asarray(bm.jitconn.mv_prob_normal(v, w_mu=0., w_sigma=1., + conn_prob=0.3, seed=9, shape=_SHAPE)) + on2 = np.asarray(bm.jitconn.mv_prob_normal(v, w_mu=0., w_sigma=1., + conn_prob=0.3, seed=9, shape=_SHAPE)) + np.testing.assert_array_equal(on1, on2) + + +def test_mv_prob_homo_seed_none_runs(): + # Cover the ``seed is None`` host-RNG branch (documented as non-reproducible). + v = _vec_for(False) + out = np.asarray(bm.jitconn.mv_prob_homo(v, weight=1.0, conn_prob=0.3, + seed=None, shape=_SHAPE)) + assert out.shape == (_SHAPE[0],) and np.all(np.isfinite(out)) + + +@pytest.mark.parametrize("fn", [ + lambda ev, **k: bm.jitconn.event_mv_prob_homo(ev, 1.0, 0.3, **k), + lambda ev, **k: bm.jitconn.event_mv_prob_uniform(ev, -1.0, 1.0, 0.3, **k), + lambda ev, **k: bm.jitconn.event_mv_prob_normal(ev, 0.0, 1.0, 0.3, **k), +], ids=["homo", "uniform", "normal"]) +def test_event_mv_prob_reproducible(fn): + events = jnp.asarray(np.random.RandomState(1).rand(_SHAPE[1]) > 0.5) + o1 = np.asarray(fn(events, seed=11, shape=_SHAPE)) + o2 = np.asarray(fn(events, seed=11, shape=_SHAPE)) + np.testing.assert_array_equal(o1, o2) + assert o1.shape == (_SHAPE[0],) + + +def test_mv_prob_uniform_normal_transpose_and_array_args(): + # transpose=True branches + Array-operand unwrap branches for uniform/normal. + vrow = bm.asarray(np.random.RandomState(3).randn(_SHAPE[0]).astype(np.float32)) + ou = np.asarray(bm.jitconn.mv_prob_uniform(vrow, bm.asarray(-1.0), bm.asarray(1.0), + 0.3, seed=4, shape=_SHAPE, transpose=True)) + assert ou.shape == (_SHAPE[1],) and np.all(np.isfinite(ou)) + on = np.asarray(bm.jitconn.mv_prob_normal(vrow, bm.asarray(0.0), bm.asarray(1.0), + 0.3, seed=4, shape=_SHAPE, transpose=True)) + assert on.shape == (_SHAPE[1],) and np.all(np.isfinite(on)) + # homo with Array weight + Array vector + oh = np.asarray(bm.jitconn.mv_prob_homo(vrow, bm.asarray(1.0), 0.3, + seed=4, shape=_SHAPE, transpose=True)) + assert oh.shape == (_SHAPE[1],) + + +def test_get_weight_matrices(): + mh = np.asarray(bm.jitconn.get_homo_weight_matrix(1.0, 0.3, seed=1, shape=_SHAPE)) + mu = np.asarray(bm.jitconn.get_uniform_weight_matrix(-1., 1., 0.3, seed=1, shape=_SHAPE)) + mn = np.asarray(bm.jitconn.get_normal_weight_matrix(0., 1., 0.3, seed=1, shape=_SHAPE)) + for m in (mh, mu, mn): + assert m.shape == _SHAPE + + +def test_get_weight_matrices_transpose_and_seed_none(): + # transpose=True -> (cols, rows); seed=None branch; Array args unwrap. + mh = np.asarray(bm.jitconn.get_homo_weight_matrix(bm.asarray(1.0), 0.3, + seed=None, shape=_SHAPE, transpose=True)) + assert mh.shape == (_SHAPE[1], _SHAPE[0]) + mu = np.asarray(bm.jitconn.get_uniform_weight_matrix(bm.asarray(-1.0), bm.asarray(1.0), 0.3, + seed=None, shape=_SHAPE, transpose=True)) + assert mu.shape == (_SHAPE[1], _SHAPE[0]) + mn = np.asarray(bm.jitconn.get_normal_weight_matrix(bm.asarray(0.0), bm.asarray(1.0), 0.3, + seed=None, shape=_SHAPE, transpose=True)) + assert mn.shape == (_SHAPE[1], _SHAPE[0]) + + +# =========================================================================== +# M-15 — pre2post_mean / syn2post_mean / syn2post_softmax edge cases +# =========================================================================== + +def test_syn2post_mean_empty_group_is_zero(): + syn = jnp.asarray([1., 3., 5.]) + post_ids = jnp.asarray([0, 0, 2]) # post group 1 is empty + out = np.asarray(bm.syn2post_mean(syn, post_ids, 3)) + assert out[0] == pytest.approx(2.0) # mean(1, 3) + assert out[1] == pytest.approx(0.0) # empty -> 0, not NaN + assert out[2] == pytest.approx(5.0) + + +def test_syn2post_mean_propagates_genuine_nan(): + syn = jnp.asarray([1., np.nan, 5.]) + post_ids = jnp.asarray([0, 0, 2]) + out = np.asarray(bm.syn2post_mean(syn, post_ids, 3)) + assert np.isnan(out[0]) # genuine NaN must propagate + assert out[1] == pytest.approx(0.0) + assert out[2] == pytest.approx(5.0) + + +def test_syn2post_softmax_propagates_nan_and_normalizes(): + post_ids = jnp.asarray([0, 0, 2]) + # genuine NaN must not be silently zeroed + syn_nan = jnp.asarray([1., np.nan, 5.]) + out_nan = np.asarray(bm.syn2post_softmax(syn_nan, post_ids, 3)) + assert np.isnan(out_nan[0]) and np.isnan(out_nan[1]) + # clean input: each non-empty group's softmax weights sum to 1 + syn = jnp.asarray([1., 3., 5.]) + out = np.asarray(bm.syn2post_softmax(syn, post_ids, 3)) + assert out[0] + out[1] == pytest.approx(1.0, abs=1e-6) + assert out[2] == pytest.approx(1.0, abs=1e-6) + + +def test_pre2post_mean_scalar_and_vector_branches(): + post_ids = jnp.asarray([0, 0, 2]) + # scalar branch: constant broadcast to targeted posts, others 0 + pm = np.asarray(bm.pre2post_mean(2.0, 3, post_ids)) + np.testing.assert_allclose(pm, [2., 0., 2.], atol=1e-6) + # vector branch routes through syn2post_mean + pre_vals = jnp.asarray([10., 20., 30.]) + pre_ids = jnp.asarray([0, 1, 2]) + pmv = np.asarray(bm.pre2post_mean(pre_vals, 3, post_ids, pre_ids)) + # post 0 gets mean(pre[0], pre[1]) = 15, post 2 gets pre[2] = 30 + np.testing.assert_allclose(pmv, [15., 0., 30.], atol=1e-6) + + +def test_pre2post_reductions_and_pre2syn(): + post_ids = jnp.asarray([0, 0, 2]) + assert np.asarray(bm.pre2post_sum(2.0, 3, post_ids)).tolist() == [4., 0., 2.] + assert np.asarray(bm.pre2post_prod(2.0, 3, post_ids)).tolist() == [0., 0., 0.] + assert np.asarray(bm.pre2post_max(2.0, 3, post_ids)).tolist() == [2., 0., 2.] + assert np.asarray(bm.pre2post_min(2.0, 3, post_ids)).tolist() == [0., 0., 0.] + syn = bm.pre2syn(jnp.asarray([1., 2., 3.]), jnp.asarray([0, 2])) + np.testing.assert_allclose(np.asarray(syn), [1., 3.], atol=1e-6) + + +def test_syn2post_reductions(): + syn = jnp.asarray([1., 3., 5.]) + post_ids = jnp.asarray([0, 0, 2]) + np.testing.assert_allclose(np.asarray(bm.syn2post_sum(syn, post_ids, 3)), + [4., 0., 5.], atol=1e-6) + np.testing.assert_allclose(np.asarray(bm.syn2post_prod(syn, post_ids, 3)), + [3., 1., 5.], atol=1e-6) + # max of empty group is -inf, min is +inf (segment reduction identities) + assert np.asarray(bm.syn2post_max(syn, post_ids, 3))[0] == 3. + assert np.asarray(bm.syn2post_min(syn, post_ids, 3))[0] == 1. + + +def test_pre2post_reductions_vector_branch(): + # Vector pre_values + pre_ids exercises the heterogeneous gather branch. + pre_vals = jnp.asarray([1., 2., 3., 4.]) + pre_ids = jnp.asarray([0, 1, 2, 3]) + post_ids = jnp.asarray([0, 0, 1, 1]) + post_num = 2 + assert np.asarray(bm.pre2post_sum(pre_vals, post_num, post_ids, pre_ids)).tolist() == [3., 7.] + # prod / min accumulate against the zero-initialised output -> 0 here. + assert np.asarray(bm.pre2post_prod(pre_vals, post_num, post_ids, pre_ids)).tolist() == [0., 0.] + assert np.asarray(bm.pre2post_max(pre_vals, post_num, post_ids, pre_ids)).tolist() == [2., 4.] + assert np.asarray(bm.pre2post_min(pre_vals, post_num, post_ids, pre_ids)).tolist() == [0., 0.] + + +def test_pre2post_vector_without_pre_ids_raises(): + # The _raise_pre_ids_is_none guard fires for heterogeneous values w/o pre_ids. + from brainpy._errors import MathError + pre_vals = jnp.asarray([1., 2., 3.]) + post_ids = jnp.asarray([0, 1, 1]) + with pytest.raises(MathError): + bm.pre2post_sum(pre_vals, 2, post_ids) + + +def test_syn2post_bool_dtype_promotion(): + # bool syn_values -> promoted to int in every syn2post reduction. + # group 0 = {True, False} -> {1, 0}; group 1 = {True} -> {1} + syn = jnp.asarray([True, False, True], dtype=bool) + post_ids = jnp.asarray([0, 0, 1]) + assert np.asarray(bm.syn2post_sum(syn, post_ids, 2)).tolist() == [1, 1] + assert np.asarray(bm.syn2post_prod(syn, post_ids, 2)).tolist() == [0, 1] + assert np.asarray(bm.syn2post_max(syn, post_ids, 2)).tolist() == [1, 1] + assert np.asarray(bm.syn2post_min(syn, post_ids, 2)).tolist() == [0, 1] + np.testing.assert_allclose(np.asarray(bm.syn2post_mean(syn, post_ids, 2)), [0.5, 1.0], atol=1e-6) + sm = np.asarray(bm.syn2post_softmax(syn, post_ids, 2)) + assert np.all(np.isfinite(sm)) + + +def test_pre2post_event_sum(): + # CSR connectivity: pre 0 -> post {0,2}, pre 1 -> post {1}, pre 2 -> post {3} + indices = jnp.asarray([0, 2, 1, 3], dtype=jnp.int32) + indptr = jnp.asarray([0, 2, 3, 4], dtype=jnp.int32) + events = jnp.asarray([True, False, True], dtype=bool) + out = np.asarray(bm.pre2post_event_sum(events, (indices, indptr), 4, 1.0)) + # pre 0 fires -> +1 at posts 0 and 2; pre 2 fires -> +1 at post 3 + np.testing.assert_allclose(out, [1., 0., 1., 1.], atol=1e-6) diff --git a/tests/audit/test_object_transform_fixes.py b/tests/audit/test_object_transform_fixes.py new file mode 100644 index 000000000..9e690216d --- /dev/null +++ b/tests/audit/test_object_transform_fixes.py @@ -0,0 +1,1153 @@ +# -*- coding: utf-8 -*- +"""Regression + coverage tests for the BrainPy v2.7.8 object-transform audit +(see ``docs/issues-found-20260618.md``). + +This module exercises the fixes recorded in the audit for the +``brainpy/math/object_transform`` package: + +* ``jit.py`` — H-01 (``cls_jit`` no longer corrupts NEGATIVE + ``static_argnums``/``donate_argnums`` by shifting them, which previously + double-marked ``self`` static), M-02 (``donate_argnums`` shifted by +1 so + ``self`` is never donated), H-04 (``bm.jit(fn, dyn_vars=..., child_objs=...)`` + pops the legacy kwargs with a ``DeprecationWarning`` instead of forwarding + them into brainstate and raising ``TypeError``). +* ``controls.py`` — H-02 (``cond``/``for_loop``/``scan``/``while_loop`` + accept a ``Variable``/``Array`` in ``operands``), H-03 (``for_loop(jit=False)`` + with a zero-length pytree operand returns ``[]`` instead of crashing), + M-03 (``scan`` returns ``(carry, ys)``), M-05 (``ifelse`` builds mutually + exclusive conditions), M-06 (``while_loop`` body returning ``None`` raises). +* ``function.py`` — ``Partial``/``to_object`` behaviour and L-04 (``function`` + emits a ``DeprecationWarning``). +* ``_utils.py`` — ``warp_to_no_state_input_output`` strips/restores states. +* ``variables.py`` — C-25 (``var_dict`` round-trips through ``jax.jit``), + C-26 (``Variable`` keeps ``batch_axis``/``axis_names`` through + flatten/unflatten, ``jit``, ``grad``, ``vmap``), H-06 (``Variable.value`` + setter accepts a ``brainstate.State`` and a float64 numpy array and + canonicalizes), H-45 (``size_without_batch`` returns a shape tuple). +* ``base.py`` — H-05 (``.cpu()`` moves variables and injects no junk + attributes), H-08 (``register_implicit_vars`` accepts ``var_list``/``var_dict``). +* ``naming.py`` — H-07 (creating + discarding many named objects does not + raise ``UniqueNameError`` and the ``_name2id`` registry stays bounded). + +All tests assert the CORRECT post-fix behaviour. They use tiny array sizes so +the whole module runs in a few seconds. +""" + +import gc +import importlib +import warnings + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +import brainstate + +import brainpy as bp +import brainpy.math as bm +from brainpy._errors import MathError, UniqueNameError + +# NOTE: ``from brainpy.math.object_transform import jit`` would bind the *jit +# function* (re-exported in the package ``__init__``), not the submodule. Import +# the submodule explicitly so we can monkeypatch its module-level ``jit`` that +# ``cls_jit`` calls into. +jit_module = importlib.import_module('brainpy.math.object_transform.jit') +from brainpy.math.object_transform import naming +from brainpy.math.object_transform.collectors import ArrayCollector +from brainpy.math.object_transform.variables import ( + Variable, TrainVar, Parameter, VariableView, VarList, VarDict, +) +from brainpy.math.object_transform.base import BrainPyObject, FunAsObject +from brainpy.math.object_transform._utils import ( + warp_to_no_state_input_output, infer_dyn_vars, get_brainpy_object, +) + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + +class _Obj(bp.BrainPyObject): + """A tiny BrainPyObject owning a single Variable.""" + + def __init__(self, n=3): + super().__init__() + self.v = bm.Variable(jnp.ones(n)) + + +def _capture_cls_jit_argnums(static_argnums=None, donate_argnums=None): + """Apply ``cls_jit`` and capture the ``static_argnums``/``donate_argnums`` + that it forwards into the underlying ``jit`` call, *without* actually + invoking ``brainstate.transform.jit`` (which rejects negative argnums). + """ + captured = {} + orig_jit = jit_module.jit + + def spy(*args, **kwargs): + captured['static_argnums'] = kwargs.get('static_argnums') + captured['donate_argnums'] = kwargs.get('donate_argnums') + raise _Stop() + + class _Stop(Exception): + pass + + jit_module.jit = spy + try: + try: + jit_module.cls_jit( + static_argnums=static_argnums, + donate_argnums=donate_argnums, + )(lambda self, *a, **k: a) + except _Stop: + pass + finally: + jit_module.jit = orig_jit + return captured['static_argnums'], captured['donate_argnums'] + + +# =========================================================================== +# jit.py +# =========================================================================== + +def test_cls_jit_positive_static_argnums_shifted_once(): + """H-01: a positive user index N is shifted to N+1 (account for ``self``), + and ``self`` (index 0) is marked static exactly once.""" + static, donate = _capture_cls_jit_argnums(static_argnums=1) + assert static == (0, 2) + assert donate == () + + +def test_cls_jit_negative_static_argnums_not_corrupted(): + """H-01: the historical bug shifted ``-1`` to ``0`` and produced + ``(0, 0)`` (``self`` marked static twice + wrong target). The fix leaves + negative indices unshifted, so ``self`` (0) appears exactly once and the + negative index is preserved -- NOT collapsed into a duplicate ``0``.""" + static, donate = _capture_cls_jit_argnums(static_argnums=-1) + assert static == (0, -1) + # the corrupting outcome would have been (0, 0); make that explicit. + assert static != (0, 0) + assert static.count(0) == 1 + + +def test_cls_jit_list_static_argnums_dedup_and_shift(): + """H-01: list of positive indices are each shifted by +1, ``self`` is + prepended, and duplicates are removed.""" + static, _ = _capture_cls_jit_argnums(static_argnums=[0, 2]) + assert static == (0, 1, 3) + + +def test_cls_jit_donate_argnums_shifted_so_self_not_donated(): + """M-02: ``donate_argnums`` is shifted by +1, so a user index 1 becomes 2 + and ``self`` (index 0) is never donated.""" + static, donate = _capture_cls_jit_argnums(static_argnums=[0, 2], donate_argnums=1) + assert donate == (2,) + static2, donate2 = _capture_cls_jit_argnums(donate_argnums=[3, 4]) + assert donate2 == (4, 5) + # default static is just (self,) + assert static2 == (0,) + + +def test_cls_jit_runs_on_bound_method_with_positive_static(): + """H-01 end-to-end: a bound method jitted with a positive ``static_argnums`` + runs and mutates the owned Variable correctly.""" + + class Prog(bp.BrainPyObject): + def __init__(self): + super().__init__() + self.b = bm.Variable(jnp.zeros(2)) + + # user index 0 (``scale``, a hashable int) is static; after the +1 shift + # it becomes index 1 in the bound signature ``(self, scale, x)``. + @bm.cls_jit(static_argnums=0) + def run(self, scale, x): + self.b.value = self.b.value + scale * x + return self.b.value + + p = Prog() + out = p.run(2, jnp.ones(2)) + np.testing.assert_allclose(np.asarray(out), [2., 2.]) + + +def test_cls_jit_invalid_argnums_type_raises(): + with pytest.raises(ValueError): + bm.cls_jit(static_argnums=1.5)(lambda self: None) + with pytest.raises(ValueError): + bm.cls_jit(donate_argnums=1.5)(lambda self: None) + + +def test_jit_dyn_vars_child_objs_deprecation_pops_kwargs(): + """H-04: ``dyn_vars``/``child_objs`` are no longer forwarded to brainstate + (which would raise ``TypeError``); they are popped with a one-time + ``DeprecationWarning`` and the function still works.""" + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + fn = bm.jit(lambda x: x + 1, dyn_vars=[], child_objs=[]) + result = fn(jnp.asarray(1.0)) + assert float(result) == 2.0 + categories = [w.category for w in caught] + assert sum(issubclass(c, DeprecationWarning) for c in categories) >= 2 + + +def test_jit_on_object_method(): + """``bm.jit`` JIT-compiles a BrainPyObject bound method.""" + + class Hello(bp.BrainPyObject): + def __init__(self): + super().__init__() + self.a = bm.Variable(jnp.asarray(10.)) + + def transform(self): + self.a.value = self.a.value * 2 + return self.a.value + + h = Hello() + jfn = bm.jit(h.transform) + assert float(jfn()) == 20.0 + assert float(jfn()) == 40.0 + + +def test_jit_decorator_form_on_pure_function(): + @bm.jit + def selu(x, alpha=1.67, lmbda=1.05): + return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) + + out = selu(jnp.asarray([1.0, -1.0])) + assert out.shape == (2,) + + +def test_jit_static_argnums_on_pure_function(): + @bm.jit(static_argnums=(1,)) + def f(x, n): + return x + n + + assert float(f(jnp.asarray(1.0), 5)) == 6.0 + + +# =========================================================================== +# controls.py +# =========================================================================== + +def test_cond_accepts_variable_operand(): + """H-02: a Variable passed in ``operands`` of ``cond`` is unwrapped to a + raw jax array and does not raise.""" + a = bm.Variable(bm.zeros(2)) + + def true_f(x): + a.value = a.value + x + return a.value + + def false_f(x): + return a.value + + bm.cond(True, true_f, false_f, bm.Variable(bm.ones(2))) + np.testing.assert_allclose(np.asarray(a.value), [1., 1.]) + + +def test_cond_accepts_array_operand_and_scalar_operand(): + out = bm.cond(True, lambda x: x + 1, lambda x: x - 1, bm.asarray([1., 2.])) + np.testing.assert_allclose(np.asarray(out), [2., 3.]) + # scalar operand (wrapped into a tuple internally) + out2 = bm.cond(False, lambda x: x + 1, lambda x: x - 1, 5.0) + assert float(out2) == 4.0 + + +def test_for_loop_accepts_variable_operand(): + """H-02: ``bm.for_loop(lambda x: x+1, bm.arange(1,5))`` works with a + BrainPy Array operand.""" + out = bm.for_loop(lambda x: x + 1, bm.arange(1, 5)) + np.testing.assert_allclose(np.asarray(out).ravel(), [2, 3, 4, 5]) + + +def test_for_loop_variable_state_accumulation(): + a = bm.Variable(bm.zeros(1)) + b = bm.Variable(bm.ones(1)) + + def body(x): + a.value += x + b.value *= x + return a.value + + hist = bm.for_loop(body, operands=bm.arange(1, 5)) + np.testing.assert_allclose(np.asarray(hist).ravel(), [1, 3, 6, 10]) + np.testing.assert_allclose(np.asarray(a.value), [10.]) + np.testing.assert_allclose(np.asarray(b.value), [24.]) + + +def test_for_loop_multiple_operands(): + a = bm.Variable(bm.zeros(1)) + + def body(x, y): + a.value += x + y + return a.value + + hist = bm.for_loop(body, operands=(bm.arange(1, 5), bm.arange(2, 6))) + assert np.asarray(hist).shape == (4, 1) + + +def test_for_loop_jit_false_normal_path(): + a = bm.Variable(bm.zeros(1)) + + def body(x): + a.value += x + return a.value + + hist = bm.for_loop(body, bm.arange(1., 4.), jit=False) + np.testing.assert_allclose(np.asarray(hist).ravel(), [1, 3, 6]) + + +def test_for_loop_jit_false_zero_length_pytree_returns_empty(): + """H-03: ``for_loop(jit=False)`` with a zero-length *pytree* (dict) operand + must not crash on ``operands[0].shape`` -- the leading length is computed + from ``jax.tree.leaves``. It falls back to JIT mode (UserWarning) and + returns an empty result.""" + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + out = bm.for_loop(lambda d: d['x'] + 1, {'x': bm.arange(0.)}, jit=False) + assert len(np.asarray(out)) == 0 + assert any(issubclass(w.category, UserWarning) for w in caught) + + +def test_for_loop_progress_bar_variants(): + out = bm.for_loop(lambda x: x + 1, bm.arange(1, 4), progress_bar=True) + assert np.asarray(out).shape == (3,) + out2 = bm.for_loop(lambda x: x + 1, bm.arange(1, 4), progress_bar=2) + assert np.asarray(out2).shape == (3,) + + +def test_for_loop_invalid_progress_bar_raises(): + with pytest.raises(TypeError): + bm.for_loop(lambda x: x + 1, bm.arange(1, 4), progress_bar="bad") + + +def test_scan_accepts_variable_operand_and_returns_carry_ys(): + """H-02 + M-03: ``scan`` accepts a Variable in operands and returns the + documented ``(final_carry, stacked_ys)`` two-tuple.""" + carry, ys = bm.scan(lambda c, x: (c + x, c), 0., bm.Variable(bm.arange(1., 4.))) + assert float(carry) == 6.0 + np.testing.assert_allclose(np.asarray(ys).ravel(), [0., 1., 3.]) + + +def test_while_loop_accepts_variable_operand(): + """H-02: ``while_loop`` accepts a Variable operand (it is unwrapped to a raw + jax array rather than being rejected at brainstate cache-key time). The body + returns the updated operands, preserving the single-operand tuple structure.""" + res = bm.while_loop(lambda x: (x + 1.,), lambda x: x < 3., (bm.Variable(jnp.asarray(0.)),)) + assert float(np.asarray(res[0])) == 3.0 + + +def test_while_loop_state_mutation(): + a = bm.Variable(bm.zeros(1)) + b = bm.Variable(bm.ones(1)) + + def cond_f(x, y): + return x < 6. + + def body_f(x, y): + a.value += x + b.value *= y + return x + b[0], y + 1. + + res = bm.while_loop(body_f, cond_f, operands=(1., 1.)) + assert len(res) == 2 + + +def test_while_loop_body_returning_none_raises(): + """M-06: a ``while_loop`` body that returns ``None`` would freeze the carry + and loop forever -- it must raise a clear ``ValueError`` instead.""" + + def body(x): + # returns None -> illegal + pass + + with pytest.raises(ValueError): + bm.while_loop(body, lambda x: x < 3., 0.) + + +def test_ifelse_callable_branches_mutually_exclusive(): + """M-05: ``ifelse`` resolves to the first matching branch (mutually + exclusive conditions, default branch last).""" + + def f(a): + return bm.ifelse( + conditions=[a > 10, a > 5, a > 2, a > 0], + branches=[lambda: 1, lambda: 2, lambda: 3, lambda: 4, lambda: 5], + ) + + assert int(f(11)) == 1 + assert int(f(7)) == 2 + assert int(f(1)) == 4 + assert int(f(-5)) == 5 + + +def test_ifelse_non_callable_branches(): + def f(a): + return bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0], + branches=[1, 2, 3, 4, 5]) + + assert int(f(3)) == 3 + + +def test_ifelse_with_operands_and_variable(): + out = bm.ifelse( + conditions=[True], + branches=[lambda x: x + 1, lambda x: x - 1], + operands=bm.Variable(jnp.asarray(10.)), + ) + assert float(out) == 11.0 + + +# =========================================================================== +# _utils.py +# =========================================================================== + +def test_warp_to_no_state_input_output_strips_and_restores(): + """``warp_to_no_state_input_output`` removes ``State`` wrappers from inputs + and outputs, passing through plain jax arrays.""" + + def fn(x): + # the wrapper should have unwrapped the State to a jax array + assert not isinstance(x, brainstate.State) + return x * 2 + + wrapped = warp_to_no_state_input_output(fn) + out = wrapped(bm.Variable(jnp.asarray(3.))) + assert not isinstance(out, brainstate.State) + assert float(out) == 6.0 + + +def test_warp_to_no_state_passes_missing_through(): + missing = brainstate.typing.Missing() + assert warp_to_no_state_input_output(missing) is missing + + +def test_infer_dyn_vars_and_get_brainpy_object(): + obj = _Obj() + dv = infer_dyn_vars(obj) + assert len(dv) >= 1 + mapping = get_brainpy_object(obj) + assert obj.name in mapping + # bound method path + bound = infer_dyn_vars(obj.vars) + assert isinstance(bound, ArrayCollector) + # non-object path returns empty collector / dict + assert len(infer_dyn_vars(lambda: None)) == 0 + assert get_brainpy_object(lambda: None) == {} + + +# =========================================================================== +# function.py +# =========================================================================== + +def test_partial_binds_positional_args(): + add = bm.Partial(lambda x, y: x + y, 1) + assert add(2) == 3 + + +def test_partial_keyword_override_and_is_brainpy_object(): + p = bm.Partial(lambda x, scale=1.: x * scale, scale=2.) + assert isinstance(p, FunAsObject) + assert float(p(3.)) == 6.0 + # call-time keyword overrides bound keyword + assert float(p(3., scale=10.)) == 30.0 + + +def test_partial_tracks_variables(): + sub = _Obj() + p = bm.Partial(lambda: sub.v.value, child_objs=sub) + np.testing.assert_allclose(np.asarray(p()), np.ones(3)) + assert sub.name in p.nodes() + + +def test_to_object_decorator_and_direct(): + sub = _Obj() + obj = bm.to_object(lambda: sub.v.value, child_objs=sub) + np.testing.assert_allclose(np.asarray(obj()), np.ones(3)) + + @bm.to_object(child_objs=sub) + def fn(): + return sub.v.value + + np.testing.assert_allclose(np.asarray(fn()), np.ones(3)) + + +def test_to_object_requires_child_objs_when_f_given(): + with pytest.raises(ValueError): + bm.to_object(lambda: 1.) + + +def test_function_is_deprecated(): + """L-04: ``function`` is deprecated in favour of ``to_object``.""" + sub = _Obj() + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + obj = bm.function(lambda: sub.v.value, nodes=sub) + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + np.testing.assert_allclose(np.asarray(obj()), np.ones(3)) + + +# =========================================================================== +# variables.py +# =========================================================================== + +def test_var_dict_round_trips_through_jax_jit(): + """C-25: ``VarDict.tree_unflatten`` no longer uses the removed ``jax.util``; + a ``var_dict`` survives ``jax.jit``.""" + d = bm.var_dict({'a': bm.Variable(jnp.ones(2))}) + out = jax.jit(lambda dd: dd)(d) + assert isinstance(out, VarDict) + assert list(out.keys()) == ['a'] + np.testing.assert_allclose(np.asarray(out['a'].value), np.ones(2)) + + +def test_var_dict_tree_map(): + d = bm.var_dict({'a': bm.Variable(jnp.ones(2)), 'b': bm.Variable(jnp.zeros(2))}) + leaves = jax.tree.leaves(d) + assert len(leaves) == 2 + + +def test_var_list_round_trips_through_pytree(): + vl = bm.var_list([bm.Variable(jnp.ones(1)), bm.Variable(jnp.zeros(2))]) + leaves, treedef = jax.tree.flatten(vl) + rebuilt = jax.tree.unflatten(treedef, leaves) + assert isinstance(rebuilt, VarList) + assert len(rebuilt) == 2 + + +def test_variable_keeps_metadata_through_flatten_unflatten(): + """C-26: ``batch_axis``/``axis_names`` survive a manual pytree round-trip.""" + v = bm.Variable(jnp.zeros((3, 4)), batch_axis=0, axis_names=('a', 'b')) + assert v.batch_axis == 0 + assert v.axis_names == ('a', 'b') + leaves, treedef = jax.tree.flatten(v) + v2 = jax.tree.unflatten(treedef, leaves) + assert isinstance(v2, Variable) + assert v2.batch_axis == 0 + assert v2.axis_names == ('a', 'b') + + +def test_variable_keeps_metadata_through_jit(): + """C-26: metadata preserved through ``jax.jit``.""" + v = bm.Variable(jnp.zeros((3, 4)), batch_axis=0, axis_names=('a', 'b')) + out = jax.jit(lambda x: x)(v) + assert out.batch_axis == 0 + assert out.axis_names == ('a', 'b') + + +def test_variable_keeps_metadata_through_grad(): + """C-26: metadata preserved through ``jax.grad``.""" + v = bm.Variable(jnp.ones((3, 4)), batch_axis=0, axis_names=('a', 'b')) + g = jax.grad(lambda x: jnp.sum(x.value ** 2))(v) + assert isinstance(g, Variable) + assert g.batch_axis == 0 + assert g.axis_names == ('a', 'b') + + +def test_variable_keeps_metadata_through_vmap(): + """C-26: metadata preserved through ``jax.vmap``.""" + v = bm.Variable(jnp.zeros((3, 4)), batch_axis=0, axis_names=('a', 'b')) + out = jax.vmap(lambda x: x)(v) + assert isinstance(out, Variable) + assert out.batch_axis == 0 + + +def test_variable_value_setter_accepts_brainstate_state(): + """H-06: ``Variable.value = some_State`` unwraps the state first.""" + v = bm.Variable(jnp.zeros(3, dtype=jnp.float32)) + st = brainstate.State(jnp.ones(3, dtype=jnp.float32)) + v.value = st + np.testing.assert_allclose(np.asarray(v.value), np.ones(3)) + + +def test_variable_value_setter_canonicalizes_float64_numpy(): + """H-06: assigning a float64 numpy array to a float32 Variable canonicalizes + (converts) it rather than raising a spurious dtype ``MathError``.""" + v = bm.Variable(jnp.zeros(3, dtype=jnp.float32)) + arr = np.ones(3, dtype=np.float64) + v.value = arr # must not raise + assert v.dtype == jnp.float32 + np.testing.assert_allclose(np.asarray(v.value), np.ones(3)) + + +def test_variable_value_setter_accepts_brainpy_array(): + v = bm.Variable(jnp.zeros(3, dtype=jnp.float32)) + v.value = bm.asarray(np.arange(3), dtype=jnp.float32) + np.testing.assert_allclose(np.asarray(v.value), [0., 1., 2.]) + + +def test_variable_value_setter_shape_mismatch_raises(): + v = bm.Variable(jnp.zeros(3)) + with pytest.raises(MathError): + v.value = jnp.zeros(4) + + +def test_size_without_batch_returns_shape_tuple(): + """H-45: ``size_without_batch`` returns a *shape tuple* (drops the batch + axis), not an integer element count.""" + v = bm.Variable(jnp.zeros((3, 4)), batch_axis=0) + assert v.size_without_batch == (4,) + v2 = bm.Variable(jnp.zeros((3, 4))) + assert v2.size_without_batch == (3, 4) + + +def test_variable_batch_size_and_axis(): + v = bm.Variable(jnp.zeros((5, 4)), batch_axis=0) + assert v.batch_size == 5 + v2 = bm.Variable(jnp.zeros(4)) + assert v2.batch_size is None + + +def test_variable_batch_axis_immutable(): + v = bm.Variable(jnp.zeros((5, 4)), batch_axis=0) + with pytest.raises(ValueError): + v.batch_axis = 1 + with pytest.raises(ValueError): + v.batch_size = 2 + + +def test_variable_invalid_batch_axis_raises(): + with pytest.raises(MathError): + bm.Variable(jnp.zeros(3), batch_axis=5) + + +def test_variable_init_from_size_and_hash(): + v = bm.Variable(4) + assert v.shape == (4,) + np.testing.assert_allclose(np.asarray(v.value), np.zeros(4)) + # identity hash + assert hash(v) == id(v) + s = {v, v} + assert len(s) == 1 + + +def test_trainvar_and_parameter_are_variables(): + tv = bm.TrainVar(jnp.zeros(2)) + par = bm.Parameter(jnp.ones(2)) + assert isinstance(tv, Variable) + assert isinstance(par, Variable) + + +def test_variable_view_reads_and_writes_origin(): + origin = bm.Variable(jnp.arange(5.)) + view = bm.VariableView(origin, slice(None, 2, None)) + np.testing.assert_allclose(np.asarray(view.value), [0., 1.]) + view.value = jnp.asarray([10., 11.]) + np.testing.assert_allclose(np.asarray(origin.value)[:2], [10., 11.]) + assert 'VariableView' in repr(view) + + +def test_variable_view_requires_variable(): + with pytest.raises(ValueError): + bm.VariableView(jnp.zeros(3), slice(None)) + + +# =========================================================================== +# base.py +# =========================================================================== + +def test_cpu_moves_variable_and_injects_no_junk_attrs(): + """H-05: ``.cpu()`` iterates real Variables and moves them; it must not add + dict-valued junk attributes named after nodes.""" + obj = _Obj() + before = set(obj.__dict__.keys()) + returned = obj.cpu() + after = set(obj.__dict__.keys()) + assert after == before # no junk attributes injected + assert returned is obj + assert isinstance(obj.v, Variable) + np.testing.assert_allclose(np.asarray(obj.v.value), np.ones(3)) + + +def test_to_moves_variable_to_device(): + obj = _Obj() + dev = jax.devices('cpu')[0] + obj.to(device=dev) + assert isinstance(obj.v, Variable) + + +def test_register_implicit_vars_with_var_list(): + """H-08: ``register_implicit_vars(var_list([...]))`` flattens the container + into the ``ArrayCollector`` (which only accepts plain Variables).""" + obj = _Obj() + obj.register_implicit_vars(bm.var_list([bm.Variable(jnp.zeros(2))])) + assert len(obj.implicit_vars) == 1 + for v in obj.implicit_vars.values(): + assert isinstance(v, Variable) + + +def test_register_implicit_vars_with_var_dict(): + obj = _Obj() + obj.register_implicit_vars(bm.var_dict({'x': bm.Variable(jnp.zeros(2)), + 'y': bm.Variable(jnp.zeros(2))})) + assert len(obj.implicit_vars) == 2 + for v in obj.implicit_vars.values(): + assert isinstance(v, Variable) + + +def test_register_implicit_vars_plain_and_named(): + obj = _Obj() + obj.register_implicit_vars(bm.Variable(jnp.zeros(1)), named=bm.Variable(jnp.zeros(1))) + assert len(obj.implicit_vars) == 2 + + +def test_register_implicit_vars_rejects_non_variable(): + obj = _Obj() + with pytest.raises(ValueError): + obj.register_implicit_vars(123) + + +def test_vars_collects_variables_and_containers(): + class Multi(bp.BrainPyObject): + def __init__(self): + super().__init__() + self.a = bm.Variable(jnp.zeros(1)) + self.lst = bm.var_list([bm.Variable(jnp.zeros(1))]) + self.dct = bm.var_dict({'k': bm.Variable(jnp.zeros(1))}) + + m = Multi() + collected = m.vars() + assert len(collected) == 3 + assert len(m.train_vars()) == 0 + + +def test_nodes_and_state_dict_round_trip(): + class Parent(bp.BrainPyObject): + def __init__(self): + super().__init__() + self.child = _Obj() + self.w = bm.Variable(jnp.zeros(2)) + + p = Parent() + nodes = p.nodes() + assert p.name in nodes + sd = p.state_dict() + assert isinstance(sd, dict) + # load it back + p.load_state_dict(sd, warn=False) + + +def test_brainpyobject_setattr_updates_variable_value(): + """Setting an attribute that is a Variable updates its value in place.""" + obj = _Obj(n=2) + vid = id(obj.v) + obj.v = jnp.asarray([5., 6.]) + assert id(obj.v) == vid # same Variable object + np.testing.assert_allclose(np.asarray(obj.v.value), [5., 6.]) + + +def test_funasobject_call_and_repr(): + sub = _Obj() + fo = FunAsObject(target=lambda: sub.v.value, child_objs=sub) + np.testing.assert_allclose(np.asarray(fo()), np.ones(3)) + assert 'FunAsObject' in repr(fo) + + +def test_objecttransform_base(): + from brainpy.math.object_transform.base import ObjectTransform + ot = ObjectTransform() + assert repr(ot) == 'ObjectTransform' + with pytest.raises(NotImplementedError): + ot() + + +def test_node_list_and_node_dict(): + from brainpy.math.object_transform.base import node_list, node_dict + nl = node_list([_Obj(), _Obj()]) + assert len(nl) == 2 + nd = node_dict({'a': _Obj()}) + assert 'a' in nd + + +def test_tracing_variable_raises_not_implemented(): + """L-06: ``tracing_variable`` is deprecated and always raises.""" + obj = _Obj() + with pytest.raises(NotImplementedError): + obj.tracing_variable('w', jnp.zeros, (2,)) + + +# =========================================================================== +# naming.py +# =========================================================================== + +def test_many_named_objects_do_not_raise_unique_name_error(): + """H-07: creating and discarding many named objects must not raise + ``UniqueNameError`` from reused ids.""" + bm.clear_name_cache() + for _ in range(300): + _Obj() # transient, immediately discarded + gc.collect() + # creating more after GC must still not collide + keep = [_Obj() for _ in range(5)] + assert len(keep) == 5 + + +def test_name2id_registry_stays_bounded(): + """H-07: the ``_name2id`` registry prunes dead weak refs and stays bounded + instead of growing unboundedly.""" + bm.clear_name_cache() + for _ in range(400): + _Obj() + gc.collect() + # all transient objects collected -> registry pruned back to (near) empty. + assert len(naming._name2id) <= 5 + + +def test_explicit_duplicate_name_raises_unique_name_error(): + bm.clear_name_cache() + keep = _Obj() + keep.name = 'my_unique_name' + other = _Obj() + with pytest.raises(UniqueNameError): + other.name = 'my_unique_name' + + +def test_invalid_identifier_name_raises(): + obj = _Obj() + with pytest.raises(bp.errors.BrainPyError): + obj.name = '123 not valid' + + +def test_get_unique_name_increments(): + n1 = naming.get_unique_name('Foo') + n2 = naming.get_unique_name('Foo') + assert n1 != n2 + assert n1.startswith('Foo') + + +def test_clear_name_cache_warns_when_requested(): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + bm.clear_name_cache(ignore_warn=False) + assert any(issubclass(w.category, UserWarning) for w in caught) + + +def test_stack_cache_helpers(): + def fn(): + return None + + naming.cache_stack(fn, [1, 2, 3]) + assert naming.get_stack_cache(fn) == [1, 2, 3] + assert naming.get_stack_cache(lambda: None) is None + naming.clear_stack_cache() + assert naming.get_stack_cache(fn) is None + + +# =========================================================================== +# collectors (ArrayCollector) exercised via base/vars APIs +# =========================================================================== + +def test_array_collector_basic_ops(): + ac = ArrayCollector() + v = bm.Variable(jnp.zeros(1)) + ac['x'] = v + assert len(ac.unique()) == 1 + assert 'x' in ac.dict() + # subset by type + tv = bm.TrainVar(jnp.zeros(1)) + ac['t'] = tv + assert len(ac.subset(TrainVar)) == 1 + + +# =========================================================================== +# Additional coverage for variables.py +# =========================================================================== + +def test_variable_init_from_list_of_int_size(): + v = bm.Variable([2, 3]) + assert v.shape == (2, 3) + + +def test_variable_axis_names_inserts_batch_axis_name(): + """The axis_names insertion path: when ``len(axis_names) + 1 == ndim`` the + batch-axis name is inserted at ``batch_axis``.""" + from brainpy.math.sharding import BATCH_AXIS + v = bm.Variable(jnp.zeros((5, 4)), batch_axis=0, axis_names=('feat',)) + assert len(v.axis_names) == 2 + assert v.axis_names[0] == BATCH_AXIS + assert v.axis_names[1] == 'feat' + + +def test_variable_value_setter_with_batch_axis(): + v = bm.Variable(jnp.zeros((5, 4)), batch_axis=0) + v.value = jnp.ones((7, 4)) # different batch size OK + assert v.shape == (7, 4) + with pytest.raises(MathError): + v.value = jnp.ones((5, 9)) # non-batch dim mismatch + + +def test_variable_value_setter_dtype_mismatch_raises(): + v = bm.Variable(jnp.zeros(3, dtype=jnp.float32)) + with pytest.raises(MathError): + v.value = jnp.zeros(3, dtype=jnp.int32) + + +def test_var_list_append_rejects_non_variable(): + vl = bm.var_list() + with pytest.raises(TypeError): + vl.append(jnp.zeros(2)) + + +def test_var_list_setitem_int_updates_value_slice_replaces(): + a = bm.Variable(jnp.zeros(1)) + b = bm.Variable(jnp.zeros(2)) + vl = bm.var_list([a, b]) + ids = (id(vl[0]), id(vl[1])) + vl[0] = jnp.ones(1) # updates value in place, keeps the same Variable + assert id(vl[0]) == ids[0] + np.testing.assert_allclose(np.asarray(vl[0].value), [1.]) + # slice assignment replaces entries + vl[0:1] = [bm.Variable(jnp.zeros(1))] + assert len(vl) == 2 + + +def test_var_dict_update_existing_key_sets_value(): + d = bm.var_dict({'a': bm.Variable(jnp.zeros(2))}) + original_id = id(d['a']) + d['a'] = jnp.ones(2) # update value, keep the same Variable + assert id(d['a']) == original_id + np.testing.assert_allclose(np.asarray(d['a'].value), np.ones(2)) + + +def test_var_dict_rejects_non_variable_element(): + with pytest.raises(TypeError): + bm.var_dict({'a': jnp.zeros(2)}) + + +def test_var_dict_update_from_tuple_and_kwargs(): + d = bm.var_dict(('a', bm.Variable(jnp.zeros(1))), b=bm.Variable(jnp.zeros(1))) + assert set(d.keys()) == {'a', 'b'} + + +def test_variable_view_setter_shape_and_dtype_checks(): + origin = bm.Variable(jnp.arange(5.)) + view = bm.VariableView(origin, slice(None, 2, None)) + with pytest.raises(MathError): + view.value = jnp.zeros(3) # wrong shape + + +# =========================================================================== +# Additional coverage for base.py +# =========================================================================== + +def test_relative_vars_and_nodes(): + class Parent(bp.BrainPyObject): + def __init__(self): + super().__init__() + self.child = _Obj() + self.w = bm.Variable(jnp.zeros(2)) + + p = Parent() + rel_vars = p.vars(method='relative') + assert len(rel_vars) >= 2 + rel_nodes = p.nodes(method='relative') + assert '' in rel_nodes + + +def test_vars_with_level_zero(): + class Parent(bp.BrainPyObject): + def __init__(self): + super().__init__() + self.child = _Obj() + self.w = bm.Variable(jnp.zeros(2)) + + p = Parent() + # level=0 -> only self's own variable, not the child's + own = p.vars(level=0) + assert len(own) == 1 + + +def test_register_implicit_nodes_variants(): + parent = _Obj() + parent.register_implicit_nodes(_Obj()) + parent.register_implicit_nodes([_Obj(), _Obj()]) + parent.register_implicit_nodes({'k': _Obj()}) + parent.register_implicit_nodes(named=_Obj()) + # nodes are keyed by name; distinct objects -> at least 5 distinct entries. + assert len(parent.implicit_nodes) >= 5 + + +def test_register_implicit_nodes_rejects_bad_type(): + parent = _Obj() + with pytest.raises(ValueError): + parent.register_implicit_nodes(123) + + +def test_register_implicit_vars_from_list_and_dict_args(): + obj = _Obj() + obj.register_implicit_vars([bm.Variable(jnp.zeros(1)), bm.Variable(jnp.zeros(1))]) + obj.register_implicit_vars({'k': bm.Variable(jnp.zeros(1))}) + assert len(obj.implicit_vars) == 3 + with pytest.raises(ValueError): + obj.register_implicit_vars([123]) + with pytest.raises(ValueError): + obj.register_implicit_vars({'bad': 123}) + with pytest.raises(ValueError): + obj.register_implicit_vars(named=123) + + +def test_node_dict_check_unique_raises_on_conflict(): + from brainpy.math.object_transform.base import NodeDict + a, b = _Obj(), _Obj() + nd = NodeDict(check_unique=True) + nd['k'] = a + nd['k'] = a # same object, OK + with pytest.raises(KeyError): + nd['k'] = b # different object under same key + + +def test_node_list_and_node_dict_in_find_nodes(): + from brainpy.math.object_transform.base import node_list, node_dict + + class Parent(bp.BrainPyObject): + def __init__(self): + super().__init__() + self.lst = node_list([_Obj(), _Obj()]) + self.dct = node_dict({'a': _Obj()}) + + p = Parent() + nodes = p.nodes() + # parent + 2 list children + 1 dict child + assert len(nodes) >= 4 + rel_nodes = p.nodes(method='relative') + assert len(rel_nodes) >= 4 + + +def test_funasobject_repr_with_nodes(): + sub = _Obj() + fo = FunAsObject(target=lambda: sub.v.value, child_objs=sub, + dyn_vars=bm.Variable(jnp.zeros(1))) + r = repr(fo) + assert 'FunAsObject' in r + assert 'num_of_vars' in r + + +def test_save_and_load_state_methods(): + obj = _Obj(n=2) + sd = obj.save_state() + assert isinstance(sd, dict) + missing, unexpected = obj.load_state(sd) + assert missing == [] and unexpected == [] + + +def test_vars_invalid_method_raises(): + obj = _Obj() + with pytest.raises(ValueError): + obj.nodes(method='bad') + + +def test_brainpyobject_tree_flatten_unflatten(): + class Mixed(bp.BrainPyObject): + def __init__(self): + super().__init__() + self.v = bm.Variable(jnp.ones(2)) # dynamic + self.lst = bm.var_list([bm.Variable(jnp.zeros(1))]) # dynamic container + self.scalar = 7 # static + + m = Mixed() + leaves, treedef = jax.tree.flatten(m) + assert len(leaves) >= 1 + rebuilt = jax.tree.unflatten(treedef, leaves) + assert isinstance(rebuilt, Mixed) + assert rebuilt.scalar == 7 + + +def test_brainpyobject_setattr_method_bypasses_variable_update(): + obj = _Obj() + # ``.setattr`` is the explicit object.__setattr__ wrapper + obj.setattr('plain', 123) + assert obj.plain == 123 + + +def test_load_state_dict_v1_and_warnings(): + obj = _Obj(n=2) + # build a v1-style flat state dict + flat = {k: np.asarray(v.value) for k, v in obj.vars().items()} + res = obj.load_state_dict(flat, compatible='v1', warn=False) + assert res.missing_keys == [] and res.unexpected_keys == [] + # unexpected + missing keys trigger warnings + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + obj.load_state_dict({'does_not_exist': np.zeros(2)}, compatible='v1', warn=True) + assert any(issubclass(w.category, UserWarning) for w in caught) + + +def test_load_state_dict_invalid_compatible_raises(): + obj = _Obj() + with pytest.raises(ValueError): + obj.load_state_dict({}, compatible='v9') + + +def test_vars_exclude_types(): + class Holder(bp.BrainPyObject): + def __init__(self): + super().__init__() + self.tv = bm.TrainVar(jnp.zeros(1)) + self.v = bm.Variable(jnp.zeros(1)) + + h = Holder() + # excluding TrainVar drops it from the collection + kept = h.vars(exclude_types=(TrainVar,)) + assert all(not isinstance(v, TrainVar) for v in kept.values()) + + +def test_unique_name_with_explicit_type(): + obj = _Obj() + name = obj.unique_name(type_='CustomType') + assert name.startswith('CustomType') + + +def test_node_dict_update_from_tuple(): + from brainpy.math.object_transform.base import NodeDict + nd = NodeDict(('a', _Obj())) + assert 'a' in nd + + +def test_brainpyobject_tree_flatten_unflatten_direct(): + """Exercise ``tree_flatten``/``tree_unflatten`` directly (the pytree + registration is gated off by default, so ``jax.tree`` would treat the + object as a leaf).""" + class Mixed(bp.BrainPyObject): + def __init__(self): + super().__init__() + self.v = bm.Variable(jnp.ones(2)) + self.scalar = 9 + + m = Mixed() + dynamic, aux = m.tree_flatten() + assert len(dynamic) == 1 + rebuilt = Mixed.tree_unflatten(aux, dynamic) + assert rebuilt.scalar == 9 + assert isinstance(rebuilt.v, Variable) + + +def test_relative_nodes_nested_hierarchy(): + """Cover the relative-method recursion that joins child paths.""" + class Leaf(bp.BrainPyObject): + def __init__(self): + super().__init__() + self.w = bm.Variable(jnp.zeros(1)) + + class Middle(bp.BrainPyObject): + def __init__(self): + super().__init__() + self.leaf = Leaf() + + class Top(bp.BrainPyObject): + def __init__(self): + super().__init__() + self.mid = Middle() + + t = Top() + rel = t.nodes(method='relative') + # joined paths like 'mid.leaf' appear + assert any('.' in k for k in rel.keys() if k) + rel_vars = t.vars(method='relative') + assert len(rel_vars) >= 1 + + +def test_cuda_tpu_raise_without_device(): + obj = _Obj() + with pytest.raises(RuntimeError): + obj.cuda() + with pytest.raises(RuntimeError): + obj.tpu() diff --git a/tests/audit/test_train_analysis_glue_fixes.py b/tests/audit/test_train_analysis_glue_fixes.py new file mode 100644 index 000000000..07d216a4e --- /dev/null +++ b/tests/audit/test_train_analysis_glue_fixes.py @@ -0,0 +1,665 @@ +# -*- coding: utf-8 -*- +"""Regression + coverage tests for the 2026-06-18 BrainPy audit. + +This module targets the train / analysis / top-level-glue fixes documented in +``docs/issues-found-20260618.md`` for the following source files: + +* ``brainpy/algorithms/online.py`` — C-23 (block RLS for batch>1) +* ``brainpy/algorithms/offline.py`` — H-46 (GD ``.value`` bug), + H-47 (ridge intercept penalty) +* ``brainpy/running/jax_multiprocessing.py`` — H-48 (pmap reuse / labels) +* ``brainpy/analysis/lowdim/lowdim_analyzer.py`` — H-49 (arg-unwrap), + H-50 (empty-candidate concat) +* ``brainpy/analysis/utils/optimization.py`` — H-49 (arg-unwrap in roots_of_1d_by_x) +* ``brainpy/runners.py`` — C-22 (memory_efficient DSRunner) +* ``brainpy/measure.py`` — H-43 (firing_rate normalization) +* ``brainpy/delay.py`` — H-44 (VarDelay self.data), + H-45 (size_without_batch) + +The tests are intentionally tiny (small nets, short durations) so the whole +module runs in well under four minutes. They assert the *fixed* behavior; on the +buggy pre-audit code each regression test would raise or diverge. +""" + +import warnings + +import numpy as np +import jax +import jax.numpy as jnp +import pytest + +import brainpy as bp +import brainpy.math as bm + +warnings.filterwarnings("ignore") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _fit_rls(batch_size, n_in=3, n_out=2, steps=400, seed=0): + """Fit a known linear map ``Y = X @ Wtrue`` with the RLS online algorithm. + + Returns (final_error, has_nan). + """ + from brainpy.algorithms.online import RLS + + rng = np.random.RandomState(seed) + w_true = jnp.asarray(rng.randn(n_in, n_out)) + rls = RLS(alpha=0.1) + rls.register_target(n_in, identifier="w") + weight = bm.Variable(jnp.zeros((n_in, n_out))) + for _ in range(steps): + x = jnp.asarray(rng.randn(batch_size, n_in)) + y = x @ w_true + out = x @ weight.value + dw = rls(y, x, out, identifier="w") + weight.value = weight.value + dw + err = float(jnp.linalg.norm(weight.value - w_true)) + has_nan = bool(jnp.isnan(weight.value).any()) + return err, has_nan + + +# --------------------------------------------------------------------------- +# online.py — C-23: block RLS valid for any batch size +# --------------------------------------------------------------------------- + +def test_rls_converges_batch1(): + """RLS fits a known linear map for batch=1 without divergence/NaN.""" + err, has_nan = _fit_rls(batch_size=1) + assert not has_nan + assert err < 0.5, f"RLS (B=1) did not converge: err={err}" + + +def test_rls_converges_batch4_no_nan(): + """C-23: block RLS must stay correct (no NaN / no divergence) for batch>1. + + On the pre-audit code the scalar ``c = sum(1/(1+HPHᵀ))`` collapse made the + covariance ``P`` diverge for B>=4 and produced NaN weights. + """ + err, has_nan = _fit_rls(batch_size=4) + assert not has_nan, "RLS (B=4) produced NaN weights (C-23 regression)" + assert err < 0.5, f"RLS (B=4) did not converge: err={err}" + + +def test_rls_batch1_matches_scalar_update(): + """For B==1 the block update reduces to the classic scalar RLS update.""" + from brainpy.algorithms.online import RLS + + rls = RLS(alpha=0.1) + rls.register_target(3, identifier="w") + x = jnp.array([[1.0, -2.0, 0.5]]) + target = jnp.array([[1.0, 0.0]]) + output = jnp.zeros((1, 2)) + dw = rls(target, x, output, identifier="w") + assert dw.shape == (3, 2) + assert not bool(jnp.isnan(dw).any()) + + +def test_lms_call_runs(): + """Coverage: LMS online algorithm produces a finite weight update.""" + from brainpy.algorithms.online import LMS + + lms = LMS(alpha=0.01) + x = jnp.ones((2, 3)) + target = jnp.ones((2, 2)) + output = jnp.zeros((2, 2)) + dw = lms(target, x, output) + assert dw.shape == (3, 2) + assert not bool(jnp.isnan(dw).any()) + + +def test_online_registry_helpers(): + """Coverage: the online method registry getters.""" + from brainpy.algorithms import online + + methods = online.get_supported_online_methods() + assert "rls" in methods and "lms" in methods + assert online.get("rls") is online.RLS + assert online.get("lms") is online.LMS + with pytest.raises(ValueError): + online.get("does-not-exist") + + +def test_online_trainer_and_force_trainer_fit(): + """Coverage: exercise online.py ``call`` through a tiny ESN + trainers. + + Drives ``brainpy.algorithms.online.RLS.call`` via the high-level + ``OnlineTrainer``/``ForceTrainer`` training loop on a small reservoir. + """ + + class ESN(bp.DynamicalSystem): + def __init__(self, num_in, num_hidden, num_out): + super().__init__() + self.r = bp.dyn.Reservoir( + num_in, num_hidden, + Win_initializer=bp.init.Uniform(-0.1, 0.1), + Wrec_initializer=bp.init.Normal(scale=0.1), + in_connectivity=0.1, rec_connectivity=0.1, + comp_type="dense", + ) + self.o = bp.dnn.Dense(num_hidden, num_out, + W_initializer=bp.init.Normal(), + mode=bm.training_mode) + + def update(self, x): + return x >> self.r >> self.o + + bm.random.seed(0) + bp.share.save(fit=True) + with bm.batching_environment(): + model = ESN(5, 25, 3) + x = bm.random.random((1, 50, 5)) + y = bm.random.random((1, 50, 3)) + + trainer = bp.OnlineTrainer(model, fit_method=bp.algorithms.RLS(alpha=0.1), + progress_bar=False) + trainer.fit([x, y]) + out = trainer.predict(x) + assert out.shape == (1, 50, 3) + + bm.random.seed(0) + with bm.batching_environment(): + model2 = ESN(5, 25, 3) + force = bp.ForceTrainer(model2, alpha=0.1, progress_bar=False) + force.fit([x, y]) + + +# --------------------------------------------------------------------------- +# offline.py — H-46 (GD .value bug) and H-47 (ridge intercept penalty) +# --------------------------------------------------------------------------- + +def _small_xy(n=20, n_in=3, n_out=2, seed=1): + rng = np.random.RandomState(seed) + return jnp.asarray(rng.randn(n, n_in)), jnp.asarray(rng.randn(n, n_out)) + + +def test_ridge_gradient_descent_runs_no_value_error(): + """H-46: ``gradient_descent=True`` must not raise ``.value`` AttributeError.""" + from brainpy.algorithms.offline import RidgeRegression + + x, y = _small_xy() + w = RidgeRegression(alpha=1e-6, gradient_descent=True, max_iter=50)(y, x) + assert w.shape == (3, 2) + assert not bool(jnp.isnan(w).any()) + + +def test_linear_regression_gradient_descent_runs(): + """H-46: the same GD code path through LinearRegression.""" + from brainpy.algorithms.offline import LinearRegression + + x, y = _small_xy() + w = LinearRegression(gradient_descent=True, max_iter=50)(y, x) + assert w.shape == (3, 2) + assert not bool(jnp.isnan(w).any()) + + +def test_lasso_always_gradient_descent_runs(): + """H-46: Lasso is always-GD; its body must not hit the ``.value`` bug.""" + from brainpy.algorithms.offline import LassoRegression + + x, y = _small_xy(n_out=1) + w = LassoRegression(alpha=0.1, degree=2, max_iter=50)(y, x) + assert w.ndim == 2 + assert not bool(jnp.isnan(w).any()) + + +def test_elastic_net_gradient_descent_runs(): + """H-46: ElasticNet is always-GD as well.""" + from brainpy.algorithms.offline import ElasticNetRegression + + x, y = _small_xy(n_out=1) + w = ElasticNetRegression(alpha=0.1, degree=2, max_iter=50)(y, x) + assert not bool(jnp.isnan(w).any()) + + +def test_ridge_intercept_not_over_penalized(): + """H-47: a large ridge ``alpha`` must not shrink the intercept/bias column. + + Fit ``y ≈ slope*x + intercept`` with a strongly nonzero mean and a huge + penalty. The bias column (index 0 of the polynomial features) must remain + close to the data mean while the slope is shrunk toward zero. + """ + from brainpy.algorithms.offline import PolynomialRidgeRegression + + rng = np.random.RandomState(2) + x = jnp.asarray(rng.randn(40, 1)) + mean_y = 5.0 + y = jnp.asarray(0.3 * np.asarray(x) + mean_y + 0.01 * rng.randn(40, 1)) + + model = PolynomialRidgeRegression(alpha=1e6, degree=1, add_bias=True, + gradient_descent=False) + w = model(y, x) + intercept = float(w[0, 0]) + slope = float(w[1, 0]) + # Intercept stays near the data mean despite the huge penalty. + assert abs(intercept - mean_y) < 0.5, f"intercept over-penalized: {intercept}" + # The (penalized) slope is shrunk toward zero. + assert abs(slope) < abs(intercept), f"slope not shrunk: slope={slope}" + + +def test_logistic_regression_remaining_bug(): + """Document a remaining (un-audited) bug in ``LogisticRegression.call``. + + ``call`` flattens ``targets`` to 1-D and then indexes ``targets.shape[1]``, + which raises ``IndexError: tuple index out of range``. This is independent + of the H-46/H-47 ridge fixes and is *not* in the assigned fix scope, so the + test pins the current broken behavior (and exercises the code path) rather + than asserting success. See summary notes. + """ + from brainpy.algorithms.offline import LogisticRegression + + rng = np.random.RandomState(3) + x = jnp.asarray(rng.randn(30, 2)) + y = jnp.asarray((np.asarray(x)[:, :1] > 0).astype("float32")) + with pytest.raises(IndexError): + LogisticRegression(max_iter=100)(y, x) + + +def test_offline_least_square_and_polynomial(): + """Coverage: non-GD lstsq path and polynomial regression.""" + from brainpy.algorithms.offline import (LinearRegression, + PolynomialRegression, + RidgeRegression) + + x, y = _small_xy() + w_lin = LinearRegression()(y, x) # lstsq path + assert w_lin.shape[-1] == 2 + w_ridge = RidgeRegression(alpha=1e-3)(y, x) # ridge closed form (no bias) + assert w_ridge.shape == (3, 2) + w_poly = PolynomialRegression(degree=2, gradient_descent=True, max_iter=20)(y, x) + assert not bool(jnp.isnan(w_poly).any()) + + +def test_offline_registry_helpers(): + """Coverage: the offline method registry getters.""" + from brainpy.algorithms import offline + + methods = offline.get_supported_offline_methods() + for name in ("linear", "ridge", "lasso", "logistic"): + assert name in methods + assert offline.get("ridge") is offline.RidgeRegression + with pytest.raises(ValueError): + offline.get("nope") + + +# --------------------------------------------------------------------------- +# measure.py — H-43: firing_rate normalization +# --------------------------------------------------------------------------- + +def test_firing_rate_100hz_mean(): + """H-43: a true 100 Hz spike train must average to ~100 Hz. + + ``sp[::10] = 1`` with dt=1 ms is one spike every 10 ms = 100 Hz. The buggy + normalization (by requested ``width`` rather than the actual window length) + biased the smoothed rate so its mean drifted toward ~110 Hz. + """ + spikes = np.zeros((1000, 1)) + spikes[::10] = 1 + rate = bp.measure.firing_rate(spikes, width=10, dt=1.0) + assert abs(float(np.mean(rate)) - 100.0) < 5.0, f"mean rate={np.mean(rate)}" + + +def test_firing_rate_jax_mode(): + """Coverage: numpy=False (JIT-able) branch of firing_rate.""" + spikes = np.zeros((200, 2)) + spikes[::5] = 1 + rate = bp.measure.firing_rate(spikes, width=5, dt=1.0, numpy=False) + assert rate.shape[0] == 200 + + +def test_raster_plot(): + """Coverage: raster_plot returns (index, time) of spikes.""" + spikes = np.zeros((5, 3)) + spikes[1, 0] = 1 + spikes[3, 2] = 1 + times = np.arange(5) * 0.1 + index, time = bp.measure.raster_plot(spikes, times) + assert set(np.asarray(index).tolist()) == {0, 2} + assert len(time) == 2 + + +# --------------------------------------------------------------------------- +# delay.py — H-44 (VarDelay self.data) and register_entry / retrieve coverage +# --------------------------------------------------------------------------- + +def test_vardelay_constructs_and_updates_time_gt_zero(): + """H-44: ``VarDelay(target, time=T>0)`` must construct without AttributeError. + + The pre-audit ``_init_data`` read ``self.data`` before it was ever assigned, + raising ``AttributeError: 'data'`` for any positive delay time. + """ + target = bm.Variable(bm.zeros(4)) + delay = bp.delay.VarDelay(target, time=2.0) + assert delay.data is not None + assert delay.max_length > 0 + + bp.share.save(i=0, t=0.0, dt=bm.get_dt()) + delay.update(bm.ones(4)) # one update step + assert delay.data is not None + + +def test_vardelay_register_entry_and_retrieve(): + """Coverage: register_entry + at()/retrieve on a VarDelay.""" + target = bm.Variable(bm.arange(4.0)) + delay = bp.delay.VarDelay(target, time=2.0) + delay.register_entry("e1", delay_time=1.0) + delay.register_entry("e0", delay_time=0.0) + + bp.share.save(i=0, t=0.0, dt=bm.get_dt()) + delay.update(bm.arange(4.0) + 10.0) + + # zero-delay entry returns the current target value + out0 = delay.at("e0") + assert out0.shape == (4,) + # nonzero-delay entry retrieves a buffered value + out1 = delay.at("e1") + assert out1.shape == (4,) + + with pytest.raises(KeyError): + delay.at("missing") + with pytest.raises(KeyError): + delay.register_entry("e1", delay_time=1.0) # duplicate + + +def test_vardelay_time_none_is_zero_length(): + """Coverage: ``time=None`` yields a zero-length (data-less) delay.""" + target = bm.Variable(bm.zeros(3)) + delay = bp.delay.VarDelay(target, time=None) + assert delay.max_length == 0 + assert delay.data is None + + +def test_length_delay_register_retrieve_update(): + """Coverage: ``brainpy.math.LengthDelay`` retrieve/update/__call__.""" + var = bm.Variable(bm.arange(4.0)) + ld = bm.LengthDelay(var, delay_len=3) + ld.update(bm.arange(4.0) + 10.0) + # delay 0 -> newest value + assert np.allclose(np.asarray(ld(0)), np.arange(4.0) + 10.0) + # delay 1 -> previous (initial zeros-ish) value + out1 = ld.retrieve(1) + assert out1.shape == (4,) + + +def test_data_delay_constructs(): + """Coverage: ``DataDelay`` (subclass of VarDelay) constructs with time>0.""" + target = bm.Variable(bm.zeros(3)) + dd = bp.delay.DataDelay(target, data_init=bm.zeros(3), time=1.0) + assert dd.data is not None + + +def test_vardelay_concat_update_and_init_by_return(): + """Coverage: CONCAT_UPDATE method, init_delay_by_return, DelayAccess.""" + from brainpy.math.delayvars import CONCAT_UPDATE + from brainpy.delay import init_delay_by_return, DelayAccess + + target = bm.Variable(bm.zeros(3)) + delay = bp.delay.VarDelay(target, time=1.0, method=CONCAT_UPDATE) + delay.register_entry("c", delay_step=2) + bp.share.save(i=0, t=0.0, dt=bm.get_dt()) + delay.update(bm.ones(3)) + assert delay.at("c").shape == (3,) + + # init_delay_by_return with a plain Variable -> VarDelay + dl = init_delay_by_return(bm.Variable(bm.zeros(2))) + assert isinstance(dl, bp.delay.VarDelay) + + # DelayAccess registers an entry on the delay and reads it back + access = DelayAccess(delay, 1.0, delay_entry="acc") + out = access.update() + assert out.shape == (3,) + + +def test_vardelay_wrong_target_type_raises(): + """Coverage: VarDelay rejects a non-Variable target.""" + with pytest.raises(ValueError): + bp.delay.VarDelay(bm.zeros(3), time=1.0) + + +# --------------------------------------------------------------------------- +# optimization.py + lowdim_analyzer.py — H-49 (arg-unwrap) and H-50 (empty concat) +# --------------------------------------------------------------------------- + +def test_roots_of_1d_by_x_finds_fixed_point(): + """H-49: ``roots_of_1d_by_x`` on ``dx=-x+I`` finds the fixed point x=I. + + Also exercises the arg-unwrap comprehension (passing a ``bm.Array`` arg). + """ + from brainpy.analysis.utils.optimization import roots_of_1d_by_x + + bp.math.enable_x64() + try: + offset = 0.7 + f = lambda x, b: -x + b + candidates = jnp.linspace(-2.0, 2.0, 401) + # pass the parameter as a bm.Array to drive the unwrap branch + fps = roots_of_1d_by_x(f, candidates, args=(bm.asarray(offset),)) + fps = np.asarray(fps) + assert fps.size >= 1 + assert np.any(np.abs(fps - offset) < 1e-3), f"fps={fps}" + finally: + bp.math.disable_x64() + + +def test_phase_plane_1d_finds_fixed_point(): + """H-49: a PhasePlane1D analyzer on ``dx=-x+I`` locates the fixed point ~x=I.""" + import matplotlib + matplotlib.use("Agg") + + bp.math.enable_x64() + try: + offset = 0.7 + + @bp.odeint + def int_x(x, t, Iext): + return -x + Iext + + analyzer = bp.analysis.PhasePlane1D( + model=int_x, + target_vars={"x": [-2.0, 2.0]}, + pars_update={"Iext": offset}, + resolutions=0.01, + ) + analyzer.plot_vector_field() + fps = analyzer.plot_fixed_point(show=False, with_return=True) + fps = np.asarray(fps).ravel() + assert fps.size >= 1 + assert np.any(np.abs(fps - offset) < 1e-2), f"fps={fps}" + finally: + import matplotlib.pyplot as plt + plt.close("all") + bp.math.disable_x64() + + +def test_lowdim_2d_empty_candidate_concat_guard(): + """H-50: the non-convertible 2D ``_get_fixed_points`` must guard the empty path. + + Build a 2D analyzer that cannot reduce to a single equation, then drive the + optimization branch with an impossible candidate-screening tolerance so that + nothing converges. The buggy code did ``jnp.concatenate([])`` -> ValueError; + the fix returns correctly-shaped empty arrays. + """ + from brainpy.analysis.lowdim.lowdim_analyzer import Num2DAnalyzer + + bp.math.enable_x64() + try: + @bp.odeint + def ds1(s1, t, s2): + return -s1 + jnp.tanh(s2) + 0.1 + + @bp.odeint + def ds2(s2, t, s1): + return -s2 + jnp.tanh(s1) + 0.1 + + analyzer = Num2DAnalyzer( + model=[ds1, ds2], + target_vars={"s1": [-2.0, 2.0], "s2": [-2.0, 2.0]}, + resolutions=0.05, + ) + assert not analyzer._can_convert_to_one_eq() + + candidates = jnp.asarray( + np.random.RandomState(0).uniform(-2.0, 2.0, size=(30, 2))) + # tol_opt_candidate=-1 screens out every candidate -> empty all_fps list + fps, ids, pargs = analyzer._get_fixed_points(candidates, + tol_opt_candidate=-1.0) + fps = np.asarray(fps) + assert fps.shape == (0, 2) + assert np.asarray(ids).shape == (0,) + finally: + bp.math.disable_x64() + + +def test_roots_of_1d_by_xy_and_brentq_helpers(): + """Coverage: roots_of_1d_by_xy and the scalar brentq helper functions.""" + from brainpy.analysis.utils import optimization as opt + + bp.math.enable_x64() + try: + f = lambda x, a: -x + a + + # roots_of_1d_by_xy on dx = -x + a, a = 0.5 + xs, ys = opt.roots_of_1d_by_xy(f, jnp.array([-2.0]), jnp.array([2.0]), + jnp.array([0.5])) + assert np.any(np.abs(np.asarray(xs) - 0.5) < 1e-6) + + # brentq_roots (jitted vmap brentq) + roots, _ = opt.brentq_roots(f, jnp.array([-2.0]), jnp.array([2.0]), + jnp.array([0.5])) + assert np.any(np.abs(np.asarray(roots) - 0.5) < 1e-6) + + # get_brentq_candidates over a 2D meshgrid + starts, ends, args = opt.get_brentq_candidates( + lambda x, y: -x + y, jnp.linspace(-2.0, 2.0, 20), + jnp.linspace(-1.0, 1.0, 5)) + assert len(np.asarray(starts)) == len(np.asarray(ends)) + + # pure-numpy brentq + 1D root finder + root, _iters, _calls = opt.numpy_brentq(lambda x: x - 0.3, -1.0, 1.0) + assert abs(root - 0.3) < 1e-9 + roots_np = opt.find_root_of_1d_numpy(lambda x: -x + 0.4, + np.linspace(-2.0, 2.0, 50)) + assert np.any(np.abs(np.asarray(roots_np) - 0.4) < 1e-6) + finally: + bp.math.disable_x64() + + +def test_phase_plane_1d_no_fixed_point_runs_clean(): + """Coverage: a 1D system with no real fixed point returns cleanly (no crash).""" + import matplotlib + matplotlib.use("Agg") + + bp.math.enable_x64() + try: + @bp.odeint + def int_x(x, t): + # dx = x^2 + 1 has no real root -> no fixed point in range + return x ** 2 + 1.0 + + analyzer = bp.analysis.PhasePlane1D( + model=int_x, + target_vars={"x": [-2.0, 2.0]}, + resolutions=0.01, + ) + fps = analyzer.plot_fixed_point(show=False, with_return=True) + assert np.asarray(fps).size == 0 + finally: + import matplotlib.pyplot as plt + plt.close("all") + bp.math.disable_x64() + + +# --------------------------------------------------------------------------- +# running/jax_multiprocessing.py — H-48 (pmap reuse / labels) +# --------------------------------------------------------------------------- + +def test_jax_vectorize_map_sequence_and_dict(): + """Coverage: jax_vectorize_map over sequence and dict arguments.""" + from brainpy.running.jax_multiprocessing import jax_vectorize_map + + out = jax_vectorize_map(lambda x: x * 2.0, [jnp.arange(6.0)], num_parallel=2) + assert np.allclose(np.asarray(out), np.arange(6.0) * 2.0) + + # NOTE: clear_buffer=True calls the process-global ``bm.clear_buffer_memory()`` + # which deletes ALL live device arrays (poisoning other test modules in the + # same pytest session). Patch it to a no-op so the branch is still covered. + _orig_clear = bm.clear_buffer_memory + bm.clear_buffer_memory = lambda *a, **k: None + try: + out2 = jax_vectorize_map(lambda a, b: a + b, + {"a": jnp.arange(4.0), "b": jnp.ones(4)}, + num_parallel=2, clear_buffer=True) + finally: + bm.clear_buffer_memory = _orig_clear + assert np.allclose(np.asarray(out2), np.arange(4.0) + 1.0) + + +def test_jax_vectorize_map_length_mismatch_raises(): + """Coverage: mismatched argument lengths raise ValueError.""" + from brainpy.running.jax_multiprocessing import jax_vectorize_map + + with pytest.raises(ValueError): + jax_vectorize_map(lambda a, b: a + b, + {"a": jnp.arange(4.0), "b": jnp.ones(3)}, + num_parallel=2) + + +def test_jax_parallelize_map_single_device(): + """H-48: jax_parallelize_map runs across chunks (one device per chunk).""" + from brainpy.running.jax_multiprocessing import jax_parallelize_map + + n_dev = jax.local_device_count() + out = jax_parallelize_map(lambda x: x * 2.0, + [jnp.arange(float(3 * n_dev))], + num_parallel=n_dev) + assert np.allclose(np.asarray(out), np.arange(float(3 * n_dev)) * 2.0) + + +# --------------------------------------------------------------------------- +# runners.py — C-22: DSRunner(memory_efficient=True) +# --------------------------------------------------------------------------- + +class _TinyNet(bp.DynamicalSystem): + def __init__(self): + super().__init__() + self.n = bp.dyn.LifRef(3) + + def update(self, inp): + self.n(inp) + return self.n.V.value + + +def test_dsrunner_memory_efficient_matches_normal(): + """C-22: ``memory_efficient=True`` must run and match the standard run. + + The pre-audit code did ``jax.ShapeDtypeStruct(mon.shape, ...)`` on a *dict* + monitor and used a broken ``pure_callback`` signature, so any + ``memory_efficient=True`` run raised ``AttributeError: 'dict' ... 'shape'``. + """ + inputs = bm.ones((15, 3)) * 2.0 + + bm.random.seed(0) + r_normal = bp.DSRunner(_TinyNet(), monitors=["n.V"], + memory_efficient=False, progress_bar=False) + r_normal.run(inputs=inputs) + mon_normal = np.asarray(r_normal.mon["n.V"]) + + bm.random.seed(0) + r_mem = bp.DSRunner(_TinyNet(), monitors=["n.V"], + memory_efficient=True, progress_bar=False) + r_mem.run(inputs=inputs) + mon_mem = np.asarray(r_mem.mon["n.V"]) + + assert mon_normal.shape == mon_mem.shape == (15, 3) + assert np.allclose(mon_normal, mon_mem), "memory_efficient run diverged" + + +def test_dsrunner_basic_run_with_monitors(): + """Coverage: a plain DSRunner run with monitors and duration.""" + bm.random.seed(0) + runner = bp.DSRunner(_TinyNet(), monitors=["n.V"], progress_bar=False) + runner.run(duration=2.0) + assert runner.mon["n.V"].shape[1] == 3 + assert runner.mon.ts.shape[0] == runner.mon["n.V"].shape[0]