mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
ollama: add reasoning model support (e.g. deepseek) (#29689)
# Description This PR adds reasoning model support for `langchain-ollama` by extracting reasoning token blocks, like those used in deepseek. It was inspired by [ollama-deep-researcher](https://github.com/langchain-ai/ollama-deep-researcher), specifically the parsing of [thinking blocks](6d1aaf2139/src/assistant/graph.py (L91)
): ```python # TODO: This is a hack to remove the <think> tags w/ Deepseek models # It appears very challenging to prompt them out of the responses while "<think>" in running_summary and "</think>" in running_summary: start = running_summary.find("<think>") end = running_summary.find("</think>") + len("</think>") running_summary = running_summary[:start] + running_summary[end:] ``` This notes that it is very hard to remove the reasoning block from prompting, but we actually want the model to reason in order to increase model performance. This implementation extracts the thinking block, so the client can still expect a proper message to be returned by `ChatOllama` (and use the reasoning content separately when desired). This implementation takes the same approach as [ChatDeepseek](5d581ba22c/libs/partners/deepseek/langchain_deepseek/chat_models.py (L215)
), which adds the reasoning content to chunk.additional_kwargs.reasoning_content; ```python if hasattr(response.choices[0].message, "reasoning_content"): # type: ignore rtn.generations[0].message.additional_kwargs["reasoning_content"] = ( response.choices[0].message.reasoning_content # type: ignore ) ``` This should probably be handled upstream in ollama + ollama-python, but this seems like a reasonably effective solution. This is a standalone example of what is happening; ```python async def deepseek_message_astream( llm: BaseChatModel, messages: list[BaseMessage], config: RunnableConfig | None = None, *, model_target: str = "deepseek-r1", **kwargs: Any, ) -> AsyncIterator[BaseMessageChunk]: """Stream responses from Deepseek models, filtering out <think> tags. Args: llm: The language model to stream from messages: The messages to send to the model Yields: Filtered chunks from the model response """ # check if the model is deepseek based if (llm.name and model_target not in llm.name) or (hasattr(llm, "model") and model_target not in llm.model): async for chunk in llm.astream(messages, config=config, **kwargs): yield chunk return # Yield with a buffer, upon completing the <think></think> tags, move them to the reasoning content and start over buffer = "" async for chunk in llm.astream(messages, config=config, **kwargs): # start or append if not buffer: buffer = chunk.content else: buffer += chunk.content if hasattr(chunk, "content") else chunk # Process buffer to remove <think> tags if "<think>" in buffer or "</think>" in buffer: if hasattr(chunk, "tool_calls") and chunk.tool_calls: raise NotImplementedError("tool calls during reasoning should be removed?") if "<think>" in chunk.content or "</think>" in chunk.content: continue chunk.additional_kwargs["reasoning_content"] = chunk.content chunk.content = "" # upon block completion, reset the buffer if "<think>" in buffer and "</think>" in buffer: buffer = "" yield chunk ``` # Issue Integrating reasoning models (e.g. deepseek-r1) into existing LangChain based workflows is hard due to the thinking blocks that are included in the message contents. To avoid this, we could match the `ChatOllama` integration with `ChatDeepseek` to return the reasoning content inside `message.additional_arguments.reasoning_content` instead. # Dependenices None --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
d8145dda95
commit
5700646cc5
@ -7,12 +7,14 @@ from typing import (
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Final,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
@ -30,6 +32,7 @@ from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
@ -57,6 +60,9 @@ from pydantic.json_schema import JsonSchemaValue
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from typing_extensions import Self, is_typeddict
|
||||
|
||||
DEFAULT_THINK_TOKEN_START: Final[str] = "<think>"
|
||||
DEFAULT_THINK_TOKEN_END: Final[str] = "</think>"
|
||||
|
||||
|
||||
def _get_usage_metadata_from_generation_info(
|
||||
generation_info: Optional[Mapping[str, Any]],
|
||||
@ -335,6 +341,13 @@ class ChatOllama(BaseChatModel):
|
||||
model: str
|
||||
"""Model name to use."""
|
||||
|
||||
extract_reasoning: Optional[Union[bool, Tuple[str, str]]] = False
|
||||
"""Whether to extract the reasoning tokens in think blocks.
|
||||
Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`.
|
||||
If a tuple is supplied, they are assumed to be the (start, end) tokens.
|
||||
If `extract_reasoning=True`, the tokens will default to (<think>, </think>).
|
||||
"""
|
||||
|
||||
mirostat: Optional[int] = None
|
||||
"""Enable Mirostat sampling for controlling perplexity.
|
||||
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"""
|
||||
@ -568,6 +581,28 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
return ollama_messages
|
||||
|
||||
def _extract_reasoning(
|
||||
self, message_chunk: BaseMessageChunk, is_thinking: bool
|
||||
) -> Tuple[BaseMessageChunk, bool]:
|
||||
"""Mutate a message chunk to extract reasoning content."""
|
||||
if not self.extract_reasoning:
|
||||
return message_chunk, is_thinking
|
||||
elif self.extract_reasoning is True:
|
||||
start_token = DEFAULT_THINK_TOKEN_START
|
||||
end_token = DEFAULT_THINK_TOKEN_END
|
||||
else:
|
||||
start_token, end_token = cast(tuple, self.extract_reasoning)
|
||||
if start_token in message_chunk.content:
|
||||
is_thinking = True
|
||||
content = message_chunk.content
|
||||
if is_thinking:
|
||||
message_chunk.additional_kwargs["reasoning_content"] = content
|
||||
message_chunk.content = ""
|
||||
if end_token in content:
|
||||
is_thinking = False
|
||||
|
||||
return message_chunk, is_thinking
|
||||
|
||||
async def _acreate_chat_stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -604,35 +639,17 @@ class ChatOllama(BaseChatModel):
|
||||
**kwargs: Any,
|
||||
) -> ChatGenerationChunk:
|
||||
final_chunk = None
|
||||
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
|
||||
if not isinstance(stream_resp, str):
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content=(
|
||||
stream_resp["message"]["content"]
|
||||
if "message" in stream_resp
|
||||
and "content" in stream_resp["message"]
|
||||
else ""
|
||||
),
|
||||
usage_metadata=_get_usage_metadata_from_generation_info(
|
||||
stream_resp
|
||||
),
|
||||
tool_calls=_get_tool_calls_from_response(stream_resp),
|
||||
),
|
||||
generation_info=(
|
||||
dict(stream_resp) if stream_resp.get("done") is True else None
|
||||
),
|
||||
for chunk in self._iterate_over_stream(messages, stop, **kwargs):
|
||||
if final_chunk is None:
|
||||
final_chunk = chunk
|
||||
else:
|
||||
final_chunk += chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
chunk=chunk,
|
||||
verbose=verbose,
|
||||
)
|
||||
if final_chunk is None:
|
||||
final_chunk = chunk
|
||||
else:
|
||||
final_chunk += chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
chunk=chunk,
|
||||
verbose=verbose,
|
||||
)
|
||||
if final_chunk is None:
|
||||
raise ValueError("No data received from Ollama stream.")
|
||||
|
||||
@ -647,35 +664,17 @@ class ChatOllama(BaseChatModel):
|
||||
**kwargs: Any,
|
||||
) -> ChatGenerationChunk:
|
||||
final_chunk = None
|
||||
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
|
||||
if not isinstance(stream_resp, str):
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content=(
|
||||
stream_resp["message"]["content"]
|
||||
if "message" in stream_resp
|
||||
and "content" in stream_resp["message"]
|
||||
else ""
|
||||
),
|
||||
usage_metadata=_get_usage_metadata_from_generation_info(
|
||||
stream_resp
|
||||
),
|
||||
tool_calls=_get_tool_calls_from_response(stream_resp),
|
||||
),
|
||||
generation_info=(
|
||||
dict(stream_resp) if stream_resp.get("done") is True else None
|
||||
),
|
||||
async for chunk in self._aiterate_over_stream(messages, stop, **kwargs):
|
||||
if final_chunk is None:
|
||||
final_chunk = chunk
|
||||
else:
|
||||
final_chunk += chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
chunk=chunk,
|
||||
verbose=verbose,
|
||||
)
|
||||
if final_chunk is None:
|
||||
final_chunk = chunk
|
||||
else:
|
||||
final_chunk += chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
chunk=chunk,
|
||||
verbose=verbose,
|
||||
)
|
||||
if final_chunk is None:
|
||||
raise ValueError("No data received from Ollama stream.")
|
||||
|
||||
@ -712,18 +711,19 @@ class ChatOllama(BaseChatModel):
|
||||
content=final_chunk.text,
|
||||
usage_metadata=cast(AIMessageChunk, final_chunk.message).usage_metadata,
|
||||
tool_calls=cast(AIMessageChunk, final_chunk.message).tool_calls,
|
||||
additional_kwargs=final_chunk.message.additional_kwargs,
|
||||
),
|
||||
generation_info=generation_info,
|
||||
)
|
||||
return ChatResult(generations=[chat_generation])
|
||||
|
||||
def _stream(
|
||||
def _iterate_over_stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
is_thinking = False
|
||||
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
|
||||
if not isinstance(stream_resp, str):
|
||||
chunk = ChatGenerationChunk(
|
||||
@ -743,20 +743,35 @@ class ChatOllama(BaseChatModel):
|
||||
dict(stream_resp) if stream_resp.get("done") is True else None
|
||||
),
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
if self.extract_reasoning:
|
||||
message, is_thinking = self._extract_reasoning(
|
||||
chunk.message, is_thinking
|
||||
)
|
||||
chunk.message = message
|
||||
yield chunk
|
||||
|
||||
async def _astream(
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
for chunk in self._iterate_over_stream(messages, stop, **kwargs):
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
yield chunk
|
||||
|
||||
async def _aiterate_over_stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
is_thinking = False
|
||||
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
|
||||
if not isinstance(stream_resp, str):
|
||||
chunk = ChatGenerationChunk(
|
||||
@ -776,13 +791,28 @@ class ChatOllama(BaseChatModel):
|
||||
dict(stream_resp) if stream_resp.get("done") is True else None
|
||||
),
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
if self.extract_reasoning:
|
||||
message, is_thinking = self._extract_reasoning(
|
||||
chunk.message, is_thinking
|
||||
)
|
||||
chunk.message = message
|
||||
yield chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
async for chunk in self._aiterate_over_stream(messages, stop, **kwargs):
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
yield chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -799,6 +829,7 @@ class ChatOllama(BaseChatModel):
|
||||
content=final_chunk.text,
|
||||
usage_metadata=cast(AIMessageChunk, final_chunk.message).usage_metadata,
|
||||
tool_calls=cast(AIMessageChunk, final_chunk.message).tool_calls,
|
||||
additional_kwargs=final_chunk.message.additional_kwargs,
|
||||
),
|
||||
generation_info=generation_info,
|
||||
)
|
||||
@ -1083,6 +1114,7 @@ class ChatOllama(BaseChatModel):
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
""" # noqa: E501, D301
|
||||
_ = kwargs.pop("strict", None)
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
is_pydantic_schema = _is_pydantic_class(schema)
|
||||
|
@ -0,0 +1,162 @@
|
||||
"""Ollama specific chat model integration tests for reasoning models."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessageChunk, HumanMessage
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
SAMPLE = "What is 3^3?"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_stream_no_reasoning(model: str) -> None:
|
||||
"""Test deepseek model without parsing."""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "<think>" in result.content and "</think>" in result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_stream_bool(model: str) -> None:
|
||||
"""Test deepseek model with reasoning bool=True"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" in result.additional_kwargs["reasoning_content"]
|
||||
clean_content = (
|
||||
result.additional_kwargs["reasoning_content"]
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.strip()
|
||||
)
|
||||
assert len(clean_content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_stream_tuple(model: str) -> None:
|
||||
"""Test deepseek model with reasoning with tuple=..."""
|
||||
llm = ChatOllama(
|
||||
model=model, num_ctx=2**12, extract_reasoning=("<think>", "</think>")
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" in result.additional_kwargs["reasoning_content"]
|
||||
clean_content = (
|
||||
result.additional_kwargs["reasoning_content"]
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.strip()
|
||||
)
|
||||
assert len(clean_content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_invoke_no_reasoning(model: str) -> None:
|
||||
"""Test deepseek model without parsing using invoke."""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = llm.invoke([message])
|
||||
assert result.content
|
||||
assert "<think>" in result.content and "</think>" in result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_invoke_bool(model: str) -> None:
|
||||
"""Test deepseek model with reasoning bool=True using invoke"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = llm.invoke([message])
|
||||
assert result.content
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" in result.additional_kwargs["reasoning_content"]
|
||||
clean_content = (
|
||||
result.additional_kwargs["reasoning_content"]
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.strip()
|
||||
)
|
||||
assert len(clean_content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_invoke_tuple(model: str) -> None:
|
||||
"""Test deepseek model with reasoning with tuple=... using invoke"""
|
||||
llm = ChatOllama(
|
||||
model=model, num_ctx=2**12, extract_reasoning=("<think>", "</think>")
|
||||
)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = llm.invoke([message])
|
||||
assert result.content
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" in result.additional_kwargs["reasoning_content"]
|
||||
clean_content = (
|
||||
result.additional_kwargs["reasoning_content"]
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.strip()
|
||||
)
|
||||
assert len(clean_content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_invalid(model: str) -> None:
|
||||
"""Test deepseek model with reasoning raises ValidationError"""
|
||||
with pytest.raises(ValidationError):
|
||||
_ = ChatOllama(model=model, extract_reasoning={"invalid": "data"}) # type: ignore[arg-type]
|
@ -23,3 +23,7 @@ class TestChatOllama(ChatModelIntegrationTests):
|
||||
@property
|
||||
def supports_json_mode(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def has_tool_choice(self) -> bool:
|
||||
return False
|
||||
|
Loading…
Reference in New Issue
Block a user