refactor(ollama): clean up tests (#33198)

This commit is contained in:
Mason Daugherty
2025-10-01 21:52:01 -04:00
committed by GitHub
parent a89c549cb0
commit a9eda18e1e
11 changed files with 387 additions and 436 deletions

View File

@@ -22,6 +22,22 @@ from langchain_ollama.chat_models import (
MODEL_NAME = "llama3.1"
@contextmanager
def _mock_httpx_client_stream(
*args: Any, **kwargs: Any
) -> Generator[Response, Any, Any]:
yield Response(
status_code=200,
content='{"message": {"role": "assistant", "content": "The meaning ..."}}',
request=Request(method="POST", url="http://whocares:11434"),
)
dummy_raw_tool_call = {
"function": {"name": "test_func", "arguments": ""},
}
class TestChatOllama(ChatModelUnitTests):
@property
def chat_model_class(self) -> type[ChatOllama]:
@@ -35,19 +51,24 @@ class TestChatOllama(ChatModelUnitTests):
def test__parse_arguments_from_tool_call() -> None:
"""Test that string arguments are preserved as strings in tool call parsing.
This test verifies the fix for PR #30154 which addressed an issue where
string-typed tool arguments (like IDs or long strings) were being incorrectly
PR #30154
String-typed tool arguments (like IDs or long strings) were being incorrectly
processed. The parser should preserve string values as strings rather than
attempting to parse them as JSON when they're already valid string arguments.
The test uses a long string ID to ensure string arguments maintain their
original type after parsing, which is critical for tools expecting string inputs.
Use a long string ID to ensure string arguments maintain their original type after
parsing, which is critical for tools expecting string inputs.
"""
raw_response = '{"model":"sample-model","message":{"role":"assistant","content":"","tool_calls":[{"function":{"name":"get_profile_details","arguments":{"arg_1":"12345678901234567890123456"}}}]},"done":false}' # noqa: E501
raw_response = (
'{"model":"sample-model","message":{"role":"assistant","content":"",'
'"tool_calls":[{"function":{"name":"get_profile_details",'
'"arguments":{"arg_1":"12345678901234567890123456"}}}]},"done":false}'
)
raw_tool_calls = json.loads(raw_response)["message"]["tool_calls"]
response = _parse_arguments_from_tool_call(raw_tool_calls[0])
assert response is not None
assert isinstance(response["arg_1"], str)
assert response["arg_1"] == "12345678901234567890123456"
def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
@@ -57,7 +78,6 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
that just echoes the function name. This should be filtered out for
no-argument tools to return an empty dictionary.
"""
# Test case where arguments contain functionName metadata
raw_tool_call_with_metadata = {
"function": {
"name": "magic_function_no_args",
@@ -67,7 +87,7 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
response = _parse_arguments_from_tool_call(raw_tool_call_with_metadata)
assert response == {}
# Test case where arguments contain both real args and metadata
# Arguments contain both real args and metadata
raw_tool_call_mixed = {
"function": {
"name": "some_function",
@@ -77,7 +97,7 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
response_mixed = _parse_arguments_from_tool_call(raw_tool_call_mixed)
assert response_mixed == {"real_arg": "value"}
# Test case where functionName has different value (should be preserved)
# functionName has different value (should be preserved)
raw_tool_call_different = {
"function": {"name": "function_a", "arguments": {"functionName": "function_b"}}
}
@@ -85,17 +105,6 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
assert response_different == {"functionName": "function_b"}
@contextmanager
def _mock_httpx_client_stream(
*args: Any, **kwargs: Any
) -> Generator[Response, Any, Any]:
yield Response(
status_code=200,
content='{"message": {"role": "assistant", "content": "The meaning ..."}}',
request=Request(method="POST", url="http://whocares:11434"),
)
def test_arbitrary_roles_accepted_in_chatmessages(
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -120,26 +129,16 @@ def test_arbitrary_roles_accepted_in_chatmessages(
@patch("langchain_ollama.chat_models.validate_model")
def test_validate_model_on_init(mock_validate_model: Any) -> None:
"""Test that the model is validated on initialization when requested."""
# Test that validate_model is called when validate_model_on_init=True
ChatOllama(model=MODEL_NAME, validate_model_on_init=True)
mock_validate_model.assert_called_once()
mock_validate_model.reset_mock()
# Test that validate_model is NOT called when validate_model_on_init=False
ChatOllama(model=MODEL_NAME, validate_model_on_init=False)
mock_validate_model.assert_not_called()
# Test that validate_model is NOT called by default
ChatOllama(model=MODEL_NAME)
mock_validate_model.assert_not_called()
# Define a dummy raw_tool_call for the function signature
dummy_raw_tool_call = {
"function": {"name": "test_func", "arguments": ""},
}
@pytest.mark.parametrize(
("input_string", "expected_output"),
[
@@ -164,7 +163,7 @@ def test_parse_json_string_success_cases(
def test_parse_json_string_failure_case_raises_exception() -> None:
"""Tests that `_parse_json_string` raises an exception for malformed strings."""
malformed_string = "{'key': 'value',,}"
malformed_string = "{'key': 'value',,}" # Double comma is invalid
raw_tool_call = {"function": {"name": "test_func", "arguments": malformed_string}}
with pytest.raises(OutputParserException):
_parse_json_string(
@@ -181,7 +180,7 @@ def test_parse_json_string_skip_returns_input_on_failure() -> None:
result = _parse_json_string(
malformed_string,
raw_tool_call=raw_tool_call,
skip=True,
skip=True, # We want the original invalid string back
)
assert result == malformed_string

View File

@@ -16,16 +16,12 @@ def test_initialization() -> None:
@patch("langchain_ollama.embeddings.validate_model")
def test_validate_model_on_init(mock_validate_model: Any) -> None:
"""Test that the model is validated on initialization when requested."""
# Test that validate_model is called when validate_model_on_init=True
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=True)
mock_validate_model.assert_called_once()
mock_validate_model.reset_mock()
# Test that validate_model is NOT called when validate_model_on_init=False
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=False)
mock_validate_model.assert_not_called()
# Test that validate_model is NOT called by default
OllamaEmbeddings(model=MODEL_NAME)
mock_validate_model.assert_not_called()
@@ -33,20 +29,13 @@ def test_validate_model_on_init(mock_validate_model: Any) -> None:
@patch("langchain_ollama.embeddings.Client")
def test_embed_documents_passes_options(mock_client_class: Any) -> None:
"""Test that `embed_documents()` passes options, including `num_gpu`."""
# Create a mock client instance
mock_client = Mock()
mock_client_class.return_value = mock_client
# Mock the embed method response
mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
# Create embeddings with num_gpu parameter
embeddings = OllamaEmbeddings(model=MODEL_NAME, num_gpu=4, temperature=0.5)
# Call embed_documents
result = embeddings.embed_documents(["test text"])
# Verify the result
assert result == [[0.1, 0.2, 0.3]]
# Check that embed was called with correct arguments

View File

@@ -14,7 +14,7 @@ def test_initialization() -> None:
def test_model_params() -> None:
# Test standard tracing params
"""Test standard tracing params"""
llm = OllamaLLM(model=MODEL_NAME)
ls_params = llm._get_ls_params()
assert ls_params == {
@@ -36,16 +36,12 @@ def test_model_params() -> None:
@patch("langchain_ollama.llms.validate_model")
def test_validate_model_on_init(mock_validate_model: Any) -> None:
"""Test that the model is validated on initialization when requested."""
# Test that validate_model is called when validate_model_on_init=True
OllamaLLM(model=MODEL_NAME, validate_model_on_init=True)
mock_validate_model.assert_called_once()
mock_validate_model.reset_mock()
# Test that validate_model is NOT called when validate_model_on_init=False
OllamaLLM(model=MODEL_NAME, validate_model_on_init=False)
mock_validate_model.assert_not_called()
# Test that validate_model is NOT called by default
OllamaLLM(model=MODEL_NAME)
mock_validate_model.assert_not_called()