mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
Fixes #34623 Add `dimensions` field to `OllamaEmbeddings` to allow users to specify output embedding size for models that support variable dimensions . The field is passed directly to the Ollama client's `embed()` call for both sync and async methods. **How I verified it works:** - Ran unit tests: `python -m pytest tests/unit_tests/ -v` - Ran integration tests against a live Ollama instance: `OLLAMA_HOST=http://ollama:11434 python -m pytest tests/integration_tests/ -v` - Confirmed that passing `dimensions=768` no longer raises `extra_forbidden` Pydantic validation error and returns embeddings of the expected size. --------- Co-authored-by: Mason Daugherty <mason@langchain.dev> Co-authored-by: Mason Daugherty <github@mdrxy.com>
131 lines
5.0 KiB
Python
131 lines
5.0 KiB
Python
"""Test embedding model integration."""
|
|
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
|
|
|
import pytest
|
|
|
|
from langchain_ollama.embeddings import OllamaEmbeddings
|
|
|
|
MODEL_NAME = "llama3.1"
|
|
|
|
|
|
def test_initialization() -> None:
|
|
"""Test embedding model initialization."""
|
|
OllamaEmbeddings(model=MODEL_NAME, keep_alive=1)
|
|
|
|
|
|
@patch("langchain_ollama.embeddings.validate_model")
|
|
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
|
"""Test that the model is validated on initialization when requested."""
|
|
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=True)
|
|
mock_validate_model.assert_called_once()
|
|
mock_validate_model.reset_mock()
|
|
|
|
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=False)
|
|
mock_validate_model.assert_not_called()
|
|
OllamaEmbeddings(model=MODEL_NAME)
|
|
mock_validate_model.assert_not_called()
|
|
|
|
|
|
@patch("langchain_ollama.embeddings.Client")
|
|
def test_embed_documents_passes_options(mock_client_class: Any) -> None:
|
|
"""Test that `embed_documents()` passes options, including `num_gpu`."""
|
|
mock_client = Mock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
|
|
|
|
embeddings = OllamaEmbeddings(model=MODEL_NAME, num_gpu=4, temperature=0.5)
|
|
result = embeddings.embed_documents(["test text"])
|
|
|
|
assert result == [[0.1, 0.2, 0.3]]
|
|
|
|
# Check that embed was called with correct arguments
|
|
mock_client.embed.assert_called_once()
|
|
call_args = mock_client.embed.call_args
|
|
|
|
# Verify the keyword arguments
|
|
assert "options" in call_args.kwargs
|
|
assert "keep_alive" in call_args.kwargs
|
|
|
|
# Verify options contain num_gpu and temperature
|
|
options = call_args.kwargs["options"]
|
|
assert options["num_gpu"] == 4
|
|
assert options["temperature"] == 0.5
|
|
|
|
|
|
@patch("langchain_ollama.embeddings.Client")
|
|
def test_embed_documents_passes_dimensions(mock_client_class: Any) -> None:
|
|
"""Test that embed_documents passes dimensions to the embed call."""
|
|
mock_client = Mock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
|
|
|
|
embeddings = OllamaEmbeddings(model=MODEL_NAME, dimensions=512)
|
|
embeddings.embed_documents(["test text"])
|
|
|
|
call_args = mock_client.embed.call_args
|
|
assert call_args.kwargs["dimensions"] == 512
|
|
|
|
|
|
@patch("langchain_ollama.embeddings.Client")
|
|
def test_embed_documents_dimensions_none_by_default(mock_client_class: Any) -> None:
|
|
"""Test that dimensions defaults to None when not specified."""
|
|
mock_client = Mock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
|
|
|
|
embeddings = OllamaEmbeddings(model=MODEL_NAME)
|
|
embeddings.embed_documents(["test text"])
|
|
|
|
call_args = mock_client.embed.call_args
|
|
assert call_args.kwargs["dimensions"] is None
|
|
|
|
|
|
@patch("langchain_ollama.embeddings.AsyncClient")
|
|
@patch("langchain_ollama.embeddings.Client")
|
|
async def test_aembed_documents_passes_dimensions(
|
|
mock_client_class: Any, mock_async_client_class: Any
|
|
) -> None:
|
|
"""Test that aembed_documents passes dimensions to the async embed call."""
|
|
mock_async_client = AsyncMock()
|
|
mock_async_client_class.return_value = mock_async_client
|
|
mock_async_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
|
|
|
|
embeddings = OllamaEmbeddings(model=MODEL_NAME, dimensions=512)
|
|
await embeddings.aembed_documents(["test text"])
|
|
|
|
call_args = mock_async_client.embed.call_args
|
|
assert call_args.kwargs["dimensions"] == 512
|
|
|
|
|
|
def test_dimensions_validation() -> None:
|
|
"""Test that dimensions must be a positive integer."""
|
|
with pytest.raises(ValueError, match="must be a positive integer"):
|
|
OllamaEmbeddings(model=MODEL_NAME, dimensions=0)
|
|
|
|
with pytest.raises(ValueError, match="must be a positive integer"):
|
|
OllamaEmbeddings(model=MODEL_NAME, dimensions=-1)
|
|
|
|
|
|
def test_embed_documents_raises_when_client_none() -> None:
|
|
"""Test that embed_documents raises RuntimeError when client is None."""
|
|
with patch("langchain_ollama.embeddings.Client") as mock_client_class:
|
|
mock_client_class.return_value = MagicMock()
|
|
embeddings = OllamaEmbeddings(model="test-model")
|
|
embeddings._client = None # type: ignore[assignment]
|
|
|
|
with pytest.raises(RuntimeError, match="sync client is not initialized"):
|
|
embeddings.embed_documents(["test"])
|
|
|
|
|
|
async def test_aembed_documents_raises_when_client_none() -> None:
|
|
"""Test that aembed_documents raises RuntimeError when async client is None."""
|
|
with patch("langchain_ollama.embeddings.AsyncClient") as mock_client_class:
|
|
mock_client_class.return_value = MagicMock()
|
|
embeddings = OllamaEmbeddings(model="test-model")
|
|
embeddings._async_client = None # type: ignore[assignment]
|
|
|
|
with pytest.raises(RuntimeError, match="async client is not initialized"):
|
|
await embeddings.aembed_documents(["test"])
|