Files
langchain/libs/community/langchain_community/embeddings/mlflow_gateway.py
James Espichan Vilca 644e0d3463 Use extend method for embeddings concatenation in mlflow_gateway (#14358)
## Description
There is a bug in the concatenation of embeddings obtained from MLflow
that does not conform to the type hint requested by the function.
``` python  
def _query(self, texts: List[str]) -> List[List[float]]:
```
It is logical to expect a **List[List[float]]** for a **List[str]**.
However, the append method encapsulates the response in a global List.
To avoid this, the extend method should be used, which will add the
embeddings of all strings at the same list level.

## Testing
I have tried using OpenAI-ADA to obtain the embeddings, and the result
of executing this snippet is as follows:

``` python  
embeds = await MlflowAIGatewayEmbeddings().aembed_documents(texts=["hi", "how are you?"])
print(embeds)
```  

``` python  
[[[-0.03512698, -0.020624293, -0.015343423, ...], [-0.021260535, -0.011461929, -0.00033121882, ...]]]
```
When in reality, the expected result should be:

``` python  
[[-0.03512698, -0.020624293, -0.015343423, ...], [-0.021260535, -0.011461929, -0.00033121882, ...]]
```
The above result complies with the expected type hint:
**List[List[float]]** . As I mentioned, we can achieve that by using the
extend method instead of the append method.

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: ccurme <chester.curme@gmail.com>
2024-08-23 14:43:43 +00:00

80 lines
2.6 KiB
Python

from __future__ import annotations
import warnings
from typing import Any, Iterator, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
for i in range(0, len(texts), size):
yield texts[i : i + size]
class MlflowAIGatewayEmbeddings(Embeddings, BaseModel):
"""MLflow AI Gateway embeddings.
To use, you should have the ``mlflow[gateway]`` python package installed.
For more information, see https://mlflow.org/docs/latest/gateway/index.html.
Example:
.. code-block:: python
from langchain_community.embeddings import MlflowAIGatewayEmbeddings
embeddings = MlflowAIGatewayEmbeddings(
gateway_uri="<your-mlflow-ai-gateway-uri>",
route="<your-mlflow-ai-gateway-embeddings-route>"
)
"""
route: str
"""The route to use for the MLflow AI Gateway API."""
gateway_uri: Optional[str] = None
"""The URI for the MLflow AI Gateway API."""
def __init__(self, **kwargs: Any):
warnings.warn(
"`MlflowAIGatewayEmbeddings` is deprecated. Use `MlflowEmbeddings` or "
"`DatabricksEmbeddings` instead.",
DeprecationWarning,
)
try:
import mlflow.gateway
except ImportError as e:
raise ImportError(
"Could not import `mlflow.gateway` module. "
"Please install it with `pip install mlflow[gateway]`."
) from e
super().__init__(**kwargs)
if self.gateway_uri:
mlflow.gateway.set_gateway_uri(self.gateway_uri)
def _query(self, texts: List[str]) -> List[List[float]]:
try:
import mlflow.gateway
except ImportError as e:
raise ImportError(
"Could not import `mlflow.gateway` module. "
"Please install it with `pip install mlflow[gateway]`."
) from e
embeddings = []
for txt in _chunk(texts, 20):
resp = mlflow.gateway.query(self.route, data={"text": txt})
# response is List[List[float]]
if isinstance(resp["embeddings"][0], List):
embeddings.extend(resp["embeddings"])
# response is List[float]
else:
embeddings.append(resp["embeddings"])
return embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self._query(texts)
def embed_query(self, text: str) -> List[float]:
return self._query([text])[0]