mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
langchain[patch]: Add progress bar option to OllamaEmbeddings (#13882)
- **Description:** Adds a tqdm progress bar to OllamaEmbeddings when embedding a list. - **Issue:** Related to #13637, but extended to Ollama. - **Dependencies:** `tqdm` made a necessary dependency. Thanks to @ugm2 for helping identify a common problem. Embeddings take a very long time to finish on local machines, and require a progress bar to help identify if one should even attempt the workload. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
ec53d983a1
commit
afcfa2a5e7
@ -1,9 +1,12 @@
|
|||||||
|
import logging
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra
|
from langchain_core.pydantic_v1 import BaseModel, Extra
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OllamaEmbeddings(BaseModel, Embeddings):
|
class OllamaEmbeddings(BaseModel, Embeddings):
|
||||||
"""Ollama locally runs large language models.
|
"""Ollama locally runs large language models.
|
||||||
@ -99,6 +102,9 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
|||||||
to more diverse text, while a lower value (e.g., 0.5) will
|
to more diverse text, while a lower value (e.g., 0.5) will
|
||||||
generate more focused and conservative text. (Default: 0.9)"""
|
generate more focused and conservative text. (Default: 0.9)"""
|
||||||
|
|
||||||
|
show_progress: bool = False
|
||||||
|
"""Whether to show a tqdm progress bar. Must have `tqdm` installed."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
"""Get the default parameters for calling Ollama."""
|
"""Get the default parameters for calling Ollama."""
|
||||||
@ -170,15 +176,23 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _embed(self, input: List[str]) -> List[List[float]]:
|
def _embed(self, input: List[str]) -> List[List[float]]:
|
||||||
embeddings_list: List[List[float]] = []
|
if self.show_progress:
|
||||||
for prompt in input:
|
try:
|
||||||
embeddings = self._process_emb_response(prompt)
|
from tqdm import tqdm
|
||||||
embeddings_list.append(embeddings)
|
|
||||||
|
|
||||||
return embeddings_list
|
iter_ = tqdm(input, desc="OllamaEmbeddings")
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(
|
||||||
|
"Unable to show progress bar because tqdm could not be imported. "
|
||||||
|
"Please install with `pip install tqdm`."
|
||||||
|
)
|
||||||
|
iter_ = input
|
||||||
|
else:
|
||||||
|
iter_ = input
|
||||||
|
return [self._process_emb_response(prompt) for prompt in iter_]
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Embed documents using a Ollama deployed embedding model.
|
"""Embed documents using an Ollama deployed embedding model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: The list of texts to embed.
|
texts: The list of texts to embed.
|
||||||
|
Loading…
Reference in New Issue
Block a user