Skip to content

Fix Monitor bug with preallocated buffers and torch.cat#761

Open
steampunc wants to merge 1 commit into
BindsNET:masterfrom
steampunc:fix/monitor-static-mode
Open

Fix Monitor bug with preallocated buffers and torch.cat#761
steampunc wants to merge 1 commit into
BindsNET:masterfrom
steampunc:fix/monitor-static-mode

Conversation

@steampunc
Copy link
Copy Markdown

I found an issue with the Monitor class's static buffers, which this PR fixes.

bindsnet.network.monitors.Monitor seems to support a preallocated buffer mode where you pass a simulation duration at construction, which leads reset_state_variables to fill a recording list with placeholders.

record() appends the new tensors and pops the empty placeholders from the head of the list, keeping total length at some time T. This works when the run length equals T, but if it runs for fewer than T steps, the list ends up with a mix of tensors and the [] placeholders. torch.cat(self.recording[var], 0) crashes when it's called on this mixed list.

Reproducing

import torch
from bindsnet.network import Network
from bindsnet.network.nodes import LIFNodes, Input
from bindsnet.network.topology import Connection
from bindsnet.network.monitors import Monitor

net = Network(dt=1.0)
inp = Input(n=2); lif = LIFNodes(n=2)
net.add_layer(inp, name="in"); net.add_layer(lif, name="lif")
net.add_connection(Connection(source=inp, target=lif), source="in", target="lif")

# Preallocate for 100 steps, but only run 10.
mon = Monitor(lif, state_vars=["v"], time=100)
net.add_monitor(mon, name="m")
net.run(inputs={"in": torch.zeros(10, 1, 2)}, time=10)
mon.get("v")

Causes:

File ".../bindsnet/network/monitors.py", line 89, in get
return_logs = torch.cat(self.recording[var], 0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: expected Tensor as element 0 in argument 0, but got list

The current workaround I've used in some downstream code is to just always construct monitors in dynamic mode (time=None), but this gives up the speed benefit of preallocation.

To fix this more properly, I've added a list comprehension in .get() to filter for just the tensors before passing the list to torch.cat. This drops the empty placeholders, so a short run returns a tensor of shape [truncated_run_length, ...] instead of crashing. A run that's filled the buffer still returns the full [T, ...].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant