ollama: thinking, tool streaming, docs, tests (#31772)

* New `reasoning` (bool) param to support toggling [Ollama
thinking](https://ollama.com/blog/thinking) (#31573, #31700). If
`reasoning=True`, Ollama's `thinking` content will be placed in the
model responses' `additional_kwargs.reasoning_content`.
  * Supported by:
    * ChatOllama (class level, invocation level TODO)
    * OllamaLLM (TODO)
* Added tests to ensure streaming tool calls is successful (#29129)
* Refactored tests that relied on `extract_reasoning()`
* Myriad docs additions and consistency/typo fixes
* Improved type safety in some spots

Closes #29129
Addresses #31573 and #31700
Supersedes #31701
This commit is contained in:
Mason Daugherty 2025-07-07 13:56:41 -04:00 committed by GitHub
parent 0eb10f31c1
commit e686a70ee0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 630 additions and 213 deletions

View File

@ -38,11 +38,11 @@
"\n",
"\n",
":::caution COMPATIBILITY\n",
"LangChain cannot automatically propagate configuration, including callbacks necessary for astream_events(), to child runnables if you are running async code in python<=3.10. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n",
"LangChain cannot automatically propagate configuration, including callbacks necessary for astream_events(), to child runnables if you are running async code in `python<=3.10`. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n",
"\n",
"If you are running python&lt;=3.10, you will need to manually propagate the `RunnableConfig` object to the child runnable in async environments. For an example of how to manually propagate the config, see the implementation of the `bar` RunnableLambda below.\n",
"If you are running `python<=3.10`, you will need to manually propagate the `RunnableConfig` object to the child runnable in async environments. For an example of how to manually propagate the config, see the implementation of the `bar` RunnableLambda below.\n",
"\n",
"If you are running python>=3.11, the `RunnableConfig` will automatically propagate to child runnables in async environment. However, it is still a good idea to propagate the `RunnableConfig` manually if your code may run in other Python versions.\n",
"If you are running `python>=3.11`, the `RunnableConfig` will automatically propagate to child runnables in async environment. However, it is still a good idea to propagate the `RunnableConfig` manually if your code may run in other Python versions.\n",
":::"
]
},

View File

@ -16,15 +16,15 @@
"\n",
":::\n",
"\n",
"If you have [tools](/docs/concepts/tools/) that call [chat models](/docs/concepts/chat_models/), [retrievers](/docs/concepts/retrievers/), or other [runnables](/docs/concepts/runnables/), you may want to access internal events from those runnables or configure them with additional properties. This guide shows you how to manually pass parameters properly so that you can do this using the `astream_events()` method.\n",
"If you have [tools](/docs/concepts/tools/) that call [chat models](/docs/concepts/chat_models/), [retrievers](/docs/concepts/retrievers/), or other [runnables](/docs/concepts/runnables/), you may want to access [internal events](https://python.langchain.com/docs/how_to/streaming/#event-reference) from those runnables or configure them with additional properties. This guide shows you how to manually pass parameters properly so that you can do this using the `astream_events()` method.\n",
"\n",
":::caution Compatibility\n",
"\n",
"LangChain cannot automatically propagate configuration, including callbacks necessary for `astream_events()`, to child runnables if you are running `async` code in `python&lt;=3.10`. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n",
"LangChain cannot automatically propagate configuration, including callbacks necessary for `astream_events()`, to child runnables if you are running `async` code in `python<=3.10`. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n",
"\n",
"If you are running python&lt;=3.10, you will need to manually propagate the `RunnableConfig` object to the child runnable in async environments. For an example of how to manually propagate the config, see the implementation of the `bar` RunnableLambda below.\n",
"If you are running `python<=3.10`, you will need to manually propagate the `RunnableConfig` object to the child runnable in async environments. For an example of how to manually propagate the config, see the implementation of the `bar` RunnableLambda below.\n",
"\n",
"If you are running python>=3.11, the `RunnableConfig` will automatically propagate to child runnables in async environment. However, it is still a good idea to propagate the `RunnableConfig` manually if your code may run in older Python versions.\n",
"If you are running `python>=3.11`, the `RunnableConfig` will automatically propagate to child runnables in async environment. However, it is still a good idea to propagate the `RunnableConfig` manually if your code may run in older Python versions.\n",
"\n",
"This guide also requires `langchain-core>=0.2.16`.\n",
":::\n",

View File

@ -224,6 +224,13 @@
"source": [
"print(type(gathered.tool_calls[0][\"args\"]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note the key difference: accumulating `tool_call_chunks` captures the raw tool arguments as an unparsed string as they are streamed. In contrast, **accumulating** `tool_calls` demonstrates partial parsing by progressively converting the streamed argument string into a valid, usable dictionary at each step of the process."
]
}
],
"metadata": {

View File

@ -1414,6 +1414,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
ValueError: If the file path is not a string or Path object.
Example:
.. code-block:: python
llm.save(file_path="path/llm.yaml")

View File

@ -1160,22 +1160,21 @@ class Runnable(ABC, Generic[Input, Output]):
A StreamEvent is a dictionary with the following schema:
- ``event``: **str** - Event names are of the
format: on_[runnable_type]_(start|stream|end).
- ``event``: **str** - Event names are of the format:
on_[runnable_type]_(start|stream|end).
- ``name``: **str** - The name of the Runnable that generated the event.
- ``run_id``: **str** - randomly generated ID associated with the given execution of
the Runnable that emitted the event.
A child Runnable that gets invoked as part of the execution of a
parent Runnable is assigned its own unique ID.
- ``parent_ids``: **list[str]** - The IDs of the parent runnables that
generated the event. The root Runnable will have an empty list.
The order of the parent IDs is from the root to the immediate parent.
Only available for v2 version of the API. The v1 version of the API
will return an empty list.
- ``run_id``: **str** - randomly generated ID associated with the given
execution of the Runnable that emitted the event. A child Runnable that gets
invoked as part of the execution of a parent Runnable is assigned its own
unique ID.
- ``parent_ids``: **list[str]** - The IDs of the parent runnables that generated
the event. The root Runnable will have an empty list. The order of the parent
IDs is from the root to the immediate parent. Only available for v2 version of
the API. The v1 version of the API will return an empty list.
- ``tags``: **Optional[list[str]]** - The tags of the Runnable that generated
the event.
- ``metadata``: **Optional[dict[str, Any]]** - The metadata of the Runnable
that generated the event.
- ``metadata``: **Optional[dict[str, Any]]** - The metadata of the Runnable that
generated the event.
- ``data``: **dict[str, Any]**
@ -1183,7 +1182,7 @@ class Runnable(ABC, Generic[Input, Output]):
chains. Metadata fields have been omitted from the table for brevity.
Chain definitions have been included after the table.
**ATTENTION** This reference table is for the V2 version of the schema.
.. NOTE:: This reference table is for the V2 version of the schema.
+----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
| event | name | chunk | input | output |

View File

@ -6,7 +6,6 @@ from operator import itemgetter
from typing import (
Any,
Callable,
Final,
Literal,
Optional,
Union,
@ -25,7 +24,6 @@ from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
HumanMessage,
SystemMessage,
@ -57,9 +55,6 @@ from typing_extensions import Self, is_typeddict
from ._utils import validate_model
DEFAULT_THINK_TOKEN_START: Final[str] = "<think>"
DEFAULT_THINK_TOKEN_END: Final[str] = "</think>"
def _get_usage_metadata_from_generation_info(
generation_info: Optional[Mapping[str, Any]],
@ -166,13 +161,14 @@ def _get_tool_calls_from_response(
return tool_calls
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
def _lc_tool_call_to_openai_tool_call(tool_call_: ToolCall) -> dict:
"""Convert a LangChain tool call to an OpenAI tool call format."""
return {
"type": "function",
"id": tool_call["id"],
"id": tool_call_["id"],
"function": {
"name": tool_call["name"],
"arguments": tool_call["args"],
"name": tool_call_["name"],
"arguments": tool_call_["args"],
},
}
@ -211,6 +207,20 @@ class ChatOllama(BaseChatModel):
Key init args completion params:
model: str
Name of Ollama model to use.
reasoning: Optional[bool]
Controls the reasoning/thinking mode for
`supported models <https://ollama.com/search?c=thinking>`__.
- ``True``: Enables reasoning mode. The model's reasoning process will be
captured and returned separately in the ``additional_kwargs`` of the
response message, under ``reasoning_content``. The main response
content will not include the reasoning tags.
- ``False``: Disables reasoning mode. The model will not perform any reasoning,
and the response will not include any reasoning content.
- ``None`` (Default): The model will use its default reasoning behavior. Note
however, if the model's default behavior *is* to perform reasoning, think tags
()``<think>`` and ``</think>``) will be present within the main response content
unless you set ``reasoning`` to ``True``.
temperature: float
Sampling temperature. Ranges from 0.0 to 1.0.
num_predict: Optional[int]
@ -347,21 +357,29 @@ class ChatOllama(BaseChatModel):
'args': {'a': 45, 'b': 67},
'id': '420c3f3b-df10-4188-945f-eb3abdb40622',
'type': 'tool_call'}]
""" # noqa: E501
""" # noqa: E501, pylint: disable=line-too-long
model: str
"""Model name to use."""
reasoning: Optional[bool] = None
"""Controls the reasoning/thinking mode for
`supported models <https://ollama.com/search?c=thinking>`__.
- ``True``: Enables reasoning mode. The model's reasoning process will be
captured and returned separately in the ``additional_kwargs`` of the
response message, under ``reasoning_content``. The main response
content will not include the reasoning tags.
- ``False``: Disables reasoning mode. The model will not perform any reasoning,
and the response will not include any reasoning content.
- ``None`` (Default): The model will use its default reasoning behavior. Note
however, if the model's default behavior *is* to perform reasoning, think tags
()``<think>`` and ``</think>``) will be present within the main response content
unless you set ``reasoning`` to ``True``."""
validate_model_on_init: bool = False
"""Whether to validate the model exists in Ollama locally on initialization."""
extract_reasoning: Optional[Union[bool, tuple[str, str]]] = False
"""Whether to extract the reasoning tokens in think blocks.
Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`.
If a tuple is supplied, they are assumed to be the (start, end) tokens.
If `extract_reasoning=True`, the tokens will default to (<think>, </think>).
"""
mirostat: Optional[int] = None
"""Enable Mirostat sampling for controlling perplexity.
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"""
@ -448,24 +466,23 @@ class ChatOllama(BaseChatModel):
"""
async_client_kwargs: Optional[dict] = {}
"""Additional kwargs to merge with client_kwargs before passing to the HTTPX
AsyncClient.
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#asyncclient>`__.
"""Additional kwargs to merge with client_kwargs before
passing to the httpx AsyncClient.
`Full list of params. <https://www.python-httpx.org/api/#asyncclient>`__
"""
sync_client_kwargs: Optional[dict] = {}
"""Additional kwargs to merge with client_kwargs before passing to the HTTPX Client.
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#client>`__.
"""Additional kwargs to merge with client_kwargs before
passing to the httpx Client.
`Full list of params. <https://www.python-httpx.org/api/#client>`__
"""
_client: Client = PrivateAttr(default=None) # type: ignore
_client: Client = PrivateAttr()
"""
The client to use for making requests.
"""
_async_client: AsyncClient = PrivateAttr(default=None) # type: ignore
_async_client: AsyncClient = PrivateAttr()
"""
The async client to use for making requests.
"""
@ -480,7 +497,7 @@ class ChatOllama(BaseChatModel):
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
if self.stop is not None:
stop = self.stop
options_dict = kwargs.pop(
@ -508,6 +525,7 @@ class ChatOllama(BaseChatModel):
"messages": ollama_messages,
"stream": kwargs.pop("stream", True),
"model": kwargs.pop("model", self.model),
"think": kwargs.pop("reasoning", self.reasoning),
"format": kwargs.pop("format", self.format),
"options": Options(**options_dict),
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
@ -618,35 +636,13 @@ class ChatOllama(BaseChatModel):
"images": images,
}
if tool_calls:
msg["tool_calls"] = tool_calls # type: ignore
msg["tool_calls"] = tool_calls
if tool_call_id:
msg["tool_call_id"] = tool_call_id
ollama_messages.append(msg)
return ollama_messages
def _extract_reasoning(
self, message_chunk: BaseMessageChunk, is_thinking: bool
) -> tuple[BaseMessageChunk, bool]:
"""Mutate a message chunk to extract reasoning content."""
if not self.extract_reasoning:
return message_chunk, is_thinking
elif self.extract_reasoning is True:
start_token = DEFAULT_THINK_TOKEN_START
end_token = DEFAULT_THINK_TOKEN_END
else:
start_token, end_token = cast(tuple, self.extract_reasoning)
if start_token in message_chunk.content:
is_thinking = True
content = message_chunk.content
if is_thinking:
message_chunk.additional_kwargs["reasoning_content"] = content
message_chunk.content = ""
if end_token in content:
is_thinking = False
return message_chunk, is_thinking
async def _acreate_chat_stream(
self,
messages: list[BaseMessage],
@ -670,8 +666,10 @@ class ChatOllama(BaseChatModel):
chat_params = self._chat_params(messages, stop, **kwargs)
if chat_params["stream"]:
if self._client:
yield from self._client.chat(**chat_params)
else:
if self._client:
yield self._client.chat(**chat_params)
def _chat_stream_with_aggregation(
@ -767,22 +765,34 @@ class ChatOllama(BaseChatModel):
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
is_thinking = False
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
if not isinstance(stream_resp, str):
if stream_resp.get("done") is True:
generation_info = dict(stream_resp)
if "model" in generation_info:
generation_info["model_name"] = generation_info["model"]
_ = generation_info.pop("message", None)
else:
generation_info = None
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content = (
stream_resp["message"]["content"]
if "message" in stream_resp
and "content" in stream_resp["message"]
if "message" in stream_resp and "content" in stream_resp["message"]
else ""
),
)
additional_kwargs = {}
if (
self.reasoning
and "message" in stream_resp
and (thinking_content := stream_resp["message"].get("thinking"))
):
additional_kwargs["reasoning_content"] = thinking_content
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
usage_metadata=_get_usage_metadata_from_generation_info(
stream_resp
),
@ -790,15 +800,7 @@ class ChatOllama(BaseChatModel):
),
generation_info=generation_info,
)
if chunk.generation_info and (
model := chunk.generation_info.get("model")
):
chunk.generation_info["model_name"] = model # backwards compat
if self.extract_reasoning:
message, is_thinking = self._extract_reasoning(
chunk.message, is_thinking
)
chunk.message = message
yield chunk
def _stream(
@ -822,22 +824,34 @@ class ChatOllama(BaseChatModel):
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
is_thinking = False
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
if not isinstance(stream_resp, str):
if stream_resp.get("done") is True:
generation_info = dict(stream_resp)
if "model" in generation_info:
generation_info["model_name"] = generation_info["model"]
_ = generation_info.pop("message", None)
else:
generation_info = None
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content = (
stream_resp["message"]["content"]
if "message" in stream_resp
and "content" in stream_resp["message"]
if "message" in stream_resp and "content" in stream_resp["message"]
else ""
),
)
additional_kwargs = {}
if (
self.reasoning
and "message" in stream_resp
and (thinking_content := stream_resp["message"].get("thinking"))
):
additional_kwargs["reasoning_content"] = thinking_content
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
usage_metadata=_get_usage_metadata_from_generation_info(
stream_resp
),
@ -845,15 +859,7 @@ class ChatOllama(BaseChatModel):
),
generation_info=generation_info,
)
if chunk.generation_info and (
model := chunk.generation_info.get("model")
):
chunk.generation_info["model_name"] = model # backwards compat
if self.extract_reasoning:
message, is_thinking = self._extract_reasoning(
chunk.message, is_thinking
)
chunk.message = message
yield chunk
async def _astream(
@ -950,7 +956,7 @@ class ChatOllama(BaseChatModel):
method: The method for steering model generation, one of:
- "json_schema":
Uses Ollama's structured output API: https://ollama.com/blog/structured-outputs
Uses Ollama's `structured output API <https://ollama.com/blog/structured-outputs>`__
- "function_calling":
Uses Ollama's tool-calling API
- "json_mode":
@ -1267,5 +1273,4 @@ class ChatOllama(BaseChatModel):
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser

View File

@ -151,12 +151,12 @@ class OllamaEmbeddings(BaseModel, Embeddings):
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#client>`__.
"""
_client: Client = PrivateAttr(default=None) # type: ignore
_client: Optional[Client] = PrivateAttr(default=None)
"""
The client to use for making requests.
"""
_async_client: AsyncClient = PrivateAttr(default=None) # type: ignore
_async_client: Optional[AsyncClient] = PrivateAttr(default=None)
"""
The async client to use for making requests.
"""
@ -270,6 +270,11 @@ class OllamaEmbeddings(BaseModel, Embeddings):
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""
if not self._client:
raise ValueError(
"Ollama client is not initialized. "
"Please ensure Ollama is running and the model is loaded."
)
embedded_docs = self._client.embed(
self.model, texts, options=self._default_params, keep_alive=self.keep_alive
)["embeddings"]
@ -281,6 +286,11 @@ class OllamaEmbeddings(BaseModel, Embeddings):
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""
if not self._async_client:
raise ValueError(
"Ollama client is not initialized. "
"Please ensure Ollama is running and the model is loaded."
)
embedded_docs = (
await self._async_client.embed(
self.model, texts, keep_alive=self.keep_alive

View File

@ -36,6 +36,20 @@ class OllamaLLM(BaseLLM):
model: str
"""Model name to use."""
reasoning: Optional[bool] = True
"""Controls the reasoning/thinking mode for
`supported models <https://ollama.com/search?c=thinking>`__.
- ``True``: Enables reasoning mode. The model's reasoning process will be
captured and returned separately in the ``additional_kwargs`` of the
response message, under ``reasoning_content``. The main response
content will not include the reasoning tags.
- ``False``: Disables reasoning mode. The model will not perform any reasoning,
and the response will not include any reasoning content.
- ``None`` (Default): The model will use its default reasoning behavior. If
the model performs reasoning, the ``<think>`` and ``</think>`` tags will
be present directly within the main response content."""
validate_model_on_init: bool = False
"""Whether to validate the model exists in ollama locally on initialization."""
@ -137,12 +151,12 @@ class OllamaLLM(BaseLLM):
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#client>`__.
"""
_client: Client = PrivateAttr(default=None) # type: ignore
_client: Optional[Client] = PrivateAttr(default=None)
"""
The client to use for making requests.
"""
_async_client: AsyncClient = PrivateAttr(default=None) # type: ignore
_async_client: Optional[AsyncClient] = PrivateAttr(default=None)
"""
The async client to use for making requests.
"""
@ -155,7 +169,7 @@ class OllamaLLM(BaseLLM):
) -> dict[str, Any]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
if self.stop is not None:
stop = self.stop
options_dict = kwargs.pop(
@ -183,6 +197,7 @@ class OllamaLLM(BaseLLM):
"prompt": prompt,
"stream": kwargs.pop("stream", True),
"model": kwargs.pop("model", self.model),
"think": kwargs.pop("reasoning", self.reasoning),
"format": kwargs.pop("format", self.format),
"options": Options(**options_dict),
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
@ -230,10 +245,11 @@ class OllamaLLM(BaseLLM):
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
if self._async_client:
async for part in await self._async_client.generate(
**self._generate_params(prompt, stop=stop, **kwargs)
): # type: ignore
yield part # type: ignore
):
yield part
def _create_generate_stream(
self,
@ -241,9 +257,10 @@ class OllamaLLM(BaseLLM):
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Iterator[Union[Mapping[str, Any], str]]:
if self._client:
yield from self._client.generate(
**self._generate_params(prompt, stop=stop, **kwargs)
) # type: ignore
)
async def _astream_with_aggregation(
self,
@ -356,11 +373,19 @@ class OllamaLLM(BaseLLM):
) -> Iterator[GenerationChunk]:
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if not isinstance(stream_resp, str):
additional_kwargs = {}
if thinking_content := stream_resp.get("thinking"):
additional_kwargs["reasoning_content"] = thinking_content
chunk = GenerationChunk(
text=(stream_resp.get("response", "")),
generation_info=(
dict(stream_resp) if stream_resp.get("done") is True else None
generation_info={
"finish_reason": self.stop,
**additional_kwargs,
**(
dict(stream_resp) if stream_resp.get("done") is True else {}
),
},
)
if run_manager:
run_manager.on_llm_new_token(
@ -378,11 +403,19 @@ class OllamaLLM(BaseLLM):
) -> AsyncIterator[GenerationChunk]:
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
if not isinstance(stream_resp, str):
additional_kwargs = {}
if thinking_content := stream_resp.get("thinking"):
additional_kwargs["reasoning_content"] = thinking_content
chunk = GenerationChunk(
text=(stream_resp.get("response", "")),
generation_info=(
dict(stream_resp) if stream_resp.get("done") is True else None
generation_info={
"finish_reason": self.stop,
**additional_kwargs,
**(
dict(stream_resp) if stream_resp.get("done") is True else {}
),
},
)
if run_manager:
await run_manager.on_llm_new_token(

View File

@ -1,17 +1,28 @@
"""Ollama specific chat model integration tests for reasoning models."""
import pytest
from langchain_core.messages import AIMessageChunk, BaseMessageChunk, HumanMessage
from pydantic import ValidationError
from langchain_core.messages import (
AIMessageChunk,
BaseMessageChunk,
HumanMessage,
)
from pydantic import BaseModel, Field
from langchain_ollama import ChatOllama
SAMPLE = "What is 3^3?"
class MathAnswer(BaseModel):
"""A mathematical expression and its numerical answer."""
expression: str = Field(description="The mathematical expression to evaluate.")
answer: int = Field(description="The numerical answer to the expression.")
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_stream_no_reasoning(model: str) -> None:
"""Test deepseek model without parsing."""
def test_stream_no_reasoning(model: str) -> None:
"""Test streaming with `reasoning=False`"""
llm = ChatOllama(model=model, num_ctx=2**12)
messages = [
{
@ -28,14 +39,41 @@ def test_deepseek_messages_stream_no_reasoning(model: str) -> None:
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "<think>" in result.content and "</think>" in result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_stream_bool(model: str) -> None:
"""Test deepseek model with reasoning bool=True"""
llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True)
async def test_astream_no_reasoning(model: str) -> None:
"""Test async streaming with `reasoning=False`"""
llm = ChatOllama(model=model, num_ctx=2**12)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_stream_reasoning_none(model: str) -> None:
"""Test streaming with `reasoning=None`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
messages = [
{
"role": "user",
@ -51,26 +89,41 @@ def test_deepseek_messages_stream_bool(model: str) -> None:
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "<think>" not in result.content and "</think>" not in result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" in result.additional_kwargs["reasoning_content"]
assert "</think>" in result.additional_kwargs["reasoning_content"]
clean_content = (
result.additional_kwargs["reasoning_content"]
.replace("<think>", "")
.replace("</think>", "")
.strip()
)
assert len(clean_content) > 0
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" in result.content and "</think>" in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_stream_tuple(model: str) -> None:
"""Test deepseek model with reasoning with tuple=..."""
llm = ChatOllama(
model=model, num_ctx=2**12, extract_reasoning=("<think>", "</think>")
)
async def test_astream_reasoning_none(model: str) -> None:
"""Test async streaming with `reasoning=None`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" in result.content and "</think>" in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_reasoning_stream(model: str) -> None:
"""Test streaming with `reasoning=True`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
messages = [
{
"role": "user",
@ -86,77 +139,114 @@ def test_deepseek_messages_stream_tuple(model: str) -> None:
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "<think>" not in result.content and "</think>" not in result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" in result.additional_kwargs["reasoning_content"]
assert "</think>" in result.additional_kwargs["reasoning_content"]
clean_content = (
result.additional_kwargs["reasoning_content"]
.replace("<think>", "")
.replace("</think>", "")
.strip()
)
assert len(clean_content) > 0
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_invoke_no_reasoning(model: str) -> None:
"""Test deepseek model without parsing using invoke."""
async def test_reasoning_astream(model: str) -> None:
"""Test async streaming with `reasoning=True`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_invoke_no_reasoning(model: str) -> None:
"""Test using invoke with `reasoning=False`"""
llm = ChatOllama(model=model, num_ctx=2**12)
message = HumanMessage(content=SAMPLE)
result = llm.invoke([message])
assert result.content
assert "<think>" in result.content and "</think>" in result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_invoke_bool(model: str) -> None:
"""Test deepseek model with reasoning bool=True using invoke"""
llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True)
async def test_ainvoke_no_reasoning(model: str) -> None:
"""Test using async invoke with `reasoning=False`"""
llm = ChatOllama(model=model, num_ctx=2**12)
message = HumanMessage(content=SAMPLE)
result = await llm.ainvoke([message])
assert result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_invoke_reasoning_none(model: str) -> None:
"""Test using invoke with `reasoning=None`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
message = HumanMessage(content=SAMPLE)
result = llm.invoke([message])
assert result.content
assert "<think>" not in result.content and "</think>" not in result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" in result.additional_kwargs["reasoning_content"]
assert "</think>" in result.additional_kwargs["reasoning_content"]
clean_content = (
result.additional_kwargs["reasoning_content"]
.replace("<think>", "")
.replace("</think>", "")
.strip()
)
assert len(clean_content) > 0
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" in result.content and "</think>" in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_invoke_tuple(model: str) -> None:
"""Test deepseek model with reasoning with tuple=... using invoke"""
llm = ChatOllama(
model=model, num_ctx=2**12, extract_reasoning=("<think>", "</think>")
)
async def test_ainvoke_reasoning_none(model: str) -> None:
"""Test using async invoke with `reasoning=None`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
message = HumanMessage(content=SAMPLE)
result = await llm.ainvoke([message])
assert result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" in result.content and "</think>" in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_reasoning_invoke(model: str) -> None:
"""Test invoke with `reasoning=True`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
message = HumanMessage(content=SAMPLE)
result = llm.invoke([message])
assert result.content
assert "<think>" not in result.content and "</think>" not in result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" in result.additional_kwargs["reasoning_content"]
assert "</think>" in result.additional_kwargs["reasoning_content"]
clean_content = (
result.additional_kwargs["reasoning_content"]
.replace("<think>", "")
.replace("</think>", "")
.strip()
)
assert len(clean_content) > 0
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_invalid(model: str) -> None:
"""Test deepseek model with reasoning raises ValidationError"""
with pytest.raises(ValidationError):
_ = ChatOllama(model=model, extract_reasoning={"invalid": "data"}) # type: ignore[arg-type]
async def test_reasoning_ainvoke(model: str) -> None:
"""Test invoke with `reasoning=True`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
message = HumanMessage(content=SAMPLE)
result = await llm.ainvoke([message])
assert result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]

View File

@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch
import pytest
from httpx import ConnectError
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolCallChunk
from langchain_core.tools import tool
from langchain_tests.integration_tests import ChatModelIntegrationTests
from ollama import ResponseError
from pydantic import ValidationError
@ -14,6 +16,15 @@ from langchain_ollama.chat_models import ChatOllama
DEFAULT_MODEL_NAME = "llama3.1"
@tool
def get_current_weather(location: str) -> dict:
"""Gets the current weather in a given location."""
if "boston" in location.lower():
return {"temperature": "15°F", "conditions": "snow"}
else:
return {"temperature": "unknown", "conditions": "unknown"}
class TestChatOllama(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> type[ChatOllama]:
@ -29,12 +40,104 @@ class TestChatOllama(ChatModelIntegrationTests):
@property
def has_tool_choice(self) -> bool:
return False # TODO: update after Ollama implements
# TODO: update after Ollama implements
# https://github.com/ollama/ollama/blob/main/docs/openai.md
return False
@property
def supports_image_inputs(self) -> bool:
return True
def test_tool_streaming(self, model: BaseChatModel) -> None:
"""Test that the model can stream tool calls."""
chat_model_with_tools = model.bind_tools([get_current_weather])
prompt = [HumanMessage("What is the weather today in Boston?")]
# Flags and collectors for validation
tool_chunk_found = False
final_tool_calls = []
collected_tool_chunks: list[ToolCallChunk] = []
# Stream the response and inspect the chunks
for chunk in chat_model_with_tools.stream(prompt):
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
if chunk.tool_call_chunks:
tool_chunk_found = True
for tc_chunk in chunk.tool_call_chunks:
collected_tool_chunks.append(tc_chunk)
if chunk.tool_calls:
final_tool_calls.extend(chunk.tool_calls)
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
assert (
len(final_tool_calls) == 1
), f"Expected 1 final tool call, but got {len(final_tool_calls)}"
final_tool_call = final_tool_calls[0]
assert final_tool_call["name"] == "get_current_weather"
assert final_tool_call["args"] == {"location": "Boston"}
assert len(collected_tool_chunks) > 0
assert collected_tool_chunks[0]["name"] == "get_current_weather"
# The ID should be consistent across chunks that have it
tool_call_id = collected_tool_chunks[0].get("id")
assert tool_call_id is not None
assert all(
chunk.get("id") == tool_call_id
for chunk in collected_tool_chunks
if chunk.get("id")
)
assert final_tool_call["id"] == tool_call_id
async def test_tool_astreaming(self, model: BaseChatModel) -> None:
"""Test that the model can stream tool calls."""
chat_model_with_tools = model.bind_tools([get_current_weather])
prompt = [HumanMessage("What is the weather today in Boston?")]
# Flags and collectors for validation
tool_chunk_found = False
final_tool_calls = []
collected_tool_chunks: list[ToolCallChunk] = []
# Stream the response and inspect the chunks
async for chunk in chat_model_with_tools.astream(prompt):
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
if chunk.tool_call_chunks:
tool_chunk_found = True
for tc_chunk in chunk.tool_call_chunks:
collected_tool_chunks.append(tc_chunk)
if chunk.tool_calls:
final_tool_calls.extend(chunk.tool_calls)
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
assert (
len(final_tool_calls) == 1
), f"Expected 1 final tool call, but got {len(final_tool_calls)}"
final_tool_call = final_tool_calls[0]
assert final_tool_call["name"] == "get_current_weather"
assert final_tool_call["args"] == {"location": "Boston"}
assert len(collected_tool_chunks) > 0
assert collected_tool_chunks[0]["name"] == "get_current_weather"
# The ID should be consistent across chunks that have it
tool_call_id = collected_tool_chunks[0].get("id")
assert tool_call_id is not None
assert all(
chunk.get("id") == tool_call_id
for chunk in collected_tool_chunks
if chunk.get("id")
)
assert final_tool_call["id"] == tool_call_id
@pytest.mark.xfail(
reason=(
"Will sometime encounter AssertionErrors where tool responses are "

View File

@ -1,11 +1,15 @@
"""Test OllamaLLM llm."""
import pytest
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
from langchain_core.runnables import RunnableConfig
from langchain_ollama.llms import OllamaLLM
MODEL_NAME = "llama3.1"
SAMPLE = "What is 3^3?"
def test_stream() -> None:
"""Test streaming tokens from OpenAI."""
@ -15,6 +19,59 @@ def test_stream() -> None:
assert isinstance(token, str)
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_stream_no_reasoning(model: str) -> None:
"""Test streaming with `reasoning=False`"""
llm = OllamaLLM(model=model, num_ctx=2**12)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" not in result.additional_kwargs
# Sanity check the old behavior isn't present
assert "<think>" not in result.content and "</think>" not in result.content
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_reasoning_stream(model: str) -> None:
"""Test streaming with `reasoning=True`"""
llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
# Sanity check the old behavior isn't present
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
async def test_astream() -> None:
"""Test streaming tokens from OpenAI."""
llm = OllamaLLM(model=MODEL_NAME)
@ -23,6 +80,59 @@ async def test_astream() -> None:
assert isinstance(token, str)
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_astream_no_reasoning(model: str) -> None:
"""Test async streaming with `reasoning=False`"""
llm = OllamaLLM(model=model, num_ctx=2**12)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" not in result.additional_kwargs
# Sanity check the old behavior isn't present
assert "<think>" not in result.content and "</think>" not in result.content
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_reasoning_astream(model: str) -> None:
"""Test async streaming with `reasoning=True`"""
llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
# Sanity check the old behavior isn't present
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
async def test_abatch() -> None:
"""Test streaming tokens from OllamaLLM."""
llm = OllamaLLM(model=MODEL_NAME)
@ -60,8 +170,68 @@ async def test_ainvoke() -> None:
assert isinstance(result, str)
# TODO
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
# async def test_ainvoke_no_reasoning(model: str) -> None:
# """Test using async invoke with `reasoning=False`"""
# llm = OllamaLLM(model=model, num_ctx=2**12)
# message = SAMPLE
# result = await llm.ainvoke(message)
# assert result.content
# assert "reasoning_content" not in result.additional_kwargs
# # Sanity check the old behavior isn't present
# assert "<think>" not in result.content and "</think>" not in result.content
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
# async def test_reasoning_ainvoke(model: str) -> None:
# """Test invoke with `reasoning=True`"""
# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
# message = SAMPLE
# result = await llm.ainvoke(message)
# assert result.content
# assert "reasoning_content" in result.additional_kwargs
# assert len(result.additional_kwargs["reasoning_content"]) > 0
# # Sanity check the old behavior isn't present
# assert "<think>" not in result.content and "</think>" not in result.content
# assert "<think>" not in result.additional_kwargs["reasoning_content"]
# assert "</think>" not in result.additional_kwargs["reasoning_content"]
def test_invoke() -> None:
"""Test invoke tokens from OllamaLLM."""
llm = OllamaLLM(model=MODEL_NAME)
result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
assert isinstance(result, str)
# TODO
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
# def test_invoke_no_reasoning(model: str) -> None:
# """Test using invoke with `reasoning=False`"""
# llm = OllamaLLM(model=model, num_ctx=2**12)
# message = SAMPLE
# result = llm.invoke(message)
# assert result.content
# assert "reasoning_content" not in result.additional_kwargs
# # Sanity check the old behavior isn't present
# assert "<think>" not in result.content and "</think>" not in result.content
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
# def test_reasoning_invoke(model: str) -> None:
# """Test invoke with `reasoning=True`"""
# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
# message = SAMPLE
# result = llm.invoke(message)
# assert result.content
# assert "reasoning_content" in result.additional_kwargs
# assert len(result.additional_kwargs["reasoning_content"]) > 0
# # Sanity check the old behavior isn't present
# assert "<think>" not in result.content and "</think>" not in result.content
# assert "<think>" not in result.additional_kwargs["reasoning_content"]
# assert "</think>" not in result.additional_kwargs["reasoning_content"]

View File

@ -23,7 +23,7 @@ class TestChatOllama(ChatModelUnitTests):
@property
def chat_model_params(self) -> dict:
return {"model": "llama3-groq-tool-use"}
return {"model": MODEL_NAME}
def test__parse_arguments_from_tool_call() -> None:
@ -51,7 +51,6 @@ def test_arbitrary_roles_accepted_in_chatmessages(
monkeypatch.setattr(Client, "stream", _mock_httpx_client_stream)
llm = ChatOllama(
base_url="http://whocares:11434",
model=MODEL_NAME,
verbose=True,
format=None,

View File

@ -10,7 +10,7 @@ MODEL_NAME = "llama3.1"
def test_initialization() -> None:
"""Test embedding model initialization."""
OllamaEmbeddings(model="llama3", keep_alive=1)
OllamaEmbeddings(model=MODEL_NAME, keep_alive=1)
@patch("langchain_ollama.embeddings.validate_model")

View File

@ -10,25 +10,25 @@ MODEL_NAME = "llama3.1"
def test_initialization() -> None:
"""Test integration initialization."""
OllamaLLM(model="llama3")
OllamaLLM(model=MODEL_NAME)
def test_model_params() -> None:
# Test standard tracing params
llm = OllamaLLM(model="llama3")
llm = OllamaLLM(model=MODEL_NAME)
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "ollama",
"ls_model_type": "llm",
"ls_model_name": "llama3",
"ls_model_name": MODEL_NAME,
}
llm = OllamaLLM(model="llama3", num_predict=3)
llm = OllamaLLM(model=MODEL_NAME, num_predict=3)
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "ollama",
"ls_model_type": "llm",
"ls_model_name": "llama3",
"ls_model_name": MODEL_NAME,
"ls_max_tokens": 3,
}