ollama[patch]: fix model validation, ensure per-call reasoning can be set, tests (#31927)

* update model validation due to change in [Ollama
client](https://github.com/ollama/ollama) - ensure you are running the
latest version (0.9.6) to use `validate_model_on_init`
* add code example and fix formatting for ChatOllama reasoning
* ensure that setting `reasoning` in invocation kwargs overrides
class-level setting
* tests
This commit is contained in:
Mason Daugherty 2025-07-08 16:39:41 -04:00 committed by GitHub
parent f33a25773e
commit 0002b1dafa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 174 additions and 176 deletions

View File

@ -16,7 +16,9 @@ def validate_model(client: Client, model_name: str) -> None:
"""
try:
response = client.list()
model_names: list[str] = [model["name"] for model in response["models"]]
model_names: list[str] = [model["model"] for model in response["models"]]
if not any(
model_name == m or m.startswith(f"{model_name}:") for m in model_names
):
@ -27,10 +29,7 @@ def validate_model(client: Client, model_name: str) -> None:
)
raise ValueError(msg)
except ConnectError as e:
msg = (
"Connection to Ollama failed. Please make sure Ollama is running "
f"and accessible at {client._client.base_url}. "
)
msg = "Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download" # noqa: E501
raise ValueError(msg) from e
except ResponseError as e:
msg = (

View File

@ -217,17 +217,17 @@ class ChatOllama(BaseChatModel):
`supported models <https://ollama.com/search?c=thinking>`__.
- ``True``: Enables reasoning mode. The model's reasoning process will be
captured and returned separately in the ``additional_kwargs`` of the
response message, under ``reasoning_content``. The main response
content will not include the reasoning tags.
captured and returned separately in the ``additional_kwargs`` of the
response message, under ``reasoning_content``. The main response
content will not include the reasoning tags.
- ``False``: Disables reasoning mode. The model will not perform any reasoning,
and the response will not include any reasoning content.
and the response will not include any reasoning content.
- ``None`` (Default): The model will use its default reasoning behavior. Note
however, if the model's default behavior *is* to perform reasoning, think tags
()``<think>`` and ``</think>``) will be present within the main response content
unless you set ``reasoning`` to ``True``.
however, if the model's default behavior *is* to perform reasoning, think tags
(``<think>`` and ``</think>``) will be present within the main response content
unless you set ``reasoning`` to ``True``.
temperature: float
Sampling temperature. Ranges from 0.0 to 1.0.
Sampling temperature. Ranges from ``0.0`` to ``1.0``.
num_predict: Optional[int]
Max number of tokens to generate.
@ -343,7 +343,6 @@ class ChatOllama(BaseChatModel):
'{"location": "Pune, India", "time_of_day": "morning"}'
Tool Calling:
.. code-block:: python
from langchain_ollama import ChatOllama
@ -362,6 +361,48 @@ class ChatOllama(BaseChatModel):
'args': {'a': 45, 'b': 67},
'id': '420c3f3b-df10-4188-945f-eb3abdb40622',
'type': 'tool_call'}]
Thinking / Reasoning:
You can enable reasoning mode for models that support it by setting
the ``reasoning`` parameter to ``True`` in either the constructor or
the ``invoke``/``stream`` methods. This will enable the model to think
through the problem and return the reasoning process separately in the
``additional_kwargs`` of the response message, under ``reasoning_content``.
If ``reasoning`` is set to ``None``, the model will use its default reasoning
behavior, and any reasoning content will *not* be captured under the
``reasoning_content`` key, but will be present within the main response content
as think tags (``<think>`` and ``</think>``).
.. note::
This feature is only available for `models that support reasoning <https://ollama.com/search?c=thinking>`__.
.. code-block:: python
from langchain_ollama import ChatOllama
llm = ChatOllama(
model = "deepseek-r1:8b",
reasoning= True,
)
user_message = HumanMessage(content="how many r in the word strawberry?")
messages: List[Any] = [user_message]
llm.invoke(messages)
# or, on an invocation basis:
llm.invoke(messages, reasoning=True)
# or llm.stream(messages, reasoning=True)
# If not provided, the invocation will default to the ChatOllama reasoning
# param provided (None by default).
.. code-block:: python
AIMessage(content='The word "strawberry" contains **three \'r\' letters**. Here\'s a breakdown for clarity:\n\n- The spelling of "strawberry" has two parts ... be 3.\n\nTo be thorough, let\'s confirm with an online source or common knowledge.\n\nI can recall that "strawberry" has: s-t-r-a-w-b-e-r-r-y — yes, three r\'s.\n\nPerhaps it\'s misspelled by some, but standard is correct.\n\nSo I think the response should be 3.\n'}, response_metadata={'model': 'deepseek-r1:8b', 'created_at': '2025-07-08T19:33:55.891269Z', 'done': True, 'done_reason': 'stop', 'total_duration': 98232561292, 'load_duration': 28036792, 'prompt_eval_count': 10, 'prompt_eval_duration': 40171834, 'eval_count': 3615, 'eval_duration': 98163832416, 'model_name': 'deepseek-r1:8b'}, id='run--18f8269f-6a35-4a7c-826d-b89d52c753b3-0', usage_metadata={'input_tokens': 10, 'output_tokens': 3615, 'total_tokens': 3625})
""" # noqa: E501, pylint: disable=line-too-long
model: str
@ -777,6 +818,7 @@ class ChatOllama(BaseChatModel):
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
reasoning = kwargs.get("reasoning", self.reasoning)
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
if not isinstance(stream_resp, str):
if stream_resp.get("done") is True:
@ -795,7 +837,7 @@ class ChatOllama(BaseChatModel):
additional_kwargs = {}
if (
self.reasoning
reasoning
and "message" in stream_resp
and (thinking_content := stream_resp["message"].get("thinking"))
):
@ -836,6 +878,7 @@ class ChatOllama(BaseChatModel):
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
reasoning = kwargs.get("reasoning", self.reasoning)
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
if not isinstance(stream_resp, str):
if stream_resp.get("done") is True:
@ -854,7 +897,7 @@ class ChatOllama(BaseChatModel):
additional_kwargs = {}
if (
self.reasoning
reasoning
and "message" in stream_resp
and (thinking_content := stream_resp["message"].get("thinking"))
):

View File

@ -38,7 +38,7 @@ class OllamaLLM(BaseLLM):
model: str
"""Model name to use."""
reasoning: Optional[bool] = True
reasoning: Optional[bool] = None
"""Controls the reasoning/thinking mode for
`supported models <https://ollama.com/search?c=thinking>`__.
@ -272,8 +272,11 @@ class OllamaLLM(BaseLLM):
**kwargs: Any,
) -> GenerationChunk:
final_chunk = None
thinking_content = ""
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
if not isinstance(stream_resp, str):
if stream_resp.get("thinking"):
thinking_content += stream_resp["thinking"]
chunk = GenerationChunk(
text=stream_resp.get("response", ""),
generation_info=(
@ -294,6 +297,12 @@ class OllamaLLM(BaseLLM):
msg = "No data received from Ollama stream."
raise ValueError(msg)
if thinking_content:
if final_chunk.generation_info:
final_chunk.generation_info["thinking"] = thinking_content
else:
final_chunk.generation_info = {"thinking": thinking_content}
return final_chunk
def _stream_with_aggregation(
@ -305,8 +314,11 @@ class OllamaLLM(BaseLLM):
**kwargs: Any,
) -> GenerationChunk:
final_chunk = None
thinking_content = ""
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if not isinstance(stream_resp, str):
if stream_resp.get("thinking"):
thinking_content += stream_resp["thinking"]
chunk = GenerationChunk(
text=stream_resp.get("response", ""),
generation_info=(
@ -327,6 +339,12 @@ class OllamaLLM(BaseLLM):
msg = "No data received from Ollama stream."
raise ValueError(msg)
if thinking_content:
if final_chunk.generation_info:
final_chunk.generation_info["thinking"] = thinking_content
else:
final_chunk.generation_info = {"thinking": thinking_content}
return final_chunk
def _generate(
@ -374,10 +392,11 @@ class OllamaLLM(BaseLLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
reasoning = kwargs.get("reasoning", self.reasoning)
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if not isinstance(stream_resp, str):
additional_kwargs = {}
if thinking_content := stream_resp.get("thinking"):
if reasoning and (thinking_content := stream_resp.get("thinking")):
additional_kwargs["reasoning_content"] = thinking_content
chunk = GenerationChunk(
@ -404,10 +423,11 @@ class OllamaLLM(BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
reasoning = kwargs.get("reasoning", self.reasoning)
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
if not isinstance(stream_resp, str):
additional_kwargs = {}
if thinking_content := stream_resp.get("thinking"):
if reasoning and (thinking_content := stream_resp.get("thinking")):
additional_kwargs["reasoning_content"] = thinking_content
chunk = GenerationChunk(

View File

@ -1,18 +1,17 @@
"""Test OllamaLLM llm."""
import pytest
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
from langchain_core.outputs import GenerationChunk
from langchain_core.runnables import RunnableConfig
from langchain_ollama.llms import OllamaLLM
MODEL_NAME = "llama3.1"
SAMPLE = "What is 3^3?"
def test_stream() -> None:
"""Test streaming tokens from OpenAI."""
def test_stream_text_tokens() -> None:
"""Test streaming raw string tokens from OllamaLLM."""
llm = OllamaLLM(model=MODEL_NAME)
for token in llm.stream("I'm Pickle Rick"):
@ -20,60 +19,52 @@ def test_stream() -> None:
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_stream_no_reasoning(model: str) -> None:
"""Test streaming with `reasoning=False`"""
def test__stream_no_reasoning(model: str) -> None:
"""Test low-level chunk streaming of a simple prompt with `reasoning=False`."""
llm = OllamaLLM(model=model, num_ctx=2**12)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" not in result.additional_kwargs
# Sanity check the old behavior isn't present
assert "<think>" not in result.content and "</think>" not in result.content
result_chunk = None
for chunk in llm._stream(SAMPLE):
# Should be a GenerationChunk
assert isinstance(chunk, GenerationChunk)
if result_chunk is None:
result_chunk = chunk
else:
result_chunk += chunk
# The final result must be a GenerationChunk with visible content
assert isinstance(result_chunk, GenerationChunk)
assert result_chunk.text
# No separate reasoning_content
assert "reasoning_content" not in result_chunk.generation_info # type: ignore[operator]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_reasoning_stream(model: str) -> None:
"""Test streaming with `reasoning=True`"""
def test__stream_with_reasoning(model: str) -> None:
"""Test low-level chunk streaming with `reasoning=True`."""
llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
# Sanity check the old behavior isn't present
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
result_chunk = None
for chunk in llm._stream(SAMPLE):
assert isinstance(chunk, GenerationChunk)
if result_chunk is None:
result_chunk = chunk
else:
result_chunk += chunk
assert isinstance(result_chunk, GenerationChunk)
assert result_chunk.text
# Should have extracted reasoning into generation_info
assert "reasoning_content" in result_chunk.generation_info # type: ignore[operator]
assert len(result_chunk.generation_info["reasoning_content"]) > 0 # type: ignore[index]
# And neither the visible nor the hidden portion contains <think> tags
assert "<think>" not in result_chunk.text and "</think>" not in result_chunk.text
assert "<think>" not in result_chunk.generation_info["reasoning_content"] # type: ignore[index]
assert "</think>" not in result_chunk.generation_info["reasoning_content"] # type: ignore[index]
async def test_astream() -> None:
"""Test streaming tokens from OpenAI."""
async def test_astream_text_tokens() -> None:
"""Test async streaming raw string tokens from OllamaLLM."""
llm = OllamaLLM(model=MODEL_NAME)
async for token in llm.astream("I'm Pickle Rick"):
@ -81,60 +72,44 @@ async def test_astream() -> None:
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_astream_no_reasoning(model: str) -> None:
"""Test async streaming with `reasoning=False`"""
async def test__astream_no_reasoning(model: str) -> None:
"""Test low-level async chunk streaming with `reasoning=False`."""
llm = OllamaLLM(model=model, num_ctx=2**12)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" not in result.additional_kwargs
# Sanity check the old behavior isn't present
assert "<think>" not in result.content and "</think>" not in result.content
result_chunk = None
async for chunk in llm._astream(SAMPLE):
assert isinstance(chunk, GenerationChunk)
if result_chunk is None:
result_chunk = chunk
else:
result_chunk += chunk
assert isinstance(result_chunk, GenerationChunk)
assert result_chunk.text
assert "reasoning_content" not in result_chunk.generation_info # type: ignore[operator]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_reasoning_astream(model: str) -> None:
"""Test async streaming with `reasoning=True`"""
async def test__astream_with_reasoning(model: str) -> None:
"""Test low-level async chunk streaming with `reasoning=True`."""
llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
# Sanity check the old behavior isn't present
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
result_chunk = None
async for chunk in llm._astream(SAMPLE):
assert isinstance(chunk, GenerationChunk)
if result_chunk is None:
result_chunk = chunk
else:
result_chunk += chunk
assert isinstance(result_chunk, GenerationChunk)
assert result_chunk.text
assert "reasoning_content" in result_chunk.generation_info # type: ignore[operator]
assert len(result_chunk.generation_info["reasoning_content"]) > 0 # type: ignore[index]
async def test_abatch() -> None:
"""Test streaming tokens from OllamaLLM."""
"""Test batch sync token generation from OllamaLLM."""
llm = OllamaLLM(model=MODEL_NAME)
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
@ -143,7 +118,7 @@ async def test_abatch() -> None:
async def test_abatch_tags() -> None:
"""Test batch tokens from OllamaLLM."""
"""Test batch sync token generation with tags."""
llm = OllamaLLM(model=MODEL_NAME)
result = await llm.abatch(
@ -154,7 +129,7 @@ async def test_abatch_tags() -> None:
def test_batch() -> None:
"""Test batch tokens from OllamaLLM."""
"""Test batch token generation from OllamaLLM."""
llm = OllamaLLM(model=MODEL_NAME)
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
@ -163,75 +138,15 @@ def test_batch() -> None:
async def test_ainvoke() -> None:
"""Test invoke tokens from OllamaLLM."""
"""Test async invoke returning a string."""
llm = OllamaLLM(model=MODEL_NAME)
result = await llm.ainvoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
assert isinstance(result, str)
# TODO
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
# async def test_ainvoke_no_reasoning(model: str) -> None:
# """Test using async invoke with `reasoning=False`"""
# llm = OllamaLLM(model=model, num_ctx=2**12)
# message = SAMPLE
# result = await llm.ainvoke(message)
# assert result.content
# assert "reasoning_content" not in result.additional_kwargs
# # Sanity check the old behavior isn't present
# assert "<think>" not in result.content and "</think>" not in result.content
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
# async def test_reasoning_ainvoke(model: str) -> None:
# """Test invoke with `reasoning=True`"""
# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
# message = SAMPLE
# result = await llm.ainvoke(message)
# assert result.content
# assert "reasoning_content" in result.additional_kwargs
# assert len(result.additional_kwargs["reasoning_content"]) > 0
# # Sanity check the old behavior isn't present
# assert "<think>" not in result.content and "</think>" not in result.content
# assert "<think>" not in result.additional_kwargs["reasoning_content"]
# assert "</think>" not in result.additional_kwargs["reasoning_content"]
def test_invoke() -> None:
"""Test invoke tokens from OllamaLLM."""
"""Test sync invoke returning a string."""
llm = OllamaLLM(model=MODEL_NAME)
result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
assert isinstance(result, str)
# TODO
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
# def test_invoke_no_reasoning(model: str) -> None:
# """Test using invoke with `reasoning=False`"""
# llm = OllamaLLM(model=model, num_ctx=2**12)
# message = SAMPLE
# result = llm.invoke(message)
# assert result.content
# assert "reasoning_content" not in result.additional_kwargs
# # Sanity check the old behavior isn't present
# assert "<think>" not in result.content and "</think>" not in result.content
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
# def test_reasoning_invoke(model: str) -> None:
# """Test invoke with `reasoning=True`"""
# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
# message = SAMPLE
# result = llm.invoke(message)
# assert result.content
# assert "reasoning_content" in result.additional_kwargs
# assert len(result.additional_kwargs["reasoning_content"]) > 0
# # Sanity check the old behavior isn't present
# assert "<think>" not in result.content and "</think>" not in result.content
# assert "<think>" not in result.additional_kwargs["reasoning_content"]
# assert "</think>" not in result.additional_kwargs["reasoning_content"]

View File

@ -48,3 +48,24 @@ def test_validate_model_on_init(mock_validate_model: Any) -> None:
# Test that validate_model is NOT called by default
OllamaLLM(model=MODEL_NAME)
mock_validate_model.assert_not_called()
def test_reasoning_aggregation() -> None:
"""Test that reasoning chunks are aggregated into final response."""
llm = OllamaLLM(model=MODEL_NAME, reasoning=True)
prompts = ["some prompt"]
mock_stream = [
{"thinking": "I am thinking.", "done": False},
{"thinking": " Still thinking.", "done": False},
{"response": "Final Answer.", "done": True},
]
with patch.object(llm, "_create_generate_stream") as mock_stream_method:
mock_stream_method.return_value = iter(mock_stream)
result = llm.generate(prompts)
assert result.generations[0][0].generation_info is not None
assert (
result.generations[0][0].generation_info["thinking"]
== "I am thinking. Still thinking."
)