diff --git a/libs/langchain/langchain/embeddings/ollama.py b/libs/langchain/langchain/embeddings/ollama.py index 1938d916405..c1b910af4ee 100644 --- a/libs/langchain/langchain/embeddings/ollama.py +++ b/libs/langchain/langchain/embeddings/ollama.py @@ -1,9 +1,12 @@ +import logging from typing import Any, Dict, List, Mapping, Optional import requests from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra +logger = logging.getLogger(__name__) + class OllamaEmbeddings(BaseModel, Embeddings): """Ollama locally runs large language models. @@ -99,6 +102,9 @@ class OllamaEmbeddings(BaseModel, Embeddings): to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)""" + show_progress: bool = False + """Whether to show a tqdm progress bar. Must have `tqdm` installed.""" + @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling Ollama.""" @@ -170,15 +176,23 @@ class OllamaEmbeddings(BaseModel, Embeddings): ) def _embed(self, input: List[str]) -> List[List[float]]: - embeddings_list: List[List[float]] = [] - for prompt in input: - embeddings = self._process_emb_response(prompt) - embeddings_list.append(embeddings) + if self.show_progress: + try: + from tqdm import tqdm - return embeddings_list + iter_ = tqdm(input, desc="OllamaEmbeddings") + except ImportError: + logger.warning( + "Unable to show progress bar because tqdm could not be imported. " + "Please install with `pip install tqdm`." + ) + iter_ = input + else: + iter_ = input + return [self._process_emb_response(prompt) for prompt in iter_] def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Embed documents using a Ollama deployed embedding model. + """Embed documents using an Ollama deployed embedding model. Args: texts: The list of texts to embed.