diff --git a/libs/community/langchain_community/embeddings/__init__.py b/libs/community/langchain_community/embeddings/__init__.py index 9b9deba027c..f39787b3556 100644 --- a/libs/community/langchain_community/embeddings/__init__.py +++ b/libs/community/langchain_community/embeddings/__init__.py @@ -57,7 +57,10 @@ from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings from langchain_community.embeddings.llm_rails import LLMRailsEmbeddings from langchain_community.embeddings.localai import LocalAIEmbeddings from langchain_community.embeddings.minimax import MiniMaxEmbeddings -from langchain_community.embeddings.mlflow import MlflowEmbeddings +from langchain_community.embeddings.mlflow import ( + MlflowCohereEmbeddings, + MlflowEmbeddings, +) from langchain_community.embeddings.mlflow_gateway import MlflowAIGatewayEmbeddings from langchain_community.embeddings.modelscope_hub import ModelScopeEmbeddings from langchain_community.embeddings.mosaicml import MosaicMLInstructorEmbeddings @@ -102,6 +105,7 @@ __all__ = [ "LLMRailsEmbeddings", "HuggingFaceHubEmbeddings", "MlflowEmbeddings", + "MlflowCohereEmbeddings", "MlflowAIGatewayEmbeddings", "ModelScopeEmbeddings", "TensorflowHubEmbeddings", diff --git a/libs/community/langchain_community/embeddings/mlflow.py b/libs/community/langchain_community/embeddings/mlflow.py index 0ae46bcffdb..6b24dacb025 100644 --- a/libs/community/langchain_community/embeddings/mlflow.py +++ b/libs/community/langchain_community/embeddings/mlflow.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Iterator, List +from typing import Any, Dict, Iterator, List from urllib.parse import urlparse from langchain_core.embeddings import Embeddings @@ -34,6 +34,10 @@ class MlflowEmbeddings(Embeddings, BaseModel): target_uri: str """The target URI to use.""" _client: Any = PrivateAttr() + """The parameters to use for queries.""" + query_params: Dict[str, str] = {} + """The parameters to use for documents.""" + documents_params: Dict[str, str] = {} def __init__(self, **kwargs: Any): super().__init__(**kwargs) @@ -63,12 +67,22 @@ class MlflowEmbeddings(Embeddings, BaseModel): f"The scheme must be one of {allowed}." ) - def embed_documents(self, texts: List[str]) -> List[List[float]]: + def embed(self, texts: List[str], params: Dict[str, str]) -> List[List[float]]: embeddings: List[List[float]] = [] for txt in _chunk(texts, 20): - resp = self._client.predict(endpoint=self.endpoint, inputs={"input": txt}) + resp = self._client.predict( + endpoint=self.endpoint, inputs={"input": txt, **params} + ) embeddings.extend(r["embedding"] for r in resp["data"]) return embeddings + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return self.embed(texts, params=self.documents_params) + def embed_query(self, text: str) -> List[float]: - return self.embed_documents([text])[0] + return self.embed([text], params=self.query_params)[0] + + +class MlflowCohereEmbeddings(MlflowEmbeddings): + query_params: Dict[str, str] = {"input_type": "search_query"} + documents_params: Dict[str, str] = {"input_type": "search_document"} diff --git a/libs/community/tests/unit_tests/embeddings/test_imports.py b/libs/community/tests/unit_tests/embeddings/test_imports.py index 6aac6609a99..dee9b1ba836 100644 --- a/libs/community/tests/unit_tests/embeddings/test_imports.py +++ b/libs/community/tests/unit_tests/embeddings/test_imports.py @@ -18,6 +18,7 @@ EXPECTED_ALL = [ "HuggingFaceHubEmbeddings", "MlflowAIGatewayEmbeddings", "MlflowEmbeddings", + "MlflowCohereEmbeddings", "ModelScopeEmbeddings", "TensorflowHubEmbeddings", "SagemakerEndpointEmbeddings",