Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2506f4f
Added backend class for SparseEncoder and also SentenceTransformersSp…
Ryzhtus Jul 3, 2025
abd7ea5
Added SentenceTransformersSparseDocumentEmbedder
Ryzhtus Jul 3, 2025
82b87c2
Created a separate _SentenceTransformersSparseEmbeddingBackendFactory…
Ryzhtus Aug 3, 2025
73eaa97
Remove unused parameter
Ryzhtus Aug 3, 2025
74c222e
Wrapped output into SparseEmbedding dataclass + fix tests
Ryzhtus Aug 3, 2025
4ddde78
Return correct SparseEmbedding, imports and tests
Ryzhtus Aug 22, 2025
3ed6005
Merge branch 'main' into feat/support_sparse_models_in_sentence_trans…
Ryzhtus Aug 22, 2025
341767b
Merge branch 'main' into st-sparse
anakin87 Aug 28, 2025
71950af
fix fmt
anakin87 Aug 28, 2025
3c08b33
Merge branch 'deepset-ai:main' into feat/support_sparse_models_in_sen…
Ryzhtus Sep 6, 2025
a469c8f
Style changes and fixes
Ryzhtus Sep 6, 2025
be29552
Added a test for embed function
Ryzhtus Sep 14, 2025
f7536f9
Added integration test and fixed some other tests
Ryzhtus Sep 14, 2025
69dfa63
Merge branch 'main' into feat/support_sparse_models_in_sentence_trans…
Ryzhtus Sep 14, 2025
90dd503
Add lint fixes
Ryzhtus Sep 14, 2025
555e897
Merge branch 'feat/support_sparse_models_in_sentence_transformers' of…
Ryzhtus Sep 14, 2025
21313ce
Fixed positional arguments
Ryzhtus Sep 14, 2025
832b76b
Merge branch 'main' into st-sparse
anakin87 Sep 18, 2025
60e2805
fix types, simplify and more
anakin87 Sep 18, 2025
3620d09
fix
anakin87 Sep 18, 2025
d95b0e9
token fixes
anakin87 Sep 19, 2025
527e24b
pydocs, small model in test, cache improvement
anakin87 Sep 19, 2025
06e8d8b
try 3.9 for docs
anakin87 Sep 19, 2025
044652d
better to pin click
anakin87 Sep 19, 2025
cfc7dda
release note
anakin87 Sep 19, 2025
4e7850c
small fix
anakin87 Sep 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Return correct SparseEmbedding, imports and tests
  • Loading branch information
Ryzhtus committed Aug 22, 2025
commit 4ddde7844e9ed93082d9df92e0101bded6f29262
8 changes: 8 additions & 0 deletions haystack/components/embedders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
"openai_text_embedder": ["OpenAITextEmbedder"],
"sentence_transformers_document_embedder": ["SentenceTransformersDocumentEmbedder"],
"sentence_transformers_text_embedder": ["SentenceTransformersTextEmbedder"],
"sentence_transformers_sparse_document_embedder": ["SentenceTransformersSparseDocumentEmbedder"],
"sentence_transformers_sparse_text_embedder": ["SentenceTransformersSparseTextEmbedder"],
}

