From e71e6564b18316fc7b946c9c9d71b0f910baa6b7 Mon Sep 17 00:00:00 2001 From: Dat Nguyen <60676245+imtiendat0311@users.noreply.github.com> Date: Mon, 6 Apr 2026 21:50:54 -0400 Subject: [PATCH] feat(ollama): add `dimensions` to `OllamaEmbeddings` (#36543) Fixes #34623 Add `dimensions` field to `OllamaEmbeddings` to allow users to specify output embedding size for models that support variable dimensions . The field is passed directly to the Ollama client's `embed()` call for both sync and async methods. **How I verified it works:** - Ran unit tests: `python -m pytest tests/unit_tests/ -v` - Ran integration tests against a live Ollama instance: `OLLAMA_HOST=http://ollama:11434 python -m pytest tests/integration_tests/ -v` - Confirmed that passing `dimensions=768` no longer raises `extra_forbidden` Pydantic validation error and returns embeddings of the expected size. --------- Co-authored-by: Mason Daugherty Co-authored-by: Mason Daugherty --- .../ollama/langchain_ollama/embeddings.py | 29 +++++++++- .../tests/unit_tests/test_embeddings.py | 56 ++++++++++++++++++- 2 files changed, 82 insertions(+), 3 deletions(-) diff --git a/libs/partners/ollama/langchain_ollama/embeddings.py b/libs/partners/ollama/langchain_ollama/embeddings.py index b8fa44ed41e..a28f081b57c 100644 --- a/libs/partners/ollama/langchain_ollama/embeddings.py +++ b/libs/partners/ollama/langchain_ollama/embeddings.py @@ -6,7 +6,13 @@ from typing import Any from langchain_core.embeddings import Embeddings from ollama import AsyncClient, Client -from pydantic import BaseModel, ConfigDict, PrivateAttr, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + PrivateAttr, + field_validator, + model_validator, +) from typing_extensions import Self from langchain_ollama._utils import ( @@ -124,6 +130,20 @@ class OllamaEmbeddings(BaseModel, Embeddings): model: str """Model name to use.""" + dimensions: int | None = None + """Number of dimensions for the output embedding vectors. + + If not provided, the model's default embedding dimensionality is used. + """ + + @field_validator("dimensions") + @classmethod + def _validate_dimensions(cls, v: int | None) -> int | None: + if v is not None and v < 1: + msg = "`dimensions` must be a positive integer." + raise ValueError(msg) + return v + validate_model_on_init: bool = False """Whether to validate the model exists in ollama locally on initialization. @@ -303,7 +323,11 @@ class OllamaEmbeddings(BaseModel, Embeddings): ) raise RuntimeError(msg) return self._client.embed( - self.model, texts, options=self._default_params, keep_alive=self.keep_alive + self.model, + texts, + dimensions=self.dimensions, + options=self._default_params, + keep_alive=self.keep_alive, )["embeddings"] def embed_query(self, text: str) -> list[float]: @@ -322,6 +346,7 @@ class OllamaEmbeddings(BaseModel, Embeddings): await self._async_client.embed( self.model, texts, + dimensions=self.dimensions, options=self._default_params, keep_alive=self.keep_alive, ) diff --git a/libs/partners/ollama/tests/unit_tests/test_embeddings.py b/libs/partners/ollama/tests/unit_tests/test_embeddings.py index c6b45700110..674f1a8bb89 100644 --- a/libs/partners/ollama/tests/unit_tests/test_embeddings.py +++ b/libs/partners/ollama/tests/unit_tests/test_embeddings.py @@ -1,7 +1,7 @@ """Test embedding model integration.""" from typing import Any -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest @@ -54,6 +54,60 @@ def test_embed_documents_passes_options(mock_client_class: Any) -> None: assert options["temperature"] == 0.5 +@patch("langchain_ollama.embeddings.Client") +def test_embed_documents_passes_dimensions(mock_client_class: Any) -> None: + """Test that embed_documents passes dimensions to the embed call.""" + mock_client = Mock() + mock_client_class.return_value = mock_client + mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]} + + embeddings = OllamaEmbeddings(model=MODEL_NAME, dimensions=512) + embeddings.embed_documents(["test text"]) + + call_args = mock_client.embed.call_args + assert call_args.kwargs["dimensions"] == 512 + + +@patch("langchain_ollama.embeddings.Client") +def test_embed_documents_dimensions_none_by_default(mock_client_class: Any) -> None: + """Test that dimensions defaults to None when not specified.""" + mock_client = Mock() + mock_client_class.return_value = mock_client + mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]} + + embeddings = OllamaEmbeddings(model=MODEL_NAME) + embeddings.embed_documents(["test text"]) + + call_args = mock_client.embed.call_args + assert call_args.kwargs["dimensions"] is None + + +@patch("langchain_ollama.embeddings.AsyncClient") +@patch("langchain_ollama.embeddings.Client") +async def test_aembed_documents_passes_dimensions( + mock_client_class: Any, mock_async_client_class: Any +) -> None: + """Test that aembed_documents passes dimensions to the async embed call.""" + mock_async_client = AsyncMock() + mock_async_client_class.return_value = mock_async_client + mock_async_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]} + + embeddings = OllamaEmbeddings(model=MODEL_NAME, dimensions=512) + await embeddings.aembed_documents(["test text"]) + + call_args = mock_async_client.embed.call_args + assert call_args.kwargs["dimensions"] == 512 + + +def test_dimensions_validation() -> None: + """Test that dimensions must be a positive integer.""" + with pytest.raises(ValueError, match="must be a positive integer"): + OllamaEmbeddings(model=MODEL_NAME, dimensions=0) + + with pytest.raises(ValueError, match="must be a positive integer"): + OllamaEmbeddings(model=MODEL_NAME, dimensions=-1) + + def test_embed_documents_raises_when_client_none() -> None: """Test that embed_documents raises RuntimeError when client is None.""" with patch("langchain_ollama.embeddings.Client") as mock_client_class: