langchain/libs/partners/ollama/tests/unit_tests/test_embeddings.py
Mason Daugherty ee4c2510eb
feat: port various nit changes from wip-v0.4 (#32506)
Lots of work that wasn't directly related to core
improvements/messages/testing functionality
2025-08-11 15:09:08 -04:00

64 lines
2.2 KiB
Python

"""Test embedding model integration."""
from typing import Any
from unittest.mock import Mock, patch
from langchain_ollama.embeddings import OllamaEmbeddings
MODEL_NAME = "llama3.1"
def test_initialization() -> None:
"""Test embedding model initialization."""
OllamaEmbeddings(model=MODEL_NAME, keep_alive=1)
@patch("langchain_ollama.embeddings.validate_model")
def test_validate_model_on_init(mock_validate_model: Any) -> None:
"""Test that the model is validated on initialization when requested."""
# Test that validate_model is called when validate_model_on_init=True
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=True)
mock_validate_model.assert_called_once()
mock_validate_model.reset_mock()
# Test that validate_model is NOT called when validate_model_on_init=False
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=False)
mock_validate_model.assert_not_called()
# Test that validate_model is NOT called by default
OllamaEmbeddings(model=MODEL_NAME)
mock_validate_model.assert_not_called()
@patch("langchain_ollama.embeddings.Client")
def test_embed_documents_passes_options(mock_client_class: Any) -> None:
"""Test that `embed_documents()` passes options, including `num_gpu`."""
# Create a mock client instance
mock_client = Mock()
mock_client_class.return_value = mock_client
# Mock the embed method response
mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
# Create embeddings with num_gpu parameter
embeddings = OllamaEmbeddings(model=MODEL_NAME, num_gpu=4, temperature=0.5)
# Call embed_documents
result = embeddings.embed_documents(["test text"])
# Verify the result
assert result == [[0.1, 0.2, 0.3]]
# Check that embed was called with correct arguments
mock_client.embed.assert_called_once()
call_args = mock_client.embed.call_args
# Verify the keyword arguments
assert "options" in call_args.kwargs
assert "keep_alive" in call_args.kwargs
# Verify options contain num_gpu and temperature
options = call_args.kwargs["options"]
assert options["num_gpu"] == 4
assert options["temperature"] == 0.5