mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-10 22:43:18 +00:00
feat(ollama): add dimensions to OllamaEmbeddings (#36543)
Fixes #34623 Add `dimensions` field to `OllamaEmbeddings` to allow users to specify output embedding size for models that support variable dimensions . The field is passed directly to the Ollama client's `embed()` call for both sync and async methods. **How I verified it works:** - Ran unit tests: `python -m pytest tests/unit_tests/ -v` - Ran integration tests against a live Ollama instance: `OLLAMA_HOST=http://ollama:11434 python -m pytest tests/integration_tests/ -v` - Confirmed that passing `dimensions=768` no longer raises `extra_forbidden` Pydantic validation error and returns embeddings of the expected size. --------- Co-authored-by: Mason Daugherty <mason@langchain.dev> Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
@@ -6,7 +6,13 @@ from typing import Any
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from ollama import AsyncClient, Client
|
||||
from pydantic import BaseModel, ConfigDict, PrivateAttr, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_ollama._utils import (
|
||||
@@ -124,6 +130,20 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
model: str
|
||||
"""Model name to use."""
|
||||
|
||||
dimensions: int | None = None
|
||||
"""Number of dimensions for the output embedding vectors.
|
||||
|
||||
If not provided, the model's default embedding dimensionality is used.
|
||||
"""
|
||||
|
||||
@field_validator("dimensions")
|
||||
@classmethod
|
||||
def _validate_dimensions(cls, v: int | None) -> int | None:
|
||||
if v is not None and v < 1:
|
||||
msg = "`dimensions` must be a positive integer."
|
||||
raise ValueError(msg)
|
||||
return v
|
||||
|
||||
validate_model_on_init: bool = False
|
||||
"""Whether to validate the model exists in ollama locally on initialization.
|
||||
|
||||
@@ -303,7 +323,11 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
return self._client.embed(
|
||||
self.model, texts, options=self._default_params, keep_alive=self.keep_alive
|
||||
self.model,
|
||||
texts,
|
||||
dimensions=self.dimensions,
|
||||
options=self._default_params,
|
||||
keep_alive=self.keep_alive,
|
||||
)["embeddings"]
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
@@ -322,6 +346,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
await self._async_client.embed(
|
||||
self.model,
|
||||
texts,
|
||||
dimensions=self.dimensions,
|
||||
options=self._default_params,
|
||||
keep_alive=self.keep_alive,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Test embedding model integration."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -54,6 +54,60 @@ def test_embed_documents_passes_options(mock_client_class: Any) -> None:
|
||||
assert options["temperature"] == 0.5
|
||||
|
||||
|
||||
@patch("langchain_ollama.embeddings.Client")
|
||||
def test_embed_documents_passes_dimensions(mock_client_class: Any) -> None:
|
||||
"""Test that embed_documents passes dimensions to the embed call."""
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
|
||||
|
||||
embeddings = OllamaEmbeddings(model=MODEL_NAME, dimensions=512)
|
||||
embeddings.embed_documents(["test text"])
|
||||
|
||||
call_args = mock_client.embed.call_args
|
||||
assert call_args.kwargs["dimensions"] == 512
|
||||
|
||||
|
||||
@patch("langchain_ollama.embeddings.Client")
|
||||
def test_embed_documents_dimensions_none_by_default(mock_client_class: Any) -> None:
|
||||
"""Test that dimensions defaults to None when not specified."""
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
|
||||
|
||||
embeddings = OllamaEmbeddings(model=MODEL_NAME)
|
||||
embeddings.embed_documents(["test text"])
|
||||
|
||||
call_args = mock_client.embed.call_args
|
||||
assert call_args.kwargs["dimensions"] is None
|
||||
|
||||
|
||||
@patch("langchain_ollama.embeddings.AsyncClient")
|
||||
@patch("langchain_ollama.embeddings.Client")
|
||||
async def test_aembed_documents_passes_dimensions(
|
||||
mock_client_class: Any, mock_async_client_class: Any
|
||||
) -> None:
|
||||
"""Test that aembed_documents passes dimensions to the async embed call."""
|
||||
mock_async_client = AsyncMock()
|
||||
mock_async_client_class.return_value = mock_async_client
|
||||
mock_async_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
|
||||
|
||||
embeddings = OllamaEmbeddings(model=MODEL_NAME, dimensions=512)
|
||||
await embeddings.aembed_documents(["test text"])
|
||||
|
||||
call_args = mock_async_client.embed.call_args
|
||||
assert call_args.kwargs["dimensions"] == 512
|
||||
|
||||
|
||||
def test_dimensions_validation() -> None:
|
||||
"""Test that dimensions must be a positive integer."""
|
||||
with pytest.raises(ValueError, match="must be a positive integer"):
|
||||
OllamaEmbeddings(model=MODEL_NAME, dimensions=0)
|
||||
|
||||
with pytest.raises(ValueError, match="must be a positive integer"):
|
||||
OllamaEmbeddings(model=MODEL_NAME, dimensions=-1)
|
||||
|
||||
|
||||
def test_embed_documents_raises_when_client_none() -> None:
|
||||
"""Test that embed_documents raises RuntimeError when client is None."""
|
||||
with patch("langchain_ollama.embeddings.Client") as mock_client_class:
|
||||
|
||||
Reference in New Issue
Block a user