mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 11:07:36 +00:00
refactor(ollama): clean up tests (#33198)
This commit is contained in:
@@ -14,14 +14,20 @@ service.
|
||||
"""
|
||||
|
||||
from importlib import metadata
|
||||
from importlib.metadata import PackageNotFoundError
|
||||
|
||||
from langchain_ollama.chat_models import ChatOllama
|
||||
from langchain_ollama.embeddings import OllamaEmbeddings
|
||||
from langchain_ollama.llms import OllamaLLM
|
||||
|
||||
|
||||
def _raise_package_not_found_error() -> None:
|
||||
raise PackageNotFoundError
|
||||
|
||||
|
||||
try:
|
||||
if __package__ is None:
|
||||
raise metadata.PackageNotFoundError
|
||||
_raise_package_not_found_error()
|
||||
__version__ = metadata.version(__package__)
|
||||
except metadata.PackageNotFoundError:
|
||||
# Case where package metadata is not available.
|
||||
|
||||
@@ -132,7 +132,6 @@ def _parse_arguments_from_tool_call(
|
||||
Should be removed/changed if fixed upstream.
|
||||
|
||||
See https://github.com/ollama/ollama/issues/6155
|
||||
|
||||
"""
|
||||
if "function" not in raw_tool_call:
|
||||
return None
|
||||
@@ -834,7 +833,7 @@ class ChatOllama(BaseChatModel):
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
verbose: bool = False, # noqa: FBT001, FBT002
|
||||
verbose: bool = False, # noqa: FBT002
|
||||
**kwargs: Any,
|
||||
) -> ChatGenerationChunk:
|
||||
final_chunk = None
|
||||
@@ -860,7 +859,7 @@ class ChatOllama(BaseChatModel):
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
verbose: bool = False, # noqa: FBT001, FBT002
|
||||
verbose: bool = False, # noqa: FBT002
|
||||
**kwargs: Any,
|
||||
) -> ChatGenerationChunk:
|
||||
final_chunk = None
|
||||
|
||||
@@ -368,7 +368,7 @@ class OllamaLLM(BaseLLM):
|
||||
prompt: str,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
verbose: bool = False, # noqa: FBT001, FBT002
|
||||
verbose: bool = False, # noqa: FBT002
|
||||
**kwargs: Any,
|
||||
) -> GenerationChunk:
|
||||
final_chunk = None
|
||||
@@ -410,7 +410,7 @@ class OllamaLLM(BaseLLM):
|
||||
prompt: str,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
verbose: bool = False, # noqa: FBT001, FBT002
|
||||
verbose: bool = False, # noqa: FBT002
|
||||
**kwargs: Any,
|
||||
) -> GenerationChunk:
|
||||
final_chunk = None
|
||||
|
||||
@@ -60,15 +60,17 @@ ignore = [
|
||||
"SLF001", # Private member access
|
||||
"UP007", # pyupgrade: non-pep604-annotation-union
|
||||
"UP045", # pyupgrade: non-pep604-annotation-optional
|
||||
"PLR0912",
|
||||
"C901",
|
||||
"PLR0915",
|
||||
"FIX002", # TODOs
|
||||
"TD002", # TODO authors
|
||||
"TC002", # Incorrect type-checking block
|
||||
"TC003", # Incorrect type-checking block
|
||||
"PLR0912", # Too many branches
|
||||
"PLR0915", # Too many statements
|
||||
"C901", # Function too complex
|
||||
"FBT001", # Boolean function param
|
||||
|
||||
# TODO:
|
||||
# TODO
|
||||
"ANN401",
|
||||
"TC002",
|
||||
"TC003",
|
||||
"TRY301",
|
||||
]
|
||||
unfixable = ["B028"] # People should intentionally tune the stacklevel
|
||||
|
||||
@@ -91,16 +93,11 @@ asyncio_mode = "auto"
|
||||
|
||||
[tool.ruff.lint.extend-per-file-ignores]
|
||||
"tests/**/*.py" = [
|
||||
"S101", # Tests need assertions
|
||||
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
|
||||
"PLR2004",
|
||||
|
||||
# TODO
|
||||
"ANN401",
|
||||
"ARG001",
|
||||
"PT011",
|
||||
"FIX",
|
||||
"TD",
|
||||
"S101", # Tests need assertions
|
||||
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
|
||||
"ARG001", # Unused function arguments in tests (e.g. kwargs)
|
||||
"PLR2004", # Magic value in comparisons
|
||||
"PT011", # `pytest.raises()` is too broad
|
||||
]
|
||||
"scripts/*.py" = [
|
||||
"INP001", # Not a package
|
||||
|
||||
@@ -3,9 +3,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
from httpx import ConnectError
|
||||
from langchain_core.messages.ai import AIMessageChunk
|
||||
from langchain_core.messages.human import HumanMessage
|
||||
from langchain_core.messages.tool import ToolCallChunk
|
||||
from langchain_core.tools import tool
|
||||
from ollama import ResponseError
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
@@ -13,6 +20,43 @@ from langchain_ollama import ChatOllama
|
||||
DEFAULT_MODEL_NAME = "llama3.1"
|
||||
|
||||
|
||||
@tool
|
||||
def get_current_weather(location: str) -> dict:
|
||||
"""Gets the current weather in a given location."""
|
||||
if "boston" in location.lower():
|
||||
return {"temperature": "15°F", "conditions": "snow"}
|
||||
return {"temperature": "unknown", "conditions": "unknown"}
|
||||
|
||||
|
||||
@patch("langchain_ollama.chat_models.Client.list")
|
||||
def test_init_model_not_found(mock_list: MagicMock) -> None:
|
||||
"""Test that a ValueError is raised when the model is not found."""
|
||||
mock_list.side_effect = ValueError("Test model not found")
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
ChatOllama(model="non-existent-model", validate_model_on_init=True)
|
||||
assert "Test model not found" in str(excinfo.value)
|
||||
|
||||
|
||||
@patch("langchain_ollama.chat_models.Client.list")
|
||||
def test_init_connection_error(mock_list: MagicMock) -> None:
|
||||
"""Test that a ValidationError is raised on connect failure during init."""
|
||||
mock_list.side_effect = ConnectError("Test connection error")
|
||||
|
||||
with pytest.raises(ValidationError) as excinfo:
|
||||
ChatOllama(model="any-model", validate_model_on_init=True)
|
||||
assert "Failed to connect to Ollama" in str(excinfo.value)
|
||||
|
||||
|
||||
@patch("langchain_ollama.chat_models.Client.list")
|
||||
def test_init_response_error(mock_list: MagicMock) -> None:
|
||||
"""Test that a ResponseError is raised."""
|
||||
mock_list.side_effect = ResponseError("Test response error")
|
||||
|
||||
with pytest.raises(ValidationError) as excinfo:
|
||||
ChatOllama(model="any-model", validate_model_on_init=True)
|
||||
assert "Received an error from the Ollama API" in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("method"), [("function_calling"), ("json_schema")])
|
||||
def test_structured_output(method: str) -> None:
|
||||
"""Test to verify structured output via tool calling and `format` parameter."""
|
||||
@@ -98,3 +142,97 @@ def test_structured_output_deeply_nested(model: str) -> None:
|
||||
|
||||
for chunk in chat.stream(text):
|
||||
assert isinstance(chunk, Data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [(DEFAULT_MODEL_NAME)])
|
||||
def test_tool_streaming(model: str) -> None:
|
||||
"""Test that the model can stream tool calls."""
|
||||
llm = ChatOllama(model=model)
|
||||
chat_model_with_tools = llm.bind_tools([get_current_weather])
|
||||
|
||||
prompt = [HumanMessage("What is the weather today in Boston?")]
|
||||
|
||||
# Flags and collectors for validation
|
||||
tool_chunk_found = False
|
||||
final_tool_calls = []
|
||||
collected_tool_chunks: list[ToolCallChunk] = []
|
||||
|
||||
# Stream the response and inspect the chunks
|
||||
for chunk in chat_model_with_tools.stream(prompt):
|
||||
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
|
||||
|
||||
if chunk.tool_call_chunks:
|
||||
tool_chunk_found = True
|
||||
collected_tool_chunks.extend(chunk.tool_call_chunks)
|
||||
|
||||
if chunk.tool_calls:
|
||||
final_tool_calls.extend(chunk.tool_calls)
|
||||
|
||||
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
|
||||
assert len(final_tool_calls) == 1, (
|
||||
f"Expected 1 final tool call, but got {len(final_tool_calls)}"
|
||||
)
|
||||
|
||||
final_tool_call = final_tool_calls[0]
|
||||
assert final_tool_call["name"] == "get_current_weather"
|
||||
assert final_tool_call["args"] == {"location": "Boston"}
|
||||
|
||||
assert len(collected_tool_chunks) > 0
|
||||
assert collected_tool_chunks[0]["name"] == "get_current_weather"
|
||||
|
||||
# The ID should be consistent across chunks that have it
|
||||
tool_call_id = collected_tool_chunks[0].get("id")
|
||||
assert tool_call_id is not None
|
||||
assert all(
|
||||
chunk.get("id") == tool_call_id
|
||||
for chunk in collected_tool_chunks
|
||||
if chunk.get("id")
|
||||
)
|
||||
assert final_tool_call["id"] == tool_call_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [(DEFAULT_MODEL_NAME)])
|
||||
async def test_tool_astreaming(model: str) -> None:
|
||||
"""Test that the model can stream tool calls."""
|
||||
llm = ChatOllama(model=model)
|
||||
chat_model_with_tools = llm.bind_tools([get_current_weather])
|
||||
|
||||
prompt = [HumanMessage("What is the weather today in Boston?")]
|
||||
|
||||
# Flags and collectors for validation
|
||||
tool_chunk_found = False
|
||||
final_tool_calls = []
|
||||
collected_tool_chunks: list[ToolCallChunk] = []
|
||||
|
||||
# Stream the response and inspect the chunks
|
||||
async for chunk in chat_model_with_tools.astream(prompt):
|
||||
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
|
||||
|
||||
if chunk.tool_call_chunks:
|
||||
tool_chunk_found = True
|
||||
collected_tool_chunks.extend(chunk.tool_call_chunks)
|
||||
|
||||
if chunk.tool_calls:
|
||||
final_tool_calls.extend(chunk.tool_calls)
|
||||
|
||||
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
|
||||
assert len(final_tool_calls) == 1, (
|
||||
f"Expected 1 final tool call, but got {len(final_tool_calls)}"
|
||||
)
|
||||
|
||||
final_tool_call = final_tool_calls[0]
|
||||
assert final_tool_call["name"] == "get_current_weather"
|
||||
assert final_tool_call["args"] == {"location": "Boston"}
|
||||
|
||||
assert len(collected_tool_chunks) > 0
|
||||
assert collected_tool_chunks[0]["name"] == "get_current_weather"
|
||||
|
||||
# The ID should be consistent across chunks that have it
|
||||
tool_call_id = collected_tool_chunks[0].get("id")
|
||||
assert tool_call_id is not None
|
||||
assert all(
|
||||
chunk.get("id") == tool_call_id
|
||||
for chunk in collected_tool_chunks
|
||||
if chunk.get("id")
|
||||
)
|
||||
assert final_tool_call["id"] == tool_call_id
|
||||
|
||||
@@ -8,9 +8,10 @@ from langchain_ollama import ChatOllama
|
||||
SAMPLE = "What is 3^3?"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_stream_no_reasoning(model: str) -> None:
|
||||
"""Test streaming with `reasoning=False`"""
|
||||
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
|
||||
@pytest.mark.parametrize("use_async", [False, True])
|
||||
async def test_stream_no_reasoning(model: str, use_async: bool) -> None:
|
||||
"""Test streaming with ``reasoning=False``."""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
|
||||
messages = [
|
||||
{
|
||||
@@ -19,12 +20,20 @@ def test_stream_no_reasoning(model: str) -> None:
|
||||
}
|
||||
]
|
||||
result = None
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
if use_async:
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
else:
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "<think>" not in result.content
|
||||
@@ -32,33 +41,10 @@ def test_stream_no_reasoning(model: str) -> None:
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
async def test_astream_no_reasoning(model: str) -> None:
|
||||
"""Test async streaming with `reasoning=False`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "<think>" not in result.content
|
||||
assert "</think>" not in result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_stream_reasoning_none(model: str) -> None:
|
||||
"""Test streaming with `reasoning=None`"""
|
||||
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
|
||||
@pytest.mark.parametrize("use_async", [False, True])
|
||||
async def test_stream_reasoning_none(model: str, use_async: bool) -> None:
|
||||
"""Test streaming with ``reasoning=None``."""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
|
||||
messages = [
|
||||
{
|
||||
@@ -67,12 +53,20 @@ def test_stream_reasoning_none(model: str) -> None:
|
||||
}
|
||||
]
|
||||
result = None
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
if use_async:
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
else:
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "<think>" in result.content
|
||||
@@ -82,35 +76,10 @@ def test_stream_reasoning_none(model: str) -> None:
|
||||
assert "</think>" not in result.additional_kwargs.get("reasoning_content", "")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
async def test_astream_reasoning_none(model: str) -> None:
|
||||
"""Test async streaming with `reasoning=None`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "<think>" in result.content
|
||||
assert "</think>" in result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" not in result.additional_kwargs.get("reasoning_content", "")
|
||||
assert "</think>" not in result.additional_kwargs.get("reasoning_content", "")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_reasoning_stream(model: str) -> None:
|
||||
"""Test streaming with `reasoning=True`"""
|
||||
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
|
||||
@pytest.mark.parametrize("use_async", [False, True])
|
||||
async def test_reasoning_stream(model: str, use_async: bool) -> None:
|
||||
"""Test streaming with ``reasoning=True``."""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
|
||||
messages = [
|
||||
{
|
||||
@@ -119,12 +88,20 @@ def test_reasoning_stream(model: str) -> None:
|
||||
}
|
||||
]
|
||||
result = None
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
if use_async:
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
else:
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
@@ -135,63 +112,32 @@ def test_reasoning_stream(model: str) -> None:
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
async def test_reasoning_astream(model: str) -> None:
|
||||
"""Test async streaming with `reasoning=True`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" not in result.content
|
||||
assert "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_invoke_no_reasoning(model: str) -> None:
|
||||
"""Test using invoke with `reasoning=False`"""
|
||||
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
|
||||
@pytest.mark.parametrize("use_async", [False, True])
|
||||
async def test_invoke_no_reasoning(model: str, use_async: bool) -> None:
|
||||
"""Test invoke with ``reasoning=False``."""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = llm.invoke([message])
|
||||
if use_async:
|
||||
result = await llm.ainvoke([message])
|
||||
else:
|
||||
result = llm.invoke([message])
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" not in result.content
|
||||
assert "</think>" not in result.content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
async def test_ainvoke_no_reasoning(model: str) -> None:
|
||||
"""Test using async invoke with `reasoning=False`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = await llm.ainvoke([message])
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" not in result.content
|
||||
assert "</think>" not in result.content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_invoke_reasoning_none(model: str) -> None:
|
||||
"""Test using invoke with `reasoning=None`"""
|
||||
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
|
||||
@pytest.mark.parametrize("use_async", [False, True])
|
||||
async def test_invoke_reasoning_none(model: str, use_async: bool) -> None:
|
||||
"""Test invoke with ``reasoning=None``."""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = llm.invoke([message])
|
||||
if use_async:
|
||||
result = await llm.ainvoke([message])
|
||||
else:
|
||||
result = llm.invoke([message])
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" in result.content
|
||||
@@ -200,26 +146,16 @@ def test_invoke_reasoning_none(model: str) -> None:
|
||||
assert "</think>" not in result.additional_kwargs.get("reasoning_content", "")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
async def test_ainvoke_reasoning_none(model: str) -> None:
|
||||
"""Test using async invoke with `reasoning=None`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = await llm.ainvoke([message])
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" in result.content
|
||||
assert "</think>" in result.content
|
||||
assert "<think>" not in result.additional_kwargs.get("reasoning_content", "")
|
||||
assert "</think>" not in result.additional_kwargs.get("reasoning_content", "")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_reasoning_invoke(model: str) -> None:
|
||||
"""Test invoke with `reasoning=True`"""
|
||||
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
|
||||
@pytest.mark.parametrize("use_async", [False, True])
|
||||
async def test_reasoning_invoke(model: str, use_async: bool) -> None:
|
||||
"""Test invoke with ``reasoning=True``."""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = llm.invoke([message])
|
||||
if use_async:
|
||||
result = await llm.ainvoke([message])
|
||||
else:
|
||||
result = llm.invoke([message])
|
||||
assert result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
@@ -229,22 +165,7 @@ def test_reasoning_invoke(model: str) -> None:
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
async def test_reasoning_ainvoke(model: str) -> None:
|
||||
"""Test invoke with `reasoning=True`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = await llm.ainvoke([message])
|
||||
assert result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" not in result.content
|
||||
assert "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
|
||||
def test_think_tag_stripping_necessity(model: str) -> None:
|
||||
"""Test that demonstrates why ``_strip_think_tags`` is necessary.
|
||||
|
||||
|
||||
@@ -1,29 +1,14 @@
|
||||
"""Test chat model integration using standard integration tests."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ConnectError
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolCallChunk
|
||||
from langchain_core.tools import tool
|
||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||
from ollama import ResponseError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langchain_ollama.chat_models import ChatOllama
|
||||
|
||||
DEFAULT_MODEL_NAME = "llama3.1"
|
||||
|
||||
|
||||
@tool
|
||||
def get_current_weather(location: str) -> dict:
|
||||
"""Gets the current weather in a given location."""
|
||||
if "boston" in location.lower():
|
||||
return {"temperature": "15°F", "conditions": "snow"}
|
||||
return {"temperature": "unknown", "conditions": "unknown"}
|
||||
|
||||
|
||||
class TestChatOllama(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> type[ChatOllama]:
|
||||
@@ -47,94 +32,6 @@ class TestChatOllama(ChatModelIntegrationTests):
|
||||
def supports_image_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
def test_tool_streaming(self, model: BaseChatModel) -> None:
|
||||
"""Test that the model can stream tool calls."""
|
||||
chat_model_with_tools = model.bind_tools([get_current_weather])
|
||||
|
||||
prompt = [HumanMessage("What is the weather today in Boston?")]
|
||||
|
||||
# Flags and collectors for validation
|
||||
tool_chunk_found = False
|
||||
final_tool_calls = []
|
||||
collected_tool_chunks: list[ToolCallChunk] = []
|
||||
|
||||
# Stream the response and inspect the chunks
|
||||
for chunk in chat_model_with_tools.stream(prompt):
|
||||
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
|
||||
|
||||
if chunk.tool_call_chunks:
|
||||
tool_chunk_found = True
|
||||
collected_tool_chunks.extend(chunk.tool_call_chunks)
|
||||
|
||||
if chunk.tool_calls:
|
||||
final_tool_calls.extend(chunk.tool_calls)
|
||||
|
||||
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
|
||||
assert len(final_tool_calls) == 1, (
|
||||
f"Expected 1 final tool call, but got {len(final_tool_calls)}"
|
||||
)
|
||||
|
||||
final_tool_call = final_tool_calls[0]
|
||||
assert final_tool_call["name"] == "get_current_weather"
|
||||
assert final_tool_call["args"] == {"location": "Boston"}
|
||||
|
||||
assert len(collected_tool_chunks) > 0
|
||||
assert collected_tool_chunks[0]["name"] == "get_current_weather"
|
||||
|
||||
# The ID should be consistent across chunks that have it
|
||||
tool_call_id = collected_tool_chunks[0].get("id")
|
||||
assert tool_call_id is not None
|
||||
assert all(
|
||||
chunk.get("id") == tool_call_id
|
||||
for chunk in collected_tool_chunks
|
||||
if chunk.get("id")
|
||||
)
|
||||
assert final_tool_call["id"] == tool_call_id
|
||||
|
||||
async def test_tool_astreaming(self, model: BaseChatModel) -> None:
|
||||
"""Test that the model can stream tool calls."""
|
||||
chat_model_with_tools = model.bind_tools([get_current_weather])
|
||||
|
||||
prompt = [HumanMessage("What is the weather today in Boston?")]
|
||||
|
||||
# Flags and collectors for validation
|
||||
tool_chunk_found = False
|
||||
final_tool_calls = []
|
||||
collected_tool_chunks: list[ToolCallChunk] = []
|
||||
|
||||
# Stream the response and inspect the chunks
|
||||
async for chunk in chat_model_with_tools.astream(prompt):
|
||||
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
|
||||
|
||||
if chunk.tool_call_chunks:
|
||||
tool_chunk_found = True
|
||||
collected_tool_chunks.extend(chunk.tool_call_chunks)
|
||||
|
||||
if chunk.tool_calls:
|
||||
final_tool_calls.extend(chunk.tool_calls)
|
||||
|
||||
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
|
||||
assert len(final_tool_calls) == 1, (
|
||||
f"Expected 1 final tool call, but got {len(final_tool_calls)}"
|
||||
)
|
||||
|
||||
final_tool_call = final_tool_calls[0]
|
||||
assert final_tool_call["name"] == "get_current_weather"
|
||||
assert final_tool_call["args"] == {"location": "Boston"}
|
||||
|
||||
assert len(collected_tool_chunks) > 0
|
||||
assert collected_tool_chunks[0]["name"] == "get_current_weather"
|
||||
|
||||
# The ID should be consistent across chunks that have it
|
||||
tool_call_id = collected_tool_chunks[0].get("id")
|
||||
assert tool_call_id is not None
|
||||
assert all(
|
||||
chunk.get("id") == tool_call_id
|
||||
for chunk in collected_tool_chunks
|
||||
if chunk.get("id")
|
||||
)
|
||||
assert final_tool_call["id"] == tool_call_id
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"Will sometime encounter AssertionErrors where tool responses are "
|
||||
@@ -153,28 +50,13 @@ class TestChatOllama(ChatModelIntegrationTests):
|
||||
async def test_tool_calling_async(self, model: BaseChatModel) -> None:
|
||||
await super().test_tool_calling_async(model)
|
||||
|
||||
@patch("langchain_ollama.chat_models.Client.list")
|
||||
def test_init_model_not_found(self, mock_list: MagicMock) -> None:
|
||||
"""Test that a ValueError is raised when the model is not found."""
|
||||
mock_list.side_effect = ValueError("Test model not found")
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
ChatOllama(model="non-existent-model", validate_model_on_init=True)
|
||||
assert "Test model not found" in str(excinfo.value)
|
||||
|
||||
@patch("langchain_ollama.chat_models.Client.list")
|
||||
def test_init_connection_error(self, mock_list: MagicMock) -> None:
|
||||
"""Test that a ValidationError is raised on connect failure during init."""
|
||||
mock_list.side_effect = ConnectError("Test connection error")
|
||||
|
||||
with pytest.raises(ValidationError) as excinfo:
|
||||
ChatOllama(model="any-model", validate_model_on_init=True)
|
||||
assert "Failed to connect to Ollama" in str(excinfo.value)
|
||||
|
||||
@patch("langchain_ollama.chat_models.Client.list")
|
||||
def test_init_response_error(self, mock_list: MagicMock) -> None:
|
||||
"""Test that a ResponseError is raised."""
|
||||
mock_list.side_effect = ResponseError("Test response error")
|
||||
|
||||
with pytest.raises(ValidationError) as excinfo:
|
||||
ChatOllama(model="any-model", validate_model_on_init=True)
|
||||
assert "Received an error from the Ollama API" in str(excinfo.value)
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"Will sometimes fail due to Ollama's inconsistent tool call argument "
|
||||
"structure (see https://github.com/ollama/ollama/issues/6155). "
|
||||
"Args may contain unexpected keys like 'conversations' instead of "
|
||||
"empty dict."
|
||||
)
|
||||
)
|
||||
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_calling_with_no_arguments(model)
|
||||
|
||||
@@ -13,11 +13,74 @@ REASONING_MODEL_NAME = os.environ.get("OLLAMA_REASONING_TEST_MODEL", "deepseek-r
|
||||
SAMPLE = "What is 3^3?"
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
"""Test sync invoke returning a string."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
async def test_ainvoke() -> None:
|
||||
"""Test async invoke returning a string."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
"""Test batch sync token generation from `OllamaLLM`."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
"""Test batch async token generation from `OllamaLLM`."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_batch_tags() -> None:
|
||||
"""Test batch sync token generation with tags."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
result = llm.batch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
async def test_abatch_tags() -> None:
|
||||
"""Test batch async token generation with tags."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_stream_text_tokens() -> None:
|
||||
"""Test streaming raw string tokens from `OllamaLLM`."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
for token in llm.stream("Hi."):
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
async def test_astream_text_tokens() -> None:
|
||||
"""Test async streaming raw string tokens from `OllamaLLM`."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
async for token in llm.astream("Hi."):
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@@ -28,7 +91,6 @@ def test__stream_no_reasoning(model: str) -> None:
|
||||
|
||||
result_chunk = None
|
||||
for chunk in llm._stream(SAMPLE):
|
||||
# Should be a GenerationChunk
|
||||
assert isinstance(chunk, GenerationChunk)
|
||||
if result_chunk is None:
|
||||
result_chunk = chunk
|
||||
@@ -38,8 +100,28 @@ def test__stream_no_reasoning(model: str) -> None:
|
||||
# The final result must be a GenerationChunk with visible content
|
||||
assert isinstance(result_chunk, GenerationChunk)
|
||||
assert result_chunk.text
|
||||
# No separate reasoning_content
|
||||
assert "reasoning_content" not in result_chunk.generation_info # type: ignore[operator]
|
||||
assert result_chunk.generation_info
|
||||
assert not result_chunk.generation_info.get("reasoning_content")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [(REASONING_MODEL_NAME)])
|
||||
async def test__astream_no_reasoning(model: str) -> None:
|
||||
"""Test low-level async chunk streaming with `reasoning=False`."""
|
||||
llm = OllamaLLM(model=model, num_ctx=2**12)
|
||||
|
||||
result_chunk = None
|
||||
async for chunk in llm._astream(SAMPLE):
|
||||
assert isinstance(chunk, GenerationChunk)
|
||||
if result_chunk is None:
|
||||
result_chunk = chunk
|
||||
else:
|
||||
result_chunk += chunk
|
||||
|
||||
# The final result must be a GenerationChunk with visible content
|
||||
assert isinstance(result_chunk, GenerationChunk)
|
||||
assert result_chunk.text
|
||||
assert result_chunk.generation_info
|
||||
assert not result_chunk.generation_info.get("reasoning_content")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [(REASONING_MODEL_NAME)])
|
||||
@@ -57,40 +139,17 @@ def test__stream_with_reasoning(model: str) -> None:
|
||||
|
||||
assert isinstance(result_chunk, GenerationChunk)
|
||||
assert result_chunk.text
|
||||
|
||||
# Should have extracted reasoning into generation_info
|
||||
assert "reasoning_content" in result_chunk.generation_info # type: ignore[operator]
|
||||
assert len(result_chunk.generation_info["reasoning_content"]) > 0 # type: ignore[index]
|
||||
assert result_chunk.generation_info
|
||||
reasoning_content = result_chunk.generation_info.get("reasoning_content")
|
||||
assert reasoning_content
|
||||
assert len(reasoning_content) > 0
|
||||
# And neither the visible nor the hidden portion contains <think> tags
|
||||
assert "<think>" not in result_chunk.text
|
||||
assert "</think>" not in result_chunk.text
|
||||
assert "<think>" not in result_chunk.generation_info["reasoning_content"] # type: ignore[index]
|
||||
assert "</think>" not in result_chunk.generation_info["reasoning_content"] # type: ignore[index]
|
||||
|
||||
|
||||
async def test_astream_text_tokens() -> None:
|
||||
"""Test async streaming raw string tokens from `OllamaLLM`."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [(REASONING_MODEL_NAME)])
|
||||
async def test__astream_no_reasoning(model: str) -> None:
|
||||
"""Test low-level async chunk streaming with `reasoning=False`."""
|
||||
llm = OllamaLLM(model=model, num_ctx=2**12)
|
||||
|
||||
result_chunk = None
|
||||
async for chunk in llm._astream(SAMPLE):
|
||||
assert isinstance(chunk, GenerationChunk)
|
||||
if result_chunk is None:
|
||||
result_chunk = chunk
|
||||
else:
|
||||
result_chunk += chunk
|
||||
|
||||
assert isinstance(result_chunk, GenerationChunk)
|
||||
assert result_chunk.text
|
||||
assert "reasoning_content" not in result_chunk.generation_info # type: ignore[operator]
|
||||
assert "<think>" not in reasoning_content
|
||||
assert "</think>" not in reasoning_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [(REASONING_MODEL_NAME)])
|
||||
@@ -108,49 +167,14 @@ async def test__astream_with_reasoning(model: str) -> None:
|
||||
|
||||
assert isinstance(result_chunk, GenerationChunk)
|
||||
assert result_chunk.text
|
||||
assert "reasoning_content" in result_chunk.generation_info # type: ignore[operator]
|
||||
assert len(result_chunk.generation_info["reasoning_content"]) > 0 # type: ignore[index]
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
"""Test batch sync token generation from `OllamaLLM`."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
async def test_abatch_tags() -> None:
|
||||
"""Test batch sync token generation with tags."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
"""Test batch token generation from `OllamaLLM`."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
async def test_ainvoke() -> None:
|
||||
"""Test async invoke returning a string."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
"""Test sync invoke returning a string."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
|
||||
assert isinstance(result, str)
|
||||
# Should have extracted reasoning into generation_info
|
||||
assert result_chunk.generation_info
|
||||
reasoning_content = result_chunk.generation_info.get("reasoning_content")
|
||||
assert reasoning_content
|
||||
assert len(reasoning_content) > 0
|
||||
# And neither the visible nor the hidden portion contains <think> tags
|
||||
assert "<think>" not in result_chunk.text
|
||||
assert "</think>" not in result_chunk.text
|
||||
assert "<think>" not in reasoning_content
|
||||
assert "</think>" not in reasoning_content
|
||||
|
||||
@@ -22,6 +22,22 @@ from langchain_ollama.chat_models import (
|
||||
MODEL_NAME = "llama3.1"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _mock_httpx_client_stream(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> Generator[Response, Any, Any]:
|
||||
yield Response(
|
||||
status_code=200,
|
||||
content='{"message": {"role": "assistant", "content": "The meaning ..."}}',
|
||||
request=Request(method="POST", url="http://whocares:11434"),
|
||||
)
|
||||
|
||||
|
||||
dummy_raw_tool_call = {
|
||||
"function": {"name": "test_func", "arguments": ""},
|
||||
}
|
||||
|
||||
|
||||
class TestChatOllama(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> type[ChatOllama]:
|
||||
@@ -35,19 +51,24 @@ class TestChatOllama(ChatModelUnitTests):
|
||||
def test__parse_arguments_from_tool_call() -> None:
|
||||
"""Test that string arguments are preserved as strings in tool call parsing.
|
||||
|
||||
This test verifies the fix for PR #30154 which addressed an issue where
|
||||
string-typed tool arguments (like IDs or long strings) were being incorrectly
|
||||
PR #30154
|
||||
String-typed tool arguments (like IDs or long strings) were being incorrectly
|
||||
processed. The parser should preserve string values as strings rather than
|
||||
attempting to parse them as JSON when they're already valid string arguments.
|
||||
|
||||
The test uses a long string ID to ensure string arguments maintain their
|
||||
original type after parsing, which is critical for tools expecting string inputs.
|
||||
Use a long string ID to ensure string arguments maintain their original type after
|
||||
parsing, which is critical for tools expecting string inputs.
|
||||
"""
|
||||
raw_response = '{"model":"sample-model","message":{"role":"assistant","content":"","tool_calls":[{"function":{"name":"get_profile_details","arguments":{"arg_1":"12345678901234567890123456"}}}]},"done":false}' # noqa: E501
|
||||
raw_response = (
|
||||
'{"model":"sample-model","message":{"role":"assistant","content":"",'
|
||||
'"tool_calls":[{"function":{"name":"get_profile_details",'
|
||||
'"arguments":{"arg_1":"12345678901234567890123456"}}}]},"done":false}'
|
||||
)
|
||||
raw_tool_calls = json.loads(raw_response)["message"]["tool_calls"]
|
||||
response = _parse_arguments_from_tool_call(raw_tool_calls[0])
|
||||
assert response is not None
|
||||
assert isinstance(response["arg_1"], str)
|
||||
assert response["arg_1"] == "12345678901234567890123456"
|
||||
|
||||
|
||||
def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
|
||||
@@ -57,7 +78,6 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
|
||||
that just echoes the function name. This should be filtered out for
|
||||
no-argument tools to return an empty dictionary.
|
||||
"""
|
||||
# Test case where arguments contain functionName metadata
|
||||
raw_tool_call_with_metadata = {
|
||||
"function": {
|
||||
"name": "magic_function_no_args",
|
||||
@@ -67,7 +87,7 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
|
||||
response = _parse_arguments_from_tool_call(raw_tool_call_with_metadata)
|
||||
assert response == {}
|
||||
|
||||
# Test case where arguments contain both real args and metadata
|
||||
# Arguments contain both real args and metadata
|
||||
raw_tool_call_mixed = {
|
||||
"function": {
|
||||
"name": "some_function",
|
||||
@@ -77,7 +97,7 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
|
||||
response_mixed = _parse_arguments_from_tool_call(raw_tool_call_mixed)
|
||||
assert response_mixed == {"real_arg": "value"}
|
||||
|
||||
# Test case where functionName has different value (should be preserved)
|
||||
# functionName has different value (should be preserved)
|
||||
raw_tool_call_different = {
|
||||
"function": {"name": "function_a", "arguments": {"functionName": "function_b"}}
|
||||
}
|
||||
@@ -85,17 +105,6 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
|
||||
assert response_different == {"functionName": "function_b"}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _mock_httpx_client_stream(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> Generator[Response, Any, Any]:
|
||||
yield Response(
|
||||
status_code=200,
|
||||
content='{"message": {"role": "assistant", "content": "The meaning ..."}}',
|
||||
request=Request(method="POST", url="http://whocares:11434"),
|
||||
)
|
||||
|
||||
|
||||
def test_arbitrary_roles_accepted_in_chatmessages(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -120,26 +129,16 @@ def test_arbitrary_roles_accepted_in_chatmessages(
|
||||
@patch("langchain_ollama.chat_models.validate_model")
|
||||
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||
"""Test that the model is validated on initialization when requested."""
|
||||
# Test that validate_model is called when validate_model_on_init=True
|
||||
ChatOllama(model=MODEL_NAME, validate_model_on_init=True)
|
||||
mock_validate_model.assert_called_once()
|
||||
mock_validate_model.reset_mock()
|
||||
|
||||
# Test that validate_model is NOT called when validate_model_on_init=False
|
||||
ChatOllama(model=MODEL_NAME, validate_model_on_init=False)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
# Test that validate_model is NOT called by default
|
||||
ChatOllama(model=MODEL_NAME)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
|
||||
# Define a dummy raw_tool_call for the function signature
|
||||
dummy_raw_tool_call = {
|
||||
"function": {"name": "test_func", "arguments": ""},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("input_string", "expected_output"),
|
||||
[
|
||||
@@ -164,7 +163,7 @@ def test_parse_json_string_success_cases(
|
||||
|
||||
def test_parse_json_string_failure_case_raises_exception() -> None:
|
||||
"""Tests that `_parse_json_string` raises an exception for malformed strings."""
|
||||
malformed_string = "{'key': 'value',,}"
|
||||
malformed_string = "{'key': 'value',,}" # Double comma is invalid
|
||||
raw_tool_call = {"function": {"name": "test_func", "arguments": malformed_string}}
|
||||
with pytest.raises(OutputParserException):
|
||||
_parse_json_string(
|
||||
@@ -181,7 +180,7 @@ def test_parse_json_string_skip_returns_input_on_failure() -> None:
|
||||
result = _parse_json_string(
|
||||
malformed_string,
|
||||
raw_tool_call=raw_tool_call,
|
||||
skip=True,
|
||||
skip=True, # We want the original invalid string back
|
||||
)
|
||||
assert result == malformed_string
|
||||
|
||||
|
||||
@@ -16,16 +16,12 @@ def test_initialization() -> None:
|
||||
@patch("langchain_ollama.embeddings.validate_model")
|
||||
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||
"""Test that the model is validated on initialization when requested."""
|
||||
# Test that validate_model is called when validate_model_on_init=True
|
||||
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=True)
|
||||
mock_validate_model.assert_called_once()
|
||||
mock_validate_model.reset_mock()
|
||||
|
||||
# Test that validate_model is NOT called when validate_model_on_init=False
|
||||
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=False)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
# Test that validate_model is NOT called by default
|
||||
OllamaEmbeddings(model=MODEL_NAME)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
@@ -33,20 +29,13 @@ def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||
@patch("langchain_ollama.embeddings.Client")
|
||||
def test_embed_documents_passes_options(mock_client_class: Any) -> None:
|
||||
"""Test that `embed_documents()` passes options, including `num_gpu`."""
|
||||
# Create a mock client instance
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Mock the embed method response
|
||||
mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
|
||||
|
||||
# Create embeddings with num_gpu parameter
|
||||
embeddings = OllamaEmbeddings(model=MODEL_NAME, num_gpu=4, temperature=0.5)
|
||||
|
||||
# Call embed_documents
|
||||
result = embeddings.embed_documents(["test text"])
|
||||
|
||||
# Verify the result
|
||||
assert result == [[0.1, 0.2, 0.3]]
|
||||
|
||||
# Check that embed was called with correct arguments
|
||||
|
||||
@@ -14,7 +14,7 @@ def test_initialization() -> None:
|
||||
|
||||
|
||||
def test_model_params() -> None:
|
||||
# Test standard tracing params
|
||||
"""Test standard tracing params"""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
@@ -36,16 +36,12 @@ def test_model_params() -> None:
|
||||
@patch("langchain_ollama.llms.validate_model")
|
||||
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||
"""Test that the model is validated on initialization when requested."""
|
||||
# Test that validate_model is called when validate_model_on_init=True
|
||||
OllamaLLM(model=MODEL_NAME, validate_model_on_init=True)
|
||||
mock_validate_model.assert_called_once()
|
||||
mock_validate_model.reset_mock()
|
||||
|
||||
# Test that validate_model is NOT called when validate_model_on_init=False
|
||||
OllamaLLM(model=MODEL_NAME, validate_model_on_init=False)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
# Test that validate_model is NOT called by default
|
||||
OllamaLLM(model=MODEL_NAME)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user