mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-04 18:53:02 +00:00
community[patch]: Add Progress bar to HuggingFaceEmbeddings (#16758)
- **Description:** Adds a function parameter to HuggingFaceEmbeddings called `show_progress` that enables a `tqdm` progress bar if enabled. Does not function if `multi_process = True`. - **Issue:** n/a - **Dependencies:** n/a
This commit is contained in:
parent
ae33979813
commit
304f3f5fc1
@ -49,6 +49,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
"""Keyword arguments to pass when calling the `encode` method of the model."""
|
"""Keyword arguments to pass when calling the `encode` method of the model."""
|
||||||
multi_process: bool = False
|
multi_process: bool = False
|
||||||
"""Run encode() on multiple GPUs."""
|
"""Run encode() on multiple GPUs."""
|
||||||
|
show_progress: bool = False
|
||||||
|
"""Whether to show a progress bar."""
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any):
|
def __init__(self, **kwargs: Any):
|
||||||
"""Initialize the sentence_transformer."""
|
"""Initialize the sentence_transformer."""
|
||||||
@ -88,7 +90,9 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
embeddings = self.client.encode_multi_process(texts, pool)
|
embeddings = self.client.encode_multi_process(texts, pool)
|
||||||
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
|
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
|
||||||
else:
|
else:
|
||||||
embeddings = self.client.encode(texts, **self.encode_kwargs)
|
embeddings = self.client.encode(
|
||||||
|
texts, show_progress_bar=self.show_progress, **self.encode_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
return embeddings.tolist()
|
return embeddings.tolist()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user