Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
<a href="https://badge.fury.io/py/brainpy"><img alt="PyPI version" src="https://badge.fury.io/py/brainpy.svg"></a>
<a href="https://github.com/brainpy/BrainPy/actions/workflows/CI.yml"><img alt="Continuous Integration" src="https://github.com/brainpy/BrainPy/actions/workflows/CI.yml/badge.svg"></a>
<a href="https://github.com/brainpy/BrainPy/actions/workflows/CI-models.yml"><img alt="Continuous Integration with Models" src="https://github.com/brainpy/BrainPy/actions/workflows/CI-models.yml/badge.svg"></a>
<a href="https://github.com/brainpy/BrainPy"><img alt="Test Coverage" src="https://img.shields.io/badge/coverage-93%25-brightgreen"></a>
</p>


Expand Down
14 changes: 11 additions & 3 deletions brainpy/algorithms/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def gradient_descent_solve(self, targets, inputs, outputs=None):
def cond_fun(a):
i, par_old, par_new = a
return jnp.logical_and(jnp.logical_not(jnp.allclose(par_old, par_new)),
i < self.max_iter).value
i < self.max_iter)

def body_fun(a):
i, _, par_new = a
Expand Down Expand Up @@ -269,9 +269,17 @@ def call(self, targets, inputs, outputs=None):
if self.gradient_descent:
return self.gradient_descent_solve(targets, inputs)
else:
n_features = inputs.shape[-1]
temp = inputs.T @ inputs
if self.regularizer.alpha > 0.:
temp += self.regularizer.alpha * jnp.eye(inputs.shape[-1])
penalty = self.regularizer.alpha * jnp.ones((n_features,))
# Do not penalize the intercept/bias column. ``polynomial_features``
# (used by ``PolynomialRidgeRegression`` when ``add_bias=True``)
# prepends a constant column at index 0; shrinking it would bias
# the fit on data with a nonzero mean.
if getattr(self, 'add_bias', False):
penalty = penalty.at[0].set(0.)
temp += jnp.diag(penalty)
weights = jnp.linalg.pinv(temp) @ (inputs.T @ targets)
return weights

Expand Down Expand Up @@ -383,7 +391,7 @@ def call(self, targets, inputs, outputs=None) -> ArrayType:
def cond_fun(a):
i, par_old, par_new = a
return jnp.logical_and(jnp.logical_not(jnp.allclose(par_old, par_new)),
i < self.max_iter).value
i < self.max_iter)

