mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-23 11:32:10 +00:00
Resolves #32215 --------- Co-authored-by: Chester Curme <chester.curme@gmail.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Nuno Campos <nuno@langchain.dev>
316 lines
10 KiB
Python
316 lines
10 KiB
Python
"""Ollama embeddings models."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Optional
|
|
|
|
from langchain_core.embeddings import Embeddings
|
|
from ollama import AsyncClient, Client
|
|
from pydantic import (
|
|
BaseModel,
|
|
ConfigDict,
|
|
PrivateAttr,
|
|
model_validator,
|
|
)
|
|
from typing_extensions import Self
|
|
|
|
from ._utils import validate_model
|
|
|
|
|
|
class OllamaEmbeddings(BaseModel, Embeddings):
|
|
"""Ollama embedding model integration.
|
|
|
|
Set up a local Ollama instance:
|
|
`Install the Ollama package <https://github.com/ollama/ollama>`__ and set up a
|
|
local Ollama instance.
|
|
|
|
You will need to choose a model to serve.
|
|
|
|
You can view a list of available models via `the model library <https://ollama.com/library>`__.
|
|
|
|
To fetch a model from the Ollama model library use ``ollama pull <name-of-model>``.
|
|
|
|
For example, to pull the llama3 model:
|
|
|
|
.. code-block:: bash
|
|
|
|
ollama pull llama3
|
|
|
|
This will download the default tagged version of the model.
|
|
Typically, the default points to the latest, smallest sized-parameter model.
|
|
|
|
* On Mac, the models will be downloaded to ``~/.ollama/models``
|
|
* On Linux (or WSL), the models will be stored at ``/usr/share/ollama/.ollama/models``
|
|
|
|
You can specify the exact version of the model of interest
|
|
as such ``ollama pull vicuna:13b-v1.5-16k-q4_0``.
|
|
|
|
To view pulled models:
|
|
|
|
.. code-block:: bash
|
|
|
|
ollama list
|
|
|
|
To start serving:
|
|
|
|
.. code-block:: bash
|
|
|
|
ollama serve
|
|
|
|
View the Ollama documentation for more commands.
|
|
|
|
.. code-block:: bash
|
|
|
|
ollama help
|
|
|
|
Install the langchain-ollama integration package:
|
|
.. code-block:: bash
|
|
|
|
pip install -U langchain_ollama
|
|
|
|
Key init args — completion params:
|
|
model: str
|
|
Name of Ollama model to use.
|
|
base_url: Optional[str]
|
|
Base url the model is hosted under.
|
|
|
|
See full list of supported init args and their descriptions in the params section.
|
|
|
|
Instantiate:
|
|
.. code-block:: python
|
|
|
|
from langchain_ollama import OllamaEmbeddings
|
|
|
|
embed = OllamaEmbeddings(
|
|
model="llama3"
|
|
)
|
|
|
|
Embed single text:
|
|
.. code-block:: python
|
|
|
|
input_text = "The meaning of life is 42"
|
|
vector = embed.embed_query(input_text)
|
|
print(vector[:3])
|
|
|
|
.. code-block:: python
|
|
|
|
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
|
|
|
|
Embed multiple texts:
|
|
.. code-block:: python
|
|
|
|
input_texts = ["Document 1...", "Document 2..."]
|
|
vectors = embed.embed_documents(input_texts)
|
|
print(len(vectors))
|
|
# The first 3 coordinates for the first vector
|
|
print(vectors[0][:3])
|
|
|
|
.. code-block:: python
|
|
|
|
2
|
|
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
|
|
|
|
Async:
|
|
.. code-block:: python
|
|
|
|
vector = await embed.aembed_query(input_text)
|
|
print(vector[:3])
|
|
|
|
# multiple:
|
|
# await embed.aembed_documents(input_texts)
|
|
|
|
.. code-block:: python
|
|
|
|
[-0.009100092574954033, 0.005071679595857859, -0.0029193938244134188]
|
|
|
|
""" # noqa: E501
|
|
|
|
model: str
|
|
"""Model name to use."""
|
|
|
|
validate_model_on_init: bool = False
|
|
"""Whether to validate the model exists in ollama locally on initialization.
|
|
|
|
.. versionadded:: 0.3.4
|
|
|
|
"""
|
|
|
|
base_url: Optional[str] = None
|
|
"""Base url the model is hosted under."""
|
|
|
|
client_kwargs: Optional[dict] = {}
|
|
"""Additional kwargs to pass to the httpx clients.
|
|
|
|
These arguments are passed to both synchronous and async clients.
|
|
|
|
Use ``sync_client_kwargs`` and ``async_client_kwargs`` to pass different arguments
|
|
to synchronous and asynchronous clients.
|
|
|
|
"""
|
|
|
|
async_client_kwargs: Optional[dict] = {}
|
|
"""Additional kwargs to merge with ``client_kwargs`` before passing to the httpx
|
|
AsyncClient.
|
|
|
|
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#asyncclient>`__.
|
|
|
|
"""
|
|
|
|
sync_client_kwargs: Optional[dict] = {}
|
|
"""Additional kwargs to merge with ``client_kwargs`` before
|
|
passing to the HTTPX Client.
|
|
|
|
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#client>`__.
|
|
|
|
"""
|
|
|
|
_client: Optional[Client] = PrivateAttr(default=None)
|
|
"""The client to use for making requests."""
|
|
|
|
_async_client: Optional[AsyncClient] = PrivateAttr(default=None)
|
|
"""The async client to use for making requests."""
|
|
|
|
mirostat: Optional[int] = None
|
|
"""Enable Mirostat sampling for controlling perplexity.
|
|
(default: ``0``, ``0`` = disabled, ``1`` = Mirostat, ``2`` = Mirostat 2.0)"""
|
|
|
|
mirostat_eta: Optional[float] = None
|
|
"""Influences how quickly the algorithm responds to feedback
|
|
from the generated text. A lower learning rate will result in
|
|
slower adjustments, while a higher learning rate will make
|
|
the algorithm more responsive. (Default: ``0.1``)"""
|
|
|
|
mirostat_tau: Optional[float] = None
|
|
"""Controls the balance between coherence and diversity
|
|
of the output. A lower value will result in more focused and
|
|
coherent text. (Default: ``5.0``)"""
|
|
|
|
num_ctx: Optional[int] = None
|
|
"""Sets the size of the context window used to generate the
|
|
next token. (Default: ``2048``) """
|
|
|
|
num_gpu: Optional[int] = None
|
|
"""The number of GPUs to use. On macOS it defaults to ``1`` to
|
|
enable metal support, ``0`` to disable."""
|
|
|
|
keep_alive: Optional[int] = None
|
|
"""Controls how long the model will stay loaded into memory
|
|
following the request (default: ``5m``)
|
|
"""
|
|
|
|
num_thread: Optional[int] = None
|
|
"""Sets the number of threads to use during computation.
|
|
By default, Ollama will detect this for optimal performance.
|
|
It is recommended to set this value to the number of physical
|
|
CPU cores your system has (as opposed to the logical number of cores)."""
|
|
|
|
repeat_last_n: Optional[int] = None
|
|
"""Sets how far back for the model to look back to prevent
|
|
repetition. (Default: ``64``, ``0`` = disabled, ``-1`` = ``num_ctx``)"""
|
|
|
|
repeat_penalty: Optional[float] = None
|
|
"""Sets how strongly to penalize repetitions. A higher value (e.g., ``1.5``)
|
|
will penalize repetitions more strongly, while a lower value (e.g., ``0.9``)
|
|
will be more lenient. (Default: ``1.1``)"""
|
|
|
|
temperature: Optional[float] = None
|
|
"""The temperature of the model. Increasing the temperature will
|
|
make the model answer more creatively. (Default: ``0.8``)"""
|
|
|
|
stop: Optional[list[str]] = None
|
|
"""Sets the stop tokens to use."""
|
|
|
|
tfs_z: Optional[float] = None
|
|
"""Tail free sampling is used to reduce the impact of less probable
|
|
tokens from the output. A higher value (e.g., ``2.0``) will reduce the
|
|
impact more, while a value of ``1.0`` disables this setting. (default: ``1``)"""
|
|
|
|
top_k: Optional[int] = None
|
|
"""Reduces the probability of generating nonsense. A higher value (e.g. ``100``)
|
|
will give more diverse answers, while a lower value (e.g. ``10``)
|
|
will be more conservative. (Default: ``40``)"""
|
|
|
|
top_p: Optional[float] = None
|
|
"""Works together with top-k. A higher value (e.g., ``0.95``) will lead
|
|
to more diverse text, while a lower value (e.g., ``0.5``) will
|
|
generate more focused and conservative text. (Default: ``0.9``)"""
|
|
|
|
model_config = ConfigDict(
|
|
extra="forbid",
|
|
)
|
|
|
|
@property
|
|
def _default_params(self) -> dict[str, Any]:
|
|
"""Get the default parameters for calling Ollama."""
|
|
return {
|
|
"mirostat": self.mirostat,
|
|
"mirostat_eta": self.mirostat_eta,
|
|
"mirostat_tau": self.mirostat_tau,
|
|
"num_ctx": self.num_ctx,
|
|
"num_gpu": self.num_gpu,
|
|
"num_thread": self.num_thread,
|
|
"repeat_last_n": self.repeat_last_n,
|
|
"repeat_penalty": self.repeat_penalty,
|
|
"temperature": self.temperature,
|
|
"stop": self.stop,
|
|
"tfs_z": self.tfs_z,
|
|
"top_k": self.top_k,
|
|
"top_p": self.top_p,
|
|
}
|
|
|
|
@model_validator(mode="after")
|
|
def _set_clients(self) -> Self:
|
|
"""Set clients to use for Ollama."""
|
|
client_kwargs = self.client_kwargs or {}
|
|
|
|
sync_client_kwargs = client_kwargs
|
|
if self.sync_client_kwargs:
|
|
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
|
|
|
|
async_client_kwargs = client_kwargs
|
|
if self.async_client_kwargs:
|
|
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
|
|
|
|
self._client = Client(host=self.base_url, **sync_client_kwargs)
|
|
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
|
|
if self.validate_model_on_init:
|
|
validate_model(self._client, self.model)
|
|
return self
|
|
|
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
|
"""Embed search docs."""
|
|
if not self._client:
|
|
msg = (
|
|
"Ollama client is not initialized. "
|
|
"Please ensure Ollama is running and the model is loaded."
|
|
)
|
|
raise ValueError(msg)
|
|
return self._client.embed(
|
|
self.model, texts, options=self._default_params, keep_alive=self.keep_alive
|
|
)["embeddings"]
|
|
|
|
def embed_query(self, text: str) -> list[float]:
|
|
"""Embed query text."""
|
|
return self.embed_documents([text])[0]
|
|
|
|
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
|
"""Embed search docs."""
|
|
if not self._async_client:
|
|
msg = (
|
|
"Ollama client is not initialized. "
|
|
"Please ensure Ollama is running and the model is loaded."
|
|
)
|
|
raise ValueError(msg)
|
|
return (
|
|
await self._async_client.embed(
|
|
self.model,
|
|
texts,
|
|
options=self._default_params,
|
|
keep_alive=self.keep_alive,
|
|
)
|
|
)["embeddings"]
|
|
|
|
async def aembed_query(self, text: str) -> list[float]:
|
|
"""Embed query text."""
|
|
return (await self.aembed_documents([text]))[0]
|