mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
ollama: add reasoning model support (e.g. deepseek) (#29689)
# Description This PR adds reasoning model support for `langchain-ollama` by extracting reasoning token blocks, like those used in deepseek. It was inspired by [ollama-deep-researcher](https://github.com/langchain-ai/ollama-deep-researcher), specifically the parsing of [thinking blocks](6d1aaf2139/src/assistant/graph.py (L91)): ```python # TODO: This is a hack to remove the <think> tags w/ Deepseek models # It appears very challenging to prompt them out of the responses while "<think>" in running_summary and "</think>" in running_summary: start = running_summary.find("<think>") end = running_summary.find("</think>") + len("</think>") running_summary = running_summary[:start] + running_summary[end:] ``` This notes that it is very hard to remove the reasoning block from prompting, but we actually want the model to reason in order to increase model performance. This implementation extracts the thinking block, so the client can still expect a proper message to be returned by `ChatOllama` (and use the reasoning content separately when desired). This implementation takes the same approach as [ChatDeepseek](5d581ba22c/libs/partners/deepseek/langchain_deepseek/chat_models.py (L215)), which adds the reasoning content to chunk.additional_kwargs.reasoning_content; ```python if hasattr(response.choices[0].message, "reasoning_content"): # type: ignore rtn.generations[0].message.additional_kwargs["reasoning_content"] = ( response.choices[0].message.reasoning_content # type: ignore ) ``` This should probably be handled upstream in ollama + ollama-python, but this seems like a reasonably effective solution. This is a standalone example of what is happening; ```python async def deepseek_message_astream( llm: BaseChatModel, messages: list[BaseMessage], config: RunnableConfig | None = None, *, model_target: str = "deepseek-r1", **kwargs: Any, ) -> AsyncIterator[BaseMessageChunk]: """Stream responses from Deepseek models, filtering out <think> tags. Args: llm: The language model to stream from messages: The messages to send to the model Yields: Filtered chunks from the model response """ # check if the model is deepseek based if (llm.name and model_target not in llm.name) or (hasattr(llm, "model") and model_target not in llm.model): async for chunk in llm.astream(messages, config=config, **kwargs): yield chunk return # Yield with a buffer, upon completing the <think></think> tags, move them to the reasoning content and start over buffer = "" async for chunk in llm.astream(messages, config=config, **kwargs): # start or append if not buffer: buffer = chunk.content else: buffer += chunk.content if hasattr(chunk, "content") else chunk # Process buffer to remove <think> tags if "<think>" in buffer or "</think>" in buffer: if hasattr(chunk, "tool_calls") and chunk.tool_calls: raise NotImplementedError("tool calls during reasoning should be removed?") if "<think>" in chunk.content or "</think>" in chunk.content: continue chunk.additional_kwargs["reasoning_content"] = chunk.content chunk.content = "" # upon block completion, reset the buffer if "<think>" in buffer and "</think>" in buffer: buffer = "" yield chunk ``` # Issue Integrating reasoning models (e.g. deepseek-r1) into existing LangChain based workflows is hard due to the thinking blocks that are included in the message contents. To avoid this, we could match the `ChatOllama` integration with `ChatDeepseek` to return the reasoning content inside `message.additional_arguments.reasoning_content` instead. # Dependenices None --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
@@ -0,0 +1,162 @@
|
||||
"""Ollama specific chat model integration tests for reasoning models."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessageChunk, HumanMessage
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
SAMPLE = "What is 3^3?"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_stream_no_reasoning(model: str) -> None:
|
||||
"""Test deepseek model without parsing."""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "<think>" in result.content and "</think>" in result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_stream_bool(model: str) -> None:
|
||||
"""Test deepseek model with reasoning bool=True"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" in result.additional_kwargs["reasoning_content"]
|
||||
clean_content = (
|
||||
result.additional_kwargs["reasoning_content"]
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.strip()
|
||||
)
|
||||
assert len(clean_content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_stream_tuple(model: str) -> None:
|
||||
"""Test deepseek model with reasoning with tuple=..."""
|
||||
llm = ChatOllama(
|
||||
model=model, num_ctx=2**12, extract_reasoning=("<think>", "</think>")
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" in result.additional_kwargs["reasoning_content"]
|
||||
clean_content = (
|
||||
result.additional_kwargs["reasoning_content"]
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.strip()
|
||||
)
|
||||
assert len(clean_content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_invoke_no_reasoning(model: str) -> None:
|
||||
"""Test deepseek model without parsing using invoke."""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = llm.invoke([message])
|
||||
assert result.content
|
||||
assert "<think>" in result.content and "</think>" in result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_invoke_bool(model: str) -> None:
|
||||
"""Test deepseek model with reasoning bool=True using invoke"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = llm.invoke([message])
|
||||
assert result.content
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" in result.additional_kwargs["reasoning_content"]
|
||||
clean_content = (
|
||||
result.additional_kwargs["reasoning_content"]
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.strip()
|
||||
)
|
||||
assert len(clean_content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_invoke_tuple(model: str) -> None:
|
||||
"""Test deepseek model with reasoning with tuple=... using invoke"""
|
||||
llm = ChatOllama(
|
||||
model=model, num_ctx=2**12, extract_reasoning=("<think>", "</think>")
|
||||
)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = llm.invoke([message])
|
||||
assert result.content
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" in result.additional_kwargs["reasoning_content"]
|
||||
clean_content = (
|
||||
result.additional_kwargs["reasoning_content"]
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.strip()
|
||||
)
|
||||
assert len(clean_content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_invalid(model: str) -> None:
|
||||
"""Test deepseek model with reasoning raises ValidationError"""
|
||||
with pytest.raises(ValidationError):
|
||||
_ = ChatOllama(model=model, extract_reasoning={"invalid": "data"}) # type: ignore[arg-type]
|
||||
@@ -23,3 +23,7 @@ class TestChatOllama(ChatModelIntegrationTests):
|
||||
@property
|
||||
def supports_json_mode(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def has_tool_choice(self) -> bool:
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user