Bug-Fix[Community] Fix FastEmbedEmbeddings (#26764)

#26759 

- Fix https://github.com/langchain-ai/langchain/issues/26759 
- Change `model` param from private to public, which may not be
initiated.
- Add test case
This commit is contained in:
ZhangShenao 2024-10-01 09:23:08 +08:00 committed by GitHub
parent a8e1577f85
commit e317d457cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 5 deletions

View File

@ -65,7 +65,7 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
Defaults to `None`.
"""
_model: Any = None # : :meta private:
model: Any = None # : :meta private:
model_config = ConfigDict(extra="allow", protected_namespaces=())
@ -91,7 +91,7 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
'FastEmbedEmbeddings requires `pip install -U "fastembed>=0.2.0"`.'
)
values["_model"] = fastembed.TextEmbedding(
values["model"] = fastembed.TextEmbedding(
model_name=model_name,
max_length=max_length,
cache_dir=cache_dir,
@ -110,11 +110,11 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
"""
embeddings: List[np.ndarray]
if self.doc_embed_type == "passage":
embeddings = self._model.passage_embed(
embeddings = self.model.passage_embed(
texts, batch_size=self.batch_size, parallel=self.parallel
)
else:
embeddings = self._model.embed(
embeddings = self.model.embed(
texts, batch_size=self.batch_size, parallel=self.parallel
)
return [e.tolist() for e in embeddings]
@ -129,7 +129,7 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
Embeddings for the text.
"""
query_embeddings: np.ndarray = next(
self._model.query_embed(
self.model.query_embed(
text, batch_size=self.batch_size, parallel=self.parallel
)
)

View File

@ -80,3 +80,11 @@ async def test_fastembed_async_embedding_query(
embedding = FastEmbedEmbeddings(model_name=model_name, max_length=max_length) # type: ignore[call-arg]
output = await embedding.aembed_query(document)
assert len(output) == 384
def test_fastembed_embedding_query_with_default_params() -> None:
"""Test fastembed embeddings for query with default model params"""
document = "foo bar"
embedding = FastEmbedEmbeddings() # type: ignore[call-arg]
output = embedding.embed_query(document)
assert len(output) == 384