mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
178 lines
4.9 KiB
Python
178 lines
4.9 KiB
Python
"""Ollama embeddings models."""
|
|
|
|
from typing import (
|
|
List,
|
|
Optional,
|
|
)
|
|
|
|
from langchain_core.embeddings import Embeddings
|
|
from ollama import AsyncClient, Client
|
|
from pydantic import (
|
|
BaseModel,
|
|
ConfigDict,
|
|
PrivateAttr,
|
|
model_validator,
|
|
)
|
|
from typing_extensions import Self
|
|
|
|
|
|
class OllamaEmbeddings(BaseModel, Embeddings):
|
|
"""Ollama embedding model integration.
|
|
|
|
Set up a local Ollama instance:
|
|
Install the Ollama package and set up a local Ollama instance
|
|
using the instructions here: https://github.com/ollama/ollama .
|
|
|
|
You will need to choose a model to serve.
|
|
|
|
You can view a list of available models via the model library (https://ollama.com/library).
|
|
|
|
To fetch a model from the Ollama model library use ``ollama pull <name-of-model>``.
|
|
|
|
For example, to pull the llama3 model:
|
|
|
|
.. code-block:: bash
|
|
|
|
ollama pull llama3
|
|
|
|
This will download the default tagged version of the model.
|
|
Typically, the default points to the latest, smallest sized-parameter model.
|
|
|
|
* On Mac, the models will be downloaded to ~/.ollama/models
|
|
* On Linux (or WSL), the models will be stored at /usr/share/ollama/.ollama/models
|
|
|
|
You can specify the exact version of the model of interest
|
|
as such ``ollama pull vicuna:13b-v1.5-16k-q4_0``.
|
|
|
|
To view pulled models:
|
|
|
|
.. code-block:: bash
|
|
|
|
ollama list
|
|
|
|
To start serving:
|
|
|
|
.. code-block:: bash
|
|
|
|
ollama serve
|
|
|
|
View the Ollama documentation for more commands.
|
|
|
|
.. code-block:: bash
|
|
|
|
ollama help
|
|
|
|
Install the langchain-ollama integration package:
|
|
.. code-block:: bash
|
|
|
|
pip install -U langchain_ollama
|
|
|
|
Key init args — completion params:
|
|
model: str
|
|
Name of Ollama model to use.
|
|
base_url: Optional[str]
|
|
Base url the model is hosted under.
|
|
|
|
See full list of supported init args and their descriptions in the params section.
|
|
|
|
Instantiate:
|
|
.. code-block:: python
|
|
|
|
from langchain_ollama import OllamaEmbeddings
|
|
|
|
embed = OllamaEmbeddings(
|
|
model="llama3"
|
|
)
|
|
|
|
Embed single text:
|
|
.. code-block:: python
|
|
|
|
input_text = "The meaning of life is 42"
|
|
vector = embed.embed_query(input_text)
|
|
print(vector[:3])
|
|
|
|
.. code-block:: python
|
|
|
|
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
|
|
|
|
Embed multiple texts:
|
|
.. code-block:: python
|
|
|
|
input_texts = ["Document 1...", "Document 2..."]
|
|
vectors = embed.embed_documents(input_texts)
|
|
print(len(vectors))
|
|
# The first 3 coordinates for the first vector
|
|
print(vectors[0][:3])
|
|
|
|
.. code-block:: python
|
|
|
|
2
|
|
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
|
|
|
|
Async:
|
|
.. code-block:: python
|
|
|
|
vector = await embed.aembed_query(input_text)
|
|
print(vector[:3])
|
|
|
|
# multiple:
|
|
# await embed.aembed_documents(input_texts)
|
|
|
|
.. code-block:: python
|
|
|
|
[-0.009100092574954033, 0.005071679595857859, -0.0029193938244134188]
|
|
""" # noqa: E501
|
|
|
|
model: str
|
|
"""Model name to use."""
|
|
|
|
base_url: Optional[str] = None
|
|
"""Base url the model is hosted under."""
|
|
|
|
client_kwargs: Optional[dict] = {}
|
|
"""Additional kwargs to pass to the httpx Client.
|
|
For a full list of the params, see [this link](https://pydoc.dev/httpx/latest/httpx.Client.html)
|
|
"""
|
|
|
|
_client: Client = PrivateAttr(default=None) # type: ignore
|
|
"""
|
|
The client to use for making requests.
|
|
"""
|
|
|
|
_async_client: AsyncClient = PrivateAttr(default=None) # type: ignore
|
|
"""
|
|
The async client to use for making requests.
|
|
"""
|
|
|
|
model_config = ConfigDict(
|
|
extra="forbid",
|
|
)
|
|
|
|
@model_validator(mode="after")
|
|
def _set_clients(self) -> Self:
|
|
"""Set clients to use for ollama."""
|
|
client_kwargs = self.client_kwargs or {}
|
|
self._client = Client(host=self.base_url, **client_kwargs)
|
|
self._async_client = AsyncClient(host=self.base_url, **client_kwargs)
|
|
return self
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Embed search docs."""
|
|
embedded_docs = self._client.embed(self.model, texts)["embeddings"]
|
|
return embedded_docs
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""Embed query text."""
|
|
return self.embed_documents([text])[0]
|
|
|
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Embed search docs."""
|
|
embedded_docs = (await self._async_client.embed(self.model, texts))[
|
|
"embeddings"
|
|
]
|
|
return embedded_docs
|
|
|
|
async def aembed_query(self, text: str) -> List[float]:
|
|
"""Embed query text."""
|
|
return (await self.aembed_documents([text]))[0]
|