diff --git a/libs/langchain/langchain/embeddings/huggingface.py b/libs/langchain/langchain/embeddings/huggingface.py index b2d4a183333..72243705d2b 100644 --- a/libs/langchain/langchain/embeddings/huggingface.py +++ b/libs/langchain/langchain/embeddings/huggingface.py @@ -47,6 +47,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): """Key word arguments to pass to the model.""" encode_kwargs: Dict[str, Any] = Field(default_factory=dict) """Key word arguments to pass when calling the `encode` method of the model.""" + multi_process: bool = False + """Run encode() on multiple GPUs.""" def __init__(self, **kwargs: Any): """Initialize the sentence_transformer.""" @@ -78,8 +80,16 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ + import sentence_transformers + texts = list(map(lambda x: x.replace("\n", " "), texts)) - embeddings = self.client.encode(texts, **self.encode_kwargs) + if self.multi_process: + pool = self.client.start_multi_process_pool() + embeddings = self.client.encode_multi_process(texts, pool) + sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool) + else: + embeddings = self.client.encode(texts, **self.encode_kwargs) + return embeddings.tolist() def embed_query(self, text: str) -> List[float]: @@ -91,9 +101,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text. """ - text = text.replace("\n", " ") - embedding = self.client.encode(text, **self.encode_kwargs) - return embedding.tolist() + return self.embed_documents([text])[0] class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):