diff --git a/libs/langchain/langchain/embeddings/google_palm.py b/libs/langchain/langchain/embeddings/google_palm.py index a61dbd3b299..db6314d47f4 100644 --- a/libs/langchain/langchain/embeddings/google_palm.py +++ b/libs/langchain/langchain/embeddings/google_palm.py @@ -60,6 +60,8 @@ class GooglePalmEmbeddings(BaseModel, Embeddings): google_api_key: Optional[str] model_name: str = "models/embedding-gecko-001" """Model name to use.""" + show_progress_bar: bool = False + """Whether to show a tqdm progress bar. Must have `tqdm` installed.""" @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -79,7 +81,20 @@ class GooglePalmEmbeddings(BaseModel, Embeddings): return values def embed_documents(self, texts: List[str]) -> List[List[float]]: - return [self.embed_query(text) for text in texts] + if self.show_progress_bar: + try: + from tqdm import tqdm + + iter_ = tqdm(texts, desc="GooglePalmEmbeddings") + except ImportError: + logger.warning( + "Unable to show progress bar because tqdm could not be imported. " + "Please install with `pip install tqdm`." + ) + iter_ = texts + else: + iter_ = texts + return [self.embed_query(text) for text in iter_] def embed_query(self, text: str) -> List[float]: """Embed query text."""