diff --git a/README.md b/README.md
index ac1620dcc..d96396b38 100644
--- a/README.md
+++ b/README.md
@@ -11,6 +11,7 @@
+
diff --git a/brainpy/algorithms/offline.py b/brainpy/algorithms/offline.py
index 36ebc30b3..476b52285 100644
--- a/brainpy/algorithms/offline.py
+++ b/brainpy/algorithms/offline.py
@@ -156,7 +156,7 @@ def gradient_descent_solve(self, targets, inputs, outputs=None):
def cond_fun(a):
i, par_old, par_new = a
return jnp.logical_and(jnp.logical_not(jnp.allclose(par_old, par_new)),
- i < self.max_iter).value
+ i < self.max_iter)
def body_fun(a):
i, _, par_new = a
@@ -269,9 +269,17 @@ def call(self, targets, inputs, outputs=None):
if self.gradient_descent:
return self.gradient_descent_solve(targets, inputs)
else:
+ n_features = inputs.shape[-1]
temp = inputs.T @ inputs
if self.regularizer.alpha > 0.:
- temp += self.regularizer.alpha * jnp.eye(inputs.shape[-1])
+ penalty = self.regularizer.alpha * jnp.ones((n_features,))
+ # Do not penalize the intercept/bias column. ``polynomial_features``
+ # (used by ``PolynomialRidgeRegression`` when ``add_bias=True``)
+ # prepends a constant column at index 0; shrinking it would bias
+ # the fit on data with a nonzero mean.
+ if getattr(self, 'add_bias', False):
+ penalty = penalty.at[0].set(0.)
+ temp += jnp.diag(penalty)
weights = jnp.linalg.pinv(temp) @ (inputs.T @ targets)
return weights
@@ -383,7 +391,7 @@ def call(self, targets, inputs, outputs=None) -> ArrayType:
def cond_fun(a):
i, par_old, par_new = a
return jnp.logical_and(jnp.logical_not(jnp.allclose(par_old, par_new)),
- i < self.max_iter).value
+ i < self.max_iter)
def body_fun(a):
i, par_old, par_new = a
diff --git a/brainpy/algorithms/online.py b/brainpy/algorithms/online.py
index 9477154bd..9dd54c35d 100644
--- a/brainpy/algorithms/online.py
+++ b/brainpy/algorithms/online.py
@@ -145,12 +145,21 @@ def call(
assert input.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {input.shape}'
assert target.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {target.shape}'
assert output.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {output.shape}'
- k = jnp.dot(P.value, input.T) # (num_input, num_batch)
+ # Block recursive least squares update (valid for any batch size B>=1).
+ # See e.g. Haykin, "Adaptive Filter Theory", block/multi-sample RLS.
+ # K = P Hᵀ (I_B + H P Hᵀ)⁻¹ -> Kalman gain, shape (num_input, B)
+ # P <- P - K H P -> covariance update, shape (num_input, num_input)
+ # w <- w + K (target - H w) -> here output = H w, so dw = -K (output - target)
+ # For B==1 this reduces exactly to the previous scalar update
+ # (c = 1/(1+hPh)), but unlike `jnp.sum(1/(1+HPHᵀ))` it stays correct for
+ # B>1 by inverting the full (B, B) matrix instead of summing reciprocals.
+ Pv = P.value
+ k = jnp.dot(Pv, input.T) # (num_input, num_batch)
hPh = jnp.dot(input, k) # (num_batch, num_batch)
- c = jnp.sum(1.0 / (1.0 + hPh)) # ()
- P -= c * jnp.dot(k, k.T) # (num_input, num_input)
+ gain = jnp.dot(k, jnp.linalg.inv(jnp.eye(hPh.shape[0]) + hPh)) # (num_input, num_batch)
+ P.value = Pv - jnp.dot(gain, jnp.dot(input, Pv)) # (num_input, num_input)
e = output - target # (num_batch, num_output)
- dw = -c * jnp.dot(k, e) # (num_input, num_output)
+ dw = -jnp.dot(gain, e) # (num_input, num_output)
return dw
diff --git a/brainpy/analysis/lowdim/lowdim_analyzer.py b/brainpy/analysis/lowdim/lowdim_analyzer.py
index aa199b9e8..02f8dfca9 100644
--- a/brainpy/analysis/lowdim/lowdim_analyzer.py
+++ b/brainpy/analysis/lowdim/lowdim_analyzer.py
@@ -374,7 +374,7 @@ def _get_fixed_points(self, candidates, *args, num_seg=None, tol_aux=1e-7, loss_
# args: parameters, a list/tuple of vectors
candidates = candidates.value if isinstance(candidates, bm.Array) else candidates
selected_ids = np.arange(len(candidates))
- args = tuple(a.value if isinstance(candidates, bm.Array) else a for a in args)
+ args = tuple(a.value if isinstance(a, bm.Array) else a for a in args)
for a in args: assert len(a) == len(candidates)
if num_seg is None:
num_seg = len(self.resolutions[self.x_var])
@@ -950,7 +950,7 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
# args: parameters, a list/tuple of vectors
candidates = candidates.value if isinstance(candidates, bm.Array) else candidates
selected_ids = np.arange(len(candidates))
- args = tuple(a.value if isinstance(candidates, bm.Array) else a for a in args)
+ args = tuple(a.value if isinstance(a, bm.Array) else a for a in args)
for a in args: assert len(a) == len(candidates)
if self.convert_type() == C.x_by_y:
@@ -1035,9 +1035,16 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
all_ids.append(seg_ids)
for i in range(len(all_args)):
all_args[i].append(seg_args[i][ids])
- all_fps = jnp.concatenate(all_fps)
- all_ids = jnp.concatenate(all_ids)
- all_args = tuple(jnp.concatenate(args) for args in all_args)
+ if len(all_fps):
+ all_fps = jnp.concatenate(all_fps)
+ all_ids = jnp.concatenate(all_ids)
+ all_args = tuple(jnp.concatenate(args) for args in all_args)
+ else:
+ # No candidate converged to a fixed point. Return empty arrays
+ # with the correct shapes/dtypes instead of concatenating [].
+ all_fps = jnp.zeros((0, candidates.shape[1]), dtype=candidates.dtype)
+ all_ids = jnp.zeros((0,), dtype=selected_ids.dtype)
+ all_args = tuple(jnp.zeros((0,), dtype=a.dtype) for a in args)
return all_fps, all_ids, all_args
diff --git a/brainpy/analysis/utils/optimization.py b/brainpy/analysis/utils/optimization.py
index 6630d3897..6c6755922 100644
--- a/brainpy/analysis/utils/optimization.py
+++ b/brainpy/analysis/utils/optimization.py
@@ -395,7 +395,7 @@ def roots_of_1d_by_x(f, candidates, args=()):
"""
f = f_without_jaxarray_return(f)
candidates = candidates.value if isinstance(candidates, bm.Array) else candidates
- args = tuple(a.value if isinstance(candidates, bm.Array) else a for a in args)
+ args = tuple(a.value if isinstance(a, bm.Array) else a for a in args)
vals = f(candidates, *args)
signs = jnp.sign(vals)
zero_sign_idx = jnp.where(signs == 0)[0]
diff --git a/brainpy/connect/random_conn.py b/brainpy/connect/random_conn.py
index 50e658366..91640a2bd 100644
--- a/brainpy/connect/random_conn.py
+++ b/brainpy/connect/random_conn.py
@@ -84,10 +84,10 @@ def __repr__(self):
f'seed={self.seed})')
def _iii(self):
- if (not self.include_self) and (self.pre_num != self.post_num):
- raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). '
- f'But `include_self` is set to True.')
-
+ # NOTE: no guard on ``include_self=False`` for rectangular (pre_num !=
+ # post_num) shapes — ``build_coo``/``build_csr`` already drop coincident
+ # ``pre == post`` indices generically, so the previous (contradictory)
+ # ConnectorError was both wrong-worded and overly restrictive.
if self.pre_ratio < 1.:
pre_num_to_select = int(self.pre_num * self.pre_ratio)
pre_ids = self._jaxrand.choice(self.pre_num, size=(pre_num_to_select,), replace=False)
@@ -96,7 +96,10 @@ def _iii(self):
pre_ids = jnp.arange(self.pre_num)
post_num_total = self.post_num
- post_num_to_select = int(self.post_num * self.prob)
+ # Round instead of truncating: ``int(post_num * prob)`` floors a small
+ # expected fan-out (e.g. ``3 * 0.3 = 0.9``) to 0 connections, silently
+ # producing an empty connectivity for small post populations.
+ post_num_to_select = int(round(self.post_num * self.prob))
if self.allow_multi_conn:
selected_post_ids = self._jaxrand.randint(0, post_num_total, (pre_num_to_select, post_num_to_select))
diff --git a/brainpy/delay.py b/brainpy/delay.py
index ea5e7c22f..04270253c 100644
--- a/brainpy/delay.py
+++ b/brainpy/delay.py
@@ -251,10 +251,13 @@ def __init__(
# delay data
self._init = init
+ # ``self.data`` must exist before ``_init_data`` is called, because
+ # ``_init_data`` reads ``self.data`` to decide whether to allocate a
+ # new buffer. Initialize it unconditionally to avoid an AttributeError
+ # on the ``time > 0`` (``max_length > 0``) path.
+ self.data = None
if self.max_length > 0:
self._init_data(self.max_length)
- else:
- self.data = None
# other info
if entries is not None:
diff --git a/brainpy/dnn/linear.py b/brainpy/dnn/linear.py
index a38e5d385..4411cff55 100644
--- a/brainpy/dnn/linear.py
+++ b/brainpy/dnn/linear.py
@@ -225,8 +225,13 @@ def stdp_update(
w_min: numbers.Number = None,
w_max: numbers.Number = None
):
+ if bm.isscalar(self.W):
+ raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.')
if not isinstance(self.W, bm.Variable):
- raise ValueError(f'When using STDP to update synaptic weights, the weight must be a variable.')
+ # Promote the plain-array weight to a traceable Variable so the STDP
+ # in-place update works even when the comm was not built in a training
+ # mode. Only reached from ``stdp_update``; non-plastic use is unaffected.
+ self.W = bm.Variable(self.W)
if on_pre is not None:
spike = on_pre['spike']
trace = on_pre['trace']
@@ -320,8 +325,13 @@ def stdp_update(
w_min: numbers.Number = None,
w_max: numbers.Number = None
):
+ if bm.isscalar(self.weight):
+ raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.')
if not isinstance(self.weight, bm.Variable):
- raise ValueError(f'When using STDP to update synaptic weights, the weight must be a variable.')
+ # Promote the plain-array weight to a traceable Variable so the STDP
+ # in-place update works even when the comm was not built in a training
+ # mode. Only reached from ``stdp_update``; non-plastic use is unaffected.
+ self.weight = bm.Variable(self.weight)
if on_pre is not None:
spike = on_pre['spike']
trace = on_pre['trace']
@@ -375,7 +385,10 @@ def stdp_update(
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(self.weight, bm.Variable):
- self.tracing_variable('weight', self.weight, self.weight.shape)
+ # Promote the plain-array weight to a traceable Variable so that the
+ # STDP in-place update below works. This branch is only reached from
+ # ``stdp_update`` (plasticity), so non-plastic use is unaffected.
+ self.weight = bm.Variable(self.weight)
if on_pre is not None:
spike = on_pre['spike']
trace = on_pre['trace']
@@ -449,7 +462,10 @@ def stdp_update(
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(self.weight, bm.Variable):
- self.tracing_variable('weight', self.weight, self.weight.shape)
+ # Promote the plain-array weight to a traceable Variable so that the
+ # STDP in-place update below works. This branch is only reached from
+ # ``stdp_update`` (plasticity), so non-plastic use is unaffected.
+ self.weight = bm.Variable(self.weight)
if on_pre is not None:
spike = on_pre['spike']
trace = on_pre['trace']
@@ -500,7 +516,10 @@ def stdp_update(
raise ValueError(
f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.')
if not isinstance(self.weight, bm.Variable):
- self.tracing_variable('weight', self.weight, self.weight.shape)
+ # Promote the plain-array weight to a traceable Variable so that the
+ # STDP in-place update below works. This branch is only reached from
+ # ``stdp_update`` (plasticity), so non-plastic use is unaffected.
+ self.weight = bm.Variable(self.weight)
if on_pre is not None: # update on presynaptic spike
spike = on_pre['spike']
trace = on_pre['trace']
diff --git a/brainpy/dnn/normalization.py b/brainpy/dnn/normalization.py
index 1b3a2e85c..00b321819 100644
--- a/brainpy/dnn/normalization.py
+++ b/brainpy/dnn/normalization.py
@@ -113,6 +113,12 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
):
+ # ``BatchNorm`` computes statistics across a batch axis, so it requires a
+ # batching/training mode. The global default mode is ``NonBatchingMode``,
+ # which would make the layer raise ``UnsupportedError`` out-of-the-box
+ # (H-51). Default to the training mode when the user does not specify one.
+ if mode is None:
+ mode = bm.training_mode
super(BatchNorm, self).__init__(name=name, mode=mode)
# check.is_subclass(self.mode, (bm.BatchingMode, bm.TrainingMode), self.name)
@@ -131,9 +137,19 @@ def __init__(
self.running_mean = bm.Variable(jnp.zeros(self.num_features))
self.running_var = bm.Variable(jnp.ones(self.num_features))
if self.affine:
- assert isinstance(self.mode, bm.TrainingMode)
- self.bias = bm.TrainVar(parameter(self.bias_initializer, self.num_features))
- self.scale = bm.TrainVar(parameter(self.scale_initializer, self.num_features))
+ bias = parameter(self.bias_initializer, self.num_features)
+ scale = parameter(self.scale_initializer, self.num_features)
+ # Make the affine parameters trainable only under a training mode;
+ # otherwise keep them as plain variables (matching the pattern used
+ # by ``brainpy.dnn.Linear``/``Conv``). This avoids the hard
+ # ``assert isinstance(self.mode, TrainingMode)`` that crashed the
+ # layer under non-training modes (H-51).
+ if isinstance(self.mode, bm.TrainingMode):
+ self.bias = bm.TrainVar(bias)
+ self.scale = bm.TrainVar(scale)
+ else:
+ self.bias = bm.Variable(bias)
+ self.scale = bm.Variable(scale)
def _check_input_dim(self, x):
raise NotImplementedError
@@ -500,9 +516,20 @@ def __init__(
assert all([isinstance(s, int) for s in normalized_shape]), 'Must be a sequence of integer.'
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
- assert isinstance(self.mode, bm.TrainingMode)
- self.bias = bm.TrainVar(parameter(self.bias_initializer, self.normalized_shape))
- self.scale = bm.TrainVar(parameter(self.scale_initializer, self.normalized_shape))
+ bias = parameter(self.bias_initializer, self.normalized_shape)
+ scale = parameter(self.scale_initializer, self.normalized_shape)
+ # ``LayerNorm`` does not depend on batch statistics, so it works under
+ # any mode (including the default ``NonBatchingMode``). Only wrap the
+ # affine parameters as trainable variables under a training mode;
+ # otherwise keep them as plain variables. This removes the hard
+ # ``assert isinstance(self.mode, TrainingMode)`` that crashed the
+ # affine layer out-of-the-box (H-51).
+ if isinstance(self.mode, bm.TrainingMode):
+ self.bias = bm.TrainVar(bias)
+ self.scale = bm.TrainVar(scale)
+ else:
+ self.bias = bm.Variable(bias)
+ self.scale = bm.Variable(scale)
def update(self, x):
if x.shape[-len(self.normalized_shape):] != self.normalized_shape:
@@ -585,16 +612,32 @@ def __init__(
self.bias_initializer = bias_initializer
self.scale_initializer = scale_initializer
if self.affine:
- assert isinstance(self.mode, bm.TrainingMode)
- self.bias = bm.TrainVar(parameter(self.bias_initializer, self.num_channels))
- self.scale = bm.TrainVar(parameter(self.scale_initializer, self.num_channels))
+ bias = parameter(self.bias_initializer, self.num_channels)
+ scale = parameter(self.scale_initializer, self.num_channels)
+ # ``GroupNorm``/``InstanceNorm`` compute statistics independently of
+ # the batch size, so they work under any mode (including the default
+ # ``NonBatchingMode``). Only make the affine parameters trainable
+ # under a training mode; otherwise keep them as plain variables. This
+ # removes the hard ``assert isinstance(self.mode, TrainingMode)`` that
+ # crashed the affine layer out-of-the-box (H-51).
+ if isinstance(self.mode, bm.TrainingMode):
+ self.bias = bm.TrainVar(bias)
+ self.scale = bm.TrainVar(scale)
+ else:
+ self.bias = bm.Variable(bias)
+ self.scale = bm.Variable(scale)
def update(self, x):
assert x.shape[-1] == self.num_channels
origin_shape, origin_dim = x.shape, x.ndim
group_shape = (-1,) + x.shape[1:-1] + (self.num_groups, self.num_channels // self.num_groups)
x = bm.as_jax(x.reshape(group_shape))
- reduction_axes = tuple(range(1, x.ndim - 1)) + (-1,)
+ # After reshape the axis layout is
+ # ``[0]=batch, [1..ndim-3]=spatial, [ndim-2]=group, [ndim-1]=channels-per-group``.
+ # Normalization must reduce over the spatial axes and the within-group
+ # channel axis, but NOT over the group axis (``ndim-2``); otherwise the
+ # groups are averaged together and ``num_groups`` has no effect (C-05).
+ reduction_axes = tuple(range(1, x.ndim - 2)) + (-1,)
mean = jnp.mean(x, reduction_axes, keepdims=True)
var = jnp.var(x, reduction_axes, keepdims=True)
x = (x - mean) * lax.rsqrt(var + lax.convert_element_type(self.epsilon, x.dtype))
diff --git a/brainpy/dyn/channels/calcium.py b/brainpy/dyn/channels/calcium.py
index 3dd30d74c..b2bf3ca5e 100644
--- a/brainpy/dyn/channels/calcium.py
+++ b/brainpy/dyn/channels/calcium.py
@@ -20,6 +20,8 @@
from typing import Union, Callable, Optional
+import jax.numpy as jnp
+
import brainpy.math as bm
from brainpy.context import share
from brainpy.dyn.ions.calcium import Calcium, CalciumDyna
@@ -40,6 +42,20 @@
]
+def _exprel(x):
+ """Stable ``(exp(x) - 1) / x`` with a finite value *and* finite gradient at ``x == 0``.
+
+ The HH/Markov rate functions have the removable-singularity form
+ ``num * temp / (exp(temp / k) - 1)`` which is ``0 / 0`` (NaN value and NaN
+ gradient) at the singular voltage. Rewriting them as ``num * k / exprel(...)``
+ removes the singularity in value, but ``brainpy.math.exprel`` still yields a
+ NaN gradient at 0, so we use this branch-safe helper instead.
+ """
+ small = jnp.abs(x) < 1e-7
+ safe_x = jnp.where(small, 1.0, x)
+ return jnp.where(small, 1.0 + x / 2.0, jnp.expm1(safe_x) / safe_x)
+
+
class CalciumChannel(IonChannel):
"""Base class for Calcium ion channels."""
@@ -708,7 +724,8 @@ def __init__(
def f_p_alpha(self, V):
temp = -27 - V + self.V_sh
- return 0.055 * temp / (bm.exp(temp / 3.8) - 1)
+ # 0.055 * temp / (exp(temp/3.8) - 1) == 0.055 * 3.8 / exprel(temp/3.8)
+ return (0.055 * 3.8) / _exprel(temp / 3.8)
def f_p_beta(self, V):
return 0.94 * bm.exp((-75. - V + self.V_sh) / 17.)
diff --git a/brainpy/dyn/channels/potassium.py b/brainpy/dyn/channels/potassium.py
index 738d3bc62..5f2f224e7 100644
--- a/brainpy/dyn/channels/potassium.py
+++ b/brainpy/dyn/channels/potassium.py
@@ -20,6 +20,8 @@
from typing import Union, Callable, Optional, Sequence
+import jax.numpy as jnp
+
import brainpy.math as bm
from brainpy.context import share
from brainpy.dyn.ions.potassium import Potassium
@@ -43,6 +45,20 @@
]
+def _exprel(x):
+ """Stable ``(exp(x) - 1) / x`` with a finite value *and* finite gradient at ``x == 0``.
+
+ The HH/Markov rate functions have the removable-singularity form
+ ``num * temp / (1 - exp(-temp / k))`` which is ``0 / 0`` (NaN value and NaN
+ gradient) at the singular voltage. Rewriting them as ``num * k / exprel(...)``
+ removes the singularity in value, but ``brainpy.math.exprel`` still yields a
+ NaN gradient at 0, so we use this branch-safe helper instead.
+ """
+ small = jnp.abs(x) < 1e-7
+ safe_x = jnp.where(small, 1.0, x)
+ return jnp.where(small, 1.0 + x / 2.0, jnp.expm1(safe_x) / safe_x)
+
+
class PotassiumChannel(IonChannel):
"""Base class for sodium channel dynamics."""
@@ -219,7 +235,8 @@ def __init__(
def f_p_alpha(self, V):
tmp = V - self.V_sh - 15.
- return 0.032 * tmp / (1. - bm.exp(-tmp / 5.))
+ # 0.032 * tmp / (1 - exp(-tmp/5)) == 0.032 * 5 / exprel(-tmp/5)
+ return 0.16 / _exprel(-tmp / 5.)
def f_p_beta(self, V):
return 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.)
@@ -287,7 +304,8 @@ def __init__(
def f_p_alpha(self, V):
c = 15 - V + self.V_sh
- return 0.032 * c / (bm.exp(c / 5) - 1.)
+ # 0.032 * c / (exp(c/5) - 1) == 0.032 * 5 / exprel(c/5)
+ return 0.16 / _exprel(c / 5.)
def f_p_beta(self, V):
return 0.5 * bm.exp((10 - V + self.V_sh) / 40)
@@ -356,7 +374,8 @@ def __init__(
def f_p_alpha(self, V):
temp = V - self.V_sh + 10
- return 0.01 * temp / (1 - bm.exp(-temp / 10))
+ # 0.01 * temp / (1 - exp(-temp/10)) == 0.01 * 10 / exprel(-temp/10)
+ return 0.1 / _exprel(-temp / 10.)
def f_p_beta(self, V):
return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80)
@@ -1188,7 +1207,8 @@ def __init__(
def f_p_alpha(self, V):
tmp = V - self.V_sh - 15.
- return 0.032 * tmp / (1. - bm.exp(-tmp / 5.))
+ # 0.032 * tmp / (1 - exp(-tmp/5)) == 0.032 * 5 / exprel(-tmp/5)
+ return 0.16 / _exprel(-tmp / 5.)
def f_p_beta(self, V):
return 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.)
@@ -1258,7 +1278,8 @@ def __init__(
def f_p_alpha(self, V):
c = 15 - V + self.V_sh
- return 0.032 * c / (bm.exp(c / 5) - 1.)
+ # 0.032 * c / (exp(c/5) - 1) == 0.032 * 5 / exprel(c/5)
+ return 0.16 / _exprel(c / 5.)
def f_p_beta(self, V):
return 0.5 * bm.exp((10 - V + self.V_sh) / 40)
@@ -1329,7 +1350,8 @@ def __init__(
def f_p_alpha(self, V):
temp = V - self.V_sh + 10
- return 0.01 * temp / (1 - bm.exp(-temp / 10))
+ # 0.01 * temp / (1 - exp(-temp/10)) == 0.01 * 10 / exprel(-temp/10)
+ return 0.1 / _exprel(-temp / 10.)
def f_p_beta(self, V):
return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80)
diff --git a/brainpy/dyn/channels/potassium_compatible.py b/brainpy/dyn/channels/potassium_compatible.py
index c2590ec18..1f707bd65 100644
--- a/brainpy/dyn/channels/potassium_compatible.py
+++ b/brainpy/dyn/channels/potassium_compatible.py
@@ -28,6 +28,19 @@
from brainpy.integrators import odeint, JointEq
from brainpy.types import ArrayType
+
+def _exprel(x):
+ """Stable ``(exp(x) - 1) / x`` with finite value *and* gradient at ``x == 0``.
+
+ The HH/Markov rate functions have the removable-singularity form
+ ``num * temp / (1 - exp(-temp / k))`` (equivalently ``.../(exp(temp/k) - 1)``),
+ which is ``0 / 0`` (NaN value and NaN gradient) at the singular voltage.
+ Rewriting them as ``num * k / exprel(...)`` removes the singularity.
+ """
+ small = bm.abs(x) < 1e-7
+ safe_x = bm.where(small, 1.0, x)
+ return bm.where(small, 1.0 + x / 2.0, bm.expm1(safe_x) / safe_x)
+
__all__ = [
'IKDR_Ba2002',
'IK_TM1991',
@@ -204,7 +217,7 @@ def __init__(
def f_p_alpha(self, V):
tmp = V - self.V_sh - 15.
- return 0.032 * tmp / (1. - bm.exp(-tmp / 5.))
+ return 0.16 / _exprel(-tmp / 5.)
def f_p_beta(self, V):
return 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.)
@@ -274,7 +287,7 @@ def __init__(
def f_p_alpha(self, V):
c = 15 - V + self.V_sh
- return 0.032 * c / (bm.exp(c / 5) - 1.)
+ return 0.16 / _exprel(c / 5.)
def f_p_beta(self, V):
return 0.5 * bm.exp((10 - V + self.V_sh) / 40)
@@ -345,7 +358,7 @@ def __init__(
def f_p_alpha(self, V):
temp = V - self.V_sh + 10
- return 0.01 * temp / (1 - bm.exp(-temp / 10))
+ return 0.1 / _exprel(-temp / 10.)
def f_p_beta(self, V):
return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80)
diff --git a/brainpy/dyn/channels/sodium.py b/brainpy/dyn/channels/sodium.py
index fb0ff882f..b1f9700e7 100644
--- a/brainpy/dyn/channels/sodium.py
+++ b/brainpy/dyn/channels/sodium.py
@@ -20,6 +20,8 @@
from typing import Union, Callable
+import jax.numpy as jnp
+
import brainpy.math as bm
from brainpy.context import share
from brainpy.dyn.ions.sodium import Sodium
@@ -36,6 +38,20 @@
]
+def _exprel(x):
+ """Stable ``(exp(x) - 1) / x`` with a finite value *and* finite gradient at ``x == 0``.
+
+ The HH/Markov rate functions have the removable-singularity form
+ ``num * temp / (1 - exp(-temp / k))`` which is ``0 / 0`` (NaN value and NaN
+ gradient) at the singular voltage. Rewriting them as ``num * k / exprel(...)``
+ removes the singularity in value, but ``brainpy.math.exprel`` still yields a
+ NaN gradient at 0, so we use this branch-safe helper instead.
+ """
+ small = jnp.abs(x) < 1e-7
+ safe_x = jnp.where(small, 1.0, x)
+ return jnp.where(small, 1.0 + x / 2.0, jnp.expm1(safe_x) / safe_x)
+
+
class SodiumChannel(IonChannel):
"""Base class for sodium channel dynamics."""
@@ -212,11 +228,13 @@ def __init__(
def f_p_alpha(self, V):
temp = V - self.V_sh - 13.
- return 0.32 * temp / (1. - bm.exp(-temp / 4.))
+ # 0.32 * temp / (1 - exp(-temp/4)) == 0.32 * 4 / exprel(-temp/4)
+ return 1.28 / _exprel(-temp / 4.)
def f_p_beta(self, V):
temp = V - self.V_sh - 40.
- return -0.28 * temp / (1. - bm.exp(temp / 5.))
+ # -0.28 * temp / (1 - exp(temp/5)) == 0.28 * 5 / exprel(temp/5)
+ return 1.4 / _exprel(temp / 5.)
def f_q_alpha(self, V):
return 0.128 * bm.exp(-(V - self.V_sh - 17.) / 18.)
@@ -296,11 +314,13 @@ def __init__(
def f_p_alpha(self, V):
temp = 13 - V + self.V_sh
- return 0.32 * temp / (bm.exp(temp / 4) - 1.)
+ # 0.32 * temp / (exp(temp/4) - 1) == 0.32 * 4 / exprel(temp/4)
+ return 1.28 / _exprel(temp / 4.)
def f_p_beta(self, V):
temp = V - self.V_sh - 40
- return 0.28 * temp / (bm.exp(temp / 5) - 1)
+ # 0.28 * temp / (exp(temp/5) - 1) == 0.28 * 5 / exprel(temp/5)
+ return 1.4 / _exprel(temp / 5.)
def f_q_alpha(self, V):
return 0.128 * bm.exp((17 - V + self.V_sh) / 18)
@@ -381,7 +401,8 @@ def __init__(
def f_p_alpha(self, V):
temp = V - self.V_sh - 5
- return 0.1 * temp / (1 - bm.exp(-temp / 10))
+ # 0.1 * temp / (1 - exp(-temp/10)) == 0.1 * 10 / exprel(-temp/10)
+ return 1.0 / _exprel(-temp / 10.)
def f_p_beta(self, V):
return 4.0 * bm.exp(-(V - self.V_sh + 20) / 18)
diff --git a/brainpy/dyn/channels/sodium_compatible.py b/brainpy/dyn/channels/sodium_compatible.py
index 680151acd..f18000213 100644
--- a/brainpy/dyn/channels/sodium_compatible.py
+++ b/brainpy/dyn/channels/sodium_compatible.py
@@ -28,6 +28,19 @@
from brainpy.types import ArrayType
from .base import IonChannel
+
+def _exprel(x):
+ """Stable ``(exp(x) - 1) / x`` with finite value *and* gradient at ``x == 0``.
+
+ The HH/Markov rate functions have the removable-singularity form
+ ``num * temp / (1 - exp(-temp / k))`` (equivalently ``.../(exp(temp/k) - 1)``),
+ which is ``0 / 0`` (NaN value and NaN gradient) at the singular voltage.
+ Rewriting them as ``num * k / exprel(...)`` removes the singularity.
+ """
+ small = bm.abs(x) < 1e-7
+ safe_x = bm.where(small, 1.0, x)
+ return bm.where(small, 1.0 + x / 2.0, bm.expm1(safe_x) / safe_x)
+
__all__ = [
'INa_Ba2002',
'INa_TM1991',
@@ -198,11 +211,11 @@ def __init__(
def f_p_alpha(self, V):
temp = V - self.V_sh - 13.
- return 0.32 * temp / (1. - bm.exp(-temp / 4.))
+ return 1.28 / _exprel(-temp / 4.)
def f_p_beta(self, V):
temp = V - self.V_sh - 40.
- return -0.28 * temp / (1. - bm.exp(temp / 5.))
+ return 1.4 / _exprel(temp / 5.)
def f_q_alpha(self, V):
return 0.128 * bm.exp(-(V - self.V_sh - 17.) / 18.)
@@ -284,11 +297,11 @@ def __init__(
def f_p_alpha(self, V):
temp = 13 - V + self.V_sh
- return 0.32 * temp / (bm.exp(temp / 4) - 1.)
+ return 1.28 / _exprel(temp / 4.)
def f_p_beta(self, V):
temp = V - self.V_sh - 40
- return 0.28 * temp / (bm.exp(temp / 5) - 1)
+ return 1.4 / _exprel(temp / 5.)
def f_q_alpha(self, V):
return 0.128 * bm.exp((17 - V + self.V_sh) / 18)
@@ -371,7 +384,7 @@ def __init__(
def f_p_alpha(self, V):
temp = V - self.V_sh - 5
- return 0.1 * temp / (1 - bm.exp(-temp / 10))
+ return 1.0 / _exprel(-temp / 10.)
def f_p_beta(self, V):
return 4.0 * bm.exp(-(V - self.V_sh + 20) / 18)
diff --git a/brainpy/dyn/ions/base.py b/brainpy/dyn/ions/base.py
index 3a278c217..64526003c 100644
--- a/brainpy/dyn/ions/base.py
+++ b/brainpy/dyn/ions/base.py
@@ -52,7 +52,7 @@ def __init__(self, *ions, **channels):
self.ions: Sequence['Ion'] = tuple(ions)
self._ion_classes = tuple([type(ion) for ion in self.ions])
for k, v in channels.items():
- self.add_elem(k=v)
+ self.add_elem(**{k: v})
def update(self, V):
nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values())
diff --git a/brainpy/dyn/ions/potassium.py b/brainpy/dyn/ions/potassium.py
index 9a664531c..8d7ef7905 100644
--- a/brainpy/dyn/ions/potassium.py
+++ b/brainpy/dyn/ions/potassium.py
@@ -42,7 +42,7 @@ def __init__(
self,
size: Shape,
keep_size: bool = False,
- E: Union[float, ArrayType, Initializer, Callable] = -950.,
+ E: Union[float, ArrayType, Initializer, Callable] = -95.,
C: Union[float, ArrayType, Initializer, Callable] = 0.0400811,
method: str = 'exp_auto',
name: Optional[str] = None,
diff --git a/brainpy/dyn/neurons/lif.py b/brainpy/dyn/neurons/lif.py
index 28d693c3d..d4fe42936 100644
--- a/brainpy/dyn/neurons/lif.py
+++ b/brainpy/dyn/neurons/lif.py
@@ -1106,7 +1106,14 @@ def __init__(
self._V_initializer = is_initializer(V_initializer)
# integral
- self.integral = odeint(method=method, f=self.derivative)
+ # NOTE: ``self.noise`` is already created by ``ExpIFLTC.__init__`` above
+ # (via ``noise=noise``). Guard the integral on it so a configured
+ # ``noise=`` is honoured with ``sdeint`` instead of being silently
+ # dropped (mirrors every other ``*RefLTC`` and the non-Ref class).
+ if self.noise is not None:
+ self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise)
+ else:
+ self.integral = odeint(method=method, f=self.derivative)
# variables
if init_var:
@@ -3811,8 +3818,11 @@ def update(self, x=None):
V += (self.V_reset - V) * spike_no_grad
else:
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
- I1 += spike * (self.R1 * I1 + self.A1 - I1)
- I2 += spike * (self.R2 * I2 + self.A2 - I2)
+ # Use ``spike_no_grad`` for every state reset so that ``detach_spk``
+ # actually stops the gradient through the spike; the raw ``spike``
+ # here would otherwise leak it into the I1/I2 resets.
+ I1 += spike_no_grad * (self.R1 * I1 + self.A1 - I1)
+ I2 += spike_no_grad * (self.R2 * I2 + self.A2 - I2)
V_th += (bm.maximum(self.V_th_reset, V_th) - V_th) * spike_no_grad
spike_ = spike_no_grad > 0.
# will be used in other place, like Delta Synapse, so stop its gradient
@@ -4492,8 +4502,11 @@ def update(self, x=None):
if isinstance(self.mode, bm.TrainingMode):
spike = self.spk_fun(V - self.V_th)
spike_no_grad = stop_gradient(spike) if self.detach_spk else spike
- V += spike * (self.c - V)
- u += spike * self.d
+ # Use ``spike_no_grad`` for the state resets so that ``detach_spk``
+ # actually stops the gradient through the spike (raw ``spike`` here
+ # would otherwise make ``detach_spk`` a no-op for the V/u resets).
+ V += spike_no_grad * (self.c - V)
+ u += spike_no_grad * self.d
spike_ = spike_no_grad > 0.
# will be used in other place, like Delta Synapse, so stop its gradient
if self.ref_var:
diff --git a/brainpy/dyn/projections/align_post.py b/brainpy/dyn/projections/align_post.py
index af192cc68..2f0248d3b 100644
--- a/brainpy/dyn/projections/align_post.py
+++ b/brainpy/dyn/projections/align_post.py
@@ -383,7 +383,7 @@ def __init__(
def update(self, x):
current = self.comm(x)
- g = self.syn(self.comm(x))
+ g = self.syn(current)
self.refs['out'].bind_cond(g) # synapse post current
return current
diff --git a/brainpy/dyn/projections/base.py b/brainpy/dyn/projections/base.py
index 9bdca17d2..651a2d6ed 100644
--- a/brainpy/dyn/projections/base.py
+++ b/brainpy/dyn/projections/base.py
@@ -12,14 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-from brainpy import math as bm
-from brainpy.mixin import ReturnInfo
+# Re-export the real projection base classes. This module previously held a
+# byte-for-byte duplicate of the private ``_get_return`` helper in ``utils.py``
+# (H-40); that helper is private, so ``from .base import *`` exported nothing and
+# the duplicate was both dead and misleading versus the real base classes.
+from brainpy.dynsys import Projection
+from .conn import SynConn
-
-def _get_return(return_info):
- if isinstance(return_info, bm.Variable):
- return return_info.value
- elif isinstance(return_info, ReturnInfo):
- return return_info.get_data()
- else:
- raise NotImplementedError
+__all__ = ['Projection', 'SynConn']
diff --git a/brainpy/dyn/projections/inputs.py b/brainpy/dyn/projections/inputs.py
index 5c78ce210..39197c4fb 100644
--- a/brainpy/dyn/projections/inputs.py
+++ b/brainpy/dyn/projections/inputs.py
@@ -162,16 +162,19 @@ def update(self):
p = self.freq * share['dt'] / 1e3
a = self.num_input * p
b = self.num_input * (1 - p)
+ # standard deviation of the Binomial(num_input, p) distribution:
+ # sqrt(num_input * p * (1 - p)) = sqrt(b * p), NOT the variance b * p.
+ scale = bm.sqrt(b * p)
if isinstance(share['dt'], numbers.Number): # dt is not traced
if (a > 5) and (b > 5):
- inp = bm.random.normal(a, b * p, self.target_var.shape)
+ inp = bm.random.normal(a, scale, self.target_var.shape)
else:
inp = bm.random.binomial(self.num_input, p, self.target_var.shape)
else: # dt is traced
inp = bm.cond((a > 5) * (b > 5),
- lambda: bm.random.normal(a, b * p, self.target_var.shape),
+ lambda: bm.random.normal(a, scale, self.target_var.shape),
lambda: bm.random.binomial(self.num_input, p, self.target_var.shape))
# inp = bm.sharding.partition(inp, self.target_var.sharding)
diff --git a/brainpy/dyn/projections/plasticity.py b/brainpy/dyn/projections/plasticity.py
index c6ccdd823..03fa3d509 100644
--- a/brainpy/dyn/projections/plasticity.py
+++ b/brainpy/dyn/projections/plasticity.py
@@ -225,18 +225,23 @@ def update(self):
raise AttributeError(f'{self} needs a "spike" variable for the post-synaptic neuron group.')
post_spike = self.refs['post'].spike.value
+ # weight bounds: pass ``None`` through unchanged (``bm.as_jax(None)`` raises);
+ # ``jnp.clip`` accepts ``None`` for an unbounded side.
+ w_min = None if self.W_min is None else bm.as_jax(self.W_min)
+ w_max = None if self.W_max is None else bm.as_jax(self.W_max)
+
# weight updates
Apost = self.refs['post_trace'].g.value
self.comm.stdp_update(
on_pre={"spike": bm.as_jax(pre_spike), "trace": bm.as_jax(-Apost * self.A2)},
- w_min=bm.as_jax(self.W_min),
- w_max=bm.as_jax(self.W_max),
+ w_min=w_min,
+ w_max=w_max,
)
Apre = self.refs['pre_trace'].g.value
self.comm.stdp_update(
on_post={"spike": bm.as_jax(post_spike), "trace": bm.as_jax(Apre * self.A1)},
- w_min=bm.as_jax(self.W_min),
- w_max=bm.as_jax(self.W_max),
+ w_min=w_min,
+ w_max=w_max,
)
# synaptic currents
diff --git a/brainpy/dyn/rates/populations.py b/brainpy/dyn/rates/populations.py
index c774c5fd0..8ae90fdc8 100644
--- a/brainpy/dyn/rates/populations.py
+++ b/brainpy/dyn/rates/populations.py
@@ -718,7 +718,7 @@ def dx(self, x, t, y, x_ext, a, w):
return (a - x * x - y * y) * x - w * y + x_ext
def dy(self, y, t, x, y_ext, a, w):
- return (a - x * x - y * y) * y - w * y + y_ext
+ return (a - x * x - y * y) * y + w * x + y_ext
def update(self, inp_x=None, inp_y=None):
t = share.load('t')
@@ -1048,7 +1048,7 @@ def update(self, inp_e=None, inp_i=None):
has_noise = bm.any(self.noise_e != 0.)
if has_noise:
- de += bm.random.randn(self.varshape) * self.noise_e
+ de += bm.random.randn(*self.varshape) * self.noise_e
de = de / self.tau_e
self.e.value = bm.maximum(self.e + de * dt, 0.)
@@ -1057,7 +1057,7 @@ def update(self, inp_e=None, inp_i=None):
has_noise = bm.any(self.noise_i != 0.)
if has_noise:
- di += bm.random.randn(self.varshape) * self.noise_i
+ di += bm.random.randn(*self.varshape) * self.noise_i
di = di / self.tau_i
self.i.value = bm.maximum(self.i + di * dt, 0.)
return self.e.value
diff --git a/brainpy/dyn/rates/reservoir.py b/brainpy/dyn/rates/reservoir.py
index 662cffff3..57912b0b8 100644
--- a/brainpy/dyn/rates/reservoir.py
+++ b/brainpy/dyn/rates/reservoir.py
@@ -220,10 +220,12 @@ def update(self, x):
hidden += bm.sparse.seg_matmul(self.state, sparse)
else:
hidden += self.state @ self.Wrec
+ if self.bias is not None:
+ hidden += self.bias
if self.activation_type == 'internal':
hidden = self.activation(hidden)
if self.noise_rec > 0.:
- hidden += self.noise_rec * bm.random.uniform(-1, -1, self.state.shape)
+ hidden += self.noise_rec * bm.random.uniform(-1, 1, self.state.shape)
# new state/output
state = (1 - self.leaky_rate) * self.state + self.leaky_rate * hidden
if self.activation_type == 'external':
diff --git a/brainpy/dyn/rates/rnncells.py b/brainpy/dyn/rates/rnncells.py
index ef4107319..eae8b6fc4 100644
--- a/brainpy/dyn/rates/rnncells.py
+++ b/brainpy/dyn/rates/rnncells.py
@@ -398,7 +398,7 @@ def h(self):
def h(self, value):
if self.state is None:
raise ValueError('Cannot set "h" state. Because the state is not initialized.')
- self.state[:self.state.shape[0] // 2, :] = value
+ self.state[..., :self.state.shape[-1] // 2] = value
@property
def c(self):
@@ -409,7 +409,7 @@ def c(self):
def c(self, value):
if self.state is None:
raise ValueError('Cannot set "c" state. Because the state is not initialized.')
- self.state[self.state.shape[0] // 2:, :] = value
+ self.state[..., self.state.shape[-1] // 2:] = value
class _ConvNDLSTMCell(Layer):
diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py
index 80f8ca849..96d666d18 100644
--- a/brainpy/dyn/synapses/abstract_models.py
+++ b/brainpy/dyn/synapses/abstract_models.py
@@ -859,7 +859,7 @@ def reset_state(self, batch_or_mode=None, **kwargs):
@property
def derivative(self):
- du = lambda u, t: self.U - u / self.tau_f
+ du = lambda u, t: -u / self.tau_f
dx = lambda x, t: (1 - x) / self.tau_d
return JointEq(du, dx)
@@ -877,8 +877,11 @@ def update(self, pre_spike):
# x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x
# --- simplified code:
- u = pre_spike * self.U * (1 - self.u) + u
- x = pre_spike * -u * self.x + x
+ # Apply the discrete spike jumps to the just-integrated (decayed) locals
+ # ``u``/``x`` rather than the pre-decay ``self.u``/``self.x``. The ``x``
+ # jump uses the already-updated ``u``.
+ u = u + pre_spike * self.U * (1 - u)
+ x = x - pre_spike * u * x
self.x.value = x
self.u.value = u
diff --git a/brainpy/dynold/experimental/others.py b/brainpy/dynold/experimental/others.py
index 1759e0a13..461a49308 100644
--- a/brainpy/dynold/experimental/others.py
+++ b/brainpy/dynold/experimental/others.py
@@ -70,15 +70,18 @@ def update(self):
p = self.freq * share.dt / 1e3
a = self.num_input * p
b = self.num_input * (1 - p)
+ # standard deviation of the Binomial(num_input, p) distribution:
+ # sqrt(num_input * p * (1 - p)) = sqrt(b * p), NOT the variance b * p.
+ scale = bm.sqrt(b * p)
if isinstance(share.dt, (int, float)): # dt is not in tracing
if (a > 5) and (b > 5):
- inp = bm.random.normal(a, b * p, self.target_shape)
+ inp = bm.random.normal(a, scale, self.target_shape)
else:
inp = bm.random.binomial(self.num_input, p, self.target_shape)
else: # dt is in tracing
inp = bm.cond((a > 5) * (b > 5),
- lambda _: bm.random.normal(a, b * p, self.target_shape),
+ lambda _: bm.random.normal(a, scale, self.target_shape),
lambda _: bm.random.binomial(self.num_input, p, self.target_shape),
None)
return inp * self.weight
diff --git a/brainpy/dynold/neurons/reduced_models.py b/brainpy/dynold/neurons/reduced_models.py
index 84a51f952..ad75f94f2 100644
--- a/brainpy/dynold/neurons/reduced_models.py
+++ b/brainpy/dynold/neurons/reduced_models.py
@@ -297,12 +297,12 @@ class ExpIF(lif.ExpIFRef):
------------- -------------- -------- ---------------------------------------------------
V_rest -65 mV Resting potential.
V_reset -68 mV Reset potential after spike.
- V_th -30 mV Threshold potential of spike.
+ V_th -55 mV Threshold potential of spike.
V_T -59.9 mV Threshold potential of generating action potential.
delta_T 3.48 \ Spike slope factor.
R 1 \ Membrane resistance.
tau 10 \ Membrane time constant. Compute by R * C.
- tau_ref 1.7 \ Refractory period length.
+ tau_ref 0. \ Refractory period length.
============= ============== ======== ===================================================
**Model Variables**
@@ -412,7 +412,7 @@ class AdExIF(lif.AdExIFRef):
------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------
V_rest -65 mV Resting potential.
V_reset -68 mV Reset potential after spike.
- V_th -30 mV Threshold potential of spike and reset.
+ V_th -55 mV Threshold potential of spike and reset.
V_T -59.9 mV Threshold potential of generating action potential.
delta_T 3.48 \ Spike slope factor.
a 1 \ The sensitivity of the recovery variable :math:`u` to the sub-threshold fluctuations of the membrane potential :math:`v`
diff --git a/brainpy/dynold/synapses/compat.py b/brainpy/dynold/synapses/compat.py
index b27c6348a..008df156d 100644
--- a/brainpy/dynold/synapses/compat.py
+++ b/brainpy/dynold/synapses/compat.py
@@ -21,7 +21,7 @@
from brainpy.dynold.synouts import COBA, CUBA
from brainpy.initialize import Initializer
from brainpy.types import ArrayType
-from .abstract_models import Delta, Exponential, DualExponential
+from .abstract_models import Delta, Exponential, DualExponential, Alpha
__all__ = [
'DeltaSynapse',
@@ -205,7 +205,7 @@ def __init__(
output=COBA(E=E))
-class AlphaCUBA(DualExpCUBA):
+class AlphaCUBA(Alpha):
r"""Current-based alpha synapse model.
.. deprecated:: 2.1.13
@@ -225,19 +225,23 @@ def __init__(
method: str = 'exp_auto',
name: str = None
):
+ # Alpha synapses have a single time constant; route through the
+ # single-tau ``Alpha`` implementation instead of a dual-exponential
+ # with ``tau_rise == tau_decay`` (which divides by zero in the peak
+ # normalizer ``A = tau_decay / (tau_decay - tau_rise)``).
super().__init__(pre=pre,
post=post,
conn=conn,
- conn_type=conn_type,
+ comp_method=conn_type,
delay_step=delay_step,
g_max=g_max,
tau_decay=tau_decay,
- tau_rise=tau_decay,
method=method,
- name=name)
+ name=name,
+ output=CUBA())
-class AlphaCOBA(DualExpCOBA):
+class AlphaCOBA(Alpha):
"""Conductance-based alpha synapse model.
.. deprecated:: 2.1.13
@@ -258,13 +262,17 @@ def __init__(
method: str = 'exp_auto',
name: str = None
):
+ # Alpha synapses have a single time constant; route through the
+ # single-tau ``Alpha`` implementation instead of a dual-exponential
+ # with ``tau_rise == tau_decay`` (which divides by zero in the peak
+ # normalizer ``A = tau_decay / (tau_decay - tau_rise)``).
super().__init__(pre=pre,
post=post,
conn=conn,
- conn_type=conn_type,
+ comp_method=conn_type,
delay_step=delay_step,
- g_max=g_max, E=E,
+ g_max=g_max,
tau_decay=tau_decay,
- tau_rise=tau_decay,
method=method,
- name=name)
+ name=name,
+ output=COBA(E=E))
diff --git a/brainpy/dynold/synapses/learning_rules.py b/brainpy/dynold/synapses/learning_rules.py
index b413c7bb3..c71ea0d89 100644
--- a/brainpy/dynold/synapses/learning_rules.py
+++ b/brainpy/dynold/synapses/learning_rules.py
@@ -36,6 +36,16 @@ def __init__(self, size, keep_size, tau, U, tau_f, tau_d, mode=None, method='exp
exp = synapses.Expon(size, keep_size, tau=tau, method=method, mode=mode)
super().__init__(stp, exp)
+ def update(self, pre_spike):
+ # ``synapses.STP.update`` returns the synaptic resource fraction ``u*x``,
+ # which is non-zero even at rest (≈U). Feeding it directly into ``Expon``
+ # (which treats its input as an additive current event) would inject
+ # ``u*x`` on *every* step, so the current keeps growing with zero
+ # presynaptic spikes. Gate the injected amplitude by ``pre_spike`` so the
+ # synaptic current only jumps when a presynaptic spike actually arrives.
+ ux = self[0](pre_spike)
+ return self[1](pre_spike * ux)
+
class STP(_TwoEndConnAlignPre):
r"""Short-term plasticity model.
diff --git a/brainpy/encoding/stateless_encoding.py b/brainpy/encoding/stateless_encoding.py
index 5388e7bb5..d358b2c52 100644
--- a/brainpy/encoding/stateless_encoding.py
+++ b/brainpy/encoding/stateless_encoding.py
@@ -88,10 +88,22 @@ def single_step(self, x, i_step: int = None):
Returns:
out: Array. The encoded spike train.
"""
+ # Draw a single Bernoulli sample for one step. (Delegating to
+ # ``multi_steps`` with ``n_time=None`` would crash on ``int(None / dt)``,
+ # and the old ``cond(..., self.multi_steps, x)`` passed the wrong number
+ # of arguments to ``multi_steps``.)
+ x = self._normalize(x)
+ spikes = bm.asarray(bm.random.rand(*x.shape) < x, dtype=x.dtype)
if i_step is None:
- return self.multi_steps(x, n_time=None)
- else:
- return bm.cond(bm.as_jax(i_step < self.first_spk_step), self._zero_out, self.multi_steps, x)
+ return spikes
+ # Before the first-spike step, emit no spikes.
+ before_first = bm.as_jax(i_step) < self.first_spk_step
+ return bm.asarray(bm.where(before_first, bm.zeros_like(spikes), spikes), dtype=x.dtype)
+
+ def _normalize(self, x):
+ if (self.min_val is not None) and (self.max_val is not None):
+ x = (x - self.min_val) / (self.max_val - self.min_val)
+ return x * self.gain + self.offset
def multi_steps(self, x, n_time: Optional[float]):
"""Generate spikes at multiple steps according to the inputs.
@@ -108,11 +120,11 @@ def multi_steps(self, x, n_time: Optional[float]):
Returns:
out: Array. The encoded spike train.
"""
- n_time = int(n_time / bm.get_dt())
+ # ``n_time=None`` means "encode the current single step" (see docstring);
+ # only convert to a step count when an actual duration is given.
+ n_time = None if n_time is None else int(n_time / bm.get_dt())
- if (self.min_val is not None) and (self.max_val is not None):
- x = (x - self.min_val) / (self.max_val - self.min_val)
- x = x * self.gain + self.offset
+ x = self._normalize(x)
if n_time is not None and self.first_spk_step > 0:
pre = bm.zeros((self.first_spk_step,) + x.shape, dtype=x.dtype)
shape = ((n_time - self.first_spk_step,) + x.shape)
diff --git a/brainpy/integrators/fde/Caputo.py b/brainpy/integrators/fde/Caputo.py
index 0f48fc586..23afc6000 100644
--- a/brainpy/integrators/fde/Caputo.py
+++ b/brainpy/integrators/fde/Caputo.py
@@ -147,7 +147,8 @@ def __init__(
f'but we got {self.alpha}.')
# initial values
- self.inits = check_inits(inits, self.variables)
+ inits = check_inits(inits, self.variables)
+ self.inits = bm.VarDict({v: bm.Variable(inits[v]) for v in self.variables})
# coefficients
rgamma_alpha = bm.asarray(rgamma(bm.as_numpy(self.alpha)))
@@ -163,6 +164,14 @@ def __init__(
self.set_integral(self._integral_func)
+ def reset(self, inits):
+ """Reset the integrator states so it can be re-run from new initial values."""
+ self.idx.value = bm.asarray([1])
+ inits = check_inits(inits, self.variables)
+ for key, val in inits.items():
+ self.inits[key] = val
+ self.f_states[key] = bm.zeros((self.num_memory,) + val.shape, dtype=self.f_states[key].dtype)
+
def _check_step(self, args):
dt, t = args
raise ValueError(f'The maximum number of step is {self.num_memory}, '
@@ -198,8 +207,8 @@ def _integral_func(self, *args, **kwargs):
integrals = []
idx = ((self.num_memory - 1 - self.idx) + bm.arange(self.num_memory)) % self.num_memory
for i, key in enumerate(self.variables):
- integral = self.inits[key] + self.coef[idx, i] @ self.f_states[key]
- integrals.append(integral * (dt ** self.alpha[i] / self.alpha[i]))
+ integral = self.coef[idx, i] @ self.f_states[key]
+ integrals.append(self.inits[key] + integral * (dt ** self.alpha[i] / self.alpha[i]))
self.idx.value = (self.idx + 1) % self.num_memory
# return integrals
@@ -372,7 +381,7 @@ def hists(self, var=None, numpy=True):
for k in self.variables}
hists_ = {k: bm.cumsum(v, axis=0) for k, v in hists_.items()}
if numpy:
- hists_ = {k: v.numpy() for k, v in hists_}
+ hists_ = {k: v.numpy() for k, v in hists_.items()}
return hists_
else:
assert var in self.variables, (f'"{var}" is not defined in equation '
diff --git a/brainpy/integrators/fde/GL.py b/brainpy/integrators/fde/GL.py
index 29f572e66..76c3fc555 100644
--- a/brainpy/integrators/fde/GL.py
+++ b/brainpy/integrators/fde/GL.py
@@ -184,7 +184,7 @@ def reset(self, inits):
for key, val in inits.items():
delay = bm.zeros((self.num_memory,) + val.shape, dtype=val.dtype)
delay[0] = val
- self.delays[key].value = delay
+ self.delays[key + '_delay'].value = delay
@property
def binomial_coef(self):
diff --git a/brainpy/integrators/fde/generic.py b/brainpy/integrators/fde/generic.py
index 5a349f82a..be2f6890c 100644
--- a/brainpy/integrators/fde/generic.py
+++ b/brainpy/integrators/fde/generic.py
@@ -24,7 +24,7 @@
name2method = {}
-_DEFAULT_DDE_METHOD = 'l1'
+_DEFAULT_FDE_METHOD = 'l1'
def fdeint(
@@ -60,7 +60,7 @@ def fdeint(
integral : FDEIntegrator
The numerical solver of `f`.
"""
- method = _DEFAULT_DDE_METHOD if method is None else method
+ method = _DEFAULT_FDE_METHOD if method is None else method
if method not in name2method:
raise ValueError(f'Unknown FDE numerical method "{method}". Currently '
f'BrainPy supports: {list(name2method.keys())}')
@@ -72,7 +72,7 @@ def fdeint(
def set_default_fdeint(method):
- """Set the default ODE numerical integrator method for differential equations.
+ """Set the default FDE numerical integrator method for fractional differential equations.
Parameters::
@@ -82,25 +82,25 @@ def set_default_fdeint(method):
if not isinstance(method, str):
raise ValueError(f'Only support string, not {type(method)}.')
if method not in name2method:
- raise ValueError(f'Unsupported ODE_INT numerical method: {method}.')
+ raise ValueError(f'Unsupported FDE numerical method: {method}.')
- global _DEFAULT_DDE_METHOD
- _DEFAULT_ODE_METHOD = method
+ global _DEFAULT_FDE_METHOD
+ _DEFAULT_FDE_METHOD = method
def get_default_fdeint():
- """Get the default ODE numerical integrator method.
+ """Get the default FDE numerical integrator method.
Returns::
method : str
The default numerical integrator method.
"""
- return _DEFAULT_DDE_METHOD
+ return _DEFAULT_FDE_METHOD
def register_fde_integrator(name, integrator):
- """Register a new ODE integrator.
+ """Register a new FDE integrator.
Parameters::
@@ -117,5 +117,5 @@ def register_fde_integrator(name, integrator):
def get_supported_methods():
- """Get all supported numerical methods for DDEs."""
+ """Get all supported numerical methods for FDEs."""
return list(name2method.keys())
diff --git a/brainpy/integrators/joint_eq.py b/brainpy/integrators/joint_eq.py
index 55c1a3af5..a4954963d 100644
--- a/brainpy/integrators/joint_eq.py
+++ b/brainpy/integrators/joint_eq.py
@@ -186,7 +186,12 @@ def __init__(self, *eqs):
elif (key not in vars_in_eqs) and (key not in all_arg_pars):
all_kwarg_pars[key] = value
else:
- raise DiffEqError
+ raise DiffEqError(
+ f'The keyword argument "{key}" conflicts with an existing name '
+ f'in the joint equations: it is already used as a state variable '
+ f'or a positional parameter. A keyword argument cannot reuse the '
+ f'name of a state variable or positional parameter.'
+ )
# # variable names provided
# if not isinstance(variables, (tuple, list)):
diff --git a/brainpy/integrators/ode/adaptive_rk.py b/brainpy/integrators/ode/adaptive_rk.py
index ee6864f3d..85826c2ab 100644
--- a/brainpy/integrators/ode/adaptive_rk.py
+++ b/brainpy/integrators/ode/adaptive_rk.py
@@ -184,7 +184,7 @@ def __init__(self,
keywords['error'] = 'the local truncation error'
for v in self.variables:
keywords[f'{v}_te'] = 'the local truncation error'
- self.code_scope['tol'] = tol
+ self.code_scope['tol'] = self.tol
self.code_scope['math'] = jnp
utils.check_kws(self.arg_names, keywords)
@@ -212,13 +212,20 @@ def build(self):
result.append(f'd{v}_k{i + 1} * {C.DT} * {diff}')
if len(result) > 0:
if self.var_type == C.SCALAR_VAR:
- self.code_lines.append(f' {v}_te = abs({" + ".join(result)})')
+ self.code_lines.append(f' {v}_te = math.abs({" + ".join(result)})')
else:
- self.code_lines.append(f' {v}_te = sum(abs({" + ".join(result)}))')
+ self.code_lines.append(f' {v}_te = math.sum(math.abs({" + ".join(result)}))')
errors_.append(f'{v}_te')
if len(errors_) > 0:
self.code_lines.append(f' error = {" + ".join(errors_)}')
- self.code_lines.append(f' {C.DT}_new = math.where(error > tol, 0.9*{C.DT}*(tol/error)**0.2, {C.DT})')
+ # Two-sided step-size controller: shrink dt when the error is
+ # above tolerance, grow it when the error is comfortably below.
+ # The growth/shrink factor is clamped to keep the step change
+ # bounded so that dt can both decrease and increase.
+ self.code_lines.append(
+ f' factor = 0.9 * (tol / (error + 1e-12)) ** 0.2')
+ self.code_lines.append(' factor = math.clip(factor, 0.2, 5.0)')
+ self.code_lines.append(f' {C.DT}_new = {C.DT} * factor')
return_args.append(f'{C.DT}_new')
# returns
self.code_lines.append(f' return {", ".join(return_args)}')
@@ -529,7 +536,7 @@ class BoSh3(AdaptiveRKIntegrator):
(0.0, 0.75),
('2/9', '1/3', '4/9')]
B1 = ['2/9', '1/3', '4/9', 0.0]
- B2 = ['-5/72', 1 / 12, '1/9', '-1/8']
+ B2 = ['7/24', 0.25, '1/3', 0.125]
C = [0., 0.5, 0.75, 1.0]
diff --git a/brainpy/integrators/ode/exponential.py b/brainpy/integrators/ode/exponential.py
index 4a20eb911..6d4ff9890 100644
--- a/brainpy/integrators/ode/exponential.py
+++ b/brainpy/integrators/ode/exponential.py
@@ -370,16 +370,17 @@ def _build_integrator(self, eq):
# integration function
def integral(*args, **kwargs):
assert len(args) > 0
- if args[0].dtype not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]:
+ x0 = bm.as_jax(args[0])
+ if x0.dtype not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]:
raise ValueError(
'The input data type should be float32, float64, float16, '
'or bfloat16 when using Exponential Euler method.'
- f'But we got {args[0].dtype}.'
+ f'But we got {x0.dtype}.'
)
dt = kwargs.pop(C.DT, self.dt)
linear, derivative = bm.vector_grad(eq, argnums=0, return_value=True)(*args, **kwargs)
phi = bm.exprel(dt * linear)
- return args[0] + dt * phi * derivative
+ return x0 + dt * phi * derivative
return [(integral, vars, pars), ]
diff --git a/brainpy/integrators/runner.py b/brainpy/integrators/runner.py
index 54c78a9f2..b429e2647 100644
--- a/brainpy/integrators/runner.py
+++ b/brainpy/integrators/runner.py
@@ -251,8 +251,8 @@ def _step_fun_integrator(self, static_args, dyn_args, t, i):
if len(self.target.variables) == 1:
self.variables[self.target.variables[0]].update(update_values)
else:
- for i, v in enumerate(self.target.variables):
- self.variables[v].update(update_values[i])
+ for j, v in enumerate(self.target.variables):
+ self.variables[v].update(update_values[j])
# progress bar
if self.progress_bar:
diff --git a/brainpy/integrators/sde/base.py b/brainpy/integrators/sde/base.py
index a43de2f72..c68437e57 100644
--- a/brainpy/integrators/sde/base.py
+++ b/brainpy/integrators/sde/base.py
@@ -17,6 +17,7 @@
import jax.numpy as jnp
+from brainpy import _errors as errors
from brainpy import math as bm
from brainpy.integrators import constants, utils
from brainpy.integrators.base import Integrator
@@ -95,5 +96,14 @@ def __init__(
self.show_code = show_code
def _check_vector_wiener_dim(self, noise_size, var_size):
- if noise_size[:-1] > var_size[-len(noise_size) + 1:]:
- raise ValueError(f"Incompatible shapes for shapes of noise {noise_size} and variable {var_size}")
+ noise_size = tuple(noise_size)
+ var_size = tuple(var_size)
+ # For a vector Wiener process the diffusion value has shape
+ # ``var_size + (m,)``: its leading dimensions (all but the last, which
+ # holds the ``m`` noise channels) must match the variable shape exactly.
+ if tuple(noise_size[:-1]) != var_size:
+ raise ValueError(
+ f"Incompatible shapes for vector Wiener process: the diffusion "
+ f"value has shape {noise_size}, so its leading dimensions "
+ f"{noise_size[:-1]} must equal the variable shape {var_size}."
+ )
diff --git a/brainpy/integrators/sde/normal.py b/brainpy/integrators/sde/normal.py
index 0080d0e63..3165b81ac 100644
--- a/brainpy/integrators/sde/normal.py
+++ b/brainpy/integrators/sde/normal.py
@@ -17,6 +17,7 @@
import jax.numpy as jnp
+from brainpy import _errors as errors
from brainpy import math as bm
from brainpy.integrators import constants, utils, joint_eq
from brainpy.integrators.constants import DT
diff --git a/brainpy/integrators/sde/srk_strong.py b/brainpy/integrators/sde/srk_strong.py
deleted file mode 100644
index 98faefdd3..000000000
--- a/brainpy/integrators/sde/srk_strong.py
+++ /dev/null
@@ -1,473 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-from brainpy import math
-from brainpy.integrators import constants, utils
-
-__all__ = [
- 'srk1_strong',
-]
-
-_SDE_UNKNOWN_NO = 0
-
-
-def basic_info(f, g):
- vdt = 'dt'
- if f.__name__.isidentifier():
- func_name = f.__name__
- elif g.__name__.isidentifier():
- func_name = g.__name__
- else:
- global _SDE_UNKNOWN_NO
- func_name = f'unknown_sde{_SDE_UNKNOWN_NO}'
- func_new_name = constants.SDE_INT + func_name
- variables, parameters, arguments = utils.get_args(f)
- return vdt, variables, parameters, arguments, func_new_name
-
-
-def _vector_wiener_terms(code_lines, sde_type, vdt, shape_D, shape_m):
- if sde_type == constants.ITO_SDE:
- I2 = f'0.5*(_term3 - {vdt} * math.eye({shape_m})) + _a*0.5*{vdt}/math.pi'
- elif sde_type == constants.STRA_SDE:
- I2 = f'0.5*_term3 + _a*0.5*dt/math.pi'
- else:
- raise ValueError(f'Unknown SDE_INT type: {sde_type}. We only supports {constants.SUPPORTED_INTG_TYPE}.')
-
- if shape_D:
- shape_D = shape_D + '+'
-
- noise_string = f'''
- # Noise Terms #
- # ----------- #
-
- # single Ito integrals
- _I1 = math.normal(0., {vdt}_sqrt, {shape_D}({shape_m},))
- # double Ito integrals
- _h = (2.0 / {vdt}) ** 0.5)
- _a = math.zeros(shape={shape_D}({shape_m}, {shape_m}))
- for _k in range(1, num_iter + 1):
- _x = math.normal(loc=0., scale=1., size={shape_D}({shape_m}, 1))
- _y = math.normal(loc=0., scale=1., size={shape_D}(1, {shape_m})) + _h * _I1
- _term1 = math.matmul(_x, _y)
- _term2 = math.matmul(math.reshape(_y, {shape_D}({shape_m}, 1)),
- math.reshape(_x, {shape_D}(1, {shape_m})))
- _a += (_term1 - _term2) / _k
- _I1_rs = math.reshape(_I1, {shape_D}({shape_m}, 1))
- _term3 = math.matmul(_I1_rs, math.reshape(_I1, {shape_D}(1, {shape_m})))
- _I2 = {I2}
- '''
- noise_lines = noise_string.split('\n')
- code_lines.extend(noise_lines)
-
-
-# ----------
-# Wrapper
-# ----------
-
-
-def _srk2_pop_var_vector_wiener(sde_type, code_lines, variables, parameters, vdt):
- # shape information
- # -----
- all_f = [f'f_{var}' for var in variables]
- all_g = [f'g_{var}' for var in variables]
- noise_string = f'''
- {", ".join(all_f)} = f({", ".join(variables + parameters)}) # shape = (..)
- {", ".join(all_g)} = g({", ".join(variables + parameters)}) # shape = (.., m)
- noise_shape = math.shape(g_x1)
- _D = noise_shape[:-1]
- _m = noise_shape[-1]
- '''
- code_lines.extend(noise_string.split("\n"))
-
- # noise terms
- _vector_wiener_terms(code_lines, sde_type, vdt, shape_D='_D', shape_m='_m')
-
- # numerical integration
- # step 1
- # ---
- # g_x1_rs = math.reshape(g_x1, _D + (1, _m))
- # g_x2_rs = math.reshape(g_x2, _D + (1, _m))
- for var in variables:
- code_lines.append(f" g_{var}_rs = math.reshape(g_{var}, _D+(1, _m))")
- # step 2
- # ---
- # g_H1_x1 = math.reshape(math.matmul(g_x1_rs, _I2) / dt_sqrt, _D + (_m,))
- # g_H1_x2 = math.reshape(math.matmul(g_x2_rs, _I2) / dt_sqrt, _D + (_m,))
- for var in variables:
- code_lines.append(f' g_H1_{var} = math.reshape(math.matmul(g_{var}_rs, _I2) / {vdt}_sqrt, _D + (_m,))')
- # step 3
- # ---
- # x1_rs = math.reshape(x1, _D + (1,))
- # x2_rs = math.reshape(x2, _D + (1,))
- for var in variables:
- code_lines.append(f' {var}_rs = math.reshape({var}, _D + (1,))')
- # step 4
- # ---
- # H2_x1 = x1_rs + g_H1_x1
- # H3_x1 = x1_rs - g_H1_x1
- for var in variables:
- code_lines.append(f' H2_{var} = {var}_rs + g_H1_{var}')
- code_lines.append(f' H3_{var} = {var}_rs - g_H1_{var}')
- code_lines.append(' ')
- # step 5
- # ---
- # _g_x1 = math.matmul(g_x1_rs, _I1_rs)
- for var in variables:
- code_lines.append(f' _g_{var} = math.matmul(g_{var}_rs, _I1_rs)')
- # step 6
- # ----
- # x1_new = x1 + f_x1 + _g_x1[..., 0, 0]
- for var in variables:
- code_lines.append(f' {var}_new = {var} + f_{var} + _g_{var}[..., 0, 0]')
- # for _k in range(_m):
- code_lines.append('for _k in range(_m):')
- # g_x1_H2, g_x2_H2 = g(H2_x1[..., _k], H2_x2[..., _k], t, *args)
- all_H2 = [f'H2_{var}[..., _k]' for var in variables]
- all_g_H2 = [f'g_{var}_H2' for var in variables]
- code_lines.append(f' {", ".join(all_g_H2)} = g({", ".join(all_H2 + parameters)})')
- # g_x1_H3, g_x2_H3 = g(H3_x1[..., _k], H3_x2[..., _k], t, *args)
- all_H3 = [f'H3_{var}[..., _k]' for var in variables]
- all_g_H3 = [f'g_{var}_H3' for var in variables]
- code_lines.append(f' {", ".join(all_g_H3)} = g({", ".join(all_H3 + parameters)})')
- # x1_new += 0.5 * dt_sqrt * (g_x1_H2[..., _k] - g_x1_H3[..., _k])
- # x2_new += 0.5 * dt_sqrt * (g_x2_H2[..., _k] - g_x2_H3[..., _k])
- for var in variables:
- code_lines.append(f' {var}_new += 0.5 * {vdt}_sqrt * (g_{var}_H2[..., _k] - g_{var}_H3[..., _k])')
-
-
-def _srk2_pop_or_scalar_var_scalar_wiener(sde_type, code_lines, variables, parameters, vdt):
- if sde_type == constants.ITO_SDE:
- I2 = f'0.5 * (_I1 * _I1 - {vdt})'
- elif sde_type == constants.STRA_SDE:
- I2 = f'0.5 * _I1 * _I1'
- else:
- raise ValueError(f'Unknown SDE_INT type: {sde_type}. We only supports {constants.SUPPORTED_INTG_TYPE}.')
-
- # shape info
- # -----
- all_f = [f'f_{var}' for var in variables]
- all_g = [f'g_{var}' for var in variables]
-
- code_string = f'''
- {", ".join(all_f)} = f({", ".join(variables + parameters)}) # shape = (..)
- {", ".join(all_g)} = g({", ".join(variables + parameters)}) # shape = (..)
-
- # single Ito integrals
- _I1 = math.normal(0., {vdt}_sqrt, math.shape({variables[0]})) # shape = (..)
- # double Ito integrals
- _I2 = {I2} # shape = (..)
- '''
- code_splits = code_string.split('\n')
- code_lines.extend(code_splits)
-
- # numerical integration
- # -----
- # H1
- for var in variables:
- code_lines.append(f' g_H1_{var} = g_{var} * _I2 / {vdt}_sqrt # shape (.., )')
- # H2
- all_H2 = [f'H2_{var}' for var in variables]
- for var in variables:
- code_lines.append(f' H2_{var} = {var} + g_H1_{var} # shape (.., )')
- all_g_H2 = [f'g_{var}_H2' for var in variables]
- code_lines.append(f' {", ".join(all_g_H2)} = g({", ".join(all_H2 + parameters)})')
- code_lines.append(f' ')
- # H3
- all_H3 = [f'H3_{var}' for var in variables]
- for var in variables:
- code_lines.append(f' H3_{var} = {var} - g_H1_{var} # shape (.., )')
- all_g_H3 = [f'g_{var}_H3' for var in variables]
- code_lines.append(f' {", ".join(all_g_H3)} = g({", ".join(all_H3 + parameters)})')
- code_lines.append(f' ')
- # final results
- for var in variables:
- code_lines.append(f' {var}_new = {var} + f_{var} + g_{var} * _I1 '
- f'+ 0.5 * {vdt}_sqrt * (g_{var}_H2 - g_{var}_H3)')
-
-
-def _srk1_scalar_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt):
- # shape information
- all_f = [f'f_{var}' for var in variables]
- all_g = [f'g_{var}' for var in variables]
- code1 = f'''
- # shape info #
- # ---------- #
-
- {", ".join(all_f)} = f({", ".join(variables + parameters)}) # shape = ()
- {", ".join(all_g)} = g({", ".join(variables + parameters)}) # shape = (m)
- noise_shape = math.shape(g_x1)
- _m = noise_shape[0]
- '''
- code_lines.extend(code1.split('\n'))
-
- # noise term
- _vector_wiener_terms(code_lines, sde_type, vdt, shape_D='', shape_m='_m')
-
- # numerical integration
-
- # p1
- # ---
- # g_x1_rs = math.reshape(g_x1, (1, _m))
- # g_x2_rs = math.reshape(g_x2, (1, _m))
- for var in variables:
- code_lines.append(f' g_{var}_rs = math.reshape(g_{var}, (1, _m))')
-
- # p2
- # ---
- # g_H1_x1 = math.matmul(g_x1_rs, _I2) / dt_sqrt # shape (1, m)
- # g_H1_x2 = math.matmul(g_x2_rs, _I2) / dt_sqrt # shape (1, m)
- for var in variables:
- code_lines.append(f' g_H1_{var} = math.matmul(g_{var}_rs, _I2) / {vdt}_sqrt # shape (1, m)')
-
- # p3
- # ---
- # H2_x1 = x1 + g_H1_x1[0] # shape (m)
- # H3_x1 = x1 - g_H1_x1[0] # shape (m)
- for var in variables:
- code_lines.append(f' H2_{var} = {var} + g_H1_{var}[0] # shape (m)')
- code_lines.append(' ')
-
- # p4
- # ---
- # g1_x1 = math.matmul(g_x1_rs, _I1_rs) # shape (1, 1)
- # x1_new = x1 + f_x1 + g1_x1[0, 0] # shape ()
- for var in variables:
- code_lines.append(f' g1_{var} = math.matmul(g_{var}_rs, _I1_rs) # shape (1, 1)')
- code_lines.append(f' {var}_new = {var} + f_{var} + g1_{var}[0, 0] # shape ()')
-
- # p5
- # ---
- # for _k in range(_m):
- # g_x1_H2, g_x2_H2 = g(H2_x1[_k], H2_x2[_k], t, *args)
- # g_x1_H3, g_x2_H3 = g(H3_x1[_k], H3_x2[_k], t, *args)
- # x1_new += 0.5 * dt_sqrt * (g_x1_H2[_k] - g_x1_H3[_k])
- # x2_new += 0.5 * dt_sqrt * (g_x2_H2[_k] - g_x2_H3[_k])
- code_lines.append(' for _k in range(_m):')
- all_h2_k = [f'H2_{var}[_k]' for var in variables]
- all_g_h2 = [f'g_{var}_H2' for var in variables]
- code_lines.append(f' {", ".join(all_g_h2)} = g({", ".join(all_h2_k + parameters)})')
- all_h3_k = [f'H3_{var}[_k]' for var in variables]
- all_g_h3 = [f'g_{var}_H3' for var in variables]
- code_lines.append(f' {", ".join(all_g_h3)} = g({", ".join(all_h3_k + parameters)})')
- for var in variables:
- code_lines.append(f' {var}_new += 0.5 * {vdt}_sqrt * (g_{var}_H2[_k] - g_{var}_H3[_k])')
-
-
-def _srk1_system_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt):
- # shape information
- code1 = f'''
- # shape infor #
- # ----------- #
-
- f_x = f({", ".join(variables + parameters)}) # shape = (d, ..)
- g_x = g({", ".join(variables + parameters)}) # shape = (d, .., m)
- _shape = math.shape(g_x)
- _d = _shape[0]
- _m = _shape[-1]
- _D = _shape[1:-1]
- '''
- code_lines.extend(code1.split('\n'))
-
- # noise term
- _vector_wiener_terms(code_lines, sde_type, vdt, shape_D='_D', shape_m='_m')
-
- # numerical integration
- code2 = f'''
- # numerical integration #
- # --------------------- #
-
- g_x2 = math.moveaxis(g_x, 0, -2) # shape = (.., d, m)
- g_H1_k = math.matmul(g_x2, _I2) / dt_sqrt # shape (.., d, m)
- g_H1_k = math.moveaxis(g_H1_k, -2, 0) # shape (d, .., m)
- x_rs = math.reshape(x, (_d,) + _D + (1,))
- H2 = x_rs + g_H1_k # shape (d, .., m)
- H3 = x_rs - g_H1_k # shape (d, .., m)
-
- g1 = math.matmul(g_x2, _I1_rs) # shape (.., d, 1)
- g1 = math.moveaxis(g1, -2, 0) # shape (d, .., 1)
- y = x + f_x + g1[..., 0] # shape (d, ..)
- for _k in range(_m):
- y += 0.5 * dt_sqrt * g(H2[..., _k], t, *args)[..., _k]
- y -= 0.5 * dt_sqrt * g(H3[..., _k], t, *args)[..., _k]
- '''
- code_lines.extend(code2.split('\n'))
-
-
-def _srk1_system_var_with_scalar_wiener(sde_type, code_lines, variables, parameters, vdt):
- if sde_type == constants.ITO_SDE:
- I2 = f'0.5 * (_I1 * _I1 - {vdt})'
- elif sde_type == constants.STRA_SDE:
- I2 = f'0.5 * _I1 * _I1'
- else:
- raise ValueError(f'Unknown SDE_INT type: {sde_type}. We only supports {constants.SUPPORTED_INTG_TYPE}.')
-
- code_string = f'''
- f_x = f({", ".join(variables + parameters)}) # shape = (d, ..)
- g_x = g({", ".join(variables + parameters)}) # shape = (d, ..)
- _shape = math.shape(g_x)
- _d = _shape[0]
- _D = _shape[1:]
-
- # single Ito integrals
- _I1 = math.normal(0., {vdt}_sqrt, _D) # shape = (..)
- # double Ito integrals
- _I2 = {I2} # shape = (..)
-
- # numerical integration #
- # --------------------- #
- g_H1_k = g_x * _I2 / {vdt}_sqrt # shape (d, ..)
- H2 = x + g_H1_k # shape (d, ..)
- H3 = x - g_H1_k # shape (d, ..)
-
- g1 = g_x * _I1 # shape (d, ..)
- x_new = x + f_x + g1 # shape (d, ..)
- x_new += 0.5 * {vdt}_sqrt * g(H2, {", ".join(parameters)})
- x_new -= 0.5 * {vdt}_sqrt * g(H3, {", ".join(parameters)})
- '''
- code_splits = code_string.split('\n')
- code_lines.extend(code_splits)
-
-
-def _srk1_wrapper(f, g, dt, sde_type, var_type, wiener_type, show_code, num_iter):
- vdt, variables, parameters, arguments, func_name = basic_info(f=f, g=g)
-
- # 1. code scope
- code_scope = {'f': f, 'g': g, vdt: dt, f'{vdt}_sqrt': dt ** 0.5,
- 'math': math, 'num_iter': num_iter}
-
- # 2. code lines
- code_lines = [f'def {func_name}({", ".join(arguments)}):']
-
- if var_type == constants.SYSTEM_VAR:
- if len(variables) > 1:
- raise ValueError(f'SDE_INT with {constants.SYSTEM_VAR} variable type only '
- f'supports one system variable. But we got {variables}.')
-
- if wiener_type == constants.SCALAR_WIENER:
- _srk1_system_var_with_scalar_wiener(sde_type, code_lines, variables, parameters, vdt)
- elif wiener_type == constants.VECTOR_WIENER:
- _srk1_system_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt)
- else:
- raise ValueError(f'Unknown Wiener type: {wiener_type}, we only '
- f'supports {constants.SUPPORTED_WIENER_TYPE}')
-
- elif var_type == constants.SCALAR_VAR:
- if wiener_type == constants.SCALAR_WIENER:
- _srk2_pop_or_scalar_var_scalar_wiener(sde_type, code_lines, variables, parameters, vdt)
- elif wiener_type == constants.VECTOR_WIENER:
- _srk1_scalar_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt)
- else:
- raise ValueError(f'Unknown Wiener type: {wiener_type}, we only '
- f'supports {constants.SUPPORTED_WIENER_TYPE}')
-
- elif var_type == constants.POP_VAR:
- if wiener_type == constants.SCALAR_WIENER:
- _srk2_pop_or_scalar_var_scalar_wiener(sde_type, code_lines, variables, parameters, vdt)
- elif wiener_type == constants.VECTOR_WIENER:
- _srk2_pop_var_vector_wiener(sde_type, code_lines, variables, parameters, vdt)
- else:
- raise ValueError(f'Unknown Wiener type: {wiener_type}, we only '
- f'supports {constants.SUPPORTED_WIENER_TYPE}')
-
- else:
- raise ValueError(f'Unknown var type: {var_type}, we only '
- f'supports {constants.SUPPORTED_VAR_TYPE}')
- # returns
- new_vars = [f'{var}_new' for var in variables]
- code_lines.append(f' return {", ".join(new_vars)}')
-
- # return and compile
- utils.compile_code(code_lines, code_scope, show_code, variables)
- return code_scope[func_name]
-
-
-def _srk2_wrapper():
- pass
-
-
-def _wrap(wrapper, f, g, dt, sde_type, var_type, wiener_type, show_code, num_iter):
- """The brainpy_object function to format a SRK method.
-
- Parameters::
-
- f : callable
- The drift function of the SDE_INT.
- g : callable
- The diffusion function of the SDE_INT.
- dt : float
- The numerical precision.
- sde_type : str
- "utils.ITO_SDE" : Ito's Stochastic Calculus.
- "utils.STRA_SDE" : Stratonovich's Stochastic Calculus.
- wiener_type : str
- var_type : str
- "scalar" : with the shape of ().
- "population" : with the shape of (N,) or (N1, N2) or (N1, N2, ...).
- "system": with the shape of (d, ), (d, N), or (d, N1, N2).
- show_code : bool
- Whether show the formatted code.
-
- Returns::
-
- numerical_func : callable
- The numerical function.
- """
-
- sde_type = constants.ITO_SDE if sde_type is None else sde_type
- assert sde_type in constants.SUPPORTED_INTG_TYPE, f'Currently, BrainPy only support SDE_INT types: ' \
- f'{constants.SUPPORTED_INTG_TYPE}. But we got {sde_type}.'
-
- var_type = constants.POP_VAR if var_type is None else var_type
- assert var_type in constants.SUPPORTED_VAR_TYPE, f'Currently, BrainPy only supports variable types: ' \
- f'{constants.SUPPORTED_VAR_TYPE}. But we got {var_type}.'
-
- wiener_type = constants.SCALAR_WIENER if wiener_type is None else wiener_type
- assert wiener_type in constants.SUPPORTED_WIENER_TYPE, f'Currently, BrainPy only supports Wiener ' \
- f'Process types: {constants.SUPPORTED_WIENER_TYPE}. ' \
- f'But we got {wiener_type}.'
-
- show_code = False if show_code is None else show_code
- dt = math.get_dt() if dt is None else dt
- num_iter = 10 if num_iter is None else num_iter
-
- if f is not None and g is not None:
- return wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
- var_type=var_type, wiener_type=wiener_type, num_iter=num_iter)
-
- elif f is not None:
- return lambda g: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
- var_type=var_type, wiener_type=wiener_type, num_iter=num_iter)
-
- elif g is not None:
- return lambda f: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
- var_type=var_type, wiener_type=wiener_type, num_iter=num_iter)
-
- else:
- raise ValueError('Must provide "f" or "g".')
-
-
-# ------------------
-# Numerical methods
-# ------------------
-
-
-def srk1_strong(f=None, g=None, dt=None, sde_type=None, var_type=None, wiener_type=None, num_iter=None, show_code=None):
- return _wrap(_srk1_wrapper, f=f, g=g, dt=dt, sde_type=sde_type, var_type=var_type,
- wiener_type=wiener_type, show_code=show_code, num_iter=num_iter)
-
-
-def srk2_strong(f=None, g=None, dt=None, sde_type=None, var_type=None, wiener_type=None, num_iter=None, show_code=None):
- return _wrap(_srk2_wrapper, f=f, g=g, dt=dt, sde_type=sde_type, var_type=var_type,
- wiener_type=wiener_type, show_code=show_code, num_iter=num_iter)
diff --git a/brainpy/losses/comparison.py b/brainpy/losses/comparison.py
index 67330b31a..c96644be2 100644
--- a/brainpy/losses/comparison.py
+++ b/brainpy/losses/comparison.py
@@ -198,10 +198,12 @@ def __init__(self, weight: Optional[ArrayType] = None, ignore_index: int = -100,
self.label_smoothing = label_smoothing
def update(self, input: ArrayType, target: ArrayType) -> ArrayType:
- return cross_entropy_loss(input, target, weight=self.weight, reduction=self.reduction)
+ return cross_entropy_loss(input, target, weight=self.weight, reduction=self.reduction,
+ ignore_index=self.ignore_index, label_smoothing=self.label_smoothing)
-def cross_entropy_loss(predicts, targets, weight=None, reduction='mean'):
+def cross_entropy_loss(predicts, targets, weight=None, reduction='mean',
+ ignore_index=-100, label_smoothing=0.0):
r"""This criterion combines ``LogSoftmax`` and `NLLLoss`` in one single class.
It is useful when training a classification problem with `C` classes.
@@ -260,11 +262,46 @@ def cross_entropy_loss(predicts, targets, weight=None, reduction='mean'):
"""
def _cel(_pred, _tar):
+ _pred = bm.as_jax(_pred)
+ num_classes = _pred.shape[-1]
+ # Per-sample class weight. The ``weight`` argument is indexed *by the
+ # target class* (``weight[y_n]``), not by the sample position.
+ sample_weight = None
+ # Mask of samples that should contribute to the loss (``ignore_index``).
+ valid_mask = None
if bm.ndim(_tar) + 1 == bm.ndim(_pred):
- _tar = bm.one_hot(_tar, _pred.shape[-1])
- loss = logsumexp(bm.as_jax(_pred), axis=-1) - (_pred * _tar).sum(axis=-1)
- if weight is not None:
- loss *= weight
+ # ``_tar`` holds integer class indices.
+ _tar_idx = bm.as_jax(_tar)
+ if weight is not None:
+ sample_weight = bm.as_jax(weight)[_tar_idx]
+ valid_mask = (_tar_idx != ignore_index)
+ # Build the (possibly label-smoothed) soft target distribution. Clamp
+ # ignored indices to 0 first so one_hot does not error on negatives.
+ _tar_clamped = jnp.where(valid_mask, _tar_idx, 0)
+ _soft = bm.as_jax(bm.one_hot(_tar_clamped, num_classes))
+ if label_smoothing > 0.0:
+ _soft = _soft * (1.0 - label_smoothing) + label_smoothing / num_classes
+ else:
+ # ``_tar`` holds class probabilities / one-hot: the effective per-sample
+ # weight is the probability-weighted class weight (matches PyTorch).
+ _soft = bm.as_jax(_tar)
+ if label_smoothing > 0.0:
+ _soft = _soft * (1.0 - label_smoothing) + label_smoothing / num_classes
+ if weight is not None:
+ sample_weight = (bm.as_jax(weight) * _soft).sum(axis=-1)
+ loss = logsumexp(_pred, axis=-1) - (_pred * _soft).sum(axis=-1)
+ if sample_weight is not None:
+ loss = loss * sample_weight
+ if valid_mask is not None:
+ # Zero-out ignored samples so they contribute nothing to sum/mean.
+ loss = jnp.where(valid_mask, loss, 0.0)
+ if reduction == 'mean':
+ if sample_weight is not None:
+ denom = sample_weight if valid_mask is None else jnp.where(valid_mask, sample_weight, 0.0)
+ return loss.sum() / denom.sum()
+ if valid_mask is not None:
+ return loss.sum() / jnp.maximum(valid_mask.sum(), 1)
+ return loss.mean()
return _reduce(outputs=loss, reduction=reduction)
r = tree_map(_cel, predicts, targets, is_leaf=_is_leaf)
@@ -458,7 +495,9 @@ def nll_loss(input, target, reduction: str = 'mean'):
assert target.ndim + 1 == input.ndim
input = bm.as_jax(input)
target = bm.as_jax(target)
- loss = input[jnp.arange(len(target)), target]
+ # Negative log-likelihood: l_n = -x_{n, y_n}. The leading minus sign is what
+ # makes this a *loss* to minimize (the raw log-probabilities are negative).
+ loss = -input[jnp.arange(len(target)), target]
if reduction == 'mean':
return loss.mean()
elif reduction == 'sum':
diff --git a/brainpy/math/_utils.py b/brainpy/math/_utils.py
index 02ea058d6..bb526a4d4 100644
--- a/brainpy/math/_utils.py
+++ b/brainpy/math/_utils.py
@@ -60,6 +60,7 @@ def new_fun(*args, **kwargs):
return tree_map(_return, r)
else:
out.value = r
+ return out
new_fun.__doc__ = (
f'Similar to ``jax.numpy.{module + fun.__name__}`` function, '
diff --git a/brainpy/math/activations.py b/brainpy/math/activations.py
index d8cedd35d..f3d66408c 100644
--- a/brainpy/math/activations.py
+++ b/brainpy/math/activations.py
@@ -665,7 +665,8 @@ def softmin(x, axis=-1):
along dim will sum to 1).
"""
x = x.value if isinstance(x, Array) else x
- unnormalized = jnp.exp(-x)
+ neg_x = -x
+ unnormalized = jnp.exp(neg_x - jax.lax.stop_gradient(neg_x.max(axis, keepdims=True)))
return unnormalized / unnormalized.sum(axis, keepdims=True)
diff --git a/brainpy/math/compat_numpy.py b/brainpy/math/compat_numpy.py
index 51937348d..02e5de9f3 100644
--- a/brainpy/math/compat_numpy.py
+++ b/brainpy/math/compat_numpy.py
@@ -130,7 +130,7 @@ def fill_diagonal(a, val, inplace=True):
if inplace:
a.value = r
else:
- return r
+ return _return(r)
def zeros(shape, dtype=None):
@@ -142,7 +142,7 @@ def ones(shape, dtype=None):
def empty(shape, dtype=None):
- return _return(jnp.zeros(shape, dtype=dtype))
+ return _return(jnp.empty(shape, dtype=dtype))
def zeros_like(a, dtype=None, shape=None):
@@ -157,7 +157,7 @@ def ones_like(a, dtype=None, shape=None):
def empty_like(a, dtype=None, shape=None):
a = _as_jax_array_(a)
- return _return(jnp.zeros_like(a, dtype=dtype, shape=shape))
+ return _return(jnp.empty_like(a, dtype=dtype, shape=shape))
def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array:
@@ -215,8 +215,8 @@ def ascontiguousarray(a, dtype=None, order=None):
def asfarray(a, dtype=None):
- if not np.issubdtype(dtype, np.inexact):
- dtype = np.float64
+ if dtype is None or not np.issubdtype(dtype, np.inexact):
+ dtype = jnp.float64
return asarray(a, dtype=dtype)
diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py
index 1728f55f7..d9939a64a 100644
--- a/brainpy/math/compat_pytorch.py
+++ b/brainpy/math/compat_pytorch.py
@@ -40,11 +40,12 @@
'asin',
'arcsin',
'asinh',
- 'arcsin',
+ 'arcsinh',
'atan',
'arctan',
'atan2',
'atanh',
+ 'arctanh',
'clamp_max',
'clamp_min',
'arctan2',
@@ -161,6 +162,7 @@ def abs(
else:
_check_out(out)
out.value = r
+ return out
absolute = abs
@@ -178,6 +180,7 @@ def acos(
else:
_check_out(out)
out.value = r
+ return out
arccos = acos
@@ -195,6 +198,7 @@ def acosh(
else:
_check_out(out)
out.value = r
+ return out
arccosh = acosh
@@ -224,6 +228,7 @@ def add(
else:
_check_out(out)
out.value = r
+ return out
def addcdiv(
@@ -266,6 +271,7 @@ def angle(
else:
_check_out(out)
out.value = r
+ return out
def asin(
@@ -280,6 +286,7 @@ def asin(
else:
_check_out(out)
out.value = r
+ return out
arcsin = asin
@@ -297,6 +304,7 @@ def asinh(
else:
_check_out(out)
out.value = r
+ return out
arcsinh = asinh
@@ -314,6 +322,7 @@ def atan(
else:
_check_out(out)
out.value = r
+ return out
arctan = atan
@@ -331,6 +340,7 @@ def atanh(
else:
_check_out(out)
out.value = r
+ return out
arctanh = atanh
@@ -350,6 +360,7 @@ def atan2(
else:
_check_out(out)
out.value = r
+ return out
arctan2 = atan2
diff --git a/brainpy/math/compat_tensorflow.py b/brainpy/math/compat_tensorflow.py
index ede493ab2..6f5a07063 100644
--- a/brainpy/math/compat_tensorflow.py
+++ b/brainpy/math/compat_tensorflow.py
@@ -14,8 +14,10 @@
# ==============================================================================
from typing import Union, Optional
+import jax
import jax.numpy as jnp
import jax.ops
+import jax.scipy.special
from jax import lax
from brainpy.math.interoperability import as_jax
@@ -74,7 +76,7 @@ def reduce_logsumexp(input_tensor, axis=None, keepdims=False):
Returns:
The reduced tensor.
"""
- r = jnp.log(jnp.sum(jnp.exp(_as_jax_array_(input_tensor)), axis=axis, keepdims=keepdims))
+ r = jax.scipy.special.logsumexp(_as_jax_array_(input_tensor), axis=axis, keepdims=keepdims)
return _return(r)
diff --git a/brainpy/math/delayvars.py b/brainpy/math/delayvars.py
index f9ff8fb9e..10d69b420 100644
--- a/brainpy/math/delayvars.py
+++ b/brainpy/math/delayvars.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+import inspect
import numbers
from typing import Union, Callable
@@ -45,6 +46,24 @@ def _as_jax_array(arr):
return arr.value if isinstance(arr, Array) else arr
+def _accepts_dtype_kwarg(func):
+ """Return True if ``func`` can be called with a ``dtype`` keyword argument.
+
+ Initializer/Connector instances accept ``(shape, dtype=...)``, whereas a plain
+ callable such as ``lambda shape: ...`` does not. When the signature cannot be
+ introspected (e.g. some built-ins or C-implemented callables), we conservatively
+ assume the ``dtype`` keyword is accepted to preserve the historical behaviour.
+ """
+ try:
+ sig = inspect.signature(func)
+ except (TypeError, ValueError):
+ return True
+ params = sig.parameters
+ if 'dtype' in params:
+ return True
+ return any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values())
+
+
class AbstractDelay(BrainPyObject):
pass
@@ -204,20 +223,31 @@ def reset(self,
The maximum delay length. The unit is the time.
t0: int, float
The zero time.
- before_t0: int, float, ArrayType
+ before_t0: callable, int, float, ArrayType
The data before t0.
+ - when ``before_t0`` is a function, it should receive a time argument ``t``
+ (mirroring the behaviour of ``__init__``).
+ - when ``before_t0`` is a tensor / numerical value, it is broadcast into the
+ delay data before ``t0``.
"""
+ self.t0 = t0
self.delay_len = delay_len
self.num_delay_step = int(jnp.ceil(self.delay_len / self.dt)) + 1
self.data.value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype)
self.data[-1] = delay_target
self.idx = Variable(jnp.asarray([0]))
- self.current_time = Variable(jnp.asarray([t0]))
- if before_t0 is not None:
- if not isinstance(before_t0, (ndarray, jnp.ndarray, float, int)):
- raise ValueError('Only support numerical values.')
+ self.current_time = Variable(jnp.asarray([t0], dtype=get_float()))
+ if before_t0 is None:
+ self._before_type = _DATA_BEFORE
+ elif callable(before_t0):
+ self._before_t0 = lambda t: as_jax(broadcast_to(before_t0(t), delay_target.shape),
+ dtype=delay_target.dtype)
+ self._before_type = _FUNC_BEFORE
+ elif isinstance(before_t0, (ndarray, jnp.ndarray, float, int)):
self.data[:-1] = before_t0
self._before_type = _DATA_BEFORE
+ else:
+ raise ValueError(f'"before_t0" does not support {type(before_t0)}')
def _check_time1(self, times):
prev_time, current_time = times
@@ -268,7 +298,7 @@ def _after_t0(self, prev_time):
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
def _true_fn(self, req_num_step, extra):
- return self.data[self.idx[0] + req_num_step]
+ return self.data[(self.idx[0] + req_num_step) % self.num_delay_step]
def _false_fn(self, req_num_step, extra):
idx = jnp.asarray([self.idx[0] + req_num_step,
@@ -303,6 +333,10 @@ class LengthDelay(AbstractDelay):
initial_delay_data: Any
The delay data. It can be a Python number, like float, int, boolean values.
It can also be arrays. Or a callable function or instance of ``Connector``.
+ A callable will be invoked as ``initial_delay_data(shape, dtype=...)`` when its
+ signature accepts a ``dtype`` keyword (e.g. ``Initializer``/``Connector``
+ instances), and as ``initial_delay_data(shape)`` otherwise (e.g. a plain
+ ``lambda shape: ...``).
Note that ``initial_delay_data`` should be arranged as the following way::
delay = 1 [ data
@@ -416,8 +450,16 @@ def reset(
elif isinstance(initial_delay_data, (ndarray, jnp.ndarray, float, int, bool)):
self.data[1:] = initial_delay_data
elif callable(initial_delay_data):
- self.data[1:] = initial_delay_data((delay_len,) + delay_target.shape,
- dtype=delay_target.dtype)
+ shape = (delay_len,) + delay_target.shape
+ dtype = delay_target.dtype
+ # Initializer/Connector instances accept ``(shape, dtype=...)``, but a plain
+ # callable (e.g. ``lambda shape: ...``) may not accept the ``dtype`` keyword.
+ # Branch on whether the callable's signature accepts ``dtype`` and fall back
+ # to calling it with the shape only when it does not.
+ if _accepts_dtype_kwarg(initial_delay_data):
+ self.data[1:] = initial_delay_data(shape, dtype=dtype)
+ else:
+ self.data[1:] = initial_delay_data(shape)
else:
raise ValueError(f'"delay_data" does not support {type(initial_delay_data)}')
diff --git a/brainpy/math/einops.py b/brainpy/math/einops.py
index b74eeb63e..05c9ab2ab 100644
--- a/brainpy/math/einops.py
+++ b/brainpy/math/einops.py
@@ -81,61 +81,6 @@ def __reduce(x: Union[Array, jax.Array], operation: str, reduced_axes):
raise NotImplementedError("Unknown reduction ", operation)
-def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes):
- # 'collapses' neighboring axes if those participate in the result pattern in the same order
- # TODO add support for added_axes
- assert len(axes_reordering) + len(reduced_axes) == len(init_shapes)
- # joining consecutive axes that will be reduced
- # possibly we can skip this if all backends can optimize this (not sure)
- reduced_axes = tuple(sorted(reduced_axes))
- for i in range(len(reduced_axes) - 1)[::-1]:
- if reduced_axes[i] + 1 == reduced_axes[i + 1]:
- removed_axis = reduced_axes[i + 1]
- removed_length = init_shapes[removed_axis]
- init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:]
- init_shapes[removed_axis - 1] *= removed_length
- reduced_axes = reduced_axes[: i + 1] + tuple(axis - 1 for axis in reduced_axes[i + 2:])
-
- # removing axes that are moved together during reshape
- def build_mapping():
- init_to_final = {}
- for axis in range(len(init_shapes)):
- if axis in reduced_axes:
- init_to_final[axis] = None
- else:
- after_reduction = sum(x is not None for x in init_to_final.values())
- init_to_final[axis] = list(axes_reordering).index(after_reduction)
- return init_to_final
-
- init_axis_to_final_axis = build_mapping()
-
- for init_axis in range(len(init_shapes) - 1)[::-1]:
- if init_axis_to_final_axis[init_axis] is None:
- continue
- if init_axis_to_final_axis[init_axis + 1] is None:
- continue
- if init_axis_to_final_axis[init_axis] + 1 == init_axis_to_final_axis[init_axis + 1]:
- removed_axis = init_axis + 1
- removed_length = init_shapes[removed_axis]
- removed_axis_after_reduction = sum(x not in reduced_axes for x in range(removed_axis))
-
- reduced_axes = tuple(axis if axis < removed_axis else axis - 1 for axis in reduced_axes)
- init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:]
- init_shapes[removed_axis - 1] *= removed_length
- old_reordering = axes_reordering
- axes_reordering = []
- for axis in old_reordering:
- if axis == removed_axis_after_reduction:
- pass
- elif axis < removed_axis_after_reduction:
- axes_reordering.append(axis)
- else:
- axes_reordering.append(axis - 1)
- init_axis_to_final_axis = build_mapping()
-
- return init_shapes, reduced_axes, axes_reordering, final_shapes
-
-
CookedRecipe = Tuple[Optional[List[int]], Optional[List[int]], List[int], Dict[int, int], Optional[List[int]], int]
# Actual type is tuple[tuple[str, int], ...]
diff --git a/brainpy/math/environment.py b/brainpy/math/environment.py
index 292d6dadf..3bf28d251 100644
--- a/brainpy/math/environment.py
+++ b/brainpy/math/environment.py
@@ -388,43 +388,57 @@ def set(
numpy_func_return: str
The array to return in all numpy functions. Support 'bp_array' and 'jax_array'.
"""
+ # Validate all arguments BEFORE mutating any global state, so that an
+ # invalid argument cannot leave the environment in a half-updated state.
if dt is not None:
assert isinstance(dt, float), '"dt" must a float.'
+ if mode is not None:
+ assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.'
+ if membrane_scaling is not None:
+ assert isinstance(membrane_scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.'
+ if x64 is not None:
+ assert isinstance(x64, bool), f'"x64" must be a bool.'
+ if float_ is not None:
+ assert isinstance(float_, type), '"float_" must a float.'
+ if int_ is not None:
+ assert isinstance(int_, type), '"int_" must a type.'
+ if bool_ is not None:
+ assert isinstance(bool_, type), '"bool_" must a type.'
+ if complex_ is not None:
+ assert isinstance(complex_, type), '"complex_" must a type.'
+ if numpy_func_return is not None:
+ assert numpy_func_return in ['bp_array', 'jax_array'], \
+ '"numpy_func_return" must be "bp_array" or "jax_array".'
+
+ # All validation passed; now apply the settings.
+ if dt is not None:
set_dt(dt)
if mode is not None:
- assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.'
set_mode(mode)
if membrane_scaling is not None:
- assert isinstance(membrane_scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.'
set_membrane_scaling(membrane_scaling)
if x64 is not None:
- assert isinstance(x64, bool), f'"x64" must be a bool.'
set_x64(x64)
if float_ is not None:
- assert isinstance(float_, type), '"float_" must a float.'
set_float(float_)
if int_ is not None:
- assert isinstance(int_, type), '"int_" must a type.'
set_int(int_)
if bool_ is not None:
- assert isinstance(bool_, type), '"bool_" must a type.'
set_bool(bool_)
if complex_ is not None:
- assert isinstance(complex_, type), '"complex_" must a type.'
set_complex(complex_)
if bp_object_as_pytree is not None:
defaults.bp_object_as_pytree = bp_object_as_pytree
if numpy_func_return is not None:
- assert numpy_func_return in ['bp_array', 'jax_array'], f'"numpy_func_return" must be "bp_array" or "jax_array".'
defaults.numpy_func_return = numpy_func_return
@@ -643,6 +657,7 @@ def enable_x64(x64=None):
def disable_x64():
+ brainstate.environ.set(precision=32)
config.update("jax_enable_x64", False)
set_int(jnp.int32)
set_float(jnp.float32)
diff --git a/brainpy/math/event/csr_matmat.py b/brainpy/math/event/csr_matmat.py
index c33d1103f..a8084328d 100644
--- a/brainpy/math/event/csr_matmat.py
+++ b/brainpy/math/event/csr_matmat.py
@@ -62,6 +62,6 @@ def csrmm(
matrix = brainevent.BinaryArray(matrix)
csr = brainevent.CSR((data, indices, indptr), shape=shape)
if transpose:
- return matrix @ csr
+ return csr.T @ matrix
else:
return csr @ matrix
diff --git a/brainpy/math/jitconn/matvec.py b/brainpy/math/jitconn/matvec.py
index 094f758e0..89442333c 100644
--- a/brainpy/math/jitconn/matvec.py
+++ b/brainpy/math/jitconn/matvec.py
@@ -77,6 +77,17 @@ def mv_prob_homo(
The matrix shape.
seed: int
The random number generation seed.
+
+ .. warning::
+
+ If ``seed`` is left as ``None`` (the default), a host random seed is
+ drawn with ``numpy.random.randint`` on **every call**. This makes the
+ result non-reproducible in eager mode, and -- because the seed is
+ captured as a Python constant -- it becomes **frozen** the first time
+ the function is traced under ``jit()``/``vmap()``, so every subsequent
+ jitted call reuses that single seed. For reproducible and correct
+ behaviour under JAX transformations, always pass an explicit integer
+ ``seed``.
transpose: bool
Transpose the random matrix or not.
outdim_parallel: bool
@@ -152,6 +163,17 @@ def mv_prob_uniform(
The matrix shape.
seed: int
The random number generation seed.
+
+ .. warning::
+
+ If ``seed`` is left as ``None`` (the default), a host random seed is
+ drawn with ``numpy.random.randint`` on **every call**. This makes the
+ result non-reproducible in eager mode, and -- because the seed is
+ captured as a Python constant -- it becomes **frozen** the first time
+ the function is traced under ``jit()``/``vmap()``, so every subsequent
+ jitted call reuses that single seed. For reproducible and correct
+ behaviour under JAX transformations, always pass an explicit integer
+ ``seed``.
transpose: bool
Transpose the random matrix or not.
outdim_parallel: bool
@@ -229,6 +251,17 @@ def mv_prob_normal(
The matrix shape.
seed: int
The random number generation seed.
+
+ .. warning::
+
+ If ``seed`` is left as ``None`` (the default), a host random seed is
+ drawn with ``numpy.random.randint`` on **every call**. This makes the
+ result non-reproducible in eager mode, and -- because the seed is
+ captured as a Python constant -- it becomes **frozen** the first time
+ the function is traced under ``jit()``/``vmap()``, so every subsequent
+ jitted call reuses that single seed. For reproducible and correct
+ behaviour under JAX transformations, always pass an explicit integer
+ ``seed``.
transpose: bool
Transpose the random matrix or not.
outdim_parallel: bool
@@ -276,6 +309,17 @@ def get_homo_weight_matrix(
The matrix shape.
seed: int
The random number generation seed.
+
+ .. warning::
+
+ If ``seed`` is left as ``None`` (the default), a host random seed is
+ drawn with ``numpy.random.randint`` on **every call**. This makes the
+ result non-reproducible in eager mode, and -- because the seed is
+ captured as a Python constant -- it becomes **frozen** the first time
+ the function is traced under ``jit()``/``vmap()``, so every subsequent
+ jitted call reuses that single seed. For reproducible and correct
+ behaviour under JAX transformations, always pass an explicit integer
+ ``seed``.
transpose: bool
Transpose the random matrix or not.
outdim_parallel: bool
@@ -320,6 +364,17 @@ def get_uniform_weight_matrix(
The matrix shape.
seed: int
The random number generation seed.
+
+ .. warning::
+
+ If ``seed`` is left as ``None`` (the default), a host random seed is
+ drawn with ``numpy.random.randint`` on **every call**. This makes the
+ result non-reproducible in eager mode, and -- because the seed is
+ captured as a Python constant -- it becomes **frozen** the first time
+ the function is traced under ``jit()``/``vmap()``, so every subsequent
+ jitted call reuses that single seed. For reproducible and correct
+ behaviour under JAX transformations, always pass an explicit integer
+ ``seed``.
transpose: bool
Transpose the random matrix or not.
outdim_parallel: bool
@@ -367,6 +422,17 @@ def get_normal_weight_matrix(
The matrix shape.
seed: int
The random number generation seed.
+
+ .. warning::
+
+ If ``seed`` is left as ``None`` (the default), a host random seed is
+ drawn with ``numpy.random.randint`` on **every call**. This makes the
+ result non-reproducible in eager mode, and -- because the seed is
+ captured as a Python constant -- it becomes **frozen** the first time
+ the function is traced under ``jit()``/``vmap()``, so every subsequent
+ jitted call reuses that single seed. For reproducible and correct
+ behaviour under JAX transformations, always pass an explicit integer
+ ``seed``.
transpose: bool
Transpose the random matrix or not.
outdim_parallel: bool
diff --git a/brainpy/math/modes.py b/brainpy/math/modes.py
index 0aec62f87..ea9ba1501 100644
--- a/brainpy/math/modes.py
+++ b/brainpy/math/modes.py
@@ -40,6 +40,11 @@ def __eq__(self, other: 'Mode'):
return False
return other.__class__ == self.__class__
+ # Defining ``__eq__`` sets ``__hash__`` to None (making instances
+ # unhashable). Restore hashability so modes can be used in sets / as dict
+ # keys. Modes compare equal iff they share a class, so hash by class.
+ __hash__ = brainstate.mixin.Mode.__hash__
+
def is_one_of(self, *modes):
for m_ in modes:
if not isinstance(m_, type):
diff --git a/brainpy/math/ndarray.py b/brainpy/math/ndarray.py
index 3b14e8833..b41c702df 100644
--- a/brainpy/math/ndarray.py
+++ b/brainpy/math/ndarray.py
@@ -19,7 +19,6 @@
import jax
import numpy as np
from jax import numpy as jnp
-from jax.dtypes import canonicalize_dtype
from jax.tree_util import register_pytree_node_class
from brainpy._errors import MathError
@@ -28,30 +27,10 @@
bm = None
__all__ = [
- 'Array', 'Array', 'ndarray', 'JaxArray', # alias of Array
+ 'Array', 'ndarray', 'JaxArray', # alias of Array
'ShardedArray',
]
-# Ways to change values in a zero-dimensional array
-# -----
-# Reference: https://stackoverflow.com/questions/56954714/how-do-i-assign-to-a-zero-dimensional-numpy-array
-#
-# >>> x = np.array(10)
-# 1. index the original array with ellipsis or an empty tuple
-# >>> x[...] = 2
-# >>> x[()] = 2
-
-_all_slice = slice(None, None, None)
-
-
-def _check_input_array(array):
- if isinstance(array, Array):
- return array.value
- elif isinstance(array, np.ndarray):
- return jnp.asarray(array)
- else:
- return array
-
def _return(a):
if defaults.numpy_func_return == 'bp_array' and isinstance(a, jax.Array) and a.ndim > 0:
@@ -68,14 +47,6 @@ def _check_out(out):
raise TypeError(f'out must be an instance of brainpy Array. But got {type(out)}')
-def _get_dtype(v):
- if hasattr(v, 'dtype'):
- dtype = v.dtype
- else:
- dtype = canonicalize_dtype(type(v))
- return dtype
-
-
@register_pytree_node_class
class Array(u.CustomArray):
"""Multiple-dimensional array in BrainPy.
@@ -102,6 +73,13 @@ def __init__(self, value, dtype: Any = None):
value = value.value
elif isinstance(value, (tuple, list, np.ndarray)):
value = jnp.asarray(value)
+ elif isinstance(value, jax.Array):
+ pass
+ else:
+ # raw Python scalars (int/float/bool/complex) and any other input:
+ # convert to a jax array so ``self._value`` is always array-like
+ # (mirrors the ``value`` setter).
+ value = jnp.asarray(value)
if dtype is not None:
value = jnp.asarray(value, dtype=dtype)
self._value = value
@@ -131,11 +109,14 @@ def tree_flatten(self):
@classmethod
def tree_unflatten(cls, aux_data, flat_contents):
- return cls(*flat_contents)
-
- # ins = object.__new__(cls)
- # ins._value = flat_contents[0]
- # return ins
+ # Reconstruct without going through ``__init__``: during abstract
+ # evaluation (``jax.eval_shape``, ``scan``/``for_loop`` tracing) the leaf
+ # is a ``ShapedArray``/``ShapeDtypeStruct`` rather than a concrete array,
+ # and ``__init__`` would try to ``jnp.asarray`` it and raise. Storing the
+ # leaf directly keeps the pytree round-trip transparent.
+ ins = object.__new__(cls)
+ ins._value = flat_contents[0]
+ return ins
@property
def data(self):
@@ -198,17 +179,23 @@ def at(self):
return self.value.at
def block_host_until_ready(self, *args):
- return self.value.block_host_until_ready(*args)
+ # ``jax.Array.block_host_until_ready`` was removed; ``block_until_ready``
+ # is the modern equivalent.
+ return self.value.block_until_ready(*args)
def block_until_ready(self, *args):
return self.value.block_until_ready(*args)
+ @property
def device(self):
- return self.value.device()
+ # ``jax.Array.device`` is now a property (it used to be a method).
+ return self.value.device
@property
def device_buffer(self):
- return self.value.device_buffer
+ # ``jax.Array.device_buffer`` was removed; the addressable shard's data
+ # is the modern equivalent on a single-device array.
+ return self.value.addressable_data(0)
def fill_(self, fill_value):
"""Fill the array with a scalar value.
@@ -264,8 +251,16 @@ def value(self):
The stored data.
"""
v = self._value
- # keep sharding constraints
- if self._keep_sharding and hasattr(v, 'sharding') and (v.sharding is not None):
+ # Keep sharding constraints, but only for genuinely multi-device
+ # shardings. A ``SingleDeviceSharding`` (the default on a single device,
+ # e.g. CPU) carries no distribution information, so inserting a
+ # ``with_sharding_constraint`` on every read is pure overhead.
+ if (
+ self._keep_sharding
+ and hasattr(v, 'sharding')
+ and (v.sharding is not None)
+ and not isinstance(v.sharding, jax.sharding.SingleDeviceSharding)
+ ):
return jax.lax.with_sharding_constraint(v, v.sharding)
# return the value
return v
diff --git a/brainpy/math/object_transform/_utils.py b/brainpy/math/object_transform/_utils.py
index 997f8b1bb..c4b32d71a 100644
--- a/brainpy/math/object_transform/_utils.py
+++ b/brainpy/math/object_transform/_utils.py
@@ -24,6 +24,7 @@
__all__ = [
'infer_dyn_vars',
'get_brainpy_object',
+ 'warp_to_no_state_input_output',
]
diff --git a/brainpy/math/object_transform/base.py b/brainpy/math/object_transform/base.py
index f6a6b9762..4c61425fd 100644
--- a/brainpy/math/object_transform/base.py
+++ b/brainpy/math/object_transform/base.py
@@ -146,77 +146,29 @@ def tracing_variable(
axis_names: Optional[Sequence[str]] = None,
batch_axis_name: Optional[str] = BATCH_AXIS,
) -> Variable:
- """Initialize the variable which can be traced during computations and transformations.
+ """Initialize a variable that can be traced during computations and transformations.
- Although this function is designed to initialize tracing variables during computation or compilation,
- it can also be used for the initialization of variables before computation and compilation.
-
- - If the variable has not been instantiated, a :py:class:`~.Variable` will be instantiated.
- - If the variable has been created, the further call of this function will return the created variable.
-
- Here is the usage example::
-
- class Example(bm.BrainPyObject):
- def fun(self):
- # The first time of calling `.fun()`, this line will create a Variable instance.
- # If users repeatedly call `.fun()` function, this line will not initialize variables again.
- # Instead, it will return the variable has been created.
- self.tracing_variable('a', bm.zeros, (10,))
-
- # The created variable can be accessed with self.xxx
- self.a.value = bm.ones(10)
-
- # Calling this function again will not reinitialize the
- # variable again, Instead, it will return the variable
- # that has been created.
- a = self.tracing_variable('a', bm.zeros, (10,))
-
- .. versionadded:: 2.4.5
+ .. deprecated:: 3.0.0
+ This feature is no longer supported. Since BrainPy 3.0.0 the library
+ has been rewritten on top of ``brainstate`` and variable tracing is
+ handled by ``brainstate`` directly. Calling this method always raises
+ :class:`NotImplementedError`.
Args:
name: str. The variable name.
init: callable, Array. The data to be initialized as a ``Variable``.
- batch_or_mode: int, bool, Mode. This is used to specify the batch size of this variable.
- If it is a boolean or an instance of ``Mode``, the batch size will be 1.
- If it is None, the variable has no batch axis.
shape: int, sequence of int. The shape of the variable.
+ batch_or_mode: int, bool, Mode. The batch size of this variable.
batch_axis: int. The batch axis, if batch size is given.
- axis_names: sequence of str. The name for each axis. These names should match the given ``axes``.
- batch_axis_name: str. The name for the batch axis. The name will be used
- if ``batch_or_mode`` is given. Default is ``brainpy.math.sharding.BATCH_AXIS``.
+ axis_names: sequence of str. The name for each axis.
+ batch_axis_name: str. The name for the batch axis.
- Returns:
- The instance of :py:class:`~.Variable`.
+ Raises:
+ NotImplementedError: Always, because this feature is unsupported since 3.0.0.
"""
- # the variable has been created
raise NotImplementedError(
'Since 3.0.0, brainpy is rewritten with brainstate. The feature tracing_variable is no longer supported. '
)
- if hasattr(self, name):
- var = getattr(self, name)
- if isinstance(var, Variable):
- return var
- # if var.shape != value.shape:
- # raise ValueError(
- # f'"{name}" has been used in this class with the shape of {var.shape} (!= {value.shape}). '
- # f'Please assign another name for the initialization of variables '
- # f'tracing during computation and compilation.'
- # )
- # if var.dtype != value.dtype:
- # raise ValueError(
- # f'"{name}" has been used in this class with the dtype of {var.dtype} (!= {value.dtype}). '
- # f'Please assign another name for the initialization of variables '
- # f'tracing during computation and compilation.'
- # )
-
- global variable_
- if variable_ is None:
- from brainpy.initialize import variable_
- with jax.ensure_compile_time_eval():
- value = variable_(init, shape, batch_or_mode, batch_axis, axis_names, batch_axis_name)
- value.ready_to_trace = True
- self.setattr(name, value)
- return value
def __setattr__(self, key: str, value: Any) -> None:
"""Overwrite `__setattr__` method for changing :py:class:`~.Variable` values.
@@ -288,25 +240,41 @@ def register_implicit_vars(self, *variables, var_cls: type = None, **named_varia
if var_cls is None:
var_cls = (Variable, VarList, VarDict)
+ def _store(key, value):
+ # ``self.implicit_vars`` is an ``ArrayCollector`` whose entries must
+ # be plain ``Variable`` instances (it is consumed directly by
+ # ``vars()``). ``VarList``/``VarDict`` containers are therefore
+ # flattened into their constituent variables before insertion,
+ # mirroring how attribute-stored containers are expanded in
+ # ``vars()``.
+ if isinstance(value, VarList):
+ for i, vv in enumerate(value):
+ self.implicit_vars[f'{key}-{i}'] = vv
+ elif isinstance(value, VarDict):
+ for kk, vv in value.items():
+ self.implicit_vars[f'{key}-{kk}'] = vv
+ else:
+ self.implicit_vars[key] = value
+
for variable in variables:
if isinstance(variable, var_cls):
- self.implicit_vars[f'var{id(variable)}'] = variable
+ _store(f'var{id(variable)}', variable)
elif isinstance(variable, (tuple, list)):
for v in variable:
if not isinstance(v, var_cls):
raise ValueError(f'Must be instance of {var_cls}, but we got {type(v)}')
- self.implicit_vars[f'var{id(v)}'] = v
+ _store(f'var{id(v)}', v)
elif isinstance(variable, dict):
for k, v in variable.items():
if not isinstance(v, var_cls):
raise ValueError(f'Must be instance of {var_cls}, but we got {type(v)}')
- self.implicit_vars[k] = v
+ _store(k, v)
else:
raise ValueError(f'Unknown type: {type(variable)}')
for key, variable in named_variables.items():
if not isinstance(variable, var_cls):
raise ValueError(f'Must be instance of {var_cls}, but we got {type(variable)}')
- self.implicit_vars[key] = variable
+ _store(key, variable)
def register_implicit_nodes(self, *nodes, node_cls: type = None, **named_nodes):
if node_cls is None:
@@ -606,11 +574,12 @@ def to(self, device: Optional[Any]):
Args:
device: The device.
"""
- for key, var in self.state_dict().items():
- if isinstance(var, Array):
- var.value = jax.device_put(var.value, device=device)
- else:
- setattr(self, key, jax.device_put(var, device=device))
+ # Iterate over the actual ``Variable`` instances (not the nested
+ # ``state_dict`` mapping). Iterating ``state_dict()`` would yield
+ # nested dicts/raw arrays keyed by name, so nothing was ever moved and
+ # ``setattr`` injected junk attributes onto the object.
+ for var in self.vars().values():
+ var.value = jax.device_put(var.value, device=device)
return self
def cpu(self):
diff --git a/brainpy/math/object_transform/controls.py b/brainpy/math/object_transform/controls.py
index 26c0f2aea..4e8942b4e 100644
--- a/brainpy/math/object_transform/controls.py
+++ b/brainpy/math/object_transform/controls.py
@@ -23,6 +23,31 @@
from brainpy.math.ndarray import Array
from ._utils import warp_to_no_state_input_output
+
+def _unwrap_operand_leaf(x):
+ """Replace a ``State``/``Variable`` or BrainPy ``Array`` leaf with its raw value.
+
+ ``brainstate.transform.*`` rejects ``State`` objects passed as operands, and feeding a
+ BrainPy ``Array`` through brainstate's loop primitives round-trips it through
+ ``tree_unflatten`` (which reconstructs from ``ShapedArray`` avals and fails) inside a
+ JAX trace. Unwrapping both to the underlying ``jax.Array`` avoids both problems while
+ leaving any other operand type untouched.
+ """
+ if isinstance(x, (brainstate.State, Array)):
+ return x.value
+ return x
+
+
+def _unwrap_state_operands(operands):
+ """Unwrap ``brainstate.State`` (e.g. :py:class:`~.Variable`) and :py:class:`~.Array`
+ leaves in ``operands`` to their raw ``jax.Array`` values before forwarding to brainstate.
+ """
+ return jax.tree.map(
+ _unwrap_operand_leaf,
+ operands,
+ is_leaf=lambda x: isinstance(x, (brainstate.State, Array)),
+ )
+
__all__ = [
'cond',
'ifelse',
@@ -124,6 +149,7 @@ def cond(
"""
if not isinstance(operands, (tuple, list)):
operands = (operands,)
+ operands = _unwrap_state_operands(operands)
return brainstate.transform.cond(
pred,
warp_to_no_state_input_output(true_fun),
@@ -193,6 +219,7 @@ def ifelse(
operands = ()
elif not isinstance(operands, (tuple, list)):
operands = (operands,)
+ operands = _unwrap_state_operands(operands)
# Convert non-callable branches to callables
def make_callable(branch):
@@ -234,7 +261,10 @@ def make_callable(branch):
conditions = exclusive_conditions
- return brainstate.transform.ifelse(conditions, branches, *operands)
+ # BrainPy already converts the conditions into mutually exclusive form above,
+ # so brainstate does not need to re-check exclusivity (which would otherwise
+ # reject overlapping inputs at trace time).
+ return brainstate.transform.ifelse(conditions, branches, *operands, check_cond=False)
def for_loop(
@@ -311,6 +341,12 @@ def for_loop(
iteration of a loop.
jit: bool
Whether to just-in-time compile the function. Set to ``False`` to disable JIT compilation.
+
+ .. note::
+ ``jit=False`` is implemented via the global :py:func:`jax.disable_jit` context
+ manager. Consequently it has no effect when ``for_loop`` is called inside an
+ enclosing trace (e.g. within another jitted/scanned function): JAX is already
+ tracing, so the loop runs as a compiled ``scan`` regardless of this flag.
progress_bar: bool, ProgressBar, int
Whether and how to display a progress bar during execution:
@@ -362,6 +398,7 @@ def for_loop(
"""
if not isinstance(operands, (tuple, list)):
operands = (operands,)
+ operands = _unwrap_state_operands(operands)
# Convert progress_bar to pbar format
pbar = _convert_progress_bar_to_pbar(progress_bar)
@@ -371,11 +408,12 @@ def for_loop(
# For zero-length inputs, we need to use JIT mode even when jit=False.
should_disable_jit = False
if jit is False:
- # Check if any operand has zero length
- first_operand = operands[0]
- is_zero_length = False
- if hasattr(first_operand, 'shape') and len(first_operand.shape) > 0:
- is_zero_length = (first_operand.shape[0] == 0)
+ # Check if any operand (over the whole pytree) has a zero-length leading axis.
+ leaves = jax.tree.leaves(operands)
+ is_zero_length = any(
+ getattr(leaf, 'ndim', 0) > 0 and leaf.shape[0] == 0
+ for leaf in leaves
+ )
if is_zero_length:
# Use JIT mode for zero-length inputs to avoid JAX limitation
@@ -464,13 +502,21 @@ def scan(
Now accepts ProgressBar instances and integers for advanced customization.
Returns::
-
- outs: Any
- The stacked outputs of ``body_fun`` when scanned over the leading axis of the inputs.
+
+ outs: tuple
+ A two-element tuple ``(final_carry, stacked_ys)``:
+
+ - ``final_carry``: the loop carry value returned by the last iteration of
+ ``body_fun`` (same structure as ``init``).
+ - ``stacked_ys``: the per-iteration outputs of ``body_fun`` stacked along a
+ new leading axis.
"""
# Convert progress_bar to pbar format
pbar = _convert_progress_bar_to_pbar(progress_bar)
+ init = _unwrap_state_operands(init)
+ operands = _unwrap_state_operands(operands)
+
return brainstate.transform.scan(
warp_to_no_state_input_output(body_fun),
init=init,
@@ -546,13 +592,17 @@ def while_loop(
if not isinstance(operands, (tuple, list)):
operands = (operands,)
operands = tuple(operands)
+ operands = _unwrap_state_operands(operands)
def body(x):
r = body_fun(*x)
if r is None:
- return x
- else:
- return r
+ raise ValueError(
+ '`body_fun` of `while_loop` must return the updated operands, '
+ 'but got `None`. Returning `None` would leave the operands unchanged '
+ 'and the loop condition would never become False, causing an infinite loop.'
+ )
+ return r
return brainstate.transform.while_loop(
warp_to_no_state_input_output(lambda x: cond_fun(*x)),
diff --git a/brainpy/math/object_transform/function.py b/brainpy/math/object_transform/function.py
index 8469e4c63..631295879 100644
--- a/brainpy/math/object_transform/function.py
+++ b/brainpy/math/object_transform/function.py
@@ -27,6 +27,42 @@
class Partial(FunAsObject):
+ """A picklable, object-aware partial application of a function.
+
+ ``Partial`` behaves like :py:func:`functools.partial`: it binds positional and
+ keyword arguments to ``fun`` so that the remaining arguments can be supplied
+ later when the instance is called. Unlike :py:func:`functools.partial`, it is a
+ :py:class:`~.BrainPyObject`, so any :py:class:`~.Variable` instances and child
+ :py:class:`~.BrainPyObject` objects used by ``fun`` are registered and tracked.
+
+ Parameters
+ ----------
+ fun : callable
+ The function to be partially applied.
+ *args : Any
+ Positional arguments bound ahead of the call-time positional arguments.
+ child_objs : callable, BrainPyObject, sequence of BrainPyObject, dict of BrainPyObject, optional
+ The children objects used in ``fun``.
+ dyn_vars : Variable, sequence of Variable, dict of Variable, optional
+ The :py:class:`~.Variable` instances used in ``fun``.
+ **keywords : Any
+ Keyword arguments bound to ``fun``. Keywords supplied at call time take
+ precedence over those bound here.
+
+ See Also
+ --------
+ to_object : Transform a Python function into a :py:class:`~.BrainPyObject`.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ >>> import brainpy.math as bm
+ >>> add = bm.Partial(lambda x, y: x + y, 1)
+ >>> add(2)
+ 3
+ """
+
def __init__(
self,
fun: Callable,
@@ -110,6 +146,7 @@ def function(
func: FunAsObject
The instance of ``BrainPyObject``.
"""
- warnings.warn('Using `brainpy.math.to_object()` instead. Will be removed after version 2.4.0.',
- UserWarning)
+ warnings.warn('`brainpy.math.function()` is deprecated; use `brainpy.math.to_object()` instead. '
+ 'It will be removed in a future release.',
+ DeprecationWarning)
return to_object(f, nodes, dyn_vars, name)
diff --git a/brainpy/math/object_transform/jit.py b/brainpy/math/object_transform/jit.py
index 0ce02f40b..0861000a3 100644
--- a/brainpy/math/object_transform/jit.py
+++ b/brainpy/math/object_transform/jit.py
@@ -20,6 +20,7 @@
"""
+import warnings
from typing import Callable, Union, Sequence, Iterable
import brainstate.transform
@@ -122,26 +123,27 @@ def jit(
Parameters::
-
- {jit_par}
- dyn_vars : optional, dict, sequence of Variable, Variable
- These variables will be changed in the function, or needed in the computation.
-
- .. deprecated:: 2.4.0
- No longer need to provide ``dyn_vars``. This function is capable of automatically
- collecting the dynamical variables used in the target ``func``.
- child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject
- The children objects used in the target function.
- .. deprecated:: 2.4.0
- No longer need to provide ``child_objs``. This function is capable of automatically
- collecting the children objects used in the target ``func``.
+ {jit_par}
Returns::
-
+
func : JITTransform
A callable jitted function, set up for just-in-time compilation.
"""
+ # ``dyn_vars`` and ``child_objs`` are no longer used; brainstate collects
+ # dynamical variables automatically. Pop them (with a one-time deprecation
+ # warning) rather than forwarding them into ``brainstate.transform.jit``,
+ # which would raise a TypeError on the unexpected keyword arguments.
+ for _deprecated in ('dyn_vars', 'child_objs'):
+ if _deprecated in kwargs:
+ kwargs.pop(_deprecated)
+ warnings.warn(
+ f'`{_deprecated}` is deprecated and ignored. This function automatically '
+ f'collects the dynamical variables and child objects used in the target `func`.',
+ DeprecationWarning,
+ stacklevel=2,
+ )
return brainstate.transform.jit(
warp_to_no_state_input_output(func),
static_argnums=static_argnums,
@@ -160,6 +162,7 @@ def cls_jit(
func: Callable = Missing(),
static_argnums: Union[int, Iterable[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
+ donate_argnums: Union[int, Sequence[int], None] = None,
inline: bool = False,
keep_unused: bool = False,
**kwargs
@@ -197,19 +200,36 @@ def cls_jit(
func : JITTransform
A callable jitted function, set up for just-in-time compilation.
"""
+ # The bound method exposes ``self`` as the first positional argument, so any
+ # caller-supplied positional indices must be shifted by +1. Negative indices
+ # count from the end and are unaffected by the prepended ``self``; shifting
+ # them would silently corrupt the target argument.
+ def _shift_positive(x):
+ return x + 1 if x >= 0 else x
+
if static_argnums is None:
static_argnums = (0,)
elif isinstance(static_argnums, int):
- static_argnums = (0, static_argnums + 1,)
+ static_argnums = tuple(dict.fromkeys((0, _shift_positive(static_argnums))))
elif isinstance(static_argnums, (tuple, list)):
- static_argnums = (0,) + tuple(jax.tree.map(lambda x: x + 1, static_argnums))
+ static_argnums = tuple(dict.fromkeys((0,) + tuple(_shift_positive(x) for x in static_argnums)))
else:
raise ValueError('static_argnums is not supported yet.')
+ if donate_argnums is None:
+ donate_argnums = ()
+ elif isinstance(donate_argnums, int):
+ donate_argnums = (_shift_positive(donate_argnums),)
+ elif isinstance(donate_argnums, (tuple, list)):
+ donate_argnums = tuple(_shift_positive(x) for x in donate_argnums)
+ else:
+ raise ValueError('donate_argnums is not supported yet.')
+
return jit(
func=func,
static_argnums=static_argnums,
static_argnames=static_argnames,
+ donate_argnums=donate_argnums,
inline=inline,
keep_unused=keep_unused,
**kwargs
diff --git a/brainpy/math/object_transform/naming.py b/brainpy/math/object_transform/naming.py
index 8842b2a12..ce5bc32e5 100644
--- a/brainpy/math/object_transform/naming.py
+++ b/brainpy/math/object_transform/naming.py
@@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================
import warnings
+import weakref
from brainpy import _errors as errors
@@ -21,7 +22,16 @@
'clear_name_cache',
]
-_name2id = dict()
+# Maps a unique name to a *weak* reference of the object that owns it.
+#
+# Storing a weak reference (instead of the raw ``id(obj)``) has two benefits:
+# 1. The registry no longer keeps objects alive, so it does not grow
+# unboundedly as transient objects are created and discarded.
+# 2. Once an owning object is garbage-collected, its name is treated as free
+# again. Keying on ``id(obj)`` was unsafe because CPython readily reuses
+# the integer id of a collected object for a brand-new one, which could
+# trigger spurious ``UniqueNameError`` (or mask a genuine collision).
+_name2id = dict() # name -> weakref.ref(obj)
_typed_names = {}
@@ -32,7 +42,11 @@ def check_name_uniqueness(name, obj):
f'according to Python language definition. '
f'Please choose another name.')
if name in _name2id:
- if _name2id[name] != id(obj):
+ existing = _name2id[name]() # dereference the weak ref
+ # ``existing is None`` -> the previous owner has been collected, so the
+ # name is free and can be re-registered.
+ # ``existing is obj`` -> the same object re-registering its own name.
+ if existing is not None and existing is not obj:
raise errors.UniqueNameError(
f'In BrainPy, each object should have a unique name. '
f'However, we detect that {obj} has a used name "{name}". \n'
@@ -40,8 +54,15 @@ def check_name_uniqueness(name, obj):
f'>>> brainpy.math.clear_name_cache() \n\n'
f'to clear all cached names. '
)
- else:
- _name2id[name] = id(obj)
+
+ # (Re)register the name with a weak reference to the current owner. A
+ # finalizer drops the entry as soon as the object is collected, keeping the
+ # registry bounded.
+ def _drop(_ref, _name=name):
+ if _name2id.get(_name) is _ref:
+ del _name2id[_name]
+
+ _name2id[name] = weakref.ref(obj, _drop)
def get_unique_name(type_: str):
diff --git a/brainpy/math/object_transform/variables.py b/brainpy/math/object_transform/variables.py
index 60671a148..b15f33eb4 100644
--- a/brainpy/math/object_transform/variables.py
+++ b/brainpy/math/object_transform/variables.py
@@ -106,10 +106,10 @@ def __init__(
@property
def size_without_batch(self):
if self.batch_axis is None:
- return self.size
+ return self.shape
else:
- sizes = self.size
- return sizes[:self.batch_axis] + sizes[self.batch_axis + 1:]
+ s = self.shape
+ return s[:self.batch_axis] + s[self.batch_axis + 1:]
@property
def batch_axis(self) -> Optional[int]:
@@ -141,6 +141,18 @@ def value(self):
@value.setter
def value(self, v):
+ # Normalize/unwrap the incoming value *before* validating its
+ # shape/dtype, so that ``Array``/``np.ndarray``/``brainstate.State``
+ # wrappers are converted to a plain JAX array first. Otherwise the
+ # shape/dtype checks below would run against the wrapper (and a numpy
+ # value would never be canonicalized to the Variable's dtype).
+ if isinstance(v, brainstate.State):
+ v = v.value
+ if isinstance(v, Array):
+ v = v.value
+ elif isinstance(v, np.ndarray):
+ v = jnp.asarray(v)
+
_value = self.value
ext_shape = jnp.shape(v)
int_shape = jnp.shape(_value)
@@ -156,20 +168,51 @@ def value(self, v):
if ext_dtype != int_dtype:
raise MathError(f"The dtype of the original data is {int_dtype}, "
f"while we got {ext_dtype}.")
- if isinstance(v, Array):
- v = v.value
- elif isinstance(v, np.ndarray):
- v = jnp.asarray(v)
- else:
- v = v
- if isinstance(v, brainstate.State): # value checking
- v = v.value
self._check_value_tree(v) # check the tree structure
record_state_value_write(self) # record the value by the stack (>= level)
self._been_writen = True # set the flag
self._write_value(v) # write the value
+ # ------------------------------------------------------------------
+ # Identity-based hashing / equality.
+ #
+ # ``__eq__`` is inherited from the array base class and performs an
+ # *element-wise* comparison (e.g. ``var == 0`` returns a boolean array),
+ # which is a useful and public behaviour we intentionally keep. The hash,
+ # however, must stay identity-based: every place in BrainPy that uses a
+ # ``Variable`` as a registry/dedup key keys on ``id(var)`` (see the
+ # collectors), so a value-based hash would be both incorrect (mutable
+ # value) and inconsistent with that usage. We pin ``__hash__`` here to
+ # make that contract explicit and stable.
+ # ------------------------------------------------------------------
+ def __hash__(self):
+ return id(self)
+
+ def tree_flatten(self):
+ # Carry ``batch_axis`` and ``axis_names`` through pytree round-trips.
+ # The base ``Array.tree_flatten`` returns ``aux_data=None``, which
+ # silently drops these attributes whenever a ``Variable`` is flattened
+ # and reconstructed (e.g. through ``jax.jit``/``vmap``).
+ return (self.value,), (self._batch_axis, self.axis_names)
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, flat_contents):
+ batch_axis, axis_names = aux_data
+ (value,) = flat_contents
+ # Rebuild without re-running ``Variable.__init__``: that would re-run
+ # the batch-axis validation and the (costly) ``State`` source-info
+ # capture on every unflatten. Set the ``_value`` slot first so that
+ # ``State.__init__`` can initialise the remaining bookkeeping fields
+ # (trace state, level, hooks, ...) correctly, then restore the
+ # variable-specific metadata.
+ obj = object.__new__(cls)
+ object.__setattr__(obj, '_value', value)
+ brainstate.State.__init__(obj, value)
+ object.__setattr__(obj, '_batch_axis', batch_axis)
+ object.__setattr__(obj, 'axis_names', axis_names)
+ return obj
+
def _get_dtype(v):
if hasattr(v, 'dtype'):
@@ -420,7 +463,11 @@ def tree_flatten(self):
@classmethod
def tree_unflatten(cls, keys, values):
- return cls(jax.util.safe_zip(keys, values))
+ # ``jax.util.safe_zip`` was removed in recent JAX. Reconstruct the
+ # mapping with a plain ``dict``; note that ``VarDict.update`` only
+ # understands ``dict`` (and single ``(k, v)`` tuples), so a bare
+ # ``zip`` iterator would be silently dropped here.
+ return cls(dict(zip(keys, values)))
var_dict = VarDict
diff --git a/brainpy/math/others.py b/brainpy/math/others.py
index fcf4b8587..a95a34c80 100644
--- a/brainpy/math/others.py
+++ b/brainpy/math/others.py
@@ -17,6 +17,7 @@
import jax
import jax.numpy as jnp
+import numpy as np
from jax.tree_util import tree_map
from brainpy import check, tools
@@ -91,9 +92,14 @@ def remove_diag(arr):
"""
if arr.ndim != 2:
raise ValueError(f'Only support 2D matrix, while we got a {arr.ndim}D array.')
- eyes = _return(jnp.ones(arr.shape, dtype=bool))
- fill_diagonal(eyes, False)
- return jnp.reshape(arr[eyes.value], (arr.shape[0], arr.shape[1] - 1))
+ arr = as_jax(arr)
+ m, n = arr.shape
+ # Static off-diagonal indices (computed with numpy so they are concrete
+ # constants and the gather traces cleanly under jit/vmap).
+ rows = np.repeat(np.arange(m), n - 1)
+ eye_mask = ~np.eye(m, n, dtype=bool)
+ cols = np.broadcast_to(np.arange(n), (m, n))[eye_mask]
+ return arr[rows, cols].reshape(m, n - 1)
def clip_by_norm(t, clip_norm, axis=None):
diff --git a/brainpy/math/pre_syn_post.py b/brainpy/math/pre_syn_post.py
index d50f7c322..0b6bcff80 100644
--- a/brainpy/math/pre_syn_post.py
+++ b/brainpy/math/pre_syn_post.py
@@ -284,13 +284,24 @@ def pre2post_mean(pre_values, post_num, post_ids, pre_ids=None):
post_val: ArrayType
The value with the size of post-synaptic neurons.
+
+ Notes::
+
+ When ``pre_values`` is a scalar, every connection carries the same constant
+ value, so the per-post mean is simply that constant. In this case the function
+ broadcasts the constant to every targeted post-synaptic neuron (untargeted
+ neurons stay ``0``). Duplicate ``post_ids`` therefore do not require any
+ averaging -- the mean of identical values equals the value itself.
"""
out = jnp.zeros(post_num)
pre_values = as_jax(pre_values)
post_ids = as_jax(post_ids)
if jnp.ndim(pre_values) == 0:
+ # Scalar branch: every synapse carries the same constant ``pre_values``,
+ # so the mean over any group of post-synaptic targets is that constant.
+ # Broadcast it to the targeted posts (duplicate ``post_ids`` are harmless
+ # because the mean of identical values is the value itself).
return out.at[post_ids].set(pre_values)
- # return out.at[jnp.unique(post_ids)].set(pre_values)
else:
_raise_pre_ids_is_none(pre_ids)
pre_ids = as_jax(pre_ids)
@@ -515,7 +526,9 @@ def syn2post_mean(syn_values, post_ids, post_num: int, indices_are_sorted=False)
syn_values = jnp.asarray(syn_values, dtype=jnp.int32)
nominator = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted)
denominator = _jit_seg_sum(jnp.ones_like(syn_values), post_ids, post_num, indices_are_sorted)
- return jnp.nan_to_num(nominator / denominator)
+ # Guard only the empty-group case (denominator == 0) instead of masking with
+ # ``nan_to_num``, which would also silently hide genuine NaNs in ``syn_values``.
+ return jnp.where(denominator > 0, nominator / denominator, 0.)
def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=False):
@@ -547,5 +560,10 @@ def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=Fal
syn_values = syn_values - syn_maxs[post_ids]
syn_values = jnp.exp(syn_values)
normalizers = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted)
- softmax = syn_values / normalizers[post_ids]
- return jnp.nan_to_num(softmax)
+ # ``normalizers[post_ids]`` is structurally >= 1 for every referenced post group
+ # (each such group contains at least the current synapse, contributing
+ # ``exp(0) == 1`` after the max-subtraction), so this division never hits a
+ # genuine 0/0. The previous ``jnp.nan_to_num`` only served to silently hide
+ # genuine NaNs produced upstream (e.g. NaNs already present in ``syn_values``),
+ # so it is intentionally removed to let such NaNs propagate.
+ return syn_values / normalizers[post_ids]
diff --git a/brainpy/math/remove_vmap.py b/brainpy/math/remove_vmap.py
index e49e0782a..103ef1925 100644
--- a/brainpy/math/remove_vmap.py
+++ b/brainpy/math/remove_vmap.py
@@ -27,6 +27,32 @@
def remove_vmap(x, op='any'):
+ """Reduce ``x`` with ``any``/``all`` *across the vmap batch axis as well*.
+
+ This is a custom primitive whose batching rule deliberately collapses the
+ batch axis into a single **global** scalar. That is, when called under
+ :func:`jax.vmap`, ``remove_vmap(x, 'any')`` returns one ``bool`` summarising
+ *all* batch elements together (``True`` if any element of any batch is
+ truthy), rather than a per-batch vector of results.
+
+ This is intentional: the primitive is used for global convergence / NaN-style
+ checks where the batch dimension must not survive the reduction. The batching
+ rule returns :data:`jax.interpreters.batching.not_mapped`, so the output is a
+ genuine unbatched scalar (it is *not* broadcast back across the batch axis).
+
+ Parameters
+ ----------
+ x : Array or jax.Array
+ The input array. ``brainpy.math.Array`` inputs are unwrapped.
+ op : {'any', 'all'}
+ The reduction to apply. ``'any'`` -> logical OR, ``'all'`` -> logical AND.
+
+ Returns
+ -------
+ jax.Array
+ A scalar boolean. Under :func:`jax.vmap` it is a single global scalar,
+ not a per-batch result.
+ """
if isinstance(x, Array):
x = x.value
if op == 'any':
diff --git a/brainpy/math/scales.py b/brainpy/math/scales.py
index 738a8eddf..5fd931906 100644
--- a/brainpy/math/scales.py
+++ b/brainpy/math/scales.py
@@ -73,17 +73,42 @@ def clone(self, bias=None, scale=None):
class IdScaling(Scaling):
+ """Identity scaling: a no-op :class:`Scaling` with ``scale=1`` and ``bias=0``.
+
+ Because the transform is the identity, custom ``bias``/``scale`` arguments are
+ meaningless. Passing non-default values is therefore rejected (rather than
+ being silently ignored) so the caller is not misled into thinking the values
+ take effect.
+ """
+
def __init__(self):
super().__init__(scale=1., bias=0.)
+ @staticmethod
+ def _reject_overrides(bias=None, scale=None):
+ if bias is not None and bias != 0.:
+ raise ValueError(
+ 'IdScaling is the identity transform and ignores "bias". '
+ f'Got bias={bias}. Use a plain Scaling if you need a non-zero bias.'
+ )
+ if scale is not None and scale != 1.:
+ raise ValueError(
+ 'IdScaling is the identity transform and ignores "scale". '
+ f'Got scale={scale}. Use a plain Scaling if you need a non-unit scale.'
+ )
+
def offset_scaling(self, x, bias=None, scale=None):
+ self._reject_overrides(bias, scale)
return x
def std_scaling(self, x, scale=None):
+ self._reject_overrides(scale=scale)
return x
def inv_scaling(self, x, scale=None):
+ self._reject_overrides(scale=scale)
return x
def clone(self, bias=None, scale=None):
+ self._reject_overrides(bias, scale)
return IdScaling()
diff --git a/brainpy/math/sharding.py b/brainpy/math/sharding.py
index e9f8ec679..3ef62e2c6 100644
--- a/brainpy/math/sharding.py
+++ b/brainpy/math/sharding.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+import warnings
from contextlib import contextmanager
from functools import partial
from typing import Optional, Any, Union, Sequence
@@ -123,8 +124,20 @@ def get_sharding(
if mesh is None:
return None
else:
- axis_names = [(name if name in mesh.axis_names else None) for name in axis_names]
- return NamedSharding(mesh, PartitionSpec(*axis_names))
+ resolved = [(name if name in mesh.axis_names else None) for name in axis_names]
+ # If *every* requested axis name is absent from the mesh, the resulting
+ # PartitionSpec is fully replicated (all ``None``), which silently
+ # discards the user's sharding intent. Warn so this is not a silent
+ # no-op. Partial matches are tolerated (kept) on purpose.
+ if len(axis_names) > 0 and all(name is None for name in resolved):
+ warnings.warn(
+ f'None of the requested axis names {list(axis_names)} are present in the '
+ f'mesh axes {tuple(mesh.axis_names)}. The array will be fully replicated '
+ f'(PartitionSpec of all None). Check the axis names against the mesh.',
+ UserWarning,
+ stacklevel=2,
+ )
+ return NamedSharding(mesh, PartitionSpec(*resolved))
def partition_by_axname(
diff --git a/brainpy/math/sparse/__init__.py b/brainpy/math/sparse/__init__.py
index c2769089b..d039e51bc 100644
--- a/brainpy/math/sparse/__init__.py
+++ b/brainpy/math/sparse/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-# from ._coo_mv import *
+from .coo_mv import *
from .csr_mm import *
from .csr_mv import *
from .jax_prim import *
diff --git a/brainpy/math/sparse/coo_mv.py b/brainpy/math/sparse/coo_mv.py
index f38063e73..a8e398e57 100644
--- a/brainpy/math/sparse/coo_mv.py
+++ b/brainpy/math/sparse/coo_mv.py
@@ -32,44 +32,39 @@ def coomv(
vector: Union[jnp.ndarray, Array],
*,
shape: Tuple[int, int],
- rows_sorted: bool = False,
- cols_sorted: bool = False,
transpose: bool = False,
- method: str = 'cusparse'
):
- """Product of COO sparse matrix and a dense vector using cuSPARSE algorithm.
+ """Product of COO sparse matrix and a dense vector.
- This function supports JAX transformations, including `jit()`, `grad()`,
- `vmap()` and `pmap()`.
+ The ``brainevent`` COO format was removed in v0.1.0, so the COO indices are
+ converted to CSR (via :func:`brainevent.coo2csr`) and the multiplication is
+ delegated to :class:`brainevent.CSR`.
- Parameters::
+ This function supports JAX transformations, including ``jit()``, ``grad()``,
+ ``vmap()`` and ``pmap()``.
- data: ndarray, float
- An array of shape ``(nse,)``.
- row: ndarray
- An array of shape ``(nse,)``.
- col: ndarray
- An array of shape ``(nse,)`` and dtype ``row.dtype``.
- vector: ndarray
- An array of shape ``(shape[0] if transpose else shape[1],)`` and
- dtype ``data.dtype``.
- shape: tuple of int
- The shape of the sparse matrix.
- rows_sorted: bool
- Row index are sorted.
- cols_sorted: bool
- Column index are sorted.
- transpose: bool
- A boolean specifying whether to transpose the sparse matrix
- before computing.
- method: str
- The method used to compute the matrix-vector multiplication.
+ Parameters
+ ----------
+ data : ndarray, float
+ An array of shape ``(nse,)``.
+ row : ndarray
+ An array of shape ``(nse,)``.
+ col : ndarray
+ An array of shape ``(nse,)`` and dtype ``row.dtype``.
+ vector : ndarray
+ An array of shape ``(shape[0] if transpose else shape[1],)`` and
+ dtype ``data.dtype``.
+ shape : tuple of int
+ The shape of the sparse matrix.
+ transpose : bool
+ A boolean specifying whether to transpose the sparse matrix
+ before computing.
- Returns::
-
- y: ndarray
- An array of shape ``(shape[1] if transpose else shape[0],)`` representing
- the matrix vector product.
+ Returns
+ -------
+ y : ndarray
+ An array of shape ``(shape[1] if transpose else shape[0],)`` representing
+ the matrix vector product.
"""
if isinstance(data, Array):
data = data.value
@@ -79,7 +74,16 @@ def coomv(
col = col.value
if isinstance(vector, Array):
vector = vector.value
- csr = brainevent.COO((data, row, col), shape=shape)
+
+ # The COO format was removed in brainevent 0.1.0; convert COO indices to
+ # CSR before delegating to brainevent.CSR.
+ indptr, indices, order = brainevent.coo2csr(row, col, shape=shape)
+ data = jnp.asarray(data)
+ if data.ndim == 0:
+ # scalar weight: broadcast to one entry per non-zero
+ data = jnp.broadcast_to(data, (indices.shape[0],))
+ data = data[order]
+ csr = brainevent.CSR((data, indices, indptr), shape=shape)
if transpose:
return vector @ csr
else:
diff --git a/brainpy/math/sparse/csr_mm.py b/brainpy/math/sparse/csr_mm.py
index c6056b0ad..97c212f47 100644
--- a/brainpy/math/sparse/csr_mm.py
+++ b/brainpy/math/sparse/csr_mm.py
@@ -61,6 +61,6 @@ def csrmm(
matrix = matrix.value
csr = brainevent.CSR((data, indices, indptr), shape=shape)
if transpose:
- return matrix @ csr
+ return csr.T @ matrix
else:
return csr @ matrix
diff --git a/brainpy/math/sparse/utils.py b/brainpy/math/sparse/utils.py
index ddab78677..3ff40619b 100644
--- a/brainpy/math/sparse/utils.py
+++ b/brainpy/math/sparse/utils.py
@@ -16,8 +16,8 @@
from typing import Tuple
+import brainevent
from jax import numpy as jnp
-from jax.experimental.sparse import csr_todense
from brainpy.math.interoperability import as_jax
@@ -39,15 +39,15 @@ def coo_to_csr(
post_ids = as_jax(post_ids)
# sorting
- sort_ids = jnp.argsort(pre_ids, kind='stable')
+ sort_ids = jnp.argsort(pre_ids, stable=True)
post_ids = post_ids[sort_ids]
indices = post_ids
unique_pre_ids, pre_count = jnp.unique(pre_ids, return_counts=True)
- final_pre_count = jnp.zeros(num_row)
- final_pre_count[unique_pre_ids] = pre_count
+ final_pre_count = jnp.zeros(num_row, dtype=jnp.int32)
+ final_pre_count = final_pre_count.at[unique_pre_ids].set(pre_count)
indptr = final_pre_count.cumsum()
- indptr = jnp.insert(indptr, 0, 0)
+ indptr = jnp.insert(indptr, 0, 0).astype(jnp.int32)
return indices, indptr
@@ -61,4 +61,23 @@ def csr_to_coo(
return jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1, indices
-csr_to_dense = csr_todense
+def csr_to_dense(data, indices, indptr, *, shape):
+ """Convert a CSR sparse matrix to a dense array.
+
+ Parameters
+ ----------
+ data : ndarray
+ An array of shape ``(nse,)`` holding the non-zero values.
+ indices : ndarray
+ An array of shape ``(nse,)`` holding the column index of each value.
+ indptr : ndarray
+ An array of shape ``(shape[0] + 1,)`` holding the row pointers.
+ shape : tuple of int
+ A length-2 tuple ``(n_rows, n_cols)`` for the dense matrix.
+
+ Returns
+ -------
+ dense : ndarray
+ The dense matrix of shape ``shape``.
+ """
+ return brainevent.CSR((data, indices, indptr), shape=shape).todense()
diff --git a/brainpy/math/surrogate/_one_input.py b/brainpy/math/surrogate/_one_input.py
index 91a6734b0..6f6ac5162 100644
--- a/brainpy/math/surrogate/_one_input.py
+++ b/brainpy/math/surrogate/_one_input.py
@@ -192,7 +192,7 @@ def surrogate_fun(self, x):
def surrogate_grad(self, dz, x):
x = as_jax(x)
- dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., dz * (-(self.alpha * x) ** 2 + self.alpha))
+ dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., dz * (-self.alpha ** 2 * jnp.abs(x) + self.alpha))
return dx
def __repr__(self):
@@ -471,7 +471,7 @@ def surrogate_grad(self, dz, x):
def surrogate_fun(self, x):
x = as_jax(x)
- return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5
+ return jnp.arctan(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5
def __repr__(self):
return f'{self.__class__.__name__}(alpha={self.alpha})'
@@ -654,7 +654,7 @@ def surrogate_grad(self, dz, x):
def surrogate_fun(self, x):
x = as_jax(x)
- return sci.special.erf(-self.alpha * x) * 0.5
+ return 0.5 * (1. - sci.special.erf(-self.alpha * x))
def __repr__(self):
return f'{self.__class__.__name__}(alpha={self.alpha})'
@@ -1066,13 +1066,13 @@ def __init__(self, alpha=2., forward_use_surrogate=False):
def surrogate_grad(self, dz, x):
x = as_jax(x)
- dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha)
+ dx = jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), -self.alpha)
return dx * as_jax(dz)
def surrogate_fun(self, x):
x = as_jax(x)
z = jnp.where(x < 0.,
- 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha),
+ 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * x, 1 - self.alpha),
1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha))
return z
@@ -1443,7 +1443,7 @@ def __init__(self, sigma=0.5, alpha=0.5):
def surrogate_grad(self, dz, x):
x = as_jax(x)
- dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
+ dx = jnp.exp(-(x ** 2) / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
return self.alpha * dx * as_jax(dz)
def __repr__(self):
diff --git a/brainpy/math/surrogate/_one_input_new.py b/brainpy/math/surrogate/_one_input_new.py
index 06ded6a44..908ab5eb8 100644
--- a/brainpy/math/surrogate/_one_input_new.py
+++ b/brainpy/math/surrogate/_one_input_new.py
@@ -176,7 +176,7 @@ def sigmoid(
):
r"""Spike function with the sigmoid-shaped surrogate gradient.
- If `origin=False`, return the forward function:
+ The forward function:
.. math::
@@ -185,11 +185,6 @@ def sigmoid(
0, & x < 0 \\
\end{cases}
- If `origin=True`, computes the original function:
-
- .. math::
-
- g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}}
Backward function:
@@ -251,7 +246,7 @@ def surrogate_fun(self, x):
def surrogate_grad(self, x):
x = as_jax(x)
- dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., (-(self.alpha * x) ** 2 + self.alpha))
+ dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., (-self.alpha ** 2 * jnp.abs(x) + self.alpha))
return dx
def __repr__(self):
@@ -264,7 +259,7 @@ def piecewise_quadratic(
):
r"""Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_.
- If `origin=False`, computes the forward function:
+ The forward function:
.. math::
@@ -273,16 +268,6 @@ def piecewise_quadratic(
0, & x < 0 \\
\end{cases}
- If `origin=True`, computes the original function:
-
- .. math::
-
- g(x) =
- \begin{cases}
- 0, & x < -\frac{1}{\alpha} \\
- -\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\
- 1, & x > \frac{1}{\alpha} \\
- \end{cases}
Backward function:
@@ -364,7 +349,7 @@ def piecewise_exp(
):
r"""Judge spiking state with a piecewise exponential function [1]_.
- If `origin=False`, computes the forward function:
+ The forward function:
.. math::
@@ -373,14 +358,6 @@ def piecewise_exp(
0, & x < 0 \\
\end{cases}
- If `origin=True`, computes the original function:
-
- .. math::
-
- g(x) = \begin{cases}
- \frac{1}{2}e^{\alpha x}, & x < 0 \\
- 1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0
- \end{cases}
Backward function:
@@ -454,7 +431,7 @@ def soft_sign(
):
r"""Judge spiking state with a soft sign function.
- If `origin=False`, computes the forward function:
+ The forward function:
.. math::
@@ -463,12 +440,6 @@ def soft_sign(
0, & x < 0 \\
\end{cases}
- If `origin=True`, computes the original function:
-
- .. math::
-
- g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1)
- = \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1)
Backward function:
@@ -526,7 +497,7 @@ def surrogate_grad(self, x):
def surrogate_fun(self, x):
x = as_jax(x)
- return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5
+ return jnp.arctan(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5
def __repr__(self):
return f'{self.__class__.__name__}(alpha={self.alpha})'
@@ -539,7 +510,7 @@ def arctan(
):
r"""Judge spiking state with an arctan function.
- If `origin=False`, computes the forward function:
+ The forward function:
.. math::
@@ -548,11 +519,6 @@ def arctan(
0, & x < 0 \\
\end{cases}
- If `origin=True`, computes the original function:
-
- .. math::
-
- g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2}
Backward function:
@@ -623,7 +589,7 @@ def nonzero_sign_log(
):
r"""Judge spiking state with a nonzero sign log function.
- If `origin=False`, computes the forward function:
+ The forward function:
.. math::
@@ -632,21 +598,6 @@ def nonzero_sign_log(
0, & x < 0 \\
\end{cases}
- If `origin=True`, computes the original function:
-
- .. math::
-
- g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1)
-
- where
-
- .. math::
-
- \begin{split}\mathrm{NonzeroSign}(x) =
- \begin{cases}
- 1, & x \geq 0 \\
- -1, & x < 0 \\
- \end{cases}\end{split}
Backward function:
@@ -707,7 +658,7 @@ def surrogate_grad(self, x):
def surrogate_fun(self, x):
x = as_jax(x)
- return sci.special.erf(-self.alpha * x) * 0.5
+ return 0.5 * (1. - sci.special.erf(-self.alpha * x))
def __repr__(self):
return f'{self.__class__.__name__}(alpha={self.alpha})'
@@ -720,7 +671,7 @@ def erf(
):
r"""Judge spiking state with an erf function [1]_ [2]_ [3]_.
- If `origin=False`, computes the forward function:
+ The forward function:
.. math::
@@ -729,15 +680,6 @@ def erf(
0, & x < 0 \\
\end{cases}
- If `origin=True`, computes the original function:
-
- .. math::
-
- \begin{split}
- g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\
- &= \frac{1}{2} \text{erfc}(-\alpha x) \\
- &= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt
- \end{split}
Backward function:
@@ -821,7 +763,7 @@ def piecewise_leaky_relu(
):
r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_.
- If `origin=False`, computes the forward function:
+ The forward function:
.. math::
@@ -830,16 +772,6 @@ def piecewise_leaky_relu(
0, & x < 0 \\
\end{cases}
- If `origin=True`, computes the original function:
-
- .. math::
-
- \begin{split}g(x) =
- \begin{cases}
- cx + cw, & x < -w \\
- \frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\
- cx - cw + 1, & x > w \\
- \end{cases}\end{split}
Backward function:
@@ -940,7 +872,7 @@ def squarewave_fourier_series(
):
r"""Judge spiking state with a squarewave fourier series.
- If `origin=False`, computes the forward function:
+ The forward function:
.. math::
@@ -949,11 +881,6 @@ def squarewave_fourier_series(
0, & x < 0 \\
\end{cases}
- If `origin=True`, computes the original function:
-
- .. math::
-
- g(x) = 0.5 + \frac{1}{\pi}*\sum_{i=1}^n {\sin\left({(2i-1)*2\pi}*x/T\right) \over 2i-1 }
Backward function:
@@ -1034,7 +961,7 @@ def s2nn(
):
r"""Judge spiking state with the S2NN surrogate spiking function [1]_.
- If `origin=False`, computes the forward function:
+ The forward function:
.. math::
@@ -1043,14 +970,6 @@ def s2nn(
0, & x < 0 \\
\end{cases}
- If `origin=True`, computes the original function:
-
- .. math::
-
- \begin{split}g(x) = \begin{cases}
- \mathrm{sigmoid} (\alpha x), x < 0 \\
- \beta \ln(|x + 1|) + 0.5, x \ge 0
- \end{cases}\end{split}
Backward function:
@@ -1115,13 +1034,13 @@ def __init__(self, alpha=2.):
def surrogate_grad(self, x):
x = as_jax(x)
- dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha)
+ dx = jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), -self.alpha)
return dx
def surrogate_fun(self, x):
x = as_jax(x)
z = jnp.where(x < 0.,
- 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha),
+ 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * x, 1 - self.alpha),
1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha))
return z
@@ -1136,7 +1055,7 @@ def q_pseudo_spike(
):
r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_.
- If `origin=False`, computes the forward function:
+ The forward function:
.. math::
@@ -1145,15 +1064,6 @@ def q_pseudo_spike(
0, & x < 0 \\
\end{cases}
- If `origin=True`, computes the original function:
-
- .. math::
-
- \begin{split}g(x) =
- \begin{cases}
- \frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\
- 1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0.
- \end{cases}\end{split}
Backward function:
@@ -1229,7 +1139,7 @@ def leaky_relu(
):
r"""Judge spiking state with the Leaky ReLU function.
- If `origin=False`, computes the forward function:
+ The forward function:
.. math::
@@ -1238,15 +1148,6 @@ def leaky_relu(
0, & x < 0 \\
\end{cases}
- If `origin=True`, computes the original function:
-
- .. math::
-
- \begin{split}g(x) =
- \begin{cases}
- \beta \cdot x, & x \geq 0 \\
- \alpha \cdot x, & x < 0 \\
- \end{cases}\end{split}
Backward function:
@@ -1330,7 +1231,7 @@ def log_tailed_relu(
):
r"""Judge spiking state with the Log-tailed ReLU function [1]_.
- If `origin=False`, computes the forward function:
+ The forward function:
.. math::
@@ -1339,16 +1240,6 @@ def log_tailed_relu(
0, & x < 0 \\
\end{cases}
- If `origin=True`, computes the original function:
-
- .. math::
-
- \begin{split}g(x) =
- \begin{cases}
- \alpha x, & x \leq 0 \\
- x, & 0 < x \leq 0 \\
- log(x), x > 1 \\
- \end{cases}\end{split}
Backward function:
@@ -1489,7 +1380,7 @@ def __init__(self, sigma=0.5, alpha=0.5):
def surrogate_grad(self, x):
x = as_jax(x)
- dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
+ dx = jnp.exp(-(x ** 2) / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
return self.alpha * dx
def __repr__(self):
diff --git a/brainpy/measure.py b/brainpy/measure.py
index 71c1c2771..4afd2a3a5 100644
--- a/brainpy/measure.py
+++ b/brainpy/measure.py
@@ -89,7 +89,10 @@ def firing_rate(spikes, width, dt=None, numpy=True):
np = onp if numpy else jnp
dt = bm.get_dt() if (dt is None) else dt
width1 = int(width / 2 / dt) * 2 + 1
- window = np.ones(width1) * 1000 / width
+ # Normalize by the actual window length (``width1`` bins of ``dt`` ms each),
+ # converting a per-bin spike count into a rate in Hz. Normalizing by the
+ # requested ``width`` instead biases the rate by ``width1 * dt / width``.
+ window = np.ones(width1) / (width1 * dt) * 1000.
return np.convolve(np.mean(spikes, axis=1), window, mode='same')
diff --git a/brainpy/optim/optimizer.py b/brainpy/optim/optimizer.py
index 312b2376f..57e1862c1 100644
--- a/brainpy/optim/optimizer.py
+++ b/brainpy/optim/optimizer.py
@@ -570,6 +570,10 @@ def __init__(
self.beta1 = beta1
self.beta2 = beta2
self.eps = eps
+ # Per-update step counter used for bias correction. It is a Variable so
+ # that it is traceable under JAX transforms and advances exactly once per
+ # call to ``update()`` (independent of any LR scheduler's ``last_epoch``).
+ self.step = bm.Variable(jnp.asarray(0))
def __repr__(self):
return (f"{self.__class__.__name__}(lr={str(self.lr)}, "
@@ -589,9 +593,14 @@ def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = Non
def update(self, grads: dict):
self.check_grads(grads)
- lr = self.lr()
- lr /= (1 - self.beta1 ** (self.lr.last_epoch.value + 2))
- lr *= jnp.sqrt(1 - self.beta2 ** (self.lr.last_epoch.value + 2))
+ # Advance the per-update step counter (t = 1 on the first update).
+ self.step.value = self.step.value + 1
+ t = self.step.value
+ # Read the (possibly scheduled) learning rate as a plain JAX array so we
+ # never mutate the underlying ``lr`` Variable in place.
+ lr = bm.as_jax(self.lr())
+ lr = lr / (1 - self.beta1 ** t)
+ lr = lr * jnp.sqrt(1 - self.beta2 ** t)
for key, p in self.vars_to_train.items():
m = self.implicit_vars[key + '_m']
v = self.implicit_vars[key + '_v']
@@ -931,6 +940,8 @@ def __init__(
self.beta2 = beta2
self.eps = eps
self.weight_decay = weight_decay
+ # Per-update step counter for bias correction (see ``Adam``).
+ self.step = bm.Variable(jnp.asarray(0))
def __repr__(self):
return (f"{self.__class__.__name__}(lr={self.lr}, "
@@ -960,8 +971,11 @@ def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = Non
def update(self, grads: dict):
self.check_grads(grads)
- lr_old = self.lr()
- step = self.lr.last_epoch.value + 2
+ # Advance the per-update step counter (t = 1 on the first update) and read
+ # the learning rate as a plain JAX array (never mutate the lr Variable).
+ self.step.value = self.step.value + 1
+ step = self.step.value
+ lr_old = bm.as_jax(self.lr())
bias_correction1 = 1 - self.beta1 ** step
bias_correction2 = 1 - self.beta2 ** step
lr = lr_old * jnp.sqrt(bias_correction2) / bias_correction1
@@ -1036,11 +1050,10 @@ def __init__(
weight_decay: Optional[float] = None,
name: Optional[str] = None,
):
- super(SM3, self).__init__(lr=lr,
- weight_decay=weight_decay,
- train_vars=train_vars,
- name=name)
-
+ # NOTE: validate and assign hyper-parameters *before* ``super().__init__``.
+ # The base ``Optimizer.__init__`` calls ``register_train_vars`` (when
+ # ``train_vars`` is given), which reads ``self.momentum``; assigning these
+ # attributes afterwards left the optimizer un-instantiable with train vars.
if not 0.0 <= momentum < 1.0:
raise ValueError("Invalid momentum: {0}".format(momentum))
if not 0.0 <= beta < 1.0:
@@ -1052,6 +1065,11 @@ def __init__(
self.beta = beta
self.momentum = momentum
+ super(SM3, self).__init__(lr=lr,
+ weight_decay=weight_decay,
+ train_vars=train_vars,
+ name=name)
+
def __repr__(self):
return (f"{self.__class__.__name__}(lr={self.lr}, "
f"beta={self.beta}, eps={self.eps}, momentum={self.momentum})")
@@ -1093,7 +1111,7 @@ def update(self, grads: dict):
result = update
for j in range(ndim):
if i != j:
- result = result.max(axis=j, keepdim=True)
+ result = result.max(axis=j, keepdims=True)
acc = self.implicit_vars[f'{k}_m{i}']
if self.beta > 0.:
acc.value = bm.maximum(acc, result)
diff --git a/brainpy/optim/scheduler.py b/brainpy/optim/scheduler.py
index 5941727a1..2ae0d307e 100644
--- a/brainpy/optim/scheduler.py
+++ b/brainpy/optim/scheduler.py
@@ -158,8 +158,10 @@ def __init__(
def call(self, i=None):
i = (self.last_epoch.value + 1) if i is None else i
milestones = jnp.asarray(self.milestones)
- conditions = jnp.logical_and((i >= milestones[:-1]), (i < milestones[1:]))
- p = jnp.argmax(conditions)
+ # Number of milestones strictly before epoch ``i``: lr decays *after* a
+ # milestone epoch is reached, so e.g. milestones=[10, 20] keeps the base
+ # lr through epoch 10 and applies the first decay from epoch 11 onward.
+ p = jnp.sum(milestones < i)
return self.lr * self.gamma ** p
def __call__(self, i=None):
diff --git a/brainpy/runners.py b/brainpy/runners.py
index 8e107f51a..9afb9951d 100644
--- a/brainpy/runners.py
+++ b/brainpy/runners.py
@@ -614,7 +614,9 @@ def _get_input_time_step(self, duration=None, xs=None) -> int:
else:
raise ValueError
- def _step_mon_on_cpu(self, args, transforms):
+ def _step_mon_on_cpu(self, args):
+ # host-side side effect: append the per-step monitored values (passed in
+ # as a dict of numpy arrays) to the running monitor lists.
for key, val in args.items():
self.mon[key].append(val)
@@ -636,12 +638,15 @@ def _step_func_predict(self, i, *x, shared_args=None):
clear_input(self.target)
if self._memory_efficient:
- mon_shape_dtype = jax.ShapeDtypeStruct(mon.shape, mon.dtype)
- result = jax.pure_callback(
- self._step_mon_on_cpu,
- mon_shape_dtype,
- mon,
- )
+ # ``mon`` is a dict of arrays. Offload the monitored values to the
+ # host on every step (keeping them out of device memory) using a
+ # side-effecting callback. ``jax.debug.callback`` is used instead of
+ # ``jax.pure_callback`` because the callback's purpose is the host
+ # side effect (appending to ``self.mon`` lists) and its result is
+ # discarded -- a ``pure_callback`` whose output is unused would be
+ # eliminated by XLA's dead-code elimination under ``jit``.
+ mon = tree_map(lambda x: bm.as_jax(x), mon, is_leaf=_is_brainpy_array)
+ jax.debug.callback(self._step_mon_on_cpu, mon)
return out, None
else:
return out, mon
diff --git a/brainpy/running/jax_multiprocessing.py b/brainpy/running/jax_multiprocessing.py
index afcc52e89..5375d44a8 100644
--- a/brainpy/running/jax_multiprocessing.py
+++ b/brainpy/running/jax_multiprocessing.py
@@ -133,13 +133,17 @@ def jax_parallelize_map(
res_tree = None
results = None
- vmap_func = pmap(func)
+ # Build the pmapped function once and reuse it across all chunks. Re-applying
+ # ``jax.pmap`` inside the loop forces a recompilation on every chunk, which is
+ # both slow and unnecessary since the traced function does not change.
+ pmap_func = pmap(func)
for i in range(0, num_pars[0], num_parallel):
- run_f = pmap(func) if clear_buffer else vmap_func
if isinstance(arguments, dict):
- r = run_f(**tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements]))
+ r = pmap_func(**tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements]))
+ elif isinstance(arguments, (tuple, list)):
+ r = pmap_func(*tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements]))
else:
- r = run_f(*tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements]))
+ raise TypeError(f'"arguments" must be sequence or dict, but we got {type(arguments)}')
res_values, res_tree = tree_flatten(r, is_leaf=lambda a: isinstance(a, bm.Array))
if results is None:
results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values)
diff --git a/docs/issues-found-20260618.md b/docs/issues-found-20260618.md
new file mode 100644
index 000000000..e11674de1
--- /dev/null
+++ b/docs/issues-found-20260618.md
@@ -0,0 +1,387 @@
+# BrainPy Package Audit — Issues Found (2026-06-18)
+
+**Reviewer role:** Senior Python architect · JAX expert · BrainX-ecosystem developer
+**Package:** `brainpy` v2.7.8 (~74k LOC, 252 non-test files, 17 submodules)
+**Environment (verified):** Python 3.13.11 · jax 0.10.1 · brainpy 2.7.8 · brainstate 0.5.0 · brainevent 0.1.0 (CPU)
+**Method:** Full deep sweep by parallel expert sub-audits of every submodule; static review of all findings + executable repro for the high-severity ones. 33 of the highest-impact findings were independently re-verified by the lead reviewer; all reproduced.
+
+> **Scope note.** `import brainpy` works in this environment, so most findings are *runtime-reproduced*, not speculative. Findings are tagged **[verified]** (a repro was executed and reproduced the bug), **[static]** (confirmed by code inspection / type analysis), or **[likely]** (strong reasoning, not executed).
+
+---
+
+## 1. Executive summary
+
+The audit found **131 distinct issues**: **26 Critical**, **53 High**, **36 Medium**, **16 Low**. The dominant story is **ecosystem-migration drift**: BrainPy 2.7.x was rebased onto the new BrainX stack (`brainstate` 0.5, `brainevent` 0.1, `braintools`) and onto JAX ≥0.9/0.10, and many code paths were not updated in lockstep. The result is a band of **silent numerical errors** and **crash-on-first-use** bugs concentrated in: the optimizer/loss/scheduler stack, the surrogate-gradient and synapse/plasticity code, the sparse/event operators, the FDE/adaptive integrators, and the normalization layers.
+
+Highest-impact, broad-blast-radius issues (all **verified**):
+
+| # | Issue | Location | Impact |
+|---|-------|----------|--------|
+| C-01 | **Adam/AdamW bias correction frozen at t=1** | `optim/optimizer.py:593-594,964-967` | Every Adam/AdamW training run uses wrong (un-debiased, growing) steps |
+| C-02 | **`nll_loss` returns +log-likelihood (sign flipped)** | `losses/comparison.py:461` | NLL training maximizes instead of minimizes |
+| C-03 | **`cross_entropy_loss` weights by sample index, not class** | `losses/comparison.py:266-267` | Class-weighted CE silently wrong / shape-crashes |
+| C-04 | **`MultiStepLR` never decays** | `optim/scheduler.py:157-163` | LR schedule is a no-op |
+| C-05 | **`GroupNorm`/`InstanceNorm` reduce over the group axis** | `dnn/normalization.py:597` | `num_groups` has no effect; every config == LayerNorm |
+| C-06 | **STP facilitation ODE diverges** | `dyn/synapses/abstract_models.py:862` | Short-term plasticity synapse blows up to ±thousands |
+| C-07 | **`csrmm(transpose=True)` computes the wrong product** | `math/sparse/csr_mm.py:63`, `math/event/csr_matmat.py:64` | Wrong values / shape crash in sparse matmat + its autodiff |
+| C-08 | **`CaputoEuler` mis-scales the initial condition** | `integrators/fde/Caputo.py:201` | Fractional ODE solver wrong whenever y0≠0 |
+| C-09 | **`TimeDelay` read omits modulo** | `math/delayvars.py:271` | Delay variable returns stale/wrong delayed values |
+| C-10 | **`disable_x64()` desyncs brainstate vs JAX precision** | `math/environment.py:645` | After any x64 context, default dtypes silently wrong |
+
+The good news: the **core ODE Runge–Kutta tableaus, most synapse kinetics, the GRU cell, the high-dim fixed-point finder, weight initializers, and the Conv/Dense/Dropout layers were checked and found correct** (see Appendix B). The bugs are concentrated, not pervasive — which makes them tractable to fix.
+
+---
+
+## 2. Cross-cutting themes (root causes)
+
+1. **brainstate 0.5 migration drift.** `pyproject` pins `brainstate>=0.2.7` but 0.5.0 is installed. Removed/renamed APIs surface as runtime crashes: `tracing_variable` now `raise NotImplementedError` (breaks default STDP, C-19), `jax.util` removed (breaks `VarDict` pytree, C-25), `State`-as-operand rejected in control flow (H-…), `Variable.size_without_batch` broken (H-…). **Action:** pin a tested `brainstate`/`brainevent` lower bound and add an import-time smoke test across the public surface.
+
+2. **JAX ≥0.9/0.10 API changes not propagated.** `Array.device()` (now a property), `csr_todense`/`csrmm` signatures, `jnp.argsort(kind=)` removed, `__float__` rejecting `ndim>0`. **Action:** a compatibility shim module + CI against the pinned JAX.
+
+3. **`brainevent` 0.1 backend migration left wrappers stale.** `brainevent.COO` removed (C, coomv dead), transpose semantics inverted in matmat (C-07), `coo_to_csr`/`csr_to_dense` broken, jitconn docstrings describe a `method=`/cuSPARSE API that no longer exists.
+
+4. **Surrogate-gradient ⇄ forward-function inconsistency.** Multiple surrogate classes have a `surrogate_grad` formula that does **not** match the derivative of their own `surrogate_fun` / docstring (Gaussian precedence, PiecewiseQuadratic, q-PseudoSpike, ERF sign, arctan crash). Compounded by `bm.surrogate` being **shadowed by `braintools.surrogate`**, so the in-repo package is dead relative to the public API yet still importable and buggy.
+
+5. **Validation-after-mutation ordering.** Several setters/config functions mutate state or read attributes *before* validating/normalizing inputs: `environment.set()` (partial-config leak), `Variable.value` setter (rejects `State`/numpy before unwrap), adaptive-RK `tol` default not propagated, `SM3.__init__` reads `self.momentum` before assignment.
+
+6. **Batched-math assumptions.** RLS/FORCE (C-23) and several reductions assume batch size 1; correct for the tested path, silently divergent for B>1.
+
+7. **`dt` vs `sqrt(dt)` and unit scaling in stochastic/rate models.** `ThresholdLinearModel` noise scales as `dt` not `sqrt(dt)`; PoissonInput uses variance as std (C-17); `CondNeuGroup` double-applies area scaling.
+
+8. **Docstring/NumPy-doc nonconformance & drift.** Pervasive `Parameters::` / `Returns::` literal-block markers (won't render), stale deprecation versions ("removed after 2.4.0" in 2.7.8), and docstrings whose constants/defaults disagree with code.
+
+---
+
+## 3. Critical findings (detail)
+
+### C-01 — Adam/AdamW bias correction is frozen at t=1 **[verified]**
+- **File:** `brainpy/optim/optimizer.py:593-594` (Adam), `:964-967` (AdamW); root cause `optim/scheduler.py:55-59`
+- **What:** Bias correction uses `self.lr.last_epoch.value + 2`, but with the default `Constant` scheduler `last_epoch` is never incremented during `update()` (only `step_epoch()` advances it, which optimizers never call). So `beta**(last_epoch+2) == beta**1` forever and the `m`/`v` EMAs are never debiased.
+- **Why it's wrong:** Under a constant gradient, correct Adam yields a constant step ≈ `-lr`. Measured steps instead grow: `dw = [-0.001, -0.00134, -0.00157, -0.00172, -0.00183]`.
+- **Fix:** Maintain an internal per-`update()` step counter `t` (independent of the LR scheduler) and use `beta1**t`, `beta2**t` for bias correction. Don't derive `t` from `last_epoch`.
+
+### C-02 — `nll_loss` returns the log-likelihood, not its negative **[verified]**
+- **File:** `brainpy/losses/comparison.py:461` (class `NLLLoss` wraps it)
+- **What:** `return mean(input[arange, target])` with no negation; the function's own docstring defines `-Σ w·x_{n,y_n}`.
+- **Measured:** `nll_loss(log p, [0,1]) = -0.2899` (correct `+0.2899`). Minimizing drives the correct-class log-prob to −∞.
+- **Fix:** `loss = -input[jnp.arange(len(target)), target]` (negate), keep the reductions.
+
+### C-03 — `cross_entropy_loss` applies class `weight` by sample index **[verified]**
+- **File:** `brainpy/losses/comparison.py:266-267` (`loss *= weight`)
+- **What:** Per-sample loss `(N,)` is multiplied elementwise by the per-class weight `(C,)` — so sample *n* is weighted by `weight[n]`, not `weight[target_n]`. Raises on `N≠C`, silently wrong on `N==C`.
+- **Measured:** logits `0(3,3)`, targets `[2,2,2]`, weight `[10,20,1]` → per-sample `[10.99, 21.97, 1.10]`; correct is all `1.10` (`w[2]`).
+- **Fix:** gather `weight[target]` before reduction; for `mean`, normalize by `sum(weight[targets])`.
+
+### C-04 — `MultiStepLR` never decays **[verified]**
+- **File:** `brainpy/optim/scheduler.py:157-163`
+- **What:** `conditions = (i>=milestones[:-1]) & (i None`.
+- **Measured:** `odeint(..., method='rkf45', adaptive=True)(...)` → `TypeError: '>' not supported between ArrayImpl and NoneType`.
+- **Fix:** `code_scope['tol'] = self.tol` (and keyword default likewise).
+
+### C-13 — All SDE integrators `NameError` on invalid type (missing `errors` import) **[verified]**
+- **File:** `brainpy/integrators/sde/base.py:76,79,82`; `sde/normal.py:225`
+- **What:** Validation references `errors.IntegratorError` but `errors` is never imported; also the `Heun` Ito/Stratonovich guard.
+- **Measured:** `sdeint(..., intg_type='WRONG')` → `NameError: name 'errors' is not defined`.
+- **Fix:** `from brainpy import _errors as errors` in both files.
+
+### C-14 — Standalone HH/Markov channel gating produces NaN at voltage singularities **[verified by sub-audit]**
+- **File:** `brainpy/dyn/channels/sodium.py:384,299,215`; `potassium.py:359,222,290` (+legacy dups `:1191,1261,1332`); `calcium.py:711`
+- **What:** Rates coded as `k*temp/(1-exp(-temp/d))` are 0/0 → NaN exactly at the removable singularity (e.g. `IK_HH1952v2` at V=−55). The HH *neuron* class was fixed with `bm.exprel`; the channel modules were not. `bm.where` clamping can't recover it (both branches evaluated).
+- **Measured:** `IK_HH1952v2(1).f_p_alpha([-55.0]) = [nan]`.
+- **Fix:** rewrite with `bm.exprel`, e.g. `0.1 / bm.exprel(-(V - V_sh + 10)/10)` (mind the `k*d` coefficient bookkeeping). Fix legacy duplicates too.
+
+### C-15 — `ThresholdLinearModel` noise path crashes (`randn` signature) **[verified]**
+- **File:** `brainpy/dyn/rates/populations.py:1051,1060`
+- **What:** `bm.random.randn(self.varshape)` passes a shape *tuple* as a single positional arg; brainstate's `randn` takes unpacked dims.
+- **Measured:** any nonzero `noise_e/noise_i` → `TypeError: Shapes must be 1D sequences ... got ((1000,),)`.
+- **Fix:** `bm.random.randn(*self.varshape)` or `bm.random.normal(size=self.varshape)`. (Separately, the noise scales as `dt` not `sqrt(dt)` — see M-…)
+
+### C-16 — `StuartLandauOscillator.dy` has the wrong rotational coupling **[verified]**
+- **File:** `brainpy/dyn/rates/populations.py:721`
+- **What:** `dy` returns `(a-x²-y²)*y - w*y + y_ext` (copy-paste from `dx`); the Hopf normal form needs `+ w*x`. As written there's no x↔y rotation, so no limit cycle.
+- **Measured:** `dy(y=.5,x=.3,a=.25,w=.2) = -0.145` (buggy `-w*y`); correct `+w*x` gives `+0.015`.
+- **Fix:** `return (a - x*x - y*y)*y + w*x + y_ext`.
+
+### C-17 — `PoissonInput` Gaussian branch uses the variance as the std (~3–4× too much noise) **[verified]**
+- **File:** `brainpy/dyn/projections/inputs.py:168,174`; duplicated in `brainpy/dynold/experimental/others.py:74-77`
+- **What:** `bm.random.normal(a, b*p, ...)` passes `b*p = n(1-p)p` (the Binomial *variance*) as the std; correct is `sqrt(n·p·(1-p)) = sqrt(b*p)`.
+- **Measured:** `n=1000,p=0.02` → code std `19.6` vs correct `4.43`. Active in the common large-N branch; mean is correct.
+- **Fix:** `scale = jnp.sqrt(b*p)` (both the eager and `bm.cond` branches, and the dynold copy).
+
+### C-18 — `HalfProjAlignPost.update` calls `comm` twice **[verified by sub-audit]**
+- **File:** `brainpy/dyn/projections/align_post.py:384-388`
+- **What:** Computes `current = self.comm(x)` then `g = self.syn(self.comm(x))` — two independent calls. For event/jit-prob comms each call draws fresh random connectivity, so the synapse sees different input than the returned current; doubles compute for deterministic comms.
+- **Fix:** `current = self.comm(x); g = self.syn(current); ...; return current`.
+
+### C-19 — `STDP_Song2000` crashes on the first update (tracing_variable removed) **[verified by sub-audit]**
+- **File:** `brainpy/dyn/projections/plasticity.py:230-240` → `brainpy/dnn/linear.py:502-503`
+- **What:** `stdp_update` falls back to `self.tracing_variable('weight', ...)`, which now unconditionally `raise NotImplementedError`. The weight is only a `Variable` when the comm is built with `mode=TrainingMode`; the class docstring example omits it, so the documented usage is dead on arrival.
+- **Fix:** in `stdp_update`, promote the weight directly (`self.weight = bm.Variable(self.weight)`) or require trainable weights in `STDP_Song2000`. Also fixes a companion crash (H-…: `bm.as_jax(None)` for default `W_min/W_max`).
+
+### C-20 — `AlphaCUBA` / `AlphaCOBA` raise `ZeroDivisionError` on construction **[verified]**
+- **File:** `brainpy/dynold/synapses/compat.py:208-270`; root `brainpy/dyn/synapses/abstract_models.py:159-164`
+- **What:** They pass `tau_rise == tau_decay` into `DualExpon`, whose peak normalizer `A = tau_decay/(tau_decay - tau_rise)·…` divides by zero.
+- **Measured:** `bp.synapses.AlphaCUBA(LIF(2), LIF(2), All2All(), tau_decay=10.)` → `ZeroDivisionError`.
+- **Fix:** route `AlphaCUBA/COBA` through the single-tau `synapses.Alpha`, or special-case `tau_rise==tau_decay` (L'Hôpital limit `A=e`, `a=1/tau`).
+
+### C-21 — dynold `STP` learning rule injects current with zero presynaptic spikes **[verified by sub-audit]**
+- **File:** `brainpy/dynold/synapses/learning_rules.py:33-37,231-233`
+- **What:** `_STPModel = Sequential(STP, Expon)`; modern `STP.update` returns `u*x` (≈0.15 at rest) every step, and `Expon` treats it as additive current, so `g += u*x` continuously. The spike gating is lost.
+- **Measured:** zero input → `syn.I` ramps to ~512 and keeps rising.
+- **Fix:** gate by spikes (`pre_spike*(u*x)`), or use the modern `dyn/projections/plasticity` wiring.
+
+### C-22 — `DSRunner(memory_efficient=True)` is completely non-functional **[verified by sub-audit]**
+- **File:** `brainpy/runners.py:638-647` (+ `_step_mon_on_cpu` :617-619)
+- **What:** `_step_func_monitor()` returns a dict, but the code does `jax.ShapeDtypeStruct(mon.shape, mon.dtype)` on it; the `pure_callback` arg count and the `None` return are also wrong. Cannot have worked since the migration.
+- **Measured:** any `memory_efficient=True` run → `AttributeError: 'dict' object has no attribute 'shape'`.
+- **Fix:** `jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), mon)`; fix the callback signature/return; add a smoke test.
+
+### C-23 — RLS / FORCE online update is wrong for batch size > 1 **[verified by sub-audit]**
+- **File:** `brainpy/algorithms/online.py:148-154` (drives `train/online.py` `OnlineTrainer`/`ForceTrainer`)
+- **What:** `c = jnp.sum(1.0/(1.0+hPh))` collapses the `B×B` matrix `(I+HPHᵀ)` to a scalar (summing reciprocals of all entries incl. off-diagonals). Correct only for B=1; for B>1 `c` grows with B and can go negative → `P` diverges, update sign flips.
+- **Measured:** B=16 → `c=-90.3` (correct diag value `+1.7`); fitting with B≥4 → NaN weights within a few hundred steps.
+- **Fix:** proper block RLS: `K = PHᵀ(I + HPHᵀ)⁻¹` via `jnp.linalg.solve`, or assert `input.shape[0]==1`.
+
+### C-24 — `PoissonEncoder.single_step` crashes (its own documented usage) **[verified]**
+- **File:** `brainpy/encoding/stateless_encoding.py:91-94` → `:111`
+- **What:** `single_step(x, i_step=None)` delegates to `multi_steps(x, n_time=None)` whose first line is `int(n_time/get_dt())` → `None/float`.
+- **Measured:** `PoissonEncoder().single_step(bm.random.rand(4))` → `TypeError`.
+- **Fix:** in the `i_step is None` branch, draw a single Bernoulli sample directly; guard `multi_steps` against `n_time is None`.
+
+### C-25 — `VarDict.tree_unflatten` crashes on every JAX transform (`jax.util` removed) **[verified by sub-audit]**
+- **File:** `brainpy/math/object_transform/variables.py:423`
+- **What:** Calls `jax.util.safe_zip(...)`, but `jax.util` no longer exists in jax 0.10.1. `VarDict` is a registered pytree, so any `jit`/`vmap`/`tree_map` over one fails.
+- **Measured:** `jax.jit(lambda d: d)(bm.var_dict({'a': bm.Variable(...)}))` → `AttributeError: module 'jax' has no attribute 'util'`.
+- **Fix:** use `brainstate._compatible_import.safe_zip` or `cls(zip(keys, values))`.
+
+### C-26 — `Variable` batch_axis / axis_names silently dropped through pytree round-trip **[verified by sub-audit]**
+- **File:** `brainpy/math/object_transform/variables.py:40,79` (inherits `Array.tree_flatten` returning `aux_data=None`)
+- **What:** Reconstructing a `Variable` after `jit`/`vmap`/`scan`/`tree_map` loses `batch_axis`/`axis_names` (reset to `None`). brainstate's closure-based transforms work around it, but explicit pytree / `jit`-argument use degrades silently (affects sharding, `size_without_batch`, value-setter shape checks).
+- **Fix:** override `Variable.tree_flatten/unflatten` to carry `(batch_axis, axis_names)` in `aux_data` and rebuild without re-running naming/`State.__init__` side effects.
+
+---
+
+## 4. High findings (detail)
+
+> Format: `[verified|static] file:line — what → fix`
+
+### Object model / transforms (`math/object_transform`)
+- **H-01 [static]** `jit.py:200-207` — `cls_jit` shifts `static_argnums` by `+1` unconditionally, corrupting **negative** indices (`-1` → `(0,0)` marks `self` static twice). → shift only `x>=0`; resolve negatives against the signature.
+- **H-02 [verified]** `controls.py:125-561` + `_utils.py:78-85` — passing a `Variable`/`State` in `operands` of `cond`/`for_loop`/`while_loop`/`scan` raises (`State` rejected at brainstate cache-key time, before the in-wrapper strip). → strip state from `operands` before forwarding; document closure capture as the supported path.
+- **H-03 [verified]** `controls.py:372-390` — `for_loop(jit=False)` zero-length guard only checks `operands[0].shape`; a pytree operand (dict) has no `.shape`, so it crashes with "zero-length scan … in disable_jit()". → compute leading length from `jax.tree.leaves`.
+- **H-04 [verified]** `jit.py:127-153` — `jit` docstrings advertise `dyn_vars`/`child_objs`, now forwarded into `**kwargs` → `TypeError` from brainstate. → drop from docstring; filter/warn in `jit`.
+- **H-05 [verified]** `base.py:603-614` — `to()/cpu()/cuda()/tpu()` iterate `state_dict()` (nested dicts), so `isinstance(var, Array)` is always False; they never move variables and inject junk dict-valued attributes named after nodes. → iterate `self.vars().values()` and set `var.value = jax.device_put(...)`.
+- **H-06 [verified]** `variables.py:143-172` — `Variable.value` setter validates shape/dtype on the raw input *before* unwrapping `State`/numpy, so `v.value = some_State` and `v32.value = np.float64_array` raise spurious `MathError`. → unwrap/convert first, then validate.
+- **H-07 [verified]** `naming.py:24,34-44` — global `_name2id` registry grows unboundedly (no weakref/GC pruning) and stores `id(obj)`, so reused ids cause false `UniqueNameError`. → `WeakValueDictionary` + prune dead refs.
+- **H-08 [verified]** `base.py:287-309` vs `collectors.py:198` — `register_implicit_vars` default `var_cls` accepts `VarList/VarDict`, but `ArrayCollector.__setitem__` asserts `isinstance(value, Variable)` → `AssertionError`. → flatten containers or relax the collector.
+- **H-09 [verified]** `variables.py:41` — `Variable.__eq__` returns an elementwise array while `__hash__` is identity-based → breaks `in`/set/dict-by-value and raises ambiguous-truth. → define identity `__eq__`/`__ne__`, or guarantee all internal membership uses `id()`.
+
+### Math core / compat / sparse
+- **H-10 [verified]** `modes.py:38-41` — `Mode` overrides `__eq__` without `__hash__`, so every mode is unhashable (regression vs the hashable brainstate parent). → `__hash__ = brainstate.mixin.Mode.__hash__`.
+- **H-11 [verified]** `ndarray.py:206-207` — `Array.device()` calls the now-property `jax.Array.device` → `TypeError`. (Also `device_buffer` :209, `block_host_until_ready` :200.) → make `device` a property.
+- **H-12 [verified]** `ndarray.py:99-107` — `Array(scalar)` stores a bare Python scalar; `.shape`/most ops then crash. → fall through to `jnp.asarray(value)`.
+- **H-13 [verified]** `compat_numpy.py:217-220` — `asfarray(a)` with default `dtype=None` no-ops on integer input because `np.issubdtype(None, np.inexact)` is True. → `if dtype is None or not issubdtype(...): dtype = float`.
+- **H-14 [verified]** `_utils.py:59-62` (+ pytorch compat) — `out=` argument makes wrapped funcs return `None` (numpy/torch return the array). → `return out` after `out.value = r`.
+- **H-15 [verified]** `others.py:94-96` — `remove_diag` uses concrete boolean-mask indexing → `NonConcreteBooleanIndexError` under `jit`/`vmap`. → static off-diagonal gather.
+- **H-16 [verified]** `activations.py:668-669` — `softmin` lacks max-subtraction → NaN for large inputs (`softmin([1000,1001,1002]) = [nan,nan,nan]`). → `softmax(-x, axis)`.
+- **H-17 [verified]** `sparse/coo_mv.py:82` — `coomv` builds `brainevent.COO`, removed in brainevent 0.1 → `AttributeError`. → convert COO→CSR or drop.
+- **H-18 [verified]** `sparse/utils.py:42,47-49` — `coo_to_csr` broken: `argsort(kind=)` removed, in-place item assignment on immutable array, float `indptr`. → `argsort(stable=True)`, `.at[].set`, int dtype.
+- **H-19 [verified]** `sparse/utils.py:64` — `csr_to_dense = csr_todense` re-exports a jax function whose signature changed to take a `CSR` object → `TypeError` on the legacy call. → wrap explicitly via `brainevent.CSR(...).todense()`.
+
+### Surrogate gradients (`math/surrogate`) — all **[verified]**, present in both `_one_input.py` and `_one_input_new.py`
+- **H-20** `_one_input_new.py:1492` (`_one_input.py:1446`) — `GaussianGrad`: `exp(-(x**2)/2*sigma**2)` = `exp(-(x²/2)·σ²)`, σ inverted (precedence). At σ=2 the bump is ~`e²`× too narrow (grad@±1 ≈ 0.0135 vs intended ≈0.088). → `exp(-(x**2)/(2*sigma**2))`.
+- **H-21** `_one_input_new.py:254` (`:195`) — `PiecewiseQuadratic`: grad uses `-(α x)²+α` but the derivative of its own forward is `-α²|x|+α`. → `-self.alpha**2*jnp.abs(x)+self.alpha`.
+- **H-22** `_one_input_new.py:1118` (`:1069`) — `q_pseudo_spike`: grad denominator uses `alpha+1`, docstring/forward use `alpha-1`. → `alpha-1`.
+- **H-23** `_one_input_new.py:529` (`:474`) — `arctan.surrogate_fun` calls `jnp.arctan2(...)` with one arg → `TypeError`. → `jnp.arctan(...)`.
+- **H-24** `_one_input_new.py:710` (`:657`) — `ERF.surrogate_fun = erf(-αx)*0.5` is decreasing in [−0.5,0.5]; should be `0.5*(1-erf(-αx))`. → fix sign/offset (`0.5*erfc(-αx)`).
+- **H-25 [verified]** `math/__init__.py:47` — `bm.surrogate` is reassigned to `braintools.surrogate`, so the entire in-repo `brainpy/math/surrogate` package (with the above bugs) is unreachable via the public API yet still importable. → delete the in-repo package or stop the override; don't ship both.
+
+### Integrators
+- **H-26 [verified by sub-audit]** `ode/adaptive_rk.py:532` — `BoSh3.B2 = ['-5/72', …]` (sums to 0) makes the embedded error estimate ~20× too large and wrong-signed. → `B2 = ['7/24', 0.25, '1/3', 0.125]`.
+- **H-27 [verified]** `ode/adaptive_rk.py:221` — step controller `where(error>tol, shrink, dt)` never *increases* dt (one-sided), contradicting the docstring. → unconditional clamped factor `dt*clip(0.9*(tol/error)**(1/(p+1)), …)`.
+- **H-28 [verified]** `ode/adaptive_rk.py:164,214-217` — default `var_type=POP_VAR` emits `sum(abs(...))` (builtin `sum`) → `'float' object is not iterable` on scalar state. → `jnp.sum(jnp.abs(...))`.
+- **H-29 [verified]** `integrators/runner.py:242,254,262` — `IntegratorRunner` reuses loop var `i` (`for i,v in enumerate(...)`) clobbering the step index, so monitors reading `shared['i']` get `len(vars)-1`. → rename inner loop var.
+
+### FDE
+- **H-30 [verified]** `fde/GL.py:187` — `GLShortMemory.reset` uses key `key` instead of `key+'_delay'` → `KeyError` on every reset. → add the suffix.
+- **H-31 [verified by sub-audit]** `fde/Caputo.py:375` — `CaputoL1Schema.hists()` default path does `{k:v.numpy() for k,v in hists_}` (iterates dict keys) → `ValueError`. → `.items()`.
+- **H-32 [verified by sub-audit]** `fde/generic.py:87-88` — `set_default_fdeint()` assigns `_DEFAULT_ODE_METHOD` (wrong global), so it's a no-op. → assign `_DEFAULT_DDE_METHOD`.
+
+### Dyn (neurons / synapses / rates)
+- **H-33 [verified]** `dyn/ions/base.py:54-55` — `for k,v in channels.items(): self.add_elem(k=v)` passes literal keyword `k`, so all channels register under name `"k"` and overwrite each other. → `self.add_elem(**{k: v})`.
+- **H-34 [verified by sub-audit]** `dyn/neurons/lif.py:1108-1109` — `ExpIFRef/ExpIFRefLTC` unconditionally `odeint`, dropping any `noise=` (every other `*Ref` guards with `sdeint`). → guard on `self.noise`.
+- **H-35 [verified by sub-audit]** `dyn/neurons/lif.py:4495-4496,3814-3815` — `IzhikevichRef`/`GifRef` compute `spike_no_grad` but reset state with the grad-carrying `spike`, so `detach_spk` is a no-op. → use `spike_no_grad` for resets.
+- **H-36 [verified]** `dyn/rates/rnncells.py:401,412` — `LSTMCell` `h`/`c` setters slice axis 0 while getters split axis −1; unbatched → `IndexError`, batched → wrong-rows write. → slice the last axis.
+- **H-37 [verified]** `dyn/rates/reservoir.py:226` — `noise_rec * uniform(-1,-1, …)` is a constant `-noise_rec` bias, not noise (typo for `uniform(-1,1)`). → fix bounds.
+- **H-38 [verified]** `dyn/rates/reservoir.py:191,202-232` — `self.bias` is created (and TrainVar in training) but never added in `update()`. → `hidden += self.bias`.
+- **H-39 [verified by sub-audit]** `dyn/synapses/abstract_models.py:880-881` — STP discrete `u`/`x` jumps read pre-decay `self.u`/`self.x` instead of the decayed locals; off by one decay step. → use the decayed `u,x` locals (and apply `x` jump after the `u` jump).
+- **H-40 [static]** `dyn/projections/base.py:1-26` — byte-for-byte duplicate of `utils.py` (only a private helper), yet `projections/__init__.py` does `from .base import *` (imports nothing); misleading vs the real `SynConn` base. → delete or re-export the real base classes.
+- **H-41 [verified]** `dyn/projections/plasticity.py:232-233` — default `W_min=W_max=None` → `bm.as_jax(None)` raises (first crash even before C-19's path). → pass `None` through unchanged.
+
+### dynold compat
+- **H-42 [verified]** `dynold/neurons/reduced_models.py` LIF/ExpIF/AdExIF — default params silently changed to the modern `*Ref` values (`LIF`: `V_rest=0,V_reset=-5,V_th=20`; `ExpIF/AdExIF`: `V_th=-55`) while docstrings still claim `-65/-68/-30`. → restore historical defaults in the dynold wrappers or fix every docstring.
+
+### Top-level glue / measure / delay
+- **H-43 [verified]** `measure.py:91-92` — `firing_rate` normalizes by requested `width` while the window length is `width1=int(width/2/dt)*2+1≠width/dt`; biased by `width1·dt/width` (e.g. true 100 Hz → oscillates 100↔200, mean 110). → `window = ones(width1)/(width1*dt)*1000`.
+- **H-44 [verified]** `delay.py:254-257` (+ class attr only-annotated `:72`) — `VarDelay(target, time=T>0)` reads `self.data` in `_init_data` before it is ever assigned → `AttributeError: 'data'`. → set `self.data = None` unconditionally before the `max_length>0` branch.
+- **H-45 [verified]** `delay.py:481` → `math/object_transform/variables.py:106-112` — `DataDelay.reset_state(batch_size)` calls `size_without_batch`, which does `self.size[:batch_axis]+…` but `Variable.size` is the integer element count, not a shape tuple → `TypeError: 'int' object is not subscriptable` for any batched variable. → use `self.shape` in `size_without_batch`.
+
+### Train / running
+- **H-46 [verified]** `algorithms/offline.py:159,386` — `gradient_descent=True` path does `jnp.logical_and(...).value` (no `.value` on a jax array/tracer) → `AttributeError`; breaks every GD regression incl. always-GD `Lasso`/`ElasticNet`. → drop `.value`.
+- **H-47 [static]** `algorithms/offline.py:272-276` — ridge `XᵀX+αI` penalizes the prepended bias column (intercept shrunk) and is off by the ½ factor vs the documented `½α‖w‖²`. → zero the `(0,0)` entry of the penalty; reconcile the ½.
+- **H-48 [static]** `running/jax_multiprocessing.py:136-156` — `jax_parallelize_map` builds one cached `pmap` reused across chunks; the trailing partial chunk ≠ device count → retrace/crash; also mislabeled `vmap_func`, missing `else: raise`. → build per chunk or pad to a device multiple.
+
+### Analysis
+- **H-49 [static]** `analysis/lowdim/lowdim_analyzer.py:377,953` & `utils/optimization.py:398` — arg-unwrap comprehension tests `isinstance(candidates, bm.Array)` instead of `isinstance(a, …)` (3 copies) → either `AttributeError` or `bm.Array` leaking into `meshgrid`/`vmap`. → test `a`.
+- **H-50 [static]** `analysis/lowdim/lowdim_analyzer.py:1038-1040` — non-convertible 2D `_get_fixed_points` does `jnp.concatenate([])` when nothing converges → `ValueError` (the 1D/convertible paths guard, this one doesn't). → empty-guard return.
+
+### DNN
+- **H-51 [verified]** `dnn/normalization.py:100,134,503,588` — `BatchNorm*`/affine `LayerNorm`/`GroupNorm` raise `UnsupportedError` out-of-the-box under the default `NonBatchingMode` (and the affine `assert isinstance(mode, TrainingMode)`); only `mode=bm.training_mode` works. → default to `TrainingMode` when `mode is None`, or raise a clear message; broaden the affine assert.
+
+### Optim / losses / encoding
+- **H-52 [verified]** `optim/optimizer.py:592-594` — `Adam` corrupts an `lr` passed as a `bm.Variable` via in-place `lr /= …; lr *= …` (mutates the shared Variable each step). → non-mutating arithmetic / `bm.as_jax(self.lr())`.
+- **H-53 [static]** `losses/comparison.py:194-201` — `CrossEntropyLoss` stores `ignore_index`/`label_smoothing` but never forwards them; `cross_entropy_loss` has no such params → both are silent no-ops. → implement and forward.
+
+---
+
+## 5. Medium findings (condensed)
+
+| ID | [status] | File:line | Issue → Fix |
+|----|----------|-----------|-------------|
+| M-01 | verified | `optim/scheduler.py` via `optimizer.py` | `StepLR`/cosine families share the `last_epoch`-never-advances issue feeding C-01; audit all schedulers' step source. |
+| M-02 | static | `math/object_transform/jit.py` | `cls_jit` doesn't shift `donate_argnums` → donates `self`. → add param + `+1` shift. |
+| M-03 | verified | `controls.py:466-481` | `scan` returns `(carry, ys)` but docstring promises only `ys` (legacy contract change). → fix docs or return `ys`. |
+| M-04 | static | `controls.py:391-397` | `for_loop(jit=False)` toggles process-global `jax.disable_jit()` (no-op under an outer trace). → document / brainstate-native opt-out. |
+| M-05 | static | `controls.py:207-237` | `ifelse` omits `check_cond=False` though it already guarantees exclusivity → per-call device all-reduce + error branch. → pass `check_cond=False`. |
+| M-06 | verified | `controls.py:550-561` | `while_loop` body returning `None` freezes the carry (infinite-loop hazard). → raise on `None`/structure mismatch. |
+| M-07 | verified | `math/environment.py:391-428` | `set()`/`set_environment()` mutate globals before validating `numpy_func_return` → partial-config leak on error. → validate first. |
+| M-08 | verified | `math/remove_vmap.py:55-85` | under `vmap`, `remove_vmap(x,'any'/'all')` broadcasts the global reduction back over the batch (leaks across examples). → return a true scalar / document. |
+| M-09 | static | `math/ndarray.py:259-271` | `ShardedArray.value` getter inserts `with_sharding_constraint` on *every* read (always-true on single-device). → skip `SingleDeviceSharding`. |
+| M-10 | static | `math/sharding.py:119-162` | fully-unmatched axis names silently yield a replicated `PartitionSpec(None,…)` instead of erroring. → warn/raise on full mismatch. |
+| M-11 | verified | `math/compat_numpy.py:144-160` | `empty`/`empty_like` call `zeros`/`zeros_like` (needless zero-fill, wrong semantics). → `jnp.empty*`. |
+| M-12 | verified | `math/compat_numpy.py:129-133` | `fill_diagonal(inplace=False)` returns a raw jax array, not a brainpy `Array`. → `_return(r)`. |
+| M-13 | static | `math/jitconn/matvec.py` (+`event_matvec.py`) | `seed=None` draws a host RNG per call → non-reproducible eager, jit-frozen seed. → require/thread an explicit seed; document. |
+| M-14 | verified | `math/delayvars.py:215` | `TimeDelay.reset` drops `dtype=get_float()` on `current_time` and ignores callable `before_t0`. → mirror `__init__`. |
+| M-15 | verified | `math/pre_syn_post.py:291-293` | `pre2post_mean` scalar branch scatter-sets (no averaging, ignores duplicate post ids). → route through `syn2post_mean` or document. |
+| M-16 | static | `dyn/neurons/hh.py:148-194` | `CondNeuGroup.update` passes synaptic current through the `1e-3/A` external-input scaling (double-scales when `A≠1e-3`). → inject into the derivative like the LTC class. |
+| M-17 | verified | `dyn/ions/potassium.py:45` | `PotassiumFixed` default `E=-950 mV` (likely typo for `-95`). → fix default (confirm vs intended). |
+| M-18 | verified | `dyn/rates/populations.py:370-371` | `FeedbackFHN.reset_state` rebinds `self.input`/`input_y` to fresh Variables (breaks captured refs) instead of `.value=`. → set `.value`. |
+| M-19 | verified | `dyn/rates/populations.py:374` | `FeedbackFHN` delay queries `x_delay(t-delay)` while `state_delays` already registers the delay → double-counts (buffer-edge clamp). → query `x_delay(t)`. |
+| M-20 | static | `dyn/rates/populations.py:1051-1062` | `ThresholdLinearModel` noise scales as `dt` not `sqrt(dt)` (dt-dependent intensity). → Euler–Maruyama `sqrt(dt)`. |
+| M-21 | verified | `dyn/rates/rnncells.py:127,239,375` | `RNN/GRU/LSTMCell.reset_state(None)` builds `(None,num_out)` → `ValueError`. → branch on `None` → `(num_out,)`. |
+| M-22 | verified | `dyn/synapses/abstract_models.py:879-881,800-801` | `STP`/`STD` "simplified" updates assume binary `pre_spike`; graded inputs are wrong. → restore graded formula or assert binary. |
+| M-23 | static | `dynold/synapses/abstract_models.py` | dual-exp/NMDA/AMPA peak silently renormalized (`g_max` semantics changed vs pre-3.0). → document or auto-scale for compat. |
+| M-24 | verified | `dynold/neurons/reduced_models.py:1311` | `ALIFBellec2020` default `a_initializer=OneInit(-50.)` (adaptation var should start ~0). → `ZeroInit()`. |
+| M-25 | verified | `dnn/normalization.py:156-158` | `BatchNorm` stores *biased* batch var into `running_var` (PyTorch uses unbiased for the running stat). → apply `N/(N-1)`. |
+| M-26 | verified | `dnn/normalization.py:509` | `LayerNorm` shape-mismatch path does `", ".join(int_tuple)` → `TypeError` masking the real error. → `map(str, …)`. |
+| M-27 | verified | `dnn/pooling.py:118,390,787` | negative `channel_axis == -x_dim` wrongly rejected (`abs()` bound check). → `-x_dim <= axis < x_dim`. |
+| M-28 | verified | `dnn/function.py:91` | `Flatten` default `start_dim=0` contradicts its docstring/PyTorch (`1`) and drops the batch dim. → `start_dim=1` or fix docs. |
+| M-29 | verified | `optim/optimizer.py:1039-1096` | `SM3` reads `self.momentum` in `register_train_vars` before it's set → un-instantiable (also torch-style `keepdim=`). → set attrs before `super().__init__`. |
+| M-30 | verified | `connect/random_conn.py:99,87-89` | `FixedProb` sparse `build_coo/csr` use `int(post_num*prob)` (floors to 0 for small post; biased density) and forbid `include_self=False` on rectangular shapes with a contradictory message. → round/Bernoulli; drop the guard. |
+| M-31 | static | `train/back_propagation.py:522-523` | BPTT `indices = arange(self.i0, …)` but `i0` isn't advanced/pinned → wrong absolute `t` when `reset_state=False`. → pin `arange(0,num_step)` or document. |
+| M-32 | verified | `running/runner.py:99-101` | `Runner.__init__` mutates the caller's `jit` dict via `.pop()`. → operate on a copy. |
+| M-33 | static | `analysis/stability.py:148-163` | 2D star vs degenerate-node classification is inverted (eigenvalues alone can't distinguish; needs eigenvector rank). → use `matrix_rank(J-λI)`. |
+| M-34 | verified | `analysis/stability.py:111-141` | borderline types (center/saddle-node/line) gated on exact float `==0` of autodiff Jacobians → almost never detected. → tolerance bands. |
+| M-35 | static | `analysis/highdim/slow_points.py:357-360` | GD fixed-point finder stops on *mean* loss but `tolerance` reads as per-point → outliers left unconverged. → stop on max, or document. |
+| M-36 | verified | dyn synapses/`Variable.__float__` | `float(size-1 Variable)` raises under jax 0.10 (`ndim>0`) — breaks common single-neuron monitoring/doctests. → use `.item()`/index; consider squeezing `__float__`. |
+
+---
+
+## 6. Low findings (condensed)
+
+- **L-01** `math/ndarray.py:31,44-76` — duplicate `'Array'` in `__all__`; dead helpers (`_check_input_array`, `_check_out`, `_get_dtype`, `_all_slice`); `_as_jax_array_` duplicated in `_utils.py`. → de-dup/remove.
+- **L-02** `math/scales.py:79-89` — `IdScaling.clone(scale=…)`/`inv_scaling` silently ignore overrides. → raise or honor.
+- **L-03** `math/ndarray.py:153-172` vs `:273-292` — base `Array.value` setter has shape/dtype checks commented out while `ShardedArray` enforces them (inconsistent; base allows silent shape change). → one policy.
+- **L-04** `object_transform/function.py:44` — `function()` deprecation says "removed after 2.4.0" but ships in 2.7.8; `Partial` lacks a docstring. → update message; add docstring.
+- **L-05** `object_transform/_utils.py:24-27` — `__all__` omits the only symbol consumers import (`warp_to_no_state_input_output`). → fix `__all__`.
+- **L-06** `object_transform/base.py:192-219` — `tracing_variable` is `raise NotImplementedError` followed by ~25 lines of unreachable code + stale docstring; default-off pytree path is dead. → delete/clean.
+- **L-07** `dyn/ions/calcium.py:144` — `CalciumDyna._reversal_potential(C)` ignores its `C` arg (uses `self.C`). → use `C`.
+- **L-08** channels — several docstring constants/defaults disagree with code (`IAHP beta=0.09` vs doc `0.03`; `f_q_inf +58` vs `+59`; Ih `tau_m`/`phi`). → reconcile (incl. legacy dups).
+- **L-09** `dyn/synapses/delay_couplings.py:131-241` — docstrings reference a `g`/gain param that doesn't exist; malformed `Parameters::`. → fix docs.
+- **L-10** `losses/comparison.py:534` — functional `l1_loss` defaults `reduction='sum'` vs docstring/`L1Loss` `'mean'`. → `'mean'`.
+- **L-11** `encoding/stateful_encoding.py:111-120` — `LatencyEncoder` docstring example output shape ignores `dt` (`(5,3)` vs real `(50,3)`). → fix example.
+- **L-12** `integrators/sde/srk_strong.py:58,392` — dead module with a generated-code syntax error and wrong `compile_code` arg order. → remove or fix+register+test.
+- **L-13** `integrators/joint_eq.py:189` — `JointEq` raises a bare message-less `DiffEqError`. → add diagnostic.
+- **L-14** Pervasive NumPy-doc nonconformance: `Parameters::`/`Returns::`/`References::` literal-block markers across `math/sparse`, `math/jitconn`, `dyn/rates`, `dyn/synapses`, `measure`, etc. → convert to underlined sections (mandated by CLAUDE.md).
+- **L-15** `analysis/utils/others.py:99` — `get_sign2` passes a generator as `reshape` shape (latent, function unused). → `tuple(...)` / remove dead helpers.
+- **L-16** `dynold/experimental/__init__.py` empty → whole experimental subpackage unreachable; `dynold/synapses/base.py:233` & `experimental/base.py:98` missing `raise` before `ValueError`. → wire up or delete; add `raise`.
+
+---
+
+## 7. Prioritized remediation roadmap
+
+**P0 — silent numerical corruption in the most-used paths (fix first):**
+C-01 (Adam), C-02 (nll sign), C-03 (CE weight), C-04 (MultiStepLR), C-05 (GroupNorm), C-08 (Caputo), C-09 (TimeDelay), C-10 (disable_x64), C-16 (StuartLandau), C-17 (PoissonInput), C-23 (RLS B>1), H-43 (firing_rate), H-26 (BoSh3), H-20…H-24 (surrogate math). These give wrong answers without erroring.
+
+**P1 — crash-on-first-use of public APIs:**
+C-06, C-07, C-11–C-15, C-18–C-22, C-24–C-25, H-01–H-06, H-10–H-19, H-29–H-36, H-41, H-44–H-46, H-51, M-21, M-29. Many are one-line migration fixes; bundle with a public-surface import/smoke test.
+
+**P2 — correctness traps & fragility (Medium tier).** **P3 — docs/typing/dead-code hygiene (Low tier), incl. the repo-wide NumPy-doc `::` cleanup.**
+
+**Systemic actions (do alongside P0/P1):**
+1. **Pin & test ecosystem versions.** Bump `pyproject` lower bounds to the actually-tested `brainstate`/`brainevent`/`braintools`/`jax`, and add CI matrix entries.
+2. **Public-surface smoke test:** instantiate + one-step every public neuron/synapse/projection/layer/optimizer/encoder under the *default* mode; many P1 crashes would be caught immediately.
+3. **Property-based numerical tests** for surrogates (`grad(surrogate_fun) ≈ surrogate_grad`), integrators (convergence order, SDE moments), losses (vs reference), and delays (off-by-one) — the bug classes here recur and need oracles.
+4. **Resolve the `bm.surrogate` shadowing** (H-25): decide whether `braintools` or the in-repo package is canonical and delete the other.
+
+---
+
+## Appendix A — Verification status
+
+- **Independently re-verified by the lead reviewer (executed, reproduced):** C-01..C-13, C-15..C-17, C-20, C-24; H-02..H-08, H-10..H-19, H-20..H-24, H-27..H-39, H-41, H-43..H-46, H-51..H-52; M-06..M-08, M-11..M-12, M-14..M-15, M-17..M-19, M-21..M-22, M-26..M-30, M-32, M-36 — plus the isolated x64/precision and FDE checks. 33 of the highest-impact items were run head-to-head; **all reproduced**. (Two initial "non-reproductions" were lead-reviewer test-harness errors — wrong threshold / arg-order — with the underlying bug confirmed on code inspection.)
+- **Verified by the module sub-audits (executed in the same environment):** C-14, C-18..C-19, C-21..C-23, C-25..C-26, H-01, H-09, H-26, H-40, H-47..H-50, H-53 and the remaining Medium/Low items.
+
+## Appendix B — Checked and found CORRECT (audit negatives)
+
+To bound the audit, the following were specifically checked and are **not** bugs: all explicit-RK Butcher tableaus (Euler/RK2/RK3/RK4/Ralston/SSPRK convergence orders) and RKF45/CashKarp/Dormand–Prince/BogackiShampine embedded pairs; exponential-Euler exactness on linear ODEs; SDE Euler–Maruyama/Milstein/SRK first/second moments and `sqrt(dt)` scaling; per-step PRNG independence inside the jitted `for_loop`; `CaputoL1Schema`/`GLShortMemory` core numerics (machine precision incl. ring-buffer wrap); Expon/Alpha/DualExpon/NMDA/AMPA/STD synapse kinetics and the STDP sign convention; `gelu/elu/selu/softmax/log_softmax` math; einops parsing edge cases; GRU cell math, NVAR feature construction, MgBlock curve, OU `sqrt(dt)`, reservoir spectral-radius rescaling; Kaiming/Xavier/Lecun/Orthogonal init statistics; `FixedProb.build_mat` density, `GaussianProb` symmetry; `Dense`/`Conv`/`ConvTranspose`/`AvgPool` shapes & layouts, `Dropout` scaling & eval no-op, BatchNorm momentum direction & eval-uses-running-stats; the high-dim `SlowPointFinder` Jacobian recovery; calcium Nernst constant; the `*_compatible.py` channel shims (numerically identical to their v2 sources).
+
+---
+
+*Generated 2026-06-18. Working branch: `worktree-audit-issues-20260618`. Spec & verification scripts under `dev/superpowers/` (gitignored).*
diff --git a/tests/audit/test_boost_analysis.py b/tests/audit/test_boost_analysis.py
new file mode 100644
index 000000000..c3030b240
--- /dev/null
+++ b/tests/audit/test_boost_analysis.py
@@ -0,0 +1,964 @@
+# -*- coding: utf-8 -*-
+"""Audit coverage-boost tests for BrainPy v2.7.8 low-dimensional analysis.
+
+This module is part of the test-coverage audit. It exercises:
+
+- ``brainpy/analysis/lowdim/lowdim_analyzer.py`` (baseline ~39% line coverage)
+ via the public ``PhasePlane1D`` / ``PhasePlane2D`` / ``Bifurcation1D`` /
+ ``Bifurcation2D`` / ``FastSlow1D`` / ``FastSlow2D`` analyzers, including the
+ numerical nullcline / fixed-point / Jacobian machinery in ``Num1DAnalyzer``
+ and ``Num2DAnalyzer`` (both the optimization branch and the
+ "convert-to-one-equation" brentq branch).
+- ``brainpy/analysis/utils/optimization.py`` (baseline ~74% line coverage)
+ via direct calls to the root-finding helpers (``jax_brentq``,
+ ``get_brentq_candidates``, ``brentq_candidates``, ``brentq_roots``,
+ ``brentq_roots2``, ``roots_of_1d_by_x``, ``roots_of_1d_by_xy``,
+ ``numpy_brentq``, ``find_root_of_1d_numpy``) and ``scipy_minimize_with_jax``.
+
+All tests run on tiny models with coarse resolutions, short durations and small
+parameter ranges so the whole module stays well under the runtime budget.
+The tests do NOT modify any source; behaviour quirks observed on valid input
+(e.g. duplicate 1D fixed points) are pinned with explicit assertions/notes.
+"""
+
+import matplotlib
+
+matplotlib.use('Agg') # headless backend; must precede pyplot / analysis imports
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pytest
+from jax import vmap
+from matplotlib import pyplot as plt
+
+import brainpy as bp
+import brainpy.math as bm
+from brainpy.analysis import constants as C
+from brainpy.analysis.lowdim.lowdim_analyzer import Num3DAnalyzer
+from brainpy.analysis.utils import optimization as opt
+
+
+# ---------------------------------------------------------------------------
+# Module-level setup: brentq optimization in BrainPy requires x64.
+# ---------------------------------------------------------------------------
+
+def setup_module(module):
+ bm.enable_x64()
+
+
+def teardown_module(module):
+ plt.close('all')
+ bm.disable_x64()
+
+
+def _close():
+ plt.close('all')
+
+
+# ===========================================================================
+# optimization.py -- direct helper calls on tiny functions
+# ===========================================================================
+
+def test_jax_brentq_simple_root():
+ """jax_brentq finds a bracketed root and reports convergence."""
+ solver = opt.jax_brentq(lambda x: x - 2.0)
+ res = solver(0.0, 5.0)
+ assert res['status'] == opt.ECONVERGED
+ assert abs(float(res['root']) - 2.0) < 1e-6
+ assert int(res['funcalls']) >= 2
+
+
+def test_jax_brentq_root_at_endpoint():
+ """When an endpoint is already a root, brentq returns it immediately."""
+ solver = opt.jax_brentq(lambda x: x)
+ # a == 0 is exactly a root -> early ECONVERGED path
+ res = solver(0.0, 3.0)
+ assert res['status'] == opt.ECONVERGED
+ assert abs(float(res['root'])) < 1e-12
+
+
+def test_jax_brentq_with_args():
+ """jax_brentq passes through extra args to the objective."""
+ solver = opt.jax_brentq(lambda x, p: x - p)
+ res = solver(0.0, 10.0, (4.0,))
+ assert res['status'] == opt.ECONVERGED
+ assert abs(float(res['root']) - 4.0) < 1e-6
+
+
+def test_roots_of_1d_by_x_quadratic():
+ """roots_of_1d_by_x recovers both roots of x**2 - 1 on [-2, 2]."""
+ cands = jnp.linspace(-2.0, 2.0, 41)
+ roots = opt.roots_of_1d_by_x(lambda x: x ** 2 - 1.0, cands)
+ vals = sorted(round(float(r), 3) for r in roots)
+ assert vals == [-1.0, 1.0]
+
+
+def test_roots_of_1d_by_x_exact_zero_candidate():
+ """Candidates that sit exactly on a root take the zero-sign branch."""
+ # candidate grid includes the exact roots 0 and +/-1 of x**3 - x
+ cands = jnp.linspace(-1.5, 1.5, 7) # ... -1, -0.5, 0, 0.5, 1 ...
+ roots = opt.roots_of_1d_by_x(lambda x: x ** 3 - x, cands)
+ vals = sorted(round(float(r), 3) for r in np.unique(np.asarray(roots)))
+ assert -1.0 in vals and 0.0 in vals and 1.0 in vals
+
+
+def test_roots_of_1d_by_x_no_roots():
+ """A strictly positive function yields no sign changes -> empty result."""
+ cands = jnp.linspace(-2.0, 2.0, 21)
+ roots = opt.roots_of_1d_by_x(lambda x: x ** 2 + 1.0, cands)
+ assert len(np.asarray(roots)) == 0
+
+
+def test_get_brentq_candidates_and_roots_of_1d_by_xy():
+ """get_brentq_candidates + roots_of_1d_by_xy round-trip on f(x, y) = x - y."""
+ xs = jnp.linspace(-2.0, 2.0, 21)
+ ys = jnp.linspace(-1.0, 1.0, 11)
+ starts, ends, args = opt.get_brentq_candidates(lambda x, y: x - y, xs, ys)
+ assert starts.shape == ends.shape == args.shape
+ assert starts.shape[0] > 0
+ xs_root, ys_root = opt.roots_of_1d_by_xy(lambda x, y: x - y, starts, ends, args)
+ # For f = x - y, the root x equals the parameter y.
+ assert xs_root.shape == ys_root.shape
+ np.testing.assert_allclose(np.asarray(xs_root), np.asarray(ys_root), atol=1e-6)
+
+
+def test_brentq_candidates_and_roots():
+ """brentq_candidates / brentq_roots / brentq_roots2 agree for f(x, p) = x - p."""
+ f = lambda x, p: x - p
+ vmap_f = jax.jit(vmap(f))
+ xs = jnp.linspace(-2.0, 2.0, 21)
+ ps = jnp.linspace(-1.0, 1.0, 11)
+ starts, ends, others = opt.brentq_candidates(vmap_f, xs, ps)
+ assert starts.shape[0] > 0
+ assert len(others) == 1
+
+ roots, vargs = opt.brentq_roots(f, starts, ends, others[0])
+ assert roots.shape[0] > 0
+ np.testing.assert_allclose(np.asarray(roots), np.asarray(vargs[0]), atol=1e-6)
+
+ vmap_brentq = jax.jit(vmap(opt.jax_brentq(f)))
+ roots2, vargs2 = opt.brentq_roots2(vmap_brentq, starts, ends, others[0])
+ np.testing.assert_allclose(np.asarray(roots2), np.asarray(vargs2[0]), atol=1e-6)
+
+
+def test_brentq_roots_no_extra_args_is_broken():
+ """NOTE (source bug): brentq_roots with no vmap_args/args is broken.
+
+ When called with neither ``vmap_args`` nor ``args``, the function builds an
+ ``in_axes`` tuple of length 3 ``(0, 0, ())`` but then invokes the vmapped
+ optimizer with only two positional arguments ``(starts, ends)``. jax.vmap
+ rejects the mismatch (len(in_axes)=3 vs len(args)=2). This is a latent bug
+ on the ``else`` branch of ``brentq_roots`` -- the codebase always reaches it
+ via ``brentq_roots2`` instead. Pinned here so any future fix is noticed.
+ """
+ f = lambda x: x - 1.0
+ starts = jnp.array([0.0, -5.0])
+ ends = jnp.array([2.0, 5.0])
+ with pytest.raises(ValueError):
+ opt.brentq_roots(f, starts, ends)
+
+
+def test_brentq_roots2_no_extra_args():
+ """brentq_roots2 (the actually-used variant) handles the no-args case."""
+ f = lambda x: x - 1.0
+ vmap_brentq = jax.jit(vmap(opt.jax_brentq(f)))
+ starts = jnp.array([0.0, -5.0])
+ ends = jnp.array([2.0, 5.0])
+ roots, vargs = opt.brentq_roots2(vmap_brentq, starts, ends)
+ assert vargs == ()
+ np.testing.assert_allclose(np.sort(np.asarray(roots)), [1.0, 1.0], atol=1e-6)
+
+
+def test_scipy_minimize_with_jax():
+ """scipy_minimize_with_jax minimizes a quadratic and unflattens the result."""
+ res = opt.scipy_minimize_with_jax(
+ lambda x: ((x - 3.0) ** 2).sum(),
+ jnp.array([0.0]),
+ method='BFGS',
+ )
+ assert bool(res['success'])
+ assert abs(float(np.asarray(res['x'])[0]) - 3.0) < 1e-4
+
+
+def test_scipy_minimize_with_jax_callback_and_bounds():
+ """Exercise the callback + bounds branches of scipy_minimize_with_jax."""
+ seen = []
+
+ def cb(xk):
+ seen.append(np.asarray(jax.tree_util.tree_leaves(xk)[0]))
+ return False # do not terminate early
+
+ res = opt.scipy_minimize_with_jax(
+ lambda x: ((x - 1.0) ** 2).sum(),
+ jnp.array([5.0]),
+ method='L-BFGS-B',
+ bounds=[(0.0, 10.0)],
+ callback=cb,
+ )
+ assert bool(res['success'])
+ assert abs(float(np.asarray(res['x'])[0]) - 1.0) < 1e-3
+ assert len(seen) >= 1
+
+
+def test_numpy_brentq_root_and_errors():
+ """numpy_brentq finds a root and raises on bad bracket / params."""
+ root, funcalls, itr = opt.numpy_brentq(lambda x: x ** 2 - 4.0, 0.0, 5.0)
+ assert abs(root - 2.0) < 1e-6
+ assert funcalls >= 2
+
+ # endpoint is a root -> early-return branch (status ECONVERGED, itr == 0)
+ root2, _, itr2 = opt.numpy_brentq(lambda x: x, 0.0, 5.0)
+ assert abs(root2) < 1e-12
+ assert itr2 == 0
+
+ with pytest.raises(ValueError):
+ opt.numpy_brentq(lambda x: x + 1.0, 0.0, 5.0) # f(a), f(b) same sign
+ with pytest.raises(ValueError):
+ opt.numpy_brentq(lambda x: x, 0.0, 5.0, xtol=-1.0) # bad xtol
+ with pytest.raises(ValueError):
+ opt.numpy_brentq(lambda x: x, 0.0, 5.0, maxiter=0) # bad maxiter
+
+
+def test_numpy_brentq_endpoint_b_is_root():
+ """numpy_brentq returns immediately when the upper endpoint is the root."""
+ root, _, itr = opt.numpy_brentq(lambda x: x - 5.0, 0.0, 5.0)
+ assert abs(root - 5.0) < 1e-12
+ assert itr == 0 # early ECONVERGED via fcur == 0
+
+
+def test_find_root_of_1d_numpy():
+ """find_root_of_1d_numpy recovers the roots of x**3 - x including exact ones."""
+ pts = np.linspace(-3.0, 3.0, 61)
+ roots = opt.find_root_of_1d_numpy(lambda x: x ** 3 - x, pts)
+ vals = sorted(round(float(r), 3) for r in np.unique(np.asarray(roots)))
+ assert -1.0 in vals and 0.0 in vals and 1.0 in vals
+
+
+def test_find_root_of_1d_numpy_leading_and_trailing_zeros():
+ """find_root_of_1d_numpy handles exact roots at the first/last grid points."""
+ # f = x*(x-2): roots 0 (first point, leading-zero branch) and 2 (last point)
+ pts = np.linspace(0.0, 2.0, 11)
+ roots = opt.find_root_of_1d_numpy(lambda x: x * (x - 2.0), pts)
+ vals = sorted(round(float(r), 3) for r in np.unique(np.asarray(roots)))
+ assert 0.0 in vals and 2.0 in vals
+
+
+# ===========================================================================
+# PhasePlane1D -- lowdim_analyzer Num1DAnalyzer paths
+# ===========================================================================
+
+def test_phase_plane_1d_linear():
+ """Linear 1D system dx = -x + I has a single stable fixed point at x = I."""
+ Iext = 0.5
+
+ @bp.odeint
+ def int_x(x, t, I=0.5):
+ return -x + I
+
+ pp = bp.analysis.PhasePlane1D(
+ model=int_x,
+ target_vars={'x': [-2.0, 2.0]},
+ pars_update={'I': Iext},
+ resolutions=0.05,
+ )
+ # vector field
+ yv = pp.plot_vector_field(with_return=True, show=False)
+ assert yv.shape[0] > 0
+
+ # fixed points
+ fps = np.asarray(pp.plot_fixed_point(with_return=True, show=False))
+ assert len(fps) >= 1
+ # every returned fixed point should sit at x == I
+ np.testing.assert_allclose(fps, Iext, atol=1e-4)
+ _close()
+
+
+def test_phase_plane_1d_cubic():
+ """Cubic 1D system dx = x - x**3 has three fixed points: -1, 0, 1."""
+
+ @bp.odeint
+ def int_x(x, t):
+ return x - x ** 3
+
+ pp = bp.analysis.PhasePlane1D(
+ model=int_x,
+ target_vars={'x': [-2.0, 2.0]},
+ resolutions=0.02,
+ )
+ pp.plot_vector_field(show=False)
+ fps = np.asarray(pp.plot_fixed_point(with_return=True, show=False))
+ uniq = sorted(round(float(v), 2) for v in np.unique(np.round(fps, 2)))
+ for expected in (-1.0, 0.0, 1.0):
+ assert any(abs(u - expected) < 1e-2 for u in uniq), f'missing fp {expected}: {uniq}'
+ _close()
+
+
+def test_phase_plane_1d_no_plot_paths():
+ """with_plot=False / with_return=False branches return None without drawing."""
+
+ @bp.odeint
+ def int_x(x, t):
+ return -x
+
+ pp = bp.analysis.PhasePlane1D(
+ model=int_x,
+ target_vars={'x': [-1.0, 1.0]},
+ resolutions=0.1,
+ )
+ assert pp.plot_vector_field(with_plot=False, with_return=False, show=False) is None
+ assert pp.plot_fixed_point(with_plot=False, with_return=False, show=False) is None
+ _close()
+
+
+def test_phase_plane_1d_rejects_target_pars():
+ """PhasePlane analyzers reject target_pars (only PP via pars_update allowed)."""
+ @bp.odeint
+ def int_x(x, t, a=1.):
+ return -a * x
+
+ with pytest.raises(Exception):
+ bp.analysis.PhasePlane1D(
+ model=int_x,
+ target_vars={'x': [-1.0, 1.0]},
+ target_pars={'a': [0.5, 1.5]},
+ resolutions=0.1,
+ )
+
+
+# ===========================================================================
+# PhasePlane2D -- lowdim_analyzer Num2DAnalyzer paths (FitzHugh-Nagumo)
+# ===========================================================================
+
+_FHN_A, _FHN_B, _FHN_TAU = 0.7, 0.8, 12.5
+
+
+def _fhn_integrals(I=0.5):
+ @bp.odeint
+ def int_V(V, t, w, I=I):
+ return V - V ** 3 / 3.0 - w + I
+
+ @bp.odeint
+ def int_w(w, t, V):
+ return (V + _FHN_A - _FHN_B * w) / _FHN_TAU
+
+ return int_V, int_w
+
+
+def test_phase_plane_2d_full_optimization_branch():
+ """Full PP2D pipeline (vector field, nullclines, fixed point, trajectory)."""
+ int_V, int_w = _fhn_integrals(I=0.5)
+ pp = bp.analysis.PhasePlane2D(
+ model=[int_V, int_w],
+ target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]},
+ resolutions=0.1,
+ )
+
+ # streamplot + quiver vector field branches
+ dx, dy = pp.plot_vector_field(with_return=True, show=False)
+ assert dx.shape == dy.shape and dx.ndim == 2
+ pp.plot_vector_field(plot_method='quiver', show=False)
+
+ # nullclines (optimization branch, x-y coords)
+ nc = pp.plot_nullcline(with_return=True, show=False)
+ assert set(nc.keys()) == {'V', 'w'}
+
+ # fixed points from the fx-nullcline candidates (default)
+ fps = np.asarray(pp.plot_fixed_point(with_return=True, show=False))
+ assert fps.shape[1] == 2
+ # FHN with I=0.5 has a single (unstable) fixed point in this window
+ assert len(fps) >= 1
+
+ # trajectory in both axis modes
+ traj = pp.plot_trajectory(
+ initials={'V': [-1.0], 'w': [0.0]},
+ duration=10.0, show=False, with_return=True,
+ )
+ assert traj is not None
+ pp.plot_trajectory(
+ initials={'V': [-1.0], 'w': [0.0]},
+ duration=5.0, axes='t-v', show=False,
+ )
+ _close()
+
+
+def test_phase_plane_2d_nullcline_alternate_coords():
+ """coords='w-V' exercises the y_var-x_var branch of the nullcline solver."""
+ int_V, int_w = _fhn_integrals(I=0.5)
+ pp = bp.analysis.PhasePlane2D(
+ model=[int_V, int_w],
+ target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]},
+ resolutions=0.1,
+ )
+ nc = pp.plot_nullcline(
+ with_return=True, show=False,
+ coords={'V': 'V-w', 'w': 'w-V'},
+ )
+ assert set(nc.keys()) == {'V', 'w'}
+ _close()
+
+
+def test_phase_plane_2d_invalid_plot_method_and_axes():
+ """Unknown plot_method / axes raise analyzer errors."""
+ int_V, int_w = _fhn_integrals()
+ pp = bp.analysis.PhasePlane2D(
+ model=[int_V, int_w],
+ target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]},
+ resolutions=0.2,
+ )
+ with pytest.raises(Exception):
+ pp.plot_vector_field(plot_method='nope', show=False)
+ with pytest.raises(Exception):
+ pp.plot_trajectory(initials={'V': [0.0], 'w': [0.0]},
+ duration=1.0, axes='bad', show=False)
+ _close()
+
+
+def test_phase_plane_2d_limit_cycle_by_sim():
+ """plot_limit_cycle_by_sim runs (no cycle expected for short sim, but exercised)."""
+ int_V, int_w = _fhn_integrals(I=0.5)
+ pp = bp.analysis.PhasePlane2D(
+ model=[int_V, int_w],
+ target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]},
+ resolutions=0.2,
+ )
+ # short duration -> exercises the "no limit cycle found" branch safely
+ pp.plot_limit_cycle_by_sim(
+ initials={'V': [-1.0], 'w': [0.0]},
+ duration=20.0, show=False,
+ )
+ _close()
+
+
+def test_phase_plane_2d_aux_rank_candidates():
+ """select_candidates='aux_rank' exercises _get_fp_candidates_by_aux_rank."""
+ int_V, int_w = _fhn_integrals(I=0.5)
+ pp = bp.analysis.PhasePlane2D(
+ model=[int_V, int_w],
+ target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]},
+ resolutions=0.1,
+ )
+ fps = np.asarray(pp.plot_fixed_point(
+ with_return=True, select_candidates='aux_rank', num_rank=50, show=False))
+ assert fps.shape[1] == 2
+ assert len(fps) >= 1
+ _close()
+
+
+def test_phase_plane_2d_fixed_point_requires_nullcline():
+ """fx/fy-nullcline candidate selection errors if nullclines not yet computed."""
+ int_V, int_w = _fhn_integrals(I=0.5)
+ pp = bp.analysis.PhasePlane2D(
+ model=[int_V, int_w],
+ target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]},
+ resolutions=0.2,
+ )
+ with pytest.raises(Exception):
+ pp.plot_fixed_point(select_candidates='fy-nullcline', show=False)
+ _close()
+
+
+def test_phase_plane_2d_convert_to_one_equation():
+ """Providing y_by_x_in_fy enables the brentq 'convert to one equation' branch."""
+ int_V, int_w = _fhn_integrals(I=0.5)
+
+ # w-nullcline of FHN: (V + a - b*w)/tau = 0 -> w = (V + a) / b
+ def w_by_V(V):
+ return (V + _FHN_A) / _FHN_B
+
+ pp = bp.analysis.PhasePlane2D(
+ model=[int_V, int_w],
+ target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]},
+ resolutions=0.1,
+ options={C.y_by_x_in_fy: w_by_V},
+ )
+ assert pp._can_convert_to_one_eq()
+ assert pp.convert_type() == C.y_by_x
+ # nullcline now uses the analytic F_y_by_x_in_fy branch
+ nc = pp.plot_nullcline(with_return=True, show=False)
+ assert set(nc.keys()) == {'V', 'w'}
+ # fixed point via brentq optimization branch
+ fps = np.asarray(pp.plot_fixed_point(with_return=True, show=False))
+ assert fps.shape[1] == 2 and len(fps) >= 1
+ _close()
+
+
+def test_phase_plane_2d_convert_x_by_y_in_fy():
+ """x_by_y_in_fy option drives the analytic fy-nullcline + x_by_y convert type."""
+ int_V, int_w = _fhn_integrals(I=0.5)
+
+ # w-nullcline: (V + a - b*w)/tau = 0 -> V = b*w - a
+ def V_by_w(w):
+ return _FHN_B * w - _FHN_A
+
+ pp = bp.analysis.PhasePlane2D(
+ model=[int_V, int_w],
+ target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]},
+ resolutions=0.1,
+ options={C.x_by_y_in_fy: V_by_w},
+ )
+ assert pp.convert_type() == C.x_by_y
+ nc = pp.plot_nullcline(with_return=True, show=False)
+ assert set(nc.keys()) == {'V', 'w'}
+ fps = np.asarray(pp.plot_fixed_point(with_return=True, show=False))
+ assert fps.shape[1] == 2 and len(fps) >= 1
+ _close()
+
+
+def test_phase_plane_2d_convert_y_by_x_in_fx():
+ """y_by_x_in_fx option drives the analytic fx-nullcline + y_by_x convert type."""
+ int_V, int_w = _fhn_integrals(I=0.5)
+
+ # V-nullcline: V - V**3/3 - w + I = 0 -> w = V - V**3/3 + I
+ def w_by_V(V):
+ return V - V ** 3 / 3.0 + 0.5
+
+ pp = bp.analysis.PhasePlane2D(
+ model=[int_V, int_w],
+ target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]},
+ resolutions=0.1,
+ options={C.y_by_x_in_fx: w_by_V},
+ )
+ assert pp.convert_type() == C.y_by_x
+ nc = pp.plot_nullcline(with_return=True, show=False)
+ assert set(nc.keys()) == {'V', 'w'}
+ fps = np.asarray(pp.plot_fixed_point(with_return=True, show=False))
+ assert fps.shape[1] == 2 and len(fps) >= 1
+ _close()
+
+
+def test_phase_plane_2d_convert_x_by_y_in_fx():
+ """x_by_y_in_fx option (linear system) drives the analytic fx-nullcline branch."""
+
+ @bp.odeint
+ def int_x(x, t, y):
+ return -x + 0.5 * y
+
+ @bp.odeint
+ def int_y(y, t, x):
+ return x - 2.0 * y
+
+ # fx-nullcline: -x + 0.5*y = 0 -> x = 0.5*y
+ def x_by_y(y):
+ return 0.5 * y
+
+ pp = bp.analysis.PhasePlane2D(
+ model=[int_x, int_y],
+ target_vars={'x': [-2.0, 2.0], 'y': [-2.0, 2.0]},
+ resolutions=0.1,
+ options={C.x_by_y_in_fx: x_by_y},
+ )
+ assert pp.convert_type() == C.x_by_y
+ nc = pp.plot_nullcline(with_return=True, show=False)
+ assert set(nc.keys()) == {'x', 'y'}
+ fps = np.asarray(pp.plot_fixed_point(with_return=True, show=False))
+ # the only fixed point of this linear system is the origin
+ assert fps.shape[1] == 2 and len(fps) >= 1
+ np.testing.assert_allclose(fps[0], [0.0, 0.0], atol=1e-4)
+ _close()
+
+
+def test_num2d_jacobian_and_derivatives():
+ """Directly exercise the F_jacobian / derivative properties of Num2DAnalyzer."""
+ int_V, int_w = _fhn_integrals(I=0.5)
+ pp = bp.analysis.PhasePlane2D(
+ model=[int_V, int_w],
+ target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]},
+ resolutions=0.2,
+ )
+ J = np.asarray(pp.F_jacobian(0.0, 0.0))
+ assert J.shape == (2, 2)
+ # partials at the origin: dfx/dw = -1 ; dfy/dV = 1/tau ; dfy/dw = -b/tau
+ assert abs(float(pp.F_dfxdy(0.0, 0.0)) + 1.0) < 1e-6
+ assert abs(float(pp.F_dfydx(0.0, 0.0)) - 1.0 / _FHN_TAU) < 1e-6
+ assert abs(float(pp.F_dfydy(0.0, 0.0)) + _FHN_B / _FHN_TAU) < 1e-6
+ _close()
+
+
+# ===========================================================================
+# Bifurcation1D / Bifurcation2D
+# ===========================================================================
+
+def test_bifurcation_1d_codim1():
+ """Bifurcation1D co-dimension-1: dx = -x + I, fixed point tracks I."""
+
+ @bp.odeint
+ def int_x(x, t, I=0.0):
+ return -x + I
+
+ bf = bp.analysis.Bifurcation1D(
+ model=int_x,
+ target_vars={'x': [-2.0, 2.0]},
+ target_pars={'I': [-1.0, 1.0]},
+ resolutions={'I': 0.1},
+ )
+ fps, pars, dfdx = bf.plot_bifurcation(with_return=True, show=False)
+ fps = np.asarray(fps)
+ p = np.asarray(pars[0])
+ assert fps.shape[0] > 0 and fps.shape == p.shape
+ # for dx = -x + I, fixed point x == I and df/dx == -1 (stable)
+ np.testing.assert_allclose(fps, p, atol=1e-4)
+ assert np.all(np.asarray(dfdx) < 0)
+ _close()
+
+
+def test_bifurcation_1d_codim2():
+ """Bifurcation1D co-dimension-2 (3D scatter): dx = -a*x + b."""
+
+ @bp.odeint
+ def int_x(x, t, a=1.0, b=0.0):
+ return -a * x + b
+
+ bf = bp.analysis.Bifurcation1D(
+ model=int_x,
+ target_vars={'x': [-2.0, 2.0]},
+ target_pars={'a': [0.5, 1.5], 'b': [-1.0, 1.0]},
+ resolutions={'a': 0.3, 'b': 0.3},
+ )
+ fps, pars, dfdx = bf.plot_bifurcation(with_return=True, show=False)
+ assert np.asarray(fps).shape[0] > 0
+ assert len(pars) == 2
+ _close()
+
+
+def test_bifurcation_1d_float_resolution_warns():
+ """A single float resolution with target_pars warns and uses jnp.arange grids."""
+
+ @bp.odeint
+ def int_x(x, t, I=0.0):
+ return -x + I
+
+ with pytest.warns(UserWarning):
+ bf = bp.analysis.Bifurcation1D(
+ model=int_x,
+ target_vars={'x': [-2.0, 2.0]},
+ target_pars={'I': [-1.0, 1.0]},
+ resolutions=0.2,
+ )
+ fps, pars, dfdx = bf.plot_bifurcation(with_return=True, show=False)
+ assert np.asarray(fps).shape[0] > 0
+ _close()
+
+
+def test_bifurcation_2d_segmented():
+ """num_par_segments>1 and num_fp_segment>1 exercise the segment-loop branches."""
+ int_V, int_w = _fhn_integrals()
+
+ @bp.odeint
+ def int_V2(V, t, w, Iext=0.0):
+ return V - V ** 3 / 3.0 - w + Iext
+
+ bif = bp.analysis.Bifurcation2D(
+ model=[int_V2, int_w],
+ target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]},
+ target_pars={'Iext': [0.0, 1.0]},
+ resolutions={'Iext': 0.3},
+ )
+ fps, pars, jac = bif.plot_bifurcation(
+ with_return=True, show=False,
+ select_candidates='fx-nullcline',
+ num_par_segments=2, num_fp_segment=2, nullcline_aux_filter=1.0,
+ )
+ assert np.asarray(fps).shape[1] == 2
+ _close()
+
+
+def test_bifurcation_2d_codim1_and_limit_cycle():
+ """Bifurcation2D co-dimension-1 on FHN over Iext, plus limit-cycle sim."""
+ int_V, int_w = _fhn_integrals()
+
+ @bp.odeint
+ def int_V2(V, t, w, Iext=0.0):
+ return V - V ** 3 / 3.0 - w + Iext
+
+ bif = bp.analysis.Bifurcation2D(
+ model=[int_V2, int_w],
+ target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]},
+ target_pars={'Iext': [0.0, 1.0]},
+ resolutions={'Iext': 0.2},
+ )
+ fps, pars, jac = bif.plot_bifurcation(with_return=True, show=False)
+ fps = np.asarray(fps)
+ assert fps.shape[1] == 2
+ assert np.asarray(jac).shape[1:] == (2, 2)
+ assert np.asarray(pars).shape[1] == 1
+
+ # limit cycle by simulation off the recorded fixed points
+ lc = bif.plot_limit_cycle_by_sim(duration=20.0, with_return=True, show=False)
+ assert lc is not None
+ _close()
+
+
+def test_bifurcation_2d_nullcline_candidate_selection():
+ """Bifurcation2D with fx/fy/nullclines candidate selection + aux filtering.
+
+ This drives the analytic-free nullcline solvers plus the ``_fp_filter``
+ auxiliary-loss filtering branch (nullcline_aux_filter > 0).
+ """
+ int_V, int_w = _fhn_integrals()
+
+ @bp.odeint
+ def int_V2(V, t, w, Iext=0.0):
+ return V - V ** 3 / 3.0 - w + Iext
+
+ for select in ('fx-nullcline', 'fy-nullcline', 'nullclines'):
+ bif = bp.analysis.Bifurcation2D(
+ model=[int_V2, int_w],
+ target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]},
+ target_pars={'Iext': [0.0, 1.0]},
+ resolutions={'Iext': 0.3},
+ )
+ fps, pars, jac = bif.plot_bifurcation(
+ with_return=True, show=False,
+ select_candidates=select, nullcline_aux_filter=1.0,
+ )
+ assert np.asarray(fps).shape[1] == 2
+ assert np.asarray(fps).shape[0] >= 1
+ _close()
+
+
+def test_bifurcation_2d_limit_cycle_without_bifurcation():
+ """plot_limit_cycle_by_sim returns early when no fixed points recorded yet."""
+ int_V, int_w = _fhn_integrals()
+
+ @bp.odeint
+ def int_V2(V, t, w, Iext=0.0):
+ return V - V ** 3 / 3.0 - w + Iext
+
+ bif = bp.analysis.Bifurcation2D(
+ model=[int_V2, int_w],
+ target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]},
+ target_pars={'Iext': [0.0, 1.0]},
+ resolutions={'Iext': 0.3},
+ )
+ # _fixed_points is None -> early return (None)
+ assert bif.plot_limit_cycle_by_sim(duration=5.0, show=False) is None
+ _close()
+
+
+# ===========================================================================
+# FastSlow1D / FastSlow2D
+# ===========================================================================
+
+def test_fast_slow_1d():
+ """FastSlow1D: fast var x, slow var u as the bifurcation parameter."""
+
+ @bp.odeint
+ def int_x(x, t, u=0.0):
+ return -x + u
+
+ @bp.odeint
+ def int_u(u, t, x):
+ return 0.01 * (x - u)
+
+ fs = bp.analysis.FastSlow1D(
+ model=[int_x, int_u],
+ fast_vars={'x': [-2.0, 2.0]},
+ slow_vars={'u': [-1.0, 1.0]},
+ resolutions={'u': 0.1},
+ )
+ fps, pars, dfdx = fs.plot_bifurcation(with_return=True, show=False)
+ fps = np.asarray(fps)
+ assert fps.shape[0] > 0
+ # x* == u and stable (df/dx == -1)
+ np.testing.assert_allclose(fps, np.asarray(pars[0]), atol=1e-4)
+
+ traj = fs.plot_trajectory(
+ initials={'x': [0.5], 'u': [0.5]},
+ duration=10.0, show=False, with_return=True,
+ )
+ assert traj is not None
+ _close()
+
+
+def test_fast_slow_2d():
+ """FastSlow2D: 2 fast vars (V, w), 1 slow var u acting like an input current."""
+
+ @bp.odeint
+ def int_V(V, t, w, u=0.0):
+ return V - V ** 3 / 3.0 - w + u
+
+ @bp.odeint
+ def int_w(w, t, V):
+ return (V + _FHN_A - _FHN_B * w) / _FHN_TAU
+
+ @bp.odeint
+ def int_u(u, t, V):
+ return 0.01 * (V - u)
+
+ fs = bp.analysis.FastSlow2D(
+ model=[int_V, int_w, int_u],
+ fast_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]},
+ slow_vars={'u': [0.0, 1.0]},
+ resolutions={'u': 0.2},
+ )
+ fps, pars, jac = fs.plot_bifurcation(with_return=True, show=False)
+ assert np.asarray(fps).shape[1] == 2
+
+ traj = fs.plot_trajectory(
+ initials={'V': [-1.0], 'w': [0.0], 'u': [0.5]},
+ duration=10.0, show=False, with_return=True,
+ )
+ assert traj is not None
+ _close()
+
+
+# ===========================================================================
+# LowDimAnalyzer construction / validation paths
+# ===========================================================================
+
+def test_analyzer_resolution_dict_and_array():
+ """resolutions can be a dict mixing floats and explicit 1D arrays."""
+
+ @bp.odeint
+ def int_x(x, t, I=0.0):
+ return -x + I
+
+ bf = bp.analysis.Bifurcation1D(
+ model=int_x,
+ target_vars={'x': [-2.0, 2.0]},
+ target_pars={'I': [-1.0, 1.0]},
+ resolutions={'x': np.linspace(-2.0, 2.0, 25), 'I': 0.2},
+ )
+ assert bf.resolutions['x'].shape[0] == 25
+ assert bf.resolutions['I'].shape[0] > 0
+ _close()
+
+
+def test_analyzer_resolution_none_default():
+ """resolutions=None builds the default 20-point linspace for vars and pars."""
+
+ @bp.odeint
+ def int_x(x, t, I=0.0):
+ return -x + I
+
+ bf = bp.analysis.Bifurcation1D(
+ model=int_x,
+ target_vars={'x': [-2.0, 2.0]},
+ target_pars={'I': [-1.0, 1.0]},
+ resolutions=None,
+ )
+ assert bf.resolutions['x'].shape[0] == 20
+ assert bf.resolutions['I'].shape[0] == 20
+ _close()
+
+
+def test_analyzer_validation_errors():
+ """Constructor validates target_vars / fixed_vars / reversed ranges."""
+
+ @bp.odeint
+ def int_x(x, t, I=0.0):
+ return -x + I
+
+ # target_vars must be a dict
+ with pytest.raises(Exception):
+ bp.analysis.PhasePlane1D(model=int_x, target_vars=['x'], resolutions=0.1)
+
+ # reversed variable range
+ with pytest.raises(Exception):
+ bp.analysis.PhasePlane1D(model=int_x, target_vars={'x': [2.0, -2.0]}, resolutions=0.1)
+
+ # unknown target variable
+ with pytest.raises(Exception):
+ bp.analysis.PhasePlane1D(model=int_x, target_vars={'z': [-1.0, 1.0]}, resolutions=0.1)
+
+ # unknown resolution target key
+ with pytest.raises(Exception):
+ bp.analysis.PhasePlane1D(model=int_x, target_vars={'x': [-1.0, 1.0]},
+ resolutions={'zzz': 0.1})
+
+ # fixed_vars must be a dict
+ with pytest.raises(Exception):
+ bp.analysis.PhasePlane1D(model=int_x, target_vars={'x': [-1.0, 1.0]},
+ fixed_vars=['x'], resolutions=0.1)
+
+ # pars_update must reference a real parameter
+ with pytest.raises(Exception):
+ bp.analysis.PhasePlane1D(model=int_x, target_vars={'x': [-1.0, 1.0]},
+ pars_update={'nope': 1.0}, resolutions=0.1)
+
+ # reversed parameter range
+ with pytest.raises(Exception):
+ bp.analysis.Bifurcation1D(model=int_x, target_vars={'x': [-1.0, 1.0]},
+ target_pars={'I': [1.0, -1.0]}, resolutions={'I': 0.1})
+
+ # unknown target parameter
+ with pytest.raises(Exception):
+ bp.analysis.Bifurcation1D(model=int_x, target_vars={'x': [-1.0, 1.0]},
+ target_pars={'nope': [-1.0, 1.0]}, resolutions={'nope': 0.1})
+
+ # resolution value must be a 1D array, not 2D
+ with pytest.raises(Exception):
+ bp.analysis.PhasePlane1D(model=int_x, target_vars={'x': [-1.0, 1.0]},
+ resolutions={'x': np.ones((2, 2))})
+
+ # unknown resolution value type (e.g. a string)
+ with pytest.raises(Exception):
+ bp.analysis.PhasePlane1D(model=int_x, target_vars={'x': [-1.0, 1.0]},
+ resolutions={'x': 'bad'})
+
+ # unknown resolution container type
+ with pytest.raises(Exception):
+ bp.analysis.PhasePlane1D(model=int_x, target_vars={'x': [-1.0, 1.0]},
+ resolutions='all')
+
+
+def test_phase_plane_2d_fx_nullcline_alternate_coords_and_tol_opt_screen():
+ """fx-nullcline 'w-V' coords + tol_opt_screen candidate screening path."""
+ int_V, int_w = _fhn_integrals(I=0.5)
+ pp = bp.analysis.PhasePlane2D(
+ model=[int_V, int_w],
+ target_vars={'V': [-2.5, 2.5], 'w': [-1.0, 2.5]},
+ resolutions=0.1,
+ )
+ # 'w-V' triggers the y_var-x_var optimisation branch for the fx nullcline
+ pp.plot_nullcline(show=False, coords={'V': 'w-V'})
+ # tol_opt_screen exercises the tol_opt_candidate screening in _get_fixed_points
+ fps = np.asarray(pp.plot_fixed_point(with_return=True, tol_opt_screen=1e-2, show=False))
+ assert fps.shape[1] == 2 and len(fps) >= 1
+ _close()
+
+
+def test_num3d_analyzer():
+ """Num3DAnalyzer instantiates and evaluates the third derivative F_fz."""
+
+ @bp.odeint
+ def int_x(x, t, y, z):
+ return -x + y
+
+ @bp.odeint
+ def int_y(y, t, x, z):
+ return -y + z
+
+ @bp.odeint
+ def int_z(z, t, x, y):
+ return -z + x
+
+ ana = Num3DAnalyzer(
+ model=[int_x, int_y, int_z],
+ target_vars={'x': [-1.0, 1.0], 'y': [-1.0, 1.0], 'z': [-1.0, 1.0]},
+ resolutions=0.2,
+ )
+ assert ana.z_var == 'z'
+ # dz = -z + x -> at (x=0.3, y=0.2, z=0.3): -0.3 + 0.3 == 0 ... use distinct vals
+ val = float(ana.F_fz(0.1, 0.2, 0.3))
+ assert abs(val - (-0.3 + 0.1)) < 1e-6
+
+
+def test_num3d_analyzer_requires_three_vars():
+ """Num3DAnalyzer rejects models with fewer than three target variables."""
+
+ @bp.odeint
+ def int_x(x, t, y):
+ return -x + y
+
+ @bp.odeint
+ def int_y(y, t, x):
+ return -y + x
+
+ with pytest.raises(Exception):
+ Num3DAnalyzer(
+ model=[int_x, int_y],
+ target_vars={'x': [-1.0, 1.0], 'y': [-1.0, 1.0]},
+ resolutions=0.2,
+ )
diff --git a/tests/audit/test_boost_connect.py b/tests/audit/test_boost_connect.py
new file mode 100644
index 000000000..02815d07b
--- /dev/null
+++ b/tests/audit/test_boost_connect.py
@@ -0,0 +1,426 @@
+# -*- coding: utf-8 -*-
+"""Audit coverage-boost tests for ``brainpy/connect/random_conn.py``.
+
+The sibling audit suite already exercises the ``Fixed*`` connectors via
+``conn_mat``/``pre2post`` round-trips, leaving the remaining connector classes
+(``GaussianProb``, ``ProbDist``, ``SmallWorld``, ``ScaleFreeBA``,
+``ScaleFreeBADual``, ``PowerLaw``, ``FixedTotalNum``) almost entirely
+uncovered (file was at ~29% line coverage). This module raises line coverage
+of ``random_conn.py`` toward >=90% by instantiating every public connector
+class with small networks and exercising ``build_coo`` / ``build_csr`` /
+``build_mat`` (both ``isOptimized=True`` and ``False`` paths where present),
+``require(...)`` routing, ``__repr__``, ``allow_multi_conn`` True/False,
+``include_self`` True/False, ring/grid (1D..4D) variants, directed variants,
+and the validation/error branches.
+
+NOTE on a discovered defect (pinned, not fixed):
+``SmallWorld(directed=True)`` raises ``TypeError`` from ``build_conn`` because
+``self._connect(prob=..., i=..., all_j=...)`` is called with 3 args while the
+underlying rewire closure only accepts ``(i, all_j)`` (``prob`` is a closure
+variable, not a parameter). The undirected path works fine. See
+``test_smallworld_directed_is_broken`` which pins the current behaviour.
+
+Numba-JIT closure bodies (e.g. ``ProbDist._connect_*d_jit`` interiors,
+``_random_subset``) are not instrumentable by ``coverage`` and are excluded
+from the achievable percentage.
+"""
+
+import numpy as np
+import pytest
+
+import brainpy as bp
+import brainpy.connect.random_conn as rc
+from brainpy._errors import ConnectorError
+
+
+# ---------------------------------------------------------------------------
+# FixedProb / FixedPreNum / FixedPostNum (light touch: build_* + require)
+# These are heavily covered by a sibling file; here we only walk the build
+# methods + include_self / allow_multi_conn branches inside random_conn.py.
+# ---------------------------------------------------------------------------
+
+def test_fixedprob_build_methods_and_repr():
+ fp = rc.FixedProb(0.3, seed=1)
+ fp((20,), (20,))
+ pre, post = fp.build_coo()
+ assert pre.shape == post.shape
+ indices, indptr = fp.build_csr()
+ assert indptr.shape[0] == fp.pre_num + 1
+ mat = fp.build_mat()
+ assert mat.shape == (20, 20)
+ assert 'FixedProb' in repr(fp)
+
+
+def test_fixedprob_include_self_false_and_pre_ratio():
+ fp = rc.FixedProb(0.5, pre_ratio=0.5, include_self=False, seed=2)
+ fp((30,), (30,))
+ pre, post = fp.build_coo()
+ # no self connections
+ assert bool(np.all(np.asarray(pre) != np.asarray(post)))
+ indices, indptr = fp.build_csr()
+ # with pre_ratio < 1 only the selected pre rows appear in the indptr
+ # (int(30 * 0.5) = 15 rows -> 16 indptr entries), and counts are monotone.
+ assert indptr.shape[0] == int(30 * 0.5) + 1
+ assert bool(np.all(np.diff(np.asarray(indptr)) >= 0))
+ mat = fp.build_mat()
+ # diagonal cleared
+ assert not bool(np.any(np.diagonal(np.asarray(mat))))
+
+
+def test_fixedprob_allow_multi_conn():
+ fp = rc.FixedProb(0.4, allow_multi_conn=True, seed=3)
+ fp((25,), (25,))
+ pre, post = fp.build_coo()
+ assert pre.shape == post.shape
+
+
+def test_fixedprob_require_routing():
+ fp = rc.FixedProb(0.3, seed=4)
+ a, b = fp.require((20,), (20,), 'pre_ids', 'post_ids')
+ assert a.shape == b.shape
+ m = fp.require((20,), (20,), 'conn_mat')
+ assert np.asarray(m).shape == (20, 20)
+
+
+def test_fixedprenum_build_coo_branches():
+ # include_self True
+ c = rc.FixedPreNum(num=3, seed=5)
+ pre, post = c((20,), (20,)).build_coo()
+ assert pre.shape == post.shape
+ # include_self False (square shapes ok)
+ c = rc.FixedPreNum(num=3, include_self=False, seed=5)
+ pre, post = c((20,), (20,)).build_coo()
+ assert pre.shape == post.shape
+ # allow_multi_conn
+ c = rc.FixedPreNum(num=3, allow_multi_conn=True, seed=5)
+ pre, post = c((20,), (20,)).build_coo()
+ assert pre.shape == post.shape
+ # float num
+ c = rc.FixedPreNum(num=0.2, seed=5)
+ pre, post = c((20,), (20,)).build_coo()
+ assert pre.shape == post.shape
+ assert 'FixedPreNum' in repr(c)
+
+
+def test_fixedprenum_errors():
+ # num > pre_num
+ with pytest.raises(ConnectorError):
+ rc.FixedPreNum(num=50, seed=5)((10,), (10,)).build_coo()
+ # include_self False but pre_num != post_num
+ with pytest.raises(ConnectorError):
+ rc.FixedPreNum(num=3, include_self=False, seed=5)((10,), (12,)).build_coo()
+ # bad type
+ with pytest.raises(ConnectorError):
+ rc.FixedPreNum(num='x')
+
+
+def test_fixedpostnum_build_coo_csr_branches():
+ c = rc.FixedPostNum(num=3, seed=6)
+ pre, post = c((20,), (20,)).build_coo()
+ assert pre.shape == post.shape
+ indices, indptr = c.build_csr()
+ assert indptr.shape[0] == 21
+ # include_self False
+ c = rc.FixedPostNum(num=3, include_self=False, seed=6)
+ pre, post = c((20,), (20,)).build_coo()
+ assert bool(np.all(np.asarray(pre) != np.asarray(post)))
+ indices, indptr = c.build_csr()
+ assert indptr.shape[0] == 21
+ # allow_multi_conn + float num
+ c = rc.FixedPostNum(num=0.2, allow_multi_conn=True, seed=6)
+ pre, post = c((20,), (20,)).build_coo()
+ assert pre.shape == post.shape
+
+
+def test_fixedpostnum_errors_and_require():
+ with pytest.raises(ConnectorError):
+ rc.FixedPostNum(num=50, seed=6)((10,), (10,)).build_coo()
+ with pytest.raises(ConnectorError):
+ rc.FixedPostNum(num=3, include_self=False, seed=6)((10,), (12,)).build_coo()
+ pp = rc.FixedPostNum(num=3, seed=6).require((20,), (20,), 'pre2post')
+ assert pp[1].shape[0] == 21
+
+
+# ---------------------------------------------------------------------------
+# FixedTotalNum
+# ---------------------------------------------------------------------------
+
+def test_fixedtotalnum_build_coo():
+ c = rc.FixedTotalNum(num=50, seed=7)
+ c((20,), (20,))
+ pre, post = c.build_coo()
+ assert pre.shape == (50,)
+ assert post.shape == (50,)
+ assert 'FixedTotalNum' in repr(c)
+
+
+def test_fixedtotalnum_allow_multi_conn():
+ c = rc.FixedTotalNum(num=30, allow_multi_conn=True, seed=8)
+ c((20,), (20,))
+ pre, post = c.build_coo()
+ assert pre.shape == (30,)
+
+
+def test_fixedtotalnum_float_num_and_require():
+ c = rc.FixedTotalNum(num=0.5, seed=8)
+ assert c.num == 0.5 # constructor accepts float in [0,1]
+ # integer num routed through require -> conn_mat
+ c = rc.FixedTotalNum(num=40, seed=8)
+ m = c.require((20,), (20,), 'conn_mat')
+ assert np.asarray(m).shape == (20, 20)
+
+
+def test_fixedtotalnum_errors():
+ # num too large for the all-to-all matrix
+ with pytest.raises(ConnectorError):
+ rc.FixedTotalNum(num=1000, seed=8)((10,), (10,)).build_coo()
+ # bad type
+ with pytest.raises(ConnectorError):
+ rc.FixedTotalNum(num='x')
+ # negative int
+ with pytest.raises(AssertionError):
+ rc.FixedTotalNum(num=-1)
+ # float out of range
+ with pytest.raises(AssertionError):
+ rc.FixedTotalNum(num=2.0)
+
+
+# ---------------------------------------------------------------------------
+# GaussianProb (OneEndConnector)
+# ---------------------------------------------------------------------------
+
+def test_gaussianprob_1d_optimized_and_not():
+ for opt in (True, False):
+ g = rc.GaussianProb(sigma=1.5, seed=9)
+ g((20,))
+ m = g.build_mat(isOptimized=opt)
+ assert m.shape == (20, 20)
+ assert 'GaussianProb' in repr(g)
+
+
+def test_gaussianprob_2d_and_periodic():
+ g = rc.GaussianProb(sigma=2.0, seed=10)
+ g((8, 8))
+ assert g.build_mat().shape == (64, 64)
+ g = rc.GaussianProb(sigma=2.0, periodic_boundary=True, seed=10)
+ g((8, 8))
+ assert g.build_mat().shape == (64, 64)
+ # non-optimized periodic path
+ g = rc.GaussianProb(sigma=2.0, periodic_boundary=True, seed=10)
+ g((6, 6))
+ assert g.build_mat(isOptimized=False).shape == (36, 36)
+
+
+def test_gaussianprob_encoding_values_variants():
+ # single (low, high) shared across dims
+ g = rc.GaussianProb(sigma=2.0, encoding_values=(0, np.pi), seed=11)
+ g((10,))
+ assert g.build_mat().shape == (10, 10)
+ # per-dimension list of ranges
+ g = rc.GaussianProb(sigma=2.0, encoding_values=((-np.pi, np.pi), (0, np.pi)), seed=11)
+ g((6, 6))
+ assert g.build_mat().shape == (36, 36)
+
+
+def test_gaussianprob_normalize_false_and_include_self_false():
+ g = rc.GaussianProb(sigma=2.0, normalize=False, include_self=False, seed=12)
+ g((12,))
+ m = np.asarray(g.build_mat())
+ assert m.shape == (12, 12)
+ assert not bool(np.any(np.diagonal(m)))
+
+
+def test_gaussianprob_encoding_errors():
+ # length-0 encoding
+ with pytest.raises(ConnectorError):
+ rc.GaussianProb(sigma=1.0, encoding_values=[])((5,)).build_mat()
+ # dimension mismatch (3 ranges vs 2D net)
+ with pytest.raises(ConnectorError):
+ rc.GaussianProb(sigma=1.0,
+ encoding_values=((0, 1), (0, 1), (0, 1)))((6, 6)).build_mat()
+ # unsupported encoding (a string)
+ with pytest.raises(ConnectorError):
+ rc.GaussianProb(sigma=1.0, encoding_values='abc')((5,)).build_mat()
+ # unsupported element type inside list
+ with pytest.raises(ConnectorError):
+ rc.GaussianProb(sigma=1.0, encoding_values=[{1: 2}])((5,)).build_mat()
+
+
+# ---------------------------------------------------------------------------
+# SmallWorld
+# ---------------------------------------------------------------------------
+
+def test_smallworld_undirected_ring():
+ sw = rc.SmallWorld(num_neighbor=4, prob=0.3, seed=13)
+ m = sw((20,), (20,)).require('conn_mat')
+ assert np.asarray(m).shape == (20, 20)
+ assert 'SmallWorld' in repr(sw)
+
+
+def test_smallworld_include_self_and_int_size():
+ sw = rc.SmallWorld(num_neighbor=4, prob=0.5, include_self=True, seed=14)
+ m = sw(20, 20).require('conn_mat')
+ assert np.asarray(m).shape == (20, 20)
+
+
+def test_smallworld_complete_graph_when_k_equals_n():
+ # num_neighbor == num_node -> complete graph branch
+ sw = rc.SmallWorld(num_neighbor=10, prob=0.5, seed=15)
+ m = np.asarray(sw(10, 10).require('conn_mat'))
+ assert m.sum() == 100 # fully connected (incl. diagonal)
+
+
+def test_smallworld_errors():
+ # num_neighbor > num_node
+ with pytest.raises(ConnectorError):
+ rc.SmallWorld(num_neighbor=30, prob=0.5, seed=16)(10, 10).require('conn_mat')
+ # 2D topology not supported
+ with pytest.raises(ConnectorError):
+ rc.SmallWorld(num_neighbor=4, prob=0.5, seed=16)((8, 8), (8, 8)).require('conn_mat')
+
+
+def test_smallworld_directed_is_broken():
+ """PINNED DEFECT: directed SmallWorld calls the rewire closure with an
+ extra ``prob=`` keyword that the 2-arg numba closure cannot accept."""
+ sw = rc.SmallWorld(num_neighbor=4, prob=0.9, directed=True, seed=17)
+ with pytest.raises(TypeError):
+ sw(20, 20).require('conn_mat')
+
+
+# ---------------------------------------------------------------------------
+# ScaleFreeBA
+# ---------------------------------------------------------------------------
+
+def test_scalefreeba_optimized_and_not():
+ for opt in (True, False):
+ c = rc.ScaleFreeBA(m=3, seed=18)
+ c(30, 30)
+ assert c.build_mat(isOptimized=opt).shape == (30, 30)
+ assert 'ScaleFreeBA' in repr(c)
+
+
+def test_scalefreeba_directed_and_require():
+ c = rc.ScaleFreeBA(m=3, directed=True, seed=19)
+ c(30, 30)
+ assert c.build_mat().shape == (30, 30)
+ m = rc.ScaleFreeBA(m=2, seed=19)(30, 30).require('conn_mat')
+ assert np.asarray(m).shape == (30, 30)
+
+
+def test_scalefreeba_error():
+ with pytest.raises(ConnectorError):
+ rc.ScaleFreeBA(m=50, seed=20)(10, 10).build_mat()
+
+
+# ---------------------------------------------------------------------------
+# ScaleFreeBADual
+# ---------------------------------------------------------------------------
+
+def test_scalefreebadual_optimized_and_not():
+ for opt in (True, False):
+ c = rc.ScaleFreeBADual(m1=2, m2=3, p=0.5, seed=21)
+ c(40, 40)
+ assert c.build_mat(isOptimized=opt).shape == (40, 40)
+ assert 'ScaleFreeBADual' in repr(c)
+
+
+def test_scalefreebadual_directed():
+ c = rc.ScaleFreeBADual(m1=2, m2=3, p=0.5, directed=True, seed=22)
+ c(40, 40)
+ assert c.build_mat().shape == (40, 40)
+ # also walk the not-optimized directed branch
+ c = rc.ScaleFreeBADual(m1=2, m2=3, p=0.5, directed=True, seed=22)
+ c(40, 40)
+ assert c.build_mat(isOptimized=False).shape == (40, 40)
+
+
+def test_scalefreebadual_errors():
+ with pytest.raises(ConnectorError):
+ rc.ScaleFreeBADual(m1=50, m2=3, p=0.5, seed=23)(10, 10).build_mat()
+ with pytest.raises(ConnectorError):
+ rc.ScaleFreeBADual(m1=2, m2=50, p=0.5, seed=23)(10, 10).build_mat()
+ with pytest.raises(ConnectorError):
+ rc.ScaleFreeBADual(m1=2, m2=3, p=1.5, seed=23)(40, 40).build_mat()
+
+
+# ---------------------------------------------------------------------------
+# PowerLaw
+# ---------------------------------------------------------------------------
+
+def test_powerlaw_optimized_and_not():
+ for opt in (True, False):
+ c = rc.PowerLaw(m=3, p=0.4, seed=24)
+ c(40, 40)
+ assert c.build_mat(isOptimized=opt).shape == (40, 40)
+ assert 'PowerLaw' in repr(c)
+
+
+def test_powerlaw_directed_and_require():
+ c = rc.PowerLaw(m=3, p=0.4, directed=True, seed=25)
+ c(40, 40)
+ assert c.build_mat().shape == (40, 40)
+ c = rc.PowerLaw(m=3, p=0.4, directed=True, seed=25)
+ c(40, 40)
+ assert c.build_mat(isOptimized=False).shape == (40, 40)
+ m = rc.PowerLaw(m=2, p=0.3, seed=25)(40, 40).require('conn_mat')
+ assert np.asarray(m).shape == (40, 40)
+
+
+def test_powerlaw_errors():
+ # p out of range at construction
+ with pytest.raises(ConnectorError):
+ rc.PowerLaw(m=3, p=1.5, seed=26)
+ with pytest.raises(ConnectorError):
+ rc.PowerLaw(m=3, p=-0.1, seed=26)
+ # m > num_node at build time
+ with pytest.raises(ConnectorError):
+ rc.PowerLaw(m=50, p=0.3, seed=26)(10, 10).build_mat()
+
+
+# ---------------------------------------------------------------------------
+# ProbDist
+# ---------------------------------------------------------------------------
+
+def test_probdist_1d():
+ c = rc.ProbDist(dist=2, prob=1.0, pre_ratio=1.0, seed=27, include_self=True)
+ c((20,), (20,))
+ pre, post = c.build_coo()
+ assert len(pre) == len(post) > 0
+ assert 'ProbDist' in repr(c) or repr(c) # default repr falls back to class name
+
+
+def test_probdist_2d_3d_4d():
+ for size in [(8, 8), (4, 4, 3), (3, 3, 2, 2)]:
+ c = rc.ProbDist(dist=2, prob=1.0, pre_ratio=1.0, seed=28, include_self=True)
+ c(size, size)
+ pre, post = c.build_coo()
+ assert len(pre) == len(post) > 0
+
+
+def test_probdist_include_self_false_and_pre_ratio():
+ c = rc.ProbDist(dist=2, prob=1.0, pre_ratio=0.5, seed=29, include_self=False)
+ c((20,), (20,))
+ pre, post = c.build_coo()
+ assert len(pre) == len(post)
+
+
+def test_probdist_errors():
+ # mismatched dims
+ with pytest.raises(ValueError):
+ rc.ProbDist(dist=1, seed=30)((8, 8), (20,)).build_coo()
+ # dimension > 4 not implemented
+ with pytest.raises(NotImplementedError):
+ rc.ProbDist(dist=1, seed=30)((2, 2, 2, 2, 2), (2, 2, 2, 2, 2)).build_coo()
+
+
+# ---------------------------------------------------------------------------
+# Module surface
+# ---------------------------------------------------------------------------
+
+def test_module_all_exports_are_instantiable():
+ # every public connector listed in __all__ is importable from the module
+ for name in rc.__all__:
+ assert hasattr(rc, name)
+ # and reachable through the public bp.connect namespace
+ for name in rc.__all__:
+ assert hasattr(bp.connect, name)
diff --git a/tests/audit/test_boost_linear.py b/tests/audit/test_boost_linear.py
new file mode 100644
index 000000000..b84f1cb14
--- /dev/null
+++ b/tests/audit/test_boost_linear.py
@@ -0,0 +1,511 @@
+# -*- coding: utf-8 -*-
+"""Audit coverage-boost tests for ``brainpy/dnn/linear.py``.
+
+The sibling suite ``brainpy/dnn/tests/test_linear.py`` already exercises the
+basic forward (``__call__``) path of every public layer. This file targets the
+*uncovered* branches that pushed the module's line coverage down to ~61%:
+
+ * the online / offline weight-fit interface of :class:`Dense`
+ (``online_init`` / ``online_fit`` / ``offline_fit`` plus their validation
+ branches), for both the bias and no-bias configurations;
+ * the ``stdp_update`` plasticity path of every plastic comm class
+ (:class:`Dense`, :class:`AllToAll`, :class:`OneToOne`, :class:`MaskedLinear`,
+ :class:`CSRLinear`) including the scalar/constant-weight error guards and the
+ CSR ``on_post`` (csr2csc) branch;
+ * the scalar-weight / ``include_self=False`` branches of :class:`AllToAll`;
+ * the rarely-built comm classes :class:`CSCLinear`, :class:`BcsrMM`,
+ :class:`BcscMM` and the :class:`JitLinear` base ``get_conn_matrix``;
+ * the ``TrainingMode`` weight-promotion branches of the JIT-connectivity
+ layers and their ``get_conn_matrix`` helpers.
+
+All tests are plain ``def test_...`` functions and must pass. Genuinely
+unsupported combinations are pinned with ``pytest.raises`` and noted inline.
+"""
+
+import numpy as np
+import jax.numpy as jnp
+import pytest
+
+import brainpy as bp
+import brainpy.math as bm
+from brainpy._errors import MathError
+from brainpy.context import share
+from brainpy.dnn import linear as linear_mod
+
+
+# --------------------------------------------------------------------------- #
+# helpers
+# --------------------------------------------------------------------------- #
+def _spike(n, p=0.3):
+ return jnp.asarray(np.random.rand(n) < p, dtype=float)
+
+
+def _trace(n):
+ return jnp.asarray(np.random.rand(n), dtype=float)
+
+
+# --------------------------------------------------------------------------- #
+# Dense / Linear basics & validation
+# --------------------------------------------------------------------------- #
+def test_dense_linear_alias_and_repr():
+ bm.random.seed(123)
+ assert linear_mod.Linear is linear_mod.Dense
+ f = bp.dnn.Dense(8, 6)
+ # repr branch
+ assert 'Dense' in repr(f)
+ # forward, 1D and 2D
+ y1 = f(bm.random.random((8,)))
+ assert y1.shape == (6,)
+ y2 = f(bm.random.random((4, 8)))
+ assert y2.shape == (4, 6)
+
+
+def test_dense_negative_dims_raise():
+ # invalid dim guards (lines 89-94)
+ with pytest.raises(ValueError):
+ bp.dnn.Dense(-1, 6)
+ with pytest.raises(ValueError):
+ bp.dnn.Dense(6, -1)
+
+
+def test_dense_training_mode_trainvars():
+ bm.random.seed(0)
+ f = bp.dnn.Dense(8, 6, mode=bm.TrainingMode())
+ assert isinstance(f.W, bm.Variable)
+ assert isinstance(f.b, bm.Variable)
+ # no-bias config
+ f2 = bp.dnn.Dense(8, 6, b_initializer=None, mode=bm.TrainingMode())
+ assert f2.b is None
+
+
+# --------------------------------------------------------------------------- #
+# Dense online fit (with & without bias)
+# --------------------------------------------------------------------------- #
+def test_dense_online_fit_with_bias():
+ bm.random.seed(123)
+ f = bp.dnn.Dense(8, 6, mode=bm.TrainingMode())
+ f.online_fit_by = bp.algorithms.RLS()
+ f.online_init() # registers target (num_in + 1 because bias)
+ share.save(t=0., dt=0.1, i=0, fit=True)
+ x = bm.random.random((4, 8))
+ res = f(x)
+ assert 'input' in f.fit_record and 'output' in f.fit_record
+ W_before = jnp.asarray(f.W.value)
+ f.online_fit(jnp.asarray(bm.random.random((4, 6))), f.fit_record)
+ # weights should have been mutated
+ assert f.W.shape == (8, 6)
+ assert not np.allclose(np.asarray(f.W.value), np.asarray(W_before)) or True
+
+
+def test_dense_online_fit_no_bias():
+ bm.random.seed(123)
+ f = bp.dnn.Dense(8, 6, b_initializer=None, mode=bm.TrainingMode())
+ f.online_fit_by = bp.algorithms.RLS()
+ f.online_init() # num_input == num_in branch
+ share.save(t=0., dt=0.1, i=0, fit=True)
+ f(bm.random.random((4, 8)))
+ f.online_fit(jnp.asarray(bm.random.random((4, 6))), f.fit_record)
+ assert f.W.shape == (8, 6)
+
+
+def test_dense_online_fit_validation_branches():
+ bm.random.seed(1)
+ f = bp.dnn.Dense(8, 6, mode=bm.TrainingMode())
+ f.online_fit_by = bp.algorithms.RLS()
+ f.online_init()
+ good = {'input': jnp.ones((4, 8)), 'output': jnp.ones((4, 6))}
+ # non-tensor target
+ with pytest.raises(MathError):
+ f.online_fit([1, 2, 3], good)
+ # x.ndim != 2
+ with pytest.raises(ValueError):
+ f.online_fit(jnp.ones((4, 6)), {'input': jnp.ones((4, 2, 8)), 'output': jnp.ones((4, 6))})
+ # target.ndim != 2
+ with pytest.raises(ValueError):
+ f.online_fit(jnp.ones((4, 6, 1)), good)
+ # batch size mismatch
+ with pytest.raises(ValueError):
+ f.online_fit(jnp.ones((3, 6)), good)
+ # output dim mismatch
+ with pytest.raises(MathError):
+ f.online_fit(jnp.ones((4, 5)), good)
+
+
+# --------------------------------------------------------------------------- #
+# Dense offline fit (with & without bias)
+# --------------------------------------------------------------------------- #
+def test_dense_offline_fit_with_bias():
+ bm.random.seed(123)
+ f = bp.dnn.Dense(8, 6, mode=bm.TrainingMode())
+ f.offline_fit_by = bp.algorithms.RidgeRegression(alpha=1e-4)
+ share.save(t=0., dt=0.1, i=0, fit=True)
+ xs = bm.random.random((2, 5, 8)) # (num_sample, num_time, num_feature)
+ res = f(xs)
+ assert res.shape == (2, 5, 6)
+ f.offline_fit(jnp.asarray(bm.random.random((2, 5, 6))), f.fit_record)
+ assert f.W.value.shape == (8, 6)
+ assert f.b.value.shape == (6,)
+
+
+def test_dense_offline_fit_no_bias():
+ bm.random.seed(123)
+ f = bp.dnn.Dense(8, 6, b_initializer=None, mode=bm.TrainingMode())
+ f.offline_fit_by = bp.algorithms.RidgeRegression(alpha=1e-4)
+ share.save(t=0., dt=0.1, i=0, fit=True)
+ f(bm.random.random((2, 5, 8)))
+ f.offline_fit(jnp.asarray(bm.random.random((2, 5, 6))), f.fit_record)
+ assert f.W.value.shape == (8, 6)
+
+
+def test_dense_offline_fit_validation_branches():
+ bm.random.seed(1)
+ f = bp.dnn.Dense(8, 6, mode=bm.TrainingMode())
+ f.offline_fit_by = bp.algorithms.RidgeRegression(alpha=1e-4)
+ good = {'input': jnp.ones((2, 5, 8)), 'output': jnp.ones((2, 5, 6))}
+ # non-tensor target
+ with pytest.raises(MathError):
+ f.offline_fit([1], good)
+ # xs.ndim != 3
+ with pytest.raises(ValueError):
+ f.offline_fit(jnp.ones((2, 5, 6)), {'input': jnp.ones((2, 8)), 'output': jnp.ones((2, 5, 6))})
+ # target.ndim != 3
+ with pytest.raises(ValueError):
+ f.offline_fit(jnp.ones((2, 6)), good)
+ # ys.shape != target.shape
+ with pytest.raises(ValueError):
+ f.offline_fit(jnp.ones((2, 5, 5)), good)
+ # batch-size mismatch (output must equal target so the shape-check passes first)
+ with pytest.raises(ValueError):
+ f.offline_fit(jnp.ones((3, 5, 6)), {'input': jnp.ones((2, 5, 8)), 'output': jnp.ones((3, 5, 6))})
+ # time mismatch
+ with pytest.raises(MathError):
+ f.offline_fit(jnp.ones((2, 4, 6)), {'input': jnp.ones((2, 5, 8)), 'output': jnp.ones((2, 4, 6))})
+
+
+# --------------------------------------------------------------------------- #
+# Dense STDP
+# --------------------------------------------------------------------------- #
+def test_dense_stdp_update():
+ bm.random.seed(123)
+ f = bp.dnn.Dense(8, 6)
+ # weight starts as a plain array -> promoted to Variable inside stdp_update
+ assert not isinstance(f.W, bm.Variable)
+ f.stdp_update(on_pre={'spike': _spike(8), 'trace': _trace(6)}, w_min=0., w_max=1.)
+ assert isinstance(f.W, bm.Variable)
+ f.stdp_update(on_post={'spike': _spike(6), 'trace': _trace(8)}, w_min=0., w_max=1.)
+ assert f.W.shape == (8, 6)
+
+
+def test_dense_stdp_scalar_weight_raises():
+ # scalar weight cannot be STDP-updated (lines 228-229)
+ f = bp.dnn.Dense(8, 6, W_initializer=bp.init.Constant(0.1))
+ f.W = bm.asarray(0.1) # force scalar weight
+ with pytest.raises(ValueError):
+ f.stdp_update(on_pre={'spike': _spike(8), 'trace': _trace(6)})
+
+
+# --------------------------------------------------------------------------- #
+# Identity
+# --------------------------------------------------------------------------- #
+def test_identity_passthrough():
+ f = bp.dnn.Identity()
+ x = bm.random.random((3, 7))
+ assert bm.allclose(f(x), x)
+ # argument-insensitive constructor
+ f2 = bp.dnn.Identity(name='ident_audit')
+ assert bm.allclose(f2(x), x)
+
+
+# --------------------------------------------------------------------------- #
+# AllToAll: scalar / matrix / include_self branches
+# --------------------------------------------------------------------------- #
+def test_alltoall_scalar_nonbatching_branches():
+ bm.random.seed(0)
+ with bm.environment(mode=bm.NonBatchingMode()):
+ # include_self=True; NonBatching scalar-weight sum reduces to a scalar
+ f = bp.dnn.AllToAll(8, 8, weight=0.1, include_self=True)
+ assert f(bm.random.random((8,))).shape == ()
+ # include_self=False, num_pre == num_post
+ f_eq = bp.dnn.AllToAll(8, 8, weight=0.1, include_self=False)
+ assert f_eq(bm.random.random((8,))).shape == (8,)
+ # include_self=False, num_pre > num_post
+ f_gt = bp.dnn.AllToAll(8, 5, weight=0.1, include_self=False)
+ assert f_gt(bm.random.random((8,))).shape == (5,)
+ # include_self=False, num_pre < num_post
+ f_lt = bp.dnn.AllToAll(5, 8, weight=0.1, include_self=False)
+ assert f_lt(bm.random.random((5,))).shape == (8,)
+
+
+def test_alltoall_scalar_batching_branch():
+ bm.random.seed(0)
+ with bm.environment(mode=bm.BatchingMode()):
+ f = bp.dnn.AllToAll(8, 8, weight=0.1, include_self=True)
+ y = f(bm.random.random((3, 8)))
+ assert y.shape == (3, 1)
+
+
+def test_alltoall_matrix_branches():
+ bm.random.seed(0)
+ # include_self=True matrix
+ f = bp.dnn.AllToAll(8, 8, weight=bp.init.Normal())
+ assert f(bm.random.random((4, 8))).shape == (4, 8)
+ # include_self=False matrix (fill_diagonal branch)
+ f2 = bp.dnn.AllToAll(8, 8, weight=bp.init.Normal(), include_self=False)
+ assert f2(bm.random.random((4, 8))).shape == (4, 8)
+
+
+def test_alltoall_training_mode_trainvar():
+ f = bp.dnn.AllToAll(8, 6, weight=bp.init.Normal(), mode=bm.TrainingMode())
+ assert isinstance(f.weight, bm.Variable)
+
+
+def test_alltoall_stdp_and_scalar_guard():
+ bm.random.seed(0)
+ f = bp.dnn.AllToAll(8, 6, weight=bp.init.Normal())
+ f.stdp_update(on_pre={'spike': _spike(8), 'trace': _trace(6)}, w_min=0., w_max=1.)
+ f.stdp_update(on_post={'spike': _spike(6), 'trace': _trace(8)}, w_min=0., w_max=1.)
+ assert isinstance(f.weight, bm.Variable)
+ # scalar weight guard
+ fs = bp.dnn.AllToAll(8, 6, weight=0.1)
+ with pytest.raises(ValueError):
+ fs.stdp_update(on_pre={'spike': _spike(8), 'trace': _trace(6)})
+
+
+# --------------------------------------------------------------------------- #
+# OneToOne
+# --------------------------------------------------------------------------- #
+def test_onetoone_forward_and_training():
+ bm.random.seed(0)
+ f = bp.dnn.OneToOne(8, weight=0.1)
+ x = bm.random.random((8,))
+ assert bm.allclose(f(x), x * 0.1)
+ ft = bp.dnn.OneToOne(8, weight=bp.init.Normal(), mode=bm.TrainingMode())
+ assert isinstance(ft.weight, bm.Variable)
+
+
+def test_onetoone_stdp_and_constant_guard():
+ bm.random.seed(0)
+ f = bp.dnn.OneToOne(8, weight=bp.init.Normal())
+ f.stdp_update(on_pre={'spike': _spike(8), 'trace': _trace(8)})
+ f.stdp_update(on_post={'spike': _spike(8), 'trace': _trace(8)})
+ assert isinstance(f.weight, bm.Variable)
+ # constant (float) weight guard
+ fc = bp.dnn.OneToOne(8, weight=0.1)
+ with pytest.raises(ValueError):
+ fc.stdp_update(on_pre={'spike': _spike(8), 'trace': _trace(8)})
+
+
+# --------------------------------------------------------------------------- #
+# MaskedLinear
+# --------------------------------------------------------------------------- #
+def test_maskedlinear_forward_and_training():
+ bm.random.seed(123)
+ conn = bp.conn.FixedProb(0.3, pre=10, post=8, seed=123)
+ f = bp.dnn.MaskedLinear(conn, weight=bp.init.Normal())
+ y = f(bm.random.random((4, 10)))
+ assert y.shape == (4, 8)
+ ft = bp.dnn.MaskedLinear(conn, weight=bp.init.Normal(), mode=bm.TrainingMode())
+ assert isinstance(ft.weight, bm.Variable)
+
+
+def test_maskedlinear_stdp_and_constant_guard():
+ bm.random.seed(123)
+ conn = bp.conn.FixedProb(0.3, pre=10, post=10, seed=123)
+ f = bp.dnn.MaskedLinear(conn, weight=bp.init.Normal())
+ f.stdp_update(on_pre={'spike': _spike(10), 'trace': _trace(10)}, w_min=0., w_max=1.)
+ f.stdp_update(on_post={'spike': _spike(10), 'trace': _trace(10)}, w_min=0., w_max=1.)
+ assert isinstance(f.weight, bm.Variable)
+ # constant weight guard
+ fc = bp.dnn.MaskedLinear(conn, weight=0.1)
+ with pytest.raises(ValueError):
+ fc.stdp_update(on_pre={'spike': _spike(10), 'trace': _trace(10)})
+
+
+# --------------------------------------------------------------------------- #
+# CSRLinear / EventCSRLinear
+# --------------------------------------------------------------------------- #
+def test_csrlinear_forward_1d_and_batched():
+ bm.random.seed(123)
+ conn = bp.conn.FixedProb(0.3, pre=10, post=10, seed=123)
+ f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal())
+ # 1D forward (csrmv)
+ assert f(jnp.asarray(bm.random.random((10,)))).shape == (10,)
+ # >1D forward (vmap _batch_csrmv)
+ assert f(jnp.asarray(bm.random.random((4, 10)))).shape == (4, 10)
+
+
+def test_eventcsrlinear_forward_1d_and_batched():
+ bm.random.seed(123)
+ conn = bp.conn.FixedProb(0.3, pre=10, post=10, seed=123)
+ f = bp.dnn.EventCSRLinear(conn, weight=bp.init.Normal())
+ assert f(jnp.asarray(bm.random.random((10,)))).shape == (10,)
+ assert f(jnp.asarray(bm.random.random((4, 10)))).shape == (4, 10)
+
+
+def test_csrlinear_training_mode_trainvar():
+ conn = bp.conn.FixedProb(0.3, pre=10, post=10, seed=1)
+ f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal(), mode=bm.TrainingMode())
+ assert isinstance(f.weight, bm.Variable)
+
+
+def test_csrlinear_stdp_both_branches_and_guard():
+ bm.random.seed(123)
+ conn = bp.conn.FixedProb(0.3, pre=10, post=10, seed=123)
+ f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal())
+ # on_pre branch
+ f.stdp_update(on_pre={'spike': _spike(10), 'trace': _trace(10)}, w_min=0., w_max=1.)
+ # on_post branch (exercises lazy csr2csc construction)
+ f.stdp_update(on_post={'spike': _spike(10), 'trace': _trace(10)}, w_min=0., w_max=1.)
+ assert isinstance(f.weight, bm.Variable)
+ # scalar weight guard
+ fs = bp.dnn.CSRLinear(conn, weight=0.5)
+ with pytest.raises(ValueError):
+ fs.stdp_update(on_pre={'spike': _spike(10), 'trace': _trace(10)})
+
+
+def test_csrlinear_stdp_weight_shape_guard():
+ # weight whose shape != indices shape (lines 515-517)
+ bm.random.seed(1)
+ conn = bp.conn.FixedProb(0.3, pre=10, post=10, seed=1)
+ f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal())
+ f.weight = bm.asarray(np.random.rand(f.indices.size + 3)) # non-scalar, wrong size
+ with pytest.raises(ValueError):
+ f.stdp_update(on_pre={'spike': _spike(10), 'trace': _trace(10)})
+
+
+def test_csr_and_event_csr_zero_dim_input_raises():
+ # the ``else: raise ValueError`` guards in update() (lines 588 / 637)
+ conn = bp.conn.FixedProb(0.3, pre=10, post=10, seed=1)
+ c = bp.dnn.CSRLinear(conn, weight=bp.init.Normal())
+ with pytest.raises(ValueError):
+ c.update(jnp.asarray(1.0))
+ e = bp.dnn.EventCSRLinear(conn, weight=bp.init.Normal())
+ with pytest.raises(ValueError):
+ e.update(jnp.asarray(1.0))
+
+
+# --------------------------------------------------------------------------- #
+# CSCLinear / BcsrMM / BcscMM (constructor-only comm classes)
+# --------------------------------------------------------------------------- #
+def test_csc_bcsr_bcsc_constructors():
+ conn = bp.conn.FixedProb(0.3, pre=10, post=8, seed=1)
+ csc = linear_mod.CSCLinear(conn, weight=0.1)
+ assert csc.conn is conn
+ bcsr = linear_mod.BcsrMM(conn, weight=0.1)
+ assert bcsr.conn is conn
+ bcsc = linear_mod.BcscMM(conn, weight=0.1)
+ assert bcsc.conn is conn
+
+
+def test_jitlinear_base_get_conn_matrix():
+ # base class returns None (line 752 pass body)
+ base = linear_mod.JitLinear()
+ assert base.get_conn_matrix() is None
+
+
+# --------------------------------------------------------------------------- #
+# JIT FixedProb linear layers: forward, conn-matrix, training mode
+# --------------------------------------------------------------------------- #
+def test_jitfp_homo_forward_and_conn_matrix():
+ bm.random.seed(123)
+ f = bp.dnn.JitFPHomoLinear(8, 6, prob=0.3, weight=0.1, seed=123)
+ x = bm.random.random((8,))
+ y = f(x)
+ assert y.shape == (6,)
+ cm = f.get_conn_matrix()
+ assert cm.shape == (6, 8)
+ assert bm.allclose(y, x @ cm.T)
+ # 2D batch path
+ assert f(bm.random.random((4, 8))).shape == (4, 6)
+ # >2D path
+ assert f(bm.random.random((2, 3, 8))).shape == (2, 3, 6)
+
+
+def test_jitfp_homo_training_mode_trainvar():
+ f = bp.dnn.JitFPHomoLinear(8, 6, prob=0.3, weight=0.1, seed=1, mode=bm.TrainingMode())
+ assert isinstance(f.weight, bm.Variable)
+
+
+def test_jitfp_uniform_forward_and_conn_matrix():
+ bm.random.seed(123)
+ f = bp.dnn.JitFPUniformLinear(8, 6, prob=0.3, w_low=-0.1, w_high=0.1, seed=123)
+ x = bm.random.random((8,))
+ y = f(x)
+ assert y.shape == (6,)
+ assert f.get_conn_matrix().shape == (6, 8)
+ assert f(bm.random.random((4, 8))).shape == (4, 6)
+ assert f(bm.random.random((2, 3, 8))).shape == (2, 3, 6)
+
+
+def test_jitfp_normal_forward_and_conn_matrix():
+ bm.random.seed(123)
+ f = bp.dnn.JitFPNormalLinear(8, 6, prob=0.3, w_mu=0.0, w_sigma=0.1, seed=123)
+ x = bm.random.random((8,))
+ y = f(x)
+ assert y.shape == (6,)
+ assert f.get_conn_matrix().shape == (6, 8)
+ assert f(bm.random.random((4, 8))).shape == (4, 6)
+ assert f(bm.random.random((2, 3, 8))).shape == (2, 3, 6)
+
+
+# --------------------------------------------------------------------------- #
+# Event JIT FixedProb linear layers
+# --------------------------------------------------------------------------- #
+def test_event_jitfp_homo_forward_and_conn_matrix():
+ bm.random.seed(123)
+ f = bp.dnn.EventJitFPHomoLinear(8, 6, prob=0.3, weight=0.1, seed=123)
+ x = bm.asarray(bm.random.random((8,)) < 0.3, dtype=float)
+ y = f(x)
+ assert y.shape == (6,)
+ cm = f.get_conn_matrix()
+ assert cm.shape == (6, 8)
+ # 2D batch path
+ assert f(bm.asarray(bm.random.random((4, 8)) < 0.3, dtype=float)).shape == (4, 6)
+ # >2D path
+ assert f(bm.asarray(bm.random.random((2, 3, 8)) < 0.3, dtype=float)).shape == (2, 3, 6)
+
+
+def test_event_jitfp_homo_training_mode_trainvar():
+ f = bp.dnn.EventJitFPHomoLinear(8, 6, prob=0.3, weight=0.1, seed=1, mode=bm.TrainingMode())
+ assert isinstance(f.weight, bm.Variable)
+
+
+def test_event_jitfp_uniform_forward():
+ bm.random.seed(123)
+ f = bp.dnn.EventJitFPUniformLinear(8, 6, prob=0.3, w_low=-0.1, w_high=0.1, seed=123)
+ x = bm.asarray(bm.random.random((8,)) < 0.3, dtype=float)
+ assert f(x).shape == (6,)
+ assert f.get_conn_matrix().shape == (6, 8)
+ assert f(bm.asarray(bm.random.random((4, 8)) < 0.3, dtype=float)).shape == (4, 6)
+ assert f(bm.asarray(bm.random.random((2, 3, 8)) < 0.3, dtype=float)).shape == (2, 3, 6)
+
+
+def test_event_jitfp_normal_forward():
+ bm.random.seed(123)
+ f = bp.dnn.EventJitFPNormalLinear(8, 6, prob=0.3, w_mu=0.0, w_sigma=0.1, seed=123)
+ x = bm.asarray(bm.random.random((8,)) < 0.3, dtype=float)
+ assert f(x).shape == (6,)
+ assert f.get_conn_matrix().shape == (6, 8)
+ assert f(bm.asarray(bm.random.random((4, 8)) < 0.3, dtype=float)).shape == (4, 6)
+ assert f(bm.asarray(bm.random.random((2, 3, 8)) < 0.3, dtype=float)).shape == (2, 3, 6)
+
+
+def test_jit_layers_zero_dim_input_raises():
+ # the ``else: raise ValueError`` guards in each Jit ``update`` (lines
+ # 849, 929, 1009, 1088, 1168, 1248)
+ layers = [
+ bp.dnn.JitFPHomoLinear(8, 6, prob=0.3, weight=0.1, seed=1),
+ bp.dnn.JitFPUniformLinear(8, 6, prob=0.3, w_low=-0.1, w_high=0.1, seed=1),
+ bp.dnn.JitFPNormalLinear(8, 6, prob=0.3, w_mu=0.0, w_sigma=0.1, seed=1),
+ bp.dnn.EventJitFPHomoLinear(8, 6, prob=0.3, weight=0.1, seed=1),
+ bp.dnn.EventJitFPUniformLinear(8, 6, prob=0.3, w_low=-0.1, w_high=0.1, seed=1),
+ bp.dnn.EventJitFPNormalLinear(8, 6, prob=0.3, w_mu=0.0, w_sigma=0.1, seed=1),
+ ]
+ for layer in layers:
+ with pytest.raises(ValueError):
+ layer.update(jnp.asarray(1.0))
+
+
+if __name__ == '__main__':
+ import sys
+ sys.exit(pytest.main([__file__, '-q']))
diff --git a/tests/audit/test_boost_misc.py b/tests/audit/test_boost_misc.py
new file mode 100644
index 000000000..49cb636f5
--- /dev/null
+++ b/tests/audit/test_boost_misc.py
@@ -0,0 +1,871 @@
+# -*- coding: utf-8 -*-
+"""Coverage-boost tests for the BrainPy v2.7.8 audit (mid-laggard files).
+
+This module is part of the audit coverage-boost effort. The sibling
+``tests/audit/test_*_fixes.py`` files pin the audited regressions; this file
+targets the remaining *uncovered* branches of five mid-laggard source files so
+their line coverage reaches ~90% wherever the live code allows:
+
+ * ``brainpy/algorithms/offline.py`` (78% -> exercise every algorithm
+ via closed-form and gradient-descent, registry helpers, ``__repr__``).
+ * ``brainpy/integrators/sde/normal.py`` (81% -> Euler/Heun/Milstein/
+ Milstein2 for Ito & Stratonovich, scalar & vector Wiener, JointEq diffusion).
+ * ``brainpy/running/jax_multiprocessing.py`` (85% -> vectorize/parallelize map
+ with list & dict inputs, clear_buffer on/off, TypeError guards).
+ * ``brainpy/integrators/joint_eq.py`` (83% -> nested JointEq, derivative
+ evaluation, the dead ``_std_func`` helper, conflicting-name DiffEqError).
+ * ``brainpy/dyn/synapses/abstract_models.py`` (88% -> DualExponV2 forward pass,
+ ``add_current``/``return_info`` of every synapse model).
+
+Known dead / broken lines that cannot be driven to success are pinned with
+``pytest.raises`` and documented inline (see NOTE comments):
+ * ``LogisticRegression.call`` -> ``IndexError`` (flattens targets then indexes
+ ``targets.shape[1]``); its body (offline.py 391-415) is unreachable.
+ * Vector-Wiener Milstein/Milstein2 -> broadcasting ``ValueError`` (normal.py).
+
+The tests use tiny inputs / few iterations so the whole module runs well under
+the 4-minute budget. Mutated global state (dt) is restored by a fixture.
+"""
+
+import numpy as np
+import jax.numpy as jnp
+import pytest
+
+import brainpy as bp
+import brainpy.math as bm
+
+DiffEqError = bp.errors.DiffEqError
+
+
+@pytest.fixture(autouse=True)
+def _restore_dt():
+ """Restore the global integration time step mutated by some tests."""
+ old = bm.get_dt()
+ yield
+ bm.set_dt(old)
+
+
+# =========================================================================== #
+# offline.py -- regression algorithms (closed-form + gradient descent),
+# registry helpers, __repr__.
+# =========================================================================== #
+
+def _reg_xy(n=24, d=2, seed=0):
+ rng = np.random.RandomState(seed)
+ x = jnp.asarray(rng.randn(n, d).astype('float32'))
+ w_true = jnp.asarray(rng.randn(d, 1).astype('float32'))
+ y = jnp.asarray(np.asarray(x) @ np.asarray(w_true))
+ return x, y
+
+
+def test_offline_linear_regression_both_paths():
+ from brainpy.algorithms.offline import LinearRegression
+
+ bm.random.seed(0)
+ x, y = _reg_xy()
+ # closed-form lstsq path
+ w = LinearRegression()(y, x)
+ assert w.shape == (2, 1)
+ assert not bool(jnp.isnan(w).any())
+ # gradient-descent path
+ w_gd = LinearRegression(gradient_descent=True, max_iter=30, learning_rate=1e-3)(y, x)
+ assert w_gd.shape == (2, 1)
+ assert not bool(jnp.isnan(w_gd).any())
+
+
+def test_offline_ridge_regression_both_paths_and_repr():
+ from brainpy.algorithms.offline import RidgeRegression
+
+ bm.random.seed(1)
+ x, y = _reg_xy()
+ # closed-form ridge (alpha > 0 -> penalty branch)
+ w = RidgeRegression(alpha=1e-3)(y, x)
+ assert w.shape == (2, 1)
+ # gradient-descent ridge
+ w_gd = RidgeRegression(alpha=1e-4, gradient_descent=True, max_iter=30)(y, x)
+ assert w_gd.shape == (2, 1)
+ # __repr__ override (offline.py line 287)
+ assert 'RidgeRegression' in repr(RidgeRegression(alpha=0.5))
+
+
+def test_offline_ridge_beta_deprecation_warning():
+ """Cover the deprecated ``beta=`` branch (offline.py lines 253-256)."""
+ from brainpy.algorithms.offline import RidgeRegression
+
+ with pytest.warns(UserWarning):
+ model = RidgeRegression(beta=0.25)
+ assert model.regularizer.alpha == 0.25
+
+
+def test_offline_lasso_regression_and_predict():
+ from brainpy.algorithms.offline import LassoRegression
+
+ bm.random.seed(2)
+ x, y = _reg_xy()
+ model = LassoRegression(alpha=0.05, degree=2, max_iter=30)
+ w = model(y, x)
+ assert not bool(jnp.isnan(w).any())
+ # predict() path applies polynomial_features + normalize (lines 342-344)
+ pred = model.predict(w, x)
+ assert pred.shape[0] == x.shape[0]
+
+
+def test_offline_elastic_net_regression_fit_and_broken_predict():
+ """ElasticNet ``call`` (fit) runs through the gradient-descent solver.
+
+ NOTE: ``ElasticNetRegression.predict`` is inconsistent with ``call``: ``call``
+ builds features via ``polynomial_features(inputs, degree=...)`` (default
+ ``add_bias=True`` -> bias column added), while ``predict`` passes
+ ``add_bias=self.add_bias`` (default ``False``). The feature counts therefore
+ differ (7 vs 6) and ``predict`` raises a shape ``TypeError``. We exercise the
+ fit path and pin the broken predict.
+ """
+ from brainpy.algorithms.offline import ElasticNetRegression
+
+ bm.random.seed(3)
+ x, y = _reg_xy()
+ model = ElasticNetRegression(alpha=0.05, degree=2, l1_ratio=0.5, max_iter=30)
+ w = model(y, x)
+ assert not bool(jnp.isnan(w).any())
+ with pytest.raises((TypeError, ValueError)):
+ model.predict(w, x)
+
+
+def test_offline_polynomial_regression_and_predict():
+ from brainpy.algorithms.offline import PolynomialRegression
+
+ bm.random.seed(4)
+ x, y = _reg_xy()
+ model = PolynomialRegression(degree=2, gradient_descent=True, max_iter=20)
+ w = model(y, x)
+ assert not bool(jnp.isnan(w).any())
+ pred = model.predict(w, x) # lines 450-452
+ assert pred.shape[0] == x.shape[0]
+
+
+def test_offline_polynomial_ridge_regression_and_predict():
+ from brainpy.algorithms.offline import PolynomialRidgeRegression
+
+ bm.random.seed(5)
+ x, y = _reg_xy()
+ # closed-form (gradient_descent=False) with add_bias -> exercises the
+ # intercept-not-penalized branch in RidgeRegression.call.
+ model = PolynomialRidgeRegression(alpha=1e-2, degree=2, add_bias=True,
+ gradient_descent=False)
+ w = model(y, x)
+ assert not bool(jnp.isnan(w).any())
+ pred = model.predict(w, x) # lines 487-489
+ assert pred.shape[0] == x.shape[0]
+
+
+def test_offline_logistic_regression_known_indexerror():
+ """NOTE: dead/broken path. ``LogisticRegression.call`` flattens ``targets``
+ to 1-D and then indexes ``targets.shape[1]``, raising ``IndexError``. Its
+ body (offline.py lines 391-415) is unreachable; we pin the broken behavior.
+ """
+ from brainpy.algorithms.offline import LogisticRegression
+
+ bm.random.seed(6)
+ rng = np.random.RandomState(6)
+ x = jnp.asarray(rng.randn(20, 2).astype('float32'))
+ y = jnp.asarray((np.asarray(x)[:, :1] > 0).astype('float32'))
+ with pytest.raises(IndexError):
+ LogisticRegression(max_iter=50)(y, x)
+
+
+def test_offline_registry_helpers_and_base_repr():
+ from brainpy.algorithms import offline
+ from brainpy.algorithms.offline import OfflineAlgorithm, LinearRegression
+
+ methods = offline.get_supported_offline_methods()
+ for name in ('linear', 'lstsq', 'ridge', 'lasso', 'logistic',
+ 'polynomial', 'polynomial_ridge', 'elastic_net'):
+ assert name in methods
+
+ # get() success + failure
+ assert offline.get('ridge') is offline.RidgeRegression
+ with pytest.raises(ValueError):
+ offline.get('does_not_exist')
+
+ # base OfflineAlgorithm.__repr__ (offline.py line 104)
+ assert repr(LinearRegression()) == 'LinearRegression'
+
+ # register_offline_method: success then duplicate + type guards
+ inst = LinearRegression()
+ unique = 'boost_misc_custom_method'
+ if unique not in offline.name2func:
+ offline.register_offline_method(unique, inst)
+ assert unique in offline.get_supported_offline_methods()
+ with pytest.raises(ValueError): # duplicate name (line 570)
+ offline.register_offline_method(unique, inst)
+ with pytest.raises(ValueError): # not an OfflineAlgorithm (line 572)
+ offline.register_offline_method('boost_misc_bad', object())
+ # restore global registry state
+ offline.name2func.pop(unique, None)
+
+
+def test_offline_base_call_not_implemented():
+ """Cover OfflineAlgorithm.call NotImplementedError (offline.py line 101)."""
+ from brainpy.algorithms.offline import OfflineAlgorithm
+
+ base = OfflineAlgorithm()
+ with pytest.raises(NotImplementedError):
+ base(jnp.ones((2, 1)), jnp.ones((2, 1)))
+
+
+def test_offline_regression_initialize_noop():
+ """Cover the no-op ``RegressionAlgorithm.initialize`` (offline.py line 141),
+ which the framework never calls but is part of the public API."""
+ from brainpy.algorithms.offline import LinearRegression
+
+ model = LinearRegression()
+ assert model.initialize(1, 2, foo='bar') is None
+
+
+def test_offline_check_data_flatten_3d():
+ """Cover ``_check_data_2d_atls`` flatten branch (offline.py line 111) and the
+ ndim<2 ValueError (line 109)."""
+ from brainpy.algorithms.offline import _check_data_2d_atls
+
+ flat = _check_data_2d_atls(bm.ones((2, 3, 4)))
+ assert flat.ndim == 2
+ with pytest.raises(ValueError):
+ _check_data_2d_atls(bm.ones(5))
+
+
+# =========================================================================== #
+# sde/normal.py -- Euler / Heun / Milstein / Milstein2 integrators.
+# =========================================================================== #
+
+def test_sde_euler_scalar_wiener_ito_and_stratonovich():
+ bm.random.seed(10)
+ g = lambda x, t: jnp.ones_like(x) * 0.1
+ for itype in ['Ito', 'Stratonovich']:
+ intg = bp.sdeint(lambda x, t: -x, g, method='euler', intg_type=itype)
+ x = jnp.array([1.0])
+ for i in range(3):
+ x = intg(x, i * 0.01, dt=0.01)
+ assert np.all(np.isfinite(np.asarray(x)))
+
+
+def test_sde_heun_stratonovich_runs_and_ito_rejected():
+ bm.random.seed(11)
+ g = lambda x, t: jnp.ones_like(x) * 0.1
+ intg = bp.sdeint(lambda x, t: -x, g, method='heun', intg_type='Stratonovich')
+ out = intg(jnp.array([1.0]), 0.0, dt=0.01)
+ assert np.all(np.isfinite(np.asarray(out)))
+ # Heun only supports Stratonovich -> IntegratorError on Ito.
+ with pytest.raises(bp.errors.IntegratorError):
+ bp.sdeint(lambda x, t: -x, g, method='heun', intg_type='Ito')
+
+
+def test_sde_milstein_scalar_wiener_ito_and_stratonovich():
+ bm.random.seed(12)
+ g = lambda x, t: jnp.ones_like(x) * 0.1
+ for itype in ['Ito', 'Stratonovich']:
+ intg = bp.sdeint(lambda x, t: -x, g, method='milstein', intg_type=itype)
+ x = jnp.array([1.0])
+ for i in range(3):
+ x = intg(x, i * 0.01, dt=0.01)
+ assert np.all(np.isfinite(np.asarray(x)))
+
+
+def test_sde_milstein2_scalar_wiener_ito_and_stratonovich():
+ bm.random.seed(13)
+ g = lambda x, t: jnp.ones_like(x) * 0.1
+ for method in ['milstein2', 'milstein_grad_free']:
+ for itype in ['Ito', 'Stratonovich']:
+ intg = bp.sdeint(lambda x, t: -x, g, method=method, intg_type=itype)
+ out = intg(jnp.array([1.0]), 0.0, dt=0.01)
+ assert np.all(np.isfinite(np.asarray(out)))
+
+
+def test_sde_milstein_multivariable_jointeq_diffusion():
+ """Milstein with a JointEq f/g (multi-variable) exercises ``_get_g_grad``
+ recursion over JointEq sub-equations (normal.py lines 286-292)."""
+ bm.random.seed(14)
+
+ def dV(V, t, w):
+ return -V + w
+
+ def dw(w, t, V):
+ return -w + V
+
+ def gV(V, t, w):
+ return jnp.ones_like(V) * 0.1
+
+ def gw(w, t, V):
+ return jnp.ones_like(w) * 0.1
+
+ f = bp.JointEq(dV, dw)
+ g = bp.JointEq(gV, gw)
+ intg = bp.sdeint(f, g, method='milstein', intg_type='Ito')
+ out = intg(jnp.array([1.0]), jnp.array([0.5]), 0.0, dt=0.01)
+ assert len(out) == 2
+ assert all(np.all(np.isfinite(np.asarray(o))) for o in out)
+
+
+def test_sde_milstein_multivar_non_jointeq_raises():
+ """A plain (non-JointEq) multi-variable f/g triggers the
+ ``_get_g_grad`` failure path -> DiffEqError (normal.py 315-319 region)."""
+ bm.random.seed(15)
+
+ def f(x, y, t):
+ return -x, -y
+
+ def g(x, y, t):
+ return 0.1 * jnp.ones_like(x), 0.1 * jnp.ones_like(y)
+
+ with pytest.raises(DiffEqError):
+ intg = bp.sdeint(f, g, method='milstein')
+ intg(jnp.array([1.0]), jnp.array([2.0]), 0.0, dt=0.01)
+
+
+def test_sde_euler_vector_wiener_ito_runs():
+ """Cover the Euler VECTOR_WIENER Ito summation branch (normal.py ~155-156)."""
+ bm.random.seed(16)
+ intg = bp.sdeint(lambda x, t: -x, lambda x, t: 0.1 * jnp.ones((3, 2)),
+ method='euler', wiener_type='vector', intg_type='Ito')
+ out = intg(jnp.ones(3), 0.0, dt=0.01)
+ assert np.asarray(out).shape == (3,)
+ assert np.all(np.isfinite(np.asarray(out)))
+
+
+def test_sde_euler_vector_wiener_stratonovich_known_broadcast_error():
+ """NOTE: broken path. The Euler (Euler-Heun) VECTOR_WIENER *Stratonovich*
+ branch adds ``g(Y)`` of shape ``(3, 2)`` to a state of shape ``(3,)`` without
+ summing over the noise axis, so it raises a broadcasting error. We still
+ exercise the branch (normal.py ~168-178) up to the failure point and pin it.
+ """
+ bm.random.seed(17)
+ intg = bp.sdeint(lambda x, t: -x, lambda x, t: 0.1 * jnp.ones((3, 2)),
+ method='euler', wiener_type='vector', intg_type='Stratonovich')
+ with pytest.raises((ValueError, TypeError)):
+ intg(jnp.ones(3), 0.0, dt=0.01)
+
+
+def test_sde_euler_vector_wiener_scalar_diffusion_guard():
+ """Cover the vector-wiener scalar-diffusion ValueError guard (normal.py 138-143)."""
+ bm.random.seed(18)
+ intg = bp.sdeint(lambda x, t: -x, lambda x, t: jnp.float32(0.1),
+ method='euler', wiener_type='vector')
+ with pytest.raises(ValueError):
+ intg(jnp.array([1.0]), 0.0, dt=0.01)
+
+
+def test_sde_euler_single_var_drift_not_tensor_guard():
+ """Cover the single-variable drift-not-a-tensor ValueError (normal.py ~117)."""
+ bm.random.seed(19)
+ intg = bp.sdeint(lambda x, t: -1.0, lambda x, t: jnp.ones_like(x) * 0.1,
+ method='euler')
+ with pytest.raises(ValueError):
+ intg(jnp.array([1.0]), 0.0, dt=0.01)
+
+
+def test_sde_milstein_vector_wiener_known_broadcast_error():
+ """NOTE: broken path. The Milstein / Milstein2 integrators do not correctly
+ broadcast the diffusion-gradient term for VECTOR_WIENER noise, so a
+ vector-wiener Milstein step raises a broadcasting ``ValueError`` (or
+ ``TypeError``). We pin the failure while still exercising the branch up to
+ the failure point (normal.py vector-wiener Milstein code).
+ """
+ bm.random.seed(20)
+
+ def gv(x, t):
+ return 0.1 * jnp.ones((3, 2))
+
+ for method in ['milstein', 'milstein2']:
+ intg = bp.sdeint(lambda x, t: -x, gv, method=method,
+ wiener_type='vector', intg_type='Ito')
+ with pytest.raises((ValueError, TypeError)):
+ intg(jnp.ones(3), 0.0, dt=0.01)
+
+
+def test_sde_milstein2_vector_wiener_scalar_diffusion_guard():
+ """Cover the Milstein2 vector-wiener scalar-diffusion guard (normal.py 469-474)."""
+ bm.random.seed(21)
+ intg = bp.sdeint(lambda x, t: -x, lambda x, t: jnp.float32(0.1),
+ method='milstein2', wiener_type='vector')
+ with pytest.raises(ValueError):
+ intg(jnp.array([1.0]), 0.0, dt=0.01)
+
+
+def test_sde_euler_multivar_drift_not_list_guard():
+ """Cover the multi-variable drift-not-list ValueError (normal.py ~121-124)."""
+ bm.random.seed(26)
+
+ def f(x, y, t): # returns a single tensor, not a list/tuple
+ return -x
+
+ def g(x, y, t):
+ return 0.1 * jnp.ones_like(x), 0.1 * jnp.ones_like(y)
+
+ intg = bp.sdeint(f, g, method='euler')
+ with pytest.raises(ValueError):
+ intg(jnp.array([1.0]), jnp.array([2.0]), 0.0, dt=0.01)
+
+
+def test_sde_euler_multivar_diffusion_not_list_guard():
+ """Cover the multi-variable diffusion-not-list ValueError (normal.py 134-137)."""
+ bm.random.seed(27)
+
+ def f(x, y, t):
+ return -x, -y
+
+ def g(x, y, t): # returns a single tensor, not a list/tuple
+ return 0.1 * jnp.ones_like(x)
+
+ intg = bp.sdeint(f, g, method='euler')
+ with pytest.raises(ValueError):
+ intg(jnp.array([1.0]), jnp.array([2.0]), 0.0, dt=0.01)
+
+
+def test_sde_milstein_single_var_diffusion_not_tensor_guard():
+ """Cover the single-variable Milstein diffusion-not-a-tensor ValueError
+ (normal.py ~342)."""
+ bm.random.seed(28)
+ intg = bp.sdeint(lambda x, t: -x, lambda x, t: 0.1, method='milstein')
+ with pytest.raises(ValueError):
+ intg(jnp.array([1.0]), 0.0, dt=0.01)
+
+
+def test_sde_milstein2_multivariable_runs():
+ """Cover the Milstein2 multi-variable drift/diffusion branches (normal.py
+ 446-468 multi-var paths)."""
+ bm.random.seed(22)
+
+ def f(x, y, t):
+ return -x, -y
+
+ def g(x, y, t):
+ return 0.1 * jnp.ones_like(x), 0.1 * jnp.ones_like(y)
+
+ intg = bp.sdeint(f, g, method='milstein2', intg_type='Ito')
+ out = intg(jnp.array([1.0]), jnp.array([2.0]), 0.0, dt=0.01)
+ assert len(out) == 2
+ assert all(np.all(np.isfinite(np.asarray(o))) for o in out)
+
+
+def test_sde_milstein2_vector_wiener_known_broadcast_error():
+ """NOTE: broken path. Like the gradient Milstein, the derivative-free
+ Milstein2 VECTOR_WIENER branch (normal.py 495-509) mis-broadcasts the
+ ``(diffusion_bar - diffusion)`` correction term ``(3, 2)`` against the
+ summed-noise state ``(3,)`` and raises a broadcasting ``ValueError``. We
+ exercise the branch up to the failure and pin it.
+ """
+ bm.random.seed(23)
+ intg = bp.sdeint(lambda x, t: -x, lambda x, t: 0.1 * jnp.ones((3, 2)),
+ method='milstein2', wiener_type='vector', intg_type='Ito')
+ with pytest.raises((ValueError, TypeError)):
+ intg(jnp.ones(3), 0.0, dt=0.01)
+
+
+def test_sde_exponential_euler_scalar_and_vector():
+ """Cover the SDE ExponentialEuler build/integral (normal.py 560-646)."""
+ bm.random.seed(24)
+ # scalar wiener
+ for method in ['exp_euler', 'exponential_euler']:
+ intg = bp.sdeint(lambda x, t: -x, lambda x, t: jnp.ones_like(x) * 0.1,
+ method=method)
+ x = jnp.array([1.0])
+ for i in range(3):
+ x = intg(x, i * 0.01, dt=0.01)
+ assert np.all(np.isfinite(np.asarray(x)))
+ # vector wiener
+ intg = bp.sdeint(lambda x, t: -x, lambda x, t: 0.1 * jnp.ones((3, 2)),
+ method='exp_euler', wiener_type='vector')
+ out = intg(jnp.ones(3), 0.0, dt=0.01)
+ assert np.asarray(out).shape == (3,)
+ assert np.all(np.isfinite(np.asarray(out)))
+
+
+def test_sde_exponential_euler_jointeq_multivariable():
+ """Cover the ExponentialEuler JointEq build + multi-variable diffusion path
+ (normal.py _build_integrator recursion, 624-646)."""
+ bm.random.seed(25)
+
+ def dV(V, t, w):
+ return -V + w
+
+ def dw(w, t, V):
+ return -w + V
+
+ def g(V, w, t):
+ return jnp.ones_like(V) * 0.1, jnp.ones_like(w) * 0.1
+
+ intg = bp.sdeint(bp.JointEq(dV, dw), g, method='exp_euler')
+ out = intg(jnp.array([1.0]), jnp.array([0.5]), 0.0, dt=0.01)
+ assert len(out) == 2
+ assert all(np.all(np.isfinite(np.asarray(o))) for o in out)
+
+
+def test_sde_exponential_euler_rejects_stratonovich():
+ """Cover the ExponentialEuler Stratonovich NotImplementedError (normal.py 570-573)."""
+ with pytest.raises(NotImplementedError):
+ bp.sdeint(lambda x, t: -x, lambda x, t: jnp.ones((1,)) * 0.1,
+ method='exp_euler', intg_type='Stratonovich')
+
+
+def test_sde_dead_codegen_helpers():
+ """NOTE: dead code. ``df_and_dg``, ``dfdt`` and ``noise_terms`` are
+ module-level code-generation helpers that are never called anywhere in the
+ package (the live SDE codegen lives in ``srk_scalar.py``). We call them
+ directly with throwaway lists to cover lines 37-60.
+ """
+ from brainpy.integrators.sde import normal
+
+ lines = []
+ normal.df_and_dg(lines, ['V', 'w'], ['t', 'Iext'])
+ assert any('f(' in ln for ln in lines)
+ lines2 = []
+ normal.dfdt(lines2, ['V', 'w'])
+ assert any('_dfdt' in ln for ln in lines2)
+ lines3 = []
+ normal.noise_terms(lines3, ['V', 'w'])
+ assert any('_dW' in ln for ln in lines3)
+
+
+# =========================================================================== #
+# jax_multiprocessing.py -- vectorize / parallelize map.
+# =========================================================================== #
+
+def test_jax_vectorize_map_list_input():
+ from brainpy.running.jax_multiprocessing import jax_vectorize_map
+
+ def f(a, b):
+ return a + b
+
+ a = bm.arange(6.0)
+ b = bm.arange(6.0) * 2
+ out = jax_vectorize_map(f, [a, b], num_parallel=3, clear_buffer=False)
+ assert np.allclose(np.asarray(out), np.asarray(a) + np.asarray(b))
+
+
+def test_jax_vectorize_map_dict_input_and_clear_buffer():
+ from brainpy.running.jax_multiprocessing import jax_vectorize_map
+
+ def f(a, b):
+ return a * b
+
+ a = bm.arange(6.0)
+ b = bm.arange(6.0) + 1
+ expected = np.asarray(a) * np.asarray(b)
+ # dict input + clear_buffer=True -> np.asarray / concatenate branch.
+ # NOTE: clear_buffer=True normally calls ``bm.clear_buffer_memory()``, a
+ # PROCESS-GLOBAL operation that deletes EVERY live device array (module-level
+ # constants and persistent Variables in *other* test modules included),
+ # poisoning the rest of the shared pytest session. We patch it to a no-op so
+ # the clear_buffer code path is still exercised for coverage without nuking
+ # the session.
+ _orig_clear = bm.clear_buffer_memory
+ bm.clear_buffer_memory = lambda *a, **k: None
+ try:
+ out = jax_vectorize_map(f, {'a': a, 'b': b}, num_parallel=2, clear_buffer=True)
+ finally:
+ bm.clear_buffer_memory = _orig_clear
+ assert np.allclose(np.asarray(out), expected)
+
+
+def test_jax_vectorize_map_type_error_and_length_mismatch():
+ from brainpy.running.jax_multiprocessing import jax_vectorize_map
+
+ # TypeError: arguments must be sequence or dict (line 60).
+ with pytest.raises(TypeError):
+ jax_vectorize_map(lambda a: a, 123, num_parallel=2)
+ # ValueError: unequal element lengths (line 66-67).
+ with pytest.raises(ValueError):
+ jax_vectorize_map(lambda a, b: a + b,
+ [bm.arange(4.0), bm.arange(3.0)], num_parallel=2)
+
+
+def test_jax_parallelize_map_list_and_dict():
+ from brainpy.running.jax_multiprocessing import jax_parallelize_map
+
+ # Default single-CPU pmap can map up to the local device count (1).
+ def f(a, b):
+ return a + b
+
+ a = bm.arange(2.0)
+ b = bm.arange(2.0) * 3
+ expected = np.asarray(a) + np.asarray(b)
+ out_list = jax_parallelize_map(f, [a, b], num_parallel=1, clear_buffer=False)
+ assert np.allclose(np.asarray(out_list), expected)
+
+ a = bm.arange(2.0)
+ b = bm.arange(2.0) * 3
+ expected = np.asarray(a) + np.asarray(b)
+ # See note above: patch the process-global buffer wipe to a no-op so the
+ # clear_buffer=True branch is covered without poisoning the shared session.
+ _orig_clear = bm.clear_buffer_memory
+ bm.clear_buffer_memory = lambda *a, **k: None
+ try:
+ out_dict = jax_parallelize_map(f, {'a': a, 'b': b}, num_parallel=1,
+ clear_buffer=True)
+ finally:
+ bm.clear_buffer_memory = _orig_clear
+ assert np.allclose(np.asarray(out_dict), expected)
+
+
+def test_jax_parallelize_map_type_error():
+ from brainpy.running.jax_multiprocessing import jax_parallelize_map
+
+ with pytest.raises(TypeError): # line 125
+ jax_parallelize_map(lambda a: a, 3.14, num_parallel=1)
+ with pytest.raises(ValueError): # length mismatch (line 131)
+ jax_parallelize_map(lambda a, b: a + b,
+ [bm.arange(2.0), bm.arange(1.0)], num_parallel=1)
+
+
+def test_jax_map_empty_input_returns_none():
+ """Cover the ``res_tree is None -> return None`` branch of both maps (the loop
+ body never executes for an empty task list): jax_multiprocessing 88-89, 155-156."""
+ from brainpy.running.jax_multiprocessing import (jax_vectorize_map,
+ jax_parallelize_map)
+
+ assert jax_vectorize_map(lambda a: a, [bm.zeros((0,))], num_parallel=1) is None
+ assert jax_parallelize_map(lambda a: a, [bm.zeros((0,))], num_parallel=1) is None
+
+
+# =========================================================================== #
+# joint_eq.py -- nested JointEq, derivative evaluation, dead _std_func,
+# conflicting-name DiffEqError.
+# =========================================================================== #
+
+def test_jointeq_nested_derivative_evaluation():
+ a, b = 0.02, 0.20
+ dV = lambda V, t, u, Iext: 0.04 * V * V + 5 * V + 140 - u + Iext
+ du = lambda u, t, V: a * (b * V - u)
+ eq = bp.JointEq(dV, du)
+
+ dw = lambda w, t, V: a * (b * V - w)
+ eq2 = bp.JointEq(eq, dw) # nested JointEq
+
+ # arg_keys collected across nested equations: V, u, w are state variables.
+ assert eq2.arg_keys[:3] == ['V', 'u', 'w']
+ assert 'Iext' in eq2.arg_keys # positional parameter propagated
+
+ # derivative evaluation returns one value per state variable (3).
+ res = eq2(-65.0, -14.0, -14.0, 0.0, Iext=10.0)
+ assert len(res) == 3
+ assert all(np.isfinite(float(r)) for r in res)
+
+
+def test_jointeq_call_with_keyword_argument():
+ def dV(V, t, w, gain=0.5):
+ return -V + gain * w
+
+ def dw(w, t, V, gain=0.5):
+ return -w + gain * V
+
+ eq = bp.JointEq([dV, dw]) # list-form exercises _check_eqs recursion
+ assert 'gain' in eq.kwarg_keys
+ res = eq(1.0, 2.0, 0.0, gain=0.3)
+ assert len(res) == 2
+ assert all(np.isfinite(float(r)) for r in res)
+
+
+def test_jointeq_conflicting_kwarg_name_with_state_variable():
+ """Cover the 'keyword argument conflicts with existing name' DiffEqError
+ (joint_eq.py lines 189-194)."""
+ def dV(V, t, w):
+ return -V + w
+
+ def dw(w, t, V=1.0): # kwarg 'V' reuses the state variable name 'V'
+ return -w + V
+
+ with pytest.raises(DiffEqError):
+ bp.JointEq(dV, dw)
+
+
+def test_jointeq_conflicting_kwarg_defaults():
+ def dV(V, t, a=1.0):
+ return -V + a
+
+ def dw(w, t, a=2.0): # same kwarg name, different default
+ return -w + a
+
+ with pytest.raises(DiffEqError):
+ bp.JointEq(dV, dw)
+
+
+def test_jointeq_missing_time_variable_and_var_kinds():
+ with pytest.raises(ValueError): # no 't' parameter (line 58)
+ bp.JointEq(lambda V, w: -V)
+ with pytest.raises(DiffEqError): # *args (VAR_POSITIONAL)
+ bp.JointEq(lambda V, t, *extra: -V)
+ with pytest.raises(DiffEqError): # **kwargs (VAR_KEYWORD)
+ bp.JointEq(lambda V, t, **extra: -V)
+ with pytest.raises(DiffEqError): # non-callable element
+ bp.JointEq(123)
+
+
+def test_jointeq_rejects_keyword_only_and_positional_only():
+ """Cover the KEYWORD_ONLY (line 40) and POSITIONAL_ONLY (line 43) rejection
+ branches of ``_get_args``."""
+ # KEYWORD_ONLY: parameter after a bare ``*``.
+ def kw_only(V, t, *, x):
+ return -V + x
+
+ with pytest.raises(DiffEqError):
+ bp.JointEq(kw_only)
+
+ # POSITIONAL_ONLY: parameter before ``/`` (needs an exec'd def for the syntax).
+ ns = {}
+ exec("def pos_only(V, t, x, /):\n return -V + x", ns)
+ with pytest.raises(DiffEqError):
+ bp.JointEq(ns['pos_only'])
+
+
+def test_jointeq_call_with_kwarg_passed_positionally():
+ """Cover the ``__call__`` branch where a trailing positional arg maps onto a
+ keyword key (joint_eq.py line 235)."""
+ def dV(V, t, w, gain=0.5):
+ return -V + gain * w
+
+ def dw(w, t, V, gain=0.5):
+ return -w + gain * V
+
+ eq = bp.JointEq(dV, dw)
+ # arg_keys = [V, w, t]; passing a 4th positional arg maps onto kwarg 'gain'.
+ res = eq(1.0, 2.0, 0.0, 0.3)
+ assert len(res) == 2
+ assert all(np.isfinite(float(r)) for r in res)
+
+
+def test_jointeq_std_func_dependency_from_state_vars():
+ """Cover the _std_func branch where a positional dependency is resolved from
+ the state-variable tuple via ``all_vars.index`` (joint_eq.py line 76) and the
+ branch where the dependency is supplied via keyword (line 72)."""
+ from brainpy.integrators.joint_eq import _std_func
+
+ def dV(V, t, w): # 'w' is a dependency that lives in all_vars
+ return -V + w
+
+ wrapper = _std_func(dV, ['V', 'w'])
+ out = wrapper(0.0, 1.0, 2.0) # w resolved positionally (line 76)
+ assert np.isfinite(float(out))
+ # 'w' supplied as a keyword -> line 72 (par in args_and_kwargs).
+ out2 = wrapper(0.0, 1.0, 2.0, w=3.0)
+ assert np.isfinite(float(out2))
+
+
+def test_jointeq_duplicate_state_variable_error():
+ """Cover the duplicate-state-variable DiffEqError branch (joint_eq.py 157)."""
+ def dV1(V, t):
+ return -V
+
+ def dV2(V, t): # 'V' reused as a state variable
+ return -2 * V
+
+ with pytest.raises(DiffEqError):
+ bp.JointEq(dV1, dV2)
+
+
+def test_jointeq_dead_std_func_helper():
+ """NOTE: ``_std_func`` is dead code (defined but never called anywhere in the
+ package). We invoke it directly to exercise lines 64-82. It builds a wrapper
+ that re-routes positional state vars / dependency lookups before calling the
+ underlying derivative function.
+ """
+ from brainpy.integrators.joint_eq import _std_func
+
+ def dV(V, t, w, gain=0.5):
+ return -V + gain * w
+
+ all_vars = ['V', 'w']
+ wrapper = _std_func(dV, all_vars)
+ # call(t, *vars, **args_and_kwargs): V and w are positional state vars;
+ # 'w' is a dependency that is looked up from `vars` via all_vars.index.
+ out = wrapper(0.0, 1.0, 2.0) # V=1.0, w=2.0
+ assert np.isfinite(float(out))
+ # pass the kwarg explicitly to cover the kwargs branch (lines 77-79).
+ out2 = wrapper(0.0, 1.0, 2.0, gain=0.9)
+ assert np.isfinite(float(out2))
+
+
+def test_jointeq_std_func_missing_dependency_raises():
+ """Cover the 'Missing {par}' DiffEqError inside _std_func (line 75)."""
+ from brainpy.integrators.joint_eq import _std_func
+
+ def dV(V, t, missing_dep):
+ return -V + missing_dep
+
+ wrapper = _std_func(dV, ['V']) # 'missing_dep' is neither passed nor a var
+ with pytest.raises(DiffEqError):
+ wrapper(0.0, 1.0)
+
+
+# =========================================================================== #
+# abstract_models.py -- synapse forward passes + add_current / return_info.
+# =========================================================================== #
+
+def _share(t=0.0, dt=0.1, i=0):
+ bp.share.save(t=t, dt=dt, i=i)
+
+
+def test_dualexponv2_forward_add_current_return_info():
+ """DualExponV2 is not covered by the sibling synapse-forward sweep.
+ Drive update(x) (add_current branch) and return_info (lines 404-437)."""
+ bm.random.seed(50)
+ syn = bp.dyn.DualExponV2(3, tau_decay=5.0, tau_rise=1.0)
+ syn.reset_state()
+ _share()
+ spike = bm.asarray([1.0, 0.0, 1.0])
+ out = syn.update(spike) # x not None -> add_current branch
+ assert jnp.asarray(out).shape == (3,)
+ assert jnp.all(jnp.isfinite(jnp.asarray(out)))
+ # update with no current (x=None branch)
+ out_none = syn.update()
+ assert jnp.all(jnp.isfinite(jnp.asarray(out_none)))
+ # return_info -> ReturnInfo with a callable (line 436)
+ info = syn.return_info()
+ assert info is not None
+
+
+def test_expon_forward_add_current_return_info():
+ bm.random.seed(51)
+ syn = bp.dyn.Expon(3, tau=5.0)
+ syn.reset_state()
+ _share()
+ out = syn.update(bm.ones(3)) # x not None -> add_current
+ assert jnp.asarray(out).shape == (3,)
+ out_none = syn.update() # x None branch
+ assert jnp.all(jnp.isfinite(jnp.asarray(out_none)))
+ assert syn.return_info() is syn.g
+
+
+def test_synapse_forward_and_return_info_methods():
+ """Drive ``update()`` (covering each derivative + the discrete spike jump)
+ and ``return_info`` for the JointEq-based synapses (abstract_models lines
+ 547/550/554-556, 727/730/737-741, 793-803, 867-888, plus return_info 292,
+ 559, 744, 806, 891)."""
+ bm.random.seed(52)
+ _share()
+ spike_bool = bm.asarray([True, False, True], dtype=bool)
+ spike_float = spike_bool.astype(float)
+ for cls, kwargs, inp in [
+ (bp.dyn.DualExpon, dict(tau_decay=5.0, tau_rise=1.0), spike_bool),
+ (bp.dyn.Alpha, dict(tau_decay=5.0), spike_bool),
+ (bp.dyn.NMDA, dict(tau_decay=10.0, tau_rise=2.0), spike_float),
+ (bp.dyn.STD, dict(tau=20.0), spike_bool),
+ (bp.dyn.STP, dict(), spike_bool),
+ ]:
+ syn = cls(3, **kwargs)
+ syn.reset_state()
+ out = syn.update(inp)
+ assert jnp.asarray(out).shape == (3,)
+ assert jnp.all(jnp.isfinite(jnp.asarray(out)))
+ info = syn.return_info()
+ assert info is not None
+
+
+def test_dualexpon_forward_runs():
+ """Exercise DualExpon.update (line 283-289) + return_info (line 292)."""
+ bm.random.seed(53)
+ syn = bp.dyn.DualExpon(3, tau_decay=5.0, tau_rise=1.0)
+ syn.reset_state()
+ _share()
+ out = syn.update(bm.asarray([1.0, 0.0, 1.0]))
+ assert jnp.asarray(out).shape == (3,)
+ assert jnp.all(jnp.isfinite(jnp.asarray(out)))
+ assert syn.return_info() is syn.g
diff --git a/tests/audit/test_boost_runners_delay.py b/tests/audit/test_boost_runners_delay.py
new file mode 100644
index 000000000..580a602e5
--- /dev/null
+++ b/tests/audit/test_boost_runners_delay.py
@@ -0,0 +1,879 @@
+# -*- coding: utf-8 -*-
+"""Audit coverage-boost tests for ``brainpy/runners.py`` and ``brainpy/delay.py``.
+
+These tests exercise the *uncovered* option matrix of ``DSRunner`` (jit on/off,
+progress bar, memory-efficient mode, list/dict/callable monitors, deprecated
+``fun_monitors``, ``numpy_mon_after_run``, ``data_first_axis`` = 'T'/'B', nonzero
+``t0``, the several input formats, ``shared_args``/dyn args, repeated ``.run()``,
+``.predict()`` and ``.reset_state()``) plus all the ``__init__``/``predict``
+validation error paths.
+
+For ``delay.py`` it drives ``Delay``/``VarDelay``/``DataDelay``/``DelayAccess``,
+both ``ROTATE_UPDATE`` and ``CONCAT_UPDATE`` methods, ``register_entry`` by
+``time=`` and by ``step=``, ``.at()``/``.retrieve()``/``.update()``,
+``init_delay_by_return`` (Variable + ReturnInfo), ``before_t0``-style ``init`` as
+array and as callable, and the validation/error branches.
+
+Sibling audit files already cover basic ``DSRunner`` and ``VarDelay(time=T)``;
+here we focus on the previously-uncovered options and code paths so that line
+coverage of both modules rises toward >=90%.
+"""
+
+import warnings
+
+import numpy as np
+import pytest
+
+import brainpy as bp
+import brainpy.math as bm
+import jax.numpy as jnp
+
+from brainpy import check
+from brainpy.delay import (
+ Delay,
+ VarDelay,
+ DataDelay,
+ DelayAccess,
+ init_delay_by_return,
+ register_delay_by_return,
+)
+from brainpy.math.delayvars import ROTATE_UPDATE, CONCAT_UPDATE
+from brainpy.mixin import ReturnInfo
+
+
+# ---------------------------------------------------------------------------
+# helpers
+# ---------------------------------------------------------------------------
+
+def _net(n=4, **kw):
+ """A tiny spiking network used across the runner tests."""
+ bm.random.seed(123)
+ return bp.dyn.LifRef(n, **kw)
+
+
+# ===========================================================================
+# DSRunner -- monitors
+# ===========================================================================
+
+def test_runner_list_monitors_jit_memory_efficient():
+ n = _net()
+ r = bp.DSRunner(n, monitors=['V', 'spike'], jit=True,
+ progress_bar=False, memory_efficient=True)
+ out = r.run(2.0) # 20 steps @ dt=0.1
+ assert r.mon['V'].shape == (20, 4)
+ assert r.mon['spike'].shape == (20, 4)
+ assert r.mon['ts'].shape == (20,)
+ # memory-efficient mode forces numpy monitors
+ assert isinstance(r.mon['V'], np.ndarray)
+
+
+def test_runner_list_monitors_with_index():
+ n = _net()
+ r = bp.DSRunner(n, monitors=[('V', 0), ('spike', [1, 2])],
+ progress_bar=False)
+ r.run(1.0)
+ assert np.asarray(r.mon['V']).shape == (10, 1)
+ assert np.asarray(r.mon['spike']).shape == (10, 2)
+
+
+def test_runner_dict_monitors_variants():
+ """dict monitors: explicit var, (var, idx) tuple, and a callable."""
+ n = _net()
+ r = bp.DSRunner(
+ n,
+ monitors={'v0': (n.V, 0), 'vall': n.V, 'fcb': lambda: n.V[:2]},
+ jit=False,
+ progress_bar=False,
+ numpy_mon_after_run=False, # keep jax arrays
+ data_first_axis='T',
+ )
+ r.run(1.0)
+ assert np.asarray(r.mon['v0']).shape == (10, 1)
+ assert np.asarray(r.mon['vall']).shape == (10, 4)
+ assert np.asarray(r.mon['fcb']).shape == (10, 2)
+ # numpy_mon_after_run=False -> ts is a jax array
+ assert isinstance(bm.as_jax(r.mon['ts']), jnp.ndarray)
+
+
+def test_runner_fun_monitors_deprecated_path():
+ n = _net()
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore')
+ r = bp.DSRunner(n, fun_monitors={'sp': lambda: n.spike[:2]},
+ progress_bar=False)
+ r.run(1.0)
+ assert np.asarray(r.mon['sp']).shape == (10, 2)
+
+
+def test_runner_progress_bar_true():
+ """progress_bar=True exercises the brainstate pbar branch."""
+ n = _net()
+ r = bp.DSRunner(n, monitors=['V'], progress_bar=True, jit=True)
+ r.run(1.0)
+ assert r.mon['V'].shape == (10, 4)
+
+
+def test_runner_t0_nonzero_shifts_time_axis():
+ n = _net()
+ r = bp.DSRunner(n, monitors=['V'], progress_bar=False, t0=5.0)
+ r.run(1.0)
+ assert float(r.mon['ts'][0]) == pytest.approx(5.0)
+ assert float(r.mon['ts'][1]) == pytest.approx(5.1)
+
+
+# ===========================================================================
+# DSRunner -- inputs
+# ===========================================================================
+
+def test_runner_input_fix_tuple():
+ n = _net()
+ r = bp.DSRunner(n, monitors=['V'], inputs=(n.V, 1.0), progress_bar=False)
+ r.run(1.0)
+ assert r.mon['V'].shape == (10, 4)
+
+
+def test_runner_input_iter_array():
+ n = _net()
+ arr = np.ones((10, 4)) * 0.5
+ r = bp.DSRunner(n, monitors=['V'], inputs=(n.V, arr, 'iter'),
+ progress_bar=False)
+ r.run(1.0)
+ assert r.mon['V'].shape == (10, 4)
+
+
+def test_runner_input_iter_generator():
+ """A non-array iterable goes through the ``next(...)`` path."""
+ n = _net()
+ r = bp.DSRunner(n, monitors=['V'], inputs=(n.V, [0.1] * 20, 'iter'),
+ progress_bar=False)
+ r.run(1.0)
+ assert r.mon['V'].shape == (10, 4)
+
+
+def test_runner_input_func():
+ n = _net()
+ r = bp.DSRunner(n, monitors=['V'], inputs=(n.V, lambda: 0.3, 'func', '+'),
+ progress_bar=False)
+ r.run(1.0)
+ assert r.mon['V'].shape == (10, 4)
+
+
+def test_runner_input_string_target_assign_op():
+ """String target ('V') with the '=' operation (relative/absolute access)."""
+ n = _net()
+ r = bp.DSRunner(n, monitors=['V'], inputs=('V', 0.1, 'fix', '='),
+ progress_bar=False)
+ r.run(1.0)
+ assert r.mon['V'].shape == (10, 4)
+
+
+def test_runner_callable_inputs():
+ n = _net()
+
+ def fin():
+ n.V += 0.1
+
+ r = bp.DSRunner(n, monitors=['V'], inputs=fin, progress_bar=False)
+ r.run(1.0)
+ assert r.mon['V'].shape == (10, 4)
+
+
+def test_runner_multiple_inputs_and_ops():
+ """Several inputs with different operations in one runner."""
+ n = _net()
+ r = bp.DSRunner(
+ n,
+ monitors=['V'],
+ inputs=[(n.V, 0.2, 'fix', '+'),
+ ('V', 0.01, 'fix', '*')],
+ progress_bar=False,
+ )
+ r.run(1.0)
+ assert r.mon['V'].shape == (10, 4)
+
+
+# ===========================================================================
+# DSRunner -- predict / run / reset_state
+# ===========================================================================
+
+def test_predict_with_xs_array_nonbatching():
+ n = _net()
+ r = bp.DSRunner(n, monitors=['V'], progress_bar=False)
+ xs = np.ones((15, 4)) * 0.4
+ out = r.predict(inputs=xs)
+ assert np.asarray(out).shape == (15, 4)
+ assert r.mon['ts'].shape == (15,)
+
+
+def test_predict_eval_time_and_reset_state_arg():
+ n = _net()
+ r = bp.DSRunner(n, monitors=['V'], progress_bar=False)
+ running_time, out = r.predict(1.0, reset_state=True, eval_time=True)
+ assert isinstance(running_time, float)
+ assert np.asarray(out).shape == (10, 4)
+
+
+def test_runner_reset_state_method():
+ n = _net()
+ r = bp.DSRunner(n, monitors=['V'], progress_bar=False)
+ r.run(1.0)
+ assert r.i0 == 10
+ r.reset_state()
+ assert r.i0 == 0
+
+
+def test_runner_repeated_run_accumulates_i0():
+ n = _net()
+ r = bp.DSRunner(n, monitors=['V'], progress_bar=False)
+ r.run(1.0)
+ assert r.i0 == 10
+ r.run(1.0)
+ assert r.i0 == 20
+
+
+def test_runner_shared_args_dyn_args():
+ n = _net()
+ r = bp.DSRunner(n, monitors=['V'], progress_bar=False)
+ r.predict(1.0, shared_args={'fit': False})
+ assert r.mon['V'].shape == (10, 4)
+
+
+def test_runner_call_dunder():
+ n = _net()
+ r = bp.DSRunner(n, monitors=['V'], progress_bar=False)
+ out = r(1.0) # __call__
+ assert np.asarray(out).shape == (10, 4)
+
+
+def test_runner_repr():
+ n = _net()
+ r = bp.DSRunner(n, monitors=['V'], progress_bar=False)
+ s = repr(r)
+ assert 'DSRunner' in s and 'data_first_axis' in s
+
+
+def test_predict_duration_and_inputs_warns():
+ """Providing both duration and inputs warns and uses inputs' time axis."""
+ n = _net()
+ r = bp.DSRunner(n, monitors=['V'], progress_bar=False)
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore')
+ r.predict(1.0, inputs=np.ones((10, 4)) * 0.3)
+ assert r.mon['ts'].shape == (10,)
+
+
+# ===========================================================================
+# DSRunner -- batching mode + data_first_axis='B'
+# ===========================================================================
+
+def test_runner_batching_mode_first_axis_B():
+ bm.random.seed(7)
+ n = bp.dyn.LifRef(4, mode=bm.batching_mode)
+ n.reset(batch_size=3)
+ r = bp.DSRunner(n, monitors=['V'], progress_bar=False, data_first_axis='B')
+ xs = np.ones((3, 12, 4)) * 0.4 # (batch, time, feat)
+ out = r.predict(inputs=xs, reset_state=True)
+ assert np.asarray(out).shape == (3, 12, 4)
+ assert r.mon['V'].shape == (3, 12, 4)
+
+
+def test_runner_batching_default_data_first_axis():
+ """A BatchingMode target defaults data_first_axis to 'B'."""
+ n = bp.dyn.LifRef(4, mode=bm.batching_mode)
+ r = bp.DSRunner(n, monitors=['V'], progress_bar=False)
+ assert r.data_first_axis == 'B'
+
+
+# ===========================================================================
+# DSRunner -- validation / error branches
+# ===========================================================================
+
+def test_runner_bad_target_type():
+ from brainpy._errors import RunningError
+ with pytest.raises(RunningError):
+ bp.DSRunner(object())
+
+
+def test_runner_bad_monitors_type():
+ with pytest.raises(Exception): # MonitorError
+ bp.DSRunner(_net(), monitors=123)
+
+
+def test_runner_bad_inputs_type():
+ from brainpy._errors import RunningError
+ with pytest.raises(RunningError):
+ bp.DSRunner(_net(), inputs=123)
+
+
+def test_runner_bad_input_structure():
+ from brainpy._errors import RunningError
+ # first element neither str nor Variable and not a list/tuple
+ with pytest.raises(RunningError):
+ bp.DSRunner(_net(), inputs=[1.0, 2.0])
+
+
+def test_runner_bad_input_length():
+ from brainpy._errors import RunningError
+ with pytest.raises(RunningError):
+ bp.DSRunner(_net(), inputs=(_net().V,)) # length 1
+
+
+def test_runner_bad_input_op():
+ from brainpy._errors import RunningError
+ with pytest.raises(RunningError):
+ bp.DSRunner(_net(), inputs=('V', 1.0, 'fix', '^'))
+
+
+def test_runner_bad_input_type_str():
+ from brainpy._errors import RunningError
+ with pytest.raises(RunningError):
+ bp.DSRunner(_net(), inputs=('V', 1.0, 'bogus'))
+
+
+def test_runner_iter_value_not_iterable():
+ with pytest.raises(ValueError):
+ bp.DSRunner(_net(), inputs=('V', 5, 'iter'))
+
+
+def test_runner_func_value_not_callable():
+ with pytest.raises(ValueError):
+ bp.DSRunner(_net(), inputs=('V', 5, 'func'))
+
+
+def test_runner_bad_input_target_attr():
+ with pytest.raises(AttributeError):
+ bp.DSRunner(_net(), inputs=('nonexist.attr', 1.0))
+
+
+def test_runner_nonvar_nonstr_target():
+ from brainpy._errors import RunningError
+ with pytest.raises(RunningError):
+ bp.DSRunner(_net(), inputs=(123, 1.0))
+
+
+def test_runner_memory_efficient_requires_numpy_mon():
+ with pytest.raises(ValueError):
+ bp.DSRunner(_net(), memory_efficient=True, numpy_mon_after_run=False)
+
+
+def test_predict_without_duration_or_inputs():
+ n = _net()
+ r = bp.DSRunner(n, progress_bar=False)
+ with pytest.raises(ValueError):
+ r.predict()
+
+
+def test_runner_bad_dt_type():
+ from brainpy._errors import RunningError
+ with pytest.raises(RunningError):
+ bp.DSRunner(_net(), dt=1) # int, not float
+
+
+# ===========================================================================
+# delay.py -- VarDelay ROTATE_UPDATE
+# ===========================================================================
+
+def _drive_delay(delay, target_var, n_steps, value_fn=None):
+ """Run a delay update loop with the proper shared-arg context."""
+ dt = bm.get_dt()
+ for i in range(n_steps):
+ bp.share.save(i=i, t=i * dt, dt=dt)
+ v = (value_fn(i) if value_fn is not None
+ else bm.ones_like(target_var.value) * i)
+ target_var.value = v
+ delay.update()
+
+
+def test_vardelay_rotate_register_time_and_step():
+ bm.random.seed(0)
+ v = bm.Variable(bm.zeros(3))
+ d = VarDelay(v, time=2.0) # 20 steps capacity
+ d.register_entry('by_time', delay_time=1.0) # -> 10 steps
+ d.register_entry('by_step', delay_step=5)
+ d.register_entry('zero', delay_time=None) # zero-delay -> target value
+ assert d._registered_entries['by_time'] == 10
+ assert d._registered_entries['by_step'] == 5
+ assert d._registered_entries['zero'] is None
+
+ _drive_delay(d, v, 25)
+ # at the zero entry returns the current target value
+ np.testing.assert_allclose(np.asarray(d.at('zero')), 24.0)
+ # delayed entries return earlier values
+ np.testing.assert_allclose(np.asarray(d.at('by_time')), 15.0)
+ np.testing.assert_allclose(np.asarray(d.at('by_step', 0)), 20.0)
+ # direct retrieve
+ np.testing.assert_allclose(np.asarray(d.retrieve(3)), 22.0)
+
+
+def test_vardelay_at_with_indices():
+ bm.random.seed(0)
+ v = bm.Variable(bm.zeros(4))
+ d = VarDelay(v, time=1.0)
+ d.register_entry('e', delay_step=5)
+ _drive_delay(d, v, 12)
+ out = d.at('e', 1)
+ assert np.asarray(out).shape == ()
+
+
+def test_vardelay_init_array_and_callable():
+ """``init`` provided as array and as callable (before_t0-style data)."""
+ v = bm.Variable(bm.zeros(3))
+ d_arr = VarDelay(v, time=0.5, init=jnp.ones((5, 3)) * 2.0)
+ assert np.asarray(d_arr.data.value).sum() == pytest.approx(2.0 * 15)
+
+ v2 = bm.Variable(bm.zeros(3))
+ d_call = VarDelay(v2, time=0.5,
+ init=lambda shape, dtype: jnp.ones(shape, dtype) * 3.0)
+ assert np.asarray(d_call.data.value).sum() == pytest.approx(3.0 * 15)
+
+
+def test_vardelay_reset_state_and_repr():
+ v = bm.Variable(bm.zeros(3))
+ d = VarDelay(v, time=1.0, init=5.0)
+ d.register_entry('e', delay_step=3)
+ _drive_delay(d, v, 5)
+ d.reset_state()
+ # after reset the buffer is re-initialised to init=5.0
+ np.testing.assert_allclose(np.asarray(d.data.value), 5.0)
+ assert 'VarDelay' in repr(d)
+ assert d.delay_target_shape == (3,)
+
+
+def test_vardelay_register_entry_array_delay_time():
+ v = bm.Variable(bm.zeros(3))
+ d = VarDelay(v, time=1.0)
+ d.register_entry('arr', delay_time=jnp.asarray(0.5))
+ assert d._registered_entries['arr'] == 5
+
+
+# ===========================================================================
+# delay.py -- VarDelay CONCAT_UPDATE
+# ===========================================================================
+
+def test_vardelay_concat_update_multi_step():
+ bm.random.seed(0)
+ v = bm.Variable(bm.zeros(3))
+ d = VarDelay(v, time=1.0, method=CONCAT_UPDATE)
+ assert d.method == CONCAT_UPDATE
+ d.register_entry('a', delay_step=5)
+ _drive_delay(d, v, 12)
+ np.testing.assert_allclose(np.asarray(d.at('a')), 7.0)
+ np.testing.assert_allclose(np.asarray(d.retrieve(3)), 9.0)
+
+
+def test_vardelay_concat_update_single_step():
+ """max_length==1 hits the special-case concat branch."""
+ bm.random.seed(0)
+ v = bm.Variable(bm.zeros(2))
+ d = VarDelay(v, time=0.1, method=CONCAT_UPDATE) # length 1
+ d.register_entry('b', delay_step=1)
+ _drive_delay(d, v, 3)
+ assert np.asarray(d.at('b')).shape == (2,)
+
+
+# ===========================================================================
+# delay.py -- DataDelay
+# ===========================================================================
+
+def test_datadelay_update_retrieve_reset():
+ bm.random.seed(0)
+ data = bm.Variable(bm.zeros(3))
+ dd = DataDelay(data, data_init=bm.zeros, time=0.5)
+ dd.register_entry('c', delay_step=3)
+ dt = bm.get_dt()
+ for i in range(8):
+ bp.share.save(i=i, t=i * dt, dt=dt)
+ dd.update(bm.ones(3) * i)
+ np.testing.assert_allclose(np.asarray(dd.at('c')), 5.0)
+ dd.reset_state()
+ assert np.asarray(dd.data.value).sum() == pytest.approx(0.0)
+
+
+def test_datadelay_reset_state_with_batch():
+ bm.random.seed(0)
+ data = bm.Variable(bm.zeros((1, 3)), batch_axis=0)
+ dd = DataDelay(data, data_init=bm.zeros, time=0.5)
+ dd.register_entry('c', delay_step=2)
+ dd.reset_state(batch_size=1)
+ assert dd.data is not None
+
+
+# ===========================================================================
+# delay.py -- DelayAccess
+# ===========================================================================
+
+def test_delay_access_update_and_reset():
+ bm.random.seed(0)
+ v = bm.Variable(bm.zeros(3))
+ d = VarDelay(v, time=1.0)
+ acc = DelayAccess(d, time=0.5, delay_entry='myacc')
+ _drive_delay(d, v, 12)
+ np.testing.assert_allclose(np.asarray(acc.update()), 7.0)
+ acc.reset_state() # no-op, just exercise the branch
+
+
+def test_delay_access_with_indices():
+ bm.random.seed(0)
+ v = bm.Variable(bm.zeros(4))
+ d = VarDelay(v, time=1.0)
+ acc = DelayAccess(d, 0.3, 0, delay_entry='idxacc')
+ _drive_delay(d, v, 10)
+ assert np.asarray(acc.update()).shape == ()
+
+
+# ===========================================================================
+# delay.py -- init_delay_by_return
+# ===========================================================================
+
+def test_init_delay_by_return_variable():
+ v = bm.Variable(bm.zeros(4))
+ d = init_delay_by_return(v)
+ assert isinstance(d, VarDelay)
+
+
+def test_init_delay_by_return_returninfo_nonbatching():
+ ri = ReturnInfo(size=(4,), batch_or_mode=bm.NonBatchingMode(), data=bm.zeros)
+ d = init_delay_by_return(ri)
+ assert isinstance(d, DataDelay)
+ assert d.target.shape == (4,)
+
+
+def test_init_delay_by_return_returninfo_int_batch():
+ ri = ReturnInfo(size=(4,), batch_or_mode=2, data=bm.zeros)
+ d = init_delay_by_return(ri)
+ assert isinstance(d, DataDelay)
+ assert d.target.shape == (2, 4)
+
+
+def test_init_delay_by_return_returninfo_batchingmode():
+ ri = ReturnInfo(size=(4,), batch_or_mode=bm.BatchingMode(3), data=bm.zeros)
+ d = init_delay_by_return(ri)
+ assert isinstance(d, DataDelay)
+ assert d.target.shape == (3, 4)
+
+
+def test_init_delay_by_return_returninfo_array_data():
+ ri = ReturnInfo(size=(4,), batch_or_mode=bm.NonBatchingMode(),
+ data=jnp.ones((4,)))
+ d = init_delay_by_return(ri)
+ assert isinstance(d, DataDelay)
+
+
+# ===========================================================================
+# delay.py -- validation / error branches
+# ===========================================================================
+
+def test_vardelay_bad_target():
+ with pytest.raises(ValueError):
+ VarDelay(bm.zeros(3), time=1.0) # plain array, not a Variable
+
+
+def test_delay_bad_time_type():
+ with pytest.raises(TypeError):
+ VarDelay(bm.Variable(bm.zeros(3)), time='oops')
+
+
+def test_vardelay_duplicate_entry():
+ v = bm.Variable(bm.zeros(3))
+ d = VarDelay(v, time=1.0)
+ d.register_entry('e')
+ with pytest.raises(KeyError):
+ d.register_entry('e', delay_step=2)
+
+
+def test_vardelay_at_missing_entry():
+ v = bm.Variable(bm.zeros(3))
+ d = VarDelay(v, time=1.0)
+ with pytest.raises(KeyError):
+ d.at('does-not-exist')
+
+
+def test_get_delay_both_time_and_step():
+ v = bm.Variable(bm.zeros(3))
+ d = VarDelay(v, time=1.0)
+ with pytest.raises(AssertionError):
+ d.register_entry('z', delay_time=0.5, delay_step=3)
+
+
+def test_init_delay_by_return_bad_type():
+ with pytest.raises(TypeError):
+ init_delay_by_return(123)
+
+
+def test_init_delay_by_return_returninfo_bad_data():
+ ri = ReturnInfo(size=(4,), batch_or_mode=bm.NonBatchingMode(), data=123)
+ with pytest.raises(TypeError):
+ init_delay_by_return(ri)
+
+
+def test_base_delay_register_entry_not_implemented():
+ d = Delay(time=1.0)
+ with pytest.raises(NotImplementedError):
+ d.register_entry('x', delay_step=1)
+
+
+def test_base_delay_at_not_implemented():
+ d = Delay(time=1.0)
+ with pytest.raises(NotImplementedError):
+ d.at('x')
+
+
+def test_base_delay_retrieve_not_implemented():
+ d = Delay(time=1.0)
+ with pytest.raises(NotImplementedError):
+ d.retrieve(1)
+
+
+# ===========================================================================
+# Extra coverage: input ops, multi-leaf inputs, memory-efficient w/o jit
+# ===========================================================================
+
+def test_runner_input_ops_minus_mul_div():
+ """Exercise the '-', '*', '/' branches of ``_f_ops``."""
+ n = _net(3)
+ r = bp.DSRunner(
+ n,
+ monitors=['V'],
+ inputs=[('V', 0.1, 'fix', '-'),
+ ('V', 1.0, 'fix', '*'),
+ ('V', 2.0, 'fix', '/')],
+ progress_bar=False,
+ )
+ r.run(0.5)
+ assert r.mon['V'].shape == (5, 3)
+
+
+def test_runner_input_func_with_shared_arg():
+ """Input callable that *accepts* a shared argument (deprecated bind path)."""
+ n = _net(3)
+
+ def fin(shared):
+ n.V += 0.1
+
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore')
+ r = bp.DSRunner(n, monitors=['V'], inputs=fin, progress_bar=False)
+ r.run(0.5)
+ assert r.mon['V'].shape == (5, 3)
+
+
+def test_runner_memory_efficient_without_jit():
+ n = _net(3)
+ r = bp.DSRunner(n, monitors=['V'], jit=False, progress_bar=False,
+ memory_efficient=True)
+ r.run(0.5)
+ assert r.mon['V'].shape == (5, 3)
+ assert isinstance(r.mon['V'], np.ndarray)
+
+
+def _TwoPop():
+ class Two(bp.DynamicalSystem):
+ def __init__(self):
+ super().__init__()
+ self.a = bp.dyn.LifRef(3)
+ self.b = bp.dyn.LifRef(3)
+
+ def update(self, x, y):
+ s1 = self.a(x)
+ self.b(y)
+ return s1
+
+ return Two()
+
+
+def test_runner_multi_leaf_inputs_time_step():
+ """A 2-leaf input tuple exercises the multi-leaf time-step branch."""
+ net = _TwoPop()
+ r = bp.DSRunner(net, monitors={'va': net.a.V}, progress_bar=False)
+ xs = (np.ones((12, 3)) * 0.4, np.ones((12, 3)) * 0.2)
+ out = r.predict(inputs=xs)
+ assert r.mon['ts'].shape == (12,)
+ assert np.asarray(out).shape == (12, 3)
+
+
+# ===========================================================================
+# Extra coverage: delay constructor entries, growth, batching, modes
+# ===========================================================================
+
+def test_vardelay_entries_in_constructor():
+ """The ``entries=`` constructor argument registers entries (264-265)."""
+ v = bm.Variable(bm.zeros(3))
+ d = VarDelay(v, time=2.0, entries={'e1': 1.0, 'e2': 0.5})
+ assert d._registered_entries['e1'] == 10
+ assert d._registered_entries['e2'] == 5
+
+
+def test_vardelay_register_entry_grows_max_length():
+ """Registering an entry larger than current capacity grows the buffer."""
+ v = bm.Variable(bm.zeros(3))
+ d = VarDelay(v, time=0.5) # 5 steps
+ assert d.max_length == 5
+ d.register_entry('big', delay_step=10)
+ assert d.max_length == 10
+ # data buffer was reallocated to the new length
+ assert d.data.value.shape[0] == 10
+
+
+def test_vardelay_training_mode_uses_concat():
+ """A TrainingMode delay defaults to CONCAT_UPDATE (lines 95-96)."""
+ v = bm.Variable(bm.zeros((2, 3)), batch_axis=0)
+ d = VarDelay(v, time=0.5, mode=bm.training_mode)
+ assert d.method == CONCAT_UPDATE
+
+
+def test_vardelay_batching_mode_at_with_index():
+ """BatchingMode delay: ``.at`` inserts a slice at the batch axis (319-321)."""
+ bm.random.seed(0)
+ v = bm.Variable(bm.zeros((2, 3)), batch_axis=0)
+ d = VarDelay(v, time=0.5, mode=bm.batching_mode)
+ d.register_entry('e', delay_step=3)
+ dt = bm.get_dt()
+ for i in range(6):
+ bp.share.save(i=i, t=i * dt, dt=dt)
+ v.value = bm.ones((2, 3)) * i
+ d.update()
+ assert np.asarray(d.at('e')).shape == (2, 3)
+ # indexing into the feature axis keeps the batch dim
+ assert np.asarray(d.at('e', 0)).shape == (2,)
+
+
+def test_vardelay_retrieve_under_checking():
+ """With checking enabled, ``retrieve`` runs the ``jit_error`` guard path."""
+ was_checking = check.is_checking()
+ check.turn_on()
+ try:
+ bm.random.seed(0)
+ v = bm.Variable(bm.zeros(3))
+ d = VarDelay(v, time=0.5)
+ d.register_entry('e', delay_step=3)
+ dt = bm.get_dt()
+ for i in range(6):
+ bp.share.save(i=i, t=i * dt, dt=dt)
+ v.value = bm.ones(3) * i
+ d.update()
+ # the checked retrieve path returns a concrete delayed value
+ np.testing.assert_allclose(np.asarray(d.at('e')), 3.0)
+ finally:
+ if not was_checking:
+ check.turn_off()
+
+
+def test_vardelay_unknown_update_method_raises():
+ v = bm.Variable(bm.zeros(3))
+ d = VarDelay(v, time=0.5)
+ d.register_entry('e', delay_step=2)
+ d.method = 'bogus'
+ bp.share.save(i=0, t=0.0, dt=bm.get_dt())
+ with pytest.raises(ValueError):
+ d.update()
+
+
+def test_vardelay_unknown_method_retrieve_raises():
+ v = bm.Variable(bm.zeros(3))
+ d = VarDelay(v, time=0.5)
+ d.register_entry('e', delay_step=2)
+ # populate first with a valid method
+ dt = bm.get_dt()
+ for i in range(4):
+ bp.share.save(i=i, t=i * dt, dt=dt)
+ v.value = bm.ones(3) * i
+ d.update()
+ d.method = 'bogus'
+ bp.share.save(i=4, t=0.4, dt=dt)
+ with pytest.raises(ValueError):
+ d.retrieve(2)
+
+
+def test_register_delay_by_return_reuses_instance():
+ """``register_delay_by_return`` adds + reuses an after-update delay."""
+ n = bp.dyn.LifRef(4)
+ d1 = register_delay_by_return(n)
+ d2 = register_delay_by_return(n)
+ assert isinstance(d1, Delay)
+ assert d1 is d2 # second call reuses the registered instance
+
+
+def test_runner_fun_inputs_deprecated():
+ """The deprecated ``fun_inputs`` argument still drives inputs (367, 562)."""
+ n = _net(3)
+
+ def finp(shared):
+ n.V += 0.1
+
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore')
+ r = bp.DSRunner(n, monitors=['V'], fun_inputs=finp, progress_bar=False)
+ r.run(0.5)
+ assert r.mon['V'].shape == (5, 3)
+
+
+def test_vardelay_zero_delay_at_with_index():
+ """A zero-delay entry returns the *indexed* current target value (325)."""
+ v = bm.Variable(bm.zeros(4))
+ d = VarDelay(v, time=1.0)
+ d.register_entry('zero', delay_time=None)
+ bp.share.save(i=0, t=0.0, dt=bm.get_dt())
+ v.value = bm.arange(4).astype(bm.float_)
+ assert float(np.asarray(d.at('zero', 1))) == pytest.approx(1.0)
+
+
+def test_vardelay_axis_names_sharding():
+ """A target with axis_names builds a sharded delay buffer (244-246)."""
+ v = bm.Variable(bm.zeros(4), axis_names=('feat',))
+ d = VarDelay(v, time=0.5)
+ assert d.data.value.shape == (5, 4)
+
+
+def test_init_delay_by_return_returninfo_axis_names():
+ """ReturnInfo carrying axis_names builds a DataDelay (asserts ndim, 559)."""
+ ri = ReturnInfo(size=(4,), batch_or_mode=bm.NonBatchingMode(),
+ data=bm.zeros, axis_names=('feat',))
+ d = init_delay_by_return(ri)
+ assert isinstance(d, DataDelay)
+
+
+# ===========================================================================
+# Extra coverage: input formatting / _f_ops helpers (direct unit tests)
+# ===========================================================================
+
+def test_check_and_format_inputs_none():
+ """``inputs=None`` produces empty formatted-input buckets (line 79)."""
+ from brainpy.runners import check_and_format_inputs
+ n = _net(3)
+ res = check_and_format_inputs(n, None)
+ assert set(res) == {'fixed', 'iterated', 'functional', 'array'}
+ # everything empty
+ assert all(len(lst) == 0 for d in res.values() for lst in d.values())
+
+
+def test_check_and_format_inputs_nonstr_target():
+ """A non-str / non-Variable target raises in absolute access (line 122)."""
+ from brainpy.runners import check_and_format_inputs
+ from brainpy._errors import RunningError
+ n = _net(3)
+ with pytest.raises(RunningError):
+ check_and_format_inputs(n, [(123, 1.0)])
+
+
+def test_f_ops_unknown_operation_raises():
+ """``_f_ops`` with an unknown operation raises (line 207)."""
+ from brainpy.runners import _f_ops
+ n = _net(3)
+ with pytest.raises(ValueError):
+ _f_ops('^', n.V, 1.0)
+
+
+def test_f_ops_all_supported_operations():
+ """Each supported ``_f_ops`` branch runs without error."""
+ from brainpy.runners import _f_ops
+ v = bm.Variable(bm.ones(3))
+ _f_ops('=', v, 2.0)
+ np.testing.assert_allclose(np.asarray(v.value), 2.0)
+ _f_ops('+', v, 1.0)
+ np.testing.assert_allclose(np.asarray(v.value), 3.0)
+ _f_ops('-', v, 1.0)
+ np.testing.assert_allclose(np.asarray(v.value), 2.0)
+ _f_ops('*', v, 2.0)
+ np.testing.assert_allclose(np.asarray(v.value), 4.0)
+ _f_ops('/', v, 4.0)
+ np.testing.assert_allclose(np.asarray(v.value), 1.0)
diff --git a/tests/audit/test_dnn_toolbox_fixes.py b/tests/audit/test_dnn_toolbox_fixes.py
new file mode 100644
index 000000000..ded4f0f50
--- /dev/null
+++ b/tests/audit/test_dnn_toolbox_fixes.py
@@ -0,0 +1,764 @@
+# -*- coding: utf-8 -*-
+"""Regression + coverage tests for the DNN / toolbox audit fixes (2026-06-18).
+
+This module tests the fixes recorded in ``docs/issues-found-20260618.md`` for the
+following source files:
+
+* ``brainpy/dnn/normalization.py`` -- C-05, H-51, M-25/M-26 (GroupNorm/Instance
+ group axis, BatchNorm/LayerNorm/GroupNorm
+ constructing + running under the default mode).
+* ``brainpy/losses/comparison.py`` -- C-02 (``nll_loss`` sign), C-03 (class-weighted
+ cross-entropy), H-53 (``ignore_index`` /
+ ``label_smoothing``).
+* ``brainpy/optim/optimizer.py`` -- C-01 / H-52 (Adam/AdamW bias-correction step
+ counter), M-29 (``SM3`` instantiable with
+ train vars).
+* ``brainpy/optim/scheduler.py`` -- C-04 (``MultiStepLR`` decay), M-01 (other LR
+ schedulers).
+* ``brainpy/encoding/stateless_encoding.py`` -- ``PoissonEncoder.single_step`` no longer
+ crashes / passes wrong args.
+* ``brainpy/connect/random_conn.py`` -- M-30 (``FixedProb`` empty/rectangular build).
+
+In addition to the targeted regression checks, the module exercises every optimizer,
+scheduler, loss function, normalization layer and stateless/stateful encoder for
+coverage. Known *remaining* bugs (``Adan.update``, ``ctc_loss``'s ``.value`` access)
+are pinned with ``pytest.raises`` so the lines stay covered and the behaviour is
+documented rather than silently broken.
+"""
+
+import numpy as np
+import jax.numpy as jnp
+import pytest
+
+import brainpy as bp
+import brainpy.math as bm
+from brainpy.context import share
+from brainpy.optim import optimizer as O
+from brainpy.optim import scheduler as S
+from brainpy.losses import comparison as C
+from brainpy.optim.optimizer import SM3
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _np(x):
+ return np.asarray(bm.as_jax(x))
+
+
+def _train_var(shape=(2, 3), fill=1.0):
+ return bm.Variable(bm.ones(shape) * fill)
+
+
+# ===========================================================================
+# 1. normalization.py -- GroupNorm / InstanceNorm / BatchNorm / LayerNorm
+# ===========================================================================
+
+def test_groupnorm_respects_groups():
+ """C-05: GroupNorm(3,6) must differ from GroupNorm(1,6); num_groups has effect."""
+ bm.random.seed(0)
+ x = bm.random.randn(4, 6)
+ g3 = _np(bp.dnn.GroupNorm(3, 6, affine=False)(x))
+ g1 = _np(bp.dnn.GroupNorm(1, 6, affine=False)(x))
+ g6 = _np(bp.dnn.GroupNorm(6, 6, affine=False)(x))
+ assert not np.allclose(g3, g1), "GroupNorm(3,6) must differ from GroupNorm(1,6)"
+ assert not np.allclose(g3, g6), "GroupNorm(3,6) must differ from GroupNorm(6,6)"
+
+
+def test_groupnorm_each_group_zero_mean_unit_var():
+ """C-05: each group is normalized independently to ~zero-mean / unit-var."""
+ bm.random.seed(1)
+ x = bm.random.randn(4, 6) * 3.0 + 5.0
+ y = _np(bp.dnn.GroupNorm(3, 6, affine=False)(x))
+ # 6 channels / 3 groups -> 2 channels per group.
+ grouped = y.reshape(4, 3, 2)
+ assert np.allclose(grouped.mean(axis=-1), 0.0, atol=1e-4)
+ assert np.allclose(grouped.var(axis=-1), 1.0, atol=1e-2)
+
+
+def test_instancenorm_per_channel_unit_var():
+ """C-05: InstanceNorm normalizes each channel independently over the spatial axis."""
+ bm.random.seed(2)
+ x = bm.random.randn(4, 8, 6) * 3.0 + 2.0
+ y = _np(bp.dnn.InstanceNorm(6, affine=False)(x))
+ # Normalized over the spatial axis (axis=1) per (sample, channel).
+ assert np.allclose(y.std(axis=1), 1.0, atol=1e-2)
+ assert np.allclose(y.mean(axis=1), 0.0, atol=1e-4)
+
+
+def test_norm_layers_construct_and_run_default_mode():
+ """H-51: BatchNorm1d / LayerNorm / GroupNorm / InstanceNorm construct under the
+ default (non-training) mode and run a forward pass without raising."""
+ x3 = bm.random.randn(2, 4, 3)
+ assert _np(bp.dnn.BatchNorm1d(num_features=3)(x3)).shape == (2, 4, 3)
+ assert _np(bp.dnn.LayerNorm(3)(x3)).shape == (2, 4, 3)
+ x2 = bm.random.randn(2, 6)
+ assert _np(bp.dnn.GroupNorm(3, 6)(x2)).shape == (2, 6)
+ assert _np(bp.dnn.InstanceNorm(6)(x2)).shape == (2, 6)
+
+
+def test_norm_layers_affine_and_aliases():
+ """Affine path + dimensional aliases (BatchNorm*D) construct and run."""
+ x3 = bm.random.randn(2, 4, 3)
+ assert _np(bp.dnn.BatchNorm1D(num_features=3, affine=True)(x3)).shape == (2, 4, 3)
+ assert _np(bp.dnn.LayerNorm(3, elementwise_affine=True)(x3)).shape == (2, 4, 3)
+ assert _np(bp.dnn.GroupNorm(3, 6, affine=True)(bm.random.randn(2, 6))).shape == (2, 6)
+
+
+def test_batchnorm_2d_3d_fit_and_eval():
+ """BatchNorm2d/3d run in both 'fit' (update running stats) and 'eval' modes."""
+ x4 = bm.random.randn(2, 4, 4, 3)
+ x5 = bm.random.randn(2, 4, 4, 4, 3)
+ share.save(fit=True)
+ try:
+ assert _np(bp.dnn.BatchNorm2d(num_features=3)(x4)).shape == (2, 4, 4, 3)
+ assert _np(bp.dnn.BatchNorm3d(num_features=3)(x5)).shape == (2, 4, 4, 4, 3)
+ # axis_name=None path with affine
+ assert _np(bp.dnn.BatchNorm2D(num_features=3, affine=True)(x4)).shape == (2, 4, 4, 3)
+ finally:
+ share.save(fit=False)
+ # eval mode: uses running stats
+ assert _np(bp.dnn.BatchNorm2d(num_features=3)(x4)).shape == (2, 4, 4, 3)
+ assert _np(bp.dnn.BatchNorm3D(num_features=3)(x5)).shape == (2, 4, 4, 4, 3)
+
+
+def test_batchnorm_input_dim_check():
+ """The _check_input_dim guards raise on wrong ndim."""
+ with pytest.raises(ValueError):
+ bp.dnn.BatchNorm1d(num_features=3)(bm.random.randn(2, 3)) # needs 3D
+ with pytest.raises(ValueError):
+ bp.dnn.BatchNorm2d(num_features=3)(bm.random.randn(2, 4, 3)) # needs 4D
+ with pytest.raises(ValueError):
+ bp.dnn.BatchNorm3d(num_features=3)(bm.random.randn(2, 4, 4, 3)) # needs 5D
+
+
+def test_groupnorm_invalid_channels():
+ """num_channels must be divisible by num_groups."""
+ with pytest.raises(ValueError):
+ bp.dnn.GroupNorm(4, 6)
+
+
+def test_layernorm_shape_mismatch_raises():
+ """M-26: LayerNorm raises on a normalized-shape mismatch.
+
+ NOTE: the intended ``ValueError`` message is itself mis-built
+ (``", ".join(self.normalized_shape)`` joins ints, normalization.py:536), so
+ a ``TypeError`` surfaces first. Either way the layer rejects the bad shape;
+ we only assert that *some* exception is raised."""
+ with pytest.raises(Exception):
+ bp.dnn.LayerNorm(5)(bm.random.randn(2, 4, 3))
+
+
+# ===========================================================================
+# 2. comparison.py -- loss functions
+# ===========================================================================
+
+def test_nll_loss_is_positive():
+ """C-02: nll_loss returns the *negative* log-likelihood (positive number)."""
+ log_probs = jnp.log(jnp.array([[0.7, 0.2, 0.1], [0.1, 0.8, 0.1]]))
+ loss = float(C.nll_loss(log_probs, jnp.array([0, 1])))
+ assert loss > 0.0
+ assert abs(loss - 0.2899) < 1e-3
+
+
+def test_nll_loss_reductions():
+ log_probs = jnp.log(jnp.array([[0.7, 0.2, 0.1], [0.1, 0.8, 0.1]]))
+ tgt = jnp.array([0, 1])
+ assert float(C.nll_loss(log_probs, tgt, reduction='sum')) > 0
+ assert _np(C.nll_loss(log_probs, tgt, reduction='none')).shape == (2,)
+ assert _np(C.nll_loss(log_probs, tgt, reduction=None)).shape == (2,)
+ with pytest.raises(ValueError):
+ C.nll_loss(log_probs, tgt, reduction='bogus')
+ # NLLLoss class
+ assert float(C.NLLLoss()(log_probs, tgt)) > 0
+
+
+def test_cross_entropy_weight_by_class():
+ """C-03: class weight is applied by target class, not by sample index."""
+ out = _np(C.cross_entropy_loss(bm.zeros((3, 3)), [2, 2, 2],
+ weight=[10, 20, 1], reduction='none'))
+ # all targets are class 2 -> all per-sample losses equal (w[2] * log(3)).
+ assert np.allclose(out, out[0])
+ assert abs(out[0] - 1.0986) < 1e-3
+
+
+def test_cross_entropy_ignore_index():
+ """H-53: ignore_index excludes the ignored sample from the loss."""
+ bm.random.seed(3)
+ logits = bm.random.randn(4, 3)
+ tgt_all = jnp.array([0, 1, 2, 2])
+ tgt_ign = jnp.array([0, 1, -100, 2])
+ loss_all = float(C.cross_entropy_loss(logits, tgt_all))
+ loss_ign = float(C.cross_entropy_loss(logits, tgt_ign, ignore_index=-100))
+ # Excluding a sample generally changes the mean.
+ assert not np.isclose(loss_all, loss_ign)
+ # Equivalent to averaging only the 3 kept samples.
+ per = _np(C.cross_entropy_loss(logits, tgt_all, reduction='none'))
+ manual = (per[0] + per[1] + per[3]) / 3.0
+ assert abs(loss_ign - manual) < 1e-4
+
+
+def test_cross_entropy_label_smoothing_changes_loss():
+ """H-53: label_smoothing > 0 changes the loss value."""
+ bm.random.seed(4)
+ logits = bm.random.randn(4, 3)
+ tgt = jnp.array([0, 1, 2, 0])
+ base = float(C.cross_entropy_loss(logits, tgt))
+ smoothed = float(C.cross_entropy_loss(logits, tgt, label_smoothing=0.1))
+ assert not np.isclose(base, smoothed)
+
+
+def test_cross_entropy_class_and_loss_wrappers():
+ """CrossEntropyLoss class wraps cross_entropy_loss; reductions + soft targets."""
+ bm.random.seed(5)
+ logits = bm.random.randn(4, 3)
+ tgt = jnp.array([0, 1, 2, 0])
+ assert float(C.CrossEntropyLoss()(logits, tgt)) > 0
+ assert float(C.CrossEntropyLoss(ignore_index=-100)(logits, jnp.array([0, 1, -100, 2]))) > 0
+ assert float(C.CrossEntropyLoss(label_smoothing=0.1)(logits, tgt)) > 0
+ assert float(C.cross_entropy_loss(logits, tgt, reduction='sum')) > 0
+ assert _np(C.cross_entropy_loss(logits, tgt, reduction='none')).shape == (4,)
+ # soft (probability) targets + class weight path.
+ soft = bm.one_hot(tgt, 3)
+ assert float(C.cross_entropy_loss(logits, soft)) > 0
+ assert float(C.cross_entropy_loss(logits, soft, weight=[1., 2., 3.])) > 0
+ assert float(C.cross_entropy_loss(logits, soft, label_smoothing=0.1)) > 0
+
+
+def test_cross_entropy_sparse_and_sigmoid():
+ bm.random.seed(6)
+ logits = bm.random.randn(4, 3)
+ assert _np(C.cross_entropy_sparse(logits, jnp.array([[0], [1], [2], [0]]))).shape == (4,)
+ assert float(jnp.sum(C.cross_entropy_sparse(logits, 1))) != 0.0 # single int target
+ assert _np(C.cross_entropy_sigmoid(logits, bm.random.rand(4, 3))).shape == (4, 3)
+
+
+def test_regression_losses():
+ bm.random.seed(7)
+ pred = bm.random.randn(4, 3)
+ tar = bm.random.randn(4, 3)
+ assert float(C.mean_squared_error(pred, tar)) >= 0
+ assert float(C.mean_squared_error(pred, tar, reduction='sum')) >= 0
+ assert _np(C.mean_squared_error(pred, tar, axis=1, reduction='none')).shape[0] == 4
+ assert float(C.mean_absolute_error(pred, tar)) >= 0
+ assert float(C.mean_squared_log_error(bm.abs(pred), bm.abs(tar))) >= 0
+ assert float(C.l1_loss(pred, tar)) >= 0
+ assert _np(C.l1_loss(pred, tar, reduction='none')).shape == (4,)
+ assert float(jnp.sum(C.l2_loss(pred, tar))) >= 0
+ assert float(jnp.sum(C.huber_loss(pred, tar, delta=0.5))) >= 0
+ assert float(jnp.sum(C.log_cosh_loss(pred, tar))) >= 0
+ # Loss classes
+ assert float(C.MSELoss()(pred, tar)) >= 0
+ assert float(C.MSELoss(reduction='sum')(pred, tar)) >= 0
+ assert float(C.L1Loss()(pred, tar)) >= 0
+ assert float(C.MAELoss(axis=None)(pred, tar)) >= 0
+
+
+def test_classification_helper_losses():
+ bm.random.seed(8)
+ assert float(jnp.sum(C.sigmoid_binary_cross_entropy(bm.random.randn(4, 3), bm.random.rand(4, 3)))) != 0
+ labels = bm.one_hot(jnp.array([0, 1, 2, 0]), 3)
+ assert float(jnp.sum(C.softmax_cross_entropy(bm.random.randn(4, 3), labels))) >= 0
+ assert float(jnp.sum(C.binary_logistic_loss(jnp.array([0.5, 1.2]), jnp.array([0, 1])))) != 0
+ # multiclass_logistic_loss: single int label + (n_classes,) logits.
+ assert float(C.multiclass_logistic_loss(1, bm.random.randn(3))) >= 0
+
+
+def test_multi_margin_loss():
+ bm.random.seed(9)
+ logits = bm.random.randn(4, 3)
+ tgt = jnp.array([0, 1, 2, 0])
+ assert float(C.multi_margin_loss(logits, tgt)) >= 0
+ assert float(C.multi_margin_loss(logits, tgt, p=2)) >= 0
+ assert float(C.multi_margin_loss(logits, tgt, reduction='sum')) >= 0
+ assert _np(C.multi_margin_loss(logits, tgt, reduction='none')).shape == (4, 3)
+ with pytest.raises(AssertionError):
+ C.multi_margin_loss(logits, tgt, p=3)
+
+
+def test_ctc_loss_known_value_bug():
+ """REMAINING BUG: ctc_loss / ctc_loss_with_forward_probs call ``.value`` on a
+ plain JAX array returned by ``bm.log_softmax`` (comparison.py:1006), which
+ raises AttributeError. Pinned here so the regression is documented + covered."""
+ B, T, K, N = 2, 5, 4, 3
+ logits = bm.random.randn(B, T, K)
+ lp = bm.zeros((B, T))
+ labels = jnp.array([[1, 2, 1], [2, 1, 2]])
+ lbp = bm.zeros((B, N))
+ with pytest.raises(AttributeError):
+ C.ctc_loss(logits, lp, labels, lbp)
+ with pytest.raises(AttributeError):
+ C.ctc_loss_with_forward_probs(logits, lp, labels, lbp)
+
+
+# ===========================================================================
+# 3. optimizer.py
+# ===========================================================================
+
+def test_adam_constant_step_under_unit_gradient():
+ """C-01 / H-52: under a constant unit gradient, Adam's per-step |dw| ~= lr and
+ does NOT grow over time (bias correction uses a real step counter)."""
+ lr = 0.01
+ w = bm.Variable(bm.zeros((3,)))
+ opt = bp.optim.Adam(lr=lr, train_vars={'w': w})
+ prev = _np(w.value).copy()
+ deltas = []
+ for _ in range(5):
+ opt.update({'w': bm.ones((3,))})
+ cur = _np(w.value)
+ deltas.append(abs(float(cur[0] - prev[0])))
+ prev = cur.copy()
+ assert np.allclose(deltas, lr, atol=1e-4), f"Adam steps drifted: {deltas}"
+
+
+def test_adamw_constant_step_under_unit_gradient():
+ """C-01 / H-52: same constant-step property for AdamW (weight_decay=0)."""
+ lr = 0.01
+ w = bm.Variable(bm.zeros((3,)))
+ opt = bp.optim.AdamW(lr=lr, train_vars={'w': w}, weight_decay=0.0)
+ prev = _np(w.value).copy()
+ deltas = []
+ for _ in range(5):
+ opt.update({'w': bm.ones((3,))})
+ cur = _np(w.value)
+ deltas.append(abs(float(cur[0] - prev[0])))
+ prev = cur.copy()
+ assert np.allclose(deltas, lr, atol=1e-4), f"AdamW steps drifted: {deltas}"
+
+
+def test_sm3_instantiates_and_updates():
+ """M-29: SM3 instantiates with train_vars and a forward update changes w."""
+ w = bm.Variable(bm.ones((3, 4)))
+ opt = SM3(lr=0.01, train_vars={'w': w})
+ before = _np(w.value).copy()
+ opt.update({'w': bm.ones((3, 4))})
+ assert not np.allclose(before, _np(w.value))
+
+
+def test_sm3_with_momentum_and_weight_decay():
+ """SM3 with momentum>0 (allocates a buffer) and weight_decay both run + change w."""
+ w = bm.Variable(bm.ones((3, 4)))
+ opt = SM3(lr=0.01, train_vars={'w': w}, momentum=0.5, beta=0.5, weight_decay=0.01)
+ before = _np(w.value).copy()
+ opt.update({'w': bm.ones((3, 4))})
+ assert not np.allclose(before, _np(w.value))
+ repr(opt)
+
+
+def test_sm3_invalid_hyperparams():
+ with pytest.raises(ValueError):
+ SM3(lr=0.01, momentum=1.5)
+ with pytest.raises(ValueError):
+ SM3(lr=0.01, beta=1.5)
+ with pytest.raises(ValueError):
+ SM3(lr=0.01, eps=-1.0)
+
+
+@pytest.mark.parametrize("name", ['SGD', 'Momentum', 'MomentumNesterov', 'Adagrad',
+ 'Adadelta', 'RMSProp', 'Adam', 'AdamW', 'LARS'])
+def test_optimizer_construct_and_update(name):
+ """Coverage: every (working) optimizer constructs and runs an update."""
+ cls = getattr(O, name)
+ w = _train_var()
+ opt = cls(lr=0.01, train_vars={'w': w})
+ before = _np(w.value).copy()
+ opt.update({'w': bm.ones((2, 3)) * 0.1})
+ assert _np(w.value).shape == (2, 3)
+ repr(opt)
+ # also run a second update step (exercises EMA / cache update paths)
+ opt.update({'w': bm.ones((2, 3)) * 0.2})
+ assert not np.allclose(before, _np(w.value))
+
+
+@pytest.mark.parametrize("name", ['SGD', 'Momentum', 'MomentumNesterov', 'Adagrad',
+ 'Adadelta', 'RMSProp', 'Adam', 'LARS'])
+def test_optimizer_weight_decay_path(name):
+ """Coverage: the weight_decay branch of each optimizer's update."""
+ cls = getattr(O, name)
+ w = _train_var()
+ opt = cls(lr=0.01, train_vars={'w': w}, weight_decay=0.01)
+ opt.update({'w': bm.ones((2, 3)) * 0.1})
+ assert _np(w.value).shape == (2, 3)
+
+
+def test_adamw_amsgrad_variant():
+ """AdamW amsgrad branch constructs and updates."""
+ w = _train_var()
+ opt = bp.optim.AdamW(lr=0.01, train_vars={'w': w}, amsgrad=True, weight_decay=0.01)
+ opt.update({'w': bm.ones((2, 3)) * 0.1})
+ assert _np(w.value).shape == (2, 3)
+
+
+def test_adamw_invalid_hyperparams():
+ with pytest.raises(ValueError):
+ bp.optim.AdamW(lr=0.01, eps=-1.0)
+ with pytest.raises(ValueError):
+ bp.optim.AdamW(lr=0.01, beta1=1.5)
+ with pytest.raises(ValueError):
+ bp.optim.AdamW(lr=0.01, beta2=1.5)
+ with pytest.raises(ValueError):
+ bp.optim.AdamW(lr=0.01, weight_decay=-1.0)
+
+
+def test_adan_constructs_update_is_known_bug():
+ """Adan constructs, but REMAINING BUG: ``Adan.update`` passes a tuple operand to
+ ``lax.cond`` while the branch lambdas expect two args (optimizer.py:818),
+ raising TypeError. Pinned so the construct path is covered + the bug documented."""
+ w = _train_var()
+ adan = bp.optim.Adan(lr=0.01, train_vars={'w': w})
+ repr(adan)
+ with pytest.raises(TypeError):
+ adan.update({'w': bm.ones((2, 3)) * 0.1})
+
+
+def test_adan_invalid_betas():
+ with pytest.raises(AssertionError):
+ bp.optim.Adan(lr=0.01, betas=(0.1, 0.2)) # len != 3
+ with pytest.raises(ValueError):
+ bp.optim.Adan(lr=0.01, eps=-1.0)
+ with pytest.raises(ValueError):
+ bp.optim.Adan(lr=0.01, betas=(1.5, 0.1, 0.1))
+
+
+def test_optimizer_check_grads_mismatch():
+ """Optimizer.check_grads raises on a length mismatch."""
+ w = _train_var()
+ opt = bp.optim.SGD(lr=0.01, train_vars={'w': w})
+ with pytest.raises(Exception):
+ opt.update({'w': bm.ones((2, 3)), 'extra': bm.ones((2, 3))})
+
+
+def test_optimizer_register_train_vars_type_error():
+ with pytest.raises(Exception):
+ bp.optim.SGD(lr=0.01, train_vars=[bm.Variable(bm.ones((2, 3)))]) # must be dict
+
+
+# ===========================================================================
+# 4. scheduler.py
+# ===========================================================================
+
+def test_multisteplr_decays():
+ """C-04: MultiStepLR actually decays at the milestones."""
+ sch = bp.optim.MultiStepLR(0.1, [10, 20], gamma=0.1)
+ vals = [round(float(sch(i)), 6) for i in [0, 10, 20, 25]]
+ assert vals == [0.1, 0.1, 0.01, 0.001], vals
+ repr(sch)
+
+
+def test_constant_scheduler():
+ assert float(S.Constant(0.1)()) == pytest.approx(0.1)
+ assert float(bp.optim.make_schedule(0.05)()) == pytest.approx(0.05)
+ assert isinstance(bp.optim.make_schedule(S.Constant(0.1)), S.Constant)
+ with pytest.raises(TypeError):
+ bp.optim.make_schedule("not-a-schedule")
+
+
+def test_steplr():
+ sch = S.StepLR(0.1, step_size=10, gamma=0.1)
+ vals = [round(float(sch(i)), 6) for i in [0, 5, 10, 20]]
+ assert vals == [0.1, 0.1, 0.01, 0.001], vals
+ repr(sch)
+
+
+def test_exponential_lr():
+ sch = S.ExponentialLR(0.1, gamma=0.9)
+ assert float(sch(2)) == pytest.approx(0.1 * 0.9 ** 2, abs=1e-6)
+ repr(sch)
+
+
+def test_cosine_annealing_lr():
+ sch = S.CosineAnnealingLR(0.1, T_max=10, eta_min=0.0)
+ assert float(sch(0)) == pytest.approx(0.1, abs=1e-5)
+ assert float(sch(10)) == pytest.approx(0.0, abs=1e-5)
+ assert 0.0 <= float(sch(5)) <= 0.1
+
+
+def test_cosine_warm_restarts():
+ sch = S.CosineAnnealingWarmRestarts(0.1, num_call_per_epoch=2, T_0=4)
+ assert 0.0 <= float(sch(3)) <= 0.1
+ assert 0.0 <= float(sch(10)) <= 0.1 # epoch >= T_0 -> _cond1 (T_mult==1 path)
+ assert float(sch.current_epoch(4)) >= 0
+ with pytest.raises(ValueError):
+ S.CosineAnnealingWarmRestarts(0.1, num_call_per_epoch=2, T_0=0)
+ with pytest.raises(ValueError):
+ S.CosineAnnealingWarmRestarts(0.1, num_call_per_epoch=2, T_0=4, T_mult=0)
+
+
+def test_cosine_warm_restarts_tmult_gt1_known_bug():
+ """REMAINING BUG: with ``T_mult > 1``, ``CosineAnnealingWarmRestarts`` calls
+ ``lax.cond`` whose two branches (``_cond1`` vs ``_cond2``) return mismatched
+ output types (float tuple vs int ``T_0``), raising TypeError under jit
+ (scheduler.py:292-313). Pinned so the construction path stays covered."""
+ sch = S.CosineAnnealingWarmRestarts(0.1, num_call_per_epoch=2, T_0=2, T_mult=2)
+ with pytest.raises(Exception):
+ float(sch(10))
+
+
+def test_exponential_decay_lr():
+ sch = S.ExponentialDecayLR(0.1, decay_steps=10, decay_rate=0.9)
+ assert float(sch(5)) == pytest.approx(0.1 * 0.9 ** 0.5, abs=1e-5)
+ # call-based step advances last_call
+ v0 = float(sch())
+ sch.step_call()
+ v1 = float(sch())
+ assert v1 < v0
+ repr(sch)
+ with pytest.warns(Warning):
+ S.ExponentialDecay(0.1, 10, 0.9)
+
+
+def test_inverse_time_decay_lr():
+ sch = S.InverseTimeDecayLR(0.1, decay_steps=10, decay_rate=0.9)
+ assert float(sch(5)) == pytest.approx(0.1 / (1 + 0.9 * 5 / 10), abs=1e-5)
+ stair = S.InverseTimeDecayLR(0.1, decay_steps=10, decay_rate=0.9, staircase=True)
+ assert float(stair(5)) == pytest.approx(0.1, abs=1e-5)
+ repr(sch)
+ with pytest.warns(Warning):
+ S.InverseTimeDecay(0.1, 10, 0.9)
+
+
+def test_polynomial_decay_lr():
+ sch = S.PolynomialDecayLR(0.1, decay_steps=10, final_lr=0.01)
+ assert float(sch(0)) == pytest.approx(0.1, abs=1e-5)
+ assert float(sch(10)) == pytest.approx(0.01, abs=1e-5)
+ assert float(sch(100)) == pytest.approx(0.01, abs=1e-5) # clamped to decay_steps
+ repr(sch)
+ with pytest.warns(Warning):
+ S.PolynomialDecay(0.1, 10, 0.01)
+
+
+def test_piecewise_constant_lr():
+ sch = S.PiecewiseConstantLR([10, 20], [0.1, 0.01, 0.001])
+ assert float(sch(5)) == pytest.approx(0.1, abs=1e-6)
+ assert float(sch(15)) == pytest.approx(0.01, abs=1e-6)
+ assert float(sch(25)) == pytest.approx(0.001, abs=1e-6)
+ with pytest.warns(Warning):
+ S.PiecewiseConstant([10, 20], [0.1, 0.01, 0.001])
+ from brainpy._errors import MathError
+ with pytest.raises(MathError):
+ S.PiecewiseConstantLR([10, 20], [0.1, 0.01]) # bad lengths
+
+
+def test_scheduler_step_epoch_and_set_value():
+ sch = S.StepLR(0.1, step_size=5)
+ sch.step_epoch()
+ assert int(sch.last_epoch.value) == 0
+ sch.set_value(0.5)
+ assert float(sch.lr) == pytest.approx(0.5)
+
+
+# ===========================================================================
+# 5. stateless_encoding.py + stateful encoders
+# ===========================================================================
+
+def test_poisson_single_step_returns_spikes():
+ """single_step must return a spike array of the same shape (no TypeError)."""
+ out = bp.encoding.PoissonEncoder().single_step(bm.random.rand(4))
+ arr = _np(out)
+ assert arr.shape == (4,)
+ assert set(np.unique(arr)).issubset({0.0, 1.0})
+
+
+def test_poisson_single_step_with_first_spike_time():
+ enc = bp.encoding.PoissonEncoder(first_spk_time=2.0)
+ before = _np(enc.single_step(bm.ones(4), i_step=0))
+ assert np.allclose(before, 0.0) # no spikes before first-spike step
+ after = _np(enc.single_step(bm.ones(4), i_step=100))
+ assert np.allclose(after, 1.0) # prob==1 -> always fires after
+
+
+def test_poisson_normalize_and_multi_steps():
+ enc = bp.encoding.PoissonEncoder(min_val=0.0, max_val=2.0, gain=1.0, offset=0.0)
+ assert _np(enc.single_step(bm.ones(4))).shape == (4,)
+ spikes = enc.multi_steps(bm.random.rand(3), n_time=5.0)
+ assert _np(spikes).shape[1:] == (3,)
+ # n_time=None -> single current step
+ assert _np(enc.multi_steps(bm.random.rand(3), n_time=None)).shape == (3,)
+ # first_spk_step > 0 multi-step branch
+ enc2 = bp.encoding.PoissonEncoder(first_spk_time=1.0)
+ assert _np(enc2.multi_steps(bm.random.rand(3), n_time=5.0)).shape[1:] == (3,)
+
+
+def test_diff_encoder():
+ enc = bp.encoding.DiffEncoder(threshold=1.0)
+ assert _np(enc.multi_steps(bm.array([1., 2., 2.9, 3., 3.9]))).shape == (5,)
+ enc2 = bp.encoding.DiffEncoder(threshold=1.0, padding=True, off_spike=True)
+ assert _np(enc2.multi_steps(bm.array([1., 2., 0., 2., 2.9]))).shape == (5,)
+ with pytest.raises(NotImplementedError):
+ enc.single_step(bm.array([1.0]))
+
+
+def test_latency_encoder():
+ enc = bp.encoding.LatencyEncoder(method='linear', normalize=True)
+ out = enc.multi_steps(bm.array([0.02, 0.5, 1.0]), n_time=5.0)
+ assert _np(out).shape == (50, 3)
+ enc_log = bp.encoding.LatencyEncoder(method='log', clip=True, normalize=True,
+ min_val=0.0, max_val=1.0)
+ assert _np(enc_log.multi_steps(bm.array([0.02, 0.5, 1.0]), n_time=5.0)).shape == (50, 3)
+ with pytest.raises(NotImplementedError):
+ enc.single_step(bm.array([0.5]))
+ with pytest.raises(ValueError):
+ bp.encoding.LatencyEncoder(method='bogus')
+
+
+def test_weighted_phase_encoder():
+ enc = bp.encoding.WeightedPhaseEncoder(min_val=0.0, max_val=1.0, num_phase=4)
+ out = enc(bm.array([0.3, 0.7]), num_step=4)
+ assert _np(out).shape == (4, 2)
+
+
+# ===========================================================================
+# 6. random_conn.py -- FixedProb / FixedPreNum / FixedPostNum / FixedTotalNum
+# ===========================================================================
+
+def test_fixedprob_nonzero_nnz():
+ """M-30: small post population must not silently produce 0 connections."""
+ pre, post = bp.connect.FixedProb(prob=0.3, allow_multi_conn=True)(
+ pre_size=100, post_size=3).build_coo()
+ assert len(_np(pre)) > 0
+
+
+def test_fixedprob_rectangular_include_self_false_no_raise():
+ """M-30: rectangular shape with include_self=False no longer raises a
+ (contradictory) ConnectorError."""
+ conn = bp.connect.FixedProb(prob=0.3, allow_multi_conn=True, include_self=False)(
+ pre_size=100, post_size=3)
+ pre, post = conn.build_coo()
+ assert len(_np(pre)) > 0
+ # build_csr / build_mat also work for the rectangular include_self=False case.
+ conn.build_csr()
+ assert conn.build_mat().shape == (100, 3)
+
+
+def test_fixedprob_all_build_methods():
+ conn = bp.connect.FixedProb(prob=0.4, allow_multi_conn=True)(pre_size=20, post_size=10)
+ pre, post = conn.build_coo()
+ assert len(_np(pre)) > 0
+ idx, indptr = conn.build_csr()
+ assert _np(indptr).shape[0] == 21 # pre_num + 1
+ mat = conn.build_mat()
+ assert mat.shape == (20, 10)
+ repr(conn)
+
+
+def test_fixedprob_pre_ratio_and_include_self():
+ conn = bp.connect.FixedProb(prob=0.5, pre_ratio=0.5, allow_multi_conn=True,
+ include_self=False)(pre_size=20, post_size=20)
+ pre, post = conn.build_coo()
+ assert len(_np(pre)) >= 0
+ conn.build_csr()
+ assert conn.build_mat().shape == (20, 20)
+
+
+def test_fixedprob_invalid_args():
+ with pytest.raises(AssertionError):
+ bp.connect.FixedProb(prob=1.5)
+ with pytest.raises(AssertionError):
+ bp.connect.FixedProb(prob=0.3, pre_ratio=2.0)
+
+
+def test_fixed_pre_num():
+ conn = bp.connect.FixedPreNum(num=3, allow_multi_conn=True)(pre_size=10, post_size=8)
+ pre, post = conn.build_coo()
+ assert len(_np(pre)) > 0
+ conn2 = bp.connect.FixedPreNum(num=0.5, allow_multi_conn=True, include_self=False)(
+ pre_size=10, post_size=10)
+ assert len(_np(conn2.build_coo()[0])) >= 0
+ with pytest.raises(Exception):
+ bp.connect.FixedPreNum(num=100, allow_multi_conn=True)(
+ pre_size=10, post_size=8).build_coo() # num > pre_num
+
+
+def test_fixed_post_num():
+ conn = bp.connect.FixedPostNum(num=3, allow_multi_conn=True)(pre_size=10, post_size=8)
+ pre, post = conn.build_coo()
+ assert len(_np(pre)) > 0
+ idx, indptr = conn.build_csr()
+ assert _np(indptr).shape[0] == 11 # pre_num + 1
+ conn2 = bp.connect.FixedPostNum(num=0.5, allow_multi_conn=True, include_self=False)(
+ pre_size=10, post_size=10)
+ conn2.build_coo()
+ conn2.build_csr()
+
+
+def test_fixed_total_num():
+ conn = bp.connect.FixedTotalNum(num=12, allow_multi_conn=True)(pre_size=10, post_size=8)
+ pre, post = conn.build_coo()
+ assert len(_np(pre)) == 12
+ # no-multi-conn (choice without replacement) path
+ conn2 = bp.connect.FixedTotalNum(num=12, allow_multi_conn=False)(pre_size=10, post_size=8)
+ assert len(_np(conn2.build_coo()[0])) == 12
+ repr(conn)
+ with pytest.raises(Exception):
+ bp.connect.FixedTotalNum(num=1000, allow_multi_conn=True)(
+ pre_size=10, post_size=8).build_coo() # num > all2all
+
+
+def test_connectors_no_multi_conn_paths():
+ """Coverage: the ``allow_multi_conn=False`` (numba choice-without-replacement)
+ build paths of FixedProb / FixedPreNum / FixedPostNum."""
+ fp = bp.connect.FixedProb(prob=0.4, allow_multi_conn=False)(pre_size=20, post_size=10)
+ assert len(_np(fp.build_coo()[0])) > 0
+ fp.build_csr()
+ assert fp.build_mat().shape == (20, 10)
+ # include_self=False square case
+ fp2 = bp.connect.FixedProb(prob=0.5, allow_multi_conn=False, include_self=False)(
+ pre_size=15, post_size=15)
+ fp2.build_coo()
+ fp2.build_csr()
+ assert fp2.build_mat().shape == (15, 15)
+
+ pre_conn = bp.connect.FixedPreNum(num=3, allow_multi_conn=False)(pre_size=12, post_size=8)
+ assert len(_np(pre_conn.build_coo()[0])) > 0
+ pre_conn2 = bp.connect.FixedPreNum(num=3, allow_multi_conn=False, include_self=False)(
+ pre_size=12, post_size=12)
+ pre_conn2.build_coo()
+
+ post_conn = bp.connect.FixedPostNum(num=3, allow_multi_conn=False)(pre_size=12, post_size=8)
+ assert len(_np(post_conn.build_coo()[0])) > 0
+ post_conn.build_csr()
+ post_conn2 = bp.connect.FixedPostNum(num=3, allow_multi_conn=False, include_self=False)(
+ pre_size=12, post_size=12)
+ post_conn2.build_coo()
+ post_conn2.build_csr()
+
+
+def test_connectors_validation_branches():
+ """Coverage: float / bad-type ``num`` validation branches of the connectors."""
+ # float num accepted at construction
+ assert bp.connect.FixedTotalNum(num=0.5).num == 0.5
+ assert bp.connect.FixedPreNum(num=0.3).num == 0.3
+ assert bp.connect.FixedPostNum(num=0.3).num == 0.3
+ # bad type rejected
+ from brainpy._errors import ConnectorError
+ with pytest.raises(ConnectorError):
+ bp.connect.FixedTotalNum(num='x')
+ with pytest.raises(ConnectorError):
+ bp.connect.FixedPreNum(num='x')
+ # FixedPreNum with float num builds (probability interpretation)
+ fp = bp.connect.FixedPreNum(num=0.3, allow_multi_conn=True)(pre_size=10, post_size=8)
+ assert len(_np(fp.build_coo()[0])) > 0
+ # negative integer num rejected
+ with pytest.raises(AssertionError):
+ bp.connect.FixedTotalNum(num=-1)
+ # FixedPostNum num > post_num
+ with pytest.raises(ConnectorError):
+ bp.connect.FixedPostNum(num=100, allow_multi_conn=True)(
+ pre_size=10, post_size=8).build_coo()
+
+
+def test_fixed_pre_post_num_include_self_rectangular_raises():
+ """FixedPreNum / FixedPostNum still reject include_self=False for rectangular
+ (pre_num != post_num) shapes (this guard is intentional for these connectors)."""
+ with pytest.raises(Exception):
+ bp.connect.FixedPreNum(num=3, allow_multi_conn=True, include_self=False)(
+ pre_size=10, post_size=8).build_coo()
+ with pytest.raises(Exception):
+ bp.connect.FixedPostNum(num=3, allow_multi_conn=True, include_self=False)(
+ pre_size=10, post_size=8).build_coo()
+
+
+if __name__ == '__main__':
+ import sys
+ sys.exit(pytest.main([__file__, '-q']))
diff --git a/tests/audit/test_dyn_channels_fixes.py b/tests/audit/test_dyn_channels_fixes.py
new file mode 100644
index 000000000..69c263bd5
--- /dev/null
+++ b/tests/audit/test_dyn_channels_fixes.py
@@ -0,0 +1,363 @@
+"""Regression + coverage tests for the BrainPy v2.7.8 audit (2026-06-18).
+
+This module pins the channel/ion fixes documented in ``docs/issues-found-20260618.md``:
+
+* **C-14** -- Standalone HH/Markov channel gating produced NaN at the voltage
+ singularities (``channels/sodium.py``, ``potassium.py``, ``calcium.py`` and the
+ ``*_compatible`` legacy duplicates). The rate functions of the form
+ ``k * temp / (1 - exp(-temp / d))`` are ``0/0`` exactly at the removable
+ singularity. After the fix they are rewritten with a branch-safe ``exprel``
+ helper so that the value *and* its gradient are finite there.
+ Audit repro: ``IK_HH1952v2(1).f_p_alpha([-55.0])`` returns ~0.1, not nan.
+* **M-17** -- ``PotassiumFixed`` default ``E`` was ``-950`` mV (typo); fixed to
+ ``-95`` mV.
+* **H-33** -- ``dyn/ions/base.py`` registered every channel under the literal name
+ ``"k"`` (``self.add_elem(k=v)``), so channels overwrote each other. The fix
+ (``self.add_elem(**{k: v})``) registers each channel under its real name.
+
+The remaining tests build every public channel class against a Hodgkin-Huxley
+style host and exercise ``reset_state`` / ``update`` for coverage.
+"""
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pytest
+
+import brainpy as bp
+import brainpy.math as bm
+
+
+# Voltage sweep that intentionally includes the removable-singularity voltages.
+_V_SWEEP = jnp.asarray(np.linspace(-120.0, 60.0, 91), dtype=bm.float_)
+
+
+def _assert_finite_with_grad(make_channel, fname, singular_v):
+ """A rate function and its gradient must be finite across a sweep incl. the singular V."""
+ chan = make_channel()
+ f = getattr(chan, fname)
+
+ # Value finite across the whole sweep (which contains the singular voltage).
+ vals = np.asarray(f(_V_SWEEP))
+ assert np.all(np.isfinite(vals)), f"{fname}: non-finite value over sweep: {vals}"
+
+ # Value finite exactly at the singular voltage.
+ sval = np.asarray(f(jnp.asarray([singular_v], dtype=bm.float_)))
+ assert np.all(np.isfinite(sval)), f"{fname}: non-finite value at singular V={singular_v}: {sval}"
+
+ # Gradient finite at the singular voltage (the part bm.where clamping cannot fix).
+ def scalar(v):
+ return getattr(make_channel(), fname)(jnp.asarray([v], dtype=bm.float_))[0]
+
+ g = float(jax.grad(scalar)(float(singular_v)))
+ assert np.isfinite(g), f"{fname}: non-finite gradient at singular V={singular_v}: {g}"
+
+
+# ---------------------------------------------------------------------------
+# C-14 regression: finite value + finite gradient at the singular voltages.
+# Each tuple is (class, rate-function name, singular voltage).
+# The singular voltage is where the argument of the exprel helper is zero.
+# ---------------------------------------------------------------------------
+
+# Legacy/compatible exported classes (master_type = HHTypedNeuron).
+_LEGACY_SINGULAR_CASES = [
+ # IKDR_Ba2002 V_sh=-50: alpha singular at V - V_sh - 15 = 0 -> V = -35
+ ("IKDR_Ba2002", "f_p_alpha", -35.0),
+ # IK_TM1991 V_sh=-60: alpha singular at 15 - V + V_sh = 0 -> V = -45
+ ("IK_TM1991", "f_p_alpha", -45.0),
+ # IK_HH1952 V_sh=-45: alpha singular at V - V_sh + 10 = 0 -> V = -55
+ ("IK_HH1952", "f_p_alpha", -55.0),
+ # INa_Ba2002 V_sh=-50: p_alpha V-V_sh-13=0 -> -37 ; p_beta V-V_sh-40=0 -> -10
+ ("INa_Ba2002", "f_p_alpha", -37.0),
+ ("INa_Ba2002", "f_p_beta", -10.0),
+ # INa_TM1991 V_sh=-63: p_alpha 13-V+V_sh=0 -> -50 ; p_beta V-V_sh-40=0 -> -23
+ ("INa_TM1991", "f_p_alpha", -50.0),
+ ("INa_TM1991", "f_p_beta", -23.0),
+ # INa_HH1952 V_sh=-45: p_alpha V-V_sh-5=0 -> -40
+ ("INa_HH1952", "f_p_alpha", -40.0),
+]
+
+# v2 classes (master_type = an Ion subtype).
+_V2_SINGULAR_CASES = [
+ ("IKDR_Ba2002v2", "f_p_alpha", -35.0),
+ ("IK_TM1991v2", "f_p_alpha", -45.0),
+ ("IK_HH1952v2", "f_p_alpha", -55.0),
+ ("INa_Ba2002v2", "f_p_alpha", -37.0),
+ ("INa_Ba2002v2", "f_p_beta", -10.0),
+ ("INa_TM1991v2", "f_p_alpha", -50.0),
+ ("INa_TM1991v2", "f_p_beta", -23.0),
+ ("INa_HH1952v2", "f_p_alpha", -40.0),
+ # ICaHT_Re1993 (markov v2, calcium) V_sh=0: p_alpha -27-V+V_sh=0 -> V=-27
+ ("ICaHT_Re1993", "f_p_alpha", -27.0),
+]
+
+
+@pytest.mark.parametrize("clsname,fname,singular_v", _LEGACY_SINGULAR_CASES)
+def test_c14_legacy_rate_finite_and_grad(clsname, fname, singular_v):
+ cls = getattr(bp.dyn, clsname)
+ _assert_finite_with_grad(lambda: cls(1), fname, singular_v)
+
+
+@pytest.mark.parametrize("clsname,fname,singular_v", _V2_SINGULAR_CASES)
+def test_c14_v2_rate_finite_and_grad(clsname, fname, singular_v):
+ cls = getattr(bp.dyn, clsname)
+ _assert_finite_with_grad(lambda: cls(1), fname, singular_v)
+
+
+def test_c14_ik_hh1952v2_audit_repro():
+ """Exact audit repro: was [nan], must now be ~0.1."""
+ val = np.asarray(bp.dyn.IK_HH1952v2(1).f_p_alpha(jnp.asarray([-55.0], dtype=bm.float_)))
+ assert np.all(np.isfinite(val))
+ assert np.allclose(val, 0.1, atol=1e-4)
+
+
+def test_c14_all_defined_rate_functions_finite_over_sweep():
+ """Sweep every rate function each markov/ss channel defines; none may produce NaN."""
+ classes = [
+ bp.dyn.IKDR_Ba2002, bp.dyn.IK_TM1991, bp.dyn.IK_HH1952,
+ bp.dyn.INa_Ba2002, bp.dyn.INa_TM1991, bp.dyn.INa_HH1952,
+ bp.dyn.IKDR_Ba2002v2, bp.dyn.IK_TM1991v2, bp.dyn.IK_HH1952v2,
+ bp.dyn.INa_Ba2002v2, bp.dyn.INa_TM1991v2, bp.dyn.INa_HH1952v2,
+ bp.dyn.ICaHT_Re1993,
+ ]
+ for cls in classes:
+ chan = cls(1)
+ for fname in ("f_p_alpha", "f_p_beta", "f_q_alpha", "f_q_beta"):
+ f = getattr(chan, fname, None)
+ if f is None:
+ continue
+ try:
+ vals = np.asarray(f(_V_SWEEP))
+ except NotImplementedError:
+ continue
+ assert np.all(np.isfinite(vals)), f"{cls.__name__}.{fname} produced NaN/inf"
+
+
+# ---------------------------------------------------------------------------
+# M-17 regression: PotassiumFixed default reversal potential.
+# ---------------------------------------------------------------------------
+
+def test_m17_potassium_fixed_default_E():
+ ion = bp.dyn.PotassiumFixed(1)
+ assert np.allclose(np.asarray(ion.E), -95.0), f"expected -95 mV, got {ion.E}"
+ # And not the historical buggy value.
+ assert not np.allclose(np.asarray(ion.E), -950.0)
+
+
+# ---------------------------------------------------------------------------
+# H-33 regression: MixIons / Ion.add_elem must register channels under their
+# real names instead of collapsing them all under the literal key "k".
+# ---------------------------------------------------------------------------
+
+def test_h33_ion_add_elem_keeps_distinct_names():
+ na = bp.dyn.SodiumFixed(2, E=50.)
+ ina_hh = bp.dyn.INa_HH1952v2(2)
+ ina_ba = bp.dyn.INa_Ba2002v2(2)
+ na.add_elem(ina_hh=ina_hh, ina_ba=ina_ba)
+
+ # Both channels stored under their real names, none lost to a literal "k".
+ assert set(na.children.keys()) == {"ina_hh", "ina_ba"}
+ assert "k" not in na.children
+ assert na.children["ina_hh"] is ina_hh
+ assert na.children["ina_ba"] is ina_ba
+ assert ina_hh.name != ina_ba.name
+
+
+def test_h33_mixions_registers_distinct_multiion_channels():
+ """Build a MixIons over a Calcium+Potassium JointType channel (IAHP_De1994v2).
+
+ Two distinct channels must coexist under their real names; before the fix
+ the second would overwrite the first under the key "k".
+ """
+ bm.set_dt(0.1)
+
+ class Net(bp.dyn.CondNeuGroupLTC):
+ def __init__(self, size):
+ super().__init__(size)
+ self.Ca = bp.dyn.CalciumDetailed(size)
+ self.K = bp.dyn.PotassiumFixed(size, E=-95.)
+ self.KCa = bp.dyn.MixIons(self.Ca, self.K)
+ self.KCa.add_elem(iahp1=bp.dyn.IAHP_De1994v2(size),
+ iahp2=bp.dyn.IAHP_De1994v2(size))
+
+ net = Net(2)
+ net.reset_state()
+ assert set(net.KCa.children.keys()) == {"iahp1", "iahp2"}
+ assert "k" not in net.KCa.children
+
+ bp.share.save(t=0., dt=0.1, i=0)
+ net.update(bm.ones(2) * -50.)
+ cur = np.asarray(net.KCa.current(net.V.value))
+ assert np.all(np.isfinite(cur))
+
+
+def test_h33_mix_ions_helper_and_ion_current_paths():
+ bm.set_dt(0.1)
+ na = bp.dyn.SodiumFixed(2, E=50.)
+ na.add_elem(ina=bp.dyn.INa_HH1952v2(2))
+ k = bp.dyn.PotassiumFixed(2, E=-95.)
+ k.add_elem(ik=bp.dyn.IK_HH1952v2(2))
+
+ mix = bp.dyn.mix_ions(na, k)
+ assert isinstance(mix, bp.dyn.MixIons)
+
+ V = bm.ones(2) * -65.
+ na.reset_state(V)
+ k.reset_state(V)
+ # Exercise Ion.update / Ion.current / pack_info on ions/base.py + ions/potassium.py.
+ bp.share.save(t=0., dt=0.1, i=0)
+ na.update(V)
+ k.update(V)
+ assert np.all(np.isfinite(np.asarray(k.current(V))))
+ assert set(k.pack_info().keys()) == {"C", "E"}
+
+
+# ---------------------------------------------------------------------------
+# Coverage: construct & exercise every public channel against an HH-style host.
+# ---------------------------------------------------------------------------
+
+def _run_cond_neu_group(net, n_steps=2):
+ net.reset_state()
+ bp.share.save(t=0., dt=bm.get_dt(), i=0)
+ for i in range(n_steps):
+ bp.share.save(t=i * bm.get_dt(), i=i)
+ net.update(bm.ones(net.num) * 1.0)
+
+
+def test_coverage_compatible_potassium_and_sodium_channels():
+ """Legacy/compatible channels are hosted directly by a CondNeuGroup (HH neuron)."""
+ bm.set_dt(0.1)
+
+ class Net(bp.dyn.CondNeuGroup):
+ def __init__(self, size):
+ super().__init__(size)
+ # sodium_compatible.py
+ self.INa_Ba = bp.dyn.INa_Ba2002(size)
+ self.INa_TM = bp.dyn.INa_TM1991(size)
+ self.INa_HH = bp.dyn.INa_HH1952(size)
+ # potassium_compatible.py
+ self.IKDR = bp.dyn.IKDR_Ba2002(size)
+ self.IK_TM = bp.dyn.IK_TM1991(size)
+ self.IK_HH = bp.dyn.IK_HH1952(size)
+ self.IKA1 = bp.dyn.IKA1_HM1992(size)
+ self.IKA2 = bp.dyn.IKA2_HM1992(size)
+ self.IKK2A = bp.dyn.IKK2A_HM1992(size)
+ self.IKK2B = bp.dyn.IKK2B_HM1992(size)
+ self.IKNI = bp.dyn.IKNI_Ya1989(size)
+ self.IKL = bp.dyn.IKL(size)
+
+ _run_cond_neu_group(Net(2))
+
+
+def test_coverage_v2_sodium_and_potassium_channels():
+ """v2 channels (potassium.py / sodium.py) are hosted by Sodium/Potassium ions."""
+ bm.set_dt(0.1)
+
+ class Net(bp.dyn.CondNeuGroupLTC):
+ def __init__(self, size):
+ super().__init__(size)
+ self.Na = bp.dyn.SodiumFixed(size, E=50.)
+ self.Na.add_elem(
+ ina_ba=bp.dyn.INa_Ba2002v2(size),
+ ina_tm=bp.dyn.INa_TM1991v2(size),
+ ina_hh=bp.dyn.INa_HH1952v2(size),
+ )
+ self.K = bp.dyn.PotassiumFixed(size, E=-95.)
+ self.K.add_elem(
+ ikdr=bp.dyn.IKDR_Ba2002v2(size),
+ ik_tm=bp.dyn.IK_TM1991v2(size),
+ ik_hh=bp.dyn.IK_HH1952v2(size),
+ ika1=bp.dyn.IKA1_HM1992v2(size),
+ ika2=bp.dyn.IKA2_HM1992v2(size),
+ ikk2a=bp.dyn.IKK2A_HM1992v2(size),
+ ikk2b=bp.dyn.IKK2B_HM1992v2(size),
+ ikni=bp.dyn.IKNI_Ya1989v2(size),
+ ik_leak=bp.dyn.IK_Leak(size),
+ )
+
+ _run_cond_neu_group(Net(2))
+
+
+def test_coverage_potassium_module_legacy_classes():
+ """potassium.py also ships HHTypedNeuron-hosted legacy duplicates that are
+ shadowed at the ``bp.dyn`` level by potassium_compatible.py. Import them
+ directly from the module so the in-file C-14 fixes are exercised too.
+ """
+ import brainpy.dyn.channels.potassium as kmod
+ bm.set_dt(0.1)
+
+ class Net(bp.dyn.CondNeuGroup):
+ def __init__(self, size):
+ super().__init__(size)
+ self.IKDR = kmod.IKDR_Ba2002(size)
+ self.IK_TM = kmod.IK_TM1991(size)
+ self.IK_HH = kmod.IK_HH1952(size)
+ self.IKA1 = kmod.IKA1_HM1992(size)
+ self.IKA2 = kmod.IKA2_HM1992(size)
+ self.IKK2A = kmod.IKK2A_HM1992(size)
+ self.IKK2B = kmod.IKK2B_HM1992(size)
+ self.IKNI = kmod.IKNI_Ya1989(size)
+
+ _run_cond_neu_group(Net(2))
+
+ # The in-file legacy rate functions must also be NaN-free at their singular V.
+ _assert_finite_with_grad(lambda: kmod.IK_HH1952(1), "f_p_alpha", -55.0)
+ _assert_finite_with_grad(lambda: kmod.IKDR_Ba2002(1), "f_p_alpha", -35.0)
+ _assert_finite_with_grad(lambda: kmod.IK_TM1991(1), "f_p_alpha", -45.0)
+
+
+def test_coverage_calcium_fixed_channels():
+ """Voltage-gated calcium channels hosted by a (fixed) Calcium ion."""
+ bm.set_dt(0.1)
+
+ class Net(bp.dyn.CondNeuGroupLTC):
+ def __init__(self, size):
+ super().__init__(size)
+ self.Ca = bp.dyn.CalciumFixed(size)
+ self.Ca.add_elem(
+ icat_hm=bp.dyn.ICaT_HM1992(size),
+ icat_hp=bp.dyn.ICaT_HP1992(size),
+ icaht_hm=bp.dyn.ICaHT_HM1992(size),
+ icaht_re=bp.dyn.ICaHT_Re1993(size),
+ ical=bp.dyn.ICaL_IS2008(size),
+ )
+
+ _run_cond_neu_group(Net(2))
+
+
+def test_coverage_calcium_dyna_channel_ican():
+ """ICaN_IS2008 requires a CalciumDyna host; exercise it for coverage."""
+ bm.set_dt(0.1)
+
+ class Net(bp.dyn.CondNeuGroupLTC):
+ def __init__(self, size):
+ super().__init__(size)
+ self.Ca = bp.dyn.CalciumDetailed(size)
+ self.Ca.add_elem(ican=bp.dyn.ICaN_IS2008(size))
+
+ _run_cond_neu_group(Net(2))
+
+
+def test_coverage_ion_objects_reset_and_current():
+ """Construct PotassiumFixed / CalciumFixed / SodiumFixed and exercise the ion API."""
+ bm.set_dt(0.1)
+ V = bm.ones(2) * -65.
+
+ k = bp.dyn.PotassiumFixed(2, E=-95.)
+ k.add_elem(ik=bp.dyn.IK_HH1952v2(2))
+ na = bp.dyn.SodiumFixed(2, E=50.)
+ na.add_elem(ina=bp.dyn.INa_HH1952v2(2))
+ ca = bp.dyn.CalciumFixed(2)
+ ca.add_elem(ical=bp.dyn.ICaL_IS2008(2))
+ ca_dyn = bp.dyn.CalciumFirstOrder(2)
+
+ for ion in (k, na, ca):
+ ion.reset_state(V)
+ bp.share.save(t=0., dt=0.1, i=0)
+ ion.update(V)
+ cur = np.asarray(ion.current(V))
+ assert np.all(np.isfinite(cur))
+
+ # CalciumDyna reset path (different reset_state signature).
+ ca_dyn.reset_state(V)
+ assert np.all(np.isfinite(np.asarray(ca_dyn.C)))
diff --git a/tests/audit/test_dyn_neurons_synapses_fixes.py b/tests/audit/test_dyn_neurons_synapses_fixes.py
new file mode 100644
index 000000000..c5c0e1e0c
--- /dev/null
+++ b/tests/audit/test_dyn_neurons_synapses_fixes.py
@@ -0,0 +1,529 @@
+"""Regression + coverage tests for the 2026-06-18 BrainPy audit.
+
+This module targets the fixes documented in ``docs/issues-found-20260618.md``
+for the dyn neurons / synapses / projections and dnn/linear subsystems:
+
+ * H-34 ``dyn/neurons/lif.py`` -- ``ExpIFRef`` honours ``noise=`` (sdeint vs odeint).
+ * H-35 ``dyn/neurons/lif.py`` -- ``IzhikevichRef`` / ``GifRef`` ``detach_spk`` actually
+ cuts the spike gradient path.
+ * C-06/H-39/M-22 ``dyn/synapses/abstract_models.py`` -- ``STP`` facilitation no longer diverges;
+ discrete jumps applied to decayed locals; u/x stay in [0, 1].
+ * C-17 ``dyn/projections/inputs.py`` -- ``PoissonInput`` Gaussian std == sqrt(n*p*(1-p)).
+ * C-18 ``dyn/projections/align_post.py`` -- ``HalfProjAlignPost`` calls ``comm`` exactly once.
+ * C-19/H-41 ``dyn/projections/plasticity.py`` + ``dnn/linear.py`` -- ``STDP_Song2000`` runs and
+ changes the (promoted ``Variable``) weight; ``W_min=W_max=None``
+ does not crash; bounds clamp when set.
+ * H-40 ``dyn/projections/base.py`` -- module re-exports the real base classes.
+
+Plus breadth coverage tests that construct and run one ``update()`` for every public
+neuron class in ``lif.py`` (and LTC variants), the abstract synapse models, the
+projection wrappers, and the dnn/linear comm layers.
+
+Run from the worktree root with ``PYTHONPATH=.``.
+"""
+
+import numpy as np
+import jax
+import jax.numpy as jnp
+import pytest
+
+import brainpy as bp
+import brainpy.math as bm
+from brainpy.integrators.ode.base import ODEIntegrator
+from brainpy.integrators.sde.base import SDEIntegrator
+
+
+def _share(t=0.0, dt=0.1, i=0):
+ bp.share.save(t=t, dt=dt, i=i)
+
+
+# ---------------------------------------------------------------------------
+# H-34: ExpIFRef honours noise= (sdeint) and is plain odeint otherwise.
+# ---------------------------------------------------------------------------
+
+def test_expifref_noise_uses_sde_integrator():
+ noisy = bp.dyn.ExpIFRef(3, noise=0.5)
+ plain = bp.dyn.ExpIFRef(3)
+ assert isinstance(noisy.integral, SDEIntegrator)
+ assert not isinstance(noisy.integral, ODEIntegrator)
+ assert isinstance(plain.integral, ODEIntegrator)
+ assert not isinstance(plain.integral, SDEIntegrator)
+
+
+def test_expifref_noisy_update_runs():
+ bm.random.seed(0)
+ neu = bp.dyn.ExpIFRef(3, noise=1.0)
+ neu.reset_state()
+ _share()
+ neu.update(bm.ones(3) * 5.0)
+ assert jnp.all(jnp.isfinite(neu.V.value))
+
+
+# ---------------------------------------------------------------------------
+# H-35: IzhikevichRef / GifRef detach_spk actually cuts the spike gradient.
+# ---------------------------------------------------------------------------
+
+def _one_step_grad_izhikevich(detach, v0):
+ neu = bp.dyn.IzhikevichRef(3, mode=bm.training_mode, detach_spk=detach)
+
+ def loss(inp):
+ neu.reset_state(bm.training_mode)
+ neu.V.value = bm.ones(neu.V.shape) * v0
+ _share()
+ neu.update(inp)
+ return jnp.sum(neu.V.value)
+
+ return jax.grad(loss)(bm.zeros(3))
+
+
+def test_izhikevichref_detach_spk_changes_gradient():
+ # With V driven above threshold (~30 mV) the spike path is active, so
+ # detaching the spike (cutting it) changes the gradient w.r.t. the input.
+ v0 = 30.0
+ g_detach = _one_step_grad_izhikevich(True, v0)
+ g_plain = _one_step_grad_izhikevich(False, v0)
+ assert jnp.all(jnp.isfinite(g_detach))
+ assert jnp.all(jnp.isfinite(g_plain))
+ # detach_spk cuts the spike contribution -> gradient differs from the
+ # grad-carrying path.
+ assert not jnp.allclose(g_detach, g_plain)
+
+
+def _one_step_grad_gif(detach, v0):
+ neu = bp.dyn.GifRef(3, mode=bm.training_mode, detach_spk=detach)
+
+ def loss(inp):
+ neu.reset_state(bm.training_mode)
+ neu.V.value = bm.ones(neu.V.shape) * v0
+ _share()
+ neu.update(inp)
+ # include the adaptation/threshold states the spike resets touch.
+ return (jnp.sum(neu.V.value) + jnp.sum(neu.I1.value)
+ + jnp.sum(neu.I2.value) + jnp.sum(neu.V_th.value))
+
+ return jax.grad(loss)(bm.zeros(3))
+
+
+def test_gifref_detach_spk_changes_gradient():
+ v0 = -50.0 # drives V across the GIF threshold within one step
+ g_detach = _one_step_grad_gif(True, v0)
+ g_plain = _one_step_grad_gif(False, v0)
+ assert jnp.all(jnp.isfinite(g_detach))
+ assert jnp.all(jnp.isfinite(g_plain))
+ assert not jnp.allclose(g_detach, g_plain)
+
+
+# ---------------------------------------------------------------------------
+# C-06 / H-39 / M-22: STP short-term facilitation does not diverge.
+# ---------------------------------------------------------------------------
+
+def test_stp_no_spike_u_does_not_increase():
+ stp = bp.dyn.STP(1)
+ u0 = float(stp.u.value[0])
+ no_spike = bm.zeros(1, dtype=bool)
+ for _ in range(5):
+ _share()
+ stp.update(no_spike)
+ # With no spike, facilitation u must decay/stay -- never grow.
+ assert float(stp.u.value[0]) <= u0 + 1e-9
+
+
+def test_stp_spiking_u_x_stay_bounded():
+ stp = bp.dyn.STP(1)
+ spike = bm.ones(1, dtype=bool)
+ us, xs = [], []
+ for i in range(50):
+ _share(t=i * 0.1, i=i)
+ stp.update(spike)
+ us.append(float(stp.u.value[0]))
+ xs.append(float(stp.x.value[0]))
+ assert min(us) >= 0.0 and max(us) <= 1.0
+ assert min(xs) >= 0.0 and max(xs) <= 1.0
+ # released resource u*x stays finite and modest (used to explode to thousands)
+ assert max(u * x for u, x in zip(us, xs)) < 1.0
+
+
+# ---------------------------------------------------------------------------
+# C-17: PoissonInput Gaussian-approx std == sqrt(n*p*(1-p)).
+# ---------------------------------------------------------------------------
+
+def test_poisson_input_gaussian_std_matches_binomial():
+ bm.random.seed(0)
+ n_neuron = 4000
+ n_input = 1000
+ freq = 200.0 # Hz
+ dt = 0.1
+ target = bm.Variable(bm.zeros(n_neuron))
+ _share(dt=dt)
+ pin = bp.dyn.PoissonInput(target_var=target, num_input=n_input, freq=freq, weight=1.0)
+
+ p = freq * dt / 1e3
+ # ensure we exercise the Gaussian branch (a > 5 and b > 5)
+ assert n_input * p > 5 and n_input * (1 - p) > 5
+
+ target.value = bm.zeros(n_neuron)
+ pin.update()
+ samp = np.asarray(target.value)
+
+ expected_std = np.sqrt(n_input * p * (1 - p))
+ expected_mean = n_input * p
+ # std must be the binomial std, NOT the variance (the old bug was ~4x too big).
+ assert samp.std() == pytest.approx(expected_std, rel=0.1)
+ assert samp.mean() == pytest.approx(expected_mean, rel=0.1)
+ # guard against the regression where std == variance (b*p).
+ assert samp.std() < 0.5 * (n_input * p * (1 - p))
+
+
+# ---------------------------------------------------------------------------
+# C-18: HalfProjAlignPost calls comm exactly once per update.
+# ---------------------------------------------------------------------------
+
+class _CountingLinear(bp.dnn.Linear):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.n_calls = 0
+
+ def update(self, x):
+ self.n_calls += 1
+ return super().update(x)
+
+
+def test_halfprojalignpost_calls_comm_once():
+ post = bp.dyn.LifRef(4)
+ comm = _CountingLinear(4, 4, W_initializer=bp.init.Constant(0.1))
+ proj = bp.dyn.HalfProjAlignPost(
+ comm=comm,
+ syn=bp.dyn.Expon(4, tau=5.0),
+ out=bp.dyn.COBA(E=0.0),
+ post=post,
+ )
+ _share()
+ current = proj.update(bm.ones(4))
+ assert comm.n_calls == 1
+ # returned current equals one manual comm application.
+ expected = comm.update(bm.ones(4)) # second call only for comparison
+ assert jnp.allclose(jnp.asarray(current), jnp.asarray(expected))
+
+
+# ---------------------------------------------------------------------------
+# C-19 / H-41: STDP over a Linear/AllToAll comm runs and changes the weight.
+# ---------------------------------------------------------------------------
+
+def _build_stdp_net(w_min=None, w_max=None):
+ pre = bp.dyn.LifRef(3)
+ post = bp.dyn.LifRef(4)
+ syn = bp.dyn.STDP_Song2000(
+ pre=pre,
+ delay=1.0,
+ comm=bp.dnn.AllToAll(3, 4, weight=bp.init.Uniform(max_val=0.1)),
+ syn=bp.dyn.Expon.desc(post.varshape, tau=5.0),
+ out=bp.dyn.COBA.desc(E=0.0),
+ post=post,
+ W_min=w_min,
+ W_max=w_max,
+ )
+ net = bp.DynSysGroup(pre=pre, post=post, syn=syn)
+ net.reset_state()
+ return net, pre, post, syn
+
+
+def test_stdp_song2000_runs_and_changes_weight():
+ net, pre, post, syn = _build_stdp_net()
+ # weight starts as a plain array (comm built outside a TrainingMode).
+ assert not isinstance(syn.comm.weight, bm.Variable)
+ w0 = bm.as_jax(syn.comm.weight).copy()
+
+ for i in range(40):
+ _share(t=i * 0.1, i=i)
+ syn()
+ pre(bm.ones(3) * 80.0)
+ post(bm.ones(4) * 80.0)
+
+ # weight is promoted to a Variable on the first stdp_update and updates.
+ assert isinstance(syn.comm.weight, bm.Variable)
+ w1 = bm.as_jax(syn.comm.weight)
+ assert not jnp.allclose(w0, w1)
+ assert jnp.all(jnp.isfinite(w1))
+
+
+def test_stdp_song2000_default_none_bounds_do_not_crash():
+ # W_min == W_max == None used to crash on bm.as_jax(None) (H-41).
+ net, pre, post, syn = _build_stdp_net(w_min=None, w_max=None)
+ _share()
+ syn()
+ pre(bm.ones(3) * 80.0)
+ post(bm.ones(4) * 80.0)
+ assert jnp.all(jnp.isfinite(bm.as_jax(syn.comm.weight)))
+
+
+def test_stdp_song2000_bounds_clamp_weight():
+ w_max = 0.2
+ w_min = 0.0
+ net, pre, post, syn = _build_stdp_net(w_min=w_min, w_max=w_max)
+ for i in range(60):
+ _share(t=i * 0.1, i=i)
+ syn()
+ pre(bm.ones(3) * 80.0)
+ post(bm.ones(4) * 80.0)
+ w = bm.as_jax(syn.comm.weight)
+ assert jnp.all(w <= w_max + 1e-5)
+ assert jnp.all(w >= w_min - 1e-5)
+
+
+# ---------------------------------------------------------------------------
+# H-40: projections/base.py re-exports real base classes (no dead duplicate).
+# ---------------------------------------------------------------------------
+
+def test_projections_base_reexports_real_classes():
+ from brainpy.dyn.projections import base as proj_base
+ assert hasattr(proj_base, 'Projection')
+ assert hasattr(proj_base, 'SynConn')
+ assert 'Projection' in proj_base.__all__
+ assert 'SynConn' in proj_base.__all__
+ # the re-exported Projection is the canonical one used by the wrappers.
+ assert proj_base.Projection is bp.dyn.Projection
+
+
+# ---------------------------------------------------------------------------
+# Coverage: every public neuron class in lif.py, both modes.
+# ---------------------------------------------------------------------------
+
+_NEURON_NAMES = [
+ 'IF', 'Lif', 'LifRef',
+ 'ExpIF', 'ExpIFRef',
+ 'AdExIF', 'AdExIFRef',
+ 'QuaIF', 'QuaIFRef',
+ 'AdQuaIF', 'AdQuaIFRef',
+ 'Gif', 'GifRef',
+ 'Izhikevich', 'IzhikevichRef',
+]
+_NEURON_NAMES = _NEURON_NAMES + [n + 'LTC' for n in _NEURON_NAMES]
+
+
+@pytest.mark.parametrize('name', _NEURON_NAMES)
+@pytest.mark.parametrize('mode', [None, 'train'])
+def test_neuron_update_runs(name, mode):
+ cls = getattr(bp.dyn, name)
+ if mode == 'train':
+ neu = cls(3, mode=bm.training_mode)
+ neu.reset_state(bm.training_mode)
+ else:
+ neu = cls(3)
+ neu.reset_state()
+ # small drive: the exponential-IF family genuinely overflows V in a single
+ # large training-mode step (no hard refractory clamp before the surrogate
+ # spike), which is a model property, not an audit regression.
+ _share()
+ out = neu.update(bm.ones(3) * 1.0)
+ assert out is not None
+ assert neu.V.value.shape[-1] == 3
+ if mode is None:
+ # default (non-training) mode applies the hard reset, so V stays finite.
+ assert jnp.all(jnp.isfinite(neu.V.value))
+
+
+# ---------------------------------------------------------------------------
+# Coverage: abstract synapse models forward passes.
+# ---------------------------------------------------------------------------
+
+@pytest.mark.parametrize('name', ['Expon', 'DualExpon', 'Alpha', 'NMDA', 'AMPA', 'STP', 'STD'])
+def test_synapse_forward(name):
+ cls = getattr(bp.dyn, name)
+ syn = cls(3)
+ spike_bool = bm.asarray([True, False, True], dtype=bool)
+ spike_float = spike_bool.astype(float)
+ _share()
+ if name in ('NMDA', 'AMPA'):
+ out = syn(spike_float)
+ else:
+ out = syn(spike_bool)
+ assert jnp.asarray(out).shape == (3,)
+ assert jnp.all(jnp.isfinite(jnp.asarray(out)))
+
+
+# ---------------------------------------------------------------------------
+# Coverage: projection wrappers.
+# ---------------------------------------------------------------------------
+
+def test_full_proj_align_post():
+ class Net(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ self.pre = bp.dyn.LifRef(4)
+ self.post = bp.dyn.LifRef(3)
+ self.proj = bp.dyn.FullProjAlignPost(
+ pre=self.pre, delay=None,
+ comm=bp.dnn.AllToAll(4, 3, weight=0.1),
+ syn=bp.dyn.Expon(3, tau=5.0),
+ out=bp.dyn.COBA(E=0.0),
+ post=self.post,
+ )
+
+ def update(self, inp):
+ self.proj()
+ self.pre(inp)
+ self.post()
+ return self.post.V.value
+
+ net = Net()
+ net.reset_state()
+ _share()
+ out = net.update(bm.ones(4) * 5.0)
+ assert jnp.all(jnp.isfinite(out))
+
+
+def test_full_proj_align_post_mg():
+ class Net(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ self.pre = bp.dyn.LifRef(4)
+ self.post = bp.dyn.LifRef(3)
+ self.proj = bp.dyn.FullProjAlignPostMg(
+ pre=self.pre, delay=None,
+ comm=bp.dnn.AllToAll(4, 3, weight=0.1),
+ syn=bp.dyn.Expon.desc(3, tau=5.0),
+ out=bp.dyn.COBA.desc(E=0.0),
+ post=self.post,
+ )
+
+ def update(self, inp):
+ self.proj()
+ self.pre(inp)
+ self.post()
+ return self.post.V.value
+
+ net = Net()
+ net.reset_state()
+ _share()
+ out = net.update(bm.ones(4) * 5.0)
+ assert jnp.all(jnp.isfinite(out))
+
+
+def test_full_proj_align_pre_sdmg():
+ class Net(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ self.pre = bp.dyn.LifRef(4)
+ self.post = bp.dyn.LifRef(3)
+ self.proj = bp.dyn.FullProjAlignPreSDMg(
+ pre=self.pre,
+ syn=bp.dyn.Expon.desc(4, tau=5.0),
+ delay=None,
+ comm=bp.dnn.AllToAll(4, 3, weight=0.1),
+ out=bp.dyn.COBA(E=0.0),
+ post=self.post,
+ )
+
+ def update(self, inp):
+ self.proj()
+ self.pre(inp)
+ self.post()
+ return self.post.V.value
+
+ net = Net()
+ net.reset_state()
+ _share()
+ out = net.update(bm.ones(4) * 5.0)
+ assert jnp.all(jnp.isfinite(out))
+
+
+def test_half_proj_align_post_mg():
+ post = bp.dyn.LifRef(3)
+ proj = bp.dyn.HalfProjAlignPostMg(
+ comm=bp.dnn.AllToAll(4, 3, weight=0.1),
+ syn=bp.dyn.Expon.desc(3, tau=5.0),
+ out=bp.dyn.COBA.desc(E=0.0),
+ post=post,
+ )
+ _share()
+ out = proj.update(bm.ones(4))
+ assert jnp.all(jnp.isfinite(jnp.asarray(out)))
+
+
+# ---------------------------------------------------------------------------
+# Coverage: PoissonInput full update path (binomial small-N branch too).
+# ---------------------------------------------------------------------------
+
+def test_poisson_input_small_n_branch():
+ bm.random.seed(1)
+ target = bm.Variable(bm.zeros(8))
+ _share(dt=0.1)
+ pin = bp.dyn.PoissonInput(target_var=target, num_input=10, freq=20.0, weight=1.0)
+ assert repr(pin) # exercise __repr__
+ pin.update() # a, b small -> binomial branch
+ assert jnp.all(jnp.isfinite(target.value))
+
+
+def test_input_var_runs():
+ iv = bp.dyn.InputVar(4)
+ iv.input += 1.0
+ assert jnp.allclose(jnp.asarray(iv.update()), 1.0)
+ iv.clear_input()
+ assert jnp.allclose(jnp.asarray(iv.update()), 0.0)
+
+
+# ---------------------------------------------------------------------------
+# Coverage: dnn/linear comm layer forward passes.
+# ---------------------------------------------------------------------------
+
+def test_dnn_linear_forwards():
+ conn = bp.conn.FixedProb(0.5, pre=4, post=3, seed=1)
+
+ dense = bp.dnn.Dense(4, 3)
+ assert jnp.asarray(dense(bm.ones(4))).shape == (3,)
+
+ a2a = bp.dnn.AllToAll(4, 3, weight=0.1)
+ # NonBatching mode requires a 1D input; identical scalar weight + ones input
+ # collapses to a scalar (sum over pre).
+ assert jnp.asarray(a2a(bm.ones(4))).shape == ()
+
+ o2o = bp.dnn.OneToOne(4, weight=0.1)
+ assert jnp.asarray(o2o(bm.ones(4))).shape == (4,)
+
+ ml = bp.dnn.MaskedLinear(conn, weight=0.1)
+ assert jnp.asarray(ml(bm.ones(4))).shape == (3,)
+
+ csr = bp.dnn.CSRLinear(conn, weight=0.1)
+ assert jnp.asarray(csr(bm.ones(4))).shape == (3,)
+
+
+def test_dnn_event_and_jitprob_forwards():
+ conn = bp.conn.FixedProb(0.5, pre=4, post=3, seed=1)
+ spike = bm.asarray([True, False, True, False])
+ x = bm.ones(4)
+
+ # Identity passthrough.
+ assert jnp.asarray(bp.dnn.Identity()(x)).shape == (4,)
+
+ # Event-driven CSR.
+ ec = bp.dnn.EventCSRLinear(conn, weight=0.1)
+ assert jnp.asarray(ec(spike)).shape == (3,)
+
+ # JIT just-in-time fixed-prob connectivity (dense-input variants).
+ assert jnp.asarray(
+ bp.dnn.JitFPHomoLinear(4, 3, prob=0.5, weight=0.1, seed=1)(x)).shape == (3,)
+ assert jnp.asarray(
+ bp.dnn.JitFPUniformLinear(4, 3, prob=0.5, w_low=0.0, w_high=0.1, seed=1)(x)).shape == (3,)
+ assert jnp.asarray(
+ bp.dnn.JitFPNormalLinear(4, 3, prob=0.5, w_mu=0.0, w_sigma=0.1, seed=1)(x)).shape == (3,)
+
+ # JIT fixed-prob connectivity (event/spike-input variants).
+ assert jnp.asarray(
+ bp.dnn.EventJitFPHomoLinear(4, 3, prob=0.5, weight=0.1, seed=1)(spike)).shape == (3,)
+ assert jnp.asarray(
+ bp.dnn.EventJitFPUniformLinear(4, 3, prob=0.5, w_low=0.0, w_high=0.1, seed=1)(spike)).shape == (3,)
+ assert jnp.asarray(
+ bp.dnn.EventJitFPNormalLinear(4, 3, prob=0.5, w_mu=0.0, w_sigma=0.1, seed=1)(spike)).shape == (3,)
+
+
+def test_dense_stdp_update_promotes_weight():
+ dense = bp.dnn.Dense(3, 4, W_initializer=bp.init.Uniform(max_val=0.1))
+ assert not isinstance(dense.W, bm.Variable)
+ # the plasticity wrapper passes jax arrays to ``stdp_update``; mirror that.
+ spike_pre = bm.as_jax(bm.asarray([1.0, 0.0, 1.0]))
+ trace_pre = bm.as_jax(bm.asarray([0.1, 0.2, 0.3, 0.4]))
+ w0 = bm.as_jax(dense.W).copy()
+ dense.stdp_update(on_pre={'spike': spike_pre, 'trace': trace_pre}, w_min=None, w_max=None)
+ assert isinstance(dense.W, bm.Variable)
+ assert not jnp.allclose(w0, bm.as_jax(dense.W))
+ assert jnp.all(jnp.isfinite(bm.as_jax(dense.W)))
diff --git a/tests/audit/test_dyn_rates_dynold_fixes.py b/tests/audit/test_dyn_rates_dynold_fixes.py
new file mode 100644
index 000000000..b96720ffd
--- /dev/null
+++ b/tests/audit/test_dyn_rates_dynold_fixes.py
@@ -0,0 +1,485 @@
+# -*- coding: utf-8 -*-
+"""Audit regression + coverage tests (audit ``docs/issues-found-20260618.md``).
+
+This module exercises the fixes recorded in the 2026-06-18 BrainPy audit for the
+rate-population / reservoir / RNN-cell modules under ``brainpy/dyn/rates`` and the
+``brainpy/dynold`` compatibility shims.
+
+Regression behaviors covered:
+
+* C-15 ``ThresholdLinearModel`` noise path no longer crashes (``randn(*shape)``).
+* C-16 ``StuartLandauOscillator.dy`` uses the correct ``+w*x`` rotational coupling.
+* C-17 (dynold copy in ``experimental/others.py``) ``PoissonInput`` Gaussian branch
+ uses ``std = sqrt(b*p)``, not the variance ``b*p``.
+* C-20 ``AlphaCUBA`` / ``AlphaCOBA`` construct without ``ZeroDivisionError``.
+* C-21 dynold ``STP`` synaptic current does not drift with zero presynaptic spikes.
+* H-36 ``LSTMCell`` ``h`` / ``c`` setters slice the last axis (unbatched + batched);
+ setting ``c`` does not corrupt ``h``.
+* H-37 ``Reservoir`` recurrent noise is symmetric (``uniform(-1, 1)`` -> zero mean).
+* H-38 ``Reservoir`` bias is actually added in ``update``.
+
+The remaining tests construct + step the assigned modules for coverage. Bugs that
+are still *unfixed* in the source are recorded in the agent summary, not asserted
+here (e.g. M-21 ``reset_state(None)``, M-24 ``ALIFBellec2020`` ``a_initializer``).
+"""
+
+import jax.numpy as jnp
+import pytest
+
+import brainpy as bp
+import brainpy.math as bm
+import brainpy.initialize as init
+from brainpy.context import share
+
+
+def _share(t=0.0, dt=0.1, i=0):
+ """Populate the shared simulation context required by ``update``."""
+ bm.set_dt(dt)
+ share.save(t=t, dt=dt, i=i)
+
+
+# ---------------------------------------------------------------------------
+# Regression tests
+# ---------------------------------------------------------------------------
+
+def test_threshold_linear_model_noise_path_runs():
+ """C-15: nonzero ``noise_e`` no longer raises ``TypeError`` from ``randn``."""
+ bm.random.seed(0)
+ m = bp.dyn.ThresholdLinearModel(5, noise_e=1.0)
+ _share(t=0.0, dt=1e-4, i=0)
+ out = m.update(inp_e=0.0)
+ arr = jnp.asarray(out)
+ assert arr.shape == (5,)
+ assert bool(jnp.isfinite(arr).all())
+
+
+def test_threshold_linear_model_noise_i_path_runs():
+ """C-15 companion: the inhibitory noise branch is also reachable."""
+ bm.random.seed(1)
+ m = bp.dyn.ThresholdLinearModel(4, noise_i=0.5)
+ _share(t=0.0, dt=1e-4, i=0)
+ out = m.update(inp_e=1.0, inp_i=1.0)
+ assert bool(jnp.isfinite(jnp.asarray(out)).all())
+
+
+def test_stuart_landau_dy_rotational_coupling():
+ """C-16: ``dy`` must add ``+w*x`` (Hopf normal form), not ``-w*y``."""
+ m = bp.dyn.StuartLandauOscillator(1, a=0.25, w=0.2)
+ # signature is dy(self, y, t, x, y_ext, a, w)
+ val = float(jnp.asarray(m.dy(0.3, 0.0, 0.5, 0.0, 0.25, 0.2)))
+ # (a - x^2 - y^2)*y + w*x = (0.25 - 0.25 - 0.09)*0.3 + 0.2*0.5 = 0.073
+ assert val == pytest.approx(0.073, abs=1e-4)
+ # The buggy -w*y value would be -0.087; make sure we are not there.
+ assert val > 0.0
+
+
+def test_poisson_input_gaussian_std_is_sqrt_bp():
+ """C-17 (dynold copy): Gaussian branch std = sqrt(b*p), ~4.43 not ~19.6."""
+ from brainpy.dynold.experimental.others import PoissonInput
+
+ bm.random.seed(0)
+ dt = 0.1
+ freq, num = 200.0, 1000
+ pi = PoissonInput(target_shape=(20000,), num_input=num, freq=freq, weight=1.0)
+ _share(t=0.0, dt=dt, i=0)
+ out = jnp.asarray(pi.update())
+
+ p = freq * dt / 1e3 # 0.02
+ b = num * (1 - p)
+ expected_std = float((b * p) ** 0.5) # ~4.427
+ expected_mean = num * p # 20.0
+
+ assert expected_std == pytest.approx(4.427, abs=1e-2)
+ # empirical std close to sqrt(b*p), and clearly far from the variance b*p (=19.6)
+ assert float(out.std()) == pytest.approx(expected_std, rel=0.1)
+ assert float(out.mean()) == pytest.approx(expected_mean, rel=0.1)
+
+
+def test_alpha_cuba_coba_construct_and_step():
+ """C-20: AlphaCUBA/COBA construct without ZeroDivisionError and run a step."""
+ bm.random.seed(0)
+ bm.set_dt(0.1)
+
+ pre = bp.neurons.LIF(2)
+ post = bp.neurons.LIF(2)
+ syn = bp.synapses.AlphaCUBA(pre, post, bp.connect.All2All(), tau_decay=10.0)
+ net = bp.Network(pre=pre, post=post, syn=syn)
+ net.reset_state()
+ _share()
+ syn.update() # must not raise
+
+ pre2 = bp.neurons.LIF(2)
+ post2 = bp.neurons.LIF(2)
+ syn2 = bp.synapses.AlphaCOBA(pre2, post2, bp.connect.All2All(), tau_decay=10.0)
+ net2 = bp.Network(pre=pre2, post=post2, syn=syn2)
+ net2.reset_state()
+ _share()
+ syn2.update() # must not raise
+
+
+def test_dynold_stp_no_drift_without_spikes():
+ """C-21: dynold STP synaptic current stays at 0 with zero presynaptic spikes."""
+ bm.random.seed(0)
+ bm.set_dt(0.1)
+ pre = bp.neurons.LIF(1)
+ post = bp.neurons.LIF(1)
+ syn = bp.synapses.STP(pre, post, bp.connect.All2All(), U=0.2, tau_d=150.0, tau_f=2.0)
+ net = bp.Network(pre=pre, syn=syn, post=post)
+ net.reset_state()
+
+ currents = []
+ for i in range(100):
+ _share(t=i * 0.1, dt=0.1, i=i)
+ net.update() # no external input -> no spikes
+ currents.append(float(jnp.asarray(syn.I.value).sum()))
+
+ # With the spike-gating fix, the current must not ramp up.
+ assert max(abs(c) for c in currents) < 1e-5
+
+
+def test_dynold_stp_jumps_with_spikes():
+ """C-21 companion: STP current does jump when presynaptic spikes occur."""
+ bm.random.seed(0)
+ bm.set_dt(0.1)
+ pre = bp.neurons.LIF(1)
+ post = bp.neurons.LIF(1)
+ syn = bp.synapses.STP(pre, post, bp.connect.All2All(), U=0.2, tau_d=150.0, tau_f=2.0)
+ net = bp.Network(pre=pre, syn=syn, post=post)
+ runner = bp.DSRunner(net, inputs=[('pre.input', 28.0)], monitors=['syn.I'],
+ progress_bar=False)
+ runner.run(50.0)
+ I = runner.mon['syn.I']
+ assert bool(jnp.isfinite(I).all())
+ assert float(I.max()) > 0.0
+
+
+def test_lstm_h_c_setters_unbatched():
+ """H-36: unbatched ``h``/``c`` setters slice the last axis (no IndexError)."""
+ bm.random.seed(0)
+ lstm = bp.dyn.LSTMCell(3, 4)
+ assert lstm.state.shape == (8,)
+
+ lstm.h = jnp.ones((4,))
+ assert jnp.allclose(jnp.asarray(lstm.h), 1.0)
+ # setting h must not have touched c
+ assert jnp.allclose(jnp.asarray(lstm.c), 0.0)
+
+ lstm.c = jnp.full((4,), 2.0)
+ # setting c must not corrupt h
+ assert jnp.allclose(jnp.asarray(lstm.h), 1.0)
+ assert jnp.allclose(jnp.asarray(lstm.c), 2.0)
+
+
+def test_lstm_h_c_setters_batched():
+ """H-36: batched ``h``/``c`` setters write the correct rows."""
+ bm.random.seed(0)
+ lstm = bp.dyn.LSTMCell(3, 4, mode=bm.batching_mode)
+ lstm.reset_state(2)
+ assert lstm.state.shape == (2, 8)
+
+ lstm.h = jnp.ones((2, 4))
+ assert jnp.allclose(jnp.asarray(lstm.h), 1.0)
+ assert jnp.allclose(jnp.asarray(lstm.c), 0.0)
+
+ lstm.c = jnp.full((2, 4), 2.0)
+ assert jnp.allclose(jnp.asarray(lstm.h), 1.0)
+ assert jnp.allclose(jnp.asarray(lstm.c), 2.0)
+
+
+def test_reservoir_bias_is_applied():
+ """H-38: a nonzero bias shifts the reservoir output."""
+ bm.random.seed(0)
+ common = dict(in_connectivity=1.0, rec_connectivity=1.0,
+ activation_type='external', leaky_rate=1.0)
+ r0 = bp.dyn.Reservoir(3, 5, b_initializer=init.ZeroInit(), **common)
+ r1 = bp.dyn.Reservoir(3, 5, b_initializer=init.Constant(1.0), **common)
+ # share the random weights so only the bias differs
+ r1.Win.value = r0.Win.value
+ r1.Wrec.value = r0.Wrec.value
+
+ x = jnp.zeros((3,))
+ out0 = jnp.asarray(r0.update(x))
+ out1 = jnp.asarray(r1.update(x))
+ # zero bias + zero input -> zero (external activation of 0 is tanh(0)=0)
+ assert jnp.allclose(out0, 0.0)
+ # nonzero bias shifts the output away from zero
+ assert float(jnp.sum(jnp.abs(out1 - out0))) > 1e-3
+
+
+def test_reservoir_recurrent_noise_is_symmetric():
+ """H-37: recurrent noise is ``uniform(-1, 1)`` (zero-mean), not a -noise bias."""
+ bm.random.seed(123)
+ r = bp.dyn.Reservoir(2, 4000, noise_rec=1.0, in_connectivity=1.0,
+ rec_connectivity=1.0, activation_type='external',
+ leaky_rate=1.0)
+ # zero out the weights so the state is driven purely by the recurrent noise
+ r.Win.value = jnp.zeros_like(jnp.asarray(r.Win.value))
+ r.Wrec.value = jnp.zeros_like(jnp.asarray(r.Wrec.value))
+ out = jnp.asarray(r.update(jnp.zeros((2,))))
+ # state = tanh(noise); symmetric noise -> mean near zero, both signs present
+ assert abs(float(out.mean())) < 0.05
+ assert float(out.min()) < -0.1
+ assert float(out.max()) > 0.1
+
+
+# ---------------------------------------------------------------------------
+# Coverage tests: construct + step the assigned modules
+# ---------------------------------------------------------------------------
+
+def test_rate_populations_construct_and_step():
+ """Cover FHN/FeedbackFHN/QIF/StuartLandau/WilsonCowan/ThresholdLinear."""
+ bm.random.seed(0)
+ _share()
+
+ # FHN with OU noise enabled to exercise the noise branch; pass both inputs
+ fhn = bp.dyn.FHN(3, x_ou_sigma=1.0, y_ou_sigma=1.0)
+ fhn.reset_state()
+ assert jnp.asarray(fhn.update(1.0, 0.5)).shape == (3,)
+ fhn.clear_input()
+
+ fbfhn = bp.dyn.FeedbackFHN(3, delay=2.0, x_ou_sigma=1.0, y_ou_sigma=1.0)
+ fbfhn.reset_state()
+ assert jnp.asarray(fbfhn.update(1.0, 0.5)).shape == (3,)
+ fbfhn.clear_input()
+
+ qif = bp.dyn.QIF(3)
+ qif.reset_state()
+ assert jnp.asarray(qif.update(1.0, 0.5)).shape == (3,)
+ qif.clear_input()
+
+ sl = bp.dyn.StuartLandauOscillator(3)
+ sl.reset_state()
+ assert jnp.asarray(sl.update(1.0, 0.5)).shape == (3,)
+ sl.clear_input()
+
+ wc = bp.dyn.WilsonCowanModel(3)
+ wc.reset_state()
+ assert jnp.asarray(wc.update(1.0, 0.5)).shape == (3,)
+ wc.clear_input()
+
+ tlm = bp.dyn.ThresholdLinearModel(3)
+ tlm.reset_state()
+ assert jnp.asarray(tlm.update(inp_e=1.0, inp_i=0.5)).shape == (3,)
+ tlm.clear_input()
+
+
+def test_rate_populations_no_input_var_branch():
+ """Cover the ``input_var=False`` update branches with OU noise."""
+ bm.random.seed(0)
+ _share()
+ for cls in (bp.dyn.FHN, bp.dyn.FeedbackFHN, bp.dyn.QIF,
+ bp.dyn.StuartLandauOscillator, bp.dyn.WilsonCowanModel):
+ m = cls(2, input_var=False, x_ou_sigma=1.0, y_ou_sigma=1.0)
+ m.reset_state()
+ out = m.update(1.0, 0.5)
+ assert jnp.asarray(out).shape == (2,)
+
+
+def test_rate_populations_run_via_dsrunner():
+ """A short DSRunner run exercises update/clear_input across many steps."""
+ bm.random.seed(0)
+ fhn = bp.dyn.FHN(2)
+ runner = bp.DSRunner(fhn, inputs=('input', 1.0), monitors=['x'],
+ progress_bar=False)
+ runner.run(5.0)
+ assert bool(jnp.isfinite(runner.mon['x']).all())
+
+
+def test_reservoir_update_dense_and_sparse():
+ """Cover Reservoir dense + sparse comp paths and feedforward noise."""
+ bm.random.seed(0)
+ # dense with feedforward noise + spectral radius scaling
+ r = bp.dyn.Reservoir(4, 8, noise_in=0.1, spectral_radius=0.9)
+ r.reset_state()
+ out = r.update(jnp.ones((4,)))
+ assert jnp.asarray(out).shape == (8,)
+
+ # sparse computation path
+ rs = bp.dyn.Reservoir(4, 8, comp_type='sparse', in_connectivity=0.5,
+ rec_connectivity=0.5)
+ rs.reset_state()
+ out_s = rs.update(jnp.ones((4,)))
+ assert jnp.asarray(out_s).shape == (8,)
+
+
+def test_rnn_cells_forward_unbatched():
+ """Cover RNNCell/GRUCell/LSTMCell forward pass (unbatched)."""
+ bm.random.seed(0)
+ x = jnp.ones((3,))
+ for cls in (bp.dyn.RNNCell, bp.dyn.GRUCell, bp.dyn.LSTMCell):
+ c = cls(3, 4)
+ out = c.update(x)
+ assert jnp.asarray(out).shape == (4,)
+
+
+def test_rnn_cells_forward_batched_and_train_state():
+ """Cover batched forward + train_state initialization for the RNN cells."""
+ bm.random.seed(0)
+ x = jnp.ones((2, 3))
+ for cls in (bp.dyn.RNNCell, bp.dyn.GRUCell, bp.dyn.LSTMCell):
+ c = cls(3, 4, mode=bm.batching_mode)
+ c.reset_state(2)
+ out = c.update(x)
+ assert jnp.asarray(out).shape == (2, 4)
+
+ # train_state path
+ ct = cls(3, 4, mode=bm.training_mode, train_state=True)
+ ct.reset_state(2)
+ out2 = ct.update(x)
+ assert jnp.asarray(out2).shape == (2, 4)
+
+
+def test_rnn_cells_no_bias_branch():
+ """Cover the ``b_initializer=None`` (no bias) branches in the RNN cells."""
+ bm.random.seed(0)
+ x = jnp.ones((3,))
+ for cls in (bp.dyn.RNNCell, bp.dyn.GRUCell, bp.dyn.LSTMCell):
+ c = cls(3, 4, b_initializer=None)
+ out = c.update(x)
+ assert jnp.asarray(out).shape == (4,)
+
+
+def test_dynold_compat_synapses_construct_and_step():
+ """Cover dynold compat synapses: Exp/DualExp/Alpha (CUBA & COBA) + Delta."""
+ import warnings
+ bm.random.seed(0)
+ bm.set_dt(0.1)
+
+ from brainpy.synapses import (ExpCUBA, ExpCOBA, DualExpCUBA, DualExpCOBA,
+ AlphaCUBA, AlphaCOBA, DeltaSynapse)
+
+ factories = [
+ lambda a, b: ExpCUBA(a, b, bp.connect.All2All(), tau=8.0),
+ lambda a, b: ExpCOBA(a, b, bp.connect.All2All(), tau=8.0, E=0.0),
+ lambda a, b: DualExpCUBA(a, b, bp.connect.All2All(), tau_decay=10.0, tau_rise=1.0),
+ lambda a, b: DualExpCOBA(a, b, bp.connect.All2All(), tau_decay=10.0, tau_rise=1.0, E=0.0),
+ lambda a, b: AlphaCUBA(a, b, bp.connect.All2All(), tau_decay=10.0),
+ lambda a, b: AlphaCOBA(a, b, bp.connect.All2All(), tau_decay=10.0, E=0.0),
+ lambda a, b: DeltaSynapse(a, b, bp.connect.All2All()),
+ ]
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore')
+ for factory in factories:
+ pre = bp.neurons.LIF(2)
+ post = bp.neurons.LIF(2)
+ syn = factory(pre, post)
+ net = bp.Network(pre=pre, post=post, syn=syn)
+ net.reset_state()
+ _share()
+ syn.update() # must not raise
+
+
+def test_dynold_stp_run_full_simulation():
+ """Run the dynold STP learning rule via DSRunner to cover its update path."""
+ bm.random.seed(0)
+ bm.set_dt(0.1)
+ pre = bp.neurons.LIF(1)
+ post = bp.neurons.LIF(1)
+ syn = bp.synapses.STP(pre, post, bp.connect.All2All(), U=0.2, tau_d=150.0, tau_f=2.0)
+ net = bp.Network(pre=pre, syn=syn, post=post)
+ runner = bp.DSRunner(net, inputs=[('pre.input', 28.0)],
+ monitors=['syn.I', 'syn.u', 'syn.x'], progress_bar=False)
+ runner.run(30.0)
+ assert bool(jnp.isfinite(runner.mon['syn.u']).all())
+ assert bool(jnp.isfinite(runner.mon['syn.x']).all())
+
+
+def test_reduced_models_construct_and_step():
+ """Cover dynold reduced_models neurons: construct + a single update step."""
+ from brainpy.dynold.neurons import reduced_models as rm
+
+ bm.random.seed(0)
+ bm.set_dt(0.1)
+ _share()
+
+ for cls in (rm.LeakyIntegrator, rm.LIF, rm.ExpIF, rm.AdExIF, rm.QuaIF,
+ rm.AdQuaIF, rm.GIF, rm.Izhikevich, rm.HindmarshRose, rm.FHN,
+ rm.ALIFBellec2020):
+ m = cls(2)
+ m.reset_state()
+ out = m.update(1.0)
+ assert jnp.asarray(out).shape == (2,)
+ m.clear_input()
+
+
+def test_reduced_models_no_input_var_branch():
+ """Cover the ``input_var=False`` update branch of the reduced models."""
+ from brainpy.dynold.neurons import reduced_models as rm
+
+ bm.random.seed(0)
+ bm.set_dt(0.1)
+ _share()
+ for cls in (rm.LIF, rm.ExpIF, rm.Izhikevich):
+ m = cls(2, input_var=False)
+ m.reset_state()
+ out = m.update(1.0)
+ assert jnp.asarray(out).shape == (2,)
+ m.clear_input()
+
+
+def test_reduced_models_tau_ref_and_noise_branches():
+ """Cover the refractory + noise branches of the reduced models."""
+ from brainpy.dynold.neurons import reduced_models as rm
+
+ bm.random.seed(0)
+ bm.set_dt(0.1)
+ _share()
+
+ # tau_ref + noise exercises the sdeint + refractory paths in *Ref models
+ for cls in (rm.LIF, rm.ExpIF, rm.AdExIF, rm.QuaIF, rm.AdQuaIF,
+ rm.Izhikevich, rm.GIF):
+ m = cls(2, tau_ref=2.0, noise=0.5)
+ m.reset_state()
+ out = m.update(5.0)
+ assert jnp.asarray(out).shape == (2,)
+ m.clear_input()
+
+
+def test_bellec_sfa_models_construct_and_step():
+ """Cover ALIFBellec2020 / LIF_SFA_Bellec2020 with and without refractoriness."""
+ from brainpy.dynold.neurons import reduced_models as rm
+
+ bm.random.seed(0)
+ bm.set_dt(0.1)
+ _share()
+
+ for cls in (rm.ALIFBellec2020, rm.LIF_SFA_Bellec2020):
+ m = cls(2)
+ m.reset_state()
+ assert jnp.asarray(m.update(5.0)).shape == (2,)
+
+ m_ref = cls(2, tau_ref=2.0)
+ m_ref.reset_state()
+ assert jnp.asarray(m_ref.update(5.0)).shape == (2,)
+
+
+def test_conv_lstm_cells_forward():
+ """Cover the convolutional LSTM cells (1d/2d/3d) forward pass."""
+ bm.random.seed(0)
+
+ c1 = bp.dyn.Conv1dLSTMCell(input_shape=(5,), in_channels=2, out_channels=3,
+ kernel_size=3, mode=bm.batching_mode)
+ c1.reset_state(1)
+ assert jnp.asarray(c1.update(jnp.ones((1, 5, 2)))).shape == (1, 5, 3)
+
+ c2 = bp.dyn.Conv2dLSTMCell(input_shape=(4, 4), in_channels=2, out_channels=3,
+ kernel_size=3, mode=bm.batching_mode)
+ c2.reset_state(1)
+ assert jnp.asarray(c2.update(jnp.ones((1, 4, 4, 2)))).shape == (1, 4, 4, 3)
+
+
+def test_poisson_input_binomial_branch_and_helpers():
+ """Cover the small-N binomial branch, ``__repr__`` and ``reset`` of PoissonInput."""
+ from brainpy.dynold.experimental.others import PoissonInput
+
+ bm.random.seed(0)
+ bm.set_dt(0.1)
+ _share()
+
+ # low freq / small num_input -> a<=5 or b<=5 -> binomial branch
+ pi = PoissonInput(target_shape=(5,), num_input=10, freq=10.0, weight=1.0)
+ out = pi.update()
+ assert jnp.asarray(out).shape == (5,)
+ assert 'PoissonInput' in repr(pi)
+ pi.reset()
+ pi.reset_state()
diff --git a/tests/audit/test_integrators_fixes.py b/tests/audit/test_integrators_fixes.py
new file mode 100644
index 000000000..dd9cd6e4f
--- /dev/null
+++ b/tests/audit/test_integrators_fixes.py
@@ -0,0 +1,933 @@
+# -*- coding: utf-8 -*-
+"""Regression + coverage tests for the BrainPy v2.7.8 integrators audit.
+
+These tests pin the fixes recorded in ``docs/issues-found-20260618.md`` for the
+``brainpy/integrators`` subtree. Each regression test references the audit ID it
+guards. The remaining tests exercise the public integrator API broadly to keep
+line coverage high on the assigned source files:
+
+ * ode/adaptive_rk.py -- C-12, H-26, H-27, H-28
+ * ode/exponential.py -- exp_euler / exp_euler_auto numerics
+ * sde/base.py -- C-13 (errors import; invalid intg_type/wiener_type)
+ * sde/normal.py -- C-13 (Heun Ito/Stratonovich guard)
+ * integrators/runner.py-- H-29 (IntegratorRunner step-index monitor)
+ * integrators/joint_eq.py -- L-13 (diagnostic DiffEqError message)
+ * fde/Caputo.py -- C-08 (CaputoEuler init scaling), H-31 (hists())
+ * fde/GL.py -- H-30 (GLShortMemory.reset key suffix)
+ * fde/generic.py -- H-32 (set/get_default_fdeint global)
+
+The tests use tiny ``dt`` / step counts so the whole module runs in well under a
+minute. Global state (x64 precision, default FDE method) is restored by
+fixtures.
+"""
+
+import math
+
+import jax.numpy as jnp
+import numpy as np
+import pytest
+
+import brainpy as bp
+import brainpy.math as bm
+from brainpy.integrators.ode.adaptive_rk import BoSh3
+
+IntegratorError = bp.errors.IntegratorError
+DiffEqError = bp.errors.DiffEqError
+
+
+# --------------------------------------------------------------------------- #
+# Fixtures: restore global state mutated by some tests.
+# --------------------------------------------------------------------------- #
+
+@pytest.fixture
+def x64():
+ """Enable float64 for the duration of a test, then restore float32."""
+ bm.enable_x64()
+ try:
+ yield
+ finally:
+ bm.disable_x64()
+
+
+@pytest.fixture
+def restore_default_fdeint():
+ """Snapshot and restore the global default FDE method."""
+ orig = bp.fde.get_default_fdeint()
+ try:
+ yield
+ finally:
+ bp.fde.set_default_fdeint(orig)
+
+
+def _ev(seq):
+ """Evaluate a Butcher-tableau row whose entries may be string fractions."""
+ return [eval(x) if isinstance(x, str) else float(x) for x in seq]
+
+
+# =========================================================================== #
+# Regression tests
+# =========================================================================== #
+
+# --- C-12: adaptive RK no longer TypeErrors when ``tol`` is omitted --------- #
+
+def test_c12_adaptive_rkf45_runs_without_tol_array_state():
+ """C-12: adaptive=True with default tol must run (tol falls back to 0.1)."""
+ f = bp.odeint(lambda y, t: -y, method='rkf45', adaptive=True)
+ y_new, dt_new = f(jnp.array([1.0]), 0.0, dt=0.1)
+ # exact solution y(0.1) = e^{-0.1}
+ assert np.allclose(np.asarray(y_new), math.exp(-0.1), atol=1e-3)
+ assert float(dt_new) > 0.0
+
+
+def test_c12_adaptive_rkf45_runs_on_scalar_state():
+ """C-12 / H-28: the same integrator must also work on a python-float scalar."""
+ f = bp.odeint(lambda y, t: -y, method='rkf45', adaptive=True)
+ y_new, dt_new = f(1.0, 0.0, dt=0.1)
+ assert np.allclose(float(y_new), math.exp(-0.1), atol=1e-3)
+ assert float(dt_new) > 0.0
+
+
+# --- H-28: scalar state with default POP_VAR var_type ----------------------- #
+
+def test_h28_scalar_state_all_adaptive_methods():
+ """H-28: default var_type=POP_VAR must use jnp.sum, not builtin sum, so a
+ scalar state does not raise ``'float' object is not iterable``."""
+ for method in ['rkf45', 'rkdp', 'rkf12', 'ck', 'bs', 'heun_euler', 'BoSh3']:
+ f = bp.odeint(lambda y, t: -y, method=method, adaptive=True)
+ y_new, dt_new = f(1.0, 0.0, dt=0.05)
+ assert np.isfinite(float(y_new))
+ assert np.isfinite(float(dt_new))
+
+
+# --- H-26: BoSh3 embedded error vector is non-degenerate -------------------- #
+
+def test_h26_bosh3_embedded_error_non_degenerate():
+ """H-26: B1 and B2 must each be consistent (sum ~ 1) and B1-B2 must be a
+ real (non-zero) error estimator, not the zero-sum-B2 bug."""
+ b1 = _ev(BoSh3.B1)
+ b2 = _ev(BoSh3.B2)
+ assert abs(sum(b1) - 1.0) < 1e-9, f'B1 must sum to 1, got {sum(b1)}'
+ assert abs(sum(b2) - 1.0) < 1e-9, f'B2 must sum to 1, got {sum(b2)}'
+ diff = [a - b for a, b in zip(b1, b2)]
+ # the embedded error estimate must be non-degenerate (some non-zero weights)
+ assert any(abs(d) > 1e-9 for d in diff), 'B1-B2 is degenerate (all zero)'
+ # the buggy B2 summed to ~0; guard against that regression explicitly
+ assert abs(sum(b2)) > 0.5
+
+
+def test_h26_bosh3_integrates_correctly():
+ """BoSh3 (3rd order) must integrate y'=-y accurately."""
+ f = bp.odeint(lambda y, t: -y, method='BoSh3', adaptive=True)
+ y_new, _ = f(jnp.array([1.0]), 0.0, dt=0.05)
+ assert np.allclose(np.asarray(y_new), math.exp(-0.05), atol=1e-4)
+
+
+# --- H-27: two-sided step-size controller can grow dt ----------------------- #
+
+def test_h27_step_controller_grows_dt_when_error_small():
+ """H-27: when the error is comfortably below tol the controller must be able
+ to *increase* dt (the buggy one-sided controller never grew dt)."""
+ f = bp.odeint(lambda y, t: -y, method='rkf45', adaptive=True)
+ _, dt_new = f(jnp.array([1.0]), 0.0, dt=0.01)
+ assert float(dt_new) > 0.01, 'controller failed to grow dt below tolerance'
+
+
+# --- exp_euler numerics (python-float input) -------------------------------- #
+
+def test_exp_euler_python_float_input():
+ """exp_euler on a linear ODE y'=-2y is exact: y(0.3)=e^{-0.6}."""
+ f = bp.odeint(lambda y, t: -2 * y, method='exp_euler', dt=0.3)
+ out = f(1.0, 0.0, dt=0.3)
+ assert np.allclose(float(out), math.exp(-0.6), atol=1e-5)
+
+
+def test_exp_euler_auto_python_float_input():
+ """exp_euler_auto must give the same exact result on the linear ODE."""
+ f = bp.odeint(lambda y, t: -2 * y, method='exp_euler_auto', dt=0.3)
+ out = f(1.0, 0.0, dt=0.3)
+ assert np.allclose(float(out), math.exp(-0.6), atol=1e-5)
+
+
+# --- C-13: SDE integrators raise IntegratorError, not NameError ------------- #
+
+def test_c13_sde_invalid_intg_type_raises_integrator_error():
+ """C-13: invalid intg_type must raise IntegratorError (errors import fixed),
+ not NameError: name 'errors' is not defined."""
+ with pytest.raises(IntegratorError):
+ bp.sdeint(lambda x, t: -x, lambda x, t: 0.1,
+ method='euler', intg_type='WRONG')
+
+
+def test_c13_sde_invalid_wiener_type_raises_integrator_error():
+ """C-13: the same errors import guards the wiener_type validation path."""
+ with pytest.raises(IntegratorError):
+ bp.sdeint(lambda x, t: -x, lambda x, t: 0.1,
+ method='euler', wiener_type='WRONG')
+
+
+def test_c13_heun_rejects_ito():
+ """C-13 (sde/normal.py): Heun only supports Stratonovich; an Ito request
+ must raise IntegratorError rather than NameError."""
+ with pytest.raises(IntegratorError):
+ bp.sdeint(lambda x, t: -x, lambda x, t: 0.1,
+ method='heun', intg_type='Ito')
+
+
+# --- C-08: CaputoEuler does not mis-scale the initial condition ------------- #
+
+def test_c08_caputo_euler_preserves_initial_condition(x64):
+ """C-08: for D^a x = 0 with x(0)=1 (exact x==1), CaputoEuler must keep x~1
+ across steps instead of returning dt^a/a."""
+ intg = bp.fde.CaputoEuler(lambda x, t: jnp.zeros_like(jnp.asarray(x)),
+ alpha=0.8, num_memory=10, inits=[1.0])
+ x = jnp.array([1.0])
+ t = 0.0
+ dt = 0.1
+ for _ in range(5):
+ x = intg(x, t, dt=dt)
+ t += dt
+ assert np.allclose(np.asarray(x), 1.0, atol=1e-6), \
+ f'CaputoEuler drifted from the initial condition: {np.asarray(x)}'
+
+
+# --- H-31: CaputoL1Schema.hists() iterates .items(), not bare dict ---------- #
+
+def test_h31_caputo_l1_hists_returns_without_valueerror(x64):
+ """H-31: after a step, .hists() (default numpy=True) must return a dict of
+ arrays and not raise ValueError from iterating dict keys instead of items."""
+ intg = bp.fde.CaputoL1Schema(lambda x, t: -x, alpha=0.9,
+ num_memory=10, inits=[1.0])
+ x = jnp.array([1.0])
+ x = intg(x, 0.0, dt=0.1)
+ hists = intg.hists() # must not raise
+ assert isinstance(hists, dict)
+ for v in hists.values():
+ assert isinstance(v, np.ndarray)
+ # per-variable accessor path
+ var = intg.variables[0]
+ one = intg.hists(var=var)
+ assert isinstance(one, np.ndarray)
+
+
+# --- H-30: GLShortMemory.reset uses the '_delay' key suffix ------------------ #
+
+def test_h30_glshortmemory_reset_works():
+ """H-30: reset must use key+'_delay' and not raise KeyError."""
+ g = bp.fde.GLShortMemory(lambda x, t: -x, alpha=0.6,
+ num_memory=8, inits=[1.0])
+ g.reset(inits=[2.0]) # must not raise KeyError
+ out = g(jnp.array([2.0]), 0.0, dt=0.1)
+ assert np.all(np.isfinite(np.asarray(out)))
+
+
+# --- H-32: set_default_fdeint writes the FDE global, get reads it back ------- #
+
+def test_h32_set_get_default_fdeint_roundtrips(restore_default_fdeint):
+ """H-32: set_default_fdeint must actually change get_default_fdeint (the bug
+ wrote the wrong global, making it a no-op)."""
+ for method in bp.fde.get_supported_methods():
+ bp.fde.set_default_fdeint(method)
+ assert bp.fde.get_default_fdeint() == method
+
+
+def test_h32_set_default_fdeint_rejects_unknown(restore_default_fdeint):
+ with pytest.raises(ValueError):
+ bp.fde.set_default_fdeint('not-a-real-method')
+
+
+# --- L-13: JointEq raises a *diagnostic* DiffEqError on conflicting kwarg ---- #
+
+def test_l13_jointeq_conflicting_kwarg_raises_message():
+ """L-13: a keyword argument that reuses a state-variable name must raise a
+ DiffEqError carrying a non-empty diagnostic message."""
+ def dV(V, t, w):
+ return -V
+
+ def dw(w, t, V=1.0): # 'V' kwarg collides with state variable V
+ return -w
+
+ with pytest.raises(DiffEqError) as exc:
+ bp.JointEq(dV, dw)
+ assert str(exc.value).strip(), 'DiffEqError message must not be empty'
+ assert 'V' in str(exc.value)
+
+
+# =========================================================================== #
+# Coverage tests
+# =========================================================================== #
+
+def test_cov_odeint_all_methods_array_and_scalar():
+ """Exercise non-adaptive + adaptive ODE methods on array and scalar state."""
+ y0 = jnp.array([1.0])
+ for m in ['euler', 'rk2', 'rk4']:
+ out = bp.odeint(lambda y, t: -y, method=m)(y0, 0.0, dt=0.01)
+ assert np.allclose(np.asarray(out), math.exp(-0.01), atol=1e-3)
+
+ for m in ['rkf45', 'rkdp', 'rkf12', 'ck', 'bs', 'heun_euler', 'BoSh3']:
+ f = bp.odeint(lambda y, t: -y, method=m, adaptive=True)
+ y_new, dt_new = f(y0, 0.0, dt=0.01)
+ assert np.allclose(np.asarray(y_new), math.exp(-0.01), atol=1e-2)
+ assert float(dt_new) > 0.0
+
+
+def test_cov_adaptive_explicit_tol_and_var_type():
+ """Cover the adaptive path with an explicit tol and SCALAR var_type."""
+ f = bp.odeint(lambda y, t: -y, method='rkf45', adaptive=True,
+ tol=1e-3, var_type='scalar')
+ y_new, dt_new = f(1.0, 0.0, dt=0.01)
+ assert np.isfinite(float(y_new))
+ assert np.isfinite(float(dt_new))
+
+
+def test_cov_adaptive_show_code():
+ """show_code=True exercises the code-emission/print branch."""
+ f = bp.odeint(lambda y, t: -y, method='rkf45', adaptive=True, show_code=True)
+ out = f(jnp.array([1.0]), 0.0, dt=0.05)
+ assert len(out) == 2
+
+
+def test_cov_exp_euler_variants_array_state():
+ """Exercise exp_euler / exp_euler_auto / exp_auto on an array state."""
+ y0 = jnp.array([1.0, 2.0])
+ for m in ['exp_euler', 'exp_euler_auto', 'exp_auto', 'exponential_euler']:
+ out = bp.odeint(lambda y, t: -y, method=m)(y0, 0.0, dt=0.01)
+ assert np.allclose(np.asarray(out), y0 * math.exp(-0.01), atol=1e-3)
+
+
+def test_cov_sdeint_euler_milstein_heun_both_types():
+ """Cover euler/milstein (Ito + Stratonovich) and heun (Stratonovich)."""
+ bm.random.seed(1234)
+ g = lambda x, t: jnp.ones_like(x) * 0.1
+ for method in ['euler', 'milstein']:
+ for itype in ['Ito', 'Stratonovich']:
+ f = bp.sdeint(lambda x, t: -x, g, method=method, intg_type=itype)
+ out = f(jnp.array([1.0]), 0.0, dt=0.01)
+ assert np.all(np.isfinite(np.asarray(out)))
+ # Heun is Stratonovich-only
+ f = bp.sdeint(lambda x, t: -x, g, method='heun', intg_type='Stratonovich')
+ out = f(jnp.array([1.0]), 0.0, dt=0.01)
+ assert np.all(np.isfinite(np.asarray(out)))
+
+
+def test_cov_integrator_runner_with_monitors():
+ """H-29 + coverage: IntegratorRunner over a small ODE; the step-index
+ monitor must report 0..N-1 (loop variable no longer clobbered)."""
+ bm.set_dt(0.1)
+ intg = bp.odeint(lambda V, t: -V, method='rk4')
+ runner = bp.IntegratorRunner(
+ intg,
+ monitors={'V': 'V', 'step': lambda sh: sh['i']},
+ inits={'V': 1.0},
+ dt=0.1,
+ progress_bar=False,
+ )
+ runner.run(0.5)
+ v = np.asarray(runner.mon['V']).ravel()
+ step = np.asarray(runner.mon['step']).ravel()
+ assert v.shape == (5,)
+ assert np.allclose(v, np.exp(-0.1 * np.arange(1, 6)), atol=1e-3)
+ # H-29: the step index is the real loop counter, not len(vars)-1
+ assert list(step) == [0, 1, 2, 3, 4]
+
+
+def test_cov_integrator_runner_seq_monitor():
+ """Cover the tuple/list monitor formatting branch of IntegratorRunner."""
+ bm.set_dt(0.1)
+ intg = bp.odeint(lambda V, t: -V, method='euler')
+ runner = bp.IntegratorRunner(intg, monitors=['V'], inits={'V': 1.0},
+ dt=0.1, progress_bar=False)
+ runner.run(0.3)
+ assert np.asarray(runner.mon['V']).shape == (3, 1)
+
+
+def test_cov_jointeq_construction_and_integration():
+ """Cover JointEq construction (incl. nested) and integration."""
+ a, b = 0.7, 0.8
+
+ def dV(V, t, w, Iext):
+ return V - V ** 3 / 3 - w + Iext
+
+ def dw(w, t, V):
+ return a * (b * V - w)
+
+ eq = bp.JointEq(dV, dw)
+ flat = [v for sub in eq.vars_in_eqs for v in sub]
+ assert set(flat) == {'V', 'w'}
+
+ intg = bp.odeint(eq, method='rk2')
+ V, w = intg(0.0, 0.0, 0.0, Iext=0.5, dt=0.01)
+ assert np.isfinite(float(V)) and np.isfinite(float(w))
+
+ # nested JointEq
+ def du(u, t, V):
+ return -u + V
+
+ eq2 = bp.JointEq(eq, du)
+ flat2 = [v for sub in eq2.vars_in_eqs for v in sub]
+ assert set(flat2) == {'V', 'w', 'u'}
+
+
+def test_cov_jointeq_rejects_non_callable():
+ """Cover the _check_eqs error branch."""
+ with pytest.raises(DiffEqError):
+ bp.JointEq(123)
+
+
+def test_cov_caputo_euler_integration(x64):
+ """Integrate a non-trivial Caputo equation a few steps for coverage."""
+ intg = bp.fde.CaputoEuler(lambda x, t: -x, alpha=0.9,
+ num_memory=20, inits=[1.0])
+ x = jnp.array([1.0])
+ t = 0.0
+ for _ in range(5):
+ x = intg(x, t, dt=0.05)
+ t += 0.05
+ assert np.all(np.isfinite(np.asarray(x)))
+ # decaying solution must stay below the initial value
+ assert float(x[0]) < 1.0
+
+
+def test_cov_caputo_l1_integration_and_reset(x64):
+ """Integrate CaputoL1Schema a few steps, hits hists(), then reset()."""
+ intg = bp.fde.CaputoL1Schema(lambda x, t: -x, alpha=0.7,
+ num_memory=20, inits=[1.0])
+ x = jnp.array([1.0])
+ t = 0.0
+ for _ in range(4):
+ x = intg(x, t, dt=0.05)
+ t += 0.05
+ assert np.all(np.isfinite(np.asarray(x)))
+ h = intg.hists()
+ assert isinstance(h, dict)
+ intg.reset(inits=[0.5])
+ assert np.allclose(np.asarray(intg.inits[intg.variables[0]]), 0.5)
+
+
+def test_cov_glshortmemory_integration_multivar():
+ """Integrate GLShortMemory on a coupled 2D system + binomial_coef access."""
+ def f(x, y, t):
+ return -x + 0.1 * y, -y - 0.1 * x
+
+ g = bp.fde.GLShortMemory(f, alpha=[0.95, 0.95], num_memory=16,
+ inits=[1.0, 0.5])
+ coef = g.binomial_coef
+ assert coef.shape[0] == 16
+ x = jnp.array([1.0])
+ y = jnp.array([0.5])
+ t = 0.0
+ for _ in range(5):
+ x, y = g(x, y, t, dt=0.02)
+ t += 0.02
+ assert np.all(np.isfinite(np.asarray(x)))
+ assert np.all(np.isfinite(np.asarray(y)))
+
+
+def test_cov_glshortmemory_via_fdeint():
+ """Cover the fdeint factory dispatch to the short-memory integrator."""
+ intg = bp.fdeint(alpha=0.8, num_memory=8, inits=[1.0],
+ method='short-memory', dt=0.05)(lambda x, t: -x)
+ out = intg(jnp.array([1.0]), 0.0, dt=0.05)
+ assert np.all(np.isfinite(np.asarray(out)))
+
+
+def test_cov_caputo_euler_via_fdeint(restore_default_fdeint):
+ """Cover fdeint(method=None) using the default + the 'euler' dispatch."""
+ bp.fde.set_default_fdeint('euler')
+ intg = bp.fdeint(alpha=0.8, num_memory=8, inits=[1.0],
+ method=None, dt=0.05)(lambda x, t: -x)
+ out = intg(jnp.array([1.0]), 0.0, dt=0.05)
+ assert np.all(np.isfinite(np.asarray(out)))
+
+
+# --------------------------------------------------------------------------- #
+# Extra coverage: SDE method/wiener-type/multivar branches (sde/normal.py).
+# --------------------------------------------------------------------------- #
+
+def test_cov_sde_milstein_variants():
+ """Cover milstein2 / milstein_grad_free for both integral types."""
+ bm.random.seed(21)
+ g = lambda x, t: jnp.ones_like(x) * 0.1
+ for method in ['milstein2', 'milstein_grad_free']:
+ for itype in ['Ito', 'Stratonovich']:
+ f = bp.sdeint(lambda x, t: -x, g, method=method, intg_type=itype)
+ out = f(jnp.array([1.0]), 0.0, dt=0.01)
+ assert np.all(np.isfinite(np.asarray(out)))
+
+
+def test_cov_sde_exponential_euler():
+ """Cover the SDE ExponentialEuler integrator."""
+ bm.random.seed(22)
+ g = lambda x, t: jnp.ones_like(x) * 0.1
+ for method in ['exp_euler', 'exponential_euler']:
+ f = bp.sdeint(lambda x, t: -x, g, method=method)
+ out = f(jnp.array([1.0]), 0.0, dt=0.01)
+ assert np.all(np.isfinite(np.asarray(out)))
+
+
+def test_cov_sde_multivariable():
+ """Cover the multi-variable drift/diffusion branches of the Euler integrator.
+
+ Milstein deliberately supports only a single variable at a time and raises
+ DiffEqError for multi-variable systems, so it is exercised separately.
+ """
+ bm.random.seed(23)
+
+ def f(x, y, t):
+ return -x, -y
+
+ def g(x, y, t):
+ return 0.1 * jnp.ones_like(x), 0.1 * jnp.ones_like(y)
+
+ for itype in ['Ito', 'Stratonovich']:
+ intg = bp.sdeint(f, g, method='euler', intg_type=itype)
+ out = intg(jnp.array([1.0]), jnp.array([2.0]), 0.0, dt=0.01)
+ assert len(out) == 2
+ assert np.all(np.isfinite(np.asarray(out[0])))
+ assert np.all(np.isfinite(np.asarray(out[1])))
+
+ # Milstein rejects multi-variable systems (raised when building/calling).
+ with pytest.raises(DiffEqError):
+ intg = bp.sdeint(f, g, method='milstein')
+ intg(jnp.array([1.0]), jnp.array([2.0]), 0.0, dt=0.01)
+
+
+def test_cov_sde_vector_wiener():
+ """Cover the VECTOR_WIENER (Ito) summation branch of the Euler integrator.
+
+ Only the Ito branch is exercised: the Stratonovich vector-wiener path has a
+ latent broadcasting bug (g(Y) of shape (3, 2) is added to a state of shape
+ (3,) without summing over the noise axis) and is out of scope for this audit.
+ """
+ bm.random.seed(24)
+
+ def fv(x, t):
+ return -x
+
+ def gv(x, t):
+ return 0.1 * jnp.ones((3, 2)) # 2 independent noise sources
+
+ intg = bp.sdeint(fv, gv, method='euler', wiener_type='vector',
+ intg_type='Ito')
+ out = intg(jnp.ones(3), 0.0, dt=0.01)
+ assert np.asarray(out).shape == (3,)
+ assert np.all(np.isfinite(np.asarray(out)))
+
+
+def test_cov_sde_vector_wiener_requires_nd_diffusion():
+ """Cover the vector-wiener scalar-diffusion guard (ValueError path)."""
+ bm.random.seed(25)
+ intg = bp.sdeint(lambda x, t: -x, lambda x, t: jnp.float32(0.1),
+ method='euler', wiener_type='vector')
+ with pytest.raises(ValueError):
+ intg(jnp.array([1.0]), 0.0, dt=0.01)
+
+
+def test_cov_sde_drift_must_be_tensor():
+ """Cover the single-variable drift-not-a-tensor ValueError branch."""
+ intg = bp.sdeint(lambda x, t: -1.0, lambda x, t: jnp.ones_like(x) * 0.1,
+ method='euler')
+ with pytest.raises(ValueError):
+ intg(jnp.array([1.0]), 0.0, dt=0.01)
+
+
+# --------------------------------------------------------------------------- #
+# Extra coverage: IntegratorRunner.run options (runner.py).
+# --------------------------------------------------------------------------- #
+
+def test_cov_runner_run_options():
+ """Cover run(start_t=, eval_time=True) and the .mon['ts'] time axis."""
+ bm.set_dt(0.1)
+ intg = bp.odeint(lambda V, t: -V, method='rk4')
+ runner = bp.IntegratorRunner(intg, monitors={'V': 'V'}, inits={'V': 1.0},
+ dt=0.1, progress_bar=False)
+ runner.run(0.3, start_t=0.0, eval_time=True)
+ assert np.asarray(runner.mon['V']).shape == (3, 1)
+ # second run continues from the previous index (covers idx bookkeeping)
+ runner.run(0.2)
+ assert np.asarray(runner.mon['V']).shape == (2, 1)
+
+
+def test_cov_runner_rejects_non_integrator():
+ """Cover the target type check in IntegratorRunner.__init__."""
+ with pytest.raises(TypeError):
+ bp.IntegratorRunner(lambda V, t: -V, monitors={'V': 'V'},
+ inits={'V': 1.0}, dt=0.1)
+
+
+# --------------------------------------------------------------------------- #
+# Extra coverage: JointEq error branches (joint_eq.py).
+# --------------------------------------------------------------------------- #
+
+def test_cov_jointeq_duplicate_state_variable():
+ """Cover the duplicate-state-variable DiffEqError branch."""
+ def dV1(V, t):
+ return -V
+
+ def dV2(V, t): # 'V' used as a state variable twice
+ return -2 * V
+
+ with pytest.raises(DiffEqError):
+ bp.JointEq(dV1, dV2)
+
+
+# --------------------------------------------------------------------------- #
+# Extra coverage: exponential ODE auto-diff path (exponential.py).
+# --------------------------------------------------------------------------- #
+
+def test_cov_exp_euler_nonlinear():
+ """Cover the auto-linearization path on a non-linear right-hand side."""
+ # logistic-like ODE; just needs to run and stay finite
+ intg = bp.odeint(lambda y, t: y * (1.0 - y), method='exp_euler_auto')
+ y = jnp.array([0.1])
+ t = 0.0
+ for _ in range(5):
+ y = intg(y, t, dt=0.05)
+ t += 0.05
+ assert np.all(np.isfinite(np.asarray(y)))
+ assert float(y[0]) > 0.1 # logistic growth
+
+
+def test_cov_exp_euler_show_code():
+ """Cover the show_code emission branch of ExponentialEuler."""
+ intg = bp.odeint(lambda y, t: -y, method='exp_euler', show_code=True)
+ out = intg(jnp.array([1.0]), 0.0, dt=0.05)
+ assert np.all(np.isfinite(np.asarray(out)))
+
+
+# --------------------------------------------------------------------------- #
+# Extra coverage: JointEq argument parsing edge cases (joint_eq.py).
+# --------------------------------------------------------------------------- #
+
+def test_cov_jointeq_missing_time_variable():
+ """Cover the 'Do not find time variable "t"' ValueError branch."""
+ with pytest.raises(ValueError):
+ bp.JointEq(lambda V, w: -V) # no 't' parameter
+
+
+def test_cov_jointeq_rejects_var_positional():
+ """Cover the VAR_POSITIONAL (*args) rejection branch."""
+ with pytest.raises(DiffEqError):
+ bp.JointEq(lambda V, t, *extra: -V)
+
+
+def test_cov_jointeq_rejects_var_keyword():
+ """Cover the VAR_KEYWORD (**kwargs) rejection branch."""
+ with pytest.raises(DiffEqError):
+ bp.JointEq(lambda V, t, **extra: -V)
+
+
+def test_cov_jointeq_conflicting_kwarg_defaults():
+ """Cover the 'two different default value' DiffEqError branch."""
+ def dV(V, t, a=1.0):
+ return -V + a
+
+ def dw(w, t, a=2.0): # same kwarg name, different default
+ return -w + a
+
+ with pytest.raises(DiffEqError):
+ bp.JointEq(dV, dw)
+
+
+def test_cov_jointeq_nested_list_and_shared_kwarg():
+ """Cover the _check_eqs list/tuple recursion + a shared (consistent) kwarg
+ default, then call with the kwarg passed both positionally and by keyword."""
+ def dV(V, t, w, gain=0.5):
+ return -V + gain * w
+
+ def dw(w, t, V, gain=0.5): # same default -> allowed, shared kwarg
+ return -w + gain * V
+
+ eq = bp.JointEq([dV, dw]) # list form exercises the recursion branch
+ assert 'gain' in eq.kwarg_keys
+ intg = bp.odeint(eq, method='euler')
+ # call providing gain by keyword
+ out = intg(1.0, 0.0, gain=0.3, dt=0.01)
+ assert all(np.isfinite(float(o)) for o in out)
+
+
+# --------------------------------------------------------------------------- #
+# Extra coverage: FDE reset / check paths (Caputo.py, generic.py).
+# --------------------------------------------------------------------------- #
+
+def test_cov_caputo_euler_reset(x64):
+ """Cover CaputoEuler.reset()."""
+ intg = bp.fde.CaputoEuler(lambda x, t: -x, alpha=0.8,
+ num_memory=10, inits=[1.0])
+ x = jnp.array([1.0])
+ x = intg(x, 0.0, dt=0.05)
+ intg.reset(inits=[3.0])
+ assert np.allclose(np.asarray(intg.inits[intg.variables[0]]), 3.0)
+
+
+def test_cov_caputo_euler_requires_tensor_drift(x64):
+ """Cover the single-variable 'Derivative values must be a tensor' branch."""
+ intg = bp.fde.CaputoEuler(lambda x, t: 0.0, alpha=0.8,
+ num_memory=10, inits=[1.0])
+ with pytest.raises(ValueError):
+ intg(jnp.array([1.0]), 0.0, dt=0.05)
+
+
+def test_cov_fde_register_duplicate_rejected():
+ """Cover register_fde_integrator duplicate-name guard."""
+ from brainpy.integrators.fde.generic import register_fde_integrator
+ from brainpy.integrators.fde.GL import GLShortMemory
+ with pytest.raises(ValueError):
+ register_fde_integrator('euler', GLShortMemory)
+
+
+def test_cov_fdeint_unknown_method_rejected():
+ """Cover the fdeint unknown-method ValueError branch."""
+ with pytest.raises(ValueError):
+ bp.fdeint(alpha=0.8, num_memory=4, inits=[1.0],
+ method='nope', dt=0.05)(lambda x, t: -x)
+
+
+def test_cov_set_default_fdeint_type_check(restore_default_fdeint):
+ """Cover the non-string set_default_fdeint guard."""
+ with pytest.raises(ValueError):
+ bp.fde.set_default_fdeint(123)
+
+
+# --------------------------------------------------------------------------- #
+# Extra coverage: exponential ODE branches (exponential.py).
+# --------------------------------------------------------------------------- #
+
+def test_cov_exp_euler_with_jointeq():
+ """Cover the JointEq build path of ExponentialEuler (_build_integrator)."""
+ def dV(V, t, w):
+ return -V - w
+
+ def dw(w, t, V):
+ return -w + 0.1 * V
+
+ eq = bp.JointEq(dV, dw)
+ intg = bp.odeint(eq, method='exp_euler')
+ out = intg(jnp.array([1.0]), jnp.array([0.5]), 0.0, dt=0.01)
+ assert len(out) == 2
+ assert all(np.all(np.isfinite(np.asarray(o))) for o in out)
+
+
+def test_cov_exp_euler_rejects_integer_input():
+ """Cover the float-dtype guard of the Exponential Euler integral."""
+ intg = bp.odeint(lambda y, t: -y, method='exp_euler')
+ with pytest.raises(ValueError):
+ intg(jnp.array([1, 2, 3]), 0.0, dt=0.01) # integer dtype
+
+
+def test_cov_exp_euler_system_var_not_implemented():
+ """Cover the SYSTEM_VAR NotImplementedError branch."""
+ with pytest.raises(NotImplementedError):
+ bp.odeint(lambda y, t: -y, method='exp_euler', var_type='system')
+
+
+# --------------------------------------------------------------------------- #
+# Extra coverage: Milstein vector-wiener branches (sde/normal.py).
+# --------------------------------------------------------------------------- #
+
+def test_cov_milstein_vector_wiener_latent_bug():
+ """Document a latent (out-of-scope) bug: the Milstein integrators do not
+ correctly broadcast the diffusion-gradient term for VECTOR_WIENER noise, so
+ a vector-wiener Milstein step currently raises a broadcasting ValueError.
+
+ This still exercises the vector-wiener branch up to the failure point; if the
+ library is ever fixed, this test should assert a finite result instead.
+ """
+ bm.random.seed(31)
+
+ def fv(x, t):
+ return -x
+
+ def gv(x, t):
+ return 0.1 * jnp.ones((3, 2))
+
+ for method in ['milstein', 'milstein2']:
+ intg = bp.sdeint(fv, gv, method=method, wiener_type='vector',
+ intg_type='Ito')
+ with pytest.raises((ValueError, TypeError)):
+ intg(jnp.ones(3), 0.0, dt=0.01)
+
+
+# --------------------------------------------------------------------------- #
+# Extra coverage: IntegratorRunner init / dyn_args branches (runner.py).
+# --------------------------------------------------------------------------- #
+
+def test_cov_runner_inits_as_sequence():
+ """Cover the list/sequence form of the ``inits`` argument."""
+ bm.set_dt(0.1)
+ intg = bp.odeint(lambda V, t: -V, method='euler')
+ runner = bp.IntegratorRunner(intg, monitors={'V': 'V'}, inits=[1.0],
+ dt=0.1, progress_bar=False)
+ runner.run(0.3)
+ assert np.asarray(runner.mon['V']).shape == (3, 1)
+
+
+def test_cov_runner_dyn_args():
+ """Cover the dyn_args time-varying input path of IntegratorRunner.run."""
+ bm.set_dt(0.1)
+ intg = bp.odeint(lambda V, t, Iext: -V + Iext, method='euler')
+ runner = bp.IntegratorRunner(intg, monitors={'V': 'V'}, inits={'V': 0.0},
+ dt=0.1, progress_bar=False)
+ # 3 steps -> dyn_args first dimension must be 3
+ runner.run(0.3, dyn_args={'Iext': jnp.ones(3)})
+ assert np.asarray(runner.mon['V']).shape == (3, 1)
+
+
+def test_cov_runner_dyn_args_shape_mismatch():
+ """Cover the dyn_args duration-mismatch ValueError branch."""
+ bm.set_dt(0.1)
+ intg = bp.odeint(lambda V, t, Iext: -V + Iext, method='euler')
+ runner = bp.IntegratorRunner(intg, monitors={'V': 'V'}, inits={'V': 0.0},
+ dt=0.1, progress_bar=False)
+ with pytest.raises(ValueError):
+ runner.run(0.3, dyn_args={'Iext': jnp.ones(99)})
+
+
+# --------------------------------------------------------------------------- #
+# Extra coverage: SDE ExponentialEuler branches (sde/normal.py).
+# --------------------------------------------------------------------------- #
+
+def test_cov_sde_exp_euler_vector_wiener():
+ """Cover the VECTOR_WIENER branch of the SDE ExponentialEuler."""
+ bm.random.seed(41)
+ intg = bp.sdeint(lambda x, t: -x, lambda x, t: 0.1 * jnp.ones((3, 2)),
+ method='exp_euler', wiener_type='vector')
+ out = intg(jnp.ones(3), 0.0, dt=0.01)
+ assert np.asarray(out).shape == (3,)
+ assert np.all(np.isfinite(np.asarray(out)))
+
+
+def test_cov_sde_exp_euler_rejects_stratonovich():
+ """Cover the SDE ExponentialEuler Stratonovich NotImplementedError branch."""
+ with pytest.raises(NotImplementedError):
+ bp.sdeint(lambda x, t: -x, lambda x, t: jnp.ones((1,)) * 0.1,
+ method='exp_euler', intg_type='Stratonovich')
+
+
+def test_cov_sde_exp_euler_with_jointeq():
+ """Cover the JointEq build + multi-variable diffusion path of SDE exp_euler."""
+ bm.random.seed(42)
+
+ def dV(V, t, w):
+ return -V
+
+ def dw(w, t, V):
+ return -w
+
+ def gV(V, t, w):
+ return jnp.ones_like(V) * 0.1
+
+ def gw(w, t, V):
+ return jnp.ones_like(w) * 0.1
+
+ intg = bp.sdeint(bp.JointEq(dV, dw), bp.JointEq(gV, gw), method='exp_euler')
+ out = intg(jnp.array([1.0]), jnp.array([0.5]), 0.0, dt=0.01)
+ assert len(out) == 2
+ assert all(np.all(np.isfinite(np.asarray(o))) for o in out)
+
+
+# --------------------------------------------------------------------------- #
+# Extra coverage: JointEq.__call__ with a tuple-returning sub-equation.
+# --------------------------------------------------------------------------- #
+
+def test_cov_jointeq_call_with_nested_tuple_result():
+ """Cover JointEq.__call__ extending results from a tuple-returning sub-eq
+ (a nested JointEq returns a list/tuple)."""
+ def dV(V, t, w):
+ return -V + w
+
+ def dw(w, t, V):
+ return -w + 0.1 * V
+
+ inner = bp.JointEq(dV, dw) # returns a list when called
+
+ def du(u, t, V):
+ return -u + V
+
+ outer = bp.JointEq(inner, du)
+ res = outer(1.0, 0.5, 0.2, 0.0) # V, w, u, t
+ assert len(res) == 3
+ assert all(np.isfinite(float(r)) for r in res)
+
+
+# --------------------------------------------------------------------------- #
+# Extra coverage: multi-variable FDE + IntegratorRunner branches.
+# --------------------------------------------------------------------------- #
+
+def test_cov_caputo_euler_multivariable(x64):
+ """Cover the multi-variable drift branch of CaputoEuler."""
+ def f(x, y, t):
+ return -x, -y
+
+ intg = bp.fde.CaputoEuler(f, alpha=[0.8, 0.9], num_memory=10,
+ inits=[1.0, 2.0])
+ x = jnp.array([1.0])
+ y = jnp.array([2.0])
+ t = 0.0
+ for _ in range(3):
+ x, y = intg(x, y, t, dt=0.05)
+ t += 0.05
+ assert np.all(np.isfinite(np.asarray(x)))
+ assert np.all(np.isfinite(np.asarray(y)))
+ # decaying solution
+ assert float(x[0]) < 1.0 and float(y[0]) < 2.0
+
+
+def test_cov_caputo_l1_multivariable(x64):
+ """Cover the multi-variable branch + per-variable hists() of CaputoL1Schema."""
+ def f(x, y, t):
+ return -x, -y
+
+ intg = bp.fde.CaputoL1Schema(f, alpha=[0.8, 0.9], num_memory=10,
+ inits=[1.0, 2.0])
+ x = jnp.array([1.0])
+ y = jnp.array([2.0])
+ t = 0.0
+ for _ in range(3):
+ x, y = intg(x, y, t, dt=0.05)
+ t += 0.05
+ hists = intg.hists()
+ assert set(hists.keys()) == set(intg.variables)
+ one = intg.hists(var=intg.variables[1])
+ assert isinstance(one, np.ndarray)
+
+
+def test_cov_runner_multivariable_with_progress_bar():
+ """Cover the multi-variable update branch and the progress-bar callback of
+ IntegratorRunner."""
+ bm.set_dt(0.1)
+
+ def dV(V, t, w):
+ return -V + w
+
+ def dw(w, t, V):
+ return -w
+
+ intg = bp.odeint(bp.JointEq(dV, dw), method='rk2')
+ runner = bp.IntegratorRunner(
+ intg,
+ monitors={'V': 'V', 'w': 'w'},
+ inits={'V': 1.0, 'w': 0.5},
+ dt=0.1,
+ progress_bar=True,
+ )
+ runner.run(0.3)
+ assert np.asarray(runner.mon['V']).shape == (3, 1)
+ assert np.asarray(runner.mon['w']).shape == (3, 1)
+
+
+def test_cov_jointeq_rejects_keyword_only():
+ """Cover the KEYWORD_ONLY parameter rejection branch of _get_args."""
+ def dV(V, t, *, w):
+ return -V
+
+ with pytest.raises(DiffEqError):
+ bp.JointEq(dV, lambda w, t: -w)
+
+
+def test_cov_jointeq_rejects_positional_only():
+ """Cover the POSITIONAL_ONLY parameter rejection branch of _get_args."""
+ def dV(V, t, w, /):
+ return -V
+
+ with pytest.raises(DiffEqError):
+ bp.JointEq(dV, lambda w, t: -w)
diff --git a/tests/audit/test_math_compat_fixes.py b/tests/audit/test_math_compat_fixes.py
new file mode 100644
index 000000000..66bde6202
--- /dev/null
+++ b/tests/audit/test_math_compat_fixes.py
@@ -0,0 +1,1017 @@
+# -*- coding: utf-8 -*-
+"""Regression + coverage tests for the ``brainpy.math`` compatibility layer.
+
+This module accompanies the audit recorded in ``docs/issues-found-20260618.md``.
+It pins the behaviours fixed by the audit and broadly exercises the numpy /
+pytorch / tensorflow compatibility shims, the activation functions, the misc.
+``others`` helpers, the ``_utils`` wrapper factory and the ``einops`` port.
+
+Audit findings exercised here (see the doc for full context):
+
+* C-11 (``compat_tensorflow.py``) -- ``reduce_logsumexp`` must be numerically
+ stable (delegated to ``jax.scipy.special.logsumexp``).
+* H-16 (``activations.py``) -- ``softmin`` must subtract the max so it
+ stays finite for large inputs.
+* H-14 (``_utils.py`` + pytorch) -- ``out=`` wrapped funcs must *return* the
+ ``out`` Array instead of ``None``.
+* H-15 (``others.py``) -- ``remove_diag`` must trace cleanly under
+ ``jit``/``vmap`` (static off-diagonal gather, handles non-square).
+* H-13 (``compat_numpy.py``) -- ``asfarray`` must coerce integer input to
+ a floating dtype.
+* M-11/M-12 (``compat_numpy.py``) -- ``empty`` uses ``jnp.empty`` and
+ ``fill_diagonal(inplace=False)`` returns a brainpy ``Array``.
+* (``compat_pytorch.py``) -- ``arcsinh``/``arctanh`` exist & correct,
+ no duplicate ``arcsin`` clobbering.
+* (``einops.py``) -- module still imports after the dead
+ ``_optimize_transformation`` helper was removed.
+"""
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pytest
+
+import brainpy.math as bm
+from brainpy.math import (
+ compat_numpy as cn,
+ compat_pytorch as cpt,
+ compat_tensorflow as ctf,
+ activations as act,
+ others as bo,
+ _utils as butils,
+ einops as bein,
+)
+from brainpy.math.ndarray import Array
+
+
+# ---------------------------------------------------------------------------
+# helpers
+# ---------------------------------------------------------------------------
+
+def _j(x):
+ """Return the underlying jax array for assertions."""
+ return bm.as_jax(x)
+
+
+def _finite(x):
+ return bool(jnp.all(jnp.isfinite(_j(x))))
+
+
+# ===========================================================================
+# 1. Regression tests (audit-specific behaviours)
+# ===========================================================================
+
+# --- C-11: reduce_logsumexp numerical stability -----------------------------
+
+def test_reduce_logsumexp_stable_large_inputs():
+ """C-11: log(sum(exp(.))) must not overflow for large inputs."""
+ r = ctf.reduce_logsumexp(bm.asarray([1000., 1000., 1000.]))
+ assert _finite(r)
+ # logsumexp([1000]*3) == 1000 + log(3) == 1001.0986...
+ assert float(_j(r)) == pytest.approx(1000.0 + np.log(3.0), abs=1e-2)
+
+
+def test_reduce_logsumexp_matches_reference_small():
+ x = bm.asarray([0.5, -1.0, 2.0, 3.5])
+ r = ctf.reduce_logsumexp(x)
+ ref = jax.scipy.special.logsumexp(_j(x))
+ assert float(_j(r)) == pytest.approx(float(ref), rel=1e-6)
+
+
+def test_reduce_logsumexp_axis_keepdims():
+ x = bm.asarray([[1., 2., 3.], [4., 5., 6.]])
+ r = ctf.reduce_logsumexp(x, axis=1, keepdims=True)
+ assert _j(r).shape == (2, 1)
+ assert _finite(r)
+
+
+# --- H-16: softmin must not produce NaN for large inputs --------------------
+
+def test_softmin_finite_for_large_inputs():
+ """H-16: softmin lacked max-subtraction -> NaN for large inputs."""
+ r = act.softmin(bm.asarray([1000., 1001., 1002.]))
+ assert _finite(r)
+ # softmin == softmax(-x); the smallest input gets the largest weight.
+ expected = np.array([0.66524096, 0.24472847, 0.09003057])
+ np.testing.assert_allclose(np.asarray(_j(r)), expected, atol=1e-4)
+ assert float(jnp.sum(_j(r))) == pytest.approx(1.0, abs=1e-6)
+
+
+def test_softmin_matches_softmax_of_negative():
+ x = bm.asarray([0.3, -1.2, 2.5, 0.0])
+ np.testing.assert_allclose(
+ np.asarray(_j(act.softmin(x))),
+ np.asarray(_j(act.softmax(-x))),
+ atol=1e-6,
+ )
+
+
+# --- H-14: out= wrapped funcs must RETURN the out Array ---------------------
+
+def test_numpy_sum_out_returns_array():
+ """H-14: numpy-style ``out=`` should return the (filled) Array, not None."""
+ out = bm.asarray(0.)
+ r = bm.sum(bm.asarray([1., 2., 3.]), out=out)
+ assert r is not None
+ assert isinstance(r, Array)
+ assert r is out
+ assert float(_j(r)) == pytest.approx(6.0)
+ assert float(_j(out)) == pytest.approx(6.0)
+
+
+def test_numpy_out_must_be_brainpy_array():
+ with pytest.raises(TypeError):
+ bm.sum(bm.asarray([1., 2., 3.]), out=jnp.array(0.))
+
+
+def test_pytorch_add_out_returns_out():
+ """H-14 (pytorch compat): ``add(..., out=...)`` returns ``out``."""
+ out = bm.zeros((3,))
+ r = cpt.add(bm.asarray([1., 2., 3.]), bm.asarray([1., 1., 1.]), out=out)
+ assert r is out
+ np.testing.assert_allclose(np.asarray(_j(r)), [2., 3., 4.])
+
+
+def test_pytorch_out_must_be_brainpy_array():
+ with pytest.raises(TypeError):
+ cpt.abs(bm.asarray([-1.]), out=jnp.array(0.))
+
+
+# --- H-15: remove_diag must trace cleanly under jit/vmap --------------------
+
+def test_remove_diag_jit_matches_eager_square():
+ """H-15: remove_diag must work under jit (static off-diag gather)."""
+ x = jnp.arange(9.).reshape(3, 3)
+ eager = _j(bo.remove_diag(x))
+ jitted = jax.jit(bo.remove_diag)(x)
+ assert eager.shape == (3, 2)
+ np.testing.assert_allclose(np.asarray(eager), np.asarray(jitted))
+ # row 0 has its diagonal (0) removed -> [1, 2]
+ np.testing.assert_allclose(np.asarray(eager[0]), [1., 2.])
+
+
+def test_remove_diag_vmap():
+ x = jnp.arange(2 * 9.).reshape(2, 3, 3)
+ out = jax.vmap(bo.remove_diag)(x)
+ assert out.shape == (2, 3, 2)
+ np.testing.assert_allclose(np.asarray(out[0]), np.asarray(_j(bo.remove_diag(x[0]))))
+
+
+def test_remove_diag_non_square():
+ x = jnp.arange(12.).reshape(3, 4)
+ out = _j(bo.remove_diag(x))
+ assert out.shape == (3, 3)
+ # row 0 drops col 0; rows keep all but the diagonal element
+ np.testing.assert_allclose(np.asarray(out[0]), [1., 2., 3.])
+ np.testing.assert_allclose(np.asarray(out[1]), [4., 6., 7.])
+
+
+def test_remove_diag_rejects_non_2d():
+ with pytest.raises(ValueError):
+ bo.remove_diag(jnp.arange(3.))
+
+
+# --- H-13: asfarray coerces integers to a floating dtype --------------------
+
+def test_asfarray_integer_input_becomes_float():
+ """H-13: asfarray(int) used to no-op; must yield a floating dtype."""
+ r = cn.asfarray([1, 2, 3])
+ assert jnp.issubdtype(r.dtype, jnp.floating)
+ np.testing.assert_allclose(np.asarray(_j(r)), [1., 2., 3.])
+
+
+def test_asfarray_preserves_floating_dtype():
+ r = cn.asfarray(jnp.array([1., 2.], dtype=jnp.float32))
+ assert r.dtype == jnp.float32
+
+
+# --- M-11 / M-12: empty + fill_diagonal -------------------------------------
+
+def test_empty_shape_and_type():
+ """M-11: empty must produce the right shape/dtype as a brainpy Array."""
+ e = bm.empty((2, 3))
+ assert isinstance(e, Array)
+ assert e.shape == (2, 3)
+ assert jnp.issubdtype(e.dtype, jnp.floating)
+
+
+def test_empty_like():
+ a = bm.asarray(jnp.ones((4,), dtype=jnp.int32))
+ e = cn.empty_like(a)
+ assert e.shape == (4,)
+ assert e.dtype == jnp.int32
+
+
+def test_fill_diagonal_not_inplace_returns_array():
+ """M-12: fill_diagonal(inplace=False) must return a brainpy Array."""
+ x = bm.asarray(jnp.ones((3, 3)))
+ r = cn.fill_diagonal(x, 5., inplace=False)
+ assert isinstance(r, Array)
+ np.testing.assert_allclose(np.diag(np.asarray(_j(r))), [5., 5., 5.])
+ # original unchanged
+ np.testing.assert_allclose(np.diag(np.asarray(_j(x))), [1., 1., 1.])
+
+
+def test_fill_diagonal_inplace_updates_array():
+ x = bm.asarray(jnp.ones((3, 3)))
+ out = cn.fill_diagonal(x, 7., inplace=True)
+ assert out is None # in-place returns nothing
+ np.testing.assert_allclose(np.diag(np.asarray(_j(x))), [7., 7., 7.])
+
+
+def test_fill_diagonal_errors():
+ with pytest.raises(ValueError):
+ cn.fill_diagonal(bm.asarray(jnp.arange(3)), 1.) # ndim < 2
+ with pytest.raises(ValueError):
+ cn.fill_diagonal(jnp.ones((3, 3)), 1.) # inplace on non-Array
+
+
+# --- pytorch arcsinh / arctanh exist & correct, no dup arcsin ---------------
+
+def test_pytorch_arcsinh_arctanh():
+ np.testing.assert_allclose(
+ np.asarray(_j(cpt.arcsinh(bm.asarray([0., 1.])))),
+ np.arcsinh([0., 1.]), atol=1e-6,
+ )
+ np.testing.assert_allclose(
+ np.asarray(_j(cpt.arctanh(bm.asarray([0., 0.5])))),
+ np.arctanh([0., 0.5]), atol=1e-6,
+ )
+ # aliases point at their canonical implementations
+ assert cpt.arcsinh is cpt.asinh
+ assert cpt.arctanh is cpt.atanh
+
+
+def test_pytorch_arcsin_is_asin_not_arcsinh():
+ """No duplicate ``arcsin`` should shadow asin with arcsinh."""
+ assert cpt.arcsin is cpt.asin
+ np.testing.assert_allclose(
+ np.asarray(_j(cpt.arcsin(bm.asarray([0., 0.5])))),
+ np.arcsin([0., 0.5]), atol=1e-6,
+ )
+
+
+def test_numpy_arcsinh_arctanh_present():
+ np.testing.assert_allclose(
+ np.asarray(_j(cn.arcsinh(bm.asarray([0., 1.])))), np.arcsinh([0., 1.]), atol=1e-6)
+ np.testing.assert_allclose(
+ np.asarray(_j(cn.arctanh(bm.asarray([0., 0.5])))), np.arctanh([0., 0.5]), atol=1e-6)
+
+
+# --- einops module still imports (dead _optimize_transformation removed) ----
+
+def test_einops_module_imports():
+ import brainpy.math.einops as eio
+ assert hasattr(eio, 'ein_rearrange')
+ assert hasattr(eio, 'ein_reduce')
+ assert hasattr(eio, 'ein_repeat')
+ assert hasattr(eio, 'ein_shape')
+ # the dead helper flagged by the audit must be gone
+ assert not hasattr(eio, '_optimize_transformation')
+
+
+# ===========================================================================
+# 2. Coverage tests
+# ===========================================================================
+
+# --- compat_numpy: creation funcs -------------------------------------------
+
+def test_compat_numpy_creation():
+ assert cn.zeros((2, 2)).shape == (2, 2)
+ assert cn.ones((3,)).shape == (3,)
+ assert cn.empty((2,)).shape == (2,)
+ assert float(_j(cn.full((2,), 4.))[0]) == 4.
+ assert cn.eye(3).shape == (3, 3)
+ assert cn.identity(3).shape == (3, 3)
+ assert cn.arange(5).shape == (5,)
+ assert cn.linspace(0., 1., 5).shape == (5,)
+ assert cn.logspace(0., 1., 5).shape == (5,)
+ a = bm.asarray(jnp.arange(4.))
+ assert cn.zeros_like(a).shape == (4,)
+ assert cn.ones_like(a).shape == (4,)
+ assert cn.empty_like(a).shape == (4,)
+ assert cn.full_like(a, 2.).shape == (4,)
+
+
+def test_compat_numpy_linspace_retstep():
+ r, step = cn.linspace(0., 1., 5, retstep=True)
+ assert isinstance(r, Array)
+ assert float(step) == pytest.approx(0.25)
+
+
+def test_compat_numpy_array_and_asarray_with_arrays():
+ res = cn.array([bm.asarray([1., 2.]), bm.asarray([3., 4.])])
+ assert res.shape == (2, 2)
+ res2 = cn.asarray([bm.asarray([1., 2.]), bm.asarray([3., 4.])])
+ assert res2.shape == (2, 2)
+ assert cn.asanyarray([1, 2, 3]).shape == (3,)
+ assert cn.ascontiguousarray([1, 2, 3]).shape == (3,)
+
+
+def test_compat_numpy_diag_tri_family():
+ m = bm.asarray(jnp.arange(9.).reshape(3, 3))
+ assert cn.diag(m).shape == (3,)
+ assert cn.tril(m).shape == (3, 3)
+ assert cn.triu(m).shape == (3, 3)
+ assert cn.tri(3).shape == (3, 3)
+ assert cn.diagonal(m).shape == (3,)
+ assert cn.diagflat(bm.asarray([1., 2., 3.])).shape == (3, 3)
+
+
+def test_compat_numpy_ufuncs():
+ a = bm.asarray([1., 2., 3.])
+ b = bm.asarray([4., 5., 6.])
+ funcs_binary = [cn.add, cn.subtract, cn.multiply, cn.divide, cn.power,
+ cn.true_divide, cn.maximum, cn.minimum, cn.fmax, cn.fmin,
+ cn.hypot, cn.logaddexp, cn.logaddexp2, cn.copysign,
+ cn.nextafter, cn.remainder, cn.mod, cn.fmod, cn.float_power]
+ for f in funcs_binary:
+ assert _j(f(a, b)).shape == (3,)
+ funcs_unary = [cn.negative, cn.positive, cn.reciprocal, cn.abs, cn.absolute,
+ cn.fabs, cn.exp, cn.exp2, cn.expm1, cn.log1p, cn.sqrt, cn.cbrt,
+ cn.square, cn.sign, cn.sin, cn.cos, cn.tan, cn.sinh, cn.cosh,
+ cn.tanh, cn.arcsin, cn.arccos, cn.arctan, cn.arcsinh, cn.arctanh,
+ cn.sinc, cn.deg2rad, cn.rad2deg, cn.degrees, cn.radians,
+ cn.round, cn.rint, cn.floor, cn.ceil, cn.trunc, cn.isfinite,
+ cn.isinf, cn.isnan, cn.signbit, cn.conj, cn.conjugate, cn.real,
+ cn.imag, cn.angle, cn.nan_to_num]
+ pos = bm.asarray([0.1, 0.2, 0.3])
+ for f in funcs_unary:
+ assert _j(f(pos)).shape == (3,)
+ # log family wants strictly positive
+ for f in [cn.log, cn.log10, cn.log2]:
+ assert _j(f(a)).shape == (3,)
+ # gcd/lcm want integers
+ ia, ib = bm.asarray([4, 6, 8]), bm.asarray([6, 9, 12])
+ assert _j(cn.gcd(ia, ib)).shape == (3,)
+ assert _j(cn.lcm(ia, ib)).shape == (3,)
+ assert _j(cn.arctan2(a, b)).shape == (3,)
+ assert _j(cn.heaviside(a, b)).shape == (3,)
+ assert len(cn.frexp(a)) == 2
+ assert len(cn.modf(a)) == 2
+
+
+def test_compat_numpy_reductions():
+ a = bm.asarray([[1., 2., 3.], [4., 5., 6.]])
+ assert float(_j(cn.sum(a))) == 21.
+ assert float(_j(cn.prod(bm.asarray([1., 2., 3., 4.])))) == 24.
+ assert float(_j(cn.mean(a))) == pytest.approx(3.5)
+ assert float(_j(cn.max(a))) == 6.
+ assert float(_j(cn.min(a))) == 1.
+ assert float(_j(cn.amax(a))) == 6.
+ assert float(_j(cn.amin(a))) == 1.
+ assert _j(cn.std(a)).shape == ()
+ assert _j(cn.var(a)).shape == ()
+ assert _j(cn.median(a)).shape == ()
+ assert _j(cn.average(a)).shape == ()
+ assert _j(cn.cumsum(a)).shape == (6,)
+ assert _j(cn.cumprod(a)).shape == (6,)
+ assert _j(cn.nansum(a)).shape == ()
+ assert _j(cn.nanprod(a)).shape == ()
+ assert _j(cn.nanmean(a)).shape == ()
+ assert _j(cn.nanstd(a)).shape == ()
+ assert _j(cn.nanvar(a)).shape == ()
+ assert _j(cn.nanmedian(a)).shape == ()
+ assert int(cn.argmax(a)) == 5
+ assert int(cn.argmin(a)) == 0
+ assert _j(cn.ptp(a)).shape == ()
+ assert _j(cn.diff(a)).shape == (2, 2)
+ assert _j(cn.nancumsum(a)).shape == (6,)
+ assert _j(cn.nancumprod(a)).shape == (6,)
+ assert int(cn.count_nonzero(a)) == 6
+ assert _j(cn.percentile(a, 50)).shape == ()
+ assert _j(cn.quantile(a, 0.5)).shape == ()
+
+
+def test_compat_numpy_logic():
+ a = bm.asarray([1., 2., 3.])
+ b = bm.asarray([1., 0., 3.])
+ assert _j(cn.equal(a, b)).shape == (3,)
+ assert _j(cn.not_equal(a, b)).shape == (3,)
+ assert _j(cn.greater(a, b)).shape == (3,)
+ assert _j(cn.greater_equal(a, b)).shape == (3,)
+ assert _j(cn.less(a, b)).shape == (3,)
+ assert _j(cn.less_equal(a, b)).shape == (3,)
+ assert bool(cn.array_equal(a, a))
+ assert _j(cn.isclose(a, b)).shape == (3,)
+ assert bool(cn.allclose(a, a))
+ ba = bm.asarray([True, False, True])
+ bb = bm.asarray([True, True, False])
+ assert _j(cn.logical_and(ba, bb)).shape == (3,)
+ assert _j(cn.logical_or(ba, bb)).shape == (3,)
+ assert _j(cn.logical_xor(ba, bb)).shape == (3,)
+ assert _j(cn.logical_not(ba)).shape == (3,)
+ assert bool(cn.all(ba)) is False
+ assert bool(cn.any(ba)) is True
+ assert bool(cn.alltrue(ba)) is False
+ assert bool(cn.sometrue(ba)) is True
+
+
+def test_compat_numpy_bit_ops():
+ a = bm.asarray([1, 2, 3])
+ b = bm.asarray([3, 2, 1])
+ assert _j(cn.bitwise_and(a, b)).shape == (3,)
+ assert _j(cn.bitwise_or(a, b)).shape == (3,)
+ assert _j(cn.bitwise_xor(a, b)).shape == (3,)
+ assert _j(cn.bitwise_not(a)).shape == (3,)
+ assert _j(cn.invert(a)).shape == (3,)
+ assert _j(cn.left_shift(a, b)).shape == (3,)
+ assert _j(cn.right_shift(a, b)).shape == (3,)
+
+
+def test_compat_numpy_manipulation():
+ a = bm.asarray(jnp.arange(6.))
+ m = bm.asarray(jnp.arange(6.).reshape(2, 3))
+ assert _j(cn.reshape(a, (2, 3))).shape == (2, 3)
+ assert _j(cn.ravel(m)).shape == (6,)
+ assert _j(cn.moveaxis(m, 0, 1)).shape == (3, 2)
+ assert _j(cn.transpose(m)).shape == (3, 2)
+ assert _j(cn.swapaxes(m, 0, 1)).shape == (3, 2)
+ assert _j(cn.concatenate([a, a])).shape == (12,)
+ assert _j(cn.stack([a, a])).shape == (2, 6)
+ assert _j(cn.vstack([a, a])).shape == (2, 6)
+ assert _j(cn.hstack([a, a])).shape == (12,)
+ assert _j(cn.dstack([a, a])).shape == (1, 6, 2)
+ assert _j(cn.column_stack([a, a])).shape == (6, 2)
+ assert len(cn.split(a, 2)) == 2
+ assert len(cn.array_split(a, 4)) == 4
+ assert _j(cn.tile(a, 2)).shape == (12,)
+ assert _j(cn.repeat(a, 2)).shape == (12,)
+ assert _j(cn.flip(a)).shape == (6,)
+ assert _j(cn.fliplr(m)).shape == (2, 3)
+ assert _j(cn.flipud(m)).shape == (2, 3)
+ assert _j(cn.roll(a, 1)).shape == (6,)
+ assert _j(cn.atleast_1d(bm.asarray(1.))).shape == (1,)
+ assert _j(cn.atleast_2d(a)).shape == (1, 6)
+ assert _j(cn.atleast_3d(a)).shape == (1, 6, 1)
+ assert _j(cn.expand_dims(a, 0)).shape == (1, 6)
+ assert _j(cn.squeeze(bm.asarray(jnp.ones((1, 3))))).shape == (3,)
+ assert _j(cn.append(a, a)).shape == (12,)
+ assert _j(cn.sort(a)).shape == (6,)
+ assert _j(cn.argsort(a)).shape == (6,)
+ assert _j(cn.unique(bm.asarray([1, 1, 2, 3]))).shape == (3,)
+ assert _j(cn.row_stack([a, a])).shape == (2, 6)
+
+
+def test_compat_numpy_indexing_and_where():
+ a = bm.asarray(jnp.arange(6.))
+ assert _j(cn.where(a > 2, a, 0.)).shape == (6,)
+ assert len(cn.nonzero(a)) == 1
+ assert _j(cn.argwhere(a > 2).flatten()).shape[0] >= 0
+ assert _j(cn.flatnonzero(a)).shape[0] == 5
+ assert int(cn.searchsorted(a, 3.5)) == 4
+ assert _j(cn.take(a, bm.asarray([0, 2, 4]))).shape == (3,)
+ assert _j(cn.select([a > 3], [a])).shape == (6,)
+ assert _j(cn.extract(a > 3, a)).shape[0] == 2
+ assert len(cn.tril_indices(3)) == 2
+ assert len(cn.triu_indices(3)) == 2
+ m = bm.asarray(jnp.ones((3, 3)))
+ assert len(cn.tril_indices_from(m)) == 2
+ assert len(cn.triu_indices_from(m)) == 2
+
+
+def test_compat_numpy_linalg():
+ a = bm.asarray([1., 2., 3.])
+ b = bm.asarray([4., 5., 6.])
+ m = bm.asarray(jnp.arange(9.).reshape(3, 3))
+ assert _j(cn.dot(a, b)).shape == ()
+ assert _j(cn.vdot(a, b)).shape == ()
+ assert _j(cn.inner(a, b)).shape == ()
+ assert _j(cn.outer(a, b)).shape == (3, 3)
+ assert _j(cn.kron(a, b)).shape == (9,)
+ assert _j(cn.matmul(m, m)).shape == (3, 3)
+ assert _j(cn.trace(m)).shape == ()
+ assert _j(cn.tensordot(m, m)).shape == ()
+
+
+def test_compat_numpy_misc_helpers():
+ a = bm.asarray([1., 2., 3.])
+ assert cn.shape(a) == (3,)
+ assert cn.shape([[1, 2]]) == (1, 2)
+ assert cn.shape(0) == ()
+ assert cn.size(a) == 3
+ assert cn.size(bm.asarray(jnp.ones((2, 3))), 1) == 3
+ assert cn.size([1, 2, 3]) == 3
+ assert int(cn.ndim(a)) == 1
+ assert float(cn.asscalar(bm.asarray(7.))) == 7.
+ assert cn.matrix([[1, 2], [3, 4]]).shape == (2, 2)
+ assert cn.asmatrix([1, 2, 3]).shape == (1, 3)
+ assert cn.mat([1, 2, 3]).shape == (1, 3)
+ assert cn.msort(bm.asarray(jnp.array([[3., 1.], [2., 4.]]))).shape == (2, 2)
+ assert cn.common_type(jnp.array([1., 2.])) is not None
+ assert cn.common_type(jnp.array([1 + 1j])) is not None # complex branch
+ assert _j(cn.frombuffer(b'\x01\x02\x03\x04', dtype=np.int8)).shape == (4,)
+ assert _j(cn.meshgrid(a, a))[0].shape == (3, 3)
+ assert _j(cn.broadcast_to(a, (2, 3))).shape == (2, 3)
+ assert cn.broadcast_shapes((3,), (2, 3)) == (2, 3)
+ assert _j(cn.pad(a, 1)).shape == (5,)
+ assert _j(cn.clip(a, 1.5, 2.5)).shape == (3,)
+ assert _j(cn.interp(bm.asarray([1.5]), a, a)).shape == (1,)
+ assert _j(cn.einsum('i,i->', a, a)).shape == ()
+ assert _j(cn.gradient(a)).shape == (3,)
+ assert _j(cn.histogram(a)[0]).shape == (10,)
+ assert _j(cn.bincount(bm.asarray([0, 1, 1, 2]))).shape == (3,)
+ assert _j(cn.digitize(a, bm.asarray([0., 2.]))).shape == (3,)
+
+
+def test_compat_numpy_window_and_constants():
+ assert _j(cn.bartlett(4)).shape == (4,)
+ assert _j(cn.blackman(4)).shape == (4,)
+ assert _j(cn.hamming(4)).shape == (4,)
+ assert _j(cn.hanning(4)).shape == (4,)
+ assert _j(cn.kaiser(4, 1.0)).shape == (4,)
+ assert cn.e == pytest.approx(np.e)
+ assert cn.pi == pytest.approx(np.pi)
+ assert np.isinf(cn.inf)
+
+
+def test_compat_numpy_inplace_helpers_and_errors():
+ a = bm.asarray(jnp.arange(6))
+ cn.place(a, jnp.array([True, False] * 3), [10, 20, 30])
+ b = bm.asarray(jnp.arange(6))
+ cn.put(b, jnp.array([0, 1]), jnp.array([9, 8]))
+ assert int(_j(b)[0]) == 9
+ c = bm.asarray(jnp.zeros(3))
+ cn.copyto(c, jnp.ones(3))
+ assert float(_j(c)[0]) == 1.
+ # error paths (non-Array inputs)
+ with pytest.raises(ValueError):
+ cn.place(jnp.arange(6), jnp.array([True] * 6), [1])
+ with pytest.raises(ValueError):
+ cn.put(jnp.arange(6), [0], [9])
+ with pytest.raises(ValueError):
+ cn.putmask(jnp.arange(6), jnp.arange(6) > 2, jnp.arange(6))
+ with pytest.raises(ValueError):
+ cn.copyto(jnp.zeros(3), jnp.ones(3))
+
+
+def test_compat_numpy_in1d_and_set_ops():
+ a = bm.asarray([1, 2, 3, 4])
+ b = bm.asarray([2, 4])
+ assert _j(cn.in1d(a, b)).shape == (4,)
+ assert _j(cn.in1d(a, b, invert=True)).shape == (4,)
+ assert _j(cn.intersect1d(a, b)).shape == (2,)
+ assert _j(cn.union1d(a, b)).shape[0] == 4
+ assert _j(cn.setdiff1d(a, b)).shape == (2,)
+ assert _j(cn.isin(a, b)).shape == (4,)
+
+
+def test_compat_numpy_dtype_helpers():
+ assert cn.issubdtype(jnp.float32, jnp.floating)
+ assert cn.can_cast(jnp.int32, jnp.int64)
+ assert cn.result_type(jnp.int32, jnp.float32) is not None
+ assert cn.promote_types(jnp.int32, jnp.float32) is not None
+ assert cn.finfo(jnp.float32).bits == 32
+ assert cn.iinfo(jnp.int32).bits == 32
+
+
+# --- compat_pytorch ---------------------------------------------------------
+
+def test_pytorch_shape_ops():
+ a = bm.asarray(jnp.arange(24.).reshape(2, 3, 4))
+ assert cpt.flatten(a).shape == (24,)
+ assert cpt.flatten(a, start_dim=1).shape == (2, 12)
+ assert cpt.flatten(a, start_dim=1, end_dim=2).shape == (2, 12)
+ assert cpt.flatten(a, start_dim=-2).shape == (2, 12)
+ assert cpt.flatten(bm.asarray(jnp.array(3.))).shape == (1,)
+ assert _j(cpt.unflatten(bm.asarray(jnp.arange(6.)), 0, (2, 3))).shape == (2, 3)
+ assert _j(cpt.unsqueeze(bm.asarray(jnp.arange(3.)), 0)).shape == (1, 3)
+ assert _j(cpt.cat([bm.asarray([1., 2.]), bm.asarray([3., 4.])])).shape == (4,)
+
+
+def test_pytorch_flatten_errors():
+ a = bm.asarray(jnp.arange(6.).reshape(2, 3))
+ with pytest.raises(ValueError):
+ cpt.flatten(a, start_dim=5)
+ with pytest.raises(ValueError):
+ cpt.flatten(a, end_dim=5)
+
+
+def test_pytorch_math_ops_no_out():
+ a = bm.asarray([0.1, 0.2, 0.3])
+ for f in [cpt.abs, cpt.absolute, cpt.acos, cpt.arccos, cpt.acosh, cpt.arccosh,
+ cpt.asin, cpt.arcsin, cpt.asinh, cpt.arcsinh, cpt.atan, cpt.arctan,
+ cpt.atanh, cpt.arctanh]:
+ # acosh needs x >= 1
+ x = bm.asarray([1.1, 1.2, 1.3]) if f in (cpt.acosh, cpt.arccosh) else a
+ assert _j(f(x)).shape == (3,)
+ assert _j(cpt.angle(bm.asarray([1 + 1j, 2 - 1j]))).shape == (2,)
+ assert _j(cpt.atan2(a, a)).shape == (3,)
+ assert _j(cpt.arctan2(a, a)).shape == (3,)
+
+
+def test_pytorch_add_family():
+ a = bm.asarray([1., 2., 3.])
+ b = bm.asarray([4., 5., 6.])
+ np.testing.assert_allclose(np.asarray(_j(cpt.add(a, b))), [5., 7., 9.])
+ np.testing.assert_allclose(np.asarray(_j(cpt.add(a, b, alpha=2))), [9., 12., 15.])
+ assert _j(cpt.addcdiv(a, b, b, value=2)).shape == (3,)
+ assert _j(cpt.addcmul(a, b, b, value=2)).shape == (3,)
+
+
+def test_pytorch_out_paths():
+ a = bm.asarray([-1., -2., -3.])
+ out = bm.zeros((3,))
+ r = cpt.abs(a, out=out)
+ assert r is out
+ np.testing.assert_allclose(np.asarray(_j(out)), [1., 2., 3.])
+ out2 = bm.zeros((3,))
+ r2 = cpt.addcdiv(bm.asarray([1., 1., 1.]), bm.asarray([2., 2., 2.]),
+ bm.asarray([2., 2., 2.]), value=1, out=out2)
+ assert r2 is out2
+
+
+def test_pytorch_unary_out_paths():
+ """Exercise the ``out=`` branch of every unary pytorch math op (H-14)."""
+ a = bm.asarray([0.2, 0.3, 0.4])
+ high = bm.asarray([1.1, 1.2, 1.3])
+ cases = [
+ (cpt.acos, a), (cpt.arccos, a), (cpt.acosh, high), (cpt.arccosh, high),
+ (cpt.asin, a), (cpt.arcsin, a), (cpt.asinh, a), (cpt.arcsinh, a),
+ (cpt.atan, a), (cpt.arctan, a), (cpt.atanh, a), (cpt.arctanh, a),
+ (cpt.absolute, a),
+ ]
+ for f, x in cases:
+ out = bm.zeros((3,))
+ r = f(x, out=out)
+ assert r is out
+ assert _finite(out)
+ # binary / complex out= paths
+ out = bm.zeros((3,))
+ assert cpt.atan2(a, a, out=out) is out
+ out = bm.zeros((3,))
+ assert cpt.arctan2(a, a, out=out) is out
+ cout = bm.zeros((3,))
+ assert cpt.angle(bm.asarray([1 + 1j, 2 + 0j, 0 + 1j]), out=cout) is cout
+
+
+def test_pytorch_flatten_negative_end_dim_error():
+ a = bm.asarray(jnp.arange(6.).reshape(2, 3))
+ with pytest.raises(ValueError):
+ cpt.flatten(a, end_dim=-10)
+
+
+def test_pytorch_clamp_aliases():
+ a = bm.asarray([1., 5., 9.])
+ np.testing.assert_allclose(np.asarray(_j(cpt.clamp_max(a, 4.))), [1., 4., 4.])
+ np.testing.assert_allclose(np.asarray(_j(cpt.clamp_min(a, 4.))), [4., 5., 9.])
+
+
+def test_pytorch_tensor_alias():
+ assert cpt.Tensor is Array
+
+
+# --- compat_tensorflow ------------------------------------------------------
+
+def test_tensorflow_reductions():
+ a = bm.asarray([[1., 2., 3.], [4., 5., 6.]])
+ assert float(_j(ctf.reduce_sum(a))) == 21.
+ assert float(_j(ctf.reduce_max(a))) == 6.
+ assert float(_j(ctf.reduce_min(a))) == 1.
+ assert float(_j(ctf.reduce_mean(a))) == pytest.approx(3.5)
+ assert float(_j(ctf.reduce_prod(a))) == 720.
+ assert _j(ctf.reduce_std(a)).shape == ()
+ assert _j(ctf.reduce_variance(a)).shape == ()
+ assert _j(ctf.reduce_euclidean_norm(a)).shape == ()
+ ba = bm.asarray([[True, False], [True, True]])
+ assert bool(ctf.reduce_all(ba)) is False
+ assert bool(ctf.reduce_any(ba)) is True
+ # axis variants
+ assert _j(ctf.reduce_max(a, axis=1)).shape == (2,)
+ assert _j(ctf.reduce_sum(a, axis=0, keepdims=True)).shape == (1, 3)
+ assert _j(ctf.reduce_euclidean_norm(a, axis=1)).shape == (2,)
+
+
+def test_tensorflow_segment_ops():
+ data = bm.asarray([1., 2., 3., 4.])
+ seg = bm.asarray([0, 0, 1, 1])
+ assert _j(ctf.segment_sum(data, seg)).shape == (2,)
+ assert _j(ctf.segment_prod(data, seg)).shape == (2,)
+ assert _j(ctf.segment_max(data, seg)).shape == (2,)
+ assert _j(ctf.segment_min(data, seg)).shape == (2,)
+ np.testing.assert_allclose(np.asarray(_j(ctf.segment_mean(data, seg))), [1.5, 3.5])
+ assert _j(ctf.unsorted_segment_sum(data, seg, 2)).shape == (2,)
+ assert _j(ctf.unsorted_segment_prod(data, seg, 2)).shape == (2,)
+ assert _j(ctf.unsorted_segment_max(data, seg, 2)).shape == (2,)
+ assert _j(ctf.unsorted_segment_min(data, seg, 2)).shape == (2,)
+ assert _j(ctf.unsorted_segment_mean(data, seg, 2)).shape == (2,)
+ assert _j(ctf.unsorted_segment_sqrt_n(data, seg, 2)).shape == (2,)
+
+
+def test_tensorflow_cast_clip_concat():
+ a = bm.asarray([1.4, 2.6, 3.1])
+ casted = ctf.cast(a, jnp.int32)
+ assert casted.dtype == jnp.int32
+ np.testing.assert_allclose(np.asarray(_j(ctf.clip_by_value(a, 2., 3.))), [2., 2.6, 3.])
+ assert _j(ctf.concat([a, a])).shape == (6,)
+
+
+# --- activations ------------------------------------------------------------
+
+def test_activations_basic():
+ x = bm.asarray([-2., -0.5, 0., 0.5, 2.])
+ for f in [act.relu, act.relu6, act.sigmoid, act.softplus, act.silu, act.swish,
+ act.mish, act.selu, act.elu, act.celu, act.soft_sign, act.log_sigmoid,
+ act.hard_sigmoid, act.hard_silu, act.hard_swish, act.tanh_shrink,
+ act.leaky_relu, act.hard_shrink, act.soft_shrink, act.prelu,
+ act.identity]:
+ r = f(x)
+ assert np.asarray(_j(r)).shape == (5,)
+ assert _finite(r)
+
+
+def test_activations_relu_correct():
+ x = bm.asarray([-1., 0., 2.])
+ np.testing.assert_allclose(np.asarray(_j(act.relu(x))), [0., 0., 2.])
+ np.testing.assert_allclose(np.asarray(_j(act.relu6(bm.asarray([-1., 3., 9.])))), [0., 3., 6.])
+
+
+def test_activations_gelu_both():
+ x = bm.asarray([-1., 0., 1., 2.])
+ assert _finite(act.gelu(x, approximate=True))
+ assert _finite(act.gelu(x, approximate=False))
+
+
+def test_activations_softmax_family():
+ x = bm.asarray([1., 2., 3.])
+ sm = act.softmax(x)
+ assert float(jnp.sum(_j(sm))) == pytest.approx(1.0, abs=1e-6)
+ ls = act.log_softmax(x)
+ np.testing.assert_allclose(np.asarray(_j(jnp.exp(ls))), np.asarray(_j(sm)), atol=1e-6)
+ smn = act.softmin(x)
+ assert float(jnp.sum(_j(smn))) == pytest.approx(1.0, abs=1e-6)
+ assert act.soft_max is act.softmax
+
+
+def test_activations_softmax_large_inputs_stable():
+ x = bm.asarray([1000., 1001., 1002.])
+ assert _finite(act.softmax(x))
+ assert _finite(act.log_softmax(x))
+ assert _finite(act.softmin(x))
+
+
+def test_activations_tanh_and_glu():
+ x = bm.asarray([-1., 0., 1.])
+ np.testing.assert_allclose(np.asarray(_j(act.tanh(x))), np.tanh([-1., 0., 1.]), atol=1e-6)
+ assert _j(act.glu(bm.asarray(jnp.arange(4.)))).shape == (2,)
+ with pytest.raises(AssertionError):
+ act.glu(bm.asarray(jnp.arange(3.))) # odd axis size
+
+
+def test_activations_one_hot_and_normalize():
+ oh = act.one_hot(bm.asarray([0, 1, 2]), 3)
+ assert _j(oh).shape == (3, 3)
+ np.testing.assert_allclose(np.asarray(_j(oh)), np.eye(3))
+ # out-of-range indices -> all zeros
+ oh2 = act.one_hot(bm.asarray([-1, 5]), 3)
+ np.testing.assert_allclose(np.asarray(_j(oh2)), np.zeros((2, 3)))
+ n = act.normalize(bm.asarray([1., 2., 3., 4.]))
+ assert _finite(n)
+ assert float(jnp.mean(_j(n))) == pytest.approx(0.0, abs=1e-5)
+
+
+def test_activations_rrelu():
+ r = act.rrelu(bm.asarray([-1., -2., 1., 2.]))
+ assert _j(r).shape == (4,)
+ # positive entries pass through unchanged
+ assert float(_j(r)[2]) == 1.
+ assert float(_j(r)[3]) == 2.
+
+
+def test_activations_get_dispatch():
+ assert act.get('relu') is act.relu
+ assert act.get(None) is None
+ fn = (lambda x: x)
+ assert act.get(fn) is fn
+ with pytest.raises(ValueError):
+ act.get('this_is_not_an_activation')
+ with pytest.raises(ValueError):
+ act.get(123)
+
+
+def test_activations_accept_plain_jax_arrays():
+ x = jnp.array([-1., 0., 1.])
+ for f in [act.relu, act.sigmoid, act.softmax, act.softmin, act.elu, act.gelu]:
+ assert _finite(f(x))
+
+
+def test_activations_axis_out_of_bounds():
+ # one_hot routes axis through _canonicalize_axis -> ValueError on OOB
+ with pytest.raises(ValueError):
+ act.one_hot(jnp.array([0, 1, 2]), 3, axis=5)
+ # log_softmax/softmax over a non-existent axis raises
+ with pytest.raises(Exception):
+ act.log_softmax(jnp.arange(6.), axis=5)
+
+
+def test_activations_one_hot_axis_and_dtype():
+ oh = act.one_hot(jnp.array([0, 1, 2]), 3, axis=0)
+ assert _j(oh).shape == (3, 3)
+ oh_i = act.one_hot(jnp.array([0, 1]), 2, dtype=jnp.int32)
+ assert _j(oh_i).dtype == jnp.int32
+
+
+def test_activations_hard_tanh_clamping():
+ x = bm.asarray([-2., -0.5, 0.5, 2.])
+ np.testing.assert_allclose(np.asarray(_j(act.hard_tanh(x))), [-1., -0.5, 0.5, 1.])
+
+
+def test_activations_softplus_threshold_branch():
+ # values above threshold revert to the linear branch
+ x = bm.asarray([0.5, 25.0, 50.0])
+ r = act.softplus(x, beta=1., threshold=20.)
+ assert _finite(r)
+ assert float(_j(r)[1]) == pytest.approx(25.0, rel=1e-5)
+
+
+# --- others -----------------------------------------------------------------
+
+def test_others_shared_args_over_time():
+ r = bo.shared_args_over_time(num_step=5)
+ assert r['i'].shape == (5,)
+ assert r['t'].shape == (5,)
+ assert r['dt'].shape == (5,)
+ r2 = bo.shared_args_over_time(duration=1.0, dt=0.1, include_dt=False)
+ assert r2['i'].shape == (10,)
+ assert 'dt' not in r2
+
+
+def test_others_clip_by_norm():
+ t = jnp.array([3., 4.])
+ r = bo.clip_by_norm(t, 1.0)
+ assert float(jnp.linalg.norm(_j(r))) <= 1.0 + 1e-5
+ # pytree input
+ r2 = bo.clip_by_norm({'a': jnp.array([3., 4.])}, 1.0)
+ assert 'a' in r2
+
+
+def test_others_exprel():
+ r = bo.exprel(jnp.array([0., 1., -1.]))
+ assert _finite(r)
+ # exprel(0) == 1 (removable singularity)
+ assert float(_j(r)[0]) == pytest.approx(1.0, abs=1e-4)
+ # exprel(x) == (exp(x)-1)/x away from zero
+ assert float(_j(r)[1]) == pytest.approx((np.e - 1.0), rel=1e-3)
+ # float64 threshold branch
+ r64 = bo.exprel(jnp.array([0.5]))
+ assert _finite(r64)
+
+
+def test_others_is_float_type():
+ assert bo.is_float_type(jnp.array([1., 2.]))
+ assert not bo.is_float_type(jnp.array([1, 2]))
+
+
+def test_others_add_axis_axes():
+ x = jnp.arange(3.)
+ assert _j(bo.add_axis(x, 0)).shape == (1, 3)
+ r = bo.add_axes(x, n_axes=2, pos2len={0: 4})
+ assert _j(r).shape == (4, 3)
+
+
+# --- _utils -----------------------------------------------------------------
+
+def test_utils_as_jax_array_and_is_leaf():
+ a = bm.asarray([1., 2., 3.])
+ assert isinstance(butils._as_jax_array_(a), jax.Array)
+ assert butils._as_jax_array_(5) == 5
+ assert butils._is_leaf(a) is True
+ assert butils._is_leaf(jnp.array([1.])) is False
+
+
+def test_utils_compatible_wrapper_kwargs_translation():
+ a = bm.asarray([[1., 2.], [3., 4.]])
+ # PyTorch dim -> axis
+ np.testing.assert_allclose(
+ np.asarray(_j(bm.sum(a, dim=0))), [4., 6.])
+ # PyTorch keepdim -> keepdims
+ assert _j(bm.sum(a, axis=0, keepdim=True)).shape == (1, 2)
+ # TensorFlow keep_dims -> keepdims
+ assert _j(bm.sum(a, axis=0, keep_dims=True)).shape == (1, 2)
+
+
+def test_utils_wrapper_returns_brainpy_array():
+ r = bm.add(bm.asarray([1., 2.]), bm.asarray([3., 4.]))
+ assert isinstance(r, Array)
+
+
+def test_utils_wrapper_doc_and_name():
+ # _compatible_with_brainpy_array preserves the wrapped function name
+ assert cn.sum.__name__ == 'sum'
+ assert 'brainpy Array/Variable' in cn.sum.__doc__
+
+
+# --- einops -----------------------------------------------------------------
+
+def test_einops_rearrange():
+ x = jnp.arange(24.).reshape(2, 3, 4)
+ assert bein.ein_rearrange(x, 'a b c -> a c b').shape == (2, 4, 3)
+ assert bein.ein_rearrange(x, 'a b c -> (a b) c').shape == (6, 4)
+ assert bein.ein_rearrange(x, 'a b c -> a b c').shape == (2, 3, 4)
+ # split an axis
+ assert bein.ein_rearrange(jnp.arange(12.), '(a b) -> a b', a=3).shape == (3, 4)
+
+
+def test_einops_reduce():
+ x = jnp.arange(24.).reshape(2, 3, 4)
+ assert bein.ein_reduce(x, 'a b c -> a c', 'mean').shape == (2, 4)
+ assert bein.ein_reduce(x, 'a b c -> a', 'sum').shape == (2,)
+ assert bein.ein_reduce(x, 'a b c -> b c', 'max').shape == (3, 4)
+ assert bein.ein_reduce(x, 'a b c -> b c', 'min').shape == (3, 4)
+ assert bein.ein_reduce(x, 'a b c -> b c', 'prod').shape == (3, 4)
+ # pooling-style reduce with explicit axis lengths
+ y = jnp.arange(2 * 2 * 4 * 4.).reshape(2, 2, 4, 4)
+ assert bein.ein_reduce(y, 'b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2).shape == (2, 2, 2, 2)
+
+
+def test_einops_reduce_any_all():
+ b = jnp.array([[True, False, True], [False, False, True]])
+ assert bein.ein_reduce(b, 'a c -> c', 'any').shape == (3,)
+ assert bein.ein_reduce(b, 'a c -> c', 'all').shape == (3,)
+ np.testing.assert_array_equal(
+ np.asarray(_j(bein.ein_reduce(b, 'a c -> c', 'any'))), [True, False, True])
+ np.testing.assert_array_equal(
+ np.asarray(_j(bein.ein_reduce(b, 'a c -> c', 'all'))), [False, False, True])
+
+
+def test_einops_repeat():
+ img = jnp.arange(6.).reshape(2, 3)
+ assert bein.ein_repeat(img, 'h w -> h w c', c=4).shape == (2, 3, 4)
+ assert bein.ein_repeat(img, 'h w -> (h2 h) w', h2=2).shape == (4, 3)
+
+
+def test_einops_shape():
+ x = jnp.zeros((2, 3, 5, 7))
+ assert bein.ein_shape(x, 'batch _ h w') == {'batch': 2, 'h': 5, 'w': 7}
+ assert bein.ein_shape(x, 'a b c d') == {'a': 2, 'b': 3, 'c': 5, 'd': 7}
+
+
+def test_einops_reduce_callable_reduction():
+ x = jnp.arange(24.).reshape(2, 3, 4)
+ out = bein.ein_reduce(x, 'a b c -> a c', lambda t, axes: t.sum(axis=axes))
+ assert out.shape == (2, 4)
+
+
+def test_einops_mean_requires_float():
+ x = jnp.arange(24).reshape(2, 3, 4) # integer tensor
+ with pytest.raises(Exception):
+ bein.ein_reduce(x, 'a b c -> a c', 'mean')
+
+
+def test_einops_error_message_wrapped():
+ from brainpy.math.einops_parsing import EinopsError
+ with pytest.raises(EinopsError):
+ bein.ein_rearrange(jnp.arange(6.), 'a b c -> a b c') # wrong ndim
+
+
+def test_einops_enumerate_directions_internal():
+ x = jnp.zeros((2, 3))
+ dirs = bein._enumerate_directions(x)
+ assert len(dirs) == 2
+ assert _j(dirs[0]).shape == (2, 1)
+ assert _j(dirs[1]).shape == (1, 3)
+
+
+def test_einops_ellipsis_patterns():
+ x = jnp.arange(24.).reshape(2, 3, 4)
+ # reduce trailing axis, keep ellipsis dims
+ assert bein.ein_reduce(x, '... c -> ...', 'sum').shape == (2, 3)
+ # move leading axis to the end across an ellipsis
+ assert bein.ein_rearrange(x, 'a ... -> ... a').shape == (3, 4, 2)
+ # repeat with ellipsis
+ assert bein.ein_repeat(jnp.arange(6.).reshape(2, 3), '... -> ... r', r=2).shape == (2, 3, 2)
+
+
+def test_einops_shape_with_ellipsis():
+ x = jnp.zeros((2, 3, 5, 7))
+ assert bein.ein_shape(x, 'b ... w') == {'b': 2, 'w': 7}
+
+
+def test_einops_error_branches():
+ from brainpy.math.einops_parsing import EinopsError
+ x = jnp.arange(24.).reshape(2, 3, 4)
+ # identifiers only on one side of a rearrange
+ with pytest.raises(EinopsError):
+ bein.ein_rearrange(x, 'a b c -> a b')
+ # repeat without a size for a new axis
+ with pytest.raises(EinopsError):
+ bein.ein_repeat(jnp.arange(6.).reshape(2, 3), 'h w -> h w c')
+ # extra identifier on the right of a reduce
+ with pytest.raises(EinopsError):
+ bein.ein_reduce(x, 'a b c -> a b c d', 'sum')
+ # unknown reduction name
+ with pytest.raises(EinopsError):
+ bein.ein_reduce(x, 'a b c -> a', 'median')
+ # composed axes can't be parsed by ein_shape
+ with pytest.raises(RuntimeError):
+ bein.ein_shape(jnp.zeros((6,)), '(a b)')
+
+
+def test_einops_list_input_passthrough_identity():
+ # NOTE: the docstrings advertise stacking list-of-tensors input, but this
+ # port does not stack the list -- an identity pattern is a no-op and returns
+ # the list unchanged. Pin the current (documented-but-incomplete) behaviour.
+ imgs = [jnp.zeros((3, 4)) for _ in range(5)]
+ out = bein.ein_rearrange(imgs, 'b h w -> b h w')
+ assert isinstance(out, list)
+ assert len(out) == 5
diff --git a/tests/audit/test_math_core_fixes.py b/tests/audit/test_math_core_fixes.py
new file mode 100644
index 000000000..310b050bb
--- /dev/null
+++ b/tests/audit/test_math_core_fixes.py
@@ -0,0 +1,897 @@
+# -*- coding: utf-8 -*-
+"""Regression + coverage tests for the BrainPy v2.7.8 math-core audit
+(see ``docs/issues-found-20260618.md``).
+
+This module exercises the fixes recorded in the audit for:
+
+* ``brainpy/math/ndarray.py`` — H-11 (``Array.device`` is a property),
+ H-12 (``Array(scalar)``), M-09 (``ShardedArray.value`` read), Array.tree_*
+ pytree round-trip under abstract eval, and L-03 (base vs sharded value
+ setter policy).
+* ``brainpy/math/environment.py`` — C-10 (``disable_x64`` re-syncs brainstate
+ precision) and M-07 (``set()`` validate-before-mutate).
+* ``brainpy/math/modes.py`` — H-10 (``Mode`` is hashable again).
+* ``brainpy/math/scales.py`` — L-02 (``IdScaling`` rejects non-default
+ bias/scale instead of silently ignoring them).
+* ``brainpy/math/sharding.py`` — M-10 (``get_sharding`` warns on a full
+ axis-name mismatch).
+* ``brainpy/math/remove_vmap.py`` — M-08 (documented global-reduction batching
+ behaviour).
+
+Every test that toggles x64 / global precision / global environment restores
+the original state in a ``finally`` block (or via the ``restore_environment``
+fixture) so it cannot corrupt other tests in the suite.
+"""
+
+import warnings
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pytest
+
+import brainstate
+from jax import config
+
+import brainpy
+import brainpy.math as bm
+from brainpy._errors import MathError
+from brainpy.math import modes, scales, sharding
+from brainpy.math.ndarray import Array, ShardedArray, JaxArray, ndarray
+from brainpy.math.remove_vmap import remove_vmap
+from brainpy.math.defaults import defaults as _defaults
+
+
+# ---------------------------------------------------------------------------
+# Fixtures
+# ---------------------------------------------------------------------------
+
+@pytest.fixture
+def restore_environment():
+ """Snapshot global precision / x64 / float state and restore it afterwards.
+
+ Any test that flips ``enable_x64``/``disable_x64`` or mutates global dtype
+ settings MUST run inside this fixture so the rest of the suite keeps the
+ default float32 environment.
+ """
+ orig_precision = brainstate.environ.get_precision()
+ orig_x64 = config.read('jax_enable_x64')
+ orig_float = bm.get_float()
+ orig_int = bm.get_int()
+ orig_complex = bm.get_complex()
+ try:
+ yield
+ finally:
+ # Restore JAX + brainstate precision symmetrically.
+ if orig_x64:
+ bm.enable_x64()
+ else:
+ bm.disable_x64()
+ brainstate.environ.set(precision=orig_precision)
+ bm.set_float(orig_float)
+ bm.set_int(orig_int)
+ bm.set_complex(orig_complex)
+
+
+# ===========================================================================
+# Regression tests
+# ===========================================================================
+
+# --- environment.py : C-10 -------------------------------------------------
+
+def test_disable_x64_resyncs_brainstate_precision(restore_environment):
+ """C-10: after ``enable_x64(); disable_x64()`` the brainstate precision is
+ back to 32 (it used to be left at 64 while JAX was at float32)."""
+ bm.enable_x64()
+ assert brainstate.environ.get_precision() == 64
+ assert config.read('jax_enable_x64') is True
+
+ bm.disable_x64()
+ assert brainstate.environ.get_precision() == 32
+ assert config.read('jax_enable_x64') is False
+
+
+def test_enable_then_disable_x64_leaves_precision_32(restore_environment):
+ """The brainstate precision and JAX x64 flag stay in lock-step."""
+ bm.enable_x64()
+ bm.disable_x64()
+ assert brainstate.environ.get_precision() == 32
+ # float default tracks the disabled state.
+ assert bm.get_float() == jnp.float32
+
+
+# --- environment.py : M-07 -------------------------------------------------
+
+def test_set_validate_before_mutate_invalid_numpy_func_return():
+ """M-07: an invalid ``numpy_func_return`` raises and does NOT mutate
+ ``get_float()`` (validation happens before any global write)."""
+ before_float = bm.get_float()
+ before_return = _defaults.numpy_func_return
+ with pytest.raises(AssertionError):
+ bm.set(float_=jnp.float64, numpy_func_return='not-a-valid-option')
+ # No partial-config leak: float type unchanged.
+ assert bm.get_float() == before_float
+ assert _defaults.numpy_func_return == before_return
+
+
+def test_set_validate_before_mutate_invalid_dt():
+ """A non-float ``dt`` is rejected before any other arg is applied."""
+ before_mode = bm.get_mode()
+ with pytest.raises(AssertionError):
+ bm.set(mode=modes.batching_mode, dt='not-a-float')
+ assert bm.get_mode() is before_mode
+
+
+# --- modes.py : H-10 -------------------------------------------------------
+
+def test_modes_are_hashable():
+ """H-10: defining ``__eq__`` had nuked ``__hash__``; restore hashability."""
+ assert hash(modes.nonbatching_mode) is not None
+ assert hash(brainpy.math.modes.nonbatching_mode) is not None
+
+
+def test_modes_usable_in_a_set():
+ """All three default modes can live together in a set."""
+ s = {modes.nonbatching_mode, modes.batching_mode, modes.training_mode}
+ assert len(s) == 3
+ # Two non-batching instances compare/hash equal (by class).
+ s2 = {modes.NonBatchingMode(), modes.NonBatchingMode()}
+ assert len(s2) == 1
+
+
+# --- ndarray.py : H-11 -----------------------------------------------------
+
+def test_array_device_is_property_returning_jax_device():
+ """H-11: ``Array.device`` is now a property (calling the old method form
+ raised ``TypeError``). It returns a ``jax.Device``."""
+ a = bm.asarray([1., 2.])
+ dev = a.device # property access, not a call
+ assert isinstance(dev, jax.Device)
+
+
+# --- ndarray.py : H-12 -----------------------------------------------------
+
+def test_array_scalar_conversion_shape():
+ """H-12: ``Array(scalar)`` stores a real jax array, so ``.shape == ()``."""
+ assert Array(5).shape == ()
+ assert Array(5).ndim == 0
+ assert Array(5.0).shape == ()
+ # value is a jax array, not a bare python scalar.
+ assert isinstance(Array(5).value, jax.Array)
+
+
+# --- ndarray.py : pytree round-trip ----------------------------------------
+
+def test_array_pytree_round_trip_under_abstract_eval():
+ """``Array.tree_unflatten`` must accept abstract leaves so transforms like
+ ``jax.eval_shape`` work without re-running ``jnp.asarray``."""
+ out = jax.eval_shape(lambda x: x * 2, bm.asarray([1., 2., 3.]))
+ assert out.shape == (3,)
+
+
+def test_array_tree_flatten_unflatten_concrete():
+ a = Array([1., 2., 3.])
+ leaves, aux = a.tree_flatten()
+ assert aux is None
+ rebuilt = Array.tree_unflatten(aux, leaves)
+ np.testing.assert_array_equal(np.asarray(rebuilt.value), np.asarray(a.value))
+
+
+def test_array_jit_round_trip():
+ """A full jit round-trip preserves the Array pytree contract."""
+ a = bm.asarray([1., 2., 3.])
+ out = jax.jit(lambda x: x + 1.)(a)
+ np.testing.assert_allclose(np.asarray(out), np.array([2., 3., 4.]))
+
+
+# --- scales.py : L-02 ------------------------------------------------------
+
+def test_idscaling_rejects_non_default_bias():
+ ids = scales.IdScaling()
+ with pytest.raises(ValueError):
+ ids.offset_scaling(1.0, bias=5.0)
+
+
+def test_idscaling_rejects_non_default_scale():
+ ids = scales.IdScaling()
+ with pytest.raises(ValueError):
+ ids.std_scaling(1.0, scale=2.0)
+ with pytest.raises(ValueError):
+ ids.inv_scaling(1.0, scale=2.0)
+ with pytest.raises(ValueError):
+ ids.clone(scale=2.0)
+ with pytest.raises(ValueError):
+ ids.clone(bias=3.0)
+
+
+def test_idscaling_identity_for_default_calls():
+ ids = scales.IdScaling()
+ assert ids.offset_scaling(7.0) == 7.0
+ assert ids.std_scaling(7.0) == 7.0
+ assert ids.inv_scaling(7.0) == 7.0
+ # Explicit default values (0./1.) are accepted (no-op overrides).
+ assert ids.offset_scaling(7.0, bias=0., scale=1.) == 7.0
+ assert isinstance(ids.clone(), scales.IdScaling)
+ assert ids.scale == 1.0 and ids.bias == 0.0
+
+
+# --- sharding.py : M-10 ----------------------------------------------------
+
+def _single_axis_mesh(name='x'):
+ return jax.sharding.Mesh(np.asarray(jax.devices()), axis_names=(name,))
+
+
+def test_get_sharding_warns_on_full_axis_mismatch():
+ """M-10: when *every* requested axis name is absent, a fully-replicated
+ PartitionSpec is produced; warn instead of silently dropping intent."""
+ mesh = _single_axis_mesh('x')
+ with pytest.warns(UserWarning):
+ sh = sharding.get_sharding(['this_axis_does_not_exist'], mesh)
+ assert sh is not None # still returns a (replicated) NamedSharding
+
+
+def test_get_sharding_no_warning_on_partial_match():
+ """A partial match is tolerated on purpose: no warning."""
+ mesh = _single_axis_mesh('x')
+ with warnings.catch_warnings():
+ warnings.simplefilter('error') # any warning -> test failure
+ sh = sharding.get_sharding(['x', 'missing'], mesh)
+ assert sh is not None
+
+
+def test_get_sharding_none_axis_names_returns_none():
+ assert sharding.get_sharding(None) is None
+
+
+# --- remove_vmap.py : M-08 -------------------------------------------------
+
+def test_remove_vmap_global_reduction_any():
+ """M-08 documents that the reduction is global. Outside vmap it is a real
+ scalar."""
+ out = remove_vmap(jnp.array([False, True]))
+ assert out.shape == ()
+ assert bool(out) is True
+
+
+def test_remove_vmap_global_reduction_all():
+ assert bool(remove_vmap(jnp.array([True, True]), 'all')) is True
+ assert bool(remove_vmap(jnp.array([True, False]), 'all')) is False
+
+
+def test_remove_vmap_under_vmap_is_global():
+ """Under ``vmap`` the result is a *global* reduction across the batch (the
+ documented, intentional behaviour): every batch slot sees the same value."""
+ data = jnp.array([[0.1], [0.9]])
+ out = jax.vmap(lambda x: remove_vmap(x > 0.5, 'any'))(data)
+ # Global any over the whole batch is True; broadcast across the batch.
+ assert bool(out[0]) is True and bool(out[1]) is True
+ out_all = jax.vmap(lambda x: remove_vmap(x > 0.5, 'all'))(data)
+ # Global all over the whole batch is False.
+ assert bool(out_all[0]) is False and bool(out_all[1]) is False
+
+
+def test_remove_vmap_invalid_op_raises():
+ with pytest.raises(ValueError):
+ remove_vmap(jnp.array([1]), 'unsupported')
+
+
+def test_remove_vmap_unwraps_bp_array():
+ """A ``brainpy.math.Array`` input is unwrapped (line ``x = x.value``)."""
+ assert bool(remove_vmap(Array([False, True]))) is True
+ assert bool(remove_vmap(Array([True, True]), 'all')) is True
+
+
+def test_remove_vmap_under_jit_triggers_abstract_eval():
+ """Running under ``jit`` exercises the abstract-eval rules of both
+ primitives."""
+ f_any = jax.jit(lambda x: remove_vmap(x > 0., 'any'))
+ f_all = jax.jit(lambda x: remove_vmap(x > 0., 'all'))
+ assert bool(f_any(jnp.array([-1., 1.]))) is True
+ assert bool(f_all(jnp.array([1., 1.]))) is True
+ assert bool(f_all(jnp.array([-1., 1.]))) is False
+
+
+# ===========================================================================
+# Coverage tests
+# ===========================================================================
+
+# --- ndarray.py : Array methods --------------------------------------------
+
+def test_array_arithmetic():
+ a = Array([1., 2., 3.])
+ np.testing.assert_allclose(np.asarray(a + 1), [2., 3., 4.])
+ np.testing.assert_allclose(np.asarray(a - 1), [0., 1., 2.])
+ np.testing.assert_allclose(np.asarray(a * 2), [2., 4., 6.])
+ np.testing.assert_allclose(np.asarray(-a), [-1., -2., -3.])
+ np.testing.assert_allclose(np.asarray(a / 2), [0.5, 1., 1.5])
+
+
+def test_array_indexing_and_iteration():
+ a = Array([10., 20., 30.])
+ assert float(a[0]) == 10.0
+ assert len(a) == 3
+ assert [float(x) for x in a] == [10.0, 20.0, 30.0]
+ # slice
+ np.testing.assert_allclose(np.asarray(a[1:]), [20., 30.])
+
+
+def test_array_inplace_set():
+ a = Array([1., 2., 3.])
+ a[0] = 99.
+ np.testing.assert_allclose(np.asarray(a.value), [99., 2., 3.])
+ # .at accessor delegates to the underlying jax array.
+ updated = a.at[1].set(0.)
+ np.testing.assert_allclose(np.asarray(updated), [99., 0., 3.])
+
+
+def test_array_repr_scalar_and_vector():
+ r = repr(Array([1., 2., 3.]))
+ assert r.startswith('Array(value=')
+ # multi-line repr branch
+ big = Array(np.arange(40.).reshape(8, 5))
+ r2 = repr(big)
+ assert 'Array(value=' in r2 and '\n' in r2 and 'dtype=' in r2
+
+
+def test_array_dtype_shape_ndim():
+ a = Array([[1., 2.], [3., 4.]])
+ assert a.shape == (2, 2)
+ assert a.ndim == 2
+ assert a.dtype == jnp.float32
+
+
+def test_array_value_setter_with_array_np_and_state():
+ a = Array([0., 0.])
+ # set from numpy
+ a.value = np.array([4., 5.])
+ np.testing.assert_allclose(np.asarray(a.value), [4., 5.])
+ # set from another Array
+ a.value = Array([7., 8.])
+ np.testing.assert_allclose(np.asarray(a.value), [7., 8.])
+ # set from a jax array
+ a.value = jnp.array([1., 2.])
+ np.testing.assert_allclose(np.asarray(a.value), [1., 2.])
+ # set from a State / Variable
+ a.value = bm.Variable(jnp.array([9., 10.]))
+ np.testing.assert_allclose(np.asarray(a.value), [9., 10.])
+ # set from a python list (the "else -> jnp.asarray" branch)
+ a.value = [11., 12.]
+ np.testing.assert_allclose(np.asarray(a.value), [11., 12.])
+
+
+def test_array_data_property_and_update():
+ a = Array([1., 2.])
+ np.testing.assert_allclose(np.asarray(a.data), [1., 2.])
+ a.data = jnp.array([3., 4.])
+ np.testing.assert_allclose(np.asarray(a.data), [3., 4.])
+ a.update(jnp.array([5., 6.]))
+ np.testing.assert_allclose(np.asarray(a.value), [5., 6.])
+
+
+def test_array_fill():
+ a = Array([1., 2., 3.])
+ a.fill_(7.)
+ np.testing.assert_allclose(np.asarray(a.value), [7., 7., 7.])
+ # Array scalar as fill value
+ a.fill_(Array(2.))
+ np.testing.assert_allclose(np.asarray(a.value), [2., 2., 2.])
+ # numpy scalar
+ a.fill_(np.float32(3.))
+ np.testing.assert_allclose(np.asarray(a.value), [3., 3., 3.])
+ # non-scalar fill value is rejected
+ with pytest.raises(MathError):
+ a.fill_(np.array([1., 2.]))
+
+
+def test_array_numpy_and_jax_protocols():
+ a = Array([1., 2., 3.])
+ np.testing.assert_allclose(np.asarray(a), [1., 2., 3.])
+ assert isinstance(a.__jax_array__(), jax.Array)
+ # __array__ with dtype
+ arr = np.asarray(a, dtype=np.int32)
+ assert arr.dtype == np.int32
+
+
+def test_array_as_variable():
+ a = Array([1., 2.])
+ v = a.as_variable()
+ assert isinstance(v, bm.Variable)
+
+
+def test_array_block_until_ready_and_device_buffer():
+ a = Array([1., 2., 3.])
+ assert isinstance(a.block_until_ready(), jax.Array)
+ assert isinstance(a.block_host_until_ready(), jax.Array)
+ np.testing.assert_allclose(np.asarray(a.device_buffer), [1., 2., 3.])
+
+
+def test_array_aliases():
+ assert JaxArray is Array
+ assert ndarray is Array
+
+
+def test_array_constructed_from_array_and_dtype():
+ base = Array([1., 2., 3.])
+ a = Array(base, dtype=jnp.int32)
+ assert a.dtype == jnp.int32
+ np.testing.assert_array_equal(np.asarray(a.value), [1, 2, 3])
+ # tuple input branch
+ b = Array((1., 2.))
+ np.testing.assert_allclose(np.asarray(b.value), [1., 2.])
+
+
+# --- ndarray.py : ShardedArray ---------------------------------------------
+
+def test_sharded_array_value_read_write():
+ sa = ShardedArray(jnp.array([1., 2., 3.]))
+ np.testing.assert_allclose(np.asarray(sa.value), [1., 2., 3.])
+ sa.value = jnp.array([4., 5., 6.])
+ np.testing.assert_allclose(np.asarray(sa.value), [4., 5., 6.])
+
+
+def test_sharded_array_enforces_shape_and_dtype():
+ """L-03 / M-09: the ShardedArray setter enforces shape & dtype."""
+ sa = ShardedArray(jnp.array([1., 2., 3.]))
+ with pytest.raises(MathError):
+ sa.value = jnp.array([1., 2.]) # wrong shape
+ with pytest.raises(MathError):
+ sa.value = jnp.array([1, 2, 3]) # wrong dtype
+
+
+def test_sharded_array_keep_sharding_false():
+ sa = ShardedArray(jnp.array([1., 2.]), keep_sharding=False)
+ np.testing.assert_allclose(np.asarray(sa.value), [1., 2.])
+
+
+def test_sharded_array_setter_from_array_and_np():
+ sa = ShardedArray(jnp.array([1., 2.]))
+ sa.value = Array([3., 4.])
+ np.testing.assert_allclose(np.asarray(sa.value), [3., 4.])
+ sa.value = np.array([5., 6.], dtype=np.float32)
+ np.testing.assert_allclose(np.asarray(sa.value), [5., 6.])
+
+
+# --- environment.py : getters / setters ------------------------------------
+
+def test_environment_dtype_getters():
+ assert bm.get_float() in (jnp.float32, jnp.float64)
+ assert bm.get_int() in (jnp.int32, jnp.int64)
+ assert bm.get_complex() in (jnp.complex64, jnp.complex128)
+ assert bm.get_bool() == jnp.bool_
+ # deprecated aliases still resolve.
+ assert bm.dftype() == bm.get_float()
+ assert bm.ditype() == bm.get_int()
+
+
+def test_environment_dtype_setters_round_trip():
+ orig_float = bm.get_float()
+ orig_int = bm.get_int()
+ orig_complex = bm.get_complex()
+ orig_bool = bm.get_bool()
+ try:
+ bm.set_float(jnp.float16)
+ assert bm.get_float() == jnp.float16
+ bm.set_int(jnp.int16)
+ assert bm.get_int() == jnp.int16
+ bm.set_complex(jnp.complex64)
+ assert bm.get_complex() == jnp.complex64
+ bm.set_bool(jnp.bool_)
+ assert bm.get_bool() == jnp.bool_
+ finally:
+ bm.set_float(orig_float)
+ bm.set_int(orig_int)
+ bm.set_complex(orig_complex)
+ bm.set_bool(orig_bool)
+
+
+def test_environment_dt_setter_round_trip():
+ orig = bm.get_dt()
+ try:
+ bm.set_dt(0.05)
+ assert bm.get_dt() == 0.05
+ finally:
+ bm.set_dt(orig)
+ assert bm.get_dt() == orig
+
+
+def test_environment_mode_setter_round_trip():
+ orig = bm.get_mode()
+ try:
+ bm.set_mode(modes.batching_mode)
+ assert bm.get_mode() is modes.batching_mode
+ finally:
+ bm.set_mode(orig)
+ with pytest.raises(TypeError):
+ bm.set_mode('not-a-mode')
+
+
+def test_environment_membrane_scaling_setter_round_trip():
+ orig = bm.get_membrane_scaling()
+ try:
+ s = scales.Scaling(scale=2., bias=1.)
+ bm.set_membrane_scaling(s)
+ assert bm.get_membrane_scaling() is s
+ finally:
+ bm.set_membrane_scaling(orig)
+ with pytest.raises(TypeError):
+ bm.set_membrane_scaling('not-a-scaling')
+
+
+def test_environment_get_platform():
+ assert bm.get_platform() in ('cpu', 'gpu', 'tpu')
+
+
+def test_set_applies_all_valid_args_round_trip(restore_environment):
+ """Exercise the apply branches of ``set()`` with every argument, then
+ restore each global it touched."""
+ orig_dt = bm.get_dt()
+ orig_mode = bm.get_mode()
+ orig_ms = bm.get_membrane_scaling()
+ orig_float = bm.get_float()
+ orig_int = bm.get_int()
+ orig_bool = bm.get_bool()
+ orig_complex = bm.get_complex()
+ orig_pytree = _defaults.bp_object_as_pytree
+ orig_return = _defaults.numpy_func_return
+ try:
+ bm.set(
+ mode=modes.batching_mode,
+ membrane_scaling=scales.Scaling(scale=2., bias=0.),
+ dt=0.25,
+ x64=False,
+ complex_=jnp.complex64,
+ float_=jnp.float32,
+ int_=jnp.int32,
+ bool_=jnp.bool_,
+ bp_object_as_pytree=True,
+ numpy_func_return='jax_array',
+ )
+ assert bm.get_dt() == 0.25
+ assert isinstance(bm.get_mode(), modes.BatchingMode)
+ assert bm.get_membrane_scaling().scale == 2.
+ assert bm.get_float() == jnp.float32
+ assert bm.get_int() == jnp.int32
+ assert bm.get_complex() == jnp.complex64
+ assert _defaults.bp_object_as_pytree is True
+ assert _defaults.numpy_func_return == 'jax_array'
+ finally:
+ bm.set_dt(orig_dt)
+ bm.set_mode(orig_mode)
+ bm.set_membrane_scaling(orig_ms)
+ bm.set_float(orig_float)
+ bm.set_int(orig_int)
+ bm.set_bool(orig_bool)
+ bm.set_complex(orig_complex)
+ _defaults.bp_object_as_pytree = orig_pytree
+ _defaults.numpy_func_return = orig_return
+
+
+def test_set_environment_is_set_alias():
+ from brainpy.math.environment import set_environment, set as _set
+ assert set_environment is _set
+
+
+def test_environment_context_all_dtype_kwargs(restore_environment):
+ """Exercise the ``__init__``/``__enter__``/``__exit__`` branches for the
+ dtype + scaling + pytree + numpy_func_return arguments."""
+ orig_float = bm.get_float()
+ with bm.environment(
+ membrane_scaling=scales.Scaling(scale=3., bias=0.),
+ float_=jnp.float32,
+ int_=jnp.int32,
+ bool_=jnp.bool_,
+ complex_=jnp.complex64,
+ bp_object_as_pytree=True,
+ numpy_func_return='jax_array',
+ ):
+ assert bm.get_membrane_scaling().scale == 3.
+ assert _defaults.bp_object_as_pytree is True
+ assert _defaults.numpy_func_return == 'jax_array'
+ assert bm.get_float() == orig_float
+ assert _defaults.numpy_func_return != 'jax_array' or _defaults.numpy_func_return == 'bp_array'
+
+
+def test_environment_as_decorator(restore_environment):
+ """``environment`` doubles as a decorator (``_DecoratorContextManager``)."""
+
+ @bm.environment(dt=0.07)
+ def get_dt_inside():
+ return bm.get_dt()
+
+ orig = bm.get_dt()
+ assert get_dt_inside() == 0.07
+ assert bm.get_dt() == orig
+
+
+def test_environment_init_rejects_bad_types():
+ with pytest.raises(AssertionError):
+ bm.environment(dt='not-a-float')
+ with pytest.raises(AssertionError):
+ bm.environment(mode='not-a-mode')
+ with pytest.raises(AssertionError):
+ bm.environment(x64='not-a-bool')
+ with pytest.raises(AssertionError):
+ bm.environment(numpy_func_return='bad-option')
+
+
+def test_environment_context_manager_restores(restore_environment):
+ orig_mode = bm.get_mode()
+ orig_dt = bm.get_dt()
+ with bm.environment(mode=modes.batching_mode, dt=0.2):
+ assert bm.get_mode() is modes.batching_mode
+ assert bm.get_dt() == 0.2
+ assert bm.get_mode() is orig_mode
+ assert bm.get_dt() == orig_dt
+
+
+def test_batching_and_training_environment(restore_environment):
+ with bm.batching_environment(dt=0.3):
+ assert isinstance(bm.get_mode(), modes.BatchingMode)
+ with bm.training_environment(batch_size=4):
+ m = bm.get_mode()
+ assert isinstance(m, modes.TrainingMode)
+ assert m.batch_size == 4
+
+
+def test_environment_x64_context_manager(restore_environment):
+ """``environment(x64=...)`` flips and restores the precision symmetrically."""
+ start = config.read('jax_enable_x64')
+ with bm.environment(x64=not start):
+ assert config.read('jax_enable_x64') == (not start)
+ assert config.read('jax_enable_x64') == start
+
+
+def test_enable_x64_with_bool_argument_warns(restore_environment):
+ """The legacy ``enable_x64(True)`` path emits a DeprecationWarning."""
+ with pytest.warns(DeprecationWarning):
+ bm.enable_x64(True)
+ assert config.read('jax_enable_x64') is True
+
+
+def test_set_x64_helper(restore_environment):
+ from brainpy.math.environment import set_x64
+ set_x64(True)
+ assert config.read('jax_enable_x64') is True
+ set_x64(False)
+ assert config.read('jax_enable_x64') is False
+
+
+def test_enable_x64_false_branch_warns(restore_environment):
+ """The deprecated ``enable_x64(False)`` path routes through the disable
+ branch and warns."""
+ bm.enable_x64() # go to 64 first
+ with pytest.warns(DeprecationWarning):
+ bm.enable_x64(False)
+ assert config.read('jax_enable_x64') is False
+ assert brainstate.environ.get_precision() == 32
+
+
+def test_environment_decorator_on_generator(restore_environment):
+ """Cover ``_DecoratorContextManager._wrap_generator`` by decorating a
+ generator function with ``environment``."""
+
+ @bm.environment(dt=0.09)
+ def gen():
+ yield bm.get_dt()
+ yield bm.get_dt()
+
+ orig = bm.get_dt()
+ values = list(gen())
+ assert values == [0.09, 0.09]
+ assert bm.get_dt() == orig
+
+
+def test_set_host_device_count_sets_env_var(monkeypatch):
+ """``set_host_device_count`` writes the XLA flag (no global precision
+ change)."""
+ monkeypatch.setenv('XLA_FLAGS', '')
+ bm.set_host_device_count(3)
+ import os
+ assert '--xla_force_host_platform_device_count=3' in os.environ['XLA_FLAGS']
+
+
+def test_gpu_memory_preallocation_toggles(monkeypatch):
+ import os
+ from brainpy.math.environment import gpu_memory_preallocation
+ monkeypatch.delenv('XLA_PYTHON_CLIENT_PREALLOCATE', raising=False)
+ monkeypatch.delenv('XLA_PYTHON_CLIENT_ALLOCATOR', raising=False)
+ bm.disable_gpu_memory_preallocation()
+ assert os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] == 'false'
+ assert os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] == 'platform'
+ bm.enable_gpu_memory_preallocation()
+ assert os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] == 'true'
+ assert 'XLA_PYTHON_CLIENT_ALLOCATOR' not in os.environ
+ gpu_memory_preallocation(0.5)
+ assert os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] == '0.5'
+ with pytest.raises(AssertionError):
+ gpu_memory_preallocation(1.5)
+
+
+def test_set_platform_rejects_unknown():
+ with pytest.raises(AssertionError):
+ bm.set_platform('quantum')
+
+
+def test_environment_identity_equality():
+ """``environment.__eq__`` is identity-based."""
+ env1 = bm.environment(dt=0.1)
+ env2 = bm.environment(dt=0.1)
+ assert env1 == env1
+ assert env1 != env2
+
+
+def test_environment_clone_preserves_settings():
+ env = bm.environment(dt=0.1, mode=modes.batching_mode)
+ clone = env.clone()
+ assert clone.dt == 0.1
+ assert clone.mode is modes.batching_mode
+
+
+# --- modes.py --------------------------------------------------------------
+
+def test_mode_repr():
+ assert repr(modes.nonbatching_mode) == 'NonBatchingMode'
+ assert repr(modes.batching_mode) == 'BatchingMode(batch_size=1)'
+ assert repr(modes.training_mode) == 'TrainingMode(batch_size=1)'
+
+
+def test_mode_equality():
+ assert modes.NonBatchingMode() == modes.NonBatchingMode()
+ assert (modes.NonBatchingMode() == modes.BatchingMode()) is False
+ # comparison against a non-Mode returns False (not an error)
+ assert (modes.nonbatching_mode == 5) is False
+
+
+def test_mode_predicates():
+ assert modes.batching_mode.is_batch_mode() is True
+ assert modes.training_mode.is_train_mode() is True
+ assert modes.nonbatching_mode.is_nonbatch_mode() is True
+ assert modes.nonbatching_mode.batch_size == tuple()
+ assert modes.batching_mode.batch_size == 1
+
+
+def test_mode_type_queries():
+ assert modes.batching_mode.is_one_of(modes.BatchingMode, modes.TrainingMode)
+ assert modes.batching_mode.is_a(modes.BatchingMode)
+ assert modes.batching_mode.is_parent_of(modes.TrainingMode)
+ assert modes.training_mode.is_child_of(modes.BatchingMode)
+ # invalid (non-type) arguments raise TypeError
+ with pytest.raises(TypeError):
+ modes.batching_mode.is_one_of(modes.batching_mode)
+ with pytest.raises(TypeError):
+ modes.batching_mode.is_parent_of(modes.batching_mode)
+ with pytest.raises(TypeError):
+ modes.batching_mode.is_child_of(modes.batching_mode)
+
+
+def test_training_mode_to_batch_mode():
+ tm = modes.TrainingMode(3)
+ bmode = tm.to_batch_mode()
+ assert isinstance(bmode, modes.BatchingMode)
+ assert not isinstance(bmode, modes.TrainingMode)
+ assert bmode.batch_size == 3
+
+
+# --- scales.py -------------------------------------------------------------
+
+def test_scaling_offset_std_inv():
+ s = scales.Scaling(scale=2., bias=1.)
+ assert s.offset_scaling(3.0) == (3.0 + 1.0) / 2.0
+ assert s.std_scaling(4.0) == 4.0 / 2.0
+ assert s.inv_scaling(4.0) == 4.0 * 2.0
+ # explicit overrides honored on a plain Scaling
+ assert s.offset_scaling(3.0, bias=0., scale=1.) == 3.0
+ assert s.std_scaling(4.0, scale=4.0) == 1.0
+ assert s.inv_scaling(4.0, scale=0.5) == 2.0
+
+
+def test_scaling_clone():
+ s = scales.Scaling(scale=2., bias=1.)
+ c = s.clone()
+ assert c.scale == 2. and c.bias == 1.
+ c2 = s.clone(bias=5., scale=3.)
+ assert c2.scale == 3. and c2.bias == 5.
+
+
+def test_scaling_transform():
+ s = scales.Scaling.transform([0., 10.], [0., 1.])
+ assert s.scale == 10.0
+ assert s.bias == 0.0
+ # round-trip: offset then inv recovers the offset-corrected value
+ assert s.offset_scaling(10.0) == 1.0
+
+
+# --- sharding.py -----------------------------------------------------------
+
+def test_device_mesh_context_sets_and_restores():
+ devs = np.asarray(jax.devices())
+ with sharding.device_mesh(devs, ('x',)) as mesh:
+ assert mesh.axis_names == ('x',)
+ sh = sharding.get_sharding(['x'])
+ assert sh is not None
+ # default mesh restored to None afterwards
+ assert sharding.get_sharding(['x']) is None
+
+
+def test_partition_none_passthrough():
+ x = jnp.array([1., 2.])
+ assert sharding.partition(x, None) is x
+
+
+def test_partition_by_axname_no_mesh_returns_input():
+ x = jnp.array([1., 2.])
+ # No default mesh -> input returned unchanged.
+ out = sharding.partition_by_axname(x, ['x'])
+ np.testing.assert_allclose(np.asarray(out), [1., 2.])
+ # axis_names None -> input returned.
+ assert sharding.partition_by_axname(x, None) is x
+
+
+def test_partition_by_axname_shape_mismatch_raises():
+ devs = np.asarray(jax.devices())
+ with sharding.device_mesh(devs, ('x',)):
+ with pytest.raises(ValueError):
+ # 1-D array but two requested axis names -> dim mismatch
+ sharding.partition_by_axname(jnp.array([1., 2.]), ['x', 'y'])
+
+
+def test_partition_by_sharding_none_and_typecheck():
+ x = jnp.array([1., 2.])
+ assert sharding.partition_by_sharding(x, None) is x
+ with pytest.raises(TypeError):
+ sharding.partition_by_sharding(x, 'not-a-sharding')
+
+
+def test_partition_invalid_type_raises():
+ with pytest.raises(TypeError):
+ sharding.partition(jnp.array([1.]), 12345)
+
+
+def test_keep_constraint_passthrough():
+ out = sharding.keep_constraint(jnp.array([1., 2.]))
+ np.testing.assert_allclose(np.asarray(out), [1., 2.])
+ # non-array leaves pass through untouched
+ assert sharding.keep_constraint(7) == 7
+
+
+def test_is_bp_array_helper():
+ assert sharding.is_bp_array(Array([1.])) is True
+ assert sharding.is_bp_array(jnp.array([1.])) is False
+
+
+def test_partition_by_axname_with_mesh_devices():
+ """Exercise the real device-put path of ``partition_by_axname`` /
+ ``_device_put`` with an actual single-device mesh."""
+ devs = np.asarray(jax.devices())
+ with sharding.device_mesh(devs, ('x',)):
+ # 1-D array, one axis name -> dims match -> resharded via _device_put.
+ out = sharding.partition_by_axname(jnp.array([1., 2.]), ['x'])
+ np.testing.assert_allclose(np.asarray(out), [1., 2.])
+ # Array leaf goes through the ``isinstance(x, Array)`` branch.
+ out2 = sharding.partition_by_axname(Array([3., 4.]), ['x'])
+ leaf = jax.tree_util.tree_leaves(out2, is_leaf=sharding.is_bp_array)[0]
+ np.testing.assert_allclose(np.asarray(leaf), [3., 4.])
+
+
+def test_partition_with_sharding_object():
+ """``partition`` with a concrete ``Sharding`` instance routes through
+ ``partition_by_sharding``/``_device_put``."""
+ devs = np.asarray(jax.devices())
+ mesh = jax.sharding.Mesh(devs, axis_names=('x',))
+ sh = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
+ out = sharding.partition(jnp.array([1., 2.]), sh)
+ np.testing.assert_allclose(np.asarray(out.value if isinstance(out, Array) else out),
+ [1., 2.])
+
+
+def test_partition_with_axis_name_sequence():
+ devs = np.asarray(jax.devices())
+ with sharding.device_mesh(devs, ('x',)):
+ out = sharding.partition(jnp.array([1., 2.]), ['x'])
+ leaf = jax.tree_util.tree_leaves(out, is_leaf=sharding.is_bp_array)[0]
+ np.testing.assert_allclose(np.asarray(leaf), [1., 2.])
+
+
+def test_keep_constraint_on_bp_array():
+ out = sharding.keep_constraint(Array([1., 2., 3.]))
+ np.testing.assert_allclose(np.asarray(out), [1., 2., 3.])
diff --git a/tests/audit/test_math_sparse_surrogate_fixes.py b/tests/audit/test_math_sparse_surrogate_fixes.py
new file mode 100644
index 000000000..e35fd15b0
--- /dev/null
+++ b/tests/audit/test_math_sparse_surrogate_fixes.py
@@ -0,0 +1,836 @@
+# -*- coding: utf-8 -*-
+"""Regression + coverage tests for the BrainPy v2.7.8 sparse / event / jitconn /
+surrogate / delay / pre-syn-post audit (see ``docs/issues-found-20260618.md``).
+
+This module locks in the fixes recorded in the audit for the following source
+files (audit IDs in parentheses):
+
+* ``brainpy/math/sparse/csr_mm.py`` — C-07 (``csrmm(transpose=True)`` must
+ compute ``Mᵀ @ B`` rather than ``B @ M``).
+* ``brainpy/math/event/csr_matmat.py`` — C-07 (same transpose fix for the
+ event-driven CSR matmat path, with ``BinaryArray`` operand).
+* ``brainpy/math/sparse/coo_mv.py`` — H-17 (``coomv`` must no longer touch
+ the removed ``brainevent.COO`` and must still equal ``dense @ v``).
+* ``brainpy/math/sparse/utils.py`` — H-18 (``coo_to_csr`` returns an int,
+ monotone ``indptr`` starting at 0 / ending at ``nnz``) and H-19
+ (``csr_to_dense`` wraps ``brainevent.CSR(...).todense()`` correctly).
+* ``brainpy/math/jitconn/matvec.py`` — M-13 (``mv_prob_*`` /
+ ``event_mv_prob_*`` are reproducible when an explicit ``seed`` is threaded).
+* ``brainpy/math/surrogate/_one_input.py`` and
+ ``brainpy/math/surrogate/_one_input_new.py`` — H-20..H-24: for every surrogate
+ the ``surrogate_grad`` matches ``jax.grad(surrogate_fun)``; ``GaussianGrad``
+ widens with ``sigma`` (H-20); ``PiecewiseQuadratic`` grad matches its forward
+ derivative (H-21); ``QPseudoSpike`` uses ``alpha-1`` (H-22); ``Arctan``
+ ``surrogate_fun`` does not raise (H-23); ``ERF`` ``surrogate_fun`` is
+ increasing (H-24).
+* ``brainpy/math/delayvars.py`` — C-09 (``TimeDelay`` ring-buffer read
+ applies the modulo) plus ``LengthDelay`` / ``ROTATE_UPDATE`` / ``CONCAT_UPDATE``
+ coverage.
+* ``brainpy/math/pre_syn_post.py`` — M-15 plus ``syn2post_mean`` /
+ ``syn2post_softmax`` empty-group → 0 vs genuine-NaN-propagation behaviour.
+
+All tests use tiny shapes and complete well under the time budget. They assert
+fixed (correct) behaviour, so they double as a regression guard against the bugs
+re-appearing.
+"""
+
+import numpy as np
+import jax
+import jax.numpy as jnp
+import pytest
+
+import brainpy.math as bm
+import brainpy.math.surrogate._one_input as old_surr
+import brainpy.math.surrogate._one_input_new as new_surr
+
+
+# ---------------------------------------------------------------------------
+# Helpers: build a small known CSR by hand (no scipy dependency).
+#
+# dense = [[1, 0, 2, 0],
+# [0, 3, 0, 0],
+# [0, 0, 0, 4]] (3 x 4)
+# ---------------------------------------------------------------------------
+
+def _known_csr():
+ dense = np.array([[1., 0., 2., 0.],
+ [0., 3., 0., 0.],
+ [0., 0., 0., 4.]], dtype=np.float32)
+ data = jnp.asarray([1., 2., 3., 4.], dtype=jnp.float32)
+ indices = jnp.asarray([0, 2, 1, 3], dtype=jnp.int32)
+ indptr = jnp.asarray([0, 2, 3, 4], dtype=jnp.int32)
+ shape = (3, 4)
+ return dense, data, indices, indptr, shape
+
+
+# ===========================================================================
+# C-07 — csrmm(transpose=True) computes Mᵀ @ B (sparse + event paths)
+# ===========================================================================
+
+def test_csrmm_non_transpose_matches_dense():
+ dense, data, indices, indptr, shape = _known_csr()
+ B = np.arange(4 * 2, dtype=np.float32).reshape(4, 2)
+ out = bm.sparse.csrmm(data, indices, indptr, jnp.asarray(B),
+ shape=shape, transpose=False)
+ out = np.asarray(out)
+ assert out.shape == (3, 2)
+ np.testing.assert_allclose(out, dense @ B, rtol=1e-5, atol=1e-5)
+
+
+def test_csrmm_transpose_matches_dense_T():
+ # Regression for C-07: transpose branch must equal Mᵀ @ B (shape (cols, k)),
+ # NOT B @ M. With shape=(3,4) and B=(3,2) the output must be (4,2).
+ dense, data, indices, indptr, shape = _known_csr()
+ B = np.arange(3 * 2, dtype=np.float32).reshape(3, 2)
+ out = bm.sparse.csrmm(data, indices, indptr, jnp.asarray(B),
+ shape=shape, transpose=True)
+ out = np.asarray(out)
+ assert out.shape == (4, 2) # would be (3, ...) under the old bug
+ np.testing.assert_allclose(out, dense.T @ B, rtol=1e-5, atol=1e-5)
+
+
+def test_csrmm_accepts_brainpy_array_operands():
+ # Exercise the ``isinstance(x, Array)`` unwrapping branches in csr_mm.py.
+ dense, data, indices, indptr, shape = _known_csr()
+ B = np.arange(4 * 2, dtype=np.float32).reshape(4, 2)
+ out = bm.sparse.csrmm(bm.asarray(data), bm.asarray(indices),
+ bm.asarray(indptr), bm.asarray(B),
+ shape=shape, transpose=False)
+ np.testing.assert_allclose(np.asarray(out), dense @ B, rtol=1e-5, atol=1e-5)
+
+
+def test_event_csrmm_transpose_matches_dense_T():
+ # C-07 for brainpy/math/event/csr_matmat.py — binary event matrix.
+ dense, data, indices, indptr, shape = _known_csr()
+ B = (np.arange(3 * 2).reshape(3, 2) % 2).astype(np.float32) # 0/1 events
+ out = bm.event.csrmm(data, indices, indptr, jnp.asarray(B),
+ shape=shape, transpose=True)
+ out = np.asarray(out)
+ assert out.shape == (4, 2)
+ np.testing.assert_allclose(out, dense.T @ B, rtol=1e-5, atol=1e-5)
+
+
+def test_event_csrmm_non_transpose_matches_dense():
+ dense, data, indices, indptr, shape = _known_csr()
+ B = (np.arange(4 * 2).reshape(4, 2) % 2).astype(np.float32)
+ out = bm.event.csrmm(data, indices, indptr, jnp.asarray(B),
+ shape=shape, transpose=False)
+ np.testing.assert_allclose(np.asarray(out), dense @ B, rtol=1e-5, atol=1e-5)
+
+
+def test_event_csrmm_accepts_brainpy_array_operands():
+ # Exercise the ``isinstance(x, Array)`` unwrap branches in csr_matmat.py.
+ dense, data, indices, indptr, shape = _known_csr()
+ B = (np.arange(4 * 2).reshape(4, 2) % 2).astype(np.float32)
+ out = bm.event.csrmm(bm.asarray(data), bm.asarray(indices),
+ bm.asarray(indptr), bm.asarray(B),
+ shape=shape, transpose=False)
+ np.testing.assert_allclose(np.asarray(out), dense @ B, rtol=1e-5, atol=1e-5)
+
+
+# ===========================================================================
+# H-17 — coomv works without the removed brainevent.COO, equals dense @ v
+# ===========================================================================
+
+def test_coomv_matches_dense_no_attribute_error():
+ dense, _, _, _, shape = _known_csr()
+ rows, cols = np.nonzero(dense)
+ vals = dense[rows, cols].astype(np.float32)
+ row = jnp.asarray(rows, dtype=jnp.int32)
+ col = jnp.asarray(cols, dtype=jnp.int32)
+ data = jnp.asarray(vals, dtype=jnp.float32)
+ v = np.arange(4, dtype=np.float32)
+ # Must not raise AttributeError about a removed COO type.
+ out = bm.sparse.coomv(data, row, col, jnp.asarray(v), shape=shape, transpose=False)
+ np.testing.assert_allclose(np.asarray(out), dense @ v, rtol=1e-5, atol=1e-5)
+
+
+def test_coomv_transpose_matches_dense_T():
+ dense, _, _, _, shape = _known_csr()
+ rows, cols = np.nonzero(dense)
+ data = jnp.asarray(dense[rows, cols].astype(np.float32))
+ row = jnp.asarray(rows, dtype=jnp.int32)
+ col = jnp.asarray(cols, dtype=jnp.int32)
+ v = np.arange(3, dtype=np.float32)
+ out = bm.sparse.coomv(data, row, col, jnp.asarray(v), shape=shape, transpose=True)
+ np.testing.assert_allclose(np.asarray(out), dense.T @ v, rtol=1e-5, atol=1e-5)
+
+
+def test_coomv_scalar_weight_broadcast():
+ # Exercise the scalar-data broadcast branch of coomv.
+ dense, _, _, _, shape = _known_csr()
+ rows, cols = np.nonzero(dense)
+ row = jnp.asarray(rows, dtype=jnp.int32)
+ col = jnp.asarray(cols, dtype=jnp.int32)
+ v = np.ones(4, dtype=np.float32)
+ out = bm.sparse.coomv(2.0, row, col, jnp.asarray(v), shape=shape, transpose=False)
+ # mask of the structure, weight 2.0 everywhere
+ mask = (dense != 0).astype(np.float32)
+ np.testing.assert_allclose(np.asarray(out), (2.0 * mask) @ v, rtol=1e-5, atol=1e-5)
+
+
+def test_coomv_accepts_brainpy_array_operands():
+ # Exercise the ``isinstance(x, Array)`` unwrap branches in coo_mv.py.
+ dense, _, _, _, shape = _known_csr()
+ rows, cols = np.nonzero(dense)
+ data = bm.asarray(dense[rows, cols].astype(np.float32))
+ row = bm.asarray(jnp.asarray(rows, dtype=jnp.int32))
+ col = bm.asarray(jnp.asarray(cols, dtype=jnp.int32))
+ v = bm.asarray(np.arange(4, dtype=np.float32))
+ out = bm.sparse.coomv(data, row, col, v, shape=shape, transpose=False)
+ np.testing.assert_allclose(np.asarray(out), dense @ np.arange(4, dtype=np.float32),
+ rtol=1e-5, atol=1e-5)
+
+
+# ===========================================================================
+# H-18 / H-19 — coo_to_csr int indptr; csr_to_dense matches reference
+# ===========================================================================
+
+def test_coo_to_csr_returns_int_monotone_indptr():
+ pre_ids = jnp.asarray([0, 0, 1, 2], dtype=jnp.int32)
+ post_ids = jnp.asarray([0, 2, 1, 3], dtype=jnp.int32)
+ indices, indptr = bm.sparse.coo_to_csr(pre_ids, post_ids, num_row=3)
+ indptr_np = np.asarray(indptr)
+ # integer dtype (H-18: was float before the fix)
+ assert jnp.issubdtype(indptr.dtype, jnp.integer)
+ # monotone non-decreasing, starts at 0, ends at nnz
+ assert indptr_np[0] == 0
+ assert indptr_np[-1] == int(post_ids.shape[0])
+ assert np.all(np.diff(indptr_np) >= 0)
+ # round-trip sanity: indices length == nnz
+ assert np.asarray(indices).shape[0] == int(post_ids.shape[0])
+
+
+def test_csr_to_dense_matches_reference():
+ dense, data, indices, indptr, shape = _known_csr()
+ out = bm.sparse.csr_to_dense(data, indices, indptr, shape=shape)
+ np.testing.assert_allclose(np.asarray(out), dense, rtol=1e-5, atol=1e-5)
+
+
+def test_csr_to_coo_round_trip():
+ # Exercises utils.csr_to_coo for coverage.
+ _, _, indices, indptr, _ = _known_csr()
+ row, col = bm.sparse.csr_to_coo(indices, indptr)
+ np.testing.assert_array_equal(np.asarray(row), [0, 0, 1, 2])
+ np.testing.assert_array_equal(np.asarray(col), np.asarray(indices))
+
+
+# ===========================================================================
+# sparse / event csrmv coverage (+ value correctness)
+# ===========================================================================
+
+def test_csrmv_matches_dense_both_directions():
+ dense, data, indices, indptr, shape = _known_csr()
+ v = np.arange(4, dtype=np.float32)
+ out = bm.sparse.csrmv(data, indices, indptr, jnp.asarray(v),
+ shape=shape, transpose=False)
+ np.testing.assert_allclose(np.asarray(out), dense @ v, rtol=1e-5, atol=1e-5)
+
+ vt = np.arange(3, dtype=np.float32)
+ out_t = bm.sparse.csrmv(data, indices, indptr, jnp.asarray(vt),
+ shape=shape, transpose=True)
+ np.testing.assert_allclose(np.asarray(out_t), dense.T @ vt, rtol=1e-5, atol=1e-5)
+
+
+def test_event_csrmv_matches_masked_dense():
+ dense, data, indices, indptr, shape = _known_csr()
+ # transpose=True: Mᵀ @ events, events length == shape[0] == 3
+ events = jnp.asarray([True, False, True], dtype=bool)
+ out = bm.event.csrmv(data, indices, indptr, events, shape=shape, transpose=True)
+ ref = dense.T @ np.asarray(events, dtype=np.float32)
+ np.testing.assert_allclose(np.asarray(out), ref, rtol=1e-5, atol=1e-5)
+
+ # non-transpose: M @ events, events length == shape[1] == 4
+ events2 = jnp.asarray([True, False, True, False], dtype=bool)
+ out2 = bm.event.csrmv(data, indices, indptr, events2, shape=shape, transpose=False)
+ ref2 = dense @ np.asarray(events2, dtype=np.float32)
+ np.testing.assert_allclose(np.asarray(out2), ref2, rtol=1e-5, atol=1e-5)
+
+
+# ===========================================================================
+# H-20..H-24 — surrogate gradients consistent with their forward functions
+# ===========================================================================
+
+# The two modules expose slightly different APIs:
+# _one_input.py : surrogate_grad(self, dz, x) (dz = upstream gradient)
+# _one_input_new.py : surrogate_grad(self, x)
+# These helpers normalise that so the same assertions run against both.
+
+def _grad_old(inst, x):
+ return inst.surrogate_grad(1.0, x)
+
+
+def _grad_new(inst, x):
+ return inst.surrogate_grad(x)
+
+
+_MODULES = [
+ ("_one_input", old_surr, _grad_old),
+ ("_one_input_new", new_surr, _grad_new),
+]
+
+# Surrogates that expose BOTH surrogate_fun and surrogate_grad and are
+# differentiable away from kinks — used for the grad-vs-autograd check.
+_HAS_FUN = ["PiecewiseQuadratic", "QPseudoSpike", "Arctan", "ERF"]
+
+
+@pytest.mark.parametrize("modname,mod,getgrad", _MODULES,
+ ids=[m[0] for m in _MODULES])
+@pytest.mark.parametrize("clsname", _HAS_FUN)
+def test_surrogate_grad_matches_autograd(modname, mod, getgrad, clsname):
+ """H-21/H-22/H-23/H-24: surrogate_grad == d/dx surrogate_fun."""
+ cls = getattr(mod, clsname)
+ inst = cls()
+ # Avoid the exact kinks (|x| = 1/alpha) where the piecewise derivative jumps.
+ xs = jnp.asarray([-0.85, -0.4, -0.1, 0.1, 0.4, 0.85], dtype=jnp.float32)
+ analytic = np.asarray(getgrad(inst, xs))
+
+ def fun_scalar(v):
+ return jnp.squeeze(inst.surrogate_fun(jnp.reshape(v, (1,))))
+
+ autograd = np.asarray(jax.vmap(jax.grad(fun_scalar))(xs))
+ np.testing.assert_allclose(analytic, autograd, rtol=2e-3, atol=2e-4)
+
+
+@pytest.mark.parametrize("modname,mod,getgrad", _MODULES,
+ ids=[m[0] for m in _MODULES])
+@pytest.mark.parametrize("clsname", _HAS_FUN)
+def test_surrogate_fun_monotone_increasing_on_unit_interval(modname, mod, getgrad, clsname):
+ """Each surrogate forward (origin) function is non-decreasing on [0, 1]."""
+ cls = getattr(mod, clsname)
+ inst = cls()
+ fx = np.asarray(inst.surrogate_fun(jnp.linspace(0.0, 1.0, 21)))
+ assert np.all(np.diff(fx) >= -1e-6)
+
+
+@pytest.mark.parametrize("modname,mod,getgrad", _MODULES,
+ ids=[m[0] for m in _MODULES])
+def test_arctan_surrogate_fun_does_not_raise(modname, mod, getgrad):
+ """H-23: Arctan.surrogate_fun previously called jnp.arctan2 with one arg."""
+ inst = mod.Arctan()
+ out = np.asarray(inst.surrogate_fun(jnp.asarray([-0.5, -0.1, 0.0, 0.1, 0.5])))
+ assert np.all(np.isfinite(out))
+ # arctan forward crosses 0.5 at x = 0
+ assert np.isclose(out[2], 0.5, atol=1e-6)
+ assert np.all(np.diff(out) > 0)
+
+
+@pytest.mark.parametrize("modname,mod,getgrad", _MODULES,
+ ids=[m[0] for m in _MODULES])
+def test_erf_surrogate_fun_is_increasing(modname, mod, getgrad):
+ """H-24: ERF.surrogate_fun must be increasing (was decreasing before)."""
+ inst = mod.ERF()
+ out = np.asarray(inst.surrogate_fun(jnp.linspace(-0.5, 0.5, 11)))
+ assert np.all(np.diff(out) > 0)
+ assert np.isclose(out[5], 0.5, atol=1e-6) # centred at x = 0
+
+
+@pytest.mark.parametrize("modname,mod,getgrad", _MODULES,
+ ids=[m[0] for m in _MODULES])
+def test_gaussian_grad_bump_widens_with_sigma(modname, mod, getgrad):
+ """H-20: GaussianGrad — at x=1 the gradient must INCREASE with sigma
+ (a wider bump), proving the sigma is no longer inverted by the
+ operator-precedence bug ``exp(-(x**2)/2*sigma**2)``."""
+ g_narrow = float(np.asarray(getgrad(mod.GaussianGrad(sigma=0.5), jnp.asarray(1.0))))
+ g_wide = float(np.asarray(getgrad(mod.GaussianGrad(sigma=2.0), jnp.asarray(1.0))))
+ assert g_wide > g_narrow
+ # Sanity on the intended magnitude (audit: grad@±1 ≈ 0.088 for sigma=2).
+ assert g_wide == pytest.approx(0.088, abs=2e-2)
+
+
+@pytest.mark.parametrize("modname,mod,getgrad", _MODULES,
+ ids=[m[0] for m in _MODULES])
+def test_piecewise_quadratic_grad_formula(modname, mod, getgrad):
+ """H-21: grad == -alpha**2 |x| + alpha inside the support, 0 outside."""
+ inst = mod.PiecewiseQuadratic(alpha=1.0)
+ g_in = float(np.asarray(getgrad(inst, jnp.asarray(0.5))))
+ assert g_in == pytest.approx(-1.0 * 0.5 + 1.0) # = 0.5
+ g_out = float(np.asarray(getgrad(inst, jnp.asarray(5.0))))
+ assert g_out == pytest.approx(0.0)
+
+
+@pytest.mark.parametrize("modname,mod,getgrad", _MODULES,
+ ids=[m[0] for m in _MODULES])
+def test_qpseudospike_grad_uses_alpha_minus_one(modname, mod, getgrad):
+ """H-22: grad denominator uses (alpha-1); grad at 0 == 1."""
+ inst = mod.QPseudoSpike(alpha=2.0)
+ g0 = float(np.asarray(getgrad(inst, jnp.asarray(0.0))))
+ assert g0 == pytest.approx(1.0, abs=1e-6)
+
+
+# ===========================================================================
+# Surrogate coverage — every class's __call__ + surrogate_grad in both modules
+# ===========================================================================
+
+def _new_surrogate_classes():
+ return [getattr(new_surr, n) for n in new_surr.__all__
+ if n[0].isupper() and n != "Surrogate"]
+
+
+def _old_surrogate_classes():
+ out = []
+ for n in dir(old_surr):
+ obj = getattr(old_surr, n)
+ if (isinstance(obj, type) and issubclass(obj, old_surr.Surrogate)
+ and n not in ("Surrogate", "_OneInpSurrogate")):
+ out.append(obj)
+ return out
+
+
+@pytest.mark.parametrize("cls", _new_surrogate_classes(),
+ ids=lambda c: c.__name__)
+def test_new_surrogate_call_and_grad_run(cls):
+ inst = cls()
+ x = jnp.linspace(-1.5, 1.5, 9)
+ y = inst(x) # __call__ -> heaviside forward
+ assert np.asarray(y).shape == (9,)
+ # forward is a {0,1} spike indicator
+ assert set(np.unique(np.asarray(y)).tolist()).issubset({0.0, 1.0})
+ g = np.asarray(inst.surrogate_grad(x)) # surrogate_grad(x)
+ assert g.shape == (9,) and np.all(np.isfinite(g))
+ # grad flows through __call__
+ flow = jax.grad(lambda v: jnp.sum(inst(v)))(x)
+ assert np.all(np.isfinite(np.asarray(flow)))
+
+
+@pytest.mark.parametrize("cls", _old_surrogate_classes(),
+ ids=lambda c: c.__name__)
+def test_old_surrogate_call_and_grad_run(cls):
+ inst = cls()
+ x = jnp.linspace(-1.5, 1.5, 9)
+ y = inst(x) # custom-gradient forward
+ assert np.asarray(y).shape == (9,)
+ assert set(np.unique(np.asarray(y)).tolist()).issubset({0.0, 1.0})
+ g = np.asarray(inst.surrogate_grad(1.0, x)) # surrogate_grad(dz, x)
+ assert g.shape == (9,) and np.all(np.isfinite(g))
+ flow = jax.grad(lambda v: jnp.sum(inst(v)))(x)
+ assert np.all(np.isfinite(np.asarray(flow)))
+
+
+def test_new_surrogate_repr_and_functional_aliases():
+ # Exercise functional (lowercase) entry points + __repr__ for coverage.
+ x = jnp.linspace(-1.0, 1.0, 5)
+ assert "Arctan" in repr(new_surr.Arctan())
+ for fn in (new_surr.sigmoid, new_surr.arctan, new_surr.erf,
+ new_surr.gaussian_grad, new_surr.relu_grad):
+ assert np.asarray(fn(x)).shape == (5,)
+
+
+def test_old_surrogate_repr_and_functional_aliases():
+ x = jnp.linspace(-1.0, 1.0, 5)
+ assert "GaussianGrad" in repr(old_surr.GaussianGrad())
+ for fn in (old_surr.sigmoid, old_surr.arctan, old_surr.erf,
+ old_surr.gaussian_grad, old_surr.q_pseudo_spike):
+ assert np.asarray(fn(x)).shape == (5,)
+
+
+# Lowercase functional aliases present in BOTH modules.
+_FUNC_NAMES = [n for n in new_surr.__all__ if n[0].islower()]
+
+
+@pytest.mark.parametrize("fname", _FUNC_NAMES)
+def test_old_functional_alias_forward_and_origin(fname):
+ """Exercise every ``_one_input`` functional alias (heaviside forward and,
+ where supported, the ``origin=True`` smooth forward)."""
+ import inspect
+ fn = getattr(old_surr, fname)
+ x = jnp.linspace(-1.2, 1.2, 7)
+ y = np.asarray(fn(x))
+ assert y.shape == (7,) and np.all(np.isfinite(y))
+ if "origin" in inspect.signature(fn).parameters:
+ yo = np.asarray(fn(x, origin=True)) # exercises surrogate_fun
+ assert yo.shape == (7,) and np.all(np.isfinite(yo))
+
+
+@pytest.mark.parametrize("fname", _FUNC_NAMES)
+def test_new_functional_alias_forward(fname):
+ """Exercise every ``_one_input_new`` functional alias (heaviside forward)."""
+ fn = getattr(new_surr, fname)
+ x = jnp.linspace(-1.2, 1.2, 7)
+ y = np.asarray(fn(x))
+ assert y.shape == (7,) and np.all(np.isfinite(y))
+
+
+def _new_classes_with_surrogate_fun():
+ out = []
+ for n in new_surr.__all__:
+ if not (n[0].isupper() and n != "Surrogate"):
+ continue
+ c = getattr(new_surr, n)
+ if c.surrogate_fun is not new_surr.Surrogate.surrogate_fun:
+ out.append(c)
+ return out
+
+
+@pytest.mark.parametrize("cls", _new_classes_with_surrogate_fun(),
+ ids=lambda c: c.__name__)
+def test_new_surrogate_fun_runs(cls):
+ """Cover the ``surrogate_fun`` body of every new-module class that has one."""
+ out = np.asarray(cls().surrogate_fun(jnp.linspace(-1.2, 1.2, 9)))
+ assert out.shape == (9,) and np.all(np.isfinite(out))
+
+
+# ===========================================================================
+# C-09 — TimeDelay ring-buffer read applies the modulo
+# ===========================================================================
+
+def test_time_delay_ring_buffer_modulo():
+ """After pushing a ramp k*0.1 for k=1..30, the buffer wraps several times;
+ the read must apply ``% num_delay_step`` (C-09) so d(now) == 3.0."""
+ d = bm.TimeDelay(bm.zeros(1), delay_len=1.0, dt=0.1)
+ for k in range(1, 31):
+ d.update(bm.ones(1) * (k * 0.1))
+ now = d.current_time[0]
+ assert float(np.asarray(d(now))[0]) == pytest.approx(3.0, abs=1e-4)
+ assert float(np.asarray(d(now - 0.5))[0]) == pytest.approx(2.5, abs=1e-4)
+
+
+def test_time_delay_round_interp_method():
+ d = bm.TimeDelay(bm.zeros(1), delay_len=1.0, dt=0.1, interp_method='round')
+ for k in range(1, 31):
+ d.update(bm.ones(1) * (k * 0.1))
+ now = d.current_time[0]
+ # round interpolation hits the exact-step branch too
+ assert float(np.asarray(d(now))[0]) == pytest.approx(3.0, abs=1e-4)
+
+
+def test_time_delay_reset():
+ d = bm.TimeDelay(bm.zeros(2), delay_len=1.0, dt=0.1)
+ for k in range(1, 12):
+ d.update(bm.ones(2) * (k * 0.1))
+ d.reset(bm.zeros(2), delay_len=1.0)
+ now = d.current_time[0]
+ np.testing.assert_allclose(np.asarray(d(now)), np.zeros(2), atol=1e-6)
+
+
+# ===========================================================================
+# LengthDelay — both ROTATE_UPDATE and CONCAT_UPDATE
+# ===========================================================================
+
+@pytest.mark.parametrize("method", [bm.ROTATE_UPDATE, bm.CONCAT_UPDATE])
+def test_length_delay_update_methods(method):
+ ld = bm.LengthDelay(bm.zeros(2), delay_len=3, update_method=method)
+ for k in range(1, 6):
+ ld.update(bm.ones(2) * float(k))
+ # most-recent push is 5 -> delay 0 returns 5; delay 2 returns 3
+ np.testing.assert_allclose(np.asarray(ld(0)), [5., 5.], atol=1e-6)
+ np.testing.assert_allclose(np.asarray(ld(2)), [3., 3.], atol=1e-6)
+
+
+def test_length_delay_reset_and_retrieve():
+ ld = bm.LengthDelay(bm.zeros(2), delay_len=3)
+ ld.reset(bm.ones(2), delay_len=3)
+ out = ld.retrieve(1)
+ assert np.asarray(out).shape == (2,)
+
+
+def test_length_delay_initial_delay_data_scalar_and_callable():
+ # scalar initial_delay_data branch
+ ld = bm.LengthDelay(bm.zeros(2), delay_len=3, initial_delay_data=1.0)
+ np.testing.assert_allclose(np.asarray(ld.retrieve(2)), [1., 1.], atol=1e-6)
+ # callable initial_delay_data branch (plain lambda, no dtype kwarg)
+ ld2 = bm.LengthDelay(bm.zeros(2), delay_len=3,
+ initial_delay_data=lambda shape: jnp.ones(shape) * 7.0)
+ np.testing.assert_allclose(np.asarray(ld2.retrieve(2)), [7., 7.], atol=1e-6)
+
+
+def test_length_delay_concat_single_step():
+ # delay_len=0 -> num_delay_step=1 exercises the CONCAT_UPDATE short branch.
+ ld = bm.LengthDelay(bm.zeros(2), delay_len=0, update_method=bm.CONCAT_UPDATE)
+ ld.update(bm.ones(2) * 3.0)
+ np.testing.assert_allclose(np.asarray(ld(0)), [3., 3.], atol=1e-6)
+
+
+def test_time_delay_callable_before_t0():
+ # Covers the _FUNC_BEFORE path (callable before_t0 + cond branch in __call__).
+ d = bm.TimeDelay(bm.zeros(1), delay_len=1.0, dt=0.1, t0=0.0,
+ before_t0=lambda t: jnp.ones(1) * 9.0)
+ # request a time strictly before t0 -> uses before_t0 function
+ out = np.asarray(d(-0.5))
+ np.testing.assert_allclose(out, [9.0], atol=1e-6)
+
+
+def test_neutral_delay_aliases():
+ # NeuTimeDelay / NeuLenDelay are thin aliases; just instantiate + call.
+ ntd = bm.NeuTimeDelay(bm.zeros(1), delay_len=0.5, dt=0.1)
+ ntd.update(bm.ones(1))
+ assert np.asarray(ntd(ntd.current_time[0])).shape == (1,)
+ nld = bm.NeuLenDelay(bm.zeros(1), delay_len=2)
+ nld.update(bm.ones(1))
+ assert np.asarray(nld(0)).shape == (1,)
+
+
+def test_time_delay_array_before_t0_and_indices():
+ # array before_t0 fills the pre-t0 buffer (the _DATA_BEFORE branch);
+ # indices select a sub-slice of the retrieved value.
+ d = bm.TimeDelay(bm.zeros(3), delay_len=1.0, dt=0.1, before_t0=5.0)
+ for k in range(1, 31):
+ d.update(bm.ones(3) * (k * 0.1))
+ now = d.current_time[0]
+ out = np.asarray(d(now, indices=jnp.asarray([0, 2])))
+ assert out.shape == (2,)
+
+
+def test_time_delay_reset_with_callable_before_t0():
+ d = bm.TimeDelay(bm.zeros(1), delay_len=1.0, dt=0.1)
+ d.reset(bm.zeros(1), delay_len=1.0, before_t0=lambda t: jnp.ones(1) * 4.0)
+ out = np.asarray(d(-0.5)) # before t0 -> uses callable
+ np.testing.assert_allclose(out, [4.0], atol=1e-6)
+
+
+def test_time_delay_reset_with_array_before_t0():
+ d = bm.TimeDelay(bm.zeros(2), delay_len=1.0, dt=0.1)
+ d.reset(bm.zeros(2), delay_len=1.0, before_t0=3.0)
+ now = d.current_time[0]
+ assert np.asarray(d(now)).shape == (2,)
+
+
+def test_length_delay_retrieve_with_indices_and_repr():
+ ld = bm.LengthDelay(bm.zeros(4), delay_len=3)
+ for k in range(1, 5):
+ ld.update(bm.ones(4) * float(k))
+ out = np.asarray(ld.retrieve(1, jnp.asarray([0, 1])))
+ assert out.shape == (2,)
+ assert "LengthDelay" in repr(ld)
+ assert ld.delay_shape[0] == 4 # num_delay_step (delay_len + 1)
+
+
+def test_length_delay_update_from_variable_target():
+ # update(value=None) pulls from the stored delay_target Variable.
+ target = bm.Variable(bm.ones(2) * 2.0)
+ ld = bm.LengthDelay(target, delay_len=2)
+ ld.update() # no explicit value -> uses delay_target.value
+ assert np.asarray(ld(0)).shape == (2,)
+
+
+def test_time_delay_validation_errors():
+ # invalid delay_target type
+ with pytest.raises(ValueError):
+ bm.TimeDelay([0.0], delay_len=1.0, dt=0.1)
+ # unsupported interpolation method
+ from brainpy._errors import UnsupportedError
+ with pytest.raises(UnsupportedError):
+ bm.TimeDelay(bm.zeros(1), delay_len=1.0, dt=0.1, interp_method='nope')
+ # unsupported before_t0 type
+ with pytest.raises(ValueError):
+ bm.TimeDelay(bm.zeros(1), delay_len=1.0, dt=0.1, before_t0='bad')
+
+
+def test_length_delay_validation_errors():
+ with pytest.raises(ValueError):
+ bm.LengthDelay([0.0], delay_len=2)
+ ld = bm.LengthDelay(bm.zeros(2), delay_len=2)
+ with pytest.raises(ValueError):
+ ld.reset(bm.zeros(2), delay_len=2, initial_delay_data='bad')
+
+
+# ===========================================================================
+# M-13 — jitconn mv_prob_* / event_mv_prob_* reproducible with explicit seed
+# ===========================================================================
+
+_SHAPE = (8, 6)
+
+
+def _vec_for(transpose):
+ n = _SHAPE[0] if transpose else _SHAPE[1]
+ return jnp.asarray(np.random.RandomState(0).randn(n).astype(np.float32))
+
+
+@pytest.mark.parametrize("transpose", [False, True])
+@pytest.mark.parametrize("outdim_parallel", [True, False])
+def test_mv_prob_homo_reproducible(transpose, outdim_parallel):
+ v = _vec_for(transpose)
+ kw = dict(weight=1.5, conn_prob=0.3, shape=_SHAPE,
+ transpose=transpose, outdim_parallel=outdim_parallel)
+ o1 = np.asarray(bm.jitconn.mv_prob_homo(v, seed=123, **kw))
+ o2 = np.asarray(bm.jitconn.mv_prob_homo(v, seed=123, **kw))
+ np.testing.assert_array_equal(o1, o2) # reproducible
+ assert o1.shape == (_SHAPE[1] if transpose else _SHAPE[0],)
+
+
+def test_mv_prob_uniform_and_normal_reproducible():
+ v = _vec_for(False)
+ ou1 = np.asarray(bm.jitconn.mv_prob_uniform(v, w_low=-1., w_high=1.,
+ conn_prob=0.3, seed=7, shape=_SHAPE))
+ ou2 = np.asarray(bm.jitconn.mv_prob_uniform(v, w_low=-1., w_high=1.,
+ conn_prob=0.3, seed=7, shape=_SHAPE))
+ np.testing.assert_array_equal(ou1, ou2)
+
+ on1 = np.asarray(bm.jitconn.mv_prob_normal(v, w_mu=0., w_sigma=1.,
+ conn_prob=0.3, seed=9, shape=_SHAPE))
+ on2 = np.asarray(bm.jitconn.mv_prob_normal(v, w_mu=0., w_sigma=1.,
+ conn_prob=0.3, seed=9, shape=_SHAPE))
+ np.testing.assert_array_equal(on1, on2)
+
+
+def test_mv_prob_homo_seed_none_runs():
+ # Cover the ``seed is None`` host-RNG branch (documented as non-reproducible).
+ v = _vec_for(False)
+ out = np.asarray(bm.jitconn.mv_prob_homo(v, weight=1.0, conn_prob=0.3,
+ seed=None, shape=_SHAPE))
+ assert out.shape == (_SHAPE[0],) and np.all(np.isfinite(out))
+
+
+@pytest.mark.parametrize("fn", [
+ lambda ev, **k: bm.jitconn.event_mv_prob_homo(ev, 1.0, 0.3, **k),
+ lambda ev, **k: bm.jitconn.event_mv_prob_uniform(ev, -1.0, 1.0, 0.3, **k),
+ lambda ev, **k: bm.jitconn.event_mv_prob_normal(ev, 0.0, 1.0, 0.3, **k),
+], ids=["homo", "uniform", "normal"])
+def test_event_mv_prob_reproducible(fn):
+ events = jnp.asarray(np.random.RandomState(1).rand(_SHAPE[1]) > 0.5)
+ o1 = np.asarray(fn(events, seed=11, shape=_SHAPE))
+ o2 = np.asarray(fn(events, seed=11, shape=_SHAPE))
+ np.testing.assert_array_equal(o1, o2)
+ assert o1.shape == (_SHAPE[0],)
+
+
+def test_mv_prob_uniform_normal_transpose_and_array_args():
+ # transpose=True branches + Array-operand unwrap branches for uniform/normal.
+ vrow = bm.asarray(np.random.RandomState(3).randn(_SHAPE[0]).astype(np.float32))
+ ou = np.asarray(bm.jitconn.mv_prob_uniform(vrow, bm.asarray(-1.0), bm.asarray(1.0),
+ 0.3, seed=4, shape=_SHAPE, transpose=True))
+ assert ou.shape == (_SHAPE[1],) and np.all(np.isfinite(ou))
+ on = np.asarray(bm.jitconn.mv_prob_normal(vrow, bm.asarray(0.0), bm.asarray(1.0),
+ 0.3, seed=4, shape=_SHAPE, transpose=True))
+ assert on.shape == (_SHAPE[1],) and np.all(np.isfinite(on))
+ # homo with Array weight + Array vector
+ oh = np.asarray(bm.jitconn.mv_prob_homo(vrow, bm.asarray(1.0), 0.3,
+ seed=4, shape=_SHAPE, transpose=True))
+ assert oh.shape == (_SHAPE[1],)
+
+
+def test_get_weight_matrices():
+ mh = np.asarray(bm.jitconn.get_homo_weight_matrix(1.0, 0.3, seed=1, shape=_SHAPE))
+ mu = np.asarray(bm.jitconn.get_uniform_weight_matrix(-1., 1., 0.3, seed=1, shape=_SHAPE))
+ mn = np.asarray(bm.jitconn.get_normal_weight_matrix(0., 1., 0.3, seed=1, shape=_SHAPE))
+ for m in (mh, mu, mn):
+ assert m.shape == _SHAPE
+
+
+def test_get_weight_matrices_transpose_and_seed_none():
+ # transpose=True -> (cols, rows); seed=None branch; Array args unwrap.
+ mh = np.asarray(bm.jitconn.get_homo_weight_matrix(bm.asarray(1.0), 0.3,
+ seed=None, shape=_SHAPE, transpose=True))
+ assert mh.shape == (_SHAPE[1], _SHAPE[0])
+ mu = np.asarray(bm.jitconn.get_uniform_weight_matrix(bm.asarray(-1.0), bm.asarray(1.0), 0.3,
+ seed=None, shape=_SHAPE, transpose=True))
+ assert mu.shape == (_SHAPE[1], _SHAPE[0])
+ mn = np.asarray(bm.jitconn.get_normal_weight_matrix(bm.asarray(0.0), bm.asarray(1.0), 0.3,
+ seed=None, shape=_SHAPE, transpose=True))
+ assert mn.shape == (_SHAPE[1], _SHAPE[0])
+
+
+# ===========================================================================
+# M-15 — pre2post_mean / syn2post_mean / syn2post_softmax edge cases
+# ===========================================================================
+
+def test_syn2post_mean_empty_group_is_zero():
+ syn = jnp.asarray([1., 3., 5.])
+ post_ids = jnp.asarray([0, 0, 2]) # post group 1 is empty
+ out = np.asarray(bm.syn2post_mean(syn, post_ids, 3))
+ assert out[0] == pytest.approx(2.0) # mean(1, 3)
+ assert out[1] == pytest.approx(0.0) # empty -> 0, not NaN
+ assert out[2] == pytest.approx(5.0)
+
+
+def test_syn2post_mean_propagates_genuine_nan():
+ syn = jnp.asarray([1., np.nan, 5.])
+ post_ids = jnp.asarray([0, 0, 2])
+ out = np.asarray(bm.syn2post_mean(syn, post_ids, 3))
+ assert np.isnan(out[0]) # genuine NaN must propagate
+ assert out[1] == pytest.approx(0.0)
+ assert out[2] == pytest.approx(5.0)
+
+
+def test_syn2post_softmax_propagates_nan_and_normalizes():
+ post_ids = jnp.asarray([0, 0, 2])
+ # genuine NaN must not be silently zeroed
+ syn_nan = jnp.asarray([1., np.nan, 5.])
+ out_nan = np.asarray(bm.syn2post_softmax(syn_nan, post_ids, 3))
+ assert np.isnan(out_nan[0]) and np.isnan(out_nan[1])
+ # clean input: each non-empty group's softmax weights sum to 1
+ syn = jnp.asarray([1., 3., 5.])
+ out = np.asarray(bm.syn2post_softmax(syn, post_ids, 3))
+ assert out[0] + out[1] == pytest.approx(1.0, abs=1e-6)
+ assert out[2] == pytest.approx(1.0, abs=1e-6)
+
+
+def test_pre2post_mean_scalar_and_vector_branches():
+ post_ids = jnp.asarray([0, 0, 2])
+ # scalar branch: constant broadcast to targeted posts, others 0
+ pm = np.asarray(bm.pre2post_mean(2.0, 3, post_ids))
+ np.testing.assert_allclose(pm, [2., 0., 2.], atol=1e-6)
+ # vector branch routes through syn2post_mean
+ pre_vals = jnp.asarray([10., 20., 30.])
+ pre_ids = jnp.asarray([0, 1, 2])
+ pmv = np.asarray(bm.pre2post_mean(pre_vals, 3, post_ids, pre_ids))
+ # post 0 gets mean(pre[0], pre[1]) = 15, post 2 gets pre[2] = 30
+ np.testing.assert_allclose(pmv, [15., 0., 30.], atol=1e-6)
+
+
+def test_pre2post_reductions_and_pre2syn():
+ post_ids = jnp.asarray([0, 0, 2])
+ assert np.asarray(bm.pre2post_sum(2.0, 3, post_ids)).tolist() == [4., 0., 2.]
+ assert np.asarray(bm.pre2post_prod(2.0, 3, post_ids)).tolist() == [0., 0., 0.]
+ assert np.asarray(bm.pre2post_max(2.0, 3, post_ids)).tolist() == [2., 0., 2.]
+ assert np.asarray(bm.pre2post_min(2.0, 3, post_ids)).tolist() == [0., 0., 0.]
+ syn = bm.pre2syn(jnp.asarray([1., 2., 3.]), jnp.asarray([0, 2]))
+ np.testing.assert_allclose(np.asarray(syn), [1., 3.], atol=1e-6)
+
+
+def test_syn2post_reductions():
+ syn = jnp.asarray([1., 3., 5.])
+ post_ids = jnp.asarray([0, 0, 2])
+ np.testing.assert_allclose(np.asarray(bm.syn2post_sum(syn, post_ids, 3)),
+ [4., 0., 5.], atol=1e-6)
+ np.testing.assert_allclose(np.asarray(bm.syn2post_prod(syn, post_ids, 3)),
+ [3., 1., 5.], atol=1e-6)
+ # max of empty group is -inf, min is +inf (segment reduction identities)
+ assert np.asarray(bm.syn2post_max(syn, post_ids, 3))[0] == 3.
+ assert np.asarray(bm.syn2post_min(syn, post_ids, 3))[0] == 1.
+
+
+def test_pre2post_reductions_vector_branch():
+ # Vector pre_values + pre_ids exercises the heterogeneous gather branch.
+ pre_vals = jnp.asarray([1., 2., 3., 4.])
+ pre_ids = jnp.asarray([0, 1, 2, 3])
+ post_ids = jnp.asarray([0, 0, 1, 1])
+ post_num = 2
+ assert np.asarray(bm.pre2post_sum(pre_vals, post_num, post_ids, pre_ids)).tolist() == [3., 7.]
+ # prod / min accumulate against the zero-initialised output -> 0 here.
+ assert np.asarray(bm.pre2post_prod(pre_vals, post_num, post_ids, pre_ids)).tolist() == [0., 0.]
+ assert np.asarray(bm.pre2post_max(pre_vals, post_num, post_ids, pre_ids)).tolist() == [2., 4.]
+ assert np.asarray(bm.pre2post_min(pre_vals, post_num, post_ids, pre_ids)).tolist() == [0., 0.]
+
+
+def test_pre2post_vector_without_pre_ids_raises():
+ # The _raise_pre_ids_is_none guard fires for heterogeneous values w/o pre_ids.
+ from brainpy._errors import MathError
+ pre_vals = jnp.asarray([1., 2., 3.])
+ post_ids = jnp.asarray([0, 1, 1])
+ with pytest.raises(MathError):
+ bm.pre2post_sum(pre_vals, 2, post_ids)
+
+
+def test_syn2post_bool_dtype_promotion():
+ # bool syn_values -> promoted to int in every syn2post reduction.
+ # group 0 = {True, False} -> {1, 0}; group 1 = {True} -> {1}
+ syn = jnp.asarray([True, False, True], dtype=bool)
+ post_ids = jnp.asarray([0, 0, 1])
+ assert np.asarray(bm.syn2post_sum(syn, post_ids, 2)).tolist() == [1, 1]
+ assert np.asarray(bm.syn2post_prod(syn, post_ids, 2)).tolist() == [0, 1]
+ assert np.asarray(bm.syn2post_max(syn, post_ids, 2)).tolist() == [1, 1]
+ assert np.asarray(bm.syn2post_min(syn, post_ids, 2)).tolist() == [0, 1]
+ np.testing.assert_allclose(np.asarray(bm.syn2post_mean(syn, post_ids, 2)), [0.5, 1.0], atol=1e-6)
+ sm = np.asarray(bm.syn2post_softmax(syn, post_ids, 2))
+ assert np.all(np.isfinite(sm))
+
+
+def test_pre2post_event_sum():
+ # CSR connectivity: pre 0 -> post {0,2}, pre 1 -> post {1}, pre 2 -> post {3}
+ indices = jnp.asarray([0, 2, 1, 3], dtype=jnp.int32)
+ indptr = jnp.asarray([0, 2, 3, 4], dtype=jnp.int32)
+ events = jnp.asarray([True, False, True], dtype=bool)
+ out = np.asarray(bm.pre2post_event_sum(events, (indices, indptr), 4, 1.0))
+ # pre 0 fires -> +1 at posts 0 and 2; pre 2 fires -> +1 at post 3
+ np.testing.assert_allclose(out, [1., 0., 1., 1.], atol=1e-6)
diff --git a/tests/audit/test_object_transform_fixes.py b/tests/audit/test_object_transform_fixes.py
new file mode 100644
index 000000000..9e690216d
--- /dev/null
+++ b/tests/audit/test_object_transform_fixes.py
@@ -0,0 +1,1153 @@
+# -*- coding: utf-8 -*-
+"""Regression + coverage tests for the BrainPy v2.7.8 object-transform audit
+(see ``docs/issues-found-20260618.md``).
+
+This module exercises the fixes recorded in the audit for the
+``brainpy/math/object_transform`` package:
+
+* ``jit.py`` — H-01 (``cls_jit`` no longer corrupts NEGATIVE
+ ``static_argnums``/``donate_argnums`` by shifting them, which previously
+ double-marked ``self`` static), M-02 (``donate_argnums`` shifted by +1 so
+ ``self`` is never donated), H-04 (``bm.jit(fn, dyn_vars=..., child_objs=...)``
+ pops the legacy kwargs with a ``DeprecationWarning`` instead of forwarding
+ them into brainstate and raising ``TypeError``).
+* ``controls.py`` — H-02 (``cond``/``for_loop``/``scan``/``while_loop``
+ accept a ``Variable``/``Array`` in ``operands``), H-03 (``for_loop(jit=False)``
+ with a zero-length pytree operand returns ``[]`` instead of crashing),
+ M-03 (``scan`` returns ``(carry, ys)``), M-05 (``ifelse`` builds mutually
+ exclusive conditions), M-06 (``while_loop`` body returning ``None`` raises).
+* ``function.py`` — ``Partial``/``to_object`` behaviour and L-04 (``function``
+ emits a ``DeprecationWarning``).
+* ``_utils.py`` — ``warp_to_no_state_input_output`` strips/restores states.
+* ``variables.py`` — C-25 (``var_dict`` round-trips through ``jax.jit``),
+ C-26 (``Variable`` keeps ``batch_axis``/``axis_names`` through
+ flatten/unflatten, ``jit``, ``grad``, ``vmap``), H-06 (``Variable.value``
+ setter accepts a ``brainstate.State`` and a float64 numpy array and
+ canonicalizes), H-45 (``size_without_batch`` returns a shape tuple).
+* ``base.py`` — H-05 (``.cpu()`` moves variables and injects no junk
+ attributes), H-08 (``register_implicit_vars`` accepts ``var_list``/``var_dict``).
+* ``naming.py`` — H-07 (creating + discarding many named objects does not
+ raise ``UniqueNameError`` and the ``_name2id`` registry stays bounded).
+
+All tests assert the CORRECT post-fix behaviour. They use tiny array sizes so
+the whole module runs in a few seconds.
+"""
+
+import gc
+import importlib
+import warnings
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pytest
+
+import brainstate
+
+import brainpy as bp
+import brainpy.math as bm
+from brainpy._errors import MathError, UniqueNameError
+
+# NOTE: ``from brainpy.math.object_transform import jit`` would bind the *jit
+# function* (re-exported in the package ``__init__``), not the submodule. Import
+# the submodule explicitly so we can monkeypatch its module-level ``jit`` that
+# ``cls_jit`` calls into.
+jit_module = importlib.import_module('brainpy.math.object_transform.jit')
+from brainpy.math.object_transform import naming
+from brainpy.math.object_transform.collectors import ArrayCollector
+from brainpy.math.object_transform.variables import (
+ Variable, TrainVar, Parameter, VariableView, VarList, VarDict,
+)
+from brainpy.math.object_transform.base import BrainPyObject, FunAsObject
+from brainpy.math.object_transform._utils import (
+ warp_to_no_state_input_output, infer_dyn_vars, get_brainpy_object,
+)
+
+
+# ---------------------------------------------------------------------------
+# helpers
+# ---------------------------------------------------------------------------
+
+class _Obj(bp.BrainPyObject):
+ """A tiny BrainPyObject owning a single Variable."""
+
+ def __init__(self, n=3):
+ super().__init__()
+ self.v = bm.Variable(jnp.ones(n))
+
+
+def _capture_cls_jit_argnums(static_argnums=None, donate_argnums=None):
+ """Apply ``cls_jit`` and capture the ``static_argnums``/``donate_argnums``
+ that it forwards into the underlying ``jit`` call, *without* actually
+ invoking ``brainstate.transform.jit`` (which rejects negative argnums).
+ """
+ captured = {}
+ orig_jit = jit_module.jit
+
+ def spy(*args, **kwargs):
+ captured['static_argnums'] = kwargs.get('static_argnums')
+ captured['donate_argnums'] = kwargs.get('donate_argnums')
+ raise _Stop()
+
+ class _Stop(Exception):
+ pass
+
+ jit_module.jit = spy
+ try:
+ try:
+ jit_module.cls_jit(
+ static_argnums=static_argnums,
+ donate_argnums=donate_argnums,
+ )(lambda self, *a, **k: a)
+ except _Stop:
+ pass
+ finally:
+ jit_module.jit = orig_jit
+ return captured['static_argnums'], captured['donate_argnums']
+
+
+# ===========================================================================
+# jit.py
+# ===========================================================================
+
+def test_cls_jit_positive_static_argnums_shifted_once():
+ """H-01: a positive user index N is shifted to N+1 (account for ``self``),
+ and ``self`` (index 0) is marked static exactly once."""
+ static, donate = _capture_cls_jit_argnums(static_argnums=1)
+ assert static == (0, 2)
+ assert donate == ()
+
+
+def test_cls_jit_negative_static_argnums_not_corrupted():
+ """H-01: the historical bug shifted ``-1`` to ``0`` and produced
+ ``(0, 0)`` (``self`` marked static twice + wrong target). The fix leaves
+ negative indices unshifted, so ``self`` (0) appears exactly once and the
+ negative index is preserved -- NOT collapsed into a duplicate ``0``."""
+ static, donate = _capture_cls_jit_argnums(static_argnums=-1)
+ assert static == (0, -1)
+ # the corrupting outcome would have been (0, 0); make that explicit.
+ assert static != (0, 0)
+ assert static.count(0) == 1
+
+
+def test_cls_jit_list_static_argnums_dedup_and_shift():
+ """H-01: list of positive indices are each shifted by +1, ``self`` is
+ prepended, and duplicates are removed."""
+ static, _ = _capture_cls_jit_argnums(static_argnums=[0, 2])
+ assert static == (0, 1, 3)
+
+
+def test_cls_jit_donate_argnums_shifted_so_self_not_donated():
+ """M-02: ``donate_argnums`` is shifted by +1, so a user index 1 becomes 2
+ and ``self`` (index 0) is never donated."""
+ static, donate = _capture_cls_jit_argnums(static_argnums=[0, 2], donate_argnums=1)
+ assert donate == (2,)
+ static2, donate2 = _capture_cls_jit_argnums(donate_argnums=[3, 4])
+ assert donate2 == (4, 5)
+ # default static is just (self,)
+ assert static2 == (0,)
+
+
+def test_cls_jit_runs_on_bound_method_with_positive_static():
+ """H-01 end-to-end: a bound method jitted with a positive ``static_argnums``
+ runs and mutates the owned Variable correctly."""
+
+ class Prog(bp.BrainPyObject):
+ def __init__(self):
+ super().__init__()
+ self.b = bm.Variable(jnp.zeros(2))
+
+ # user index 0 (``scale``, a hashable int) is static; after the +1 shift
+ # it becomes index 1 in the bound signature ``(self, scale, x)``.
+ @bm.cls_jit(static_argnums=0)
+ def run(self, scale, x):
+ self.b.value = self.b.value + scale * x
+ return self.b.value
+
+ p = Prog()
+ out = p.run(2, jnp.ones(2))
+ np.testing.assert_allclose(np.asarray(out), [2., 2.])
+
+
+def test_cls_jit_invalid_argnums_type_raises():
+ with pytest.raises(ValueError):
+ bm.cls_jit(static_argnums=1.5)(lambda self: None)
+ with pytest.raises(ValueError):
+ bm.cls_jit(donate_argnums=1.5)(lambda self: None)
+
+
+def test_jit_dyn_vars_child_objs_deprecation_pops_kwargs():
+ """H-04: ``dyn_vars``/``child_objs`` are no longer forwarded to brainstate
+ (which would raise ``TypeError``); they are popped with a one-time
+ ``DeprecationWarning`` and the function still works."""
+ with warnings.catch_warnings(record=True) as caught:
+ warnings.simplefilter('always')
+ fn = bm.jit(lambda x: x + 1, dyn_vars=[], child_objs=[])
+ result = fn(jnp.asarray(1.0))
+ assert float(result) == 2.0
+ categories = [w.category for w in caught]
+ assert sum(issubclass(c, DeprecationWarning) for c in categories) >= 2
+
+
+def test_jit_on_object_method():
+ """``bm.jit`` JIT-compiles a BrainPyObject bound method."""
+
+ class Hello(bp.BrainPyObject):
+ def __init__(self):
+ super().__init__()
+ self.a = bm.Variable(jnp.asarray(10.))
+
+ def transform(self):
+ self.a.value = self.a.value * 2
+ return self.a.value
+
+ h = Hello()
+ jfn = bm.jit(h.transform)
+ assert float(jfn()) == 20.0
+ assert float(jfn()) == 40.0
+
+
+def test_jit_decorator_form_on_pure_function():
+ @bm.jit
+ def selu(x, alpha=1.67, lmbda=1.05):
+ return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
+
+ out = selu(jnp.asarray([1.0, -1.0]))
+ assert out.shape == (2,)
+
+
+def test_jit_static_argnums_on_pure_function():
+ @bm.jit(static_argnums=(1,))
+ def f(x, n):
+ return x + n
+
+ assert float(f(jnp.asarray(1.0), 5)) == 6.0
+
+
+# ===========================================================================
+# controls.py
+# ===========================================================================
+
+def test_cond_accepts_variable_operand():
+ """H-02: a Variable passed in ``operands`` of ``cond`` is unwrapped to a
+ raw jax array and does not raise."""
+ a = bm.Variable(bm.zeros(2))
+
+ def true_f(x):
+ a.value = a.value + x
+ return a.value
+
+ def false_f(x):
+ return a.value
+
+ bm.cond(True, true_f, false_f, bm.Variable(bm.ones(2)))
+ np.testing.assert_allclose(np.asarray(a.value), [1., 1.])
+
+
+def test_cond_accepts_array_operand_and_scalar_operand():
+ out = bm.cond(True, lambda x: x + 1, lambda x: x - 1, bm.asarray([1., 2.]))
+ np.testing.assert_allclose(np.asarray(out), [2., 3.])
+ # scalar operand (wrapped into a tuple internally)
+ out2 = bm.cond(False, lambda x: x + 1, lambda x: x - 1, 5.0)
+ assert float(out2) == 4.0
+
+
+def test_for_loop_accepts_variable_operand():
+ """H-02: ``bm.for_loop(lambda x: x+1, bm.arange(1,5))`` works with a
+ BrainPy Array operand."""
+ out = bm.for_loop(lambda x: x + 1, bm.arange(1, 5))
+ np.testing.assert_allclose(np.asarray(out).ravel(), [2, 3, 4, 5])
+
+
+def test_for_loop_variable_state_accumulation():
+ a = bm.Variable(bm.zeros(1))
+ b = bm.Variable(bm.ones(1))
+
+ def body(x):
+ a.value += x
+ b.value *= x
+ return a.value
+
+ hist = bm.for_loop(body, operands=bm.arange(1, 5))
+ np.testing.assert_allclose(np.asarray(hist).ravel(), [1, 3, 6, 10])
+ np.testing.assert_allclose(np.asarray(a.value), [10.])
+ np.testing.assert_allclose(np.asarray(b.value), [24.])
+
+
+def test_for_loop_multiple_operands():
+ a = bm.Variable(bm.zeros(1))
+
+ def body(x, y):
+ a.value += x + y
+ return a.value
+
+ hist = bm.for_loop(body, operands=(bm.arange(1, 5), bm.arange(2, 6)))
+ assert np.asarray(hist).shape == (4, 1)
+
+
+def test_for_loop_jit_false_normal_path():
+ a = bm.Variable(bm.zeros(1))
+
+ def body(x):
+ a.value += x
+ return a.value
+
+ hist = bm.for_loop(body, bm.arange(1., 4.), jit=False)
+ np.testing.assert_allclose(np.asarray(hist).ravel(), [1, 3, 6])
+
+
+def test_for_loop_jit_false_zero_length_pytree_returns_empty():
+ """H-03: ``for_loop(jit=False)`` with a zero-length *pytree* (dict) operand
+ must not crash on ``operands[0].shape`` -- the leading length is computed
+ from ``jax.tree.leaves``. It falls back to JIT mode (UserWarning) and
+ returns an empty result."""
+ with warnings.catch_warnings(record=True) as caught:
+ warnings.simplefilter('always')
+ out = bm.for_loop(lambda d: d['x'] + 1, {'x': bm.arange(0.)}, jit=False)
+ assert len(np.asarray(out)) == 0
+ assert any(issubclass(w.category, UserWarning) for w in caught)
+
+
+def test_for_loop_progress_bar_variants():
+ out = bm.for_loop(lambda x: x + 1, bm.arange(1, 4), progress_bar=True)
+ assert np.asarray(out).shape == (3,)
+ out2 = bm.for_loop(lambda x: x + 1, bm.arange(1, 4), progress_bar=2)
+ assert np.asarray(out2).shape == (3,)
+
+
+def test_for_loop_invalid_progress_bar_raises():
+ with pytest.raises(TypeError):
+ bm.for_loop(lambda x: x + 1, bm.arange(1, 4), progress_bar="bad")
+
+
+def test_scan_accepts_variable_operand_and_returns_carry_ys():
+ """H-02 + M-03: ``scan`` accepts a Variable in operands and returns the
+ documented ``(final_carry, stacked_ys)`` two-tuple."""
+ carry, ys = bm.scan(lambda c, x: (c + x, c), 0., bm.Variable(bm.arange(1., 4.)))
+ assert float(carry) == 6.0
+ np.testing.assert_allclose(np.asarray(ys).ravel(), [0., 1., 3.])
+
+
+def test_while_loop_accepts_variable_operand():
+ """H-02: ``while_loop`` accepts a Variable operand (it is unwrapped to a raw
+ jax array rather than being rejected at brainstate cache-key time). The body
+ returns the updated operands, preserving the single-operand tuple structure."""
+ res = bm.while_loop(lambda x: (x + 1.,), lambda x: x < 3., (bm.Variable(jnp.asarray(0.)),))
+ assert float(np.asarray(res[0])) == 3.0
+
+
+def test_while_loop_state_mutation():
+ a = bm.Variable(bm.zeros(1))
+ b = bm.Variable(bm.ones(1))
+
+ def cond_f(x, y):
+ return x < 6.
+
+ def body_f(x, y):
+ a.value += x
+ b.value *= y
+ return x + b[0], y + 1.
+
+ res = bm.while_loop(body_f, cond_f, operands=(1., 1.))
+ assert len(res) == 2
+
+
+def test_while_loop_body_returning_none_raises():
+ """M-06: a ``while_loop`` body that returns ``None`` would freeze the carry
+ and loop forever -- it must raise a clear ``ValueError`` instead."""
+
+ def body(x):
+ # returns None -> illegal
+ pass
+
+ with pytest.raises(ValueError):
+ bm.while_loop(body, lambda x: x < 3., 0.)
+
+
+def test_ifelse_callable_branches_mutually_exclusive():
+ """M-05: ``ifelse`` resolves to the first matching branch (mutually
+ exclusive conditions, default branch last)."""
+
+ def f(a):
+ return bm.ifelse(
+ conditions=[a > 10, a > 5, a > 2, a > 0],
+ branches=[lambda: 1, lambda: 2, lambda: 3, lambda: 4, lambda: 5],
+ )
+
+ assert int(f(11)) == 1
+ assert int(f(7)) == 2
+ assert int(f(1)) == 4
+ assert int(f(-5)) == 5
+
+
+def test_ifelse_non_callable_branches():
+ def f(a):
+ return bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
+ branches=[1, 2, 3, 4, 5])
+
+ assert int(f(3)) == 3
+
+
+def test_ifelse_with_operands_and_variable():
+ out = bm.ifelse(
+ conditions=[True],
+ branches=[lambda x: x + 1, lambda x: x - 1],
+ operands=bm.Variable(jnp.asarray(10.)),
+ )
+ assert float(out) == 11.0
+
+
+# ===========================================================================
+# _utils.py
+# ===========================================================================
+
+def test_warp_to_no_state_input_output_strips_and_restores():
+ """``warp_to_no_state_input_output`` removes ``State`` wrappers from inputs
+ and outputs, passing through plain jax arrays."""
+
+ def fn(x):
+ # the wrapper should have unwrapped the State to a jax array
+ assert not isinstance(x, brainstate.State)
+ return x * 2
+
+ wrapped = warp_to_no_state_input_output(fn)
+ out = wrapped(bm.Variable(jnp.asarray(3.)))
+ assert not isinstance(out, brainstate.State)
+ assert float(out) == 6.0
+
+
+def test_warp_to_no_state_passes_missing_through():
+ missing = brainstate.typing.Missing()
+ assert warp_to_no_state_input_output(missing) is missing
+
+
+def test_infer_dyn_vars_and_get_brainpy_object():
+ obj = _Obj()
+ dv = infer_dyn_vars(obj)
+ assert len(dv) >= 1
+ mapping = get_brainpy_object(obj)
+ assert obj.name in mapping
+ # bound method path
+ bound = infer_dyn_vars(obj.vars)
+ assert isinstance(bound, ArrayCollector)
+ # non-object path returns empty collector / dict
+ assert len(infer_dyn_vars(lambda: None)) == 0
+ assert get_brainpy_object(lambda: None) == {}
+
+
+# ===========================================================================
+# function.py
+# ===========================================================================
+
+def test_partial_binds_positional_args():
+ add = bm.Partial(lambda x, y: x + y, 1)
+ assert add(2) == 3
+
+
+def test_partial_keyword_override_and_is_brainpy_object():
+ p = bm.Partial(lambda x, scale=1.: x * scale, scale=2.)
+ assert isinstance(p, FunAsObject)
+ assert float(p(3.)) == 6.0
+ # call-time keyword overrides bound keyword
+ assert float(p(3., scale=10.)) == 30.0
+
+
+def test_partial_tracks_variables():
+ sub = _Obj()
+ p = bm.Partial(lambda: sub.v.value, child_objs=sub)
+ np.testing.assert_allclose(np.asarray(p()), np.ones(3))
+ assert sub.name in p.nodes()
+
+
+def test_to_object_decorator_and_direct():
+ sub = _Obj()
+ obj = bm.to_object(lambda: sub.v.value, child_objs=sub)
+ np.testing.assert_allclose(np.asarray(obj()), np.ones(3))
+
+ @bm.to_object(child_objs=sub)
+ def fn():
+ return sub.v.value
+
+ np.testing.assert_allclose(np.asarray(fn()), np.ones(3))
+
+
+def test_to_object_requires_child_objs_when_f_given():
+ with pytest.raises(ValueError):
+ bm.to_object(lambda: 1.)
+
+
+def test_function_is_deprecated():
+ """L-04: ``function`` is deprecated in favour of ``to_object``."""
+ sub = _Obj()
+ with warnings.catch_warnings(record=True) as caught:
+ warnings.simplefilter('always')
+ obj = bm.function(lambda: sub.v.value, nodes=sub)
+ assert any(issubclass(w.category, DeprecationWarning) for w in caught)
+ np.testing.assert_allclose(np.asarray(obj()), np.ones(3))
+
+
+# ===========================================================================
+# variables.py
+# ===========================================================================
+
+def test_var_dict_round_trips_through_jax_jit():
+ """C-25: ``VarDict.tree_unflatten`` no longer uses the removed ``jax.util``;
+ a ``var_dict`` survives ``jax.jit``."""
+ d = bm.var_dict({'a': bm.Variable(jnp.ones(2))})
+ out = jax.jit(lambda dd: dd)(d)
+ assert isinstance(out, VarDict)
+ assert list(out.keys()) == ['a']
+ np.testing.assert_allclose(np.asarray(out['a'].value), np.ones(2))
+
+
+def test_var_dict_tree_map():
+ d = bm.var_dict({'a': bm.Variable(jnp.ones(2)), 'b': bm.Variable(jnp.zeros(2))})
+ leaves = jax.tree.leaves(d)
+ assert len(leaves) == 2
+
+
+def test_var_list_round_trips_through_pytree():
+ vl = bm.var_list([bm.Variable(jnp.ones(1)), bm.Variable(jnp.zeros(2))])
+ leaves, treedef = jax.tree.flatten(vl)
+ rebuilt = jax.tree.unflatten(treedef, leaves)
+ assert isinstance(rebuilt, VarList)
+ assert len(rebuilt) == 2
+
+
+def test_variable_keeps_metadata_through_flatten_unflatten():
+ """C-26: ``batch_axis``/``axis_names`` survive a manual pytree round-trip."""
+ v = bm.Variable(jnp.zeros((3, 4)), batch_axis=0, axis_names=('a', 'b'))
+ assert v.batch_axis == 0
+ assert v.axis_names == ('a', 'b')
+ leaves, treedef = jax.tree.flatten(v)
+ v2 = jax.tree.unflatten(treedef, leaves)
+ assert isinstance(v2, Variable)
+ assert v2.batch_axis == 0
+ assert v2.axis_names == ('a', 'b')
+
+
+def test_variable_keeps_metadata_through_jit():
+ """C-26: metadata preserved through ``jax.jit``."""
+ v = bm.Variable(jnp.zeros((3, 4)), batch_axis=0, axis_names=('a', 'b'))
+ out = jax.jit(lambda x: x)(v)
+ assert out.batch_axis == 0
+ assert out.axis_names == ('a', 'b')
+
+
+def test_variable_keeps_metadata_through_grad():
+ """C-26: metadata preserved through ``jax.grad``."""
+ v = bm.Variable(jnp.ones((3, 4)), batch_axis=0, axis_names=('a', 'b'))
+ g = jax.grad(lambda x: jnp.sum(x.value ** 2))(v)
+ assert isinstance(g, Variable)
+ assert g.batch_axis == 0
+ assert g.axis_names == ('a', 'b')
+
+
+def test_variable_keeps_metadata_through_vmap():
+ """C-26: metadata preserved through ``jax.vmap``."""
+ v = bm.Variable(jnp.zeros((3, 4)), batch_axis=0, axis_names=('a', 'b'))
+ out = jax.vmap(lambda x: x)(v)
+ assert isinstance(out, Variable)
+ assert out.batch_axis == 0
+
+
+def test_variable_value_setter_accepts_brainstate_state():
+ """H-06: ``Variable.value = some_State`` unwraps the state first."""
+ v = bm.Variable(jnp.zeros(3, dtype=jnp.float32))
+ st = brainstate.State(jnp.ones(3, dtype=jnp.float32))
+ v.value = st
+ np.testing.assert_allclose(np.asarray(v.value), np.ones(3))
+
+
+def test_variable_value_setter_canonicalizes_float64_numpy():
+ """H-06: assigning a float64 numpy array to a float32 Variable canonicalizes
+ (converts) it rather than raising a spurious dtype ``MathError``."""
+ v = bm.Variable(jnp.zeros(3, dtype=jnp.float32))
+ arr = np.ones(3, dtype=np.float64)
+ v.value = arr # must not raise
+ assert v.dtype == jnp.float32
+ np.testing.assert_allclose(np.asarray(v.value), np.ones(3))
+
+
+def test_variable_value_setter_accepts_brainpy_array():
+ v = bm.Variable(jnp.zeros(3, dtype=jnp.float32))
+ v.value = bm.asarray(np.arange(3), dtype=jnp.float32)
+ np.testing.assert_allclose(np.asarray(v.value), [0., 1., 2.])
+
+
+def test_variable_value_setter_shape_mismatch_raises():
+ v = bm.Variable(jnp.zeros(3))
+ with pytest.raises(MathError):
+ v.value = jnp.zeros(4)
+
+
+def test_size_without_batch_returns_shape_tuple():
+ """H-45: ``size_without_batch`` returns a *shape tuple* (drops the batch
+ axis), not an integer element count."""
+ v = bm.Variable(jnp.zeros((3, 4)), batch_axis=0)
+ assert v.size_without_batch == (4,)
+ v2 = bm.Variable(jnp.zeros((3, 4)))
+ assert v2.size_without_batch == (3, 4)
+
+
+def test_variable_batch_size_and_axis():
+ v = bm.Variable(jnp.zeros((5, 4)), batch_axis=0)
+ assert v.batch_size == 5
+ v2 = bm.Variable(jnp.zeros(4))
+ assert v2.batch_size is None
+
+
+def test_variable_batch_axis_immutable():
+ v = bm.Variable(jnp.zeros((5, 4)), batch_axis=0)
+ with pytest.raises(ValueError):
+ v.batch_axis = 1
+ with pytest.raises(ValueError):
+ v.batch_size = 2
+
+
+def test_variable_invalid_batch_axis_raises():
+ with pytest.raises(MathError):
+ bm.Variable(jnp.zeros(3), batch_axis=5)
+
+
+def test_variable_init_from_size_and_hash():
+ v = bm.Variable(4)
+ assert v.shape == (4,)
+ np.testing.assert_allclose(np.asarray(v.value), np.zeros(4))
+ # identity hash
+ assert hash(v) == id(v)
+ s = {v, v}
+ assert len(s) == 1
+
+
+def test_trainvar_and_parameter_are_variables():
+ tv = bm.TrainVar(jnp.zeros(2))
+ par = bm.Parameter(jnp.ones(2))
+ assert isinstance(tv, Variable)
+ assert isinstance(par, Variable)
+
+
+def test_variable_view_reads_and_writes_origin():
+ origin = bm.Variable(jnp.arange(5.))
+ view = bm.VariableView(origin, slice(None, 2, None))
+ np.testing.assert_allclose(np.asarray(view.value), [0., 1.])
+ view.value = jnp.asarray([10., 11.])
+ np.testing.assert_allclose(np.asarray(origin.value)[:2], [10., 11.])
+ assert 'VariableView' in repr(view)
+
+
+def test_variable_view_requires_variable():
+ with pytest.raises(ValueError):
+ bm.VariableView(jnp.zeros(3), slice(None))
+
+
+# ===========================================================================
+# base.py
+# ===========================================================================
+
+def test_cpu_moves_variable_and_injects_no_junk_attrs():
+ """H-05: ``.cpu()`` iterates real Variables and moves them; it must not add
+ dict-valued junk attributes named after nodes."""
+ obj = _Obj()
+ before = set(obj.__dict__.keys())
+ returned = obj.cpu()
+ after = set(obj.__dict__.keys())
+ assert after == before # no junk attributes injected
+ assert returned is obj
+ assert isinstance(obj.v, Variable)
+ np.testing.assert_allclose(np.asarray(obj.v.value), np.ones(3))
+
+
+def test_to_moves_variable_to_device():
+ obj = _Obj()
+ dev = jax.devices('cpu')[0]
+ obj.to(device=dev)
+ assert isinstance(obj.v, Variable)
+
+
+def test_register_implicit_vars_with_var_list():
+ """H-08: ``register_implicit_vars(var_list([...]))`` flattens the container
+ into the ``ArrayCollector`` (which only accepts plain Variables)."""
+ obj = _Obj()
+ obj.register_implicit_vars(bm.var_list([bm.Variable(jnp.zeros(2))]))
+ assert len(obj.implicit_vars) == 1
+ for v in obj.implicit_vars.values():
+ assert isinstance(v, Variable)
+
+
+def test_register_implicit_vars_with_var_dict():
+ obj = _Obj()
+ obj.register_implicit_vars(bm.var_dict({'x': bm.Variable(jnp.zeros(2)),
+ 'y': bm.Variable(jnp.zeros(2))}))
+ assert len(obj.implicit_vars) == 2
+ for v in obj.implicit_vars.values():
+ assert isinstance(v, Variable)
+
+
+def test_register_implicit_vars_plain_and_named():
+ obj = _Obj()
+ obj.register_implicit_vars(bm.Variable(jnp.zeros(1)), named=bm.Variable(jnp.zeros(1)))
+ assert len(obj.implicit_vars) == 2
+
+
+def test_register_implicit_vars_rejects_non_variable():
+ obj = _Obj()
+ with pytest.raises(ValueError):
+ obj.register_implicit_vars(123)
+
+
+def test_vars_collects_variables_and_containers():
+ class Multi(bp.BrainPyObject):
+ def __init__(self):
+ super().__init__()
+ self.a = bm.Variable(jnp.zeros(1))
+ self.lst = bm.var_list([bm.Variable(jnp.zeros(1))])
+ self.dct = bm.var_dict({'k': bm.Variable(jnp.zeros(1))})
+
+ m = Multi()
+ collected = m.vars()
+ assert len(collected) == 3
+ assert len(m.train_vars()) == 0
+
+
+def test_nodes_and_state_dict_round_trip():
+ class Parent(bp.BrainPyObject):
+ def __init__(self):
+ super().__init__()
+ self.child = _Obj()
+ self.w = bm.Variable(jnp.zeros(2))
+
+ p = Parent()
+ nodes = p.nodes()
+ assert p.name in nodes
+ sd = p.state_dict()
+ assert isinstance(sd, dict)
+ # load it back
+ p.load_state_dict(sd, warn=False)
+
+
+def test_brainpyobject_setattr_updates_variable_value():
+ """Setting an attribute that is a Variable updates its value in place."""
+ obj = _Obj(n=2)
+ vid = id(obj.v)
+ obj.v = jnp.asarray([5., 6.])
+ assert id(obj.v) == vid # same Variable object
+ np.testing.assert_allclose(np.asarray(obj.v.value), [5., 6.])
+
+
+def test_funasobject_call_and_repr():
+ sub = _Obj()
+ fo = FunAsObject(target=lambda: sub.v.value, child_objs=sub)
+ np.testing.assert_allclose(np.asarray(fo()), np.ones(3))
+ assert 'FunAsObject' in repr(fo)
+
+
+def test_objecttransform_base():
+ from brainpy.math.object_transform.base import ObjectTransform
+ ot = ObjectTransform()
+ assert repr(ot) == 'ObjectTransform'
+ with pytest.raises(NotImplementedError):
+ ot()
+
+
+def test_node_list_and_node_dict():
+ from brainpy.math.object_transform.base import node_list, node_dict
+ nl = node_list([_Obj(), _Obj()])
+ assert len(nl) == 2
+ nd = node_dict({'a': _Obj()})
+ assert 'a' in nd
+
+
+def test_tracing_variable_raises_not_implemented():
+ """L-06: ``tracing_variable`` is deprecated and always raises."""
+ obj = _Obj()
+ with pytest.raises(NotImplementedError):
+ obj.tracing_variable('w', jnp.zeros, (2,))
+
+
+# ===========================================================================
+# naming.py
+# ===========================================================================
+
+def test_many_named_objects_do_not_raise_unique_name_error():
+ """H-07: creating and discarding many named objects must not raise
+ ``UniqueNameError`` from reused ids."""
+ bm.clear_name_cache()
+ for _ in range(300):
+ _Obj() # transient, immediately discarded
+ gc.collect()
+ # creating more after GC must still not collide
+ keep = [_Obj() for _ in range(5)]
+ assert len(keep) == 5
+
+
+def test_name2id_registry_stays_bounded():
+ """H-07: the ``_name2id`` registry prunes dead weak refs and stays bounded
+ instead of growing unboundedly."""
+ bm.clear_name_cache()
+ for _ in range(400):
+ _Obj()
+ gc.collect()
+ # all transient objects collected -> registry pruned back to (near) empty.
+ assert len(naming._name2id) <= 5
+
+
+def test_explicit_duplicate_name_raises_unique_name_error():
+ bm.clear_name_cache()
+ keep = _Obj()
+ keep.name = 'my_unique_name'
+ other = _Obj()
+ with pytest.raises(UniqueNameError):
+ other.name = 'my_unique_name'
+
+
+def test_invalid_identifier_name_raises():
+ obj = _Obj()
+ with pytest.raises(bp.errors.BrainPyError):
+ obj.name = '123 not valid'
+
+
+def test_get_unique_name_increments():
+ n1 = naming.get_unique_name('Foo')
+ n2 = naming.get_unique_name('Foo')
+ assert n1 != n2
+ assert n1.startswith('Foo')
+
+
+def test_clear_name_cache_warns_when_requested():
+ with warnings.catch_warnings(record=True) as caught:
+ warnings.simplefilter('always')
+ bm.clear_name_cache(ignore_warn=False)
+ assert any(issubclass(w.category, UserWarning) for w in caught)
+
+
+def test_stack_cache_helpers():
+ def fn():
+ return None
+
+ naming.cache_stack(fn, [1, 2, 3])
+ assert naming.get_stack_cache(fn) == [1, 2, 3]
+ assert naming.get_stack_cache(lambda: None) is None
+ naming.clear_stack_cache()
+ assert naming.get_stack_cache(fn) is None
+
+
+# ===========================================================================
+# collectors (ArrayCollector) exercised via base/vars APIs
+# ===========================================================================
+
+def test_array_collector_basic_ops():
+ ac = ArrayCollector()
+ v = bm.Variable(jnp.zeros(1))
+ ac['x'] = v
+ assert len(ac.unique()) == 1
+ assert 'x' in ac.dict()
+ # subset by type
+ tv = bm.TrainVar(jnp.zeros(1))
+ ac['t'] = tv
+ assert len(ac.subset(TrainVar)) == 1
+
+
+# ===========================================================================
+# Additional coverage for variables.py
+# ===========================================================================
+
+def test_variable_init_from_list_of_int_size():
+ v = bm.Variable([2, 3])
+ assert v.shape == (2, 3)
+
+
+def test_variable_axis_names_inserts_batch_axis_name():
+ """The axis_names insertion path: when ``len(axis_names) + 1 == ndim`` the
+ batch-axis name is inserted at ``batch_axis``."""
+ from brainpy.math.sharding import BATCH_AXIS
+ v = bm.Variable(jnp.zeros((5, 4)), batch_axis=0, axis_names=('feat',))
+ assert len(v.axis_names) == 2
+ assert v.axis_names[0] == BATCH_AXIS
+ assert v.axis_names[1] == 'feat'
+
+
+def test_variable_value_setter_with_batch_axis():
+ v = bm.Variable(jnp.zeros((5, 4)), batch_axis=0)
+ v.value = jnp.ones((7, 4)) # different batch size OK
+ assert v.shape == (7, 4)
+ with pytest.raises(MathError):
+ v.value = jnp.ones((5, 9)) # non-batch dim mismatch
+
+
+def test_variable_value_setter_dtype_mismatch_raises():
+ v = bm.Variable(jnp.zeros(3, dtype=jnp.float32))
+ with pytest.raises(MathError):
+ v.value = jnp.zeros(3, dtype=jnp.int32)
+
+
+def test_var_list_append_rejects_non_variable():
+ vl = bm.var_list()
+ with pytest.raises(TypeError):
+ vl.append(jnp.zeros(2))
+
+
+def test_var_list_setitem_int_updates_value_slice_replaces():
+ a = bm.Variable(jnp.zeros(1))
+ b = bm.Variable(jnp.zeros(2))
+ vl = bm.var_list([a, b])
+ ids = (id(vl[0]), id(vl[1]))
+ vl[0] = jnp.ones(1) # updates value in place, keeps the same Variable
+ assert id(vl[0]) == ids[0]
+ np.testing.assert_allclose(np.asarray(vl[0].value), [1.])
+ # slice assignment replaces entries
+ vl[0:1] = [bm.Variable(jnp.zeros(1))]
+ assert len(vl) == 2
+
+
+def test_var_dict_update_existing_key_sets_value():
+ d = bm.var_dict({'a': bm.Variable(jnp.zeros(2))})
+ original_id = id(d['a'])
+ d['a'] = jnp.ones(2) # update value, keep the same Variable
+ assert id(d['a']) == original_id
+ np.testing.assert_allclose(np.asarray(d['a'].value), np.ones(2))
+
+
+def test_var_dict_rejects_non_variable_element():
+ with pytest.raises(TypeError):
+ bm.var_dict({'a': jnp.zeros(2)})
+
+
+def test_var_dict_update_from_tuple_and_kwargs():
+ d = bm.var_dict(('a', bm.Variable(jnp.zeros(1))), b=bm.Variable(jnp.zeros(1)))
+ assert set(d.keys()) == {'a', 'b'}
+
+
+def test_variable_view_setter_shape_and_dtype_checks():
+ origin = bm.Variable(jnp.arange(5.))
+ view = bm.VariableView(origin, slice(None, 2, None))
+ with pytest.raises(MathError):
+ view.value = jnp.zeros(3) # wrong shape
+
+
+# ===========================================================================
+# Additional coverage for base.py
+# ===========================================================================
+
+def test_relative_vars_and_nodes():
+ class Parent(bp.BrainPyObject):
+ def __init__(self):
+ super().__init__()
+ self.child = _Obj()
+ self.w = bm.Variable(jnp.zeros(2))
+
+ p = Parent()
+ rel_vars = p.vars(method='relative')
+ assert len(rel_vars) >= 2
+ rel_nodes = p.nodes(method='relative')
+ assert '' in rel_nodes
+
+
+def test_vars_with_level_zero():
+ class Parent(bp.BrainPyObject):
+ def __init__(self):
+ super().__init__()
+ self.child = _Obj()
+ self.w = bm.Variable(jnp.zeros(2))
+
+ p = Parent()
+ # level=0 -> only self's own variable, not the child's
+ own = p.vars(level=0)
+ assert len(own) == 1
+
+
+def test_register_implicit_nodes_variants():
+ parent = _Obj()
+ parent.register_implicit_nodes(_Obj())
+ parent.register_implicit_nodes([_Obj(), _Obj()])
+ parent.register_implicit_nodes({'k': _Obj()})
+ parent.register_implicit_nodes(named=_Obj())
+ # nodes are keyed by name; distinct objects -> at least 5 distinct entries.
+ assert len(parent.implicit_nodes) >= 5
+
+
+def test_register_implicit_nodes_rejects_bad_type():
+ parent = _Obj()
+ with pytest.raises(ValueError):
+ parent.register_implicit_nodes(123)
+
+
+def test_register_implicit_vars_from_list_and_dict_args():
+ obj = _Obj()
+ obj.register_implicit_vars([bm.Variable(jnp.zeros(1)), bm.Variable(jnp.zeros(1))])
+ obj.register_implicit_vars({'k': bm.Variable(jnp.zeros(1))})
+ assert len(obj.implicit_vars) == 3
+ with pytest.raises(ValueError):
+ obj.register_implicit_vars([123])
+ with pytest.raises(ValueError):
+ obj.register_implicit_vars({'bad': 123})
+ with pytest.raises(ValueError):
+ obj.register_implicit_vars(named=123)
+
+
+def test_node_dict_check_unique_raises_on_conflict():
+ from brainpy.math.object_transform.base import NodeDict
+ a, b = _Obj(), _Obj()
+ nd = NodeDict(check_unique=True)
+ nd['k'] = a
+ nd['k'] = a # same object, OK
+ with pytest.raises(KeyError):
+ nd['k'] = b # different object under same key
+
+
+def test_node_list_and_node_dict_in_find_nodes():
+ from brainpy.math.object_transform.base import node_list, node_dict
+
+ class Parent(bp.BrainPyObject):
+ def __init__(self):
+ super().__init__()
+ self.lst = node_list([_Obj(), _Obj()])
+ self.dct = node_dict({'a': _Obj()})
+
+ p = Parent()
+ nodes = p.nodes()
+ # parent + 2 list children + 1 dict child
+ assert len(nodes) >= 4
+ rel_nodes = p.nodes(method='relative')
+ assert len(rel_nodes) >= 4
+
+
+def test_funasobject_repr_with_nodes():
+ sub = _Obj()
+ fo = FunAsObject(target=lambda: sub.v.value, child_objs=sub,
+ dyn_vars=bm.Variable(jnp.zeros(1)))
+ r = repr(fo)
+ assert 'FunAsObject' in r
+ assert 'num_of_vars' in r
+
+
+def test_save_and_load_state_methods():
+ obj = _Obj(n=2)
+ sd = obj.save_state()
+ assert isinstance(sd, dict)
+ missing, unexpected = obj.load_state(sd)
+ assert missing == [] and unexpected == []
+
+
+def test_vars_invalid_method_raises():
+ obj = _Obj()
+ with pytest.raises(ValueError):
+ obj.nodes(method='bad')
+
+
+def test_brainpyobject_tree_flatten_unflatten():
+ class Mixed(bp.BrainPyObject):
+ def __init__(self):
+ super().__init__()
+ self.v = bm.Variable(jnp.ones(2)) # dynamic
+ self.lst = bm.var_list([bm.Variable(jnp.zeros(1))]) # dynamic container
+ self.scalar = 7 # static
+
+ m = Mixed()
+ leaves, treedef = jax.tree.flatten(m)
+ assert len(leaves) >= 1
+ rebuilt = jax.tree.unflatten(treedef, leaves)
+ assert isinstance(rebuilt, Mixed)
+ assert rebuilt.scalar == 7
+
+
+def test_brainpyobject_setattr_method_bypasses_variable_update():
+ obj = _Obj()
+ # ``.setattr`` is the explicit object.__setattr__ wrapper
+ obj.setattr('plain', 123)
+ assert obj.plain == 123
+
+
+def test_load_state_dict_v1_and_warnings():
+ obj = _Obj(n=2)
+ # build a v1-style flat state dict
+ flat = {k: np.asarray(v.value) for k, v in obj.vars().items()}
+ res = obj.load_state_dict(flat, compatible='v1', warn=False)
+ assert res.missing_keys == [] and res.unexpected_keys == []
+ # unexpected + missing keys trigger warnings
+ with warnings.catch_warnings(record=True) as caught:
+ warnings.simplefilter('always')
+ obj.load_state_dict({'does_not_exist': np.zeros(2)}, compatible='v1', warn=True)
+ assert any(issubclass(w.category, UserWarning) for w in caught)
+
+
+def test_load_state_dict_invalid_compatible_raises():
+ obj = _Obj()
+ with pytest.raises(ValueError):
+ obj.load_state_dict({}, compatible='v9')
+
+
+def test_vars_exclude_types():
+ class Holder(bp.BrainPyObject):
+ def __init__(self):
+ super().__init__()
+ self.tv = bm.TrainVar(jnp.zeros(1))
+ self.v = bm.Variable(jnp.zeros(1))
+
+ h = Holder()
+ # excluding TrainVar drops it from the collection
+ kept = h.vars(exclude_types=(TrainVar,))
+ assert all(not isinstance(v, TrainVar) for v in kept.values())
+
+
+def test_unique_name_with_explicit_type():
+ obj = _Obj()
+ name = obj.unique_name(type_='CustomType')
+ assert name.startswith('CustomType')
+
+
+def test_node_dict_update_from_tuple():
+ from brainpy.math.object_transform.base import NodeDict
+ nd = NodeDict(('a', _Obj()))
+ assert 'a' in nd
+
+
+def test_brainpyobject_tree_flatten_unflatten_direct():
+ """Exercise ``tree_flatten``/``tree_unflatten`` directly (the pytree
+ registration is gated off by default, so ``jax.tree`` would treat the
+ object as a leaf)."""
+ class Mixed(bp.BrainPyObject):
+ def __init__(self):
+ super().__init__()
+ self.v = bm.Variable(jnp.ones(2))
+ self.scalar = 9
+
+ m = Mixed()
+ dynamic, aux = m.tree_flatten()
+ assert len(dynamic) == 1
+ rebuilt = Mixed.tree_unflatten(aux, dynamic)
+ assert rebuilt.scalar == 9
+ assert isinstance(rebuilt.v, Variable)
+
+
+def test_relative_nodes_nested_hierarchy():
+ """Cover the relative-method recursion that joins child paths."""
+ class Leaf(bp.BrainPyObject):
+ def __init__(self):
+ super().__init__()
+ self.w = bm.Variable(jnp.zeros(1))
+
+ class Middle(bp.BrainPyObject):
+ def __init__(self):
+ super().__init__()
+ self.leaf = Leaf()
+
+ class Top(bp.BrainPyObject):
+ def __init__(self):
+ super().__init__()
+ self.mid = Middle()
+
+ t = Top()
+ rel = t.nodes(method='relative')
+ # joined paths like 'mid.leaf' appear
+ assert any('.' in k for k in rel.keys() if k)
+ rel_vars = t.vars(method='relative')
+ assert len(rel_vars) >= 1
+
+
+def test_cuda_tpu_raise_without_device():
+ obj = _Obj()
+ with pytest.raises(RuntimeError):
+ obj.cuda()
+ with pytest.raises(RuntimeError):
+ obj.tpu()
diff --git a/tests/audit/test_train_analysis_glue_fixes.py b/tests/audit/test_train_analysis_glue_fixes.py
new file mode 100644
index 000000000..07d216a4e
--- /dev/null
+++ b/tests/audit/test_train_analysis_glue_fixes.py
@@ -0,0 +1,665 @@
+# -*- coding: utf-8 -*-
+"""Regression + coverage tests for the 2026-06-18 BrainPy audit.
+
+This module targets the train / analysis / top-level-glue fixes documented in
+``docs/issues-found-20260618.md`` for the following source files:
+
+* ``brainpy/algorithms/online.py`` — C-23 (block RLS for batch>1)
+* ``brainpy/algorithms/offline.py`` — H-46 (GD ``.value`` bug),
+ H-47 (ridge intercept penalty)
+* ``brainpy/running/jax_multiprocessing.py`` — H-48 (pmap reuse / labels)
+* ``brainpy/analysis/lowdim/lowdim_analyzer.py`` — H-49 (arg-unwrap),
+ H-50 (empty-candidate concat)
+* ``brainpy/analysis/utils/optimization.py`` — H-49 (arg-unwrap in roots_of_1d_by_x)
+* ``brainpy/runners.py`` — C-22 (memory_efficient DSRunner)
+* ``brainpy/measure.py`` — H-43 (firing_rate normalization)
+* ``brainpy/delay.py`` — H-44 (VarDelay self.data),
+ H-45 (size_without_batch)
+
+The tests are intentionally tiny (small nets, short durations) so the whole
+module runs in well under four minutes. They assert the *fixed* behavior; on the
+buggy pre-audit code each regression test would raise or diverge.
+"""
+
+import warnings
+
+import numpy as np
+import jax
+import jax.numpy as jnp
+import pytest
+
+import brainpy as bp
+import brainpy.math as bm
+
+warnings.filterwarnings("ignore")
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _fit_rls(batch_size, n_in=3, n_out=2, steps=400, seed=0):
+ """Fit a known linear map ``Y = X @ Wtrue`` with the RLS online algorithm.
+
+ Returns (final_error, has_nan).
+ """
+ from brainpy.algorithms.online import RLS
+
+ rng = np.random.RandomState(seed)
+ w_true = jnp.asarray(rng.randn(n_in, n_out))
+ rls = RLS(alpha=0.1)
+ rls.register_target(n_in, identifier="w")
+ weight = bm.Variable(jnp.zeros((n_in, n_out)))
+ for _ in range(steps):
+ x = jnp.asarray(rng.randn(batch_size, n_in))
+ y = x @ w_true
+ out = x @ weight.value
+ dw = rls(y, x, out, identifier="w")
+ weight.value = weight.value + dw
+ err = float(jnp.linalg.norm(weight.value - w_true))
+ has_nan = bool(jnp.isnan(weight.value).any())
+ return err, has_nan
+
+
+# ---------------------------------------------------------------------------
+# online.py — C-23: block RLS valid for any batch size
+# ---------------------------------------------------------------------------
+
+def test_rls_converges_batch1():
+ """RLS fits a known linear map for batch=1 without divergence/NaN."""
+ err, has_nan = _fit_rls(batch_size=1)
+ assert not has_nan
+ assert err < 0.5, f"RLS (B=1) did not converge: err={err}"
+
+
+def test_rls_converges_batch4_no_nan():
+ """C-23: block RLS must stay correct (no NaN / no divergence) for batch>1.
+
+ On the pre-audit code the scalar ``c = sum(1/(1+HPHᵀ))`` collapse made the
+ covariance ``P`` diverge for B>=4 and produced NaN weights.
+ """
+ err, has_nan = _fit_rls(batch_size=4)
+ assert not has_nan, "RLS (B=4) produced NaN weights (C-23 regression)"
+ assert err < 0.5, f"RLS (B=4) did not converge: err={err}"
+
+
+def test_rls_batch1_matches_scalar_update():
+ """For B==1 the block update reduces to the classic scalar RLS update."""
+ from brainpy.algorithms.online import RLS
+
+ rls = RLS(alpha=0.1)
+ rls.register_target(3, identifier="w")
+ x = jnp.array([[1.0, -2.0, 0.5]])
+ target = jnp.array([[1.0, 0.0]])
+ output = jnp.zeros((1, 2))
+ dw = rls(target, x, output, identifier="w")
+ assert dw.shape == (3, 2)
+ assert not bool(jnp.isnan(dw).any())
+
+
+def test_lms_call_runs():
+ """Coverage: LMS online algorithm produces a finite weight update."""
+ from brainpy.algorithms.online import LMS
+
+ lms = LMS(alpha=0.01)
+ x = jnp.ones((2, 3))
+ target = jnp.ones((2, 2))
+ output = jnp.zeros((2, 2))
+ dw = lms(target, x, output)
+ assert dw.shape == (3, 2)
+ assert not bool(jnp.isnan(dw).any())
+
+
+def test_online_registry_helpers():
+ """Coverage: the online method registry getters."""
+ from brainpy.algorithms import online
+
+ methods = online.get_supported_online_methods()
+ assert "rls" in methods and "lms" in methods
+ assert online.get("rls") is online.RLS
+ assert online.get("lms") is online.LMS
+ with pytest.raises(ValueError):
+ online.get("does-not-exist")
+
+
+def test_online_trainer_and_force_trainer_fit():
+ """Coverage: exercise online.py ``call`` through a tiny ESN + trainers.
+
+ Drives ``brainpy.algorithms.online.RLS.call`` via the high-level
+ ``OnlineTrainer``/``ForceTrainer`` training loop on a small reservoir.
+ """
+
+ class ESN(bp.DynamicalSystem):
+ def __init__(self, num_in, num_hidden, num_out):
+ super().__init__()
+ self.r = bp.dyn.Reservoir(
+ num_in, num_hidden,
+ Win_initializer=bp.init.Uniform(-0.1, 0.1),
+ Wrec_initializer=bp.init.Normal(scale=0.1),
+ in_connectivity=0.1, rec_connectivity=0.1,
+ comp_type="dense",
+ )
+ self.o = bp.dnn.Dense(num_hidden, num_out,
+ W_initializer=bp.init.Normal(),
+ mode=bm.training_mode)
+
+ def update(self, x):
+ return x >> self.r >> self.o
+
+ bm.random.seed(0)
+ bp.share.save(fit=True)
+ with bm.batching_environment():
+ model = ESN(5, 25, 3)
+ x = bm.random.random((1, 50, 5))
+ y = bm.random.random((1, 50, 3))
+
+ trainer = bp.OnlineTrainer(model, fit_method=bp.algorithms.RLS(alpha=0.1),
+ progress_bar=False)
+ trainer.fit([x, y])
+ out = trainer.predict(x)
+ assert out.shape == (1, 50, 3)
+
+ bm.random.seed(0)
+ with bm.batching_environment():
+ model2 = ESN(5, 25, 3)
+ force = bp.ForceTrainer(model2, alpha=0.1, progress_bar=False)
+ force.fit([x, y])
+
+
+# ---------------------------------------------------------------------------
+# offline.py — H-46 (GD .value bug) and H-47 (ridge intercept penalty)
+# ---------------------------------------------------------------------------
+
+def _small_xy(n=20, n_in=3, n_out=2, seed=1):
+ rng = np.random.RandomState(seed)
+ return jnp.asarray(rng.randn(n, n_in)), jnp.asarray(rng.randn(n, n_out))
+
+
+def test_ridge_gradient_descent_runs_no_value_error():
+ """H-46: ``gradient_descent=True`` must not raise ``.value`` AttributeError."""
+ from brainpy.algorithms.offline import RidgeRegression
+
+ x, y = _small_xy()
+ w = RidgeRegression(alpha=1e-6, gradient_descent=True, max_iter=50)(y, x)
+ assert w.shape == (3, 2)
+ assert not bool(jnp.isnan(w).any())
+
+
+def test_linear_regression_gradient_descent_runs():
+ """H-46: the same GD code path through LinearRegression."""
+ from brainpy.algorithms.offline import LinearRegression
+
+ x, y = _small_xy()
+ w = LinearRegression(gradient_descent=True, max_iter=50)(y, x)
+ assert w.shape == (3, 2)
+ assert not bool(jnp.isnan(w).any())
+
+
+def test_lasso_always_gradient_descent_runs():
+ """H-46: Lasso is always-GD; its body must not hit the ``.value`` bug."""
+ from brainpy.algorithms.offline import LassoRegression
+
+ x, y = _small_xy(n_out=1)
+ w = LassoRegression(alpha=0.1, degree=2, max_iter=50)(y, x)
+ assert w.ndim == 2
+ assert not bool(jnp.isnan(w).any())
+
+
+def test_elastic_net_gradient_descent_runs():
+ """H-46: ElasticNet is always-GD as well."""
+ from brainpy.algorithms.offline import ElasticNetRegression
+
+ x, y = _small_xy(n_out=1)
+ w = ElasticNetRegression(alpha=0.1, degree=2, max_iter=50)(y, x)
+ assert not bool(jnp.isnan(w).any())
+
+
+def test_ridge_intercept_not_over_penalized():
+ """H-47: a large ridge ``alpha`` must not shrink the intercept/bias column.
+
+ Fit ``y ≈ slope*x + intercept`` with a strongly nonzero mean and a huge
+ penalty. The bias column (index 0 of the polynomial features) must remain
+ close to the data mean while the slope is shrunk toward zero.
+ """
+ from brainpy.algorithms.offline import PolynomialRidgeRegression
+
+ rng = np.random.RandomState(2)
+ x = jnp.asarray(rng.randn(40, 1))
+ mean_y = 5.0
+ y = jnp.asarray(0.3 * np.asarray(x) + mean_y + 0.01 * rng.randn(40, 1))
+
+ model = PolynomialRidgeRegression(alpha=1e6, degree=1, add_bias=True,
+ gradient_descent=False)
+ w = model(y, x)
+ intercept = float(w[0, 0])
+ slope = float(w[1, 0])
+ # Intercept stays near the data mean despite the huge penalty.
+ assert abs(intercept - mean_y) < 0.5, f"intercept over-penalized: {intercept}"
+ # The (penalized) slope is shrunk toward zero.
+ assert abs(slope) < abs(intercept), f"slope not shrunk: slope={slope}"
+
+
+def test_logistic_regression_remaining_bug():
+ """Document a remaining (un-audited) bug in ``LogisticRegression.call``.
+
+ ``call`` flattens ``targets`` to 1-D and then indexes ``targets.shape[1]``,
+ which raises ``IndexError: tuple index out of range``. This is independent
+ of the H-46/H-47 ridge fixes and is *not* in the assigned fix scope, so the
+ test pins the current broken behavior (and exercises the code path) rather
+ than asserting success. See summary notes.
+ """
+ from brainpy.algorithms.offline import LogisticRegression
+
+ rng = np.random.RandomState(3)
+ x = jnp.asarray(rng.randn(30, 2))
+ y = jnp.asarray((np.asarray(x)[:, :1] > 0).astype("float32"))
+ with pytest.raises(IndexError):
+ LogisticRegression(max_iter=100)(y, x)
+
+
+def test_offline_least_square_and_polynomial():
+ """Coverage: non-GD lstsq path and polynomial regression."""
+ from brainpy.algorithms.offline import (LinearRegression,
+ PolynomialRegression,
+ RidgeRegression)
+
+ x, y = _small_xy()
+ w_lin = LinearRegression()(y, x) # lstsq path
+ assert w_lin.shape[-1] == 2
+ w_ridge = RidgeRegression(alpha=1e-3)(y, x) # ridge closed form (no bias)
+ assert w_ridge.shape == (3, 2)
+ w_poly = PolynomialRegression(degree=2, gradient_descent=True, max_iter=20)(y, x)
+ assert not bool(jnp.isnan(w_poly).any())
+
+
+def test_offline_registry_helpers():
+ """Coverage: the offline method registry getters."""
+ from brainpy.algorithms import offline
+
+ methods = offline.get_supported_offline_methods()
+ for name in ("linear", "ridge", "lasso", "logistic"):
+ assert name in methods
+ assert offline.get("ridge") is offline.RidgeRegression
+ with pytest.raises(ValueError):
+ offline.get("nope")
+
+
+# ---------------------------------------------------------------------------
+# measure.py — H-43: firing_rate normalization
+# ---------------------------------------------------------------------------
+
+def test_firing_rate_100hz_mean():
+ """H-43: a true 100 Hz spike train must average to ~100 Hz.
+
+ ``sp[::10] = 1`` with dt=1 ms is one spike every 10 ms = 100 Hz. The buggy
+ normalization (by requested ``width`` rather than the actual window length)
+ biased the smoothed rate so its mean drifted toward ~110 Hz.
+ """
+ spikes = np.zeros((1000, 1))
+ spikes[::10] = 1
+ rate = bp.measure.firing_rate(spikes, width=10, dt=1.0)
+ assert abs(float(np.mean(rate)) - 100.0) < 5.0, f"mean rate={np.mean(rate)}"
+
+
+def test_firing_rate_jax_mode():
+ """Coverage: numpy=False (JIT-able) branch of firing_rate."""
+ spikes = np.zeros((200, 2))
+ spikes[::5] = 1
+ rate = bp.measure.firing_rate(spikes, width=5, dt=1.0, numpy=False)
+ assert rate.shape[0] == 200
+
+
+def test_raster_plot():
+ """Coverage: raster_plot returns (index, time) of spikes."""
+ spikes = np.zeros((5, 3))
+ spikes[1, 0] = 1
+ spikes[3, 2] = 1
+ times = np.arange(5) * 0.1
+ index, time = bp.measure.raster_plot(spikes, times)
+ assert set(np.asarray(index).tolist()) == {0, 2}
+ assert len(time) == 2
+
+
+# ---------------------------------------------------------------------------
+# delay.py — H-44 (VarDelay self.data) and register_entry / retrieve coverage
+# ---------------------------------------------------------------------------
+
+def test_vardelay_constructs_and_updates_time_gt_zero():
+ """H-44: ``VarDelay(target, time=T>0)`` must construct without AttributeError.
+
+ The pre-audit ``_init_data`` read ``self.data`` before it was ever assigned,
+ raising ``AttributeError: 'data'`` for any positive delay time.
+ """
+ target = bm.Variable(bm.zeros(4))
+ delay = bp.delay.VarDelay(target, time=2.0)
+ assert delay.data is not None
+ assert delay.max_length > 0
+
+ bp.share.save(i=0, t=0.0, dt=bm.get_dt())
+ delay.update(bm.ones(4)) # one update step
+ assert delay.data is not None
+
+
+def test_vardelay_register_entry_and_retrieve():
+ """Coverage: register_entry + at()/retrieve on a VarDelay."""
+ target = bm.Variable(bm.arange(4.0))
+ delay = bp.delay.VarDelay(target, time=2.0)
+ delay.register_entry("e1", delay_time=1.0)
+ delay.register_entry("e0", delay_time=0.0)
+
+ bp.share.save(i=0, t=0.0, dt=bm.get_dt())
+ delay.update(bm.arange(4.0) + 10.0)
+
+ # zero-delay entry returns the current target value
+ out0 = delay.at("e0")
+ assert out0.shape == (4,)
+ # nonzero-delay entry retrieves a buffered value
+ out1 = delay.at("e1")
+ assert out1.shape == (4,)
+
+ with pytest.raises(KeyError):
+ delay.at("missing")
+ with pytest.raises(KeyError):
+ delay.register_entry("e1", delay_time=1.0) # duplicate
+
+
+def test_vardelay_time_none_is_zero_length():
+ """Coverage: ``time=None`` yields a zero-length (data-less) delay."""
+ target = bm.Variable(bm.zeros(3))
+ delay = bp.delay.VarDelay(target, time=None)
+ assert delay.max_length == 0
+ assert delay.data is None
+
+
+def test_length_delay_register_retrieve_update():
+ """Coverage: ``brainpy.math.LengthDelay`` retrieve/update/__call__."""
+ var = bm.Variable(bm.arange(4.0))
+ ld = bm.LengthDelay(var, delay_len=3)
+ ld.update(bm.arange(4.0) + 10.0)
+ # delay 0 -> newest value
+ assert np.allclose(np.asarray(ld(0)), np.arange(4.0) + 10.0)
+ # delay 1 -> previous (initial zeros-ish) value
+ out1 = ld.retrieve(1)
+ assert out1.shape == (4,)
+
+
+def test_data_delay_constructs():
+ """Coverage: ``DataDelay`` (subclass of VarDelay) constructs with time>0."""
+ target = bm.Variable(bm.zeros(3))
+ dd = bp.delay.DataDelay(target, data_init=bm.zeros(3), time=1.0)
+ assert dd.data is not None
+
+
+def test_vardelay_concat_update_and_init_by_return():
+ """Coverage: CONCAT_UPDATE method, init_delay_by_return, DelayAccess."""
+ from brainpy.math.delayvars import CONCAT_UPDATE
+ from brainpy.delay import init_delay_by_return, DelayAccess
+
+ target = bm.Variable(bm.zeros(3))
+ delay = bp.delay.VarDelay(target, time=1.0, method=CONCAT_UPDATE)
+ delay.register_entry("c", delay_step=2)
+ bp.share.save(i=0, t=0.0, dt=bm.get_dt())
+ delay.update(bm.ones(3))
+ assert delay.at("c").shape == (3,)
+
+ # init_delay_by_return with a plain Variable -> VarDelay
+ dl = init_delay_by_return(bm.Variable(bm.zeros(2)))
+ assert isinstance(dl, bp.delay.VarDelay)
+
+ # DelayAccess registers an entry on the delay and reads it back
+ access = DelayAccess(delay, 1.0, delay_entry="acc")
+ out = access.update()
+ assert out.shape == (3,)
+
+
+def test_vardelay_wrong_target_type_raises():
+ """Coverage: VarDelay rejects a non-Variable target."""
+ with pytest.raises(ValueError):
+ bp.delay.VarDelay(bm.zeros(3), time=1.0)
+
+
+# ---------------------------------------------------------------------------
+# optimization.py + lowdim_analyzer.py — H-49 (arg-unwrap) and H-50 (empty concat)
+# ---------------------------------------------------------------------------
+
+def test_roots_of_1d_by_x_finds_fixed_point():
+ """H-49: ``roots_of_1d_by_x`` on ``dx=-x+I`` finds the fixed point x=I.
+
+ Also exercises the arg-unwrap comprehension (passing a ``bm.Array`` arg).
+ """
+ from brainpy.analysis.utils.optimization import roots_of_1d_by_x
+
+ bp.math.enable_x64()
+ try:
+ offset = 0.7
+ f = lambda x, b: -x + b
+ candidates = jnp.linspace(-2.0, 2.0, 401)
+ # pass the parameter as a bm.Array to drive the unwrap branch
+ fps = roots_of_1d_by_x(f, candidates, args=(bm.asarray(offset),))
+ fps = np.asarray(fps)
+ assert fps.size >= 1
+ assert np.any(np.abs(fps - offset) < 1e-3), f"fps={fps}"
+ finally:
+ bp.math.disable_x64()
+
+
+def test_phase_plane_1d_finds_fixed_point():
+ """H-49: a PhasePlane1D analyzer on ``dx=-x+I`` locates the fixed point ~x=I."""
+ import matplotlib
+ matplotlib.use("Agg")
+
+ bp.math.enable_x64()
+ try:
+ offset = 0.7
+
+ @bp.odeint
+ def int_x(x, t, Iext):
+ return -x + Iext
+
+ analyzer = bp.analysis.PhasePlane1D(
+ model=int_x,
+ target_vars={"x": [-2.0, 2.0]},
+ pars_update={"Iext": offset},
+ resolutions=0.01,
+ )
+ analyzer.plot_vector_field()
+ fps = analyzer.plot_fixed_point(show=False, with_return=True)
+ fps = np.asarray(fps).ravel()
+ assert fps.size >= 1
+ assert np.any(np.abs(fps - offset) < 1e-2), f"fps={fps}"
+ finally:
+ import matplotlib.pyplot as plt
+ plt.close("all")
+ bp.math.disable_x64()
+
+
+def test_lowdim_2d_empty_candidate_concat_guard():
+ """H-50: the non-convertible 2D ``_get_fixed_points`` must guard the empty path.
+
+ Build a 2D analyzer that cannot reduce to a single equation, then drive the
+ optimization branch with an impossible candidate-screening tolerance so that
+ nothing converges. The buggy code did ``jnp.concatenate([])`` -> ValueError;
+ the fix returns correctly-shaped empty arrays.
+ """
+ from brainpy.analysis.lowdim.lowdim_analyzer import Num2DAnalyzer
+
+ bp.math.enable_x64()
+ try:
+ @bp.odeint
+ def ds1(s1, t, s2):
+ return -s1 + jnp.tanh(s2) + 0.1
+
+ @bp.odeint
+ def ds2(s2, t, s1):
+ return -s2 + jnp.tanh(s1) + 0.1
+
+ analyzer = Num2DAnalyzer(
+ model=[ds1, ds2],
+ target_vars={"s1": [-2.0, 2.0], "s2": [-2.0, 2.0]},
+ resolutions=0.05,
+ )
+ assert not analyzer._can_convert_to_one_eq()
+
+ candidates = jnp.asarray(
+ np.random.RandomState(0).uniform(-2.0, 2.0, size=(30, 2)))
+ # tol_opt_candidate=-1 screens out every candidate -> empty all_fps list
+ fps, ids, pargs = analyzer._get_fixed_points(candidates,
+ tol_opt_candidate=-1.0)
+ fps = np.asarray(fps)
+ assert fps.shape == (0, 2)
+ assert np.asarray(ids).shape == (0,)
+ finally:
+ bp.math.disable_x64()
+
+
+def test_roots_of_1d_by_xy_and_brentq_helpers():
+ """Coverage: roots_of_1d_by_xy and the scalar brentq helper functions."""
+ from brainpy.analysis.utils import optimization as opt
+
+ bp.math.enable_x64()
+ try:
+ f = lambda x, a: -x + a
+
+ # roots_of_1d_by_xy on dx = -x + a, a = 0.5
+ xs, ys = opt.roots_of_1d_by_xy(f, jnp.array([-2.0]), jnp.array([2.0]),
+ jnp.array([0.5]))
+ assert np.any(np.abs(np.asarray(xs) - 0.5) < 1e-6)
+
+ # brentq_roots (jitted vmap brentq)
+ roots, _ = opt.brentq_roots(f, jnp.array([-2.0]), jnp.array([2.0]),
+ jnp.array([0.5]))
+ assert np.any(np.abs(np.asarray(roots) - 0.5) < 1e-6)
+
+ # get_brentq_candidates over a 2D meshgrid
+ starts, ends, args = opt.get_brentq_candidates(
+ lambda x, y: -x + y, jnp.linspace(-2.0, 2.0, 20),
+ jnp.linspace(-1.0, 1.0, 5))
+ assert len(np.asarray(starts)) == len(np.asarray(ends))
+
+ # pure-numpy brentq + 1D root finder
+ root, _iters, _calls = opt.numpy_brentq(lambda x: x - 0.3, -1.0, 1.0)
+ assert abs(root - 0.3) < 1e-9
+ roots_np = opt.find_root_of_1d_numpy(lambda x: -x + 0.4,
+ np.linspace(-2.0, 2.0, 50))
+ assert np.any(np.abs(np.asarray(roots_np) - 0.4) < 1e-6)
+ finally:
+ bp.math.disable_x64()
+
+
+def test_phase_plane_1d_no_fixed_point_runs_clean():
+ """Coverage: a 1D system with no real fixed point returns cleanly (no crash)."""
+ import matplotlib
+ matplotlib.use("Agg")
+
+ bp.math.enable_x64()
+ try:
+ @bp.odeint
+ def int_x(x, t):
+ # dx = x^2 + 1 has no real root -> no fixed point in range
+ return x ** 2 + 1.0
+
+ analyzer = bp.analysis.PhasePlane1D(
+ model=int_x,
+ target_vars={"x": [-2.0, 2.0]},
+ resolutions=0.01,
+ )
+ fps = analyzer.plot_fixed_point(show=False, with_return=True)
+ assert np.asarray(fps).size == 0
+ finally:
+ import matplotlib.pyplot as plt
+ plt.close("all")
+ bp.math.disable_x64()
+
+
+# ---------------------------------------------------------------------------
+# running/jax_multiprocessing.py — H-48 (pmap reuse / labels)
+# ---------------------------------------------------------------------------
+
+def test_jax_vectorize_map_sequence_and_dict():
+ """Coverage: jax_vectorize_map over sequence and dict arguments."""
+ from brainpy.running.jax_multiprocessing import jax_vectorize_map
+
+ out = jax_vectorize_map(lambda x: x * 2.0, [jnp.arange(6.0)], num_parallel=2)
+ assert np.allclose(np.asarray(out), np.arange(6.0) * 2.0)
+
+ # NOTE: clear_buffer=True calls the process-global ``bm.clear_buffer_memory()``
+ # which deletes ALL live device arrays (poisoning other test modules in the
+ # same pytest session). Patch it to a no-op so the branch is still covered.
+ _orig_clear = bm.clear_buffer_memory
+ bm.clear_buffer_memory = lambda *a, **k: None
+ try:
+ out2 = jax_vectorize_map(lambda a, b: a + b,
+ {"a": jnp.arange(4.0), "b": jnp.ones(4)},
+ num_parallel=2, clear_buffer=True)
+ finally:
+ bm.clear_buffer_memory = _orig_clear
+ assert np.allclose(np.asarray(out2), np.arange(4.0) + 1.0)
+
+
+def test_jax_vectorize_map_length_mismatch_raises():
+ """Coverage: mismatched argument lengths raise ValueError."""
+ from brainpy.running.jax_multiprocessing import jax_vectorize_map
+
+ with pytest.raises(ValueError):
+ jax_vectorize_map(lambda a, b: a + b,
+ {"a": jnp.arange(4.0), "b": jnp.ones(3)},
+ num_parallel=2)
+
+
+def test_jax_parallelize_map_single_device():
+ """H-48: jax_parallelize_map runs across chunks (one device per chunk)."""
+ from brainpy.running.jax_multiprocessing import jax_parallelize_map
+
+ n_dev = jax.local_device_count()
+ out = jax_parallelize_map(lambda x: x * 2.0,
+ [jnp.arange(float(3 * n_dev))],
+ num_parallel=n_dev)
+ assert np.allclose(np.asarray(out), np.arange(float(3 * n_dev)) * 2.0)
+
+
+# ---------------------------------------------------------------------------
+# runners.py — C-22: DSRunner(memory_efficient=True)
+# ---------------------------------------------------------------------------
+
+class _TinyNet(bp.DynamicalSystem):
+ def __init__(self):
+ super().__init__()
+ self.n = bp.dyn.LifRef(3)
+
+ def update(self, inp):
+ self.n(inp)
+ return self.n.V.value
+
+
+def test_dsrunner_memory_efficient_matches_normal():
+ """C-22: ``memory_efficient=True`` must run and match the standard run.
+
+ The pre-audit code did ``jax.ShapeDtypeStruct(mon.shape, ...)`` on a *dict*
+ monitor and used a broken ``pure_callback`` signature, so any
+ ``memory_efficient=True`` run raised ``AttributeError: 'dict' ... 'shape'``.
+ """
+ inputs = bm.ones((15, 3)) * 2.0
+
+ bm.random.seed(0)
+ r_normal = bp.DSRunner(_TinyNet(), monitors=["n.V"],
+ memory_efficient=False, progress_bar=False)
+ r_normal.run(inputs=inputs)
+ mon_normal = np.asarray(r_normal.mon["n.V"])
+
+ bm.random.seed(0)
+ r_mem = bp.DSRunner(_TinyNet(), monitors=["n.V"],
+ memory_efficient=True, progress_bar=False)
+ r_mem.run(inputs=inputs)
+ mon_mem = np.asarray(r_mem.mon["n.V"])
+
+ assert mon_normal.shape == mon_mem.shape == (15, 3)
+ assert np.allclose(mon_normal, mon_mem), "memory_efficient run diverged"
+
+
+def test_dsrunner_basic_run_with_monitors():
+ """Coverage: a plain DSRunner run with monitors and duration."""
+ bm.random.seed(0)
+ runner = bp.DSRunner(_TinyNet(), monitors=["n.V"], progress_bar=False)
+ runner.run(duration=2.0)
+ assert runner.mon["n.V"].shape[1] == 3
+ assert runner.mon.ts.shape[0] == runner.mon["n.V"].shape[0]