mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
ollama: thinking, tool streaming, docs, tests (#31772)
* New `reasoning` (bool) param to support toggling [Ollama thinking](https://ollama.com/blog/thinking) (#31573, #31700). If `reasoning=True`, Ollama's `thinking` content will be placed in the model responses' `additional_kwargs.reasoning_content`. * Supported by: * ChatOllama (class level, invocation level TODO) * OllamaLLM (TODO) * Added tests to ensure streaming tool calls is successful (#29129) * Refactored tests that relied on `extract_reasoning()` * Myriad docs additions and consistency/typo fixes * Improved type safety in some spots Closes #29129 Addresses #31573 and #31700 Supersedes #31701
This commit is contained in:
@@ -1,17 +1,28 @@
|
||||
"""Ollama specific chat model integration tests for reasoning models."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessageChunk, HumanMessage
|
||||
from pydantic import ValidationError
|
||||
from langchain_core.messages import (
|
||||
AIMessageChunk,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
SAMPLE = "What is 3^3?"
|
||||
|
||||
|
||||
class MathAnswer(BaseModel):
|
||||
"""A mathematical expression and its numerical answer."""
|
||||
|
||||
expression: str = Field(description="The mathematical expression to evaluate.")
|
||||
answer: int = Field(description="The numerical answer to the expression.")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_stream_no_reasoning(model: str) -> None:
|
||||
"""Test deepseek model without parsing."""
|
||||
def test_stream_no_reasoning(model: str) -> None:
|
||||
"""Test streaming with `reasoning=False`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12)
|
||||
messages = [
|
||||
{
|
||||
@@ -28,14 +39,41 @@ def test_deepseek_messages_stream_no_reasoning(model: str) -> None:
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "<think>" in result.content and "</think>" in result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_stream_bool(model: str) -> None:
|
||||
"""Test deepseek model with reasoning bool=True"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True)
|
||||
async def test_astream_no_reasoning(model: str) -> None:
|
||||
"""Test async streaming with `reasoning=False`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_stream_reasoning_none(model: str) -> None:
|
||||
"""Test streaming with `reasoning=None`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -51,26 +89,41 @@ def test_deepseek_messages_stream_bool(model: str) -> None:
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" in result.additional_kwargs["reasoning_content"]
|
||||
clean_content = (
|
||||
result.additional_kwargs["reasoning_content"]
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.strip()
|
||||
)
|
||||
assert len(clean_content) > 0
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" in result.content and "</think>" in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_stream_tuple(model: str) -> None:
|
||||
"""Test deepseek model with reasoning with tuple=..."""
|
||||
llm = ChatOllama(
|
||||
model=model, num_ctx=2**12, extract_reasoning=("<think>", "</think>")
|
||||
)
|
||||
async def test_astream_reasoning_none(model: str) -> None:
|
||||
"""Test async streaming with `reasoning=None`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" in result.content and "</think>" in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_reasoning_stream(model: str) -> None:
|
||||
"""Test streaming with `reasoning=True`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -86,77 +139,114 @@ def test_deepseek_messages_stream_tuple(model: str) -> None:
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" in result.additional_kwargs["reasoning_content"]
|
||||
clean_content = (
|
||||
result.additional_kwargs["reasoning_content"]
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.strip()
|
||||
)
|
||||
assert len(clean_content) > 0
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_invoke_no_reasoning(model: str) -> None:
|
||||
"""Test deepseek model without parsing using invoke."""
|
||||
async def test_reasoning_astream(model: str) -> None:
|
||||
"""Test async streaming with `reasoning=True`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_invoke_no_reasoning(model: str) -> None:
|
||||
"""Test using invoke with `reasoning=False`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = llm.invoke([message])
|
||||
assert result.content
|
||||
assert "<think>" in result.content and "</think>" in result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_invoke_bool(model: str) -> None:
|
||||
"""Test deepseek model with reasoning bool=True using invoke"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True)
|
||||
async def test_ainvoke_no_reasoning(model: str) -> None:
|
||||
"""Test using async invoke with `reasoning=False`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = await llm.ainvoke([message])
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_invoke_reasoning_none(model: str) -> None:
|
||||
"""Test using invoke with `reasoning=None`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = llm.invoke([message])
|
||||
assert result.content
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" in result.additional_kwargs["reasoning_content"]
|
||||
clean_content = (
|
||||
result.additional_kwargs["reasoning_content"]
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.strip()
|
||||
)
|
||||
assert len(clean_content) > 0
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" in result.content and "</think>" in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_invoke_tuple(model: str) -> None:
|
||||
"""Test deepseek model with reasoning with tuple=... using invoke"""
|
||||
llm = ChatOllama(
|
||||
model=model, num_ctx=2**12, extract_reasoning=("<think>", "</think>")
|
||||
)
|
||||
async def test_ainvoke_reasoning_none(model: str) -> None:
|
||||
"""Test using async invoke with `reasoning=None`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = await llm.ainvoke([message])
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" in result.content and "</think>" in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_reasoning_invoke(model: str) -> None:
|
||||
"""Test invoke with `reasoning=True`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = llm.invoke([message])
|
||||
assert result.content
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" in result.additional_kwargs["reasoning_content"]
|
||||
clean_content = (
|
||||
result.additional_kwargs["reasoning_content"]
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.strip()
|
||||
)
|
||||
assert len(clean_content) > 0
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_invalid(model: str) -> None:
|
||||
"""Test deepseek model with reasoning raises ValidationError"""
|
||||
with pytest.raises(ValidationError):
|
||||
_ = ChatOllama(model=model, extract_reasoning={"invalid": "data"}) # type: ignore[arg-type]
|
||||
async def test_reasoning_ainvoke(model: str) -> None:
|
||||
"""Test invoke with `reasoning=True`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = await llm.ainvoke([message])
|
||||
assert result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
@@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from httpx import ConnectError
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolCallChunk
|
||||
from langchain_core.tools import tool
|
||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||
from ollama import ResponseError
|
||||
from pydantic import ValidationError
|
||||
@@ -14,6 +16,15 @@ from langchain_ollama.chat_models import ChatOllama
|
||||
DEFAULT_MODEL_NAME = "llama3.1"
|
||||
|
||||
|
||||
@tool
|
||||
def get_current_weather(location: str) -> dict:
|
||||
"""Gets the current weather in a given location."""
|
||||
if "boston" in location.lower():
|
||||
return {"temperature": "15°F", "conditions": "snow"}
|
||||
else:
|
||||
return {"temperature": "unknown", "conditions": "unknown"}
|
||||
|
||||
|
||||
class TestChatOllama(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> type[ChatOllama]:
|
||||
@@ -29,12 +40,104 @@ class TestChatOllama(ChatModelIntegrationTests):
|
||||
|
||||
@property
|
||||
def has_tool_choice(self) -> bool:
|
||||
return False # TODO: update after Ollama implements
|
||||
# TODO: update after Ollama implements
|
||||
# https://github.com/ollama/ollama/blob/main/docs/openai.md
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_image_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
def test_tool_streaming(self, model: BaseChatModel) -> None:
|
||||
"""Test that the model can stream tool calls."""
|
||||
chat_model_with_tools = model.bind_tools([get_current_weather])
|
||||
|
||||
prompt = [HumanMessage("What is the weather today in Boston?")]
|
||||
|
||||
# Flags and collectors for validation
|
||||
tool_chunk_found = False
|
||||
final_tool_calls = []
|
||||
collected_tool_chunks: list[ToolCallChunk] = []
|
||||
|
||||
# Stream the response and inspect the chunks
|
||||
for chunk in chat_model_with_tools.stream(prompt):
|
||||
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
|
||||
|
||||
if chunk.tool_call_chunks:
|
||||
tool_chunk_found = True
|
||||
for tc_chunk in chunk.tool_call_chunks:
|
||||
collected_tool_chunks.append(tc_chunk)
|
||||
|
||||
if chunk.tool_calls:
|
||||
final_tool_calls.extend(chunk.tool_calls)
|
||||
|
||||
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
|
||||
assert (
|
||||
len(final_tool_calls) == 1
|
||||
), f"Expected 1 final tool call, but got {len(final_tool_calls)}"
|
||||
|
||||
final_tool_call = final_tool_calls[0]
|
||||
assert final_tool_call["name"] == "get_current_weather"
|
||||
assert final_tool_call["args"] == {"location": "Boston"}
|
||||
|
||||
assert len(collected_tool_chunks) > 0
|
||||
assert collected_tool_chunks[0]["name"] == "get_current_weather"
|
||||
|
||||
# The ID should be consistent across chunks that have it
|
||||
tool_call_id = collected_tool_chunks[0].get("id")
|
||||
assert tool_call_id is not None
|
||||
assert all(
|
||||
chunk.get("id") == tool_call_id
|
||||
for chunk in collected_tool_chunks
|
||||
if chunk.get("id")
|
||||
)
|
||||
assert final_tool_call["id"] == tool_call_id
|
||||
|
||||
async def test_tool_astreaming(self, model: BaseChatModel) -> None:
|
||||
"""Test that the model can stream tool calls."""
|
||||
chat_model_with_tools = model.bind_tools([get_current_weather])
|
||||
|
||||
prompt = [HumanMessage("What is the weather today in Boston?")]
|
||||
|
||||
# Flags and collectors for validation
|
||||
tool_chunk_found = False
|
||||
final_tool_calls = []
|
||||
collected_tool_chunks: list[ToolCallChunk] = []
|
||||
|
||||
# Stream the response and inspect the chunks
|
||||
async for chunk in chat_model_with_tools.astream(prompt):
|
||||
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
|
||||
|
||||
if chunk.tool_call_chunks:
|
||||
tool_chunk_found = True
|
||||
for tc_chunk in chunk.tool_call_chunks:
|
||||
collected_tool_chunks.append(tc_chunk)
|
||||
|
||||
if chunk.tool_calls:
|
||||
final_tool_calls.extend(chunk.tool_calls)
|
||||
|
||||
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
|
||||
assert (
|
||||
len(final_tool_calls) == 1
|
||||
), f"Expected 1 final tool call, but got {len(final_tool_calls)}"
|
||||
|
||||
final_tool_call = final_tool_calls[0]
|
||||
assert final_tool_call["name"] == "get_current_weather"
|
||||
assert final_tool_call["args"] == {"location": "Boston"}
|
||||
|
||||
assert len(collected_tool_chunks) > 0
|
||||
assert collected_tool_chunks[0]["name"] == "get_current_weather"
|
||||
|
||||
# The ID should be consistent across chunks that have it
|
||||
tool_call_id = collected_tool_chunks[0].get("id")
|
||||
assert tool_call_id is not None
|
||||
assert all(
|
||||
chunk.get("id") == tool_call_id
|
||||
for chunk in collected_tool_chunks
|
||||
if chunk.get("id")
|
||||
)
|
||||
assert final_tool_call["id"] == tool_call_id
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"Will sometime encounter AssertionErrors where tool responses are "
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
"""Test OllamaLLM llm."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from langchain_ollama.llms import OllamaLLM
|
||||
|
||||
MODEL_NAME = "llama3.1"
|
||||
|
||||
SAMPLE = "What is 3^3?"
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
@@ -15,6 +19,59 @@ def test_stream() -> None:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_stream_no_reasoning(model: str) -> None:
|
||||
"""Test streaming with `reasoning=False`"""
|
||||
llm = OllamaLLM(model=model, num_ctx=2**12)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
# Sanity check the old behavior isn't present
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_reasoning_stream(model: str) -> None:
|
||||
"""Test streaming with `reasoning=True`"""
|
||||
llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
|
||||
# Sanity check the old behavior isn't present
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
@@ -23,6 +80,59 @@ async def test_astream() -> None:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
async def test_astream_no_reasoning(model: str) -> None:
|
||||
"""Test async streaming with `reasoning=False`"""
|
||||
llm = OllamaLLM(model=model, num_ctx=2**12)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
# Sanity check the old behavior isn't present
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
async def test_reasoning_astream(model: str) -> None:
|
||||
"""Test async streaming with `reasoning=True`"""
|
||||
llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
|
||||
# Sanity check the old behavior isn't present
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
"""Test streaming tokens from OllamaLLM."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
@@ -60,8 +170,68 @@ async def test_ainvoke() -> None:
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
# TODO
|
||||
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
# async def test_ainvoke_no_reasoning(model: str) -> None:
|
||||
# """Test using async invoke with `reasoning=False`"""
|
||||
# llm = OllamaLLM(model=model, num_ctx=2**12)
|
||||
# message = SAMPLE
|
||||
# result = await llm.ainvoke(message)
|
||||
# assert result.content
|
||||
# assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
# # Sanity check the old behavior isn't present
|
||||
# assert "<think>" not in result.content and "</think>" not in result.content
|
||||
|
||||
|
||||
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
# async def test_reasoning_ainvoke(model: str) -> None:
|
||||
# """Test invoke with `reasoning=True`"""
|
||||
# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
||||
# message = SAMPLE
|
||||
# result = await llm.ainvoke(message)
|
||||
# assert result.content
|
||||
# assert "reasoning_content" in result.additional_kwargs
|
||||
# assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
|
||||
# # Sanity check the old behavior isn't present
|
||||
# assert "<think>" not in result.content and "</think>" not in result.content
|
||||
# assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
# assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
"""Test invoke tokens from OllamaLLM."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
# TODO
|
||||
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
# def test_invoke_no_reasoning(model: str) -> None:
|
||||
# """Test using invoke with `reasoning=False`"""
|
||||
# llm = OllamaLLM(model=model, num_ctx=2**12)
|
||||
# message = SAMPLE
|
||||
# result = llm.invoke(message)
|
||||
# assert result.content
|
||||
# assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
# # Sanity check the old behavior isn't present
|
||||
# assert "<think>" not in result.content and "</think>" not in result.content
|
||||
|
||||
|
||||
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
# def test_reasoning_invoke(model: str) -> None:
|
||||
# """Test invoke with `reasoning=True`"""
|
||||
# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
||||
# message = SAMPLE
|
||||
# result = llm.invoke(message)
|
||||
# assert result.content
|
||||
# assert "reasoning_content" in result.additional_kwargs
|
||||
# assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
|
||||
# # Sanity check the old behavior isn't present
|
||||
# assert "<think>" not in result.content and "</think>" not in result.content
|
||||
# assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
# assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
Reference in New Issue
Block a user