Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,36 +1,37 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
minimum_pre_commit_version: "3.2.0"

default_language_version:
python: python3

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v6.0.0
hooks:
- id: no-commit-to-branch
args: [--branch, main]
- id: check-ast
description: Simply check whether files parse as valid python.
- id: check-yaml
- id: check-toml
- id: check-json
- id: check-merge-conflict
- id: debug-statements
- id: trailing-whitespace
description: Trims trailing whitespace
- id: end-of-file-fixer
description: Makes sure files end in a newline and only a newline.
- id: check-added-large-files
args: ['--maxkb=5000']
description: Prevent giant files from being committed.
args: ["--maxkb=5000"]
- id: check-case-conflict
description: Check for files with names that would conflict on case-insensitive filesystems like MacOS/Windows.
- id: check-yaml
description: Check yaml files for syntax errors.
- repo: https://github.com/jsh9/pydoclint
rev: 0.5.3
rev: 0.8.3
hooks:
- id: pydoclint
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.13.0
rev: v0.15.12
hooks:
- id: ruff
args: [ --fix ]
- id: ruff-check
args: [--fix]
- id: ruff-format
- repo: local
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
hooks:
- id: mypy
name: mypy
entry: mypy
language: python
types: [python]
additional_dependencies: []
10 changes: 6 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
clean:

VERBOSITY=

venv:
uv venv
Expand All @@ -9,7 +8,7 @@ install:
uv run pre-commit install

install-no-pre-commit:
uv pip install ".[dev,distill,inference,train]"
uv pip install ".[dev,distill,inference,train,onnx,quantization]"

install-base:
uv sync --extra dev
Expand All @@ -18,4 +17,7 @@ fix:
uv run pre-commit run --all-files

test:
uv run pytest --cov=model2vec --cov-report=term-missing
uv run pytest --cov=model2vec --cov-report=term-missing $(VERBOSITY)

test-verbose:
make test VERBOSITY="-vvv"
13 changes: 5 additions & 8 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ def distill_from_model(
vocabulary_quantization: int | None = None,
pooling: PoolingMode | str = PoolingMode.MEAN,
) -> StaticModel:
"""
Distill a staticmodel from a sentence transformer.
"""Distill a staticmodel from a sentence transformer.