if TYPE_CHECKING:
Expand All @@ -28,6 +30,12 @@
from .sentence_transformers_document_embedder import (
SentenceTransformersDocumentEmbedder as SentenceTransformersDocumentEmbedder,
)
from .sentence_transformers_sparse_document_embedder import (
SentenceTransformersSparseDocumentEmbedder as SentenceTransformersSparseDocumentEmbedder,
)
from .sentence_transformers_sparse_text_embedder import (
SentenceTransformersSparseTextEmbedder as SentenceTransformersSparseTextEmbedder,
)
from .sentence_transformers_text_embedder import (
SentenceTransformersTextEmbedder as SentenceTransformersTextEmbedder,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,17 @@ def __init__( # pylint: disable=too-many-positional-arguments
)

def embed(self, data: List[str], **kwargs) -> List[SparseEmbedding]:
embeddings = self.model.encode(data, **kwargs)

sparse_embeddings = []

if isinstance(embeddings, list):
for embedding in embeddings:
sparse_embeddings.append(
SparseEmbedding(indices=embedding.indices.tolist(), values=embedding.values.tolist())
)
else:
sparse_embeddings.append(
SparseEmbedding(indices=embeddings.indices.tolist(), values=embeddings.values.tolist())
)
embeddings = self.model.encode(data, **kwargs).coalesce()

rows, columns = embeddings.indices()
values = embeddings.values()
batch_size = embeddings.size(0)

sparse_embeddings: List[SparseEmbedding] = []
for embedding in range(batch_size):
mask = rows == embedding
embedding_columns = columns[mask].tolist()
embedding_values = values[mask].tolist()
sparse_embeddings.append(SparseEmbedding(indices=embedding_columns, values=embedding_values))

return sparse_embeddings
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def warm_up(self):
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
backend=self.backend,
sparse=True,
)
if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from unittest.mock import patch

import pytest
import torch

from haystack.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
Expand Down Expand Up @@ -101,6 +102,10 @@ def test_embedding_function_with_kwargs(mock_sentence_transformer):

@patch("haystack.components.embedders.backends.sentence_transformers_backend.SparseEncoder")
def test_sparse_embedding_function_with_kwargs(mock_sparse_encoder):
indices = torch.tensor([[0, 1], [1, 3]])
values = torch.tensor([0.5, 0.7])
mock_sparse_encoder.return_value.encode.return_value = torch.sparse_coo_tensor(indices, values, (2, 5))

embedding_backend = _SentenceTransformersSparseEmbeddingBackendFactory.get_embedding_backend(model="model")

data = ["sentence1", "sentence2"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
)
from haystack.utils import ComponentDevice, Secret

from haystack.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersSparseEmbeddingBackendFactory,
_SentenceTransformersSparseEncoderEmbeddingBackend,
)


class TestSentenceTransformersDocumentEmbedder:
def test_init_default(self):
Expand Down Expand Up @@ -210,7 +215,7 @@ def test_from_dict_none_device(self):
assert component.meta_fields_to_embed == ["meta_field"]

@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
"haystack.components.embedders.sentence_transformers_sparse_document_embedder._SentenceTransformersSparseEmbeddingBackendFactory"
)
def test_warmup(self, mocked_factory):
embedder = SentenceTransformersSparseDocumentEmbedder(
Expand All @@ -236,7 +241,7 @@ def test_warmup(self, mocked_factory):
)

@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
"haystack.components.embedders.sentence_transformers_sparse_document_embedder._SentenceTransformersSparseEmbeddingBackendFactory"
)
def test_warmup_doesnt_reload(self, mocked_factory):
embedder = SentenceTransformersSparseDocumentEmbedder(model="model")
Expand Down Expand Up @@ -322,7 +327,7 @@ def test_prefix_suffix(self):
)

@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
"haystack.components.embedders.sentence_transformers_sparse_document_embedder._SentenceTransformersSparseEmbeddingBackendFactory"
)
def test_model_onnx_backend(self, mocked_factory):
onnx_embedder = SentenceTransformersSparseDocumentEmbedder(
Expand All @@ -349,7 +354,7 @@ def test_model_onnx_backend(self, mocked_factory):
)

@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
"haystack.components.embedders.sentence_transformers_sparse_document_embedder._SentenceTransformersSparseEmbeddingBackendFactory"
)
def test_model_openvino_backend(self, mocked_factory):
openvino_embedder = SentenceTransformersSparseDocumentEmbedder(
Expand All @@ -376,7 +381,7 @@ def test_model_openvino_backend(self, mocked_factory):
)

@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
"haystack.components.embedders.sentence_transformers_sparse_document_embedder._SentenceTransformersSparseEmbeddingBackendFactory"
)
@pytest.mark.parametrize("model_kwargs", [{"torch_dtype": "bfloat16"}, {"torch_dtype": "float16"}])
def test_dtype_on_gpu(self, mocked_factory, model_kwargs):
Expand Down
Loading