mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
refactor(ollama): clean up tests (#33198)
This commit is contained in:
@@ -22,6 +22,22 @@ from langchain_ollama.chat_models import (
|
||||
MODEL_NAME = "llama3.1"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _mock_httpx_client_stream(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> Generator[Response, Any, Any]:
|
||||
yield Response(
|
||||
status_code=200,
|
||||
content='{"message": {"role": "assistant", "content": "The meaning ..."}}',
|
||||
request=Request(method="POST", url="http://whocares:11434"),
|
||||
)
|
||||
|
||||
|
||||
dummy_raw_tool_call = {
|
||||
"function": {"name": "test_func", "arguments": ""},
|
||||
}
|
||||
|
||||
|
||||
class TestChatOllama(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> type[ChatOllama]:
|
||||
@@ -35,19 +51,24 @@ class TestChatOllama(ChatModelUnitTests):
|
||||
def test__parse_arguments_from_tool_call() -> None:
|
||||
"""Test that string arguments are preserved as strings in tool call parsing.
|
||||
|
||||
This test verifies the fix for PR #30154 which addressed an issue where
|
||||
string-typed tool arguments (like IDs or long strings) were being incorrectly
|
||||
PR #30154
|
||||
String-typed tool arguments (like IDs or long strings) were being incorrectly
|
||||
processed. The parser should preserve string values as strings rather than
|
||||
attempting to parse them as JSON when they're already valid string arguments.
|
||||
|
||||
The test uses a long string ID to ensure string arguments maintain their
|
||||
original type after parsing, which is critical for tools expecting string inputs.
|
||||
Use a long string ID to ensure string arguments maintain their original type after
|
||||
parsing, which is critical for tools expecting string inputs.
|
||||
"""
|
||||
raw_response = '{"model":"sample-model","message":{"role":"assistant","content":"","tool_calls":[{"function":{"name":"get_profile_details","arguments":{"arg_1":"12345678901234567890123456"}}}]},"done":false}' # noqa: E501
|
||||
raw_response = (
|
||||
'{"model":"sample-model","message":{"role":"assistant","content":"",'
|
||||
'"tool_calls":[{"function":{"name":"get_profile_details",'
|
||||
'"arguments":{"arg_1":"12345678901234567890123456"}}}]},"done":false}'
|
||||
)
|
||||
raw_tool_calls = json.loads(raw_response)["message"]["tool_calls"]
|
||||
response = _parse_arguments_from_tool_call(raw_tool_calls[0])
|
||||
assert response is not None
|
||||
assert isinstance(response["arg_1"], str)
|
||||
assert response["arg_1"] == "12345678901234567890123456"
|
||||
|
||||
|
||||
def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
|
||||
@@ -57,7 +78,6 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
|
||||
that just echoes the function name. This should be filtered out for
|
||||
no-argument tools to return an empty dictionary.
|
||||
"""
|
||||
# Test case where arguments contain functionName metadata
|
||||
raw_tool_call_with_metadata = {
|
||||
"function": {
|
||||
"name": "magic_function_no_args",
|
||||
@@ -67,7 +87,7 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
|
||||
response = _parse_arguments_from_tool_call(raw_tool_call_with_metadata)
|
||||
assert response == {}
|
||||
|
||||
# Test case where arguments contain both real args and metadata
|
||||
# Arguments contain both real args and metadata
|
||||
raw_tool_call_mixed = {
|
||||
"function": {
|
||||
"name": "some_function",
|
||||
@@ -77,7 +97,7 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
|
||||
response_mixed = _parse_arguments_from_tool_call(raw_tool_call_mixed)
|
||||
assert response_mixed == {"real_arg": "value"}
|
||||
|
||||
# Test case where functionName has different value (should be preserved)
|
||||
# functionName has different value (should be preserved)
|
||||
raw_tool_call_different = {
|
||||
"function": {"name": "function_a", "arguments": {"functionName": "function_b"}}
|
||||
}
|
||||
@@ -85,17 +105,6 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
|
||||
assert response_different == {"functionName": "function_b"}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _mock_httpx_client_stream(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> Generator[Response, Any, Any]:
|
||||
yield Response(
|
||||
status_code=200,
|
||||
content='{"message": {"role": "assistant", "content": "The meaning ..."}}',
|
||||
request=Request(method="POST", url="http://whocares:11434"),
|
||||
)
|
||||
|
||||
|
||||
def test_arbitrary_roles_accepted_in_chatmessages(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -120,26 +129,16 @@ def test_arbitrary_roles_accepted_in_chatmessages(
|
||||
@patch("langchain_ollama.chat_models.validate_model")
|
||||
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||
"""Test that the model is validated on initialization when requested."""
|
||||
# Test that validate_model is called when validate_model_on_init=True
|
||||
ChatOllama(model=MODEL_NAME, validate_model_on_init=True)
|
||||
mock_validate_model.assert_called_once()
|
||||
mock_validate_model.reset_mock()
|
||||
|
||||
# Test that validate_model is NOT called when validate_model_on_init=False
|
||||
ChatOllama(model=MODEL_NAME, validate_model_on_init=False)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
# Test that validate_model is NOT called by default
|
||||
ChatOllama(model=MODEL_NAME)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
|
||||
# Define a dummy raw_tool_call for the function signature
|
||||
dummy_raw_tool_call = {
|
||||
"function": {"name": "test_func", "arguments": ""},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("input_string", "expected_output"),
|
||||
[
|
||||
@@ -164,7 +163,7 @@ def test_parse_json_string_success_cases(
|
||||
|
||||
def test_parse_json_string_failure_case_raises_exception() -> None:
|
||||
"""Tests that `_parse_json_string` raises an exception for malformed strings."""
|
||||
malformed_string = "{'key': 'value',,}"
|
||||
malformed_string = "{'key': 'value',,}" # Double comma is invalid
|
||||
raw_tool_call = {"function": {"name": "test_func", "arguments": malformed_string}}
|
||||
with pytest.raises(OutputParserException):
|
||||
_parse_json_string(
|
||||
@@ -181,7 +180,7 @@ def test_parse_json_string_skip_returns_input_on_failure() -> None:
|
||||
result = _parse_json_string(
|
||||
malformed_string,
|
||||
raw_tool_call=raw_tool_call,
|
||||
skip=True,
|
||||
skip=True, # We want the original invalid string back
|
||||
)
|
||||
assert result == malformed_string
|
||||
|
||||
|
||||
@@ -16,16 +16,12 @@ def test_initialization() -> None:
|
||||
@patch("langchain_ollama.embeddings.validate_model")
|
||||
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||
"""Test that the model is validated on initialization when requested."""
|
||||
# Test that validate_model is called when validate_model_on_init=True
|
||||
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=True)
|
||||
mock_validate_model.assert_called_once()
|
||||
mock_validate_model.reset_mock()
|
||||
|
||||
# Test that validate_model is NOT called when validate_model_on_init=False
|
||||
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=False)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
# Test that validate_model is NOT called by default
|
||||
OllamaEmbeddings(model=MODEL_NAME)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
@@ -33,20 +29,13 @@ def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||
@patch("langchain_ollama.embeddings.Client")
|
||||
def test_embed_documents_passes_options(mock_client_class: Any) -> None:
|
||||
"""Test that `embed_documents()` passes options, including `num_gpu`."""
|
||||
# Create a mock client instance
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Mock the embed method response
|
||||
mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
|
||||
|
||||
# Create embeddings with num_gpu parameter
|
||||
embeddings = OllamaEmbeddings(model=MODEL_NAME, num_gpu=4, temperature=0.5)
|
||||
|
||||
# Call embed_documents
|
||||
result = embeddings.embed_documents(["test text"])
|
||||
|
||||
# Verify the result
|
||||
assert result == [[0.1, 0.2, 0.3]]
|
||||
|
||||
# Check that embed was called with correct arguments
|
||||
|
||||
@@ -14,7 +14,7 @@ def test_initialization() -> None:
|
||||
|
||||
|
||||
def test_model_params() -> None:
|
||||
# Test standard tracing params
|
||||
"""Test standard tracing params"""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
@@ -36,16 +36,12 @@ def test_model_params() -> None:
|
||||
@patch("langchain_ollama.llms.validate_model")
|
||||
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||
"""Test that the model is validated on initialization when requested."""
|
||||
# Test that validate_model is called when validate_model_on_init=True
|
||||
OllamaLLM(model=MODEL_NAME, validate_model_on_init=True)
|
||||
mock_validate_model.assert_called_once()
|
||||
mock_validate_model.reset_mock()
|
||||
|
||||
# Test that validate_model is NOT called when validate_model_on_init=False
|
||||
OllamaLLM(model=MODEL_NAME, validate_model_on_init=False)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
# Test that validate_model is NOT called by default
|
||||
OllamaLLM(model=MODEL_NAME)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user