mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-24 05:50:18 +00:00
Support multi gpu inference for HuggingFaceEmbeddings (#4732)
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -47,6 +47,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
"""Key word arguments to pass to the model."""
|
||||
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Key word arguments to pass when calling the `encode` method of the model."""
|
||||
multi_process: bool = False
|
||||
"""Run encode() on multiple GPUs."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
@@ -78,8 +80,16 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
import sentence_transformers
|
||||
|
||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
embeddings = self.client.encode(texts, **self.encode_kwargs)
|
||||
if self.multi_process:
|
||||
pool = self.client.start_multi_process_pool()
|
||||
embeddings = self.client.encode_multi_process(texts, pool)
|
||||
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
|
||||
else:
|
||||
embeddings = self.client.encode(texts, **self.encode_kwargs)
|
||||
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
@@ -91,9 +101,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
embedding = self.client.encode(text, **self.encode_kwargs)
|
||||
return embedding.tolist()
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
|
||||
class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Reference in New Issue
Block a user