mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 07:36:08 +00:00
more tests
This commit is contained in:
parent
c91d681c83
commit
db41bfe1f8
@ -621,7 +621,7 @@ class ChatOllamaV1(BaseChatModelV1):
|
|||||||
chat_params = self._chat_params(messages, stop, **kwargs)
|
chat_params = self._chat_params(messages, stop, **kwargs)
|
||||||
|
|
||||||
if chat_params["stream"]:
|
if chat_params["stream"]:
|
||||||
async for part in await self._async_client.chat(**chat_params):
|
async for part in self._async_client.chat(**chat_params): # type: ignore[attr-defined]
|
||||||
if not isinstance(part, str):
|
if not isinstance(part, str):
|
||||||
# Skip empty load responses
|
# Skip empty load responses
|
||||||
if (
|
if (
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import AsyncIterator, Iterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
@ -160,52 +161,115 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
|
|||||||
def chat_model_params(self) -> dict:
|
def chat_model_params(self) -> dict:
|
||||||
return {"model": MODEL_NAME}
|
return {"model": MODEL_NAME}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_non_standard_blocks(self) -> bool:
|
||||||
|
"""Override to indicate Ollama doesn't support non-standard content blocks."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model(self) -> Iterator[ChatOllamaV1]:
|
||||||
|
"""Create a ChatOllamaV1 instance for testing."""
|
||||||
|
sync_patcher = patch("langchain_ollama.chat_models_v1.Client")
|
||||||
|
async_patcher = patch("langchain_ollama.chat_models_v1.AsyncClient")
|
||||||
|
|
||||||
|
mock_sync_client_class = sync_patcher.start()
|
||||||
|
mock_async_client_class = async_patcher.start()
|
||||||
|
|
||||||
|
mock_sync_client = MagicMock()
|
||||||
|
mock_async_client = MagicMock()
|
||||||
|
|
||||||
|
mock_sync_client_class.return_value = mock_sync_client
|
||||||
|
mock_async_client_class.return_value = mock_async_client
|
||||||
|
|
||||||
|
def mock_chat_response(*_args: Any, **_kwargs: Any) -> Iterator[dict[str, Any]]:
|
||||||
|
return iter(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"created_at": "2024-01-01T00:00:00Z",
|
||||||
|
"message": {"role": "assistant", "content": "Test response"},
|
||||||
|
"done": True,
|
||||||
|
"done_reason": "stop",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_async_chat_iterator(
|
||||||
|
*_args: Any, **_kwargs: Any
|
||||||
|
) -> AsyncIterator[dict[str, Any]]:
|
||||||
|
for item in mock_chat_response(*_args, **_kwargs):
|
||||||
|
yield item
|
||||||
|
|
||||||
|
mock_sync_client.chat.side_effect = mock_chat_response
|
||||||
|
mock_async_client.chat.side_effect = mock_async_chat_iterator
|
||||||
|
|
||||||
|
model_instance = self.chat_model_class(**self.chat_model_params)
|
||||||
|
yield model_instance
|
||||||
|
sync_patcher.stop()
|
||||||
|
async_patcher.stop()
|
||||||
|
|
||||||
def test_initialization(self) -> None:
|
def test_initialization(self) -> None:
|
||||||
"""Test `ChatOllamaV1` initialization."""
|
"""Test `ChatOllamaV1` initialization."""
|
||||||
llm = ChatOllamaV1(model=MODEL_NAME)
|
with (
|
||||||
|
patch("langchain_ollama.chat_models_v1.Client"),
|
||||||
|
patch("langchain_ollama.chat_models_v1.AsyncClient"),
|
||||||
|
):
|
||||||
|
llm = ChatOllamaV1(model=MODEL_NAME)
|
||||||
|
|
||||||
assert llm.model == MODEL_NAME
|
assert llm.model == MODEL_NAME
|
||||||
assert llm._llm_type == "chat-ollama-v1"
|
assert llm._llm_type == "chat-ollama-v1"
|
||||||
|
|
||||||
def test_chat_params(self) -> None:
|
def test_chat_params(self) -> None:
|
||||||
"""Test `_chat_params()`."""
|
"""Test `_chat_params()`."""
|
||||||
llm = ChatOllamaV1(model=MODEL_NAME, temperature=0.7)
|
with (
|
||||||
|
patch("langchain_ollama.chat_models_v1.Client"),
|
||||||
|
patch("langchain_ollama.chat_models_v1.AsyncClient"),
|
||||||
|
):
|
||||||
|
llm = ChatOllamaV1(model=MODEL_NAME, temperature=0.7)
|
||||||
|
|
||||||
messages: list[MessageV1] = [HumanMessageV1("Hello")]
|
messages: list[MessageV1] = [HumanMessageV1("Hello")]
|
||||||
|
|
||||||
params = llm._chat_params(messages)
|
params = llm._chat_params(messages)
|
||||||
|
|
||||||
assert params["model"] == MODEL_NAME
|
assert params["model"] == MODEL_NAME
|
||||||
assert len(params["messages"]) == 1
|
assert len(params["messages"]) == 1
|
||||||
assert params["messages"][0]["role"] == "user"
|
assert params["messages"][0]["role"] == "user"
|
||||||
assert params["messages"][0]["content"] == "Hello"
|
assert params["messages"][0]["content"] == "Hello"
|
||||||
|
|
||||||
# Ensure options carry over
|
# Ensure options carry over
|
||||||
assert params["options"].temperature == 0.7
|
assert params["options"].temperature == 0.7
|
||||||
|
|
||||||
def test_ls_params(self) -> None:
|
def test_ls_params(self) -> None:
|
||||||
"""Test LangSmith parameters."""
|
"""Test LangSmith parameters."""
|
||||||
llm = ChatOllamaV1(model=MODEL_NAME, temperature=0.5)
|
with (
|
||||||
|
patch("langchain_ollama.chat_models_v1.Client"),
|
||||||
|
patch("langchain_ollama.chat_models_v1.AsyncClient"),
|
||||||
|
):
|
||||||
|
llm = ChatOllamaV1(model=MODEL_NAME, temperature=0.5)
|
||||||
|
|
||||||
ls_params = llm._get_ls_params()
|
ls_params = llm._get_ls_params()
|
||||||
|
|
||||||
assert ls_params.get("ls_provider") == "ollama"
|
assert ls_params.get("ls_provider") == "ollama"
|
||||||
assert ls_params.get("ls_model_name") == MODEL_NAME
|
assert ls_params.get("ls_model_name") == MODEL_NAME
|
||||||
assert ls_params.get("ls_model_type") == "chat"
|
assert ls_params.get("ls_model_type") == "chat"
|
||||||
assert ls_params.get("ls_temperature") == 0.5
|
assert ls_params.get("ls_temperature") == 0.5
|
||||||
|
|
||||||
def test_bind_tools_basic(self) -> None:
|
def test_bind_tools_basic(self) -> None:
|
||||||
"""Test basic tool binding functionality."""
|
"""Test basic tool binding functionality."""
|
||||||
llm = ChatOllamaV1(model=MODEL_NAME)
|
with (
|
||||||
|
patch("langchain_ollama.chat_models_v1.Client"),
|
||||||
|
patch("langchain_ollama.chat_models_v1.AsyncClient"),
|
||||||
|
):
|
||||||
|
llm = ChatOllamaV1(model=MODEL_NAME)
|
||||||
|
|
||||||
def test_tool(query: str) -> str:
|
def test_tool(query: str) -> str:
|
||||||
"""A test tool."""
|
"""A test tool."""
|
||||||
return f"Result for: {query}"
|
return f"Result for: {query}"
|
||||||
|
|
||||||
bound_llm = llm.bind_tools([test_tool])
|
bound_llm = llm.bind_tools([test_tool])
|
||||||
|
|
||||||
# Should return a bound model
|
# Should return a bound model
|
||||||
assert bound_llm is not None
|
assert bound_llm is not None
|
||||||
|
|
||||||
|
|
||||||
# Missing: `test_arbitrary_roles_accepted_in_chatmessages`
|
# Missing: `test_arbitrary_roles_accepted_in_chatmessages`
|
||||||
@ -213,9 +277,15 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
|
|||||||
# But can be added if needed in the future.
|
# But can be added if needed in the future.
|
||||||
|
|
||||||
|
|
||||||
@patch("langchain_ollama.chat_models.validate_model")
|
@patch("langchain_ollama.chat_models_v1.validate_model")
|
||||||
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
@patch("langchain_ollama.chat_models_v1.Client")
|
||||||
|
def test_validate_model_on_init(
|
||||||
|
mock_client_class: Any, mock_validate_model: Any
|
||||||
|
) -> None:
|
||||||
"""Test that the model is validated on initialization when requested."""
|
"""Test that the model is validated on initialization when requested."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
# Test that validate_model is called when validate_model_on_init=True
|
# Test that validate_model is called when validate_model_on_init=True
|
||||||
ChatOllamaV1(model=MODEL_NAME, validate_model_on_init=True)
|
ChatOllamaV1(model=MODEL_NAME, validate_model_on_init=True)
|
||||||
mock_validate_model.assert_called_once()
|
mock_validate_model.assert_called_once()
|
||||||
@ -314,16 +384,16 @@ def test_load_response_with_empty_content_is_skipped(
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_client_class.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
mock_client.chat.return_value = load_only_response
|
mock_client.chat.return_value = iter(load_only_response)
|
||||||
|
|
||||||
llm = ChatOllamaV1(model="test-model")
|
llm = ChatOllamaV1(model="test-model")
|
||||||
|
|
||||||
with (
|
with (
|
||||||
caplog.at_level(logging.WARNING),
|
caplog.at_level(logging.WARNING),
|
||||||
pytest.raises(ValueError, match="No data received from Ollama stream"),
|
pytest.raises(ValueError, match="No generations found in stream"),
|
||||||
):
|
):
|
||||||
llm.invoke([HumanMessageV1("Hello")])
|
llm.invoke([HumanMessageV1("Hello")])
|
||||||
|
|
||||||
@ -344,16 +414,16 @@ def test_load_response_with_whitespace_content_is_skipped(
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_client_class.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
mock_client.chat.return_value = load_whitespace_response
|
mock_client.chat.return_value = iter(load_whitespace_response)
|
||||||
|
|
||||||
llm = ChatOllamaV1(model="test-model")
|
llm = ChatOllamaV1(model="test-model")
|
||||||
|
|
||||||
with (
|
with (
|
||||||
caplog.at_level(logging.WARNING),
|
caplog.at_level(logging.WARNING),
|
||||||
pytest.raises(ValueError, match="No data received from Ollama stream"),
|
pytest.raises(ValueError, match="No generations found in stream"),
|
||||||
):
|
):
|
||||||
llm.invoke([HumanMessageV1("Hello")])
|
llm.invoke([HumanMessageV1("Hello")])
|
||||||
assert "Ollama returned empty response with done_reason='load'" in caplog.text
|
assert "Ollama returned empty response with done_reason='load'" in caplog.text
|
||||||
@ -383,10 +453,10 @@ def test_load_followed_by_content_response(
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_client_class.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
mock_client.chat.return_value = load_then_content_response
|
mock_client.chat.return_value = iter(load_then_content_response)
|
||||||
|
|
||||||
llm = ChatOllamaV1(model="test-model")
|
llm = ChatOllamaV1(model="test-model")
|
||||||
|
|
||||||
@ -394,7 +464,8 @@ def test_load_followed_by_content_response(
|
|||||||
result = llm.invoke([HumanMessageV1("Hello")])
|
result = llm.invoke([HumanMessageV1("Hello")])
|
||||||
|
|
||||||
assert "Ollama returned empty response with done_reason='load'" in caplog.text
|
assert "Ollama returned empty response with done_reason='load'" in caplog.text
|
||||||
assert result.content == "Hello! How can I help you today?"
|
assert len(result.content) == 1
|
||||||
|
assert result.text == "Hello! How can I help you today?"
|
||||||
assert result.response_metadata.get("done_reason") == "stop"
|
assert result.response_metadata.get("done_reason") == "stop"
|
||||||
|
|
||||||
|
|
||||||
@ -412,16 +483,17 @@ def test_load_response_with_actual_content_is_not_skipped(
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_client_class.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
mock_client.chat.return_value = load_with_content_response
|
mock_client.chat.return_value = iter(load_with_content_response)
|
||||||
|
|
||||||
llm = ChatOllamaV1(model="test-model")
|
llm = ChatOllamaV1(model="test-model")
|
||||||
|
|
||||||
with caplog.at_level(logging.WARNING):
|
with caplog.at_level(logging.WARNING):
|
||||||
result = llm.invoke([HumanMessageV1("Hello")])
|
result = llm.invoke([HumanMessageV1("Hello")])
|
||||||
|
|
||||||
assert result.content == "This is actual content"
|
assert len(result.content) == 1
|
||||||
|
assert result.text == "This is actual content"
|
||||||
assert result.response_metadata.get("done_reason") == "load"
|
assert result.response_metadata.get("done_reason") == "load"
|
||||||
assert not caplog.text
|
assert not caplog.text
|
||||||
|
@ -207,38 +207,6 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
|
|||||||
params = model._identifying_params
|
params = model._identifying_params
|
||||||
assert isinstance(params, dict) # Should be dict-like mapping
|
assert isinstance(params, dict) # Should be dict-like mapping
|
||||||
|
|
||||||
# Token Counting Tests
|
|
||||||
def test_get_token_ids(self, model: BaseChatModelV1) -> None:
|
|
||||||
"""Test that ``get_token_ids`` returns a list of integers."""
|
|
||||||
text = "Hello, world!"
|
|
||||||
token_ids = model.get_token_ids(text)
|
|
||||||
assert isinstance(token_ids, list)
|
|
||||||
assert all(isinstance(token_id, int) for token_id in token_ids)
|
|
||||||
assert len(token_ids) > 0
|
|
||||||
|
|
||||||
def test_get_num_tokens(self, model: BaseChatModelV1) -> None:
|
|
||||||
"""Test that ``get_num_tokens`` returns a positive integer."""
|
|
||||||
text = "Hello, world!"
|
|
||||||
num_tokens = model.get_num_tokens(text)
|
|
||||||
assert isinstance(num_tokens, int)
|
|
||||||
assert num_tokens > 0
|
|
||||||
|
|
||||||
def test_get_num_tokens_from_messages(self, model: BaseChatModelV1) -> None:
|
|
||||||
"""Test that ``get_num_tokens_from_messages`` returns a positive integer."""
|
|
||||||
messages = [HumanMessage("Hello, world!")]
|
|
||||||
num_tokens = model.get_num_tokens_from_messages(messages) # type: ignore[arg-type]
|
|
||||||
assert isinstance(num_tokens, int)
|
|
||||||
assert num_tokens > 0
|
|
||||||
|
|
||||||
def test_token_counting_consistency(self, model: BaseChatModelV1) -> None:
|
|
||||||
"""Test that token counting methods are consistent with each other."""
|
|
||||||
text = "Hello, world!"
|
|
||||||
token_ids = model.get_token_ids(text)
|
|
||||||
num_tokens = model.get_num_tokens(text)
|
|
||||||
|
|
||||||
# Number of tokens should match length of token IDs list
|
|
||||||
assert len(token_ids) == num_tokens
|
|
||||||
|
|
||||||
# Serialization Tests
|
# Serialization Tests
|
||||||
def test_dump_serialization(self, model: BaseChatModelV1) -> None:
|
def test_dump_serialization(self, model: BaseChatModelV1) -> None:
|
||||||
"""Test that ``dump()`` returns proper serialization."""
|
"""Test that ``dump()`` returns proper serialization."""
|
||||||
|
Loading…
Reference in New Issue
Block a user