ollama: add reasoning model support (e.g. deepseek) (#29689)

# Description
This PR adds reasoning model support for `langchain-ollama` by
extracting reasoning token blocks, like those used in deepseek. It was
inspired by
[ollama-deep-researcher](https://github.com/langchain-ai/ollama-deep-researcher),
specifically the parsing of [thinking
blocks](6d1aaf2139/src/assistant/graph.py (L91)):
```python
  # TODO: This is a hack to remove the <think> tags w/ Deepseek models 
  # It appears very challenging to prompt them out of the responses 
  while "<think>" in running_summary and "</think>" in running_summary:
      start = running_summary.find("<think>")
      end = running_summary.find("</think>") + len("</think>")
      running_summary = running_summary[:start] + running_summary[end:]
```

This notes that it is very hard to remove the reasoning block from
prompting, but we actually want the model to reason in order to increase
model performance. This implementation extracts the thinking block, so
the client can still expect a proper message to be returned by
`ChatOllama` (and use the reasoning content separately when desired).

This implementation takes the same approach as
[ChatDeepseek](5d581ba22c/libs/partners/deepseek/langchain_deepseek/chat_models.py (L215)),
which adds the reasoning content to
chunk.additional_kwargs.reasoning_content;
```python
  if hasattr(response.choices[0].message, "reasoning_content"):  # type: ignore
      rtn.generations[0].message.additional_kwargs["reasoning_content"] = (
          response.choices[0].message.reasoning_content  # type: ignore
      )
```

This should probably be handled upstream in ollama + ollama-python, but
this seems like a reasonably effective solution. This is a standalone
example of what is happening;

```python
async def deepseek_message_astream(
    llm: BaseChatModel,
    messages: list[BaseMessage],
    config: RunnableConfig | None = None,
    *,
    model_target: str = "deepseek-r1",
    **kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]:
    """Stream responses from Deepseek models, filtering out <think> tags.

    Args:
        llm: The language model to stream from
        messages: The messages to send to the model

    Yields:
        Filtered chunks from the model response
    """
    # check if the model is deepseek based
    if (llm.name and model_target not in llm.name) or (hasattr(llm, "model") and model_target not in llm.model):
        async for chunk in llm.astream(messages, config=config, **kwargs):
            yield chunk
        return

    # Yield with a buffer, upon completing the <think></think> tags, move them to the reasoning content and start over
    buffer = ""
    async for chunk in llm.astream(messages, config=config, **kwargs):
        # start or append
        if not buffer:
            buffer = chunk.content
        else:
            buffer += chunk.content if hasattr(chunk, "content") else chunk

        # Process buffer to remove <think> tags
        if "<think>" in buffer or "</think>" in buffer:
            if hasattr(chunk, "tool_calls") and chunk.tool_calls:
                raise NotImplementedError("tool calls during reasoning should be removed?")
            if "<think>" in chunk.content or "</think>" in chunk.content:
                continue
            chunk.additional_kwargs["reasoning_content"] = chunk.content
            chunk.content = ""
        # upon block completion, reset the buffer
        if "<think>" in buffer and "</think>" in buffer:
            buffer = ""
        yield chunk

```

# Issue
Integrating reasoning models (e.g. deepseek-r1) into existing LangChain
based workflows is hard due to the thinking blocks that are included in
the message contents. To avoid this, we could match the `ChatOllama`
integration with `ChatDeepseek` to return the reasoning content inside
`message.additional_arguments.reasoning_content` instead.

# Dependenices
None

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Bob Merkus 2025-03-21 16:44:54 +01:00 committed by GitHub
parent d8145dda95
commit 5700646cc5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 266 additions and 68 deletions

View File

@ -7,12 +7,14 @@ from typing import (
AsyncIterator,
Callable,
Dict,
Final,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
@ -30,6 +32,7 @@ from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
HumanMessage,
SystemMessage,
ToolCall,
@ -57,6 +60,9 @@ from pydantic.json_schema import JsonSchemaValue
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self, is_typeddict
DEFAULT_THINK_TOKEN_START: Final[str] = "<think>"
DEFAULT_THINK_TOKEN_END: Final[str] = "</think>"
def _get_usage_metadata_from_generation_info(
generation_info: Optional[Mapping[str, Any]],
@ -335,6 +341,13 @@ class ChatOllama(BaseChatModel):
model: str
"""Model name to use."""
extract_reasoning: Optional[Union[bool, Tuple[str, str]]] = False
"""Whether to extract the reasoning tokens in think blocks.
Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`.
If a tuple is supplied, they are assumed to be the (start, end) tokens.
If `extract_reasoning=True`, the tokens will default to (<think>, </think>).
"""
mirostat: Optional[int] = None
"""Enable Mirostat sampling for controlling perplexity.
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"""
@ -568,6 +581,28 @@ class ChatOllama(BaseChatModel):
return ollama_messages
def _extract_reasoning(
self, message_chunk: BaseMessageChunk, is_thinking: bool
) -> Tuple[BaseMessageChunk, bool]:
"""Mutate a message chunk to extract reasoning content."""
if not self.extract_reasoning:
return message_chunk, is_thinking
elif self.extract_reasoning is True:
start_token = DEFAULT_THINK_TOKEN_START
end_token = DEFAULT_THINK_TOKEN_END
else:
start_token, end_token = cast(tuple, self.extract_reasoning)
if start_token in message_chunk.content:
is_thinking = True
content = message_chunk.content
if is_thinking:
message_chunk.additional_kwargs["reasoning_content"] = content
message_chunk.content = ""
if end_token in content:
is_thinking = False
return message_chunk, is_thinking
async def _acreate_chat_stream(
self,
messages: List[BaseMessage],
@ -604,35 +639,17 @@ class ChatOllama(BaseChatModel):
**kwargs: Any,
) -> ChatGenerationChunk:
final_chunk = None
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
if not isinstance(stream_resp, str):
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content=(
stream_resp["message"]["content"]
if "message" in stream_resp
and "content" in stream_resp["message"]
else ""
),
usage_metadata=_get_usage_metadata_from_generation_info(
stream_resp
),
tool_calls=_get_tool_calls_from_response(stream_resp),
),
generation_info=(
dict(stream_resp) if stream_resp.get("done") is True else None
),
for chunk in self._iterate_over_stream(messages, stop, **kwargs):
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=verbose,
)
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=verbose,
)
if final_chunk is None:
raise ValueError("No data received from Ollama stream.")
@ -647,35 +664,17 @@ class ChatOllama(BaseChatModel):
**kwargs: Any,
) -> ChatGenerationChunk:
final_chunk = None
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
if not isinstance(stream_resp, str):
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content=(
stream_resp["message"]["content"]
if "message" in stream_resp
and "content" in stream_resp["message"]
else ""
),
usage_metadata=_get_usage_metadata_from_generation_info(
stream_resp
),
tool_calls=_get_tool_calls_from_response(stream_resp),
),
generation_info=(
dict(stream_resp) if stream_resp.get("done") is True else None
),
async for chunk in self._aiterate_over_stream(messages, stop, **kwargs):
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=verbose,
)
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=verbose,
)
if final_chunk is None:
raise ValueError("No data received from Ollama stream.")
@ -712,18 +711,19 @@ class ChatOllama(BaseChatModel):
content=final_chunk.text,
usage_metadata=cast(AIMessageChunk, final_chunk.message).usage_metadata,
tool_calls=cast(AIMessageChunk, final_chunk.message).tool_calls,
additional_kwargs=final_chunk.message.additional_kwargs,
),
generation_info=generation_info,
)
return ChatResult(generations=[chat_generation])
def _stream(
def _iterate_over_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
is_thinking = False
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
if not isinstance(stream_resp, str):
chunk = ChatGenerationChunk(
@ -743,20 +743,35 @@ class ChatOllama(BaseChatModel):
dict(stream_resp) if stream_resp.get("done") is True else None
),
)
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
if self.extract_reasoning:
message, is_thinking = self._extract_reasoning(
chunk.message, is_thinking
)
chunk.message = message
yield chunk
async def _astream(
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
for chunk in self._iterate_over_stream(messages, stop, **kwargs):
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
yield chunk
async def _aiterate_over_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
is_thinking = False
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
if not isinstance(stream_resp, str):
chunk = ChatGenerationChunk(
@ -776,13 +791,28 @@ class ChatOllama(BaseChatModel):
dict(stream_resp) if stream_resp.get("done") is True else None
),
)
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
if self.extract_reasoning:
message, is_thinking = self._extract_reasoning(
chunk.message, is_thinking
)
chunk.message = message
yield chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
async for chunk in self._aiterate_over_stream(messages, stop, **kwargs):
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
yield chunk
async def _agenerate(
self,
messages: List[BaseMessage],
@ -799,6 +829,7 @@ class ChatOllama(BaseChatModel):
content=final_chunk.text,
usage_metadata=cast(AIMessageChunk, final_chunk.message).usage_metadata,
tool_calls=cast(AIMessageChunk, final_chunk.message).tool_calls,
additional_kwargs=final_chunk.message.additional_kwargs,
),
generation_info=generation_info,
)
@ -1083,6 +1114,7 @@ class ChatOllama(BaseChatModel):
# 'parsing_error': None
# }
""" # noqa: E501, D301
_ = kwargs.pop("strict", None)
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)

View File

@ -0,0 +1,162 @@
"""Ollama specific chat model integration tests for reasoning models."""
import pytest
from langchain_core.messages import AIMessageChunk, BaseMessageChunk, HumanMessage
from pydantic import ValidationError
from langchain_ollama import ChatOllama
SAMPLE = "What is 3^3?"
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_stream_no_reasoning(model: str) -> None:
"""Test deepseek model without parsing."""
llm = ChatOllama(model=model, num_ctx=2**12)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "<think>" in result.content and "</think>" in result.content
assert "reasoning_content" not in result.additional_kwargs
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_stream_bool(model: str) -> None:
"""Test deepseek model with reasoning bool=True"""
llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "<think>" not in result.content and "</think>" not in result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" in result.additional_kwargs["reasoning_content"]
assert "</think>" in result.additional_kwargs["reasoning_content"]
clean_content = (
result.additional_kwargs["reasoning_content"]
.replace("<think>", "")
.replace("</think>", "")
.strip()
)
assert len(clean_content) > 0
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_stream_tuple(model: str) -> None:
"""Test deepseek model with reasoning with tuple=..."""
llm = ChatOllama(
model=model, num_ctx=2**12, extract_reasoning=("<think>", "</think>")
)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "<think>" not in result.content and "</think>" not in result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" in result.additional_kwargs["reasoning_content"]
assert "</think>" in result.additional_kwargs["reasoning_content"]
clean_content = (
result.additional_kwargs["reasoning_content"]
.replace("<think>", "")
.replace("</think>", "")
.strip()
)
assert len(clean_content) > 0
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_invoke_no_reasoning(model: str) -> None:
"""Test deepseek model without parsing using invoke."""
llm = ChatOllama(model=model, num_ctx=2**12)
message = HumanMessage(content=SAMPLE)
result = llm.invoke([message])
assert result.content
assert "<think>" in result.content and "</think>" in result.content
assert "reasoning_content" not in result.additional_kwargs
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_invoke_bool(model: str) -> None:
"""Test deepseek model with reasoning bool=True using invoke"""
llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True)
message = HumanMessage(content=SAMPLE)
result = llm.invoke([message])
assert result.content
assert "<think>" not in result.content and "</think>" not in result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" in result.additional_kwargs["reasoning_content"]
assert "</think>" in result.additional_kwargs["reasoning_content"]
clean_content = (
result.additional_kwargs["reasoning_content"]
.replace("<think>", "")
.replace("</think>", "")
.strip()
)
assert len(clean_content) > 0
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_invoke_tuple(model: str) -> None:
"""Test deepseek model with reasoning with tuple=... using invoke"""
llm = ChatOllama(
model=model, num_ctx=2**12, extract_reasoning=("<think>", "</think>")
)
message = HumanMessage(content=SAMPLE)
result = llm.invoke([message])
assert result.content
assert "<think>" not in result.content and "</think>" not in result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" in result.additional_kwargs["reasoning_content"]
assert "</think>" in result.additional_kwargs["reasoning_content"]
clean_content = (
result.additional_kwargs["reasoning_content"]
.replace("<think>", "")
.replace("</think>", "")
.strip()
)
assert len(clean_content) > 0
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_invalid(model: str) -> None:
"""Test deepseek model with reasoning raises ValidationError"""
with pytest.raises(ValidationError):
_ = ChatOllama(model=model, extract_reasoning={"invalid": "data"}) # type: ignore[arg-type]

View File

@ -23,3 +23,7 @@ class TestChatOllama(ChatModelIntegrationTests):
@property
def supports_json_mode(self) -> bool:
return True
@property
def has_tool_choice(self) -> bool:
return False