mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
Bug-Fix[Community] Fix FastEmbedEmbeddings
(#26764)
#26759 - Fix https://github.com/langchain-ai/langchain/issues/26759 - Change `model` param from private to public, which may not be initiated. - Add test case
This commit is contained in:
parent
a8e1577f85
commit
e317d457cf
@ -65,7 +65,7 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
|||||||
Defaults to `None`.
|
Defaults to `None`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_model: Any = None # : :meta private:
|
model: Any = None # : :meta private:
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow", protected_namespaces=())
|
model_config = ConfigDict(extra="allow", protected_namespaces=())
|
||||||
|
|
||||||
@ -91,7 +91,7 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
|||||||
'FastEmbedEmbeddings requires `pip install -U "fastembed>=0.2.0"`.'
|
'FastEmbedEmbeddings requires `pip install -U "fastembed>=0.2.0"`.'
|
||||||
)
|
)
|
||||||
|
|
||||||
values["_model"] = fastembed.TextEmbedding(
|
values["model"] = fastembed.TextEmbedding(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
@ -110,11 +110,11 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
|||||||
"""
|
"""
|
||||||
embeddings: List[np.ndarray]
|
embeddings: List[np.ndarray]
|
||||||
if self.doc_embed_type == "passage":
|
if self.doc_embed_type == "passage":
|
||||||
embeddings = self._model.passage_embed(
|
embeddings = self.model.passage_embed(
|
||||||
texts, batch_size=self.batch_size, parallel=self.parallel
|
texts, batch_size=self.batch_size, parallel=self.parallel
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
embeddings = self._model.embed(
|
embeddings = self.model.embed(
|
||||||
texts, batch_size=self.batch_size, parallel=self.parallel
|
texts, batch_size=self.batch_size, parallel=self.parallel
|
||||||
)
|
)
|
||||||
return [e.tolist() for e in embeddings]
|
return [e.tolist() for e in embeddings]
|
||||||
@ -129,7 +129,7 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
|||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
query_embeddings: np.ndarray = next(
|
query_embeddings: np.ndarray = next(
|
||||||
self._model.query_embed(
|
self.model.query_embed(
|
||||||
text, batch_size=self.batch_size, parallel=self.parallel
|
text, batch_size=self.batch_size, parallel=self.parallel
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -80,3 +80,11 @@ async def test_fastembed_async_embedding_query(
|
|||||||
embedding = FastEmbedEmbeddings(model_name=model_name, max_length=max_length) # type: ignore[call-arg]
|
embedding = FastEmbedEmbeddings(model_name=model_name, max_length=max_length) # type: ignore[call-arg]
|
||||||
output = await embedding.aembed_query(document)
|
output = await embedding.aembed_query(document)
|
||||||
assert len(output) == 384
|
assert len(output) == 384
|
||||||
|
|
||||||
|
|
||||||
|
def test_fastembed_embedding_query_with_default_params() -> None:
|
||||||
|
"""Test fastembed embeddings for query with default model params"""
|
||||||
|
document = "foo bar"
|
||||||
|
embedding = FastEmbedEmbeddings() # type: ignore[call-arg]
|
||||||
|
output = embedding.embed_query(document)
|
||||||
|
assert len(output) == 384
|
||||||
|
Loading…
Reference in New Issue
Block a user