From 4629383babb2aae0c453c8be7ddebb58b6e54395 Mon Sep 17 00:00:00 2001 From: stephantul Date: Thu, 30 Apr 2026 09:49:15 +0200 Subject: [PATCH 1/3] improve training --- model2vec/train/base.py | 81 +++++++++++++++++++++++------------ model2vec/train/classifier.py | 17 +++----- model2vec/train/similarity.py | 5 ++- 3 files changed, 64 insertions(+), 39 deletions(-) diff --git a/model2vec/train/base.py b/model2vec/train/base.py index 8ce093e..ceefae7 100644 --- a/model2vec/train/base.py +++ b/model2vec/train/base.py @@ -46,9 +46,9 @@ def __init__( weights: torch.Tensor | None = None, freeze: bool = False, normalize: bool = True, + freeze_weights: bool = False, ) -> None: - """ - Initialize a trainable StaticModel from a StaticModel. + """Initialize a trainable StaticModel from a StaticModel. :param vectors: The embeddings of the staticmodel. :param tokenizer: The tokenizer. @@ -60,6 +60,7 @@ def __init__( :param weights: The weights of the model. If None, the weights are initialized to zeros. :param freeze: Whether to freeze the embeddings. This should be set to False in most cases. :param normalize: Whether to normalize the embeddings. + :param freeze_weights: Whether to freeze the learned token weights. """ super().__init__() self.pad_id = pad_id @@ -68,6 +69,7 @@ def __init__( self.hidden_dim = hidden_dim self.n_layers = n_layers self.normalize = normalize + self.freeze_weights = freeze_weights self.vectors = vectors if self.vectors.dtype != torch.float32: @@ -93,26 +95,31 @@ def construct_weights(self) -> nn.Parameter: """Construct the weights for the model.""" if self._weights is not None: w = logit(self._weights) - return nn.Parameter(w.float(), requires_grad=True) - weights = torch.zeros(len(self.token_mapping)) - weights[self.pad_id] = -10_000 - return nn.Parameter(weights, requires_grad=not self.freeze) + else: + w = torch.zeros(len(self.token_mapping)).float() + w[self.pad_id] = -10_000 + return nn.Parameter(w, requires_grad=not self.freeze_weights) def construct_head(self) -> nn.Sequential: + """Constructs a simple classifier head.""" + return self.construct_mlp(self.n_layers, self.embed_dim, self.hidden_dim, self.out_dim) + + @staticmethod + def construct_mlp(n_layers: int, embed_dim: int, hidden_dim: int, out_dim: int) -> nn.Sequential: """Constructs a simple classifier head.""" modules: list[nn.Module] = [] - if self.n_layers == 0: - modules.append(nn.Linear(self.embed_dim, self.out_dim)) + if n_layers == 0: + modules.append(nn.Linear(embed_dim, out_dim)) else: # If we have a hidden layer, we should first project to hidden_dim modules = [ - nn.Linear(self.embed_dim, self.hidden_dim), + nn.Linear(embed_dim, hidden_dim), nn.ReLU(), ] - for _ in range(self.n_layers - 1): - modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()]) + for _ in range(n_layers - 1): + modules.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()]) # We always have a layer mapping from hidden to out. - modules.append(nn.Linear(self.hidden_dim, self.out_dim)) + modules.append(nn.Linear(hidden_dim, out_dim)) linear_modules = [module for module in modules if isinstance(module, nn.Linear)] if linear_modules: @@ -137,7 +144,11 @@ def _initialize(self) -> None: @classmethod def from_pretrained( - cls: type[ModelType], path: str = "minishlab/potion-base-32m", *, token: str | None = None, **kwargs: Any + cls: type[ModelType], + path: str = "minishlab/potion-base-32m", + *, + token: str | None = None, + **kwargs: Any, ) -> ModelType: """Load the model from a pretrained model2vec model.""" if model_name := kwargs.pop("model_name", None): @@ -148,7 +159,11 @@ def from_pretrained( @classmethod def from_static_model( - cls: type[ModelType], *, model: StaticModel, pad_token: str | None = None, **kwargs: Any + cls: type[ModelType], + *, + model: StaticModel, + pad_token: str | None = None, + **kwargs: Any, ) -> ModelType: """Load the model from a static model.""" model.embedding = np.nan_to_num(model.embedding) @@ -172,8 +187,7 @@ def from_static_model( ) def _encode(self, input_ids: torch.Tensor) -> torch.Tensor: - """ - A forward pass and mean pooling. + """A forward pass and mean pooling. This function is analogous to `StaticModel.encode`, but reimplemented to allow gradients to pass through. @@ -181,14 +195,15 @@ def _encode(self, input_ids: torch.Tensor) -> torch.Tensor: :param input_ids: A 2D tensor of input ids. All input ids are have to be within bounds. :return: The mean over the input ids, weighted by token weights. """ - w = self.w[input_ids] - w = torch.sigmoid(w) zeros = (input_ids != self.pad_id).float() - w = w * zeros # Add a small epsilon to avoid division by zero length = zeros.sum(1) + 1e-16 input_ids_embeddings = self.token_mapping[input_ids] embedded = self.embeddings(input_ids_embeddings) + + w = self.w[input_ids] + w = torch.sigmoid(w) + w = w * zeros # Weigh each token embedded = torch.bmm(w[:, None, :], embedded).squeeze(1) # Mean pooling by dividing by the length @@ -218,8 +233,7 @@ def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return self.head(encoded), encoded def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tensor: - """ - Tokenize a bunch of strings into a single padded 2D tensor. + """Tokenize a bunch of strings into a single padded 2D tensor. Note that this is not used during training. @@ -238,13 +252,22 @@ def device(self) -> torch.device: def to_static_model(self) -> StaticModel: """Convert the model to a static model.""" - emb = self.embeddings.weight.detach().cpu().numpy() - w = torch.sigmoid(self.w).detach().cpu().numpy() + with torch.no_grad(): + emb = self.embeddings.weight + emb = emb.detach().cpu().numpy() + if self.w is not None: + w = torch.sigmoid(self.w).detach().cpu().numpy() + else: + w = np.ones(len(emb)) # If the weights and emb are the same length, the model was not quantized before training. if len(w) == len(emb): emb = emb * w[:, None] return StaticModel( - vectors=emb, weights=None, tokenizer=self.tokenizer, normalize=self.normalize, token_mapping=None + vectors=emb, + weights=None, + tokenizer=self.tokenizer, + normalize=self.normalize, + token_mapping=None, ) return StaticModel( vectors=emb, @@ -268,7 +291,12 @@ def _determine_batch_size(self, batch_size: int | None, train_length: int) -> in return batch_size def _check_val_split( - self, X: list[str], y: list, X_val: list[str] | None, y_val: list | None, test_size: float + self, + X: list[str], + y: list, + X_val: list[str] | None, + y_val: list | None, + test_size: float, ) -> tuple[list[str], list[str], Sequence, Sequence]: if (X_val is not None) != (y_val is not None): raise ValueError("Both X_val and y_val must be provided together, or neither.") @@ -368,8 +396,7 @@ def _determine_val_check_interval( return val_check_interval, check_val_every_epoch def _prepare_dataset(self, X: list[str], y: torch.Tensor, max_length: int = 512) -> TextDataset: - """ - Prepare a dataset. + """Prepare a dataset. :param X: The texts. :param y: The labels. diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index e12d385..12c74d8 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -38,6 +38,7 @@ def __init__( weights: torch.Tensor | None = None, freeze: bool = False, normalize: bool = True, + freeze_weights: bool = False, ) -> None: """Initialize a standard classifier model.""" # Alias: Follows scikit-learn. Set to dummy classes @@ -55,6 +56,7 @@ def __init__( hidden_dim=hidden_dim, n_layers=n_layers, normalize=normalize, + freeze_weights=freeze_weights, ) @property @@ -65,8 +67,7 @@ def classes(self) -> np.ndarray: def predict( self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024, threshold: float = 0.5 ) -> np.ndarray: - """ - Predict labels for a set of texts. + """Predict labels for a set of texts. In single-label mode, each prediction is a single class. In multilabel mode, each prediction is a list of classes. @@ -93,8 +94,7 @@ def predict( return np.array(pred) def predict_proba(self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024) -> np.ndarray: - """ - Predict probabilities for each class. + """Predict probabilities for each class. In single-label mode, returns softmax probabilities. In multilabel mode, returns sigmoid probabilities. @@ -125,8 +125,7 @@ def fit( validation_steps: int | None = None, random_seed: int = _DEFAULT_RANDOM_SEED, ) -> StaticModelForClassification: - """ - Fit a model. + """Fit a model. This function creates a Lightning Trainer object and fits the model to the data. It supports both single-label and multi-label classification. @@ -222,8 +221,7 @@ def _determine_class_weight( def evaluate( self, X: list[str], y: LabelType, batch_size: int = 1024, threshold: float = 0.5, output_dict: bool = False ) -> str | dict[str, dict[str, float]]: - """ - Evaluate the classifier on a given dataset using scikit-learn's classification report. + """Evaluate the classifier on a given dataset using scikit-learn's classification report. :param X: The texts to predict on. :param y: The ground truth labels. @@ -239,8 +237,7 @@ def evaluate( return report def _initialize_on_labels(self, y: LabelType) -> None: - """ - Sets the output dimensionality, the classes, and initializes the head. + """Sets the output dimensionality, the classes, and initializes the head. :param y: The labels. :raises ValueError: If the labels are inconsistent. diff --git a/model2vec/train/similarity.py b/model2vec/train/similarity.py index 227eb4f..d61f55b 100644 --- a/model2vec/train/similarity.py +++ b/model2vec/train/similarity.py @@ -30,6 +30,7 @@ def __init__( weights: torch.Tensor | None = None, freeze: bool = False, normalize: bool = True, + freeze_weights: bool = False, ) -> None: """Initialize a standard similarity model.""" super().__init__( @@ -43,6 +44,7 @@ def __init__( hidden_dim=hidden_dim, n_layers=n_layers, normalize=normalize, + freeze_weights=freeze_weights, ) def fit( @@ -61,8 +63,7 @@ def fit( validation_steps: int | None = None, random_seed: int = _DEFAULT_RANDOM_SEED, ) -> StaticModelForSimilarity: - """ - Fit a model. + """Fit a model. This function creates a Lightning Trainer object and fits the model to the data. We use early stopping. After training, the weights of the best model are loaded back into the model. From fae4b11f16e3d82a16c29fce53b2e3431e310daa Mon Sep 17 00:00:00 2001 From: stephantul Date: Thu, 30 Apr 2026 10:09:38 +0200 Subject: [PATCH 2/3] update lock file --- uv.lock | 91 ++------------------------------------------------------- 1 file changed, 2 insertions(+), 89 deletions(-) diff --git a/uv.lock b/uv.lock index 9c5aa2f..dd3f354 100644 --- a/uv.lock +++ b/uv.lock @@ -209,50 +209,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/64/b4/17d4b0b2a2dc85a6df63d1157e028ed19f90d4cd97c36717afef2bc2f395/attrs-26.1.0-py3-none-any.whl", hash = "sha256:c647aa4a12dfbad9333ca4e71fe62ddc36f4e63b2d260a37a8b83d2f043ac309", size = 67548, upload-time = "2026-03-19T14:22:23.645Z" }, ] -[[package]] -name = "black" -version = "26.3.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click" }, - { name = "mypy-extensions" }, - { name = "packaging" }, - { name = "pathspec" }, - { name = "platformdirs" }, - { name = "pytokens" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/e1/c5/61175d618685d42b005847464b8fb4743a67b1b8fdb75e50e5a96c31a27a/black-26.3.1.tar.gz", hash = "sha256:2c50f5063a9641c7eed7795014ba37b0f5fa227f3d408b968936e24bc0566b07", size = 666155, upload-time = "2026-03-12T03:36:03.593Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/32/a8/11170031095655d36ebc6664fe0897866f6023892396900eec0e8fdc4299/black-26.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:86a8b5035fce64f5dcd1b794cf8ec4d31fe458cf6ce3986a30deb434df82a1d2", size = 1866562, upload-time = "2026-03-12T03:39:58.639Z" }, - { url = "https://files.pythonhosted.org/packages/69/ce/9e7548d719c3248c6c2abfd555d11169457cbd584d98d179111338423790/black-26.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5602bdb96d52d2d0672f24f6ffe5218795736dd34807fd0fd55ccd6bf206168b", size = 1703623, upload-time = "2026-03-12T03:40:00.347Z" }, - { url = "https://files.pythonhosted.org/packages/7f/0a/8d17d1a9c06f88d3d030d0b1d4373c1551146e252afe4547ed601c0e697f/black-26.3.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6c54a4a82e291a1fee5137371ab488866b7c86a3305af4026bdd4dc78642e1ac", size = 1768388, upload-time = "2026-03-12T03:40:01.765Z" }, - { url = "https://files.pythonhosted.org/packages/52/79/c1ee726e221c863cde5164f925bacf183dfdf0397d4e3f94889439b947b4/black-26.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:6e131579c243c98f35bce64a7e08e87fb2d610544754675d4a0e73a070a5aa3a", size = 1412969, upload-time = "2026-03-12T03:40:03.252Z" }, - { url = "https://files.pythonhosted.org/packages/73/a5/15c01d613f5756f68ed8f6d4ec0a1e24b82b18889fa71affd3d1f7fad058/black-26.3.1-cp310-cp310-win_arm64.whl", hash = "sha256:5ed0ca58586c8d9a487352a96b15272b7fa55d139fc8496b519e78023a8dab0a", size = 1220345, upload-time = "2026-03-12T03:40:04.892Z" }, - { url = "https://files.pythonhosted.org/packages/17/57/5f11c92861f9c92eb9dddf515530bc2d06db843e44bdcf1c83c1427824bc/black-26.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:28ef38aee69e4b12fda8dba75e21f9b4f979b490c8ac0baa7cb505369ac9e1ff", size = 1851987, upload-time = "2026-03-12T03:40:06.248Z" }, - { url = "https://files.pythonhosted.org/packages/54/aa/340a1463660bf6831f9e39646bf774086dbd8ca7fc3cded9d59bbdf4ad0a/black-26.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bf9bf162ed91a26f1adba8efda0b573bc6924ec1408a52cc6f82cb73ec2b142c", size = 1689499, upload-time = "2026-03-12T03:40:07.642Z" }, - { url = "https://files.pythonhosted.org/packages/f3/01/b726c93d717d72733da031d2de10b92c9fa4c8d0c67e8a8a372076579279/black-26.3.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:474c27574d6d7037c1bc875a81d9be0a9a4f9ee95e62800dab3cfaadbf75acd5", size = 1754369, upload-time = "2026-03-12T03:40:09.279Z" }, - { url = "https://files.pythonhosted.org/packages/e3/09/61e91881ca291f150cfc9eb7ba19473c2e59df28859a11a88248b5cbbc4d/black-26.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:5e9d0d86df21f2e1677cc4bd090cd0e446278bcbbe49bf3659c308c3e402843e", size = 1413613, upload-time = "2026-03-12T03:40:10.943Z" }, - { url = "https://files.pythonhosted.org/packages/16/73/544f23891b22e7efe4d8f812371ab85b57f6a01b2fc45e3ba2e52ba985b8/black-26.3.1-cp311-cp311-win_arm64.whl", hash = "sha256:9a5e9f45e5d5e1c5b5c29b3bd4265dcc90e8b92cf4534520896ed77f791f4da5", size = 1219719, upload-time = "2026-03-12T03:40:12.597Z" }, - { url = "https://files.pythonhosted.org/packages/dc/f8/da5eae4fc75e78e6dceb60624e1b9662ab00d6b452996046dfa9b8a6025b/black-26.3.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b5e6f89631eb88a7302d416594a32faeee9fb8fb848290da9d0a5f2903519fc1", size = 1895920, upload-time = "2026-03-12T03:40:13.921Z" }, - { url = "https://files.pythonhosted.org/packages/2c/9f/04e6f26534da2e1629b2b48255c264cabf5eedc5141d04516d9d68a24111/black-26.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:41cd2012d35b47d589cb8a16faf8a32ef7a336f56356babd9fcf70939ad1897f", size = 1718499, upload-time = "2026-03-12T03:40:15.239Z" }, - { url = "https://files.pythonhosted.org/packages/04/91/a5935b2a63e31b331060c4a9fdb5a6c725840858c599032a6f3aac94055f/black-26.3.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0f76ff19ec5297dd8e66eb64deda23631e642c9393ab592826fd4bdc97a4bce7", size = 1794994, upload-time = "2026-03-12T03:40:17.124Z" }, - { url = "https://files.pythonhosted.org/packages/e7/0a/86e462cdd311a3c2a8ece708d22aba17d0b2a0d5348ca34b40cdcbea512e/black-26.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:ddb113db38838eb9f043623ba274cfaf7d51d5b0c22ecb30afe58b1bb8322983", size = 1420867, upload-time = "2026-03-12T03:40:18.83Z" }, - { url = "https://files.pythonhosted.org/packages/5b/e5/22515a19cb7eaee3440325a6b0d95d2c0e88dd180cb011b12ae488e031d1/black-26.3.1-cp312-cp312-win_arm64.whl", hash = "sha256:dfdd51fc3e64ea4f35873d1b3fb25326773d55d2329ff8449139ebaad7357efb", size = 1230124, upload-time = "2026-03-12T03:40:20.425Z" }, - { url = "https://files.pythonhosted.org/packages/f5/77/5728052a3c0450c53d9bb3945c4c46b91baa62b2cafab6801411b6271e45/black-26.3.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:855822d90f884905362f602880ed8b5df1b7e3ee7d0db2502d4388a954cc8c54", size = 1895034, upload-time = "2026-03-12T03:40:21.813Z" }, - { url = "https://files.pythonhosted.org/packages/52/73/7cae55fdfdfbe9d19e9a8d25d145018965fe2079fa908101c3733b0c55a0/black-26.3.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8a33d657f3276328ce00e4d37fe70361e1ec7614da5d7b6e78de5426cb56332f", size = 1718503, upload-time = "2026-03-12T03:40:23.666Z" }, - { url = "https://files.pythonhosted.org/packages/e1/87/af89ad449e8254fdbc74654e6467e3c9381b61472cc532ee350d28cfdafb/black-26.3.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f1cd08e99d2f9317292a311dfe578fd2a24b15dbce97792f9c4d752275c1fa56", size = 1793557, upload-time = "2026-03-12T03:40:25.497Z" }, - { url = "https://files.pythonhosted.org/packages/43/10/d6c06a791d8124b843bf325ab4ac7d2f5b98731dff84d6064eafd687ded1/black-26.3.1-cp313-cp313-win_amd64.whl", hash = "sha256:c7e72339f841b5a237ff14f7d3880ddd0fc7f98a1199e8c4327f9a4f478c1839", size = 1422766, upload-time = "2026-03-12T03:40:27.14Z" }, - { url = "https://files.pythonhosted.org/packages/59/4f/40a582c015f2d841ac24fed6390bd68f0fc896069ff3a886317959c9daf8/black-26.3.1-cp313-cp313-win_arm64.whl", hash = "sha256:afc622538b430aa4c8c853f7f63bc582b3b8030fd8c80b70fb5fa5b834e575c2", size = 1232140, upload-time = "2026-03-12T03:40:28.882Z" }, - { url = "https://files.pythonhosted.org/packages/d5/da/e36e27c9cebc1311b7579210df6f1c86e50f2d7143ae4fcf8a5017dc8809/black-26.3.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:2d6bfaf7fd0993b420bed691f20f9492d53ce9a2bcccea4b797d34e947318a78", size = 1889234, upload-time = "2026-03-12T03:40:30.964Z" }, - { url = "https://files.pythonhosted.org/packages/0e/7b/9871acf393f64a5fa33668c19350ca87177b181f44bb3d0c33b2d534f22c/black-26.3.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:f89f2ab047c76a9c03f78d0d66ca519e389519902fa27e7a91117ef7611c0568", size = 1720522, upload-time = "2026-03-12T03:40:32.346Z" }, - { url = "https://files.pythonhosted.org/packages/03/87/e766c7f2e90c07fb7586cc787c9ae6462b1eedab390191f2b7fc7f6170a9/black-26.3.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b07fc0dab849d24a80a29cfab8d8a19187d1c4685d8a5e6385a5ce323c1f015f", size = 1787824, upload-time = "2026-03-12T03:40:33.636Z" }, - { url = "https://files.pythonhosted.org/packages/ac/94/2424338fb2d1875e9e83eed4c8e9c67f6905ec25afd826a911aea2b02535/black-26.3.1-cp314-cp314-win_amd64.whl", hash = "sha256:0126ae5b7c09957da2bdbd91a9ba1207453feada9e9fe51992848658c6c8e01c", size = 1445855, upload-time = "2026-03-12T03:40:35.442Z" }, - { url = "https://files.pythonhosted.org/packages/86/43/0c3338bd928afb8ee7471f1a4eec3bdbe2245ccb4a646092a222e8669840/black-26.3.1-cp314-cp314-win_arm64.whl", hash = "sha256:92c0ec1f2cc149551a2b7b47efc32c866406b6891b0ee4625e95967c8f4acfb1", size = 1258109, upload-time = "2026-03-12T03:40:36.832Z" }, - { url = "https://files.pythonhosted.org/packages/8e/0d/52d98722666d6fc6c3dd4c76df339501d6efd40e0ff95e6186a7b7f0befd/black-26.3.1-py3-none-any.whl", hash = "sha256:2bd5aa94fc267d38bb21a70d7410a89f1a1d318841855f698746f8e7f51acd1b", size = 207542, upload-time = "2026-03-12T03:36:01.668Z" }, -] - [[package]] name = "certifi" version = "2026.2.25" @@ -1144,16 +1100,13 @@ dependencies = [ { name = "joblib" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "rich" }, { name = "safetensors" }, - { name = "setuptools" }, { name = "tokenizers" }, { name = "tqdm" }, ] [package.optional-dependencies] dev = [ - { name = "black" }, { name = "ipython", version = "8.39.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "ipython", version = "9.10.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, { name = "ipython", version = "9.12.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, @@ -1162,6 +1115,7 @@ dev = [ { name = "pytest" }, { name = "pytest-cov" }, { name = "ruff" }, + { name = "setuptools" }, ] distill = [ { name = "scikit-learn", version = "1.7.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, @@ -1193,7 +1147,6 @@ train = [ [package.metadata] requires-dist = [ - { name = "black", marker = "extra == 'dev'" }, { name = "ipython", marker = "extra == 'dev'" }, { name = "jinja2" }, { name = "joblib" }, @@ -1204,14 +1157,13 @@ requires-dist = [ { name = "pre-commit", marker = "extra == 'dev'" }, { name = "pytest", marker = "extra == 'dev'" }, { name = "pytest-cov", marker = "extra == 'dev'" }, - { name = "rich" }, { name = "ruff", marker = "extra == 'dev'" }, { name = "safetensors" }, { name = "scikit-learn", marker = "extra == 'distill'" }, { name = "scikit-learn", marker = "extra == 'inference'" }, { name = "scikit-learn", marker = "extra == 'quantization'" }, { name = "scikit-learn", marker = "extra == 'train'" }, - { name = "setuptools" }, + { name = "setuptools", marker = "extra == 'dev'" }, { name = "skeletoken", marker = "extra == 'distill'", specifier = ">=0.3.3" }, { name = "skops", marker = "extra == 'inference'" }, { name = "skops", marker = "extra == 'train'" }, @@ -2212,45 +2164,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/67/0f/019d3949a40280f6193b62bc010177d4ce702d0fce424322286488569cd3/python_discovery-1.2.1-py3-none-any.whl", hash = "sha256:b6a957b24c1cd79252484d3566d1b49527581d46e789aaf43181005e56201502", size = 31674, upload-time = "2026-03-26T22:30:43.396Z" }, ] -[[package]] -name = "pytokens" -version = "0.4.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b6/34/b4e015b99031667a7b960f888889c5bd34ef585c85e1cb56a594b92836ac/pytokens-0.4.1.tar.gz", hash = "sha256:292052fe80923aae2260c073f822ceba21f3872ced9a68bb7953b348e561179a", size = 23015, upload-time = "2026-01-30T01:03:45.924Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/42/24/f206113e05cb8ef51b3850e7ef88f20da6f4bf932190ceb48bd3da103e10/pytokens-0.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2a44ed93ea23415c54f3face3b65ef2b844d96aeb3455b8a69b3df6beab6acc5", size = 161522, upload-time = "2026-01-30T01:02:50.393Z" }, - { url = "https://files.pythonhosted.org/packages/d4/e9/06a6bf1b90c2ed81a9c7d2544232fe5d2891d1cd480e8a1809ca354a8eb2/pytokens-0.4.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:add8bf86b71a5d9fb5b89f023a80b791e04fba57960aa790cc6125f7f1d39dfe", size = 246945, upload-time = "2026-01-30T01:02:52.399Z" }, - { url = "https://files.pythonhosted.org/packages/69/66/f6fb1007a4c3d8b682d5d65b7c1fb33257587a5f782647091e3408abe0b8/pytokens-0.4.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:670d286910b531c7b7e3c0b453fd8156f250adb140146d234a82219459b9640c", size = 259525, upload-time = "2026-01-30T01:02:53.737Z" }, - { url = "https://files.pythonhosted.org/packages/04/92/086f89b4d622a18418bac74ab5db7f68cf0c21cf7cc92de6c7b919d76c88/pytokens-0.4.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:4e691d7f5186bd2842c14813f79f8884bb03f5995f0575272009982c5ac6c0f7", size = 262693, upload-time = "2026-01-30T01:02:54.871Z" }, - { url = "https://files.pythonhosted.org/packages/b4/7b/8b31c347cf94a3f900bdde750b2e9131575a61fdb620d3d3c75832262137/pytokens-0.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:27b83ad28825978742beef057bfe406ad6ed524b2d28c252c5de7b4a6dd48fa2", size = 103567, upload-time = "2026-01-30T01:02:56.414Z" }, - { url = "https://files.pythonhosted.org/packages/3d/92/790ebe03f07b57e53b10884c329b9a1a308648fc083a6d4a39a10a28c8fc/pytokens-0.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d70e77c55ae8380c91c0c18dea05951482e263982911fc7410b1ffd1dadd3440", size = 160864, upload-time = "2026-01-30T01:02:57.882Z" }, - { url = "https://files.pythonhosted.org/packages/13/25/a4f555281d975bfdd1eba731450e2fe3a95870274da73fb12c40aeae7625/pytokens-0.4.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4a58d057208cb9075c144950d789511220b07636dd2e4708d5645d24de666bdc", size = 248565, upload-time = "2026-01-30T01:02:59.912Z" }, - { url = "https://files.pythonhosted.org/packages/17/50/bc0394b4ad5b1601be22fa43652173d47e4c9efbf0044c62e9a59b747c56/pytokens-0.4.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b49750419d300e2b5a3813cf229d4e5a4c728dae470bcc89867a9ad6f25a722d", size = 260824, upload-time = "2026-01-30T01:03:01.471Z" }, - { url = "https://files.pythonhosted.org/packages/4e/54/3e04f9d92a4be4fc6c80016bc396b923d2a6933ae94b5f557c939c460ee0/pytokens-0.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d9907d61f15bf7261d7e775bd5d7ee4d2930e04424bab1972591918497623a16", size = 264075, upload-time = "2026-01-30T01:03:04.143Z" }, - { url = "https://files.pythonhosted.org/packages/d1/1b/44b0326cb5470a4375f37988aea5d61b5cc52407143303015ebee94abfd6/pytokens-0.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:ee44d0f85b803321710f9239f335aafe16553b39106384cef8e6de40cb4ef2f6", size = 103323, upload-time = "2026-01-30T01:03:05.412Z" }, - { url = "https://files.pythonhosted.org/packages/41/5d/e44573011401fb82e9d51e97f1290ceb377800fb4eed650b96f4753b499c/pytokens-0.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:140709331e846b728475786df8aeb27d24f48cbcf7bcd449f8de75cae7a45083", size = 160663, upload-time = "2026-01-30T01:03:06.473Z" }, - { url = "https://files.pythonhosted.org/packages/f0/e6/5bbc3019f8e6f21d09c41f8b8654536117e5e211a85d89212d59cbdab381/pytokens-0.4.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6d6c4268598f762bc8e91f5dbf2ab2f61f7b95bdc07953b602db879b3c8c18e1", size = 255626, upload-time = "2026-01-30T01:03:08.177Z" }, - { url = "https://files.pythonhosted.org/packages/bf/3c/2d5297d82286f6f3d92770289fd439956b201c0a4fc7e72efb9b2293758e/pytokens-0.4.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:24afde1f53d95348b5a0eb19488661147285ca4dd7ed752bbc3e1c6242a304d1", size = 269779, upload-time = "2026-01-30T01:03:09.756Z" }, - { url = "https://files.pythonhosted.org/packages/20/01/7436e9ad693cebda0551203e0bf28f7669976c60ad07d6402098208476de/pytokens-0.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5ad948d085ed6c16413eb5fec6b3e02fa00dc29a2534f088d3302c47eb59adf9", size = 268076, upload-time = "2026-01-30T01:03:10.957Z" }, - { url = "https://files.pythonhosted.org/packages/2e/df/533c82a3c752ba13ae7ef238b7f8cdd272cf1475f03c63ac6cf3fcfb00b6/pytokens-0.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:3f901fe783e06e48e8cbdc82d631fca8f118333798193e026a50ce1b3757ea68", size = 103552, upload-time = "2026-01-30T01:03:12.066Z" }, - { url = "https://files.pythonhosted.org/packages/cb/dc/08b1a080372afda3cceb4f3c0a7ba2bde9d6a5241f1edb02a22a019ee147/pytokens-0.4.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8bdb9d0ce90cbf99c525e75a2fa415144fd570a1ba987380190e8b786bc6ef9b", size = 160720, upload-time = "2026-01-30T01:03:13.843Z" }, - { url = "https://files.pythonhosted.org/packages/64/0c/41ea22205da480837a700e395507e6a24425151dfb7ead73343d6e2d7ffe/pytokens-0.4.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5502408cab1cb18e128570f8d598981c68a50d0cbd7c61312a90507cd3a1276f", size = 254204, upload-time = "2026-01-30T01:03:14.886Z" }, - { url = "https://files.pythonhosted.org/packages/e0/d2/afe5c7f8607018beb99971489dbb846508f1b8f351fcefc225fcf4b2adc0/pytokens-0.4.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:29d1d8fb1030af4d231789959f21821ab6325e463f0503a61d204343c9b355d1", size = 268423, upload-time = "2026-01-30T01:03:15.936Z" }, - { url = "https://files.pythonhosted.org/packages/68/d4/00ffdbd370410c04e9591da9220a68dc1693ef7499173eb3e30d06e05ed1/pytokens-0.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:970b08dd6b86058b6dc07efe9e98414f5102974716232d10f32ff39701e841c4", size = 266859, upload-time = "2026-01-30T01:03:17.458Z" }, - { url = "https://files.pythonhosted.org/packages/a7/c9/c3161313b4ca0c601eeefabd3d3b576edaa9afdefd32da97210700e47652/pytokens-0.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:9bd7d7f544d362576be74f9d5901a22f317efc20046efe2034dced238cbbfe78", size = 103520, upload-time = "2026-01-30T01:03:18.652Z" }, - { url = "https://files.pythonhosted.org/packages/8f/a7/b470f672e6fc5fee0a01d9e75005a0e617e162381974213a945fcd274843/pytokens-0.4.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:4a14d5f5fc78ce85e426aa159489e2d5961acf0e47575e08f35584009178e321", size = 160821, upload-time = "2026-01-30T01:03:19.684Z" }, - { url = "https://files.pythonhosted.org/packages/80/98/e83a36fe8d170c911f864bfded690d2542bfcfacb9c649d11a9e6eb9dc41/pytokens-0.4.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:97f50fd18543be72da51dd505e2ed20d2228c74e0464e4262e4899797803d7fa", size = 254263, upload-time = "2026-01-30T01:03:20.834Z" }, - { url = "https://files.pythonhosted.org/packages/0f/95/70d7041273890f9f97a24234c00b746e8da86df462620194cef1d411ddeb/pytokens-0.4.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dc74c035f9bfca0255c1af77ddd2d6ae8419012805453e4b0e7513e17904545d", size = 268071, upload-time = "2026-01-30T01:03:21.888Z" }, - { url = "https://files.pythonhosted.org/packages/da/79/76e6d09ae19c99404656d7db9c35dfd20f2086f3eb6ecb496b5b31163bad/pytokens-0.4.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:f66a6bbe741bd431f6d741e617e0f39ec7257ca1f89089593479347cc4d13324", size = 271716, upload-time = "2026-01-30T01:03:23.633Z" }, - { url = "https://files.pythonhosted.org/packages/79/37/482e55fa1602e0a7ff012661d8c946bafdc05e480ea5a32f4f7e336d4aa9/pytokens-0.4.1-cp314-cp314-win_amd64.whl", hash = "sha256:b35d7e5ad269804f6697727702da3c517bb8a5228afa450ab0fa787732055fc9", size = 104539, upload-time = "2026-01-30T01:03:24.788Z" }, - { url = "https://files.pythonhosted.org/packages/30/e8/20e7db907c23f3d63b0be3b8a4fd1927f6da2395f5bcc7f72242bb963dfe/pytokens-0.4.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:8fcb9ba3709ff77e77f1c7022ff11d13553f3c30299a9fe246a166903e9091eb", size = 168474, upload-time = "2026-01-30T01:03:26.428Z" }, - { url = "https://files.pythonhosted.org/packages/d6/81/88a95ee9fafdd8f5f3452107748fd04c24930d500b9aba9738f3ade642cc/pytokens-0.4.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:79fc6b8699564e1f9b521582c35435f1bd32dd06822322ec44afdeba666d8cb3", size = 290473, upload-time = "2026-01-30T01:03:27.415Z" }, - { url = "https://files.pythonhosted.org/packages/cf/35/3aa899645e29b6375b4aed9f8d21df219e7c958c4c186b465e42ee0a06bf/pytokens-0.4.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d31b97b3de0f61571a124a00ffe9a81fb9939146c122c11060725bd5aea79975", size = 303485, upload-time = "2026-01-30T01:03:28.558Z" }, - { url = "https://files.pythonhosted.org/packages/52/a0/07907b6ff512674d9b201859f7d212298c44933633c946703a20c25e9d81/pytokens-0.4.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:967cf6e3fd4adf7de8fc73cd3043754ae79c36475c1c11d514fc72cf5490094a", size = 306698, upload-time = "2026-01-30T01:03:29.653Z" }, - { url = "https://files.pythonhosted.org/packages/39/2a/cbbf9250020a4a8dd53ba83a46c097b69e5eb49dd14e708f496f548c6612/pytokens-0.4.1-cp314-cp314t-win_amd64.whl", hash = "sha256:584c80c24b078eec1e227079d56dc22ff755e0ba8654d8383b2c549107528918", size = 116287, upload-time = "2026-01-30T01:03:30.912Z" }, - { url = "https://files.pythonhosted.org/packages/c6/78/397db326746f0a342855b81216ae1f0a32965deccfd7c830a2dbc66d2483/pytokens-0.4.1-py3-none-any.whl", hash = "sha256:26cef14744a8385f35d0e095dc8b3a7583f6c953c2e3d269c7f82484bf5ad2de", size = 13729, upload-time = "2026-01-30T01:03:45.029Z" }, -] - [[package]] name = "pytorch-lightning" version = "2.6.1" From a3911915626d6f9ec64386cbfd39ea4d09eade7a Mon Sep 17 00:00:00 2001 From: stephantul Date: Thu, 30 Apr 2026 12:14:04 +0200 Subject: [PATCH 3/3] simplify --- model2vec/train/base.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/model2vec/train/base.py b/model2vec/train/base.py index ceefae7..368b4d1 100644 --- a/model2vec/train/base.py +++ b/model2vec/train/base.py @@ -101,25 +101,20 @@ def construct_weights(self) -> nn.Parameter: return nn.Parameter(w, requires_grad=not self.freeze_weights) def construct_head(self) -> nn.Sequential: - """Constructs a simple classifier head.""" - return self.construct_mlp(self.n_layers, self.embed_dim, self.hidden_dim, self.out_dim) - - @staticmethod - def construct_mlp(n_layers: int, embed_dim: int, hidden_dim: int, out_dim: int) -> nn.Sequential: """Constructs a simple classifier head.""" modules: list[nn.Module] = [] - if n_layers == 0: - modules.append(nn.Linear(embed_dim, out_dim)) + if self.n_layers == 0: + modules.append(nn.Linear(self.embed_dim, self.out_dim)) else: # If we have a hidden layer, we should first project to hidden_dim modules = [ - nn.Linear(embed_dim, hidden_dim), + nn.Linear(self.embed_dim, self.hidden_dim), nn.ReLU(), ] - for _ in range(n_layers - 1): - modules.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()]) + for _ in range(self.n_layers - 1): + modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()]) # We always have a layer mapping from hidden to out. - modules.append(nn.Linear(hidden_dim, out_dim)) + modules.append(nn.Linear(self.hidden_dim, self.out_dim)) linear_modules = [module for module in modules if isinstance(module, nn.Linear)] if linear_modules: @@ -254,11 +249,9 @@ def to_static_model(self) -> StaticModel: """Convert the model to a static model.""" with torch.no_grad(): emb = self.embeddings.weight - emb = emb.detach().cpu().numpy() - if self.w is not None: - w = torch.sigmoid(self.w).detach().cpu().numpy() - else: - w = np.ones(len(emb)) + emb = emb.cpu().numpy() + w = torch.sigmoid(self.w).cpu().numpy() + # If the weights and emb are the same length, the model was not quantized before training. if len(w) == len(emb): emb = emb * w[:, None]