diff --git a/libs/community/langchain_community/embeddings/fastembed.py b/libs/community/langchain_community/embeddings/fastembed.py index 288061b67cb..28105dd7208 100644 --- a/libs/community/langchain_community/embeddings/fastembed.py +++ b/libs/community/langchain_community/embeddings/fastembed.py @@ -65,7 +65,7 @@ class FastEmbedEmbeddings(BaseModel, Embeddings): Defaults to `None`. """ - _model: Any = None # : :meta private: + model: Any = None # : :meta private: model_config = ConfigDict(extra="allow", protected_namespaces=()) @@ -91,7 +91,7 @@ class FastEmbedEmbeddings(BaseModel, Embeddings): 'FastEmbedEmbeddings requires `pip install -U "fastembed>=0.2.0"`.' ) - values["_model"] = fastembed.TextEmbedding( + values["model"] = fastembed.TextEmbedding( model_name=model_name, max_length=max_length, cache_dir=cache_dir, @@ -110,11 +110,11 @@ class FastEmbedEmbeddings(BaseModel, Embeddings): """ embeddings: List[np.ndarray] if self.doc_embed_type == "passage": - embeddings = self._model.passage_embed( + embeddings = self.model.passage_embed( texts, batch_size=self.batch_size, parallel=self.parallel ) else: - embeddings = self._model.embed( + embeddings = self.model.embed( texts, batch_size=self.batch_size, parallel=self.parallel ) return [e.tolist() for e in embeddings] @@ -129,7 +129,7 @@ class FastEmbedEmbeddings(BaseModel, Embeddings): Embeddings for the text. """ query_embeddings: np.ndarray = next( - self._model.query_embed( + self.model.query_embed( text, batch_size=self.batch_size, parallel=self.parallel ) ) diff --git a/libs/community/tests/integration_tests/embeddings/test_fastembed.py b/libs/community/tests/integration_tests/embeddings/test_fastembed.py index 09cede659bf..f39c4cedfb4 100644 --- a/libs/community/tests/integration_tests/embeddings/test_fastembed.py +++ b/libs/community/tests/integration_tests/embeddings/test_fastembed.py @@ -80,3 +80,11 @@ async def test_fastembed_async_embedding_query( embedding = FastEmbedEmbeddings(model_name=model_name, max_length=max_length) # type: ignore[call-arg] output = await embedding.aembed_query(document) assert len(output) == 384 + + +def test_fastembed_embedding_query_with_default_params() -> None: + """Test fastembed embeddings for query with default model params""" + document = "foo bar" + embedding = FastEmbedEmbeddings() # type: ignore[call-arg] + output = embedding.embed_query(document) + assert len(output) == 384