Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions cebra/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,25 +154,22 @@ def expand_index_in_trial(self, index, trial_ids, trial_borders):
trial_ids is in size of a length of self.index and indicate the trial id of the index belong to.
trial_borders is in size of a length of self.idnex and indicate the border of each trial.

Todo:
- rewrite
"""

# TODO(stes) potential room for speed improvements by pre-allocating these tensors/
# using non_blocking copy operation.
offset = torch.arange(-self.offset.left,
self.offset.right,
device=index.device)
index = torch.tensor(
[
torch.clamp(
i,
trial_borders[trial_ids[i]] + self.offset.left,
trial_borders[trial_ids[i] + 1] - self.offset.right,
) for i in index
],
device=self.device,
)

# Vectorized lookup and boundary calculation
batch_trial_ids = trial_ids[index]
min_borders = trial_borders[batch_trial_ids] + self.offset.left
max_borders = trial_borders[batch_trial_ids + 1] - self.offset.right

# Fast C-level clamp
index = torch.clamp(index, min=min_borders, max=max_borders)

return index[:, None] + offset[None, :]

@abc.abstractmethod
Expand Down