diff --git a/bindsnet/network/monitors.py b/bindsnet/network/monitors.py index 4446e678d..d91e6420e 100644 --- a/bindsnet/network/monitors.py +++ b/bindsnet/network/monitors.py @@ -86,7 +86,13 @@ def get(self, var: str) -> torch.Tensor: if self.clean: return_logs = torch.empty(0, device=self.device) else: - return_logs = torch.cat(self.recording[var], 0) + # If the actual run length is shorter than the preallocated time, + # some placeholders remain and break torch.cat, so we drop them. + entries = [e for e in self.recording[var] if torch.is_tensor(e)] + if len(entries) == 0: + return_logs = torch.empty(0, device=self.device) + else: + return_logs = torch.cat(entries, 0) if self.time is None: self.recording[var] = [] return return_logs