langchain[patch]: Add progress bar option to OllamaEmbeddings (#13882)

- **Description:** Adds a tqdm progress bar to OllamaEmbeddings when
embedding a list.
- **Issue:** Related to #13637, but extended to Ollama.
- **Dependencies:** `tqdm` made a necessary dependency.

Thanks to @ugm2 for helping identify a common problem. Embeddings take a
very long time to finish on local machines, and require a progress bar
to help identify if one should even attempt the workload.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Tyler Titsworth 2023-11-27 13:56:13 -08:00 committed by GitHub
parent ec53d983a1
commit afcfa2a5e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,9 +1,12 @@
import logging
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
import requests import requests
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra from langchain_core.pydantic_v1 import BaseModel, Extra
logger = logging.getLogger(__name__)
class OllamaEmbeddings(BaseModel, Embeddings): class OllamaEmbeddings(BaseModel, Embeddings):
"""Ollama locally runs large language models. """Ollama locally runs large language models.
@ -99,6 +102,9 @@ class OllamaEmbeddings(BaseModel, Embeddings):
to more diverse text, while a lower value (e.g., 0.5) will to more diverse text, while a lower value (e.g., 0.5) will
generate more focused and conservative text. (Default: 0.9)""" generate more focused and conservative text. (Default: 0.9)"""
show_progress: bool = False
"""Whether to show a tqdm progress bar. Must have `tqdm` installed."""
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Ollama.""" """Get the default parameters for calling Ollama."""
@ -170,15 +176,23 @@ class OllamaEmbeddings(BaseModel, Embeddings):
) )
def _embed(self, input: List[str]) -> List[List[float]]: def _embed(self, input: List[str]) -> List[List[float]]:
embeddings_list: List[List[float]] = [] if self.show_progress:
for prompt in input: try:
embeddings = self._process_emb_response(prompt) from tqdm import tqdm
embeddings_list.append(embeddings)
return embeddings_list iter_ = tqdm(input, desc="OllamaEmbeddings")
except ImportError:
logger.warning(
"Unable to show progress bar because tqdm could not be imported. "
"Please install with `pip install tqdm`."
)
iter_ = input
else:
iter_ = input
return [self._process_emb_response(prompt) for prompt in iter_]
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed documents using a Ollama deployed embedding model. """Embed documents using an Ollama deployed embedding model.
Args: Args:
texts: The list of texts to embed. texts: The list of texts to embed.