refactor(ollama): clean up tests (#33198)

This commit is contained in:
Mason Daugherty
2025-10-01 21:52:01 -04:00
committed by GitHub
parent a89c549cb0
commit a9eda18e1e
11 changed files with 387 additions and 436 deletions

View File

@@ -3,9 +3,16 @@
from __future__ import annotations
from typing import Annotated, Optional
from unittest.mock import MagicMock, patch
import pytest
from pydantic import BaseModel, Field
from httpx import ConnectError
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.tool import ToolCallChunk
from langchain_core.tools import tool
from ollama import ResponseError
from pydantic import BaseModel, Field, ValidationError
from typing_extensions import TypedDict
from langchain_ollama import ChatOllama
@@ -13,6 +20,43 @@ from langchain_ollama import ChatOllama
DEFAULT_MODEL_NAME = "llama3.1"
@tool
def get_current_weather(location: str) -> dict:
"""Gets the current weather in a given location."""
if "boston" in location.lower():
return {"temperature": "15°F", "conditions": "snow"}
return {"temperature": "unknown", "conditions": "unknown"}
@patch("langchain_ollama.chat_models.Client.list")
def test_init_model_not_found(mock_list: MagicMock) -> None:
"""Test that a ValueError is raised when the model is not found."""
mock_list.side_effect = ValueError("Test model not found")
with pytest.raises(ValueError) as excinfo:
ChatOllama(model="non-existent-model", validate_model_on_init=True)
assert "Test model not found" in str(excinfo.value)
@patch("langchain_ollama.chat_models.Client.list")
def test_init_connection_error(mock_list: MagicMock) -> None:
"""Test that a ValidationError is raised on connect failure during init."""
mock_list.side_effect = ConnectError("Test connection error")
with pytest.raises(ValidationError) as excinfo:
ChatOllama(model="any-model", validate_model_on_init=True)
assert "Failed to connect to Ollama" in str(excinfo.value)
@patch("langchain_ollama.chat_models.Client.list")
def test_init_response_error(mock_list: MagicMock) -> None:
"""Test that a ResponseError is raised."""
mock_list.side_effect = ResponseError("Test response error")
with pytest.raises(ValidationError) as excinfo:
ChatOllama(model="any-model", validate_model_on_init=True)
assert "Received an error from the Ollama API" in str(excinfo.value)
@pytest.mark.parametrize(("method"), [("function_calling"), ("json_schema")])
def test_structured_output(method: str) -> None:
"""Test to verify structured output via tool calling and `format` parameter."""
@@ -98,3 +142,97 @@ def test_structured_output_deeply_nested(model: str) -> None:
for chunk in chat.stream(text):
assert isinstance(chunk, Data)
@pytest.mark.parametrize(("model"), [(DEFAULT_MODEL_NAME)])
def test_tool_streaming(model: str) -> None:
"""Test that the model can stream tool calls."""
llm = ChatOllama(model=model)
chat_model_with_tools = llm.bind_tools([get_current_weather])
prompt = [HumanMessage("What is the weather today in Boston?")]
# Flags and collectors for validation
tool_chunk_found = False
final_tool_calls = []
collected_tool_chunks: list[ToolCallChunk] = []
# Stream the response and inspect the chunks
for chunk in chat_model_with_tools.stream(prompt):
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
if chunk.tool_call_chunks:
tool_chunk_found = True
collected_tool_chunks.extend(chunk.tool_call_chunks)
if chunk.tool_calls:
final_tool_calls.extend(chunk.tool_calls)
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
assert len(final_tool_calls) == 1, (
f"Expected 1 final tool call, but got {len(final_tool_calls)}"
)
final_tool_call = final_tool_calls[0]
assert final_tool_call["name"] == "get_current_weather"
assert final_tool_call["args"] == {"location": "Boston"}
assert len(collected_tool_chunks) > 0
assert collected_tool_chunks[0]["name"] == "get_current_weather"
# The ID should be consistent across chunks that have it
tool_call_id = collected_tool_chunks[0].get("id")
assert tool_call_id is not None
assert all(
chunk.get("id") == tool_call_id
for chunk in collected_tool_chunks
if chunk.get("id")
)
assert final_tool_call["id"] == tool_call_id
@pytest.mark.parametrize(("model"), [(DEFAULT_MODEL_NAME)])
async def test_tool_astreaming(model: str) -> None:
"""Test that the model can stream tool calls."""
llm = ChatOllama(model=model)
chat_model_with_tools = llm.bind_tools([get_current_weather])
prompt = [HumanMessage("What is the weather today in Boston?")]
# Flags and collectors for validation
tool_chunk_found = False
final_tool_calls = []
collected_tool_chunks: list[ToolCallChunk] = []
# Stream the response and inspect the chunks
async for chunk in chat_model_with_tools.astream(prompt):
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
if chunk.tool_call_chunks:
tool_chunk_found = True
collected_tool_chunks.extend(chunk.tool_call_chunks)
if chunk.tool_calls:
final_tool_calls.extend(chunk.tool_calls)
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
assert len(final_tool_calls) == 1, (
f"Expected 1 final tool call, but got {len(final_tool_calls)}"
)
final_tool_call = final_tool_calls[0]
assert final_tool_call["name"] == "get_current_weather"
assert final_tool_call["args"] == {"location": "Boston"}
assert len(collected_tool_chunks) > 0
assert collected_tool_chunks[0]["name"] == "get_current_weather"
# The ID should be consistent across chunks that have it
tool_call_id = collected_tool_chunks[0].get("id")
assert tool_call_id is not None
assert all(
chunk.get("id") == tool_call_id
for chunk in collected_tool_chunks
if chunk.get("id")
)
assert final_tool_call["id"] == tool_call_id

View File

@@ -8,9 +8,10 @@ from langchain_ollama import ChatOllama
SAMPLE = "What is 3^3?"
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_stream_no_reasoning(model: str) -> None:
"""Test streaming with `reasoning=False`"""
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
@pytest.mark.parametrize("use_async", [False, True])
async def test_stream_no_reasoning(model: str, use_async: bool) -> None:
"""Test streaming with ``reasoning=False``."""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
messages = [
{
@@ -19,12 +20,20 @@ def test_stream_no_reasoning(model: str) -> None:
}
]
result = None
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
if use_async:
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
else:
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "<think>" not in result.content
@@ -32,33 +41,10 @@ def test_stream_no_reasoning(model: str) -> None:
assert "reasoning_content" not in result.additional_kwargs
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_astream_no_reasoning(model: str) -> None:
"""Test async streaming with `reasoning=False`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "<think>" not in result.content
assert "</think>" not in result.content
assert "reasoning_content" not in result.additional_kwargs
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_stream_reasoning_none(model: str) -> None:
"""Test streaming with `reasoning=None`"""
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
@pytest.mark.parametrize("use_async", [False, True])
async def test_stream_reasoning_none(model: str, use_async: bool) -> None:
"""Test streaming with ``reasoning=None``."""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
messages = [
{
@@ -67,12 +53,20 @@ def test_stream_reasoning_none(model: str) -> None:
}
]
result = None
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
if use_async:
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
else:
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "<think>" in result.content
@@ -82,35 +76,10 @@ def test_stream_reasoning_none(model: str) -> None:
assert "</think>" not in result.additional_kwargs.get("reasoning_content", "")
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_astream_reasoning_none(model: str) -> None:
"""Test async streaming with `reasoning=None`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "<think>" in result.content
assert "</think>" in result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" not in result.additional_kwargs.get("reasoning_content", "")
assert "</think>" not in result.additional_kwargs.get("reasoning_content", "")
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_reasoning_stream(model: str) -> None:
"""Test streaming with `reasoning=True`"""
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
@pytest.mark.parametrize("use_async", [False, True])
async def test_reasoning_stream(model: str, use_async: bool) -> None:
"""Test streaming with ``reasoning=True``."""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
messages = [
{
@@ -119,12 +88,20 @@ def test_reasoning_stream(model: str) -> None:
}
]
result = None
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
if use_async:
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
else:
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" in result.additional_kwargs
@@ -135,63 +112,32 @@ def test_reasoning_stream(model: str) -> None:
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_reasoning_astream(model: str) -> None:
"""Test async streaming with `reasoning=True`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" not in result.content
assert "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_invoke_no_reasoning(model: str) -> None:
"""Test using invoke with `reasoning=False`"""
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
@pytest.mark.parametrize("use_async", [False, True])
async def test_invoke_no_reasoning(model: str, use_async: bool) -> None:
"""Test invoke with ``reasoning=False``."""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
message = HumanMessage(content=SAMPLE)
result = llm.invoke([message])
if use_async:
result = await llm.ainvoke([message])
else:
result = llm.invoke([message])
assert result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" not in result.content
assert "</think>" not in result.content
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_ainvoke_no_reasoning(model: str) -> None:
"""Test using async invoke with `reasoning=False`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
message = HumanMessage(content=SAMPLE)
result = await llm.ainvoke([message])
assert result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" not in result.content
assert "</think>" not in result.content
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_invoke_reasoning_none(model: str) -> None:
"""Test using invoke with `reasoning=None`"""
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
@pytest.mark.parametrize("use_async", [False, True])
async def test_invoke_reasoning_none(model: str, use_async: bool) -> None:
"""Test invoke with ``reasoning=None``."""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
message = HumanMessage(content=SAMPLE)
result = llm.invoke([message])
if use_async:
result = await llm.ainvoke([message])
else:
result = llm.invoke([message])
assert result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" in result.content
@@ -200,26 +146,16 @@ def test_invoke_reasoning_none(model: str) -> None:
assert "</think>" not in result.additional_kwargs.get("reasoning_content", "")
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_ainvoke_reasoning_none(model: str) -> None:
"""Test using async invoke with `reasoning=None`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
message = HumanMessage(content=SAMPLE)
result = await llm.ainvoke([message])
assert result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" in result.content
assert "</think>" in result.content
assert "<think>" not in result.additional_kwargs.get("reasoning_content", "")
assert "</think>" not in result.additional_kwargs.get("reasoning_content", "")
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_reasoning_invoke(model: str) -> None:
"""Test invoke with `reasoning=True`"""
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
@pytest.mark.parametrize("use_async", [False, True])
async def test_reasoning_invoke(model: str, use_async: bool) -> None:
"""Test invoke with ``reasoning=True``."""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
message = HumanMessage(content=SAMPLE)
result = llm.invoke([message])
if use_async:
result = await llm.ainvoke([message])
else:
result = llm.invoke([message])
assert result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
@@ -229,22 +165,7 @@ def test_reasoning_invoke(model: str) -> None:
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_reasoning_ainvoke(model: str) -> None:
"""Test invoke with `reasoning=True`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
message = HumanMessage(content=SAMPLE)
result = await llm.ainvoke([message])
assert result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" not in result.content
assert "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"])
def test_think_tag_stripping_necessity(model: str) -> None:
"""Test that demonstrates why ``_strip_think_tags`` is necessary.

View File

@@ -1,29 +1,14 @@
"""Test chat model integration using standard integration tests."""
from unittest.mock import MagicMock, patch
import pytest
from httpx import ConnectError
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolCallChunk
from langchain_core.tools import tool
from langchain_tests.integration_tests import ChatModelIntegrationTests
from ollama import ResponseError
from pydantic import ValidationError
from langchain_ollama.chat_models import ChatOllama
DEFAULT_MODEL_NAME = "llama3.1"
@tool
def get_current_weather(location: str) -> dict:
"""Gets the current weather in a given location."""
if "boston" in location.lower():
return {"temperature": "15°F", "conditions": "snow"}
return {"temperature": "unknown", "conditions": "unknown"}
class TestChatOllama(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> type[ChatOllama]:
@@ -47,94 +32,6 @@ class TestChatOllama(ChatModelIntegrationTests):
def supports_image_inputs(self) -> bool:
return True
def test_tool_streaming(self, model: BaseChatModel) -> None:
"""Test that the model can stream tool calls."""
chat_model_with_tools = model.bind_tools([get_current_weather])
prompt = [HumanMessage("What is the weather today in Boston?")]
# Flags and collectors for validation
tool_chunk_found = False
final_tool_calls = []
collected_tool_chunks: list[ToolCallChunk] = []
# Stream the response and inspect the chunks
for chunk in chat_model_with_tools.stream(prompt):
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
if chunk.tool_call_chunks:
tool_chunk_found = True
collected_tool_chunks.extend(chunk.tool_call_chunks)
if chunk.tool_calls:
final_tool_calls.extend(chunk.tool_calls)
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
assert len(final_tool_calls) == 1, (
f"Expected 1 final tool call, but got {len(final_tool_calls)}"
)
final_tool_call = final_tool_calls[0]
assert final_tool_call["name"] == "get_current_weather"
assert final_tool_call["args"] == {"location": "Boston"}
assert len(collected_tool_chunks) > 0
assert collected_tool_chunks[0]["name"] == "get_current_weather"
# The ID should be consistent across chunks that have it
tool_call_id = collected_tool_chunks[0].get("id")
assert tool_call_id is not None
assert all(
chunk.get("id") == tool_call_id
for chunk in collected_tool_chunks
if chunk.get("id")
)
assert final_tool_call["id"] == tool_call_id
async def test_tool_astreaming(self, model: BaseChatModel) -> None:
"""Test that the model can stream tool calls."""
chat_model_with_tools = model.bind_tools([get_current_weather])
prompt = [HumanMessage("What is the weather today in Boston?")]
# Flags and collectors for validation
tool_chunk_found = False
final_tool_calls = []
collected_tool_chunks: list[ToolCallChunk] = []
# Stream the response and inspect the chunks
async for chunk in chat_model_with_tools.astream(prompt):
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
if chunk.tool_call_chunks:
tool_chunk_found = True
collected_tool_chunks.extend(chunk.tool_call_chunks)
if chunk.tool_calls:
final_tool_calls.extend(chunk.tool_calls)
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
assert len(final_tool_calls) == 1, (
f"Expected 1 final tool call, but got {len(final_tool_calls)}"
)
final_tool_call = final_tool_calls[0]
assert final_tool_call["name"] == "get_current_weather"
assert final_tool_call["args"] == {"location": "Boston"}
assert len(collected_tool_chunks) > 0
assert collected_tool_chunks[0]["name"] == "get_current_weather"
# The ID should be consistent across chunks that have it
tool_call_id = collected_tool_chunks[0].get("id")
assert tool_call_id is not None
assert all(
chunk.get("id") == tool_call_id
for chunk in collected_tool_chunks
if chunk.get("id")
)
assert final_tool_call["id"] == tool_call_id
@pytest.mark.xfail(
reason=(
"Will sometime encounter AssertionErrors where tool responses are "
@@ -153,28 +50,13 @@ class TestChatOllama(ChatModelIntegrationTests):
async def test_tool_calling_async(self, model: BaseChatModel) -> None:
await super().test_tool_calling_async(model)
@patch("langchain_ollama.chat_models.Client.list")
def test_init_model_not_found(self, mock_list: MagicMock) -> None:
"""Test that a ValueError is raised when the model is not found."""
mock_list.side_effect = ValueError("Test model not found")
with pytest.raises(ValueError) as excinfo:
ChatOllama(model="non-existent-model", validate_model_on_init=True)
assert "Test model not found" in str(excinfo.value)
@patch("langchain_ollama.chat_models.Client.list")
def test_init_connection_error(self, mock_list: MagicMock) -> None:
"""Test that a ValidationError is raised on connect failure during init."""
mock_list.side_effect = ConnectError("Test connection error")
with pytest.raises(ValidationError) as excinfo:
ChatOllama(model="any-model", validate_model_on_init=True)
assert "Failed to connect to Ollama" in str(excinfo.value)
@patch("langchain_ollama.chat_models.Client.list")
def test_init_response_error(self, mock_list: MagicMock) -> None:
"""Test that a ResponseError is raised."""
mock_list.side_effect = ResponseError("Test response error")
with pytest.raises(ValidationError) as excinfo:
ChatOllama(model="any-model", validate_model_on_init=True)
assert "Received an error from the Ollama API" in str(excinfo.value)
@pytest.mark.xfail(
reason=(
"Will sometimes fail due to Ollama's inconsistent tool call argument "
"structure (see https://github.com/ollama/ollama/issues/6155). "
"Args may contain unexpected keys like 'conversations' instead of "
"empty dict."
)
)
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
super().test_tool_calling_with_no_arguments(model)

View File

@@ -13,11 +13,74 @@ REASONING_MODEL_NAME = os.environ.get("OLLAMA_REASONING_TEST_MODEL", "deepseek-r
SAMPLE = "What is 3^3?"
def test_invoke() -> None:
"""Test sync invoke returning a string."""
llm = OllamaLLM(model=MODEL_NAME)
result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
assert isinstance(result, str)
async def test_ainvoke() -> None:
"""Test async invoke returning a string."""
llm = OllamaLLM(model=MODEL_NAME)
result = await llm.ainvoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
assert isinstance(result, str)
def test_batch() -> None:
"""Test batch sync token generation from `OllamaLLM`."""
llm = OllamaLLM(model=MODEL_NAME)
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)
async def test_abatch() -> None:
"""Test batch async token generation from `OllamaLLM`."""
llm = OllamaLLM(model=MODEL_NAME)
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)
def test_batch_tags() -> None:
"""Test batch sync token generation with tags."""
llm = OllamaLLM(model=MODEL_NAME)
result = llm.batch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token, str)
async def test_abatch_tags() -> None:
"""Test batch async token generation with tags."""
llm = OllamaLLM(model=MODEL_NAME)
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token, str)
def test_stream_text_tokens() -> None:
"""Test streaming raw string tokens from `OllamaLLM`."""
llm = OllamaLLM(model=MODEL_NAME)
for token in llm.stream("I'm Pickle Rick"):
for token in llm.stream("Hi."):
assert isinstance(token, str)
async def test_astream_text_tokens() -> None:
"""Test async streaming raw string tokens from `OllamaLLM`."""
llm = OllamaLLM(model=MODEL_NAME)
async for token in llm.astream("Hi."):
assert isinstance(token, str)
@@ -28,7 +91,6 @@ def test__stream_no_reasoning(model: str) -> None:
result_chunk = None
for chunk in llm._stream(SAMPLE):
# Should be a GenerationChunk
assert isinstance(chunk, GenerationChunk)
if result_chunk is None:
result_chunk = chunk
@@ -38,8 +100,28 @@ def test__stream_no_reasoning(model: str) -> None:
# The final result must be a GenerationChunk with visible content
assert isinstance(result_chunk, GenerationChunk)
assert result_chunk.text
# No separate reasoning_content
assert "reasoning_content" not in result_chunk.generation_info # type: ignore[operator]
assert result_chunk.generation_info
assert not result_chunk.generation_info.get("reasoning_content")
@pytest.mark.parametrize(("model"), [(REASONING_MODEL_NAME)])
async def test__astream_no_reasoning(model: str) -> None:
"""Test low-level async chunk streaming with `reasoning=False`."""
llm = OllamaLLM(model=model, num_ctx=2**12)
result_chunk = None
async for chunk in llm._astream(SAMPLE):
assert isinstance(chunk, GenerationChunk)
if result_chunk is None:
result_chunk = chunk
else:
result_chunk += chunk
# The final result must be a GenerationChunk with visible content
assert isinstance(result_chunk, GenerationChunk)
assert result_chunk.text
assert result_chunk.generation_info
assert not result_chunk.generation_info.get("reasoning_content")
@pytest.mark.parametrize(("model"), [(REASONING_MODEL_NAME)])
@@ -57,40 +139,17 @@ def test__stream_with_reasoning(model: str) -> None:
assert isinstance(result_chunk, GenerationChunk)
assert result_chunk.text
# Should have extracted reasoning into generation_info
assert "reasoning_content" in result_chunk.generation_info # type: ignore[operator]
assert len(result_chunk.generation_info["reasoning_content"]) > 0 # type: ignore[index]
assert result_chunk.generation_info
reasoning_content = result_chunk.generation_info.get("reasoning_content")
assert reasoning_content
assert len(reasoning_content) > 0
# And neither the visible nor the hidden portion contains <think> tags
assert "<think>" not in result_chunk.text
assert "</think>" not in result_chunk.text
assert "<think>" not in result_chunk.generation_info["reasoning_content"] # type: ignore[index]
assert "</think>" not in result_chunk.generation_info["reasoning_content"] # type: ignore[index]
async def test_astream_text_tokens() -> None:
"""Test async streaming raw string tokens from `OllamaLLM`."""
llm = OllamaLLM(model=MODEL_NAME)
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token, str)
@pytest.mark.parametrize(("model"), [(REASONING_MODEL_NAME)])
async def test__astream_no_reasoning(model: str) -> None:
"""Test low-level async chunk streaming with `reasoning=False`."""
llm = OllamaLLM(model=model, num_ctx=2**12)
result_chunk = None
async for chunk in llm._astream(SAMPLE):
assert isinstance(chunk, GenerationChunk)
if result_chunk is None:
result_chunk = chunk
else:
result_chunk += chunk
assert isinstance(result_chunk, GenerationChunk)
assert result_chunk.text
assert "reasoning_content" not in result_chunk.generation_info # type: ignore[operator]
assert "<think>" not in reasoning_content
assert "</think>" not in reasoning_content
@pytest.mark.parametrize(("model"), [(REASONING_MODEL_NAME)])
@@ -108,49 +167,14 @@ async def test__astream_with_reasoning(model: str) -> None:
assert isinstance(result_chunk, GenerationChunk)
assert result_chunk.text
assert "reasoning_content" in result_chunk.generation_info # type: ignore[operator]
assert len(result_chunk.generation_info["reasoning_content"]) > 0 # type: ignore[index]
async def test_abatch() -> None:
"""Test batch sync token generation from `OllamaLLM`."""
llm = OllamaLLM(model=MODEL_NAME)
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)
async def test_abatch_tags() -> None:
"""Test batch sync token generation with tags."""
llm = OllamaLLM(model=MODEL_NAME)
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token, str)
def test_batch() -> None:
"""Test batch token generation from `OllamaLLM`."""
llm = OllamaLLM(model=MODEL_NAME)
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)
async def test_ainvoke() -> None:
"""Test async invoke returning a string."""
llm = OllamaLLM(model=MODEL_NAME)
result = await llm.ainvoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
assert isinstance(result, str)
def test_invoke() -> None:
"""Test sync invoke returning a string."""
llm = OllamaLLM(model=MODEL_NAME)
result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
assert isinstance(result, str)
# Should have extracted reasoning into generation_info
assert result_chunk.generation_info
reasoning_content = result_chunk.generation_info.get("reasoning_content")
assert reasoning_content
assert len(reasoning_content) > 0
# And neither the visible nor the hidden portion contains <think> tags
assert "<think>" not in result_chunk.text
assert "</think>" not in result_chunk.text
assert "<think>" not in reasoning_content
assert "</think>" not in reasoning_content