ollama: thinking, tool streaming, docs, tests (#31772)

* New `reasoning` (bool) param to support toggling [Ollama
thinking](https://ollama.com/blog/thinking) (#31573, #31700). If
`reasoning=True`, Ollama's `thinking` content will be placed in the
model responses' `additional_kwargs.reasoning_content`.
  * Supported by:
    * ChatOllama (class level, invocation level TODO)
    * OllamaLLM (TODO)
* Added tests to ensure streaming tool calls is successful (#29129)
* Refactored tests that relied on `extract_reasoning()`
* Myriad docs additions and consistency/typo fixes
* Improved type safety in some spots

Closes #29129
Addresses #31573 and #31700
Supersedes #31701
This commit is contained in:
Mason Daugherty
2025-07-07 13:56:41 -04:00
committed by GitHub
parent 0eb10f31c1
commit e686a70ee0
14 changed files with 630 additions and 213 deletions

View File

@@ -6,7 +6,6 @@ from operator import itemgetter
from typing import (
Any,
Callable,
Final,
Literal,
Optional,
Union,
@@ -25,7 +24,6 @@ from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
HumanMessage,
SystemMessage,
@@ -57,9 +55,6 @@ from typing_extensions import Self, is_typeddict
from ._utils import validate_model
DEFAULT_THINK_TOKEN_START: Final[str] = "<think>"
DEFAULT_THINK_TOKEN_END: Final[str] = "</think>"
def _get_usage_metadata_from_generation_info(
generation_info: Optional[Mapping[str, Any]],
@@ -166,13 +161,14 @@ def _get_tool_calls_from_response(
return tool_calls
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
def _lc_tool_call_to_openai_tool_call(tool_call_: ToolCall) -> dict:
"""Convert a LangChain tool call to an OpenAI tool call format."""
return {
"type": "function",
"id": tool_call["id"],
"id": tool_call_["id"],
"function": {
"name": tool_call["name"],
"arguments": tool_call["args"],
"name": tool_call_["name"],
"arguments": tool_call_["args"],
},
}
@@ -211,6 +207,20 @@ class ChatOllama(BaseChatModel):
Key init args — completion params:
model: str
Name of Ollama model to use.
reasoning: Optional[bool]
Controls the reasoning/thinking mode for
`supported models <https://ollama.com/search?c=thinking>`__.
- ``True``: Enables reasoning mode. The model's reasoning process will be
captured and returned separately in the ``additional_kwargs`` of the
response message, under ``reasoning_content``. The main response
content will not include the reasoning tags.
- ``False``: Disables reasoning mode. The model will not perform any reasoning,
and the response will not include any reasoning content.
- ``None`` (Default): The model will use its default reasoning behavior. Note
however, if the model's default behavior *is* to perform reasoning, think tags
()``<think>`` and ``</think>``) will be present within the main response content
unless you set ``reasoning`` to ``True``.
temperature: float
Sampling temperature. Ranges from 0.0 to 1.0.
num_predict: Optional[int]
@@ -347,21 +357,29 @@ class ChatOllama(BaseChatModel):
'args': {'a': 45, 'b': 67},
'id': '420c3f3b-df10-4188-945f-eb3abdb40622',
'type': 'tool_call'}]
""" # noqa: E501
""" # noqa: E501, pylint: disable=line-too-long
model: str
"""Model name to use."""
reasoning: Optional[bool] = None
"""Controls the reasoning/thinking mode for
`supported models <https://ollama.com/search?c=thinking>`__.
- ``True``: Enables reasoning mode. The model's reasoning process will be
captured and returned separately in the ``additional_kwargs`` of the
response message, under ``reasoning_content``. The main response
content will not include the reasoning tags.
- ``False``: Disables reasoning mode. The model will not perform any reasoning,
and the response will not include any reasoning content.
- ``None`` (Default): The model will use its default reasoning behavior. Note
however, if the model's default behavior *is* to perform reasoning, think tags
()``<think>`` and ``</think>``) will be present within the main response content
unless you set ``reasoning`` to ``True``."""
validate_model_on_init: bool = False
"""Whether to validate the model exists in Ollama locally on initialization."""
extract_reasoning: Optional[Union[bool, tuple[str, str]]] = False
"""Whether to extract the reasoning tokens in think blocks.
Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`.
If a tuple is supplied, they are assumed to be the (start, end) tokens.
If `extract_reasoning=True`, the tokens will default to (<think>, </think>).
"""
mirostat: Optional[int] = None
"""Enable Mirostat sampling for controlling perplexity.
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"""
@@ -448,24 +466,23 @@ class ChatOllama(BaseChatModel):
"""
async_client_kwargs: Optional[dict] = {}
"""Additional kwargs to merge with client_kwargs before passing to the HTTPX
AsyncClient.
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#asyncclient>`__.
"""Additional kwargs to merge with client_kwargs before
passing to the httpx AsyncClient.
`Full list of params. <https://www.python-httpx.org/api/#asyncclient>`__
"""
sync_client_kwargs: Optional[dict] = {}
"""Additional kwargs to merge with client_kwargs before passing to the HTTPX Client.
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#client>`__.
"""Additional kwargs to merge with client_kwargs before
passing to the httpx Client.
`Full list of params. <https://www.python-httpx.org/api/#client>`__
"""
_client: Client = PrivateAttr(default=None) # type: ignore
_client: Client = PrivateAttr()
"""
The client to use for making requests.
"""
_async_client: AsyncClient = PrivateAttr(default=None) # type: ignore
_async_client: AsyncClient = PrivateAttr()
"""
The async client to use for making requests.
"""
@@ -480,7 +497,7 @@ class ChatOllama(BaseChatModel):
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
if self.stop is not None:
stop = self.stop
options_dict = kwargs.pop(
@@ -508,6 +525,7 @@ class ChatOllama(BaseChatModel):
"messages": ollama_messages,
"stream": kwargs.pop("stream", True),
"model": kwargs.pop("model", self.model),
"think": kwargs.pop("reasoning", self.reasoning),
"format": kwargs.pop("format", self.format),
"options": Options(**options_dict),
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
@@ -618,35 +636,13 @@ class ChatOllama(BaseChatModel):
"images": images,
}
if tool_calls:
msg["tool_calls"] = tool_calls # type: ignore
msg["tool_calls"] = tool_calls
if tool_call_id:
msg["tool_call_id"] = tool_call_id
ollama_messages.append(msg)
return ollama_messages
def _extract_reasoning(
self, message_chunk: BaseMessageChunk, is_thinking: bool
) -> tuple[BaseMessageChunk, bool]:
"""Mutate a message chunk to extract reasoning content."""
if not self.extract_reasoning:
return message_chunk, is_thinking
elif self.extract_reasoning is True:
start_token = DEFAULT_THINK_TOKEN_START
end_token = DEFAULT_THINK_TOKEN_END
else:
start_token, end_token = cast(tuple, self.extract_reasoning)
if start_token in message_chunk.content:
is_thinking = True
content = message_chunk.content
if is_thinking:
message_chunk.additional_kwargs["reasoning_content"] = content
message_chunk.content = ""
if end_token in content:
is_thinking = False
return message_chunk, is_thinking
async def _acreate_chat_stream(
self,
messages: list[BaseMessage],
@@ -670,9 +666,11 @@ class ChatOllama(BaseChatModel):
chat_params = self._chat_params(messages, stop, **kwargs)
if chat_params["stream"]:
yield from self._client.chat(**chat_params)
if self._client:
yield from self._client.chat(**chat_params)
else:
yield self._client.chat(**chat_params)
if self._client:
yield self._client.chat(**chat_params)
def _chat_stream_with_aggregation(
self,
@@ -767,22 +765,34 @@ class ChatOllama(BaseChatModel):
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
is_thinking = False
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
if not isinstance(stream_resp, str):
if stream_resp.get("done") is True:
generation_info = dict(stream_resp)
if "model" in generation_info:
generation_info["model_name"] = generation_info["model"]
_ = generation_info.pop("message", None)
else:
generation_info = None
content = (
stream_resp["message"]["content"]
if "message" in stream_resp and "content" in stream_resp["message"]
else ""
)
additional_kwargs = {}
if (
self.reasoning
and "message" in stream_resp
and (thinking_content := stream_resp["message"].get("thinking"))
):
additional_kwargs["reasoning_content"] = thinking_content
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content=(
stream_resp["message"]["content"]
if "message" in stream_resp
and "content" in stream_resp["message"]
else ""
),
content=content,
additional_kwargs=additional_kwargs,
usage_metadata=_get_usage_metadata_from_generation_info(
stream_resp
),
@@ -790,15 +800,7 @@ class ChatOllama(BaseChatModel):
),
generation_info=generation_info,
)
if chunk.generation_info and (
model := chunk.generation_info.get("model")
):
chunk.generation_info["model_name"] = model # backwards compat
if self.extract_reasoning:
message, is_thinking = self._extract_reasoning(
chunk.message, is_thinking
)
chunk.message = message
yield chunk
def _stream(
@@ -822,22 +824,34 @@ class ChatOllama(BaseChatModel):
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
is_thinking = False
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
if not isinstance(stream_resp, str):
if stream_resp.get("done") is True:
generation_info = dict(stream_resp)
if "model" in generation_info:
generation_info["model_name"] = generation_info["model"]
_ = generation_info.pop("message", None)
else:
generation_info = None
content = (
stream_resp["message"]["content"]
if "message" in stream_resp and "content" in stream_resp["message"]
else ""
)
additional_kwargs = {}
if (
self.reasoning
and "message" in stream_resp
and (thinking_content := stream_resp["message"].get("thinking"))
):
additional_kwargs["reasoning_content"] = thinking_content
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content=(
stream_resp["message"]["content"]
if "message" in stream_resp
and "content" in stream_resp["message"]
else ""
),
content=content,
additional_kwargs=additional_kwargs,
usage_metadata=_get_usage_metadata_from_generation_info(
stream_resp
),
@@ -845,15 +859,7 @@ class ChatOllama(BaseChatModel):
),
generation_info=generation_info,
)
if chunk.generation_info and (
model := chunk.generation_info.get("model")
):
chunk.generation_info["model_name"] = model # backwards compat
if self.extract_reasoning:
message, is_thinking = self._extract_reasoning(
chunk.message, is_thinking
)
chunk.message = message
yield chunk
async def _astream(
@@ -950,7 +956,7 @@ class ChatOllama(BaseChatModel):
method: The method for steering model generation, one of:
- "json_schema":
Uses Ollama's structured output API: https://ollama.com/blog/structured-outputs
Uses Ollama's `structured output API <https://ollama.com/blog/structured-outputs>`__
- "function_calling":
Uses Ollama's tool-calling API
- "json_mode":
@@ -1267,5 +1273,4 @@ class ChatOllama(BaseChatModel):
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
return llm | output_parser

View File

@@ -97,7 +97,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
Embed multiple texts:
.. code-block:: python
input_texts = ["Document 1...", "Document 2..."]
input_texts = ["Document 1...", "Document 2..."]
vectors = embed.embed_documents(input_texts)
print(len(vectors))
# The first 3 coordinates for the first vector
@@ -112,7 +112,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
.. code-block:: python
vector = await embed.aembed_query(input_text)
print(vector[:3])
print(vector[:3])
# multiple:
# await embed.aembed_documents(input_texts)
@@ -151,12 +151,12 @@ class OllamaEmbeddings(BaseModel, Embeddings):
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#client>`__.
"""
_client: Client = PrivateAttr(default=None) # type: ignore
_client: Optional[Client] = PrivateAttr(default=None)
"""
The client to use for making requests.
"""
_async_client: AsyncClient = PrivateAttr(default=None) # type: ignore
_async_client: Optional[AsyncClient] = PrivateAttr(default=None)
"""
The async client to use for making requests.
"""
@@ -270,6 +270,11 @@ class OllamaEmbeddings(BaseModel, Embeddings):
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""
if not self._client:
raise ValueError(
"Ollama client is not initialized. "
"Please ensure Ollama is running and the model is loaded."
)
embedded_docs = self._client.embed(
self.model, texts, options=self._default_params, keep_alive=self.keep_alive
)["embeddings"]
@@ -281,6 +286,11 @@ class OllamaEmbeddings(BaseModel, Embeddings):
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""
if not self._async_client:
raise ValueError(
"Ollama client is not initialized. "
"Please ensure Ollama is running and the model is loaded."
)
embedded_docs = (
await self._async_client.embed(
self.model, texts, keep_alive=self.keep_alive

View File

@@ -36,6 +36,20 @@ class OllamaLLM(BaseLLM):
model: str
"""Model name to use."""
reasoning: Optional[bool] = True
"""Controls the reasoning/thinking mode for
`supported models <https://ollama.com/search?c=thinking>`__.
- ``True``: Enables reasoning mode. The model's reasoning process will be
captured and returned separately in the ``additional_kwargs`` of the
response message, under ``reasoning_content``. The main response
content will not include the reasoning tags.
- ``False``: Disables reasoning mode. The model will not perform any reasoning,
and the response will not include any reasoning content.
- ``None`` (Default): The model will use its default reasoning behavior. If
the model performs reasoning, the ``<think>`` and ``</think>`` tags will
be present directly within the main response content."""
validate_model_on_init: bool = False
"""Whether to validate the model exists in ollama locally on initialization."""
@@ -56,7 +70,7 @@ class OllamaLLM(BaseLLM):
num_ctx: Optional[int] = None
"""Sets the size of the context window used to generate the
next token. (Default: 2048) """
next token. (Default: 2048)"""
num_gpu: Optional[int] = None
"""The number of GPUs to use. On macOS it defaults to 1 to
@@ -137,12 +151,12 @@ class OllamaLLM(BaseLLM):
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#client>`__.
"""
_client: Client = PrivateAttr(default=None) # type: ignore
_client: Optional[Client] = PrivateAttr(default=None)
"""
The client to use for making requests.
"""
_async_client: AsyncClient = PrivateAttr(default=None) # type: ignore
_async_client: Optional[AsyncClient] = PrivateAttr(default=None)
"""
The async client to use for making requests.
"""
@@ -155,7 +169,7 @@ class OllamaLLM(BaseLLM):
) -> dict[str, Any]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
if self.stop is not None:
stop = self.stop
options_dict = kwargs.pop(
@@ -183,6 +197,7 @@ class OllamaLLM(BaseLLM):
"prompt": prompt,
"stream": kwargs.pop("stream", True),
"model": kwargs.pop("model", self.model),
"think": kwargs.pop("reasoning", self.reasoning),
"format": kwargs.pop("format", self.format),
"options": Options(**options_dict),
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
@@ -230,10 +245,11 @@ class OllamaLLM(BaseLLM):
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
async for part in await self._async_client.generate(
**self._generate_params(prompt, stop=stop, **kwargs)
): # type: ignore
yield part # type: ignore
if self._async_client:
async for part in await self._async_client.generate(
**self._generate_params(prompt, stop=stop, **kwargs)
):
yield part
def _create_generate_stream(
self,
@@ -241,9 +257,10 @@ class OllamaLLM(BaseLLM):
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Iterator[Union[Mapping[str, Any], str]]:
yield from self._client.generate(
**self._generate_params(prompt, stop=stop, **kwargs)
) # type: ignore
if self._client:
yield from self._client.generate(
**self._generate_params(prompt, stop=stop, **kwargs)
)
async def _astream_with_aggregation(
self,
@@ -356,11 +373,19 @@ class OllamaLLM(BaseLLM):
) -> Iterator[GenerationChunk]:
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if not isinstance(stream_resp, str):
additional_kwargs = {}
if thinking_content := stream_resp.get("thinking"):
additional_kwargs["reasoning_content"] = thinking_content
chunk = GenerationChunk(
text=(stream_resp.get("response", "")),
generation_info=(
dict(stream_resp) if stream_resp.get("done") is True else None
),
generation_info={
"finish_reason": self.stop,
**additional_kwargs,
**(
dict(stream_resp) if stream_resp.get("done") is True else {}
),
},
)
if run_manager:
run_manager.on_llm_new_token(
@@ -378,11 +403,19 @@ class OllamaLLM(BaseLLM):
) -> AsyncIterator[GenerationChunk]:
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
if not isinstance(stream_resp, str):
additional_kwargs = {}
if thinking_content := stream_resp.get("thinking"):
additional_kwargs["reasoning_content"] = thinking_content
chunk = GenerationChunk(
text=(stream_resp.get("response", "")),
generation_info=(
dict(stream_resp) if stream_resp.get("done") is True else None
),
generation_info={
"finish_reason": self.stop,
**additional_kwargs,
**(
dict(stream_resp) if stream_resp.get("done") is True else {}
),
},
)
if run_manager:
await run_manager.on_llm_new_token(