mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 23:26:34 +00:00
more tests
This commit is contained in:
parent
c91d681c83
commit
db41bfe1f8
@ -621,7 +621,7 @@ class ChatOllamaV1(BaseChatModelV1):
|
||||
chat_params = self._chat_params(messages, stop, **kwargs)
|
||||
|
||||
if chat_params["stream"]:
|
||||
async for part in await self._async_client.chat(**chat_params):
|
||||
async for part in self._async_client.chat(**chat_params): # type: ignore[attr-defined]
|
||||
if not isinstance(part, str):
|
||||
# Skip empty load responses
|
||||
if (
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -160,8 +161,59 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"model": MODEL_NAME}
|
||||
|
||||
@property
|
||||
def supports_non_standard_blocks(self) -> bool:
|
||||
"""Override to indicate Ollama doesn't support non-standard content blocks."""
|
||||
return False
|
||||
|
||||
@pytest.fixture
|
||||
def model(self) -> Iterator[ChatOllamaV1]:
|
||||
"""Create a ChatOllamaV1 instance for testing."""
|
||||
sync_patcher = patch("langchain_ollama.chat_models_v1.Client")
|
||||
async_patcher = patch("langchain_ollama.chat_models_v1.AsyncClient")
|
||||
|
||||
mock_sync_client_class = sync_patcher.start()
|
||||
mock_async_client_class = async_patcher.start()
|
||||
|
||||
mock_sync_client = MagicMock()
|
||||
mock_async_client = MagicMock()
|
||||
|
||||
mock_sync_client_class.return_value = mock_sync_client
|
||||
mock_async_client_class.return_value = mock_async_client
|
||||
|
||||
def mock_chat_response(*_args: Any, **_kwargs: Any) -> Iterator[dict[str, Any]]:
|
||||
return iter(
|
||||
[
|
||||
{
|
||||
"model": MODEL_NAME,
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"message": {"role": "assistant", "content": "Test response"},
|
||||
"done": True,
|
||||
"done_reason": "stop",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
async def mock_async_chat_iterator(
|
||||
*_args: Any, **_kwargs: Any
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
for item in mock_chat_response(*_args, **_kwargs):
|
||||
yield item
|
||||
|
||||
mock_sync_client.chat.side_effect = mock_chat_response
|
||||
mock_async_client.chat.side_effect = mock_async_chat_iterator
|
||||
|
||||
model_instance = self.chat_model_class(**self.chat_model_params)
|
||||
yield model_instance
|
||||
sync_patcher.stop()
|
||||
async_patcher.stop()
|
||||
|
||||
def test_initialization(self) -> None:
|
||||
"""Test `ChatOllamaV1` initialization."""
|
||||
with (
|
||||
patch("langchain_ollama.chat_models_v1.Client"),
|
||||
patch("langchain_ollama.chat_models_v1.AsyncClient"),
|
||||
):
|
||||
llm = ChatOllamaV1(model=MODEL_NAME)
|
||||
|
||||
assert llm.model == MODEL_NAME
|
||||
@ -169,6 +221,10 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
|
||||
|
||||
def test_chat_params(self) -> None:
|
||||
"""Test `_chat_params()`."""
|
||||
with (
|
||||
patch("langchain_ollama.chat_models_v1.Client"),
|
||||
patch("langchain_ollama.chat_models_v1.AsyncClient"),
|
||||
):
|
||||
llm = ChatOllamaV1(model=MODEL_NAME, temperature=0.7)
|
||||
|
||||
messages: list[MessageV1] = [HumanMessageV1("Hello")]
|
||||
@ -185,6 +241,10 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
|
||||
|
||||
def test_ls_params(self) -> None:
|
||||
"""Test LangSmith parameters."""
|
||||
with (
|
||||
patch("langchain_ollama.chat_models_v1.Client"),
|
||||
patch("langchain_ollama.chat_models_v1.AsyncClient"),
|
||||
):
|
||||
llm = ChatOllamaV1(model=MODEL_NAME, temperature=0.5)
|
||||
|
||||
ls_params = llm._get_ls_params()
|
||||
@ -196,6 +256,10 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
|
||||
|
||||
def test_bind_tools_basic(self) -> None:
|
||||
"""Test basic tool binding functionality."""
|
||||
with (
|
||||
patch("langchain_ollama.chat_models_v1.Client"),
|
||||
patch("langchain_ollama.chat_models_v1.AsyncClient"),
|
||||
):
|
||||
llm = ChatOllamaV1(model=MODEL_NAME)
|
||||
|
||||
def test_tool(query: str) -> str:
|
||||
@ -213,9 +277,15 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
|
||||
# But can be added if needed in the future.
|
||||
|
||||
|
||||
@patch("langchain_ollama.chat_models.validate_model")
|
||||
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||
@patch("langchain_ollama.chat_models_v1.validate_model")
|
||||
@patch("langchain_ollama.chat_models_v1.Client")
|
||||
def test_validate_model_on_init(
|
||||
mock_client_class: Any, mock_validate_model: Any
|
||||
) -> None:
|
||||
"""Test that the model is validated on initialization when requested."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Test that validate_model is called when validate_model_on_init=True
|
||||
ChatOllamaV1(model=MODEL_NAME, validate_model_on_init=True)
|
||||
mock_validate_model.assert_called_once()
|
||||
@ -314,16 +384,16 @@ def test_load_response_with_empty_content_is_skipped(
|
||||
}
|
||||
]
|
||||
|
||||
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
||||
with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.chat.return_value = load_only_response
|
||||
mock_client.chat.return_value = iter(load_only_response)
|
||||
|
||||
llm = ChatOllamaV1(model="test-model")
|
||||
|
||||
with (
|
||||
caplog.at_level(logging.WARNING),
|
||||
pytest.raises(ValueError, match="No data received from Ollama stream"),
|
||||
pytest.raises(ValueError, match="No generations found in stream"),
|
||||
):
|
||||
llm.invoke([HumanMessageV1("Hello")])
|
||||
|
||||
@ -344,16 +414,16 @@ def test_load_response_with_whitespace_content_is_skipped(
|
||||
}
|
||||
]
|
||||
|
||||
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
||||
with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.chat.return_value = load_whitespace_response
|
||||
mock_client.chat.return_value = iter(load_whitespace_response)
|
||||
|
||||
llm = ChatOllamaV1(model="test-model")
|
||||
|
||||
with (
|
||||
caplog.at_level(logging.WARNING),
|
||||
pytest.raises(ValueError, match="No data received from Ollama stream"),
|
||||
pytest.raises(ValueError, match="No generations found in stream"),
|
||||
):
|
||||
llm.invoke([HumanMessageV1("Hello")])
|
||||
assert "Ollama returned empty response with done_reason='load'" in caplog.text
|
||||
@ -383,10 +453,10 @@ def test_load_followed_by_content_response(
|
||||
},
|
||||
]
|
||||
|
||||
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
||||
with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.chat.return_value = load_then_content_response
|
||||
mock_client.chat.return_value = iter(load_then_content_response)
|
||||
|
||||
llm = ChatOllamaV1(model="test-model")
|
||||
|
||||
@ -394,7 +464,8 @@ def test_load_followed_by_content_response(
|
||||
result = llm.invoke([HumanMessageV1("Hello")])
|
||||
|
||||
assert "Ollama returned empty response with done_reason='load'" in caplog.text
|
||||
assert result.content == "Hello! How can I help you today?"
|
||||
assert len(result.content) == 1
|
||||
assert result.text == "Hello! How can I help you today?"
|
||||
assert result.response_metadata.get("done_reason") == "stop"
|
||||
|
||||
|
||||
@ -412,16 +483,17 @@ def test_load_response_with_actual_content_is_not_skipped(
|
||||
}
|
||||
]
|
||||
|
||||
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
||||
with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.chat.return_value = load_with_content_response
|
||||
mock_client.chat.return_value = iter(load_with_content_response)
|
||||
|
||||
llm = ChatOllamaV1(model="test-model")
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
result = llm.invoke([HumanMessageV1("Hello")])
|
||||
|
||||
assert result.content == "This is actual content"
|
||||
assert len(result.content) == 1
|
||||
assert result.text == "This is actual content"
|
||||
assert result.response_metadata.get("done_reason") == "load"
|
||||
assert not caplog.text
|
||||
|
@ -207,38 +207,6 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
|
||||
params = model._identifying_params
|
||||
assert isinstance(params, dict) # Should be dict-like mapping
|
||||
|
||||
# Token Counting Tests
|
||||
def test_get_token_ids(self, model: BaseChatModelV1) -> None:
|
||||
"""Test that ``get_token_ids`` returns a list of integers."""
|
||||
text = "Hello, world!"
|
||||
token_ids = model.get_token_ids(text)
|
||||
assert isinstance(token_ids, list)
|
||||
assert all(isinstance(token_id, int) for token_id in token_ids)
|
||||
assert len(token_ids) > 0
|
||||
|
||||
def test_get_num_tokens(self, model: BaseChatModelV1) -> None:
|
||||
"""Test that ``get_num_tokens`` returns a positive integer."""
|
||||
text = "Hello, world!"
|
||||
num_tokens = model.get_num_tokens(text)
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens > 0
|
||||
|
||||
def test_get_num_tokens_from_messages(self, model: BaseChatModelV1) -> None:
|
||||
"""Test that ``get_num_tokens_from_messages`` returns a positive integer."""
|
||||
messages = [HumanMessage("Hello, world!")]
|
||||
num_tokens = model.get_num_tokens_from_messages(messages) # type: ignore[arg-type]
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens > 0
|
||||
|
||||
def test_token_counting_consistency(self, model: BaseChatModelV1) -> None:
|
||||
"""Test that token counting methods are consistent with each other."""
|
||||
text = "Hello, world!"
|
||||
token_ids = model.get_token_ids(text)
|
||||
num_tokens = model.get_num_tokens(text)
|
||||
|
||||
# Number of tokens should match length of token IDs list
|
||||
assert len(token_ids) == num_tokens
|
||||
|
||||
# Serialization Tests
|
||||
def test_dump_serialization(self, model: BaseChatModelV1) -> None:
|
||||
"""Test that ``dump()`` returns proper serialization."""
|
||||
|
Loading…
Reference in New Issue
Block a user