Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions model2vec/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
weights: torch.Tensor | None = None,
freeze: bool = False,
normalize: bool = True,
freeze_weights: bool = False,
) -> None:
"""Initialize a trainable StaticModel from a StaticModel.

Expand All @@ -59,6 +60,7 @@ def __init__(
:param weights: The weights of the model. If None, the weights are initialized to zeros.
:param freeze: Whether to freeze the embeddings. This should be set to False in most cases.
:param normalize: Whether to normalize the embeddings.
:param freeze_weights: Whether to freeze the learned token weights.
"""
super().__init__()
self.pad_id = pad_id
Expand All @@ -67,6 +69,7 @@ def __init__(
self.hidden_dim = hidden_dim
self.n_layers = n_layers
self.normalize = normalize
self.freeze_weights = freeze_weights

self.vectors = vectors
if self.vectors.dtype != torch.float32:
Expand All @@ -92,10 +95,10 @@ def construct_weights(self) -> nn.Parameter:
"""Construct the weights for the model."""
if self._weights is not None:
w = logit(self._weights)
return nn.Parameter(w.float(), requires_grad=True)
weights = torch.zeros(len(self.token_mapping))
weights[self.pad_id] = -10_000
return nn.Parameter(weights, requires_grad=not self.freeze)
else:
w = torch.zeros(len(self.token_mapping)).float()
w[self.pad_id] = -10_000
return nn.Parameter(w, requires_grad=not self.freeze_weights)

def construct_head(self) -> nn.Sequential:
"""Constructs a simple classifier head."""
Expand Down Expand Up @@ -136,7 +139,11 @@ def _initialize(self) -> None:

@classmethod
def from_pretrained(
cls: type[ModelType], path: str = "minishlab/potion-base-32m", *, token: str | None = None, **kwargs: Any
cls: type[ModelType],
path: str = "minishlab/potion-base-32m",
*,
token: str | None = None,
**kwargs: Any,
) -> ModelType:
"""Load the model from a pretrained model2vec model."""
if model_name := kwargs.pop("model_name", None):
Expand All @@ -147,7 +154,11 @@ def from_pretrained(

@classmethod
def from_static_model(
cls: type[ModelType], *, model: StaticModel, pad_token: str | None = None, **kwargs: Any
cls: type[ModelType],
*,
model: StaticModel,
pad_token: str | None = None,
**kwargs: Any,
) -> ModelType:
"""Load the model from a static model."""
model.embedding = np.nan_to_num(model.embedding)
Expand Down Expand Up @@ -179,14 +190,15 @@ def _encode(self, input_ids: torch.Tensor) -> torch.Tensor:
:param input_ids: A 2D tensor of input ids. All input ids are have to be within bounds.
:return: The mean over the input ids, weighted by token weights.
"""
w = self.w[input_ids]
w = torch.sigmoid(w)
zeros = (input_ids != self.pad_id).float()
w = w * zeros
# Add a small epsilon to avoid division by zero
length = zeros.sum(1) + 1e-16
input_ids_embeddings = self.token_mapping[input_ids]
embedded = self.embeddings(input_ids_embeddings)

w = self.w[input_ids]
w = torch.sigmoid(w)
w = w * zeros
# Weigh each token
embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
# Mean pooling by dividing by the length
Expand Down Expand Up @@ -235,13 +247,20 @@ def device(self) -> torch.device:

def to_static_model(self) -> StaticModel:
"""Convert the model to a static model."""
emb = self.embeddings.weight.detach().cpu().numpy()
w = torch.sigmoid(self.w).detach().cpu().numpy()
with torch.no_grad():
emb = self.embeddings.weight
emb = emb.cpu().numpy()
w = torch.sigmoid(self.w).cpu().numpy()

# If the weights and emb are the same length, the model was not quantized before training.
if len(w) == len(emb):
emb = emb * w[:, None]
return StaticModel(
vectors=emb, weights=None, tokenizer=self.tokenizer, normalize=self.normalize, token_mapping=None
vectors=emb,
weights=None,
tokenizer=self.tokenizer,
normalize=self.normalize,
token_mapping=None,
)
return StaticModel(
vectors=emb,
Expand All @@ -265,7 +284,12 @@ def _determine_batch_size(self, batch_size: int | None, train_length: int) -> in
return batch_size

def _check_val_split(
self, X: list[str], y: list, X_val: list[str] | None, y_val: list | None, test_size: float
self,
X: list[str],
y: list,
X_val: list[str] | None,
y_val: list | None,
test_size: float,
) -> tuple[list[str], list[str], Sequence, Sequence]:
if (X_val is not None) != (y_val is not None):
raise ValueError("Both X_val and y_val must be provided together, or neither.")
Expand Down
2 changes: 2 additions & 0 deletions model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
weights: torch.Tensor | None = None,
freeze: bool = False,
normalize: bool = True,
freeze_weights: bool = False,
) -> None:
"""Initialize a standard classifier model."""
# Alias: Follows scikit-learn. Set to dummy classes
Expand All @@ -55,6 +56,7 @@ def __init__(
hidden_dim=hidden_dim,
n_layers=n_layers,
normalize=normalize,
freeze_weights=freeze_weights,
)

@property
Expand Down
2 changes: 2 additions & 0 deletions model2vec/train/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
weights: torch.Tensor | None = None,
freeze: bool = False,
normalize: bool = True,
freeze_weights: bool = False,
) -> None:
"""Initialize a standard similarity model."""
super().__init__(
Expand All @@ -43,6 +44,7 @@ def __init__(
hidden_dim=hidden_dim,
n_layers=n_layers,
normalize=normalize,
freeze_weights=freeze_weights,
)

def fit(
Expand Down
Loading
Loading