mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 09:04:03 +00:00
community[patch]: Add Progress bar to HuggingFaceEmbeddings (#16758)
- **Description:** Adds a function parameter to HuggingFaceEmbeddings called `show_progress` that enables a `tqdm` progress bar if enabled. Does not function if `multi_process = True`. - **Issue:** n/a - **Dependencies:** n/a
This commit is contained in:
parent
ae33979813
commit
304f3f5fc1
@ -49,6 +49,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
"""Keyword arguments to pass when calling the `encode` method of the model."""
|
||||
multi_process: bool = False
|
||||
"""Run encode() on multiple GPUs."""
|
||||
show_progress: bool = False
|
||||
"""Whether to show a progress bar."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
@ -88,7 +90,9 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = self.client.encode_multi_process(texts, pool)
|
||||
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
|
||||
else:
|
||||
embeddings = self.client.encode(texts, **self.encode_kwargs)
|
||||
embeddings = self.client.encode(
|
||||
texts, show_progress_bar=self.show_progress, **self.encode_kwargs
|
||||
)
|
||||
|
||||
return embeddings.tolist()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user