From 6b2a57161af92965ca4c00c06508a88372967cf0 Mon Sep 17 00:00:00 2001 From: Eli Lucherini Date: Mon, 22 Jan 2024 11:38:11 -0800 Subject: [PATCH] community[patch]: allow additional kwargs in MlflowEmbeddings for compatibility with Cohere API (#15242) - **Description:** add support for kwargs in`MlflowEmbeddings` `embed_document()` and `embed_query()` so that all the arguments required by Cohere API (and others?) can be passed down to the server. - **Issue:** #15234 - **Dependencies:** MLflow with MLflow Deployments (`pip install mlflow[genai]`) **Tests** Now this code [adapted from the docs](https://python.langchain.com/docs/integrations/providers/mlflow#embeddings-example) for the Cohere API works locally. ```python """ Setup ----- export COHERE_API_KEY=... mlflow deployments start-server --config-path examples/deployments/cohere/config.yaml Run --- python /path/to/this/file.py """ embeddings = MlflowCohereEmbeddings(target_uri="http://127.0.0.1:5000", endpoint="embeddings") print(embeddings.embed_query("hello")[:3]) print(embeddings.embed_documents(["hello", "world"])[0][:3]) ``` Output ``` [0.060455322, 0.028793335, -0.025848389] [0.031707764, 0.021057129, -0.009361267] ``` --- .../embeddings/__init__.py | 6 ++++- .../langchain_community/embeddings/mlflow.py | 22 +++++++++++++++---- .../unit_tests/embeddings/test_imports.py | 1 + 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/libs/community/langchain_community/embeddings/__init__.py b/libs/community/langchain_community/embeddings/__init__.py index 9b9deba027c..f39787b3556 100644 --- a/libs/community/langchain_community/embeddings/__init__.py +++ b/libs/community/langchain_community/embeddings/__init__.py @@ -57,7 +57,10 @@ from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings from langchain_community.embeddings.llm_rails import LLMRailsEmbeddings from langchain_community.embeddings.localai import LocalAIEmbeddings from langchain_community.embeddings.minimax import MiniMaxEmbeddings -from langchain_community.embeddings.mlflow import MlflowEmbeddings +from langchain_community.embeddings.mlflow import ( + MlflowCohereEmbeddings, + MlflowEmbeddings, +) from langchain_community.embeddings.mlflow_gateway import MlflowAIGatewayEmbeddings from langchain_community.embeddings.modelscope_hub import ModelScopeEmbeddings from langchain_community.embeddings.mosaicml import MosaicMLInstructorEmbeddings @@ -102,6 +105,7 @@ __all__ = [ "LLMRailsEmbeddings", "HuggingFaceHubEmbeddings", "MlflowEmbeddings", + "MlflowCohereEmbeddings", "MlflowAIGatewayEmbeddings", "ModelScopeEmbeddings", "TensorflowHubEmbeddings", diff --git a/libs/community/langchain_community/embeddings/mlflow.py b/libs/community/langchain_community/embeddings/mlflow.py index 0ae46bcffdb..6b24dacb025 100644 --- a/libs/community/langchain_community/embeddings/mlflow.py +++ b/libs/community/langchain_community/embeddings/mlflow.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Iterator, List +from typing import Any, Dict, Iterator, List from urllib.parse import urlparse from langchain_core.embeddings import Embeddings @@ -34,6 +34,10 @@ class MlflowEmbeddings(Embeddings, BaseModel): target_uri: str """The target URI to use.""" _client: Any = PrivateAttr() + """The parameters to use for queries.""" + query_params: Dict[str, str] = {} + """The parameters to use for documents.""" + documents_params: Dict[str, str] = {} def __init__(self, **kwargs: Any): super().__init__(**kwargs) @@ -63,12 +67,22 @@ class MlflowEmbeddings(Embeddings, BaseModel): f"The scheme must be one of {allowed}." ) - def embed_documents(self, texts: List[str]) -> List[List[float]]: + def embed(self, texts: List[str], params: Dict[str, str]) -> List[List[float]]: embeddings: List[List[float]] = [] for txt in _chunk(texts, 20): - resp = self._client.predict(endpoint=self.endpoint, inputs={"input": txt}) + resp = self._client.predict( + endpoint=self.endpoint, inputs={"input": txt, **params} + ) embeddings.extend(r["embedding"] for r in resp["data"]) return embeddings + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return self.embed(texts, params=self.documents_params) + def embed_query(self, text: str) -> List[float]: - return self.embed_documents([text])[0] + return self.embed([text], params=self.query_params)[0] + + +class MlflowCohereEmbeddings(MlflowEmbeddings): + query_params: Dict[str, str] = {"input_type": "search_query"} + documents_params: Dict[str, str] = {"input_type": "search_document"} diff --git a/libs/community/tests/unit_tests/embeddings/test_imports.py b/libs/community/tests/unit_tests/embeddings/test_imports.py index 6aac6609a99..dee9b1ba836 100644 --- a/libs/community/tests/unit_tests/embeddings/test_imports.py +++ b/libs/community/tests/unit_tests/embeddings/test_imports.py @@ -18,6 +18,7 @@ EXPECTED_ALL = [ "HuggingFaceHubEmbeddings", "MlflowAIGatewayEmbeddings", "MlflowEmbeddings", + "MlflowCohereEmbeddings", "ModelScopeEmbeddings", "TensorflowHubEmbeddings", "SagemakerEndpointEmbeddings",