def body_fun(a):
i, par_old, par_new = a
Expand Down
17 changes: 13 additions & 4 deletions brainpy/algorithms/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,21 @@ def call(
assert input.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {input.shape}'
assert target.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {target.shape}'
assert output.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {output.shape}'
k = jnp.dot(P.value, input.T) # (num_input, num_batch)
# Block recursive least squares update (valid for any batch size B>=1).
# See e.g. Haykin, "Adaptive Filter Theory", block/multi-sample RLS.
# K = P Hᵀ (I_B + H P Hᵀ)⁻¹ -> Kalman gain, shape (num_input, B)
# P <- P - K H P -> covariance update, shape (num_input, num_input)
# w <- w + K (target - H w) -> here output = H w, so dw = -K (output - target)
# For B==1 this reduces exactly to the previous scalar update
# (c = 1/(1+hPh)), but unlike `jnp.sum(1/(1+HPHᵀ))` it stays correct for
# B>1 by inverting the full (B, B) matrix instead of summing reciprocals.
Pv = P.value
k = jnp.dot(Pv, input.T) # (num_input, num_batch)
hPh = jnp.dot(input, k) # (num_batch, num_batch)
c = jnp.sum(1.0 / (1.0 + hPh)) # ()
P -= c * jnp.dot(k, k.T) # (num_input, num_input)
gain = jnp.dot(k, jnp.linalg.inv(jnp.eye(hPh.shape[0]) + hPh)) # (num_input, num_batch)
P.value = Pv - jnp.dot(gain, jnp.dot(input, Pv)) # (num_input, num_input)
e = output - target # (num_batch, num_output)
dw = -c * jnp.dot(k, e) # (num_input, num_output)
dw = -jnp.dot(gain, e) # (num_input, num_output)
return dw


Expand Down
17 changes: 12 additions & 5 deletions brainpy/analysis/lowdim/lowdim_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def _get_fixed_points(self, candidates, *args, num_seg=None, tol_aux=1e-7, loss_
# args: parameters, a list/tuple of vectors
candidates = candidates.value if isinstance(candidates, bm.Array) else candidates
selected_ids = np.arange(len(candidates))
args = tuple(a.value if isinstance(candidates, bm.Array) else a for a in args)
args = tuple(a.value if isinstance(a, bm.Array) else a for a in args)
for a in args: assert len(a) == len(candidates)
if num_seg is None:
num_seg = len(self.resolutions[self.x_var])
Expand Down Expand Up @@ -950,7 +950,7 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
# args: parameters, a list/tuple of vectors
candidates = candidates.value if isinstance(candidates, bm.Array) else candidates
selected_ids = np.arange(len(candidates))
args = tuple(a.value if isinstance(candidates, bm.Array) else a for a in args)
args = tuple(a.value if isinstance(a, bm.Array) else a for a in args)
for a in args: assert len(a) == len(candidates)

if self.convert_type() == C.x_by_y:
Expand Down Expand Up @@ -1035,9 +1035,16 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
all_ids.append(seg_ids)
for i in range(len(all_args)):
all_args[i].append(seg_args[i][ids])
all_fps = jnp.concatenate(all_fps)
all_ids = jnp.concatenate(all_ids)
all_args = tuple(jnp.concatenate(args) for args in all_args)
if len(all_fps):
all_fps = jnp.concatenate(all_fps)
all_ids = jnp.concatenate(all_ids)
all_args = tuple(jnp.concatenate(args) for args in all_args)
else:
# No candidate converged to a fixed point. Return empty arrays
# with the correct shapes/dtypes instead of concatenating [].
all_fps = jnp.zeros((0, candidates.shape[1]), dtype=candidates.dtype)
all_ids = jnp.zeros((0,), dtype=selected_ids.dtype)
all_args = tuple(jnp.zeros((0,), dtype=a.dtype) for a in args)
return all_fps, all_ids, all_args


Expand Down
2 changes: 1 addition & 1 deletion brainpy/analysis/utils/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def roots_of_1d_by_x(f, candidates, args=()):
"""
f = f_without_jaxarray_return(f)
candidates = candidates.value if isinstance(candidates, bm.Array) else candidates
args = tuple(a.value if isinstance(candidates, bm.Array) else a for a in args)
args = tuple(a.value if isinstance(a, bm.Array) else a for a in args)
vals = f(candidates, *args)
signs = jnp.sign(vals)
zero_sign_idx = jnp.where(signs == 0)[0]
Expand Down
13 changes: 8 additions & 5 deletions brainpy/connect/random_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ def __repr__(self):
f'seed={self.seed})')

def _iii(self):
if (not self.include_self) and (self.pre_num != self.post_num):
raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). '
f'But `include_self` is set to True.')

# NOTE: no guard on ``include_self=False`` for rectangular (pre_num !=
# post_num) shapes — ``build_coo``/``build_csr`` already drop coincident
# ``pre == post`` indices generically, so the previous (contradictory)
# ConnectorError was both wrong-worded and overly restrictive.
if self.pre_ratio < 1.:
pre_num_to_select = int(self.pre_num * self.pre_ratio)
pre_ids = self._jaxrand.choice(self.pre_num, size=(pre_num_to_select,), replace=False)
Expand All @@ -96,7 +96,10 @@ def _iii(self):
pre_ids = jnp.arange(self.pre_num)

post_num_total = self.post_num
post_num_to_select = int(self.post_num * self.prob)
# Round instead of truncating: ``int(post_num * prob)`` floors a small
# expected fan-out (e.g. ``3 * 0.3 = 0.9``) to 0 connections, silently
# producing an empty connectivity for small post populations.
post_num_to_select = int(round(self.post_num * self.prob))

if self.allow_multi_conn:
selected_post_ids = self._jaxrand.randint(0, post_num_total, (pre_num_to_select, post_num_to_select))
Expand Down
7 changes: 5 additions & 2 deletions brainpy/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,13 @@ def __init__(

# delay data
self._init = init
# ``self.data`` must exist before ``_init_data`` is called, because
# ``_init_data`` reads ``self.data`` to decide whether to allocate a
# new buffer. Initialize it unconditionally to avoid an AttributeError
# on the ``time > 0`` (``max_length > 0``) path.
self.data = None
if self.max_length > 0:
self._init_data(self.max_length)
else:
self.data = None

# other info
if entries is not None:
Expand Down
29 changes: 24 additions & 5 deletions brainpy/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,13 @@ def stdp_update(
w_min: numbers.Number = None,
w_max: numbers.Number = None
):
if bm.isscalar(self.W):
raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.')
if not isinstance(self.W, bm.Variable):
raise ValueError(f'When using STDP to update synaptic weights, the weight must be a variable.')
# Promote the plain-array weight to a traceable Variable so the STDP
# in-place update works even when the comm was not built in a training
# mode. Only reached from ``stdp_update``; non-plastic use is unaffected.
self.W = bm.Variable(self.W)
if on_pre is not None:
spike = on_pre['spike']
trace = on_pre['trace']
Expand Down Expand Up @@ -320,8 +325,13 @@ def stdp_update(
w_min: numbers.Number = None,
w_max: numbers.Number = None
):
if bm.isscalar(self.weight):
raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.')
if not isinstance(self.weight, bm.Variable):
raise ValueError(f'When using STDP to update synaptic weights, the weight must be a variable.')
# Promote the plain-array weight to a traceable Variable so the STDP
# in-place update works even when the comm was not built in a training
# mode. Only reached from ``stdp_update``; non-plastic use is unaffected.
self.weight = bm.Variable(self.weight)
if on_pre is not None:
spike = on_pre['spike']
trace = on_pre['trace']
Expand Down Expand Up @@ -375,7 +385,10 @@ def stdp_update(
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)
# Promote the plain-array weight to a traceable Variable so that the
# STDP in-place update below works. This branch is only reached from
# ``stdp_update`` (plasticity), so non-plastic use is unaffected.
self.weight = bm.Variable(self.weight)
if on_pre is not None:
spike = on_pre['spike']
trace = on_pre['trace']
Expand Down Expand Up @@ -449,7 +462,10 @@ def stdp_update(
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)
# Promote the plain-array weight to a traceable Variable so that the
# STDP in-place update below works. This branch is only reached from
# ``stdp_update`` (plasticity), so non-plastic use is unaffected.
self.weight = bm.Variable(self.weight)
if on_pre is not None:
spike = on_pre['spike']
trace = on_pre['trace']
Expand Down Expand Up @@ -500,7 +516,10 @@ def stdp_update(
raise ValueError(
f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.')
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)
# Promote the plain-array weight to a traceable Variable so that the
# STDP in-place update below works. This branch is only reached from
# ``stdp_update`` (plasticity), so non-plastic use is unaffected.
self.weight = bm.Variable(self.weight)
if on_pre is not None: # update on presynaptic spike
spike = on_pre['spike']
trace = on_pre['trace']
Expand Down
63 changes: 53 additions & 10 deletions brainpy/dnn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
):
# ``BatchNorm`` computes statistics across a batch axis, so it requires a
# batching/training mode. The global default mode is ``NonBatchingMode``,
# which would make the layer raise ``UnsupportedError`` out-of-the-box
# (H-51). Default to the training mode when the user does not specify one.
if mode is None:
mode = bm.training_mode
super(BatchNorm, self).__init__(name=name, mode=mode)
# check.is_subclass(self.mode, (bm.BatchingMode, bm.TrainingMode), self.name)

Expand All @@ -131,9 +137,19 @@ def __init__(
self.running_mean = bm.Variable(jnp.zeros(self.num_features))
self.running_var = bm.Variable(jnp.ones(self.num_features))
if self.affine:
assert isinstance(self.mode, bm.TrainingMode)
self.bias = bm.TrainVar(parameter(self.bias_initializer, self.num_features))
self.scale = bm.TrainVar(parameter(self.scale_initializer, self.num_features))
bias = parameter(self.bias_initializer, self.num_features)
scale = parameter(self.scale_initializer, self.num_features)
# Make the affine parameters trainable only under a training mode;
# otherwise keep them as plain variables (matching the pattern used
# by ``brainpy.dnn.Linear``/``Conv``). This avoids the hard
# ``assert isinstance(self.mode, TrainingMode)`` that crashed the
# layer under non-training modes (H-51).
if isinstance(self.mode, bm.TrainingMode):
self.bias = bm.TrainVar(bias)
self.scale = bm.TrainVar(scale)
else:
self.bias = bm.Variable(bias)
self.scale = bm.Variable(scale)

def _check_input_dim(self, x):
raise NotImplementedError
Expand Down Expand Up @@ -500,9 +516,20 @@ def __init__(
assert all([isinstance(s, int) for s in normalized_shape]), 'Must be a sequence of integer.'
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
assert isinstance(self.mode, bm.TrainingMode)
self.bias = bm.TrainVar(parameter(self.bias_initializer, self.normalized_shape))
self.scale = bm.TrainVar(parameter(self.scale_initializer, self.normalized_shape))
bias = parameter(self.bias_initializer, self.normalized_shape)
scale = parameter(self.scale_initializer, self.normalized_shape)
# ``LayerNorm`` does not depend on batch statistics, so it works under
# any mode (including the default ``NonBatchingMode``). Only wrap the
# affine parameters as trainable variables under a training mode;
# otherwise keep them as plain variables. This removes the hard
# ``assert isinstance(self.mode, TrainingMode)`` that crashed the
# affine layer out-of-the-box (H-51).
if isinstance(self.mode, bm.TrainingMode):
self.bias = bm.TrainVar(bias)
self.scale = bm.TrainVar(scale)
else:
self.bias = bm.Variable(bias)
self.scale = bm.Variable(scale)

def update(self, x):
if x.shape[-len(self.normalized_shape):] != self.normalized_shape:
Expand Down Expand Up @@ -585,16 +612,32 @@ def __init__(
self.bias_initializer = bias_initializer
self.scale_initializer = scale_initializer
if self.affine:
assert isinstance(self.mode, bm.TrainingMode)
self.bias = bm.TrainVar(parameter(self.bias_initializer, self.num_channels))
self.scale = bm.TrainVar(parameter(self.scale_initializer, self.num_channels))
bias = parameter(self.bias_initializer, self.num_channels)
scale = parameter(self.scale_initializer, self.num_channels)
# ``GroupNorm``/``InstanceNorm`` compute statistics independently of
# the batch size, so they work under any mode (including the default
# ``NonBatchingMode``). Only make the affine parameters trainable
# under a training mode; otherwise keep them as plain variables. This
# removes the hard ``assert isinstance(self.mode, TrainingMode)`` that
# crashed the affine layer out-of-the-box (H-51).
if isinstance(self.mode, bm.TrainingMode):
self.bias = bm.TrainVar(bias)
self.scale = bm.TrainVar(scale)
else:
self.bias = bm.Variable(bias)
self.scale = bm.Variable(scale)

def update(self, x):
assert x.shape[-1] == self.num_channels
origin_shape, origin_dim = x.shape, x.ndim
group_shape = (-1,) + x.shape[1:-1] + (self.num_groups, self.num_channels // self.num_groups)
x = bm.as_jax(x.reshape(group_shape))
reduction_axes = tuple(range(1, x.ndim - 1)) + (-1,)
# After reshape the axis layout is
# ``[0]=batch, [1..ndim-3]=spatial, [ndim-2]=group, [ndim-1]=channels-per-group``.
# Normalization must reduce over the spatial axes and the within-group
# channel axis, but NOT over the group axis (``ndim-2``); otherwise the
# groups are averaged together and ``num_groups`` has no effect (C-05).
reduction_axes = tuple(range(1, x.ndim - 2)) + (-1,)
mean = jnp.mean(x, reduction_axes, keepdims=True)
var = jnp.var(x, reduction_axes, keepdims=True)
x = (x - mean) * lax.rsqrt(var + lax.convert_element_type(self.epsilon, x.dtype))
Expand Down
19 changes: 18 additions & 1 deletion brainpy/dyn/channels/calcium.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from typing import Union, Callable, Optional

import jax.numpy as jnp

import brainpy.math as bm
from brainpy.context import share
from brainpy.dyn.ions.calcium import Calcium, CalciumDyna
Expand All @@ -40,6 +42,20 @@
]


def _exprel(x):
"""Stable ``(exp(x) - 1) / x`` with a finite value *and* finite gradient at ``x == 0``.

The HH/Markov rate functions have the removable-singularity form
``num * temp / (exp(temp / k) - 1)`` which is ``0 / 0`` (NaN value and NaN
gradient) at the singular voltage. Rewriting them as ``num * k / exprel(...)``
removes the singularity in value, but ``brainpy.math.exprel`` still yields a
NaN gradient at 0, so we use this branch-safe helper instead.
"""
small = jnp.abs(x) < 1e-7
safe_x = jnp.where(small, 1.0, x)
return jnp.where(small, 1.0 + x / 2.0, jnp.expm1(safe_x) / safe_x)


class CalciumChannel(IonChannel):
"""Base class for Calcium ion channels."""

Expand Down Expand Up @@ -708,7 +724,8 @@ def __init__(

def f_p_alpha(self, V):
temp = -27 - V + self.V_sh
return 0.055 * temp / (bm.exp(temp / 3.8) - 1)
# 0.055 * temp / (exp(temp/3.8) - 1) == 0.055 * 3.8 / exprel(temp/3.8)
return (0.055 * 3.8) / _exprel(temp / 3.8)

def f_p_beta(self, V):
return 0.94 * bm.exp((-75. - V + self.V_sh) / 17.)
Expand Down
Loading
Loading