diff --git a/libs/partners/ollama/langchain_ollama/embeddings.py b/libs/partners/ollama/langchain_ollama/embeddings.py index cfc9d174471..fedf80b5ba5 100644 --- a/libs/partners/ollama/langchain_ollama/embeddings.py +++ b/libs/partners/ollama/langchain_ollama/embeddings.py @@ -296,7 +296,10 @@ class OllamaEmbeddings(BaseModel, Embeddings): raise ValueError(msg) return ( await self._async_client.embed( - self.model, texts, keep_alive=self.keep_alive + self.model, + texts, + options=self._default_params, + keep_alive=self.keep_alive, ) )["embeddings"] diff --git a/libs/partners/ollama/tests/unit_tests/test_embeddings.py b/libs/partners/ollama/tests/unit_tests/test_embeddings.py index 6ceec7c5df9..93f996f59bc 100644 --- a/libs/partners/ollama/tests/unit_tests/test_embeddings.py +++ b/libs/partners/ollama/tests/unit_tests/test_embeddings.py @@ -1,7 +1,7 @@ """Test embedding model integration.""" from typing import Any -from unittest.mock import patch +from unittest.mock import Mock, patch from langchain_ollama.embeddings import OllamaEmbeddings @@ -28,3 +28,36 @@ def test_validate_model_on_init(mock_validate_model: Any) -> None: # Test that validate_model is NOT called by default OllamaEmbeddings(model=MODEL_NAME) mock_validate_model.assert_not_called() + + +@patch("langchain_ollama.embeddings.Client") +def test_embed_documents_passes_options(mock_client_class: Any) -> None: + """Test that embed_documents method passes options including num_gpu.""" + # Create a mock client instance + mock_client = Mock() + mock_client_class.return_value = mock_client + + # Mock the embed method response + mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]} + + # Create embeddings with num_gpu parameter + embeddings = OllamaEmbeddings(model=MODEL_NAME, num_gpu=4, temperature=0.5) + + # Call embed_documents + result = embeddings.embed_documents(["test text"]) + + # Verify the result + assert result == [[0.1, 0.2, 0.3]] + + # Check that embed was called with correct arguments + mock_client.embed.assert_called_once() + call_args = mock_client.embed.call_args + + # Verify the keyword arguments + assert "options" in call_args.kwargs + assert "keep_alive" in call_args.kwargs + + # Verify options contain num_gpu and temperature + options = call_args.kwargs["options"] + assert options["num_gpu"] == 4 + assert options["temperature"] == 0.5