This function creates a set of embeddings from a sentence transformer. It does this by doing either
a forward pass for all subword tokens in the tokenizer, or by doing a forward pass for all tokens in a passed
Expand Down Expand Up @@ -65,7 +64,7 @@ def distill_from_model(
'first': use the first token's hidden state ([CLS] token in BERT-style models).
'pooler': use the pooler output (if available). This is often a non-linear projection of the [CLS] token.
:return: A StaticModel.
:raises: ValueError if the vocabulary is empty after preprocessing.
:raises ValueError: if the vocabulary is empty after preprocessing.

"""
quantize_to = DType(quantize_to)
Expand Down Expand Up @@ -168,15 +167,14 @@ def _validate_parameters(
sif_coefficient: float | None,
token_remove_pattern: str | None,
) -> tuple[float | None, re.Pattern[str] | None]:
"""
Validate the parameters passed to the distillation function.
"""Validate the parameters passed to the distillation function.

:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to
this regex pattern will be removed from the vocabulary.
:return: The SIF coefficient to use.
:raises: ValueError if the regex can't be compiled.
:raises ValueError: if the regex can't be compiled.

"""
if sif_coefficient is not None:
Expand Down Expand Up @@ -205,8 +203,7 @@ def distill(
vocabulary_quantization: int | None = None,
pooling: PoolingMode | str = PoolingMode.MEAN,
) -> StaticModel:
"""
Distill a staticmodel from a sentence transformer.
"""Distill a staticmodel from a sentence transformer.

This function creates a set of embeddings from a sentence transformer. It does this by doing either
a forward pass for all subword tokens in the tokenizer, or by doing a forward pass for all tokens in a passed
Expand Down
21 changes: 7 additions & 14 deletions model2vec/distill/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@


class PoolingMode(str, Enum):
"""
Pooling modes for embedding creation.
"""Pooling modes for embedding creation.

- MEAN: masked mean over all tokens.
- LAST: last non-padding token (often EOS, common in decoder-style models).
Expand All @@ -48,8 +47,7 @@ def create_embeddings(
pad_token_id: int,
pooling: PoolingMode | str = PoolingMode.MEAN,
) -> np.ndarray:
"""
Create output embeddings for a bunch of tokens using a pretrained model.
"""Create output embeddings for a bunch of tokens using a pretrained model.

It does a forward pass for all tokens passed in `tokens`.

Expand Down Expand Up @@ -121,8 +119,7 @@ def create_embeddings(
def _encode_with_model(
model: PreTrainedModel, encodings: dict[str, torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor | None, dict[str, torch.Tensor]]:
"""
Move inputs to the model device, run a forward pass, and standardize dtypes.
"""Move inputs to the model device, run a forward pass, and standardize dtypes.

:param model: The model to use.
:param encodings: The encoded tokens to turn into features.
Expand All @@ -146,8 +143,7 @@ def _encode_with_model(

@torch.inference_mode()
def _encode_mean_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Encode a batch of tokens using mean pooling.
"""Encode a batch of tokens using mean pooling.

:param model: The model to use.
:param encodings: The encoded tokens to turn into features.
Expand All @@ -163,8 +159,7 @@ def _encode_mean_with_model(model: PreTrainedModel, encodings: dict[str, torch.T

@torch.inference_mode()
def _encode_last_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Encode a batch of tokens using last token pooling.
"""Encode a batch of tokens using last token pooling.

:param model: The model to use.
:param encodings: The encoded tokens to turn into features.
Expand All @@ -179,8 +174,7 @@ def _encode_last_with_model(model: PreTrainedModel, encodings: dict[str, torch.T

@torch.inference_mode()
def _encode_first_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Encode a batch of tokens using first token (CLS) pooling.
"""Encode a batch of tokens using first token (CLS) pooling.

:param model: The model to use.
:param encodings: The encoded tokens to turn into features.
Expand All @@ -192,8 +186,7 @@ def _encode_first_with_model(model: PreTrainedModel, encodings: dict[str, torch.

@torch.inference_mode()
def _encode_pooler_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Encode a batch of tokens using pooler output.
"""Encode a batch of tokens using pooler output.

:param model: The model to use.
:param encodings: The encoded tokens to turn into features.
Expand Down
3 changes: 1 addition & 2 deletions model2vec/distill/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@


def select_optimal_device(device: str | None) -> str:
"""
Get the optimal device to use based on backend availability.
"""Get the optimal device to use based on backend availability.

For Torch versions >= 2.8.0, MPS is disabled due to known performance regressions.

Expand Down
24 changes: 8 additions & 16 deletions model2vec/inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def __init__(self, model: StaticModel, head: Pipeline) -> None:
def from_pretrained(
cls: type[StaticModelPipeline], path: PathLike, token: str | None = None, trust_remote_code: bool = False
) -> StaticModelPipeline:
"""
Load a StaticModel from a local path or huggingface hub path.
"""Load a StaticModel from a local path or huggingface hub path.

NOTE: if you load a private model from the huggingface hub, you need to pass a token.

Expand All @@ -74,8 +73,7 @@ def save_pretrained(self, path: str) -> None:
def push_to_hub(
self, repo_id: str, subfolder: str | None = None, token: str | None = None, private: bool = False
) -> None:
"""
Save a model to a folder, and then push that folder to the hf hub.
"""Save a model to a folder, and then push that folder to the hf hub.

:param repo_id: The id of the repository to push to.
:param subfolder: The subfolder to push to.
Expand Down Expand Up @@ -122,8 +120,7 @@ def predict(
multiprocessing_threshold: int = 10_000,
threshold: float = 0.5,
) -> np.ndarray:
"""
Predict the labels of the input.
"""Predict the labels of the input.

:param X: The input data to predict. Can be a list of strings or a single string.
:param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False.
Expand Down Expand Up @@ -162,8 +159,7 @@ def predict_proba(
use_multiprocessing: bool = True,
multiprocessing_threshold: int = 10_000,
) -> np.ndarray:
"""
Predict the labels of the input.
"""Predict the labels of the input.

:param X: The input data to predict. Can be a list of strings or a single string.
:param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False.
Expand All @@ -190,8 +186,7 @@ def predict_proba(
def evaluate(
self, X: Sequence[str], y: LabelType, batch_size: int = 1024, threshold: float = 0.5, output_dict: bool = False
) -> str | dict[str, dict[str, float]]:
"""
Evaluate the classifier on a given dataset using scikit-learn's classification report.
"""Evaluate the classifier on a given dataset using scikit-learn's classification report.

:param X: The texts to predict on.
:param y: The ground truth labels.
Expand All @@ -212,8 +207,7 @@ def evaluate(
def _load_pipeline(
folder_or_repo_path: PathLike, token: str | None = None, trust_remote_code: bool = False
) -> tuple[StaticModel, Pipeline]:
"""
Load a model and an sklearn pipeline.
"""Load a model and an sklearn pipeline.

This assumes the following files are present in the repo:
- `pipeline.skops`: The head of the pipeline.
Expand Down Expand Up @@ -259,8 +253,7 @@ def _load_pipeline(


def save_pipeline(pipeline: StaticModelPipeline, folder_path: str | Path) -> None:
"""
Save a pipeline to a folder.
"""Save a pipeline to a folder.

:param pipeline: The pipeline to save.
:param folder_path: The path to the folder to save the pipeline to.
Expand Down Expand Up @@ -296,8 +289,7 @@ def evaluate_single_or_multi_label(
y: list[int] | list[str] | list[list[int]] | list[list[str]],
output_dict: bool = False,
) -> str | dict[str, dict[str, float]]:
"""
Evaluate the classifier on a given dataset using scikit-learn's classification report.
"""Evaluate the classifier on a given dataset using scikit-learn's classification report.

:param predictions: The predictions.
:param y: The ground truth labels.
Expand Down
Loading
Loading