mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 13:55:03 +00:00
Convert numpy arrays to lists in HuggingFaceEmbeddings (#714)
`SentenceTransformer` returns a NumPy array, not a `List[List[float]]` or `List[float]` as specified in the interface of `Embeddings`. That PR makes it consistent with the interface.
This commit is contained in:
parent
97c3544a1e
commit
d4f719c34b
@ -54,7 +54,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
"""
|
"""
|
||||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||||
embeddings = self.client.encode(texts)
|
embeddings = self.client.encode(texts)
|
||||||
return embeddings
|
return embeddings.tolist()
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Compute query embeddings using a HuggingFace transformer model.
|
"""Compute query embeddings using a HuggingFace transformer model.
|
||||||
@ -67,4 +67,4 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
"""
|
"""
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
embedding = self.client.encode(text)
|
embedding = self.client.encode(text)
|
||||||
return embedding
|
return embedding.tolist()
|
||||||
|
Loading…
Reference in New Issue
Block a user