mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 00:47:27 +00:00
langchain[patch]: Adds progress bar to GooglePalmEmbeddings (#13812)
- **Description:** Adds a tqdm progress bar to GooglePalmEmbeddings when embedding a list. - **Issue:** #13637 - **Dependencies:** TQDM as a main dependency (instead of extra) Signed-off-by: ugm2 <unaigaraymaestre@gmail.com> --------- Signed-off-by: ugm2 <unaigaraymaestre@gmail.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
1cd9d5f332
commit
9e2ae866c4
@ -60,6 +60,8 @@ class GooglePalmEmbeddings(BaseModel, Embeddings):
|
||||
google_api_key: Optional[str]
|
||||
model_name: str = "models/embedding-gecko-001"
|
||||
"""Model name to use."""
|
||||
show_progress_bar: bool = False
|
||||
"""Whether to show a tqdm progress bar. Must have `tqdm` installed."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@ -79,7 +81,20 @@ class GooglePalmEmbeddings(BaseModel, Embeddings):
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return [self.embed_query(text) for text in texts]
|
||||
if self.show_progress_bar:
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
iter_ = tqdm(texts, desc="GooglePalmEmbeddings")
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Unable to show progress bar because tqdm could not be imported. "
|
||||
"Please install with `pip install tqdm`."
|
||||
)
|
||||
iter_ = texts
|
||||
else:
|
||||
iter_ = texts
|
||||
return [self.embed_query(text) for text in iter_]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
|
Loading…
Reference in New Issue
Block a user