diff --git a/libs/community/langchain_community/embeddings/mlflow_gateway.py b/libs/community/langchain_community/embeddings/mlflow_gateway.py index 6e2fad408a3..e722d9e35ce 100644 --- a/libs/community/langchain_community/embeddings/mlflow_gateway.py +++ b/libs/community/langchain_community/embeddings/mlflow_gateway.py @@ -64,7 +64,12 @@ class MlflowAIGatewayEmbeddings(Embeddings, BaseModel): embeddings = [] for txt in _chunk(texts, 20): resp = mlflow.gateway.query(self.route, data={"text": txt}) - embeddings.append(resp["embeddings"]) + # response is List[List[float]] + if isinstance(resp["embeddings"][0], List): + embeddings.extend(resp["embeddings"]) + # response is List[float] + else: + embeddings.append(resp["embeddings"]) return embeddings def embed_documents(self, texts: List[str]) -> List[List[float]]: