feat(ollama): add dimensions to OllamaEmbeddings (#36543)

Fixes #34623

Add `dimensions` field to `OllamaEmbeddings` to allow users to specify 
output embedding size for models that support variable dimensions . The
field is passed
directly to the Ollama client's `embed()` call for both sync and async
methods.

**How I verified it works:**
- Ran unit tests: `python -m pytest tests/unit_tests/ -v`
- Ran integration tests against a live Ollama instance:
`OLLAMA_HOST=http://ollama:11434 python -m pytest
tests/integration_tests/ -v`
- Confirmed that passing `dimensions=768` no longer raises
`extra_forbidden`
  Pydantic validation error and returns embeddings of the expected size.

---------

Co-authored-by: Mason Daugherty <mason@langchain.dev>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
Dat Nguyen
2026-04-06 21:50:54 -04:00
committed by GitHub
parent 050b779d97
commit e71e6564b1
2 changed files with 82 additions and 3 deletions

View File

@@ -1,7 +1,7 @@
"""Test embedding model integration."""
from typing import Any
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
@@ -54,6 +54,60 @@ def test_embed_documents_passes_options(mock_client_class: Any) -> None:
assert options["temperature"] == 0.5
@patch("langchain_ollama.embeddings.Client")
def test_embed_documents_passes_dimensions(mock_client_class: Any) -> None:
"""Test that embed_documents passes dimensions to the embed call."""
mock_client = Mock()
mock_client_class.return_value = mock_client
mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
embeddings = OllamaEmbeddings(model=MODEL_NAME, dimensions=512)
embeddings.embed_documents(["test text"])
call_args = mock_client.embed.call_args
assert call_args.kwargs["dimensions"] == 512
@patch("langchain_ollama.embeddings.Client")
def test_embed_documents_dimensions_none_by_default(mock_client_class: Any) -> None:
"""Test that dimensions defaults to None when not specified."""
mock_client = Mock()
mock_client_class.return_value = mock_client
mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
embeddings = OllamaEmbeddings(model=MODEL_NAME)
embeddings.embed_documents(["test text"])
call_args = mock_client.embed.call_args
assert call_args.kwargs["dimensions"] is None
@patch("langchain_ollama.embeddings.AsyncClient")
@patch("langchain_ollama.embeddings.Client")
async def test_aembed_documents_passes_dimensions(
mock_client_class: Any, mock_async_client_class: Any
) -> None:
"""Test that aembed_documents passes dimensions to the async embed call."""
mock_async_client = AsyncMock()
mock_async_client_class.return_value = mock_async_client
mock_async_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
embeddings = OllamaEmbeddings(model=MODEL_NAME, dimensions=512)
await embeddings.aembed_documents(["test text"])
call_args = mock_async_client.embed.call_args
assert call_args.kwargs["dimensions"] == 512
def test_dimensions_validation() -> None:
"""Test that dimensions must be a positive integer."""
with pytest.raises(ValueError, match="must be a positive integer"):
OllamaEmbeddings(model=MODEL_NAME, dimensions=0)
with pytest.raises(ValueError, match="must be a positive integer"):
OllamaEmbeddings(model=MODEL_NAME, dimensions=-1)
def test_embed_documents_raises_when_client_none() -> None:
"""Test that embed_documents raises RuntimeError when client is None."""
with patch("langchain_ollama.embeddings.Client") as mock_client_class: