more tests

This commit is contained in:
Mason Daugherty 2025-08-04 17:31:47 -04:00
parent c91d681c83
commit db41bfe1f8
No known key found for this signature in database
3 changed files with 112 additions and 72 deletions

View File

@ -621,7 +621,7 @@ class ChatOllamaV1(BaseChatModelV1):
chat_params = self._chat_params(messages, stop, **kwargs) chat_params = self._chat_params(messages, stop, **kwargs)
if chat_params["stream"]: if chat_params["stream"]:
async for part in await self._async_client.chat(**chat_params): async for part in self._async_client.chat(**chat_params): # type: ignore[attr-defined]
if not isinstance(part, str): if not isinstance(part, str):
# Skip empty load responses # Skip empty load responses
if ( if (

View File

@ -2,6 +2,7 @@
import json import json
import logging import logging
from collections.abc import AsyncIterator, Iterator
from typing import Any from typing import Any
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@ -160,52 +161,115 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return {"model": MODEL_NAME} return {"model": MODEL_NAME}
@property
def supports_non_standard_blocks(self) -> bool:
"""Override to indicate Ollama doesn't support non-standard content blocks."""
return False
@pytest.fixture
def model(self) -> Iterator[ChatOllamaV1]:
"""Create a ChatOllamaV1 instance for testing."""
sync_patcher = patch("langchain_ollama.chat_models_v1.Client")
async_patcher = patch("langchain_ollama.chat_models_v1.AsyncClient")
mock_sync_client_class = sync_patcher.start()
mock_async_client_class = async_patcher.start()
mock_sync_client = MagicMock()
mock_async_client = MagicMock()
mock_sync_client_class.return_value = mock_sync_client
mock_async_client_class.return_value = mock_async_client
def mock_chat_response(*_args: Any, **_kwargs: Any) -> Iterator[dict[str, Any]]:
return iter(
[
{
"model": MODEL_NAME,
"created_at": "2024-01-01T00:00:00Z",
"message": {"role": "assistant", "content": "Test response"},
"done": True,
"done_reason": "stop",
}
]
)
async def mock_async_chat_iterator(
*_args: Any, **_kwargs: Any
) -> AsyncIterator[dict[str, Any]]:
for item in mock_chat_response(*_args, **_kwargs):
yield item
mock_sync_client.chat.side_effect = mock_chat_response
mock_async_client.chat.side_effect = mock_async_chat_iterator
model_instance = self.chat_model_class(**self.chat_model_params)
yield model_instance
sync_patcher.stop()
async_patcher.stop()
def test_initialization(self) -> None: def test_initialization(self) -> None:
"""Test `ChatOllamaV1` initialization.""" """Test `ChatOllamaV1` initialization."""
llm = ChatOllamaV1(model=MODEL_NAME) with (
patch("langchain_ollama.chat_models_v1.Client"),
patch("langchain_ollama.chat_models_v1.AsyncClient"),
):
llm = ChatOllamaV1(model=MODEL_NAME)
assert llm.model == MODEL_NAME assert llm.model == MODEL_NAME
assert llm._llm_type == "chat-ollama-v1" assert llm._llm_type == "chat-ollama-v1"
def test_chat_params(self) -> None: def test_chat_params(self) -> None:
"""Test `_chat_params()`.""" """Test `_chat_params()`."""
llm = ChatOllamaV1(model=MODEL_NAME, temperature=0.7) with (
patch("langchain_ollama.chat_models_v1.Client"),
patch("langchain_ollama.chat_models_v1.AsyncClient"),
):
llm = ChatOllamaV1(model=MODEL_NAME, temperature=0.7)
messages: list[MessageV1] = [HumanMessageV1("Hello")] messages: list[MessageV1] = [HumanMessageV1("Hello")]
params = llm._chat_params(messages) params = llm._chat_params(messages)
assert params["model"] == MODEL_NAME assert params["model"] == MODEL_NAME
assert len(params["messages"]) == 1 assert len(params["messages"]) == 1
assert params["messages"][0]["role"] == "user" assert params["messages"][0]["role"] == "user"
assert params["messages"][0]["content"] == "Hello" assert params["messages"][0]["content"] == "Hello"
# Ensure options carry over # Ensure options carry over
assert params["options"].temperature == 0.7 assert params["options"].temperature == 0.7
def test_ls_params(self) -> None: def test_ls_params(self) -> None:
"""Test LangSmith parameters.""" """Test LangSmith parameters."""
llm = ChatOllamaV1(model=MODEL_NAME, temperature=0.5) with (
patch("langchain_ollama.chat_models_v1.Client"),
patch("langchain_ollama.chat_models_v1.AsyncClient"),
):
llm = ChatOllamaV1(model=MODEL_NAME, temperature=0.5)
ls_params = llm._get_ls_params() ls_params = llm._get_ls_params()
assert ls_params.get("ls_provider") == "ollama" assert ls_params.get("ls_provider") == "ollama"
assert ls_params.get("ls_model_name") == MODEL_NAME assert ls_params.get("ls_model_name") == MODEL_NAME
assert ls_params.get("ls_model_type") == "chat" assert ls_params.get("ls_model_type") == "chat"
assert ls_params.get("ls_temperature") == 0.5 assert ls_params.get("ls_temperature") == 0.5
def test_bind_tools_basic(self) -> None: def test_bind_tools_basic(self) -> None:
"""Test basic tool binding functionality.""" """Test basic tool binding functionality."""
llm = ChatOllamaV1(model=MODEL_NAME) with (
patch("langchain_ollama.chat_models_v1.Client"),
patch("langchain_ollama.chat_models_v1.AsyncClient"),
):
llm = ChatOllamaV1(model=MODEL_NAME)
def test_tool(query: str) -> str: def test_tool(query: str) -> str:
"""A test tool.""" """A test tool."""
return f"Result for: {query}" return f"Result for: {query}"
bound_llm = llm.bind_tools([test_tool]) bound_llm = llm.bind_tools([test_tool])
# Should return a bound model # Should return a bound model
assert bound_llm is not None assert bound_llm is not None
# Missing: `test_arbitrary_roles_accepted_in_chatmessages` # Missing: `test_arbitrary_roles_accepted_in_chatmessages`
@ -213,9 +277,15 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
# But can be added if needed in the future. # But can be added if needed in the future.
@patch("langchain_ollama.chat_models.validate_model") @patch("langchain_ollama.chat_models_v1.validate_model")
def test_validate_model_on_init(mock_validate_model: Any) -> None: @patch("langchain_ollama.chat_models_v1.Client")
def test_validate_model_on_init(
mock_client_class: Any, mock_validate_model: Any
) -> None:
"""Test that the model is validated on initialization when requested.""" """Test that the model is validated on initialization when requested."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
# Test that validate_model is called when validate_model_on_init=True # Test that validate_model is called when validate_model_on_init=True
ChatOllamaV1(model=MODEL_NAME, validate_model_on_init=True) ChatOllamaV1(model=MODEL_NAME, validate_model_on_init=True)
mock_validate_model.assert_called_once() mock_validate_model.assert_called_once()
@ -314,16 +384,16 @@ def test_load_response_with_empty_content_is_skipped(
} }
] ]
with patch("langchain_ollama.chat_models.Client") as mock_client_class: with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
mock_client = MagicMock() mock_client = MagicMock()
mock_client_class.return_value = mock_client mock_client_class.return_value = mock_client
mock_client.chat.return_value = load_only_response mock_client.chat.return_value = iter(load_only_response)
llm = ChatOllamaV1(model="test-model") llm = ChatOllamaV1(model="test-model")
with ( with (
caplog.at_level(logging.WARNING), caplog.at_level(logging.WARNING),
pytest.raises(ValueError, match="No data received from Ollama stream"), pytest.raises(ValueError, match="No generations found in stream"),
): ):
llm.invoke([HumanMessageV1("Hello")]) llm.invoke([HumanMessageV1("Hello")])
@ -344,16 +414,16 @@ def test_load_response_with_whitespace_content_is_skipped(
} }
] ]
with patch("langchain_ollama.chat_models.Client") as mock_client_class: with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
mock_client = MagicMock() mock_client = MagicMock()
mock_client_class.return_value = mock_client mock_client_class.return_value = mock_client
mock_client.chat.return_value = load_whitespace_response mock_client.chat.return_value = iter(load_whitespace_response)
llm = ChatOllamaV1(model="test-model") llm = ChatOllamaV1(model="test-model")
with ( with (
caplog.at_level(logging.WARNING), caplog.at_level(logging.WARNING),
pytest.raises(ValueError, match="No data received from Ollama stream"), pytest.raises(ValueError, match="No generations found in stream"),
): ):
llm.invoke([HumanMessageV1("Hello")]) llm.invoke([HumanMessageV1("Hello")])
assert "Ollama returned empty response with done_reason='load'" in caplog.text assert "Ollama returned empty response with done_reason='load'" in caplog.text
@ -383,10 +453,10 @@ def test_load_followed_by_content_response(
}, },
] ]
with patch("langchain_ollama.chat_models.Client") as mock_client_class: with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
mock_client = MagicMock() mock_client = MagicMock()
mock_client_class.return_value = mock_client mock_client_class.return_value = mock_client
mock_client.chat.return_value = load_then_content_response mock_client.chat.return_value = iter(load_then_content_response)
llm = ChatOllamaV1(model="test-model") llm = ChatOllamaV1(model="test-model")
@ -394,7 +464,8 @@ def test_load_followed_by_content_response(
result = llm.invoke([HumanMessageV1("Hello")]) result = llm.invoke([HumanMessageV1("Hello")])
assert "Ollama returned empty response with done_reason='load'" in caplog.text assert "Ollama returned empty response with done_reason='load'" in caplog.text
assert result.content == "Hello! How can I help you today?" assert len(result.content) == 1
assert result.text == "Hello! How can I help you today?"
assert result.response_metadata.get("done_reason") == "stop" assert result.response_metadata.get("done_reason") == "stop"
@ -412,16 +483,17 @@ def test_load_response_with_actual_content_is_not_skipped(
} }
] ]
with patch("langchain_ollama.chat_models.Client") as mock_client_class: with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
mock_client = MagicMock() mock_client = MagicMock()
mock_client_class.return_value = mock_client mock_client_class.return_value = mock_client
mock_client.chat.return_value = load_with_content_response mock_client.chat.return_value = iter(load_with_content_response)
llm = ChatOllamaV1(model="test-model") llm = ChatOllamaV1(model="test-model")
with caplog.at_level(logging.WARNING): with caplog.at_level(logging.WARNING):
result = llm.invoke([HumanMessageV1("Hello")]) result = llm.invoke([HumanMessageV1("Hello")])
assert result.content == "This is actual content" assert len(result.content) == 1
assert result.text == "This is actual content"
assert result.response_metadata.get("done_reason") == "load" assert result.response_metadata.get("done_reason") == "load"
assert not caplog.text assert not caplog.text

View File

@ -207,38 +207,6 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
params = model._identifying_params params = model._identifying_params
assert isinstance(params, dict) # Should be dict-like mapping assert isinstance(params, dict) # Should be dict-like mapping
# Token Counting Tests
def test_get_token_ids(self, model: BaseChatModelV1) -> None:
"""Test that ``get_token_ids`` returns a list of integers."""
text = "Hello, world!"
token_ids = model.get_token_ids(text)
assert isinstance(token_ids, list)
assert all(isinstance(token_id, int) for token_id in token_ids)
assert len(token_ids) > 0
def test_get_num_tokens(self, model: BaseChatModelV1) -> None:
"""Test that ``get_num_tokens`` returns a positive integer."""
text = "Hello, world!"
num_tokens = model.get_num_tokens(text)
assert isinstance(num_tokens, int)
assert num_tokens > 0
def test_get_num_tokens_from_messages(self, model: BaseChatModelV1) -> None:
"""Test that ``get_num_tokens_from_messages`` returns a positive integer."""
messages = [HumanMessage("Hello, world!")]
num_tokens = model.get_num_tokens_from_messages(messages) # type: ignore[arg-type]
assert isinstance(num_tokens, int)
assert num_tokens > 0
def test_token_counting_consistency(self, model: BaseChatModelV1) -> None:
"""Test that token counting methods are consistent with each other."""
text = "Hello, world!"
token_ids = model.get_token_ids(text)
num_tokens = model.get_num_tokens(text)
# Number of tokens should match length of token IDs list
assert len(token_ids) == num_tokens
# Serialization Tests # Serialization Tests
def test_dump_serialization(self, model: BaseChatModelV1) -> None: def test_dump_serialization(self, model: BaseChatModelV1) -> None:
"""Test that ``dump()`` returns proper serialization.""" """Test that ``dump()`` returns proper serialization."""