refactor(ollama): clean up tests (#33198)

This commit is contained in:
Mason Daugherty
2025-10-01 21:52:01 -04:00
committed by GitHub
parent a89c549cb0
commit a9eda18e1e
11 changed files with 387 additions and 436 deletions

View File

@@ -14,14 +14,20 @@ service.
"""
from importlib import metadata
from importlib.metadata import PackageNotFoundError
from langchain_ollama.chat_models import ChatOllama
from langchain_ollama.embeddings import OllamaEmbeddings
from langchain_ollama.llms import OllamaLLM
def _raise_package_not_found_error() -> None:
raise PackageNotFoundError
try:
if __package__ is None:
raise metadata.PackageNotFoundError
_raise_package_not_found_error()
__version__ = metadata.version(__package__)
except metadata.PackageNotFoundError:
# Case where package metadata is not available.

View File

@@ -132,7 +132,6 @@ def _parse_arguments_from_tool_call(
Should be removed/changed if fixed upstream.
See https://github.com/ollama/ollama/issues/6155
"""
if "function" not in raw_tool_call:
return None
@@ -834,7 +833,7 @@ class ChatOllama(BaseChatModel):
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
verbose: bool = False, # noqa: FBT001, FBT002
verbose: bool = False, # noqa: FBT002
**kwargs: Any,
) -> ChatGenerationChunk:
final_chunk = None
@@ -860,7 +859,7 @@ class ChatOllama(BaseChatModel):
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
verbose: bool = False, # noqa: FBT001, FBT002
verbose: bool = False, # noqa: FBT002
**kwargs: Any,
) -> ChatGenerationChunk:
final_chunk = None

View File

@@ -368,7 +368,7 @@ class OllamaLLM(BaseLLM):
prompt: str,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
verbose: bool = False, # noqa: FBT001, FBT002
verbose: bool = False, # noqa: FBT002
**kwargs: Any,
) -> GenerationChunk:
final_chunk = None
@@ -410,7 +410,7 @@ class OllamaLLM(BaseLLM):
prompt: str,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
verbose: bool = False, # noqa: FBT001, FBT002
verbose: bool = False, # noqa: FBT002
**kwargs: Any,
) -> GenerationChunk:
final_chunk = None

View File

