Use extend method for embeddings concatenation in mlflow_gateway (#14358)

## Description
There is a bug in the concatenation of embeddings obtained from MLflow
that does not conform to the type hint requested by the function.
``` python  
def _query(self, texts: List[str]) -> List[List[float]]:
```
It is logical to expect a **List[List[float]]** for a **List[str]**.
However, the append method encapsulates the response in a global List.
To avoid this, the extend method should be used, which will add the
embeddings of all strings at the same list level.

## Testing
I have tried using OpenAI-ADA to obtain the embeddings, and the result
of executing this snippet is as follows:

``` python  
embeds = await MlflowAIGatewayEmbeddings().aembed_documents(texts=["hi", "how are you?"])
print(embeds)
```  

``` python  
[[[-0.03512698, -0.020624293, -0.015343423, ...], [-0.021260535, -0.011461929, -0.00033121882, ...]]]
```
When in reality, the expected result should be:

``` python  
[[-0.03512698, -0.020624293, -0.015343423, ...], [-0.021260535, -0.011461929, -0.00033121882, ...]]
```
The above result complies with the expected type hint:
**List[List[float]]** . As I mentioned, we can achieve that by using the
extend method instead of the append method.

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
James Espichan Vilca 2024-08-23 09:43:43 -05:00 committed by GitHub
parent 7f1e444efa
commit 644e0d3463
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -64,7 +64,12 @@ class MlflowAIGatewayEmbeddings(Embeddings, BaseModel):
embeddings = []
for txt in _chunk(texts, 20):
resp = mlflow.gateway.query(self.route, data={"text": txt})
embeddings.append(resp["embeddings"])
# response is List[List[float]]
if isinstance(resp["embeddings"][0], List):
embeddings.extend(resp["embeddings"])
# response is List[float]
else:
embeddings.append(resp["embeddings"])
return embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]: