more tests

This commit is contained in:
Mason Daugherty 2025-08-04 17:31:47 -04:00
parent c91d681c83
commit db41bfe1f8
No known key found for this signature in database
3 changed files with 112 additions and 72 deletions

View File

@ -621,7 +621,7 @@ class ChatOllamaV1(BaseChatModelV1):
chat_params = self._chat_params(messages, stop, **kwargs)
if chat_params["stream"]:
async for part in await self._async_client.chat(**chat_params):
async for part in self._async_client.chat(**chat_params): # type: ignore[attr-defined]
if not isinstance(part, str):
# Skip empty load responses
if (

View File

@ -2,6 +2,7 @@
import json
import logging
from collections.abc import AsyncIterator, Iterator
from typing import Any
from unittest.mock import MagicMock, patch
@ -160,8 +161,59 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
def chat_model_params(self) -> dict:
return {"model": MODEL_NAME}
@property
def supports_non_standard_blocks(self) -> bool:
"""Override to indicate Ollama doesn't support non-standard content blocks."""
return False
@pytest.fixture
def model(self) -> Iterator[ChatOllamaV1]:
"""Create a ChatOllamaV1 instance for testing."""
sync_patcher = patch("langchain_ollama.chat_models_v1.Client")
async_patcher = patch("langchain_ollama.chat_models_v1.AsyncClient")
mock_sync_client_class = sync_patcher.start()
mock_async_client_class = async_patcher.start()
mock_sync_client = MagicMock()
mock_async_client = MagicMock()
mock_sync_client_class.return_value = mock_sync_client
mock_async_client_class.return_value = mock_async_client
def mock_chat_response(*_args: Any, **_kwargs: Any) -> Iterator[dict[str, Any]]:
return iter(
[
{
"model": MODEL_NAME,
"created_at": "2024-01-01T00:00:00Z",
"message": {"role": "assistant", "content": "Test response"},
"done": True,
"done_reason": "stop",
}
]
)
async def mock_async_chat_iterator(
*_args: Any, **_kwargs: Any
) -> AsyncIterator[dict[str, Any]]:
for item in mock_chat_response(*_args, **_kwargs):
yield item
mock_sync_client.chat.side_effect = mock_chat_response
mock_async_client.chat.side_effect = mock_async_chat_iterator
model_instance = self.chat_model_class(**self.chat_model_params)
yield model_instance
sync_patcher.stop()
async_patcher.stop()
def test_initialization(self) -> None:
"""Test `ChatOllamaV1` initialization."""
with (
patch("langchain_ollama.chat_models_v1.Client"),
patch("langchain_ollama.chat_models_v1.AsyncClient"),
):
llm = ChatOllamaV1(model=MODEL_NAME)
assert llm.model == MODEL_NAME
@ -169,6 +221,10 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
def test_chat_params(self) -> None:
"""Test `_chat_params()`."""
with (
patch("langchain_ollama.chat_models_v1.Client"),
patch("langchain_ollama.chat_models_v1.AsyncClient"),
):
llm = ChatOllamaV1(model=MODEL_NAME, temperature=0.7)
messages: list[MessageV1] = [HumanMessageV1("Hello")]
@ -185,6 +241,10 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
def test_ls_params(self) -> None:
"""Test LangSmith parameters."""
with (
patch("langchain_ollama.chat_models_v1.Client"),
patch("langchain_ollama.chat_models_v1.AsyncClient"),
):
llm = ChatOllamaV1(model=MODEL_NAME, temperature=0.5)
ls_params = llm._get_ls_params()
@ -196,6 +256,10 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
def test_bind_tools_basic(self) -> None:
"""Test basic tool binding functionality."""
with (
patch("langchain_ollama.chat_models_v1.Client"),
patch("langchain_ollama.chat_models_v1.AsyncClient"),
):
llm = ChatOllamaV1(model=MODEL_NAME)
def test_tool(query: str) -> str:
@ -213,9 +277,15 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
# But can be added if needed in the future.
@patch("langchain_ollama.chat_models.validate_model")
def test_validate_model_on_init(mock_validate_model: Any) -> None:
@patch("langchain_ollama.chat_models_v1.validate_model")
@patch("langchain_ollama.chat_models_v1.Client")
def test_validate_model_on_init(
mock_client_class: Any, mock_validate_model: Any
) -> None:
"""Test that the model is validated on initialization when requested."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
# Test that validate_model is called when validate_model_on_init=True
ChatOllamaV1(model=MODEL_NAME, validate_model_on_init=True)
mock_validate_model.assert_called_once()
@ -314,16 +384,16 @@ def test_load_response_with_empty_content_is_skipped(
}
]
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.chat.return_value = load_only_response
mock_client.chat.return_value = iter(load_only_response)
llm = ChatOllamaV1(model="test-model")
with (
caplog.at_level(logging.WARNING),
pytest.raises(ValueError, match="No data received from Ollama stream"),
pytest.raises(ValueError, match="No generations found in stream"),
):
llm.invoke([HumanMessageV1("Hello")])
@ -344,16 +414,16 @@ def test_load_response_with_whitespace_content_is_skipped(
}
]
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.chat.return_value = load_whitespace_response
mock_client.chat.return_value = iter(load_whitespace_response)
llm = ChatOllamaV1(model="test-model")
with (
caplog.at_level(logging.WARNING),
pytest.raises(ValueError, match="No data received from Ollama stream"),
pytest.raises(ValueError, match="No generations found in stream"),
):
llm.invoke([HumanMessageV1("Hello")])
assert "Ollama returned empty response with done_reason='load'" in caplog.text
@ -383,10 +453,10 @@ def test_load_followed_by_content_response(
},
]
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.chat.return_value = load_then_content_response
mock_client.chat.return_value = iter(load_then_content_response)
llm = ChatOllamaV1(model="test-model")
@ -394,7 +464,8 @@ def test_load_followed_by_content_response(
result = llm.invoke([HumanMessageV1("Hello")])
assert "Ollama returned empty response with done_reason='load'" in caplog.text
assert result.content == "Hello! How can I help you today?"
assert len(result.content) == 1
assert result.text == "Hello! How can I help you today?"
assert result.response_metadata.get("done_reason") == "stop"
@ -412,16 +483,17 @@ def test_load_response_with_actual_content_is_not_skipped(
}
]
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.chat.return_value = load_with_content_response
mock_client.chat.return_value = iter(load_with_content_response)
llm = ChatOllamaV1(model="test-model")
with caplog.at_level(logging.WARNING):
result = llm.invoke([HumanMessageV1("Hello")])
assert result.content == "This is actual content"
assert len(result.content) == 1
assert result.text == "This is actual content"
assert result.response_metadata.get("done_reason") == "load"
assert not caplog.text

View File

@ -207,38 +207,6 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
params = model._identifying_params
assert isinstance(params, dict) # Should be dict-like mapping
# Token Counting Tests
def test_get_token_ids(self, model: BaseChatModelV1) -> None:
"""Test that ``get_token_ids`` returns a list of integers."""
text = "Hello, world!"
token_ids = model.get_token_ids(text)
assert isinstance(token_ids, list)
assert all(isinstance(token_id, int) for token_id in token_ids)
assert len(token_ids) > 0
def test_get_num_tokens(self, model: BaseChatModelV1) -> None:
"""Test that ``get_num_tokens`` returns a positive integer."""
text = "Hello, world!"
num_tokens = model.get_num_tokens(text)
assert isinstance(num_tokens, int)
assert num_tokens > 0
def test_get_num_tokens_from_messages(self, model: BaseChatModelV1) -> None:
"""Test that ``get_num_tokens_from_messages`` returns a positive integer."""
messages = [HumanMessage("Hello, world!")]
num_tokens = model.get_num_tokens_from_messages(messages) # type: ignore[arg-type]
assert isinstance(num_tokens, int)
assert num_tokens > 0
def test_token_counting_consistency(self, model: BaseChatModelV1) -> None:
"""Test that token counting methods are consistent with each other."""
text = "Hello, world!"
token_ids = model.get_token_ids(text)
num_tokens = model.get_num_tokens(text)
# Number of tokens should match length of token IDs list
assert len(token_ids) == num_tokens
# Serialization Tests
def test_dump_serialization(self, model: BaseChatModelV1) -> None:
"""Test that ``dump()`` returns proper serialization."""