@@ -60,15 +60,17 @@ ignore = [
"SLF001", # Private member access
"UP007", # pyupgrade: non-pep604-annotation-union
"UP045", # pyupgrade: non-pep604-annotation-optional
"PLR0912",
"C901",
"PLR0915",
"FIX002", # TODOs
"TD002", # TODO authors
"TC002", # Incorrect type-checking block
"TC003", # Incorrect type-checking block
"PLR0912", # Too many branches
"PLR0915", # Too many statements
"C901", # Function too complex
"FBT001", # Boolean function param
# TODO:
# TODO
"ANN401",
"TC002",
"TC003",
"TRY301",
]
unfixable = ["B028"] # People should intentionally tune the stacklevel
@@ -91,16 +93,11 @@ asyncio_mode = "auto"
[tool.ruff.lint.extend-per-file-ignores]
"tests/**/*.py" = [
"S101", # Tests need assertions
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
"PLR2004",
# TODO
"ANN401",
"ARG001",
"PT011",
"FIX",
"TD",
"S101", # Tests need assertions
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
"ARG001", # Unused function arguments in tests (e.g. kwargs)
"PLR2004", # Magic value in comparisons
"PT011", # `pytest.raises()` is too broad
]
"scripts/*.py" = [
"INP001", # Not a package

View File

@@ -3,9 +3,16 @@
from __future__ import annotations
from typing import Annotated, Optional
from unittest.mock import MagicMock, patch
import pytest
from pydantic import BaseModel, Field
from httpx import ConnectError
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.tool import ToolCallChunk
from langchain_core.tools import tool
from ollama import ResponseError
from pydantic import BaseModel, Field, ValidationError
from typing_extensions import TypedDict
from langchain_ollama import ChatOllama
@@ -13,6 +20,43 @@ from langchain_ollama import ChatOllama
DEFAULT_MODEL_NAME = "llama3.1"
@tool
def get_current_weather(location: str) -> dict:
"""Gets the current weather in a given location."""
if "boston" in location.lower():
return {"temperature": "15°F", "conditions": "snow"}
return {"temperature": "unknown", "conditions": "unknown"}
@patch("langchain_ollama.chat_models.Client.list")
def test_init_model_not_found(mock_list: MagicMock) -> None:
"""Test that a ValueError is raised when the model is not found."""
mock_list.side_effect = ValueError("Test model not found")
with pytest.raises(ValueError) as excinfo:
ChatOllama(model="non-existent-model", validate_model_on_init=True)
assert "Test model not found" in str(excinfo.value)
@patch("langchain_ollama.chat_models.Client.list")
def test_init_connection_error(mock_list: MagicMock) -> None:
"""Test that a ValidationError is raised on connect failure during init."""
mock_list.side_effect = ConnectError("Test connection error")
with pytest.raises(ValidationError) as excinfo:
ChatOllama(model="any-model", validate_model_on_init=True)
assert "Failed to connect to Ollama" in str(excinfo.value)
@patch("langchain_ollama.chat_models.Client.list")
def test_init_response_error(mock_list: MagicMock) -> None:
"""Test that a ResponseError is raised."""
mock_list.side_effect = ResponseError("Test response error")
with pytest.raises(ValidationError) as excinfo:
ChatOllama(model="any-model", validate_model_on_init=True)
assert "Received an error from the Ollama API" in str(excinfo.value)
@pytest.mark.parametrize(("method"), [("function_calling"), ("json_schema")])
def test_structured_output(method: str) -> None:
"""Test to verify structured output via tool calling and `format` parameter."""
@@ -98,3 +142,97 @@ def test_structured_output_deeply_nested(model: str) -> None:
for chunk in chat.stream(text):
assert isinstance(chunk, Data)
@pytest.mark.parametrize(("model"), [(DEFAULT_MODEL_NAME)])
def test_tool_streaming(model: str) -> None:
"""Test that the model can stream tool calls."""
llm = ChatOllama(model=model)
chat_model_with_tools = llm.bind_tools([get_current_weather])
prompt = [HumanMessage("What is the weather today in Boston?")]
# Flags and collectors for validation
tool_chunk_found = False
final_tool_calls = []
collected_tool_chunks: list[ToolCallChunk] = []
# Stream the response and inspect the chunks
for chunk in chat_model_with_tools.stream(prompt):
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
if chunk.tool_call_chunks:
tool_chunk_found = True
collected_tool_chunks.extend(chunk.tool_call_chunks)
if chunk.tool_calls:
final_tool_calls.extend(chunk.tool_calls)
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
assert len(final_tool_calls) == 1, (
f"Expected 1 final tool call, but got {len(final_tool_calls)}"
)
final_tool_call = final_tool_calls[0]
assert final_tool_call["name"] == "get_current_weather"
assert final_tool_call["args"] == {"location": "Boston"}
assert len(collected_tool_chunks) > 0
assert collected_tool_chunks[0]["name"] == "get_current_weather"
# The ID should be consistent across chunks that have it
tool_call_id = collected_tool_chunks[0].get("id")
assert tool_call_id is not None
assert all(
chunk.get("id") == tool_call_id
for chunk in collected_tool_chunks
if chunk.get("id")
)
assert final_tool_call["id"] == tool_call_id
@pytest.mark.parametrize(("model"), [(DEFAULT_MODEL_NAME)])
async def test_tool_astreaming(model: str) -> None:
"""Test that the model can stream tool calls."""
llm = ChatOllama(model=model)
chat_model_with_tools = llm.bind_tools([get_current_weather])
prompt = [HumanMessage("What is the weather today in Boston?")]
# Flags and collectors for validation
tool_chunk_found = False
final_tool_calls = []
collected_tool_chunks: list[ToolCallChunk] = []
# Stream the response and inspect the chunks
async for chunk in chat_model_with_tools.astream(prompt):
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
if chunk.tool_call_chunks:
tool_chunk_found = True
collected_tool_chunks.extend(chunk.tool_call_chunks)
if chunk.tool_calls:
final_tool_calls.extend(chunk.tool_calls)
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
assert len(final_tool_calls) == 1, (
f"Expected 1 final tool call, but got {len(final_tool_calls)}"
)
final_tool_call = final_tool_calls[0]
assert final_tool_call["name"] == "get_current_weather"
assert final_tool_call["args"] == {"location": "Boston"}
assert len(collected_tool_chunks) > 0
assert collected_tool_chunks[0]["name"] == "get_current_weather"
# The ID should be consistent across chunks that have it
tool_call_id = collected_tool_chunks[0].get("id")
assert tool_call_id is not None
assert all(
chunk.get("id") == tool_call_id
for chunk in collected_tool_chunks
if chunk.get("id")
)
assert final_tool_call["id"] == tool_call_id

View File

@@ -8,9 +8,10 @@ from langchain_ollama import ChatOllama
SAMPLE = "What is 3^3?"
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_stream_no_reasoning(model: str) -> None:
"""Test streaming with `reasoning=False`"""
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
@pytest.mark.parametrize("use_async", [False, True])
async def test_stream_no_reasoning(model: str, use_async: bool) -> None:
"""Test streaming with ``reasoning=False``."""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
messages = [
{
@@ -19,12 +20,20 @@ def test_stream_no_reasoning(model: str) -> None:
}
]
result = None
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
if use_async:
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
else:
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "<think>" not in result.content
@@ -32,33 +41,10 @@ def test_stream_no_reasoning(model: str) -> None:
assert "reasoning_content" not in result.additional_kwargs
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_astream_no_reasoning(model: str) -> None:
"""Test async streaming with `reasoning=False`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "<think>" not in result.content
assert "</think>" not in result.content
assert "reasoning_content" not in result.additional_kwargs
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_stream_reasoning_none(model: str) -> None:
"""Test streaming with `reasoning=None`"""
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
@pytest.mark.parametrize("use_async", [False, True])
async def test_stream_reasoning_none(model: str, use_async: bool) -> None:
"""Test streaming with ``reasoning=None``."""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
messages = [
{
@@ -67,12 +53,20 @@ def test_stream_reasoning_none(model: str) -> None:
}
]
result = None
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
if use_async:
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
else:
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "<think>" in result.content
@@ -82,35 +76,10 @@ def test_stream_reasoning_none(model: str) -> None:
assert "</think>" not in result.additional_kwargs.get("reasoning_content", "")
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_astream_reasoning_none(model: str) -> None:
"""Test async streaming with `reasoning=None`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "<think>" in result.content
assert "</think>" in result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" not in result.additional_kwargs.get("reasoning_content", "")
assert "</think>" not in result.additional_kwargs.get("reasoning_content", "")
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_reasoning_stream(model: str) -> None:
"""Test streaming with `reasoning=True`"""
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
@pytest.mark.parametrize("use_async", [False, True])
async def test_reasoning_stream(model: str, use_async: bool) -> None:
"""Test streaming with ``reasoning=True``."""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
messages = [
{
@@ -119,12 +88,20 @@ def test_reasoning_stream(model: str) -> None:
}
]
result = None
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
if use_async:
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
else:
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" in result.additional_kwargs
@@ -135,63 +112,32 @@ def test_reasoning_stream(model: str) -> None:
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_reasoning_astream(model: str) -> None:
"""Test async streaming with `reasoning=True`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" not in result.content
assert "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_invoke_no_reasoning(model: str) -> None:
"""Test using invoke with `reasoning=False`"""
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
@pytest.mark.parametrize("use_async", [False, True])
async def test_invoke_no_reasoning(model: str, use_async: bool) -> None:
"""Test invoke with ``reasoning=False``."""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
message = HumanMessage(content=SAMPLE)
result = llm.invoke([message])
if use_async:
result = await llm.ainvoke([message])
else:
result = llm.invoke([message])
assert result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" not in result.content
assert "</think>" not in result.content
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_ainvoke_no_reasoning(model: str) -> None:
"""Test using async invoke with `reasoning=False`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
message = HumanMessage(content=SAMPLE)
result = await llm.ainvoke([message])
assert result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" not in result.content
assert "</think>" not in result.content
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_invoke_reasoning_none(model: str) -> None:
"""Test using invoke with `reasoning=None`"""
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
@pytest.mark.parametrize("use_async", [False, True])
async def test_invoke_reasoning_none(model: str, use_async: bool) -> None:
"""Test invoke with ``reasoning=None``."""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
message = HumanMessage(content=SAMPLE)
result = llm.invoke([message])
if use_async:
result = await llm.ainvoke([message])
else:
result = llm.invoke([message])
assert result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" in result.content
@@ -200,26 +146,16 @@ def test_invoke_reasoning_none(model: str) -> None:
assert "</think>" not in result.additional_kwargs.get("reasoning_content", "")
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_ainvoke_reasoning_none(model: str) -> None:
"""Test using async invoke with `reasoning=None`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
message = HumanMessage(content=SAMPLE)
result = await llm.ainvoke([message])
assert result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" in result.content
assert "</think>" in result.content
assert "<think>" not in result.additional_kwargs.get("reasoning_content", "")
assert "</think>" not in result.additional_kwargs.get("reasoning_content", "")
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_reasoning_invoke(model: str) -> None:
"""Test invoke with `reasoning=True`"""
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
@pytest.mark.parametrize("use_async", [False, True])
async def test_reasoning_invoke(model: str, use_async: bool) -> None:
"""Test invoke with ``reasoning=True``."""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
message = HumanMessage(content=SAMPLE)
result = llm.invoke([message])
if use_async:
result = await llm.ainvoke([message])
else:
result = llm.invoke([message])
assert result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
@@ -229,22 +165,7 @@ def test_reasoning_invoke(model: str) -> None:
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_reasoning_ainvoke(model: str) -> None:
"""Test invoke with `reasoning=True`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
message = HumanMessage(content=SAMPLE)
result = await llm.ainvoke([message])
assert result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" not in result.content
assert "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
def test_think_tag_stripping_necessity(model: str) -> None:
"""Test that demonstrates why ``_strip_think_tags`` is necessary.

View File

@@ -1,29 +1,14 @@
"""Test chat model integration using standard integration tests."""
from unittest.mock import MagicMock, patch
import pytest
from httpx import ConnectError
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolCallChunk
from langchain_core.tools import tool
from langchain_tests.integration_tests import ChatModelIntegrationTests
from ollama import ResponseError
from pydantic import ValidationError
from langchain_ollama.chat_models import ChatOllama
DEFAULT_MODEL_NAME = "llama3.1"
@tool
def get_current_weather(location: str) -> dict:
"""Gets the current weather in a given location."""
if "boston" in location.lower():
return {"temperature": "15°F", "conditions": "snow"}
return {"temperature": "unknown", "conditions": "unknown"}
class TestChatOllama(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> type[ChatOllama]:
@@ -47,94 +32,6 @@ class TestChatOllama(ChatModelIntegrationTests):
def supports_image_inputs(self) -> bool:
return True
def test_tool_streaming(self, model: BaseChatModel) -> None:
"""Test that the model can stream tool calls."""
chat_model_with_tools = model.bind_tools([get_current_weather])
prompt = [HumanMessage("What is the weather today in Boston?")]
# Flags and collectors for validation
tool_chunk_found = False
final_tool_calls = []
collected_tool_chunks: list[ToolCallChunk] = []
# Stream the response and inspect the chunks
for chunk in chat_model_with_tools.stream(prompt):
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
if chunk.tool_call_chunks:
tool_chunk_found = True
collected_tool_chunks.extend(chunk.tool_call_chunks)
if chunk.tool_calls:
final_tool_calls.extend(chunk.tool_calls)
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
assert len(final_tool_calls) == 1, (
f"Expected 1 final tool call, but got {len(final_tool_calls)}"
)
final_tool_call = final_tool_calls[0]
assert final_tool_call["name"] == "get_current_weather"
assert final_tool_call["args"] == {"location": "Boston"}
assert len(collected_tool_chunks) > 0
assert collected_tool_chunks[0]["name"] == "get_current_weather"
# The ID should be consistent across chunks that have it
tool_call_id = collected_tool_chunks[0].get("id")
assert tool_call_id is not None
assert all(
chunk.get("id") == tool_call_id
for chunk in collected_tool_chunks
if chunk.get("id")
)
assert final_tool_call["id"] == tool_call_id
async def test_tool_astreaming(self, model: BaseChatModel) -> None:
"""Test that the model can stream tool calls."""
chat_model_with_tools = model.bind_tools([get_current_weather])
prompt = [HumanMessage("What is the weather today in Boston?")]
# Flags and collectors for validation
tool_chunk_found = False
final_tool_calls = []
collected_tool_chunks: list[ToolCallChunk] = []
# Stream the response and inspect the chunks
async for chunk in chat_model_with_tools.astream(prompt):
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
if chunk.tool_call_chunks:
tool_chunk_found = True
collected_tool_chunks.extend(chunk.tool_call_chunks)
if chunk.tool_calls:
final_tool_calls.extend(chunk.tool_calls)
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
assert len(final_tool_calls) == 1, (
f"Expected 1 final tool call, but got {len(final_tool_calls)}"
)
final_tool_call = final_tool_calls[0]
assert final_tool_call["name"] == "get_current_weather"
assert final_tool_call["args"] == {"location": "Boston"}
assert len(collected_tool_chunks) > 0
assert collected_tool_chunks[0]["name"] == "get_current_weather"
# The ID should be consistent across chunks that have it
tool_call_id = collected_tool_chunks[0].get("id")
assert tool_call_id is not None
assert all(
chunk.get("id") == tool_call_id
for chunk in collected_tool_chunks
if chunk.get("id")
)
assert final_tool_call["id"] == tool_call_id
@pytest.mark.xfail(
reason=(
"Will sometime encounter AssertionErrors where tool responses are "
@@ -153,28 +50,13 @@ class TestChatOllama(ChatModelIntegrationTests):
async def test_tool_calling_async(self, model: BaseChatModel) -> None:
await super().test_tool_calling_async(model)
@patch("langchain_ollama.chat_models.Client.list")
def test_init_model_not_found(self, mock_list: MagicMock) -> None:
"""Test that a ValueError is raised when the model is not found."""
mock_list.side_effect = ValueError("Test model not found")
with pytest.raises(ValueError) as excinfo:
ChatOllama(model="non-existent-model", validate_model_on_init=True)
assert "Test model not found" in str(excinfo.value)
@patch("langchain_ollama.chat_models.Client.list")
def test_init_connection_error(self, mock_list: MagicMock) -> None:
"""Test that a ValidationError is raised on connect failure during init."""
mock_list.side_effect = ConnectError("Test connection error")
with pytest.raises(ValidationError) as excinfo:
ChatOllama(model="any-model", validate_model_on_init=True)
assert "Failed to connect to Ollama" in str(excinfo.value)
@patch("langchain_ollama.chat_models.Client.list")
def test_init_response_error(self, mock_list: MagicMock) -> None:
"""Test that a ResponseError is raised."""
mock_list.side_effect = ResponseError("Test response error")
with pytest.raises(ValidationError) as excinfo:
ChatOllama(model="any-model", validate_model_on_init=True)
assert "Received an error from the Ollama API" in str(excinfo.value)
@pytest.mark.xfail(
reason=(
"Will sometimes fail due to Ollama's inconsistent tool call argument "
"structure (see https://github.com/ollama/ollama/issues/6155). "
"Args may contain unexpected keys like 'conversations' instead of "
"empty dict."
)
)
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
super().test_tool_calling_with_no_arguments(model)

View File

@@ -13,11 +13,74 @@ REASONING_MODEL_NAME = os.environ.get("OLLAMA_REASONING_TEST_MODEL", "deepseek-r
SAMPLE = "What is 3^3?"
def test_invoke() -> None:
"""Test sync invoke returning a string."""
llm = OllamaLLM(model=MODEL_NAME)
result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
assert isinstance(result, str)
async def test_ainvoke() -> None:
"""Test async invoke returning a string."""
llm = OllamaLLM(model=MODEL_NAME)
result = await llm.ainvoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
assert isinstance(result, str)
def test_batch() -> None:
"""Test batch sync token generation from `OllamaLLM`."""
llm = OllamaLLM(model=MODEL_NAME)
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)
async def test_abatch() -> None:
"""Test batch async token generation from `OllamaLLM`."""
llm = OllamaLLM(model=MODEL_NAME)
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)
def test_batch_tags() -> None:
"""Test batch sync token generation with tags."""
llm = OllamaLLM(model=MODEL_NAME)
result = llm.batch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token, str)
async def test_abatch_tags() -> None:
"""Test batch async token generation with tags."""
llm = OllamaLLM(model=MODEL_NAME)
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token, str)
def test_stream_text_tokens() -> None:
"""Test streaming raw string tokens from `OllamaLLM`."""
llm = OllamaLLM(model=MODEL_NAME)
for token in llm.stream("I'm Pickle Rick"):
for token in llm.stream("Hi."):
assert isinstance(token, str)
async def test_astream_text_tokens() -> None:
"""Test async streaming raw string tokens from `OllamaLLM`."""
llm = OllamaLLM(model=MODEL_NAME)
async for token in llm.astream("Hi."):
assert isinstance(token, str)
@@ -28,7 +91,6 @@ def test__stream_no_reasoning(model: str) -> None:
result_chunk = None
for chunk in llm._stream(SAMPLE):
# Should be a GenerationChunk
assert isinstance(chunk, GenerationChunk)
if result_chunk is None:
result_chunk = chunk
@@ -38,8 +100,28 @@ def test__stream_no_reasoning(model: str) -> None:
# The final result must be a GenerationChunk with visible content
assert isinstance(result_chunk, GenerationChunk)
assert result_chunk.text
# No separate reasoning_content
assert "reasoning_content" not in result_chunk.generation_info # type: ignore[operator]
assert result_chunk.generation_info
assert not result_chunk.generation_info.get("reasoning_content")
@pytest.mark.parametrize(("model"), [(REASONING_MODEL_NAME)])
async def test__astream_no_reasoning(model: str) -> None:
"""Test low-level async chunk streaming with `reasoning=False`."""
llm = OllamaLLM(model=model, num_ctx=2**12)
result_chunk = None
async for chunk in llm._astream(SAMPLE):
assert isinstance(chunk, GenerationChunk)
if result_chunk is None:
result_chunk = chunk
else:
result_chunk += chunk
# The final result must be a GenerationChunk with visible content
assert isinstance(result_chunk, GenerationChunk)
assert result_chunk.text
assert result_chunk.generation_info
assert not result_chunk.generation_info.get("reasoning_content")
@pytest.mark.parametrize(("model"), [(REASONING_MODEL_NAME)])
@@ -57,40 +139,17 @@ def test__stream_with_reasoning(model: str) -> None:
assert isinstance(result_chunk, GenerationChunk)
assert result_chunk.text
# Should have extracted reasoning into generation_info
assert "reasoning_content" in result_chunk.generation_info # type: ignore[operator]
assert len(result_chunk.generation_info["reasoning_content"]) > 0 # type: ignore[index]
assert result_chunk.generation_info
reasoning_content = result_chunk.generation_info.get("reasoning_content")
assert reasoning_content
assert len(reasoning_content) > 0
# And neither the visible nor the hidden portion contains <think> tags
assert "<think>" not in result_chunk.text
assert "</think>" not in result_chunk.text
assert "<think>" not in result_chunk.generation_info["reasoning_content"] # type: ignore[index]
assert "</think>" not in result_chunk.generation_info["reasoning_content"] # type: ignore[index]
async def test_astream_text_tokens() -> None:
"""Test async streaming raw string tokens from `OllamaLLM`."""
llm = OllamaLLM(model=MODEL_NAME)
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token, str)
@pytest.mark.parametrize(("model"), [(REASONING_MODEL_NAME)])
async def test__astream_no_reasoning(model: str) -> None:
"""Test low-level async chunk streaming with `reasoning=False`."""
llm = OllamaLLM(model=model, num_ctx=2**12)
result_chunk = None
async for chunk in llm._astream(SAMPLE):
assert isinstance(chunk, GenerationChunk)
if result_chunk is None:
result_chunk = chunk
else:
result_chunk += chunk
assert isinstance(result_chunk, GenerationChunk)
assert result_chunk.text
assert "reasoning_content" not in result_chunk.generation_info # type: ignore[operator]
assert "<think>" not in reasoning_content
assert "</think>" not in reasoning_content
@pytest.mark.parametrize(("model"), [(REASONING_MODEL_NAME)])
@@ -108,49 +167,14 @@ async def test__astream_with_reasoning(model: str) -> None:
assert isinstance(result_chunk, GenerationChunk)
assert result_chunk.text
assert "reasoning_content" in result_chunk.generation_info # type: ignore[operator]
assert len(result_chunk.generation_info["reasoning_content"]) > 0 # type: ignore[index]
async def test_abatch() -> None:
"""Test batch sync token generation from `OllamaLLM`."""
llm = OllamaLLM(model=MODEL_NAME)
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)
async def test_abatch_tags() -> None:
"""Test batch sync token generation with tags."""
llm = OllamaLLM(model=MODEL_NAME)
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token, str)
def test_batch() -> None:
"""Test batch token generation from `OllamaLLM`."""
llm = OllamaLLM(model=MODEL_NAME)
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)
async def test_ainvoke() -> None:
"""Test async invoke returning a string."""
llm = OllamaLLM(model=MODEL_NAME)
result = await llm.ainvoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
assert isinstance(result, str)
def test_invoke() -> None:
"""Test sync invoke returning a string."""
llm = OllamaLLM(model=MODEL_NAME)
result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
assert isinstance(result, str)
# Should have extracted reasoning into generation_info
assert result_chunk.generation_info
reasoning_content = result_chunk.generation_info.get("reasoning_content")
assert reasoning_content
assert len(reasoning_content) > 0
# And neither the visible nor the hidden portion contains <think> tags
assert "<think>" not in result_chunk.text
assert "</think>" not in result_chunk.text
assert "<think>" not in reasoning_content
assert "</think>" not in reasoning_content

View File

@@ -22,6 +22,22 @@ from langchain_ollama.chat_models import (
MODEL_NAME = "llama3.1"
@contextmanager
def _mock_httpx_client_stream(
*args: Any, **kwargs: Any
) -> Generator[Response, Any, Any]:
yield Response(
status_code=200,
content='{"message": {"role": "assistant", "content": "The meaning ..."}}',
request=Request(method="POST", url="http://whocares:11434"),
)
dummy_raw_tool_call = {
"function": {"name": "test_func", "arguments": ""},
}
class TestChatOllama(ChatModelUnitTests):
@property
def chat_model_class(self) -> type[ChatOllama]:
@@ -35,19 +51,24 @@ class TestChatOllama(ChatModelUnitTests):
def test__parse_arguments_from_tool_call() -> None:
"""Test that string arguments are preserved as strings in tool call parsing.
This test verifies the fix for PR #30154 which addressed an issue where
string-typed tool arguments (like IDs or long strings) were being incorrectly
PR #30154
String-typed tool arguments (like IDs or long strings) were being incorrectly
processed. The parser should preserve string values as strings rather than
attempting to parse them as JSON when they're already valid string arguments.
The test uses a long string ID to ensure string arguments maintain their
original type after parsing, which is critical for tools expecting string inputs.
Use a long string ID to ensure string arguments maintain their original type after
parsing, which is critical for tools expecting string inputs.
"""
raw_response = '{"model":"sample-model","message":{"role":"assistant","content":"","tool_calls":[{"function":{"name":"get_profile_details","arguments":{"arg_1":"12345678901234567890123456"}}}]},"done":false}' # noqa: E501
raw_response = (
'{"model":"sample-model","message":{"role":"assistant","content":"",'
'"tool_calls":[{"function":{"name":"get_profile_details",'
'"arguments":{"arg_1":"12345678901234567890123456"}}}]},"done":false}'
)
raw_tool_calls = json.loads(raw_response)["message"]["tool_calls"]
response = _parse_arguments_from_tool_call(raw_tool_calls[0])
assert response is not None
assert isinstance(response["arg_1"], str)
assert response["arg_1"] == "12345678901234567890123456"
def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
@@ -57,7 +78,6 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
that just echoes the function name. This should be filtered out for
no-argument tools to return an empty dictionary.
"""
# Test case where arguments contain functionName metadata
raw_tool_call_with_metadata = {
"function": {
"name": "magic_function_no_args",
@@ -67,7 +87,7 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
response = _parse_arguments_from_tool_call(raw_tool_call_with_metadata)
assert response == {}
# Test case where arguments contain both real args and metadata
# Arguments contain both real args and metadata
raw_tool_call_mixed = {
"function": {
"name": "some_function",
@@ -77,7 +97,7 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
response_mixed = _parse_arguments_from_tool_call(raw_tool_call_mixed)
assert response_mixed == {"real_arg": "value"}
# Test case where functionName has different value (should be preserved)
# functionName has different value (should be preserved)
raw_tool_call_different = {
"function": {"name": "function_a", "arguments": {"functionName": "function_b"}}
}
@@ -85,17 +105,6 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
assert response_different == {"functionName": "function_b"}
@contextmanager
def _mock_httpx_client_stream(
*args: Any, **kwargs: Any
) -> Generator[Response, Any, Any]:
yield Response(
status_code=200,
content='{"message": {"role": "assistant", "content": "The meaning ..."}}',
request=Request(method="POST", url="http://whocares:11434"),
)
def test_arbitrary_roles_accepted_in_chatmessages(
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -120,26 +129,16 @@ def test_arbitrary_roles_accepted_in_chatmessages(
@patch("langchain_ollama.chat_models.validate_model")
def test_validate_model_on_init(mock_validate_model: Any) -> None:
"""Test that the model is validated on initialization when requested."""
# Test that validate_model is called when validate_model_on_init=True
ChatOllama(model=MODEL_NAME, validate_model_on_init=True)
mock_validate_model.assert_called_once()
mock_validate_model.reset_mock()
# Test that validate_model is NOT called when validate_model_on_init=False
ChatOllama(model=MODEL_NAME, validate_model_on_init=False)
mock_validate_model.assert_not_called()
# Test that validate_model is NOT called by default
ChatOllama(model=MODEL_NAME)
mock_validate_model.assert_not_called()
# Define a dummy raw_tool_call for the function signature
dummy_raw_tool_call = {
"function": {"name": "test_func", "arguments": ""},
}
@pytest.mark.parametrize(
("input_string", "expected_output"),
[
@@ -164,7 +163,7 @@ def test_parse_json_string_success_cases(
def test_parse_json_string_failure_case_raises_exception() -> None:
"""Tests that `_parse_json_string` raises an exception for malformed strings."""
malformed_string = "{'key': 'value',,}"
malformed_string = "{'key': 'value',,}" # Double comma is invalid
raw_tool_call = {"function": {"name": "test_func", "arguments": malformed_string}}
with pytest.raises(OutputParserException):
_parse_json_string(
@@ -181,7 +180,7 @@ def test_parse_json_string_skip_returns_input_on_failure() -> None:
result = _parse_json_string(
malformed_string,
raw_tool_call=raw_tool_call,
skip=True,
skip=True, # We want the original invalid string back
)
assert result == malformed_string

View File

@@ -16,16 +16,12 @@ def test_initialization() -> None:
@patch("langchain_ollama.embeddings.validate_model")
def test_validate_model_on_init(mock_validate_model: Any) -> None:
"""Test that the model is validated on initialization when requested."""
# Test that validate_model is called when validate_model_on_init=True
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=True)
mock_validate_model.assert_called_once()
mock_validate_model.reset_mock()
# Test that validate_model is NOT called when validate_model_on_init=False
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=False)
mock_validate_model.assert_not_called()
# Test that validate_model is NOT called by default
OllamaEmbeddings(model=MODEL_NAME)
mock_validate_model.assert_not_called()
@@ -33,20 +29,13 @@ def test_validate_model_on_init(mock_validate_model: Any) -> None:
@patch("langchain_ollama.embeddings.Client")
def test_embed_documents_passes_options(mock_client_class: Any) -> None:
"""Test that `embed_documents()` passes options, including `num_gpu`."""
# Create a mock client instance
mock_client = Mock()
mock_client_class.return_value = mock_client
# Mock the embed method response
mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
# Create embeddings with num_gpu parameter
embeddings = OllamaEmbeddings(model=MODEL_NAME, num_gpu=4, temperature=0.5)
# Call embed_documents
result = embeddings.embed_documents(["test text"])
# Verify the result
assert result == [[0.1, 0.2, 0.3]]
# Check that embed was called with correct arguments

View File

@@ -14,7 +14,7 @@ def test_initialization() -> None:
def test_model_params() -> None:
# Test standard tracing params
"""Test standard tracing params"""
llm = OllamaLLM(model=MODEL_NAME)
ls_params = llm._get_ls_params()
assert ls_params == {
@@ -36,16 +36,12 @@ def test_model_params() -> None:
@patch("langchain_ollama.llms.validate_model")
def test_validate_model_on_init(mock_validate_model: Any) -> None:
"""Test that the model is validated on initialization when requested."""
# Test that validate_model is called when validate_model_on_init=True
OllamaLLM(model=MODEL_NAME, validate_model_on_init=True)
mock_validate_model.assert_called_once()
mock_validate_model.reset_mock()
# Test that validate_model is NOT called when validate_model_on_init=False
OllamaLLM(model=MODEL_NAME, validate_model_on_init=False)
mock_validate_model.assert_not_called()
# Test that validate_model is NOT called by default
OllamaLLM(model=MODEL_NAME)
mock_validate_model.assert_not_called()