ollama: thinking, tool streaming, docs, tests (#31772)

* New `reasoning` (bool) param to support toggling [Ollama
thinking](https://ollama.com/blog/thinking) (#31573, #31700). If
`reasoning=True`, Ollama's `thinking` content will be placed in the
model responses' `additional_kwargs.reasoning_content`.
  * Supported by:
    * ChatOllama (class level, invocation level TODO)
    * OllamaLLM (TODO)
* Added tests to ensure streaming tool calls is successful (#29129)
* Refactored tests that relied on `extract_reasoning()`
* Myriad docs additions and consistency/typo fixes
* Improved type safety in some spots

Closes #29129
Addresses #31573 and #31700
Supersedes #31701
This commit is contained in:
Mason Daugherty 2025-07-07 13:56:41 -04:00 committed by GitHub
parent 0eb10f31c1
commit e686a70ee0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 630 additions and 213 deletions

View File

@ -38,11 +38,11 @@
"\n", "\n",
"\n", "\n",
":::caution COMPATIBILITY\n", ":::caution COMPATIBILITY\n",
"LangChain cannot automatically propagate configuration, including callbacks necessary for astream_events(), to child runnables if you are running async code in python&lt;=3.10. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n", "LangChain cannot automatically propagate configuration, including callbacks necessary for astream_events(), to child runnables if you are running async code in `python<=3.10`. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n",
"\n", "\n",
"If you are running python&lt;=3.10, you will need to manually propagate the `RunnableConfig` object to the child runnable in async environments. For an example of how to manually propagate the config, see the implementation of the `bar` RunnableLambda below.\n", "If you are running `python<=3.10`, you will need to manually propagate the `RunnableConfig` object to the child runnable in async environments. For an example of how to manually propagate the config, see the implementation of the `bar` RunnableLambda below.\n",
"\n", "\n",
"If you are running python>=3.11, the `RunnableConfig` will automatically propagate to child runnables in async environment. However, it is still a good idea to propagate the `RunnableConfig` manually if your code may run in other Python versions.\n", "If you are running `python>=3.11`, the `RunnableConfig` will automatically propagate to child runnables in async environment. However, it is still a good idea to propagate the `RunnableConfig` manually if your code may run in other Python versions.\n",
":::" ":::"
] ]
}, },

View File

@ -16,15 +16,15 @@
"\n", "\n",
":::\n", ":::\n",
"\n", "\n",
"If you have [tools](/docs/concepts/tools/) that call [chat models](/docs/concepts/chat_models/), [retrievers](/docs/concepts/retrievers/), or other [runnables](/docs/concepts/runnables/), you may want to access internal events from those runnables or configure them with additional properties. This guide shows you how to manually pass parameters properly so that you can do this using the `astream_events()` method.\n", "If you have [tools](/docs/concepts/tools/) that call [chat models](/docs/concepts/chat_models/), [retrievers](/docs/concepts/retrievers/), or other [runnables](/docs/concepts/runnables/), you may want to access [internal events](https://python.langchain.com/docs/how_to/streaming/#event-reference) from those runnables or configure them with additional properties. This guide shows you how to manually pass parameters properly so that you can do this using the `astream_events()` method.\n",
"\n", "\n",
":::caution Compatibility\n", ":::caution Compatibility\n",
"\n", "\n",
"LangChain cannot automatically propagate configuration, including callbacks necessary for `astream_events()`, to child runnables if you are running `async` code in `python&lt;=3.10`. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n", "LangChain cannot automatically propagate configuration, including callbacks necessary for `astream_events()`, to child runnables if you are running `async` code in `python<=3.10`. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n",
"\n", "\n",
"If you are running python&lt;=3.10, you will need to manually propagate the `RunnableConfig` object to the child runnable in async environments. For an example of how to manually propagate the config, see the implementation of the `bar` RunnableLambda below.\n", "If you are running `python<=3.10`, you will need to manually propagate the `RunnableConfig` object to the child runnable in async environments. For an example of how to manually propagate the config, see the implementation of the `bar` RunnableLambda below.\n",
"\n", "\n",
"If you are running python>=3.11, the `RunnableConfig` will automatically propagate to child runnables in async environment. However, it is still a good idea to propagate the `RunnableConfig` manually if your code may run in older Python versions.\n", "If you are running `python>=3.11`, the `RunnableConfig` will automatically propagate to child runnables in async environment. However, it is still a good idea to propagate the `RunnableConfig` manually if your code may run in older Python versions.\n",
"\n", "\n",
"This guide also requires `langchain-core>=0.2.16`.\n", "This guide also requires `langchain-core>=0.2.16`.\n",
":::\n", ":::\n",

View File

@ -224,6 +224,13 @@
"source": [ "source": [
"print(type(gathered.tool_calls[0][\"args\"]))" "print(type(gathered.tool_calls[0][\"args\"]))"
] ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note the key difference: accumulating `tool_call_chunks` captures the raw tool arguments as an unparsed string as they are streamed. In contrast, **accumulating** `tool_calls` demonstrates partial parsing by progressively converting the streamed argument string into a valid, usable dictionary at each step of the process."
]
} }
], ],
"metadata": { "metadata": {

View File

@ -1414,6 +1414,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
ValueError: If the file path is not a string or Path object. ValueError: If the file path is not a string or Path object.
Example: Example:
.. code-block:: python .. code-block:: python
llm.save(file_path="path/llm.yaml") llm.save(file_path="path/llm.yaml")

View File

@ -1160,22 +1160,21 @@ class Runnable(ABC, Generic[Input, Output]):
A StreamEvent is a dictionary with the following schema: A StreamEvent is a dictionary with the following schema:
- ``event``: **str** - Event names are of the - ``event``: **str** - Event names are of the format:
format: on_[runnable_type]_(start|stream|end). on_[runnable_type]_(start|stream|end).
- ``name``: **str** - The name of the Runnable that generated the event. - ``name``: **str** - The name of the Runnable that generated the event.
- ``run_id``: **str** - randomly generated ID associated with the given execution of - ``run_id``: **str** - randomly generated ID associated with the given
the Runnable that emitted the event. execution of the Runnable that emitted the event. A child Runnable that gets
A child Runnable that gets invoked as part of the execution of a invoked as part of the execution of a parent Runnable is assigned its own
parent Runnable is assigned its own unique ID. unique ID.
- ``parent_ids``: **list[str]** - The IDs of the parent runnables that - ``parent_ids``: **list[str]** - The IDs of the parent runnables that generated
generated the event. The root Runnable will have an empty list. the event. The root Runnable will have an empty list. The order of the parent
The order of the parent IDs is from the root to the immediate parent. IDs is from the root to the immediate parent. Only available for v2 version of
Only available for v2 version of the API. The v1 version of the API the API. The v1 version of the API will return an empty list.
will return an empty list.
- ``tags``: **Optional[list[str]]** - The tags of the Runnable that generated - ``tags``: **Optional[list[str]]** - The tags of the Runnable that generated
the event. the event.
- ``metadata``: **Optional[dict[str, Any]]** - The metadata of the Runnable - ``metadata``: **Optional[dict[str, Any]]** - The metadata of the Runnable that
that generated the event. generated the event.
- ``data``: **dict[str, Any]** - ``data``: **dict[str, Any]**
@ -1183,7 +1182,7 @@ class Runnable(ABC, Generic[Input, Output]):
chains. Metadata fields have been omitted from the table for brevity. chains. Metadata fields have been omitted from the table for brevity.
Chain definitions have been included after the table. Chain definitions have been included after the table.
**ATTENTION** This reference table is for the V2 version of the schema. .. NOTE:: This reference table is for the V2 version of the schema.
+----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
| event | name | chunk | input | output | | event | name | chunk | input | output |

View File

@ -6,7 +6,6 @@ from operator import itemgetter
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Final,
Literal, Literal,
Optional, Optional,
Union, Union,
@ -25,7 +24,6 @@ from langchain_core.messages import (
AIMessage, AIMessage,
AIMessageChunk, AIMessageChunk,
BaseMessage, BaseMessage,
BaseMessageChunk,
ChatMessage, ChatMessage,
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
@ -57,9 +55,6 @@ from typing_extensions import Self, is_typeddict
from ._utils import validate_model from ._utils import validate_model
DEFAULT_THINK_TOKEN_START: Final[str] = "<think>"
DEFAULT_THINK_TOKEN_END: Final[str] = "</think>"
def _get_usage_metadata_from_generation_info( def _get_usage_metadata_from_generation_info(
generation_info: Optional[Mapping[str, Any]], generation_info: Optional[Mapping[str, Any]],
@ -166,13 +161,14 @@ def _get_tool_calls_from_response(
return tool_calls return tool_calls
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict: def _lc_tool_call_to_openai_tool_call(tool_call_: ToolCall) -> dict:
"""Convert a LangChain tool call to an OpenAI tool call format."""
return { return {
"type": "function", "type": "function",
"id": tool_call["id"], "id": tool_call_["id"],
"function": { "function": {
"name": tool_call["name"], "name": tool_call_["name"],
"arguments": tool_call["args"], "arguments": tool_call_["args"],
}, },
} }
@ -211,6 +207,20 @@ class ChatOllama(BaseChatModel):
Key init args completion params: Key init args completion params:
model: str model: str
Name of Ollama model to use. Name of Ollama model to use.
reasoning: Optional[bool]
Controls the reasoning/thinking mode for
`supported models <https://ollama.com/search?c=thinking>`__.
- ``True``: Enables reasoning mode. The model's reasoning process will be
captured and returned separately in the ``additional_kwargs`` of the
response message, under ``reasoning_content``. The main response
content will not include the reasoning tags.
- ``False``: Disables reasoning mode. The model will not perform any reasoning,
and the response will not include any reasoning content.
- ``None`` (Default): The model will use its default reasoning behavior. Note
however, if the model's default behavior *is* to perform reasoning, think tags
()``<think>`` and ``</think>``) will be present within the main response content
unless you set ``reasoning`` to ``True``.
temperature: float temperature: float
Sampling temperature. Ranges from 0.0 to 1.0. Sampling temperature. Ranges from 0.0 to 1.0.
num_predict: Optional[int] num_predict: Optional[int]
@ -347,21 +357,29 @@ class ChatOllama(BaseChatModel):
'args': {'a': 45, 'b': 67}, 'args': {'a': 45, 'b': 67},
'id': '420c3f3b-df10-4188-945f-eb3abdb40622', 'id': '420c3f3b-df10-4188-945f-eb3abdb40622',
'type': 'tool_call'}] 'type': 'tool_call'}]
""" # noqa: E501 """ # noqa: E501, pylint: disable=line-too-long
model: str model: str
"""Model name to use.""" """Model name to use."""
reasoning: Optional[bool] = None
"""Controls the reasoning/thinking mode for
`supported models <https://ollama.com/search?c=thinking>`__.
- ``True``: Enables reasoning mode. The model's reasoning process will be
captured and returned separately in the ``additional_kwargs`` of the
response message, under ``reasoning_content``. The main response
content will not include the reasoning tags.
- ``False``: Disables reasoning mode. The model will not perform any reasoning,
and the response will not include any reasoning content.
- ``None`` (Default): The model will use its default reasoning behavior. Note
however, if the model's default behavior *is* to perform reasoning, think tags
()``<think>`` and ``</think>``) will be present within the main response content
unless you set ``reasoning`` to ``True``."""
validate_model_on_init: bool = False validate_model_on_init: bool = False
"""Whether to validate the model exists in Ollama locally on initialization.""" """Whether to validate the model exists in Ollama locally on initialization."""
extract_reasoning: Optional[Union[bool, tuple[str, str]]] = False
"""Whether to extract the reasoning tokens in think blocks.
Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`.
If a tuple is supplied, they are assumed to be the (start, end) tokens.
If `extract_reasoning=True`, the tokens will default to (<think>, </think>).
"""
mirostat: Optional[int] = None mirostat: Optional[int] = None
"""Enable Mirostat sampling for controlling perplexity. """Enable Mirostat sampling for controlling perplexity.
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)""" (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"""
@ -448,24 +466,23 @@ class ChatOllama(BaseChatModel):
""" """
async_client_kwargs: Optional[dict] = {} async_client_kwargs: Optional[dict] = {}
"""Additional kwargs to merge with client_kwargs before passing to the HTTPX """Additional kwargs to merge with client_kwargs before
AsyncClient. passing to the httpx AsyncClient.
`Full list of params. <https://www.python-httpx.org/api/#asyncclient>`__
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#asyncclient>`__.
""" """
sync_client_kwargs: Optional[dict] = {} sync_client_kwargs: Optional[dict] = {}
"""Additional kwargs to merge with client_kwargs before passing to the HTTPX Client. """Additional kwargs to merge with client_kwargs before
passing to the httpx Client.
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#client>`__. `Full list of params. <https://www.python-httpx.org/api/#client>`__
""" """
_client: Client = PrivateAttr(default=None) # type: ignore _client: Client = PrivateAttr()
""" """
The client to use for making requests. The client to use for making requests.
""" """
_async_client: AsyncClient = PrivateAttr(default=None) # type: ignore _async_client: AsyncClient = PrivateAttr()
""" """
The async client to use for making requests. The async client to use for making requests.
""" """
@ -480,7 +497,7 @@ class ChatOllama(BaseChatModel):
if self.stop is not None and stop is not None: if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.") raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None: if self.stop is not None:
stop = self.stop stop = self.stop
options_dict = kwargs.pop( options_dict = kwargs.pop(
@ -508,6 +525,7 @@ class ChatOllama(BaseChatModel):
"messages": ollama_messages, "messages": ollama_messages,
"stream": kwargs.pop("stream", True), "stream": kwargs.pop("stream", True),
"model": kwargs.pop("model", self.model), "model": kwargs.pop("model", self.model),
"think": kwargs.pop("reasoning", self.reasoning),
"format": kwargs.pop("format", self.format), "format": kwargs.pop("format", self.format),
"options": Options(**options_dict), "options": Options(**options_dict),
"keep_alive": kwargs.pop("keep_alive", self.keep_alive), "keep_alive": kwargs.pop("keep_alive", self.keep_alive),
@ -618,35 +636,13 @@ class ChatOllama(BaseChatModel):
"images": images, "images": images,
} }
if tool_calls: if tool_calls:
msg["tool_calls"] = tool_calls # type: ignore msg["tool_calls"] = tool_calls
if tool_call_id: if tool_call_id:
msg["tool_call_id"] = tool_call_id msg["tool_call_id"] = tool_call_id
ollama_messages.append(msg) ollama_messages.append(msg)
return ollama_messages return ollama_messages
def _extract_reasoning(
self, message_chunk: BaseMessageChunk, is_thinking: bool
) -> tuple[BaseMessageChunk, bool]:
"""Mutate a message chunk to extract reasoning content."""
if not self.extract_reasoning:
return message_chunk, is_thinking
elif self.extract_reasoning is True:
start_token = DEFAULT_THINK_TOKEN_START
end_token = DEFAULT_THINK_TOKEN_END
else:
start_token, end_token = cast(tuple, self.extract_reasoning)
if start_token in message_chunk.content:
is_thinking = True
content = message_chunk.content
if is_thinking:
message_chunk.additional_kwargs["reasoning_content"] = content
message_chunk.content = ""
if end_token in content:
is_thinking = False
return message_chunk, is_thinking
async def _acreate_chat_stream( async def _acreate_chat_stream(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@ -670,8 +666,10 @@ class ChatOllama(BaseChatModel):
chat_params = self._chat_params(messages, stop, **kwargs) chat_params = self._chat_params(messages, stop, **kwargs)
if chat_params["stream"]: if chat_params["stream"]:
if self._client:
yield from self._client.chat(**chat_params) yield from self._client.chat(**chat_params)
else: else:
if self._client:
yield self._client.chat(**chat_params) yield self._client.chat(**chat_params)
def _chat_stream_with_aggregation( def _chat_stream_with_aggregation(
@ -767,22 +765,34 @@ class ChatOllama(BaseChatModel):
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
is_thinking = False
for stream_resp in self._create_chat_stream(messages, stop, **kwargs): for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
if not isinstance(stream_resp, str): if not isinstance(stream_resp, str):
if stream_resp.get("done") is True: if stream_resp.get("done") is True:
generation_info = dict(stream_resp) generation_info = dict(stream_resp)
if "model" in generation_info:
generation_info["model_name"] = generation_info["model"]
_ = generation_info.pop("message", None) _ = generation_info.pop("message", None)
else: else:
generation_info = None generation_info = None
content = (
stream_resp["message"]["content"]
if "message" in stream_resp and "content" in stream_resp["message"]
else ""
)
additional_kwargs = {}
if (
self.reasoning
and "message" in stream_resp
and (thinking_content := stream_resp["message"].get("thinking"))
):
additional_kwargs["reasoning_content"] = thinking_content
chunk = ChatGenerationChunk( chunk = ChatGenerationChunk(
message=AIMessageChunk( message=AIMessageChunk(
content=( content=content,
stream_resp["message"]["content"] additional_kwargs=additional_kwargs,
if "message" in stream_resp
and "content" in stream_resp["message"]
else ""
),
usage_metadata=_get_usage_metadata_from_generation_info( usage_metadata=_get_usage_metadata_from_generation_info(
stream_resp stream_resp
), ),
@ -790,15 +800,7 @@ class ChatOllama(BaseChatModel):
), ),
generation_info=generation_info, generation_info=generation_info,
) )
if chunk.generation_info and (
model := chunk.generation_info.get("model")
):
chunk.generation_info["model_name"] = model # backwards compat
if self.extract_reasoning:
message, is_thinking = self._extract_reasoning(
chunk.message, is_thinking
)
chunk.message = message
yield chunk yield chunk
def _stream( def _stream(
@ -822,22 +824,34 @@ class ChatOllama(BaseChatModel):
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
is_thinking = False
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs): async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
if not isinstance(stream_resp, str): if not isinstance(stream_resp, str):
if stream_resp.get("done") is True: if stream_resp.get("done") is True:
generation_info = dict(stream_resp) generation_info = dict(stream_resp)
if "model" in generation_info:
generation_info["model_name"] = generation_info["model"]
_ = generation_info.pop("message", None) _ = generation_info.pop("message", None)
else: else:
generation_info = None generation_info = None
content = (
stream_resp["message"]["content"]
if "message" in stream_resp and "content" in stream_resp["message"]
else ""
)
additional_kwargs = {}
if (
self.reasoning
and "message" in stream_resp
and (thinking_content := stream_resp["message"].get("thinking"))
):
additional_kwargs["reasoning_content"] = thinking_content
chunk = ChatGenerationChunk( chunk = ChatGenerationChunk(
message=AIMessageChunk( message=AIMessageChunk(
content=( content=content,
stream_resp["message"]["content"] additional_kwargs=additional_kwargs,
if "message" in stream_resp
and "content" in stream_resp["message"]
else ""
),
usage_metadata=_get_usage_metadata_from_generation_info( usage_metadata=_get_usage_metadata_from_generation_info(
stream_resp stream_resp
), ),
@ -845,15 +859,7 @@ class ChatOllama(BaseChatModel):
), ),
generation_info=generation_info, generation_info=generation_info,
) )
if chunk.generation_info and (
model := chunk.generation_info.get("model")
):
chunk.generation_info["model_name"] = model # backwards compat
if self.extract_reasoning:
message, is_thinking = self._extract_reasoning(
chunk.message, is_thinking
)
chunk.message = message
yield chunk yield chunk
async def _astream( async def _astream(
@ -950,7 +956,7 @@ class ChatOllama(BaseChatModel):
method: The method for steering model generation, one of: method: The method for steering model generation, one of:
- "json_schema": - "json_schema":
Uses Ollama's structured output API: https://ollama.com/blog/structured-outputs Uses Ollama's `structured output API <https://ollama.com/blog/structured-outputs>`__
- "function_calling": - "function_calling":
Uses Ollama's tool-calling API Uses Ollama's tool-calling API
- "json_mode": - "json_mode":
@ -1267,5 +1273,4 @@ class ChatOllama(BaseChatModel):
[parser_none], exception_key="parsing_error" [parser_none], exception_key="parsing_error"
) )
return RunnableMap(raw=llm) | parser_with_fallback return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser return llm | output_parser

View File

@ -151,12 +151,12 @@ class OllamaEmbeddings(BaseModel, Embeddings):
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#client>`__. For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#client>`__.
""" """
_client: Client = PrivateAttr(default=None) # type: ignore _client: Optional[Client] = PrivateAttr(default=None)
""" """
The client to use for making requests. The client to use for making requests.
""" """
_async_client: AsyncClient = PrivateAttr(default=None) # type: ignore _async_client: Optional[AsyncClient] = PrivateAttr(default=None)
""" """
The async client to use for making requests. The async client to use for making requests.
""" """
@ -270,6 +270,11 @@ class OllamaEmbeddings(BaseModel, Embeddings):
def embed_documents(self, texts: list[str]) -> list[list[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs.""" """Embed search docs."""
if not self._client:
raise ValueError(
"Ollama client is not initialized. "
"Please ensure Ollama is running and the model is loaded."
)
embedded_docs = self._client.embed( embedded_docs = self._client.embed(
self.model, texts, options=self._default_params, keep_alive=self.keep_alive self.model, texts, options=self._default_params, keep_alive=self.keep_alive
)["embeddings"] )["embeddings"]
@ -281,6 +286,11 @@ class OllamaEmbeddings(BaseModel, Embeddings):
async def aembed_documents(self, texts: list[str]) -> list[list[float]]: async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs.""" """Embed search docs."""
if not self._async_client:
raise ValueError(
"Ollama client is not initialized. "
"Please ensure Ollama is running and the model is loaded."
)
embedded_docs = ( embedded_docs = (
await self._async_client.embed( await self._async_client.embed(
self.model, texts, keep_alive=self.keep_alive self.model, texts, keep_alive=self.keep_alive

View File

@ -36,6 +36,20 @@ class OllamaLLM(BaseLLM):
model: str model: str
"""Model name to use.""" """Model name to use."""
reasoning: Optional[bool] = True
"""Controls the reasoning/thinking mode for
`supported models <https://ollama.com/search?c=thinking>`__.
- ``True``: Enables reasoning mode. The model's reasoning process will be
captured and returned separately in the ``additional_kwargs`` of the
response message, under ``reasoning_content``. The main response
content will not include the reasoning tags.
- ``False``: Disables reasoning mode. The model will not perform any reasoning,
and the response will not include any reasoning content.
- ``None`` (Default): The model will use its default reasoning behavior. If
the model performs reasoning, the ``<think>`` and ``</think>`` tags will
be present directly within the main response content."""
validate_model_on_init: bool = False validate_model_on_init: bool = False
"""Whether to validate the model exists in ollama locally on initialization.""" """Whether to validate the model exists in ollama locally on initialization."""
@ -56,7 +70,7 @@ class OllamaLLM(BaseLLM):
num_ctx: Optional[int] = None num_ctx: Optional[int] = None
"""Sets the size of the context window used to generate the """Sets the size of the context window used to generate the
next token. (Default: 2048) """ next token. (Default: 2048)"""
num_gpu: Optional[int] = None num_gpu: Optional[int] = None
"""The number of GPUs to use. On macOS it defaults to 1 to """The number of GPUs to use. On macOS it defaults to 1 to
@ -137,12 +151,12 @@ class OllamaLLM(BaseLLM):
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#client>`__. For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#client>`__.
""" """
_client: Client = PrivateAttr(default=None) # type: ignore _client: Optional[Client] = PrivateAttr(default=None)
""" """
The client to use for making requests. The client to use for making requests.
""" """
_async_client: AsyncClient = PrivateAttr(default=None) # type: ignore _async_client: Optional[AsyncClient] = PrivateAttr(default=None)
""" """
The async client to use for making requests. The async client to use for making requests.
""" """
@ -155,7 +169,7 @@ class OllamaLLM(BaseLLM):
) -> dict[str, Any]: ) -> dict[str, Any]:
if self.stop is not None and stop is not None: if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.") raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None: if self.stop is not None:
stop = self.stop stop = self.stop
options_dict = kwargs.pop( options_dict = kwargs.pop(
@ -183,6 +197,7 @@ class OllamaLLM(BaseLLM):
"prompt": prompt, "prompt": prompt,
"stream": kwargs.pop("stream", True), "stream": kwargs.pop("stream", True),
"model": kwargs.pop("model", self.model), "model": kwargs.pop("model", self.model),
"think": kwargs.pop("reasoning", self.reasoning),
"format": kwargs.pop("format", self.format), "format": kwargs.pop("format", self.format),
"options": Options(**options_dict), "options": Options(**options_dict),
"keep_alive": kwargs.pop("keep_alive", self.keep_alive), "keep_alive": kwargs.pop("keep_alive", self.keep_alive),
@ -230,10 +245,11 @@ class OllamaLLM(BaseLLM):
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Union[Mapping[str, Any], str]]: ) -> AsyncIterator[Union[Mapping[str, Any], str]]:
if self._async_client:
async for part in await self._async_client.generate( async for part in await self._async_client.generate(
**self._generate_params(prompt, stop=stop, **kwargs) **self._generate_params(prompt, stop=stop, **kwargs)
): # type: ignore ):
yield part # type: ignore yield part
def _create_generate_stream( def _create_generate_stream(
self, self,
@ -241,9 +257,10 @@ class OllamaLLM(BaseLLM):
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Union[Mapping[str, Any], str]]: ) -> Iterator[Union[Mapping[str, Any], str]]:
if self._client:
yield from self._client.generate( yield from self._client.generate(
**self._generate_params(prompt, stop=stop, **kwargs) **self._generate_params(prompt, stop=stop, **kwargs)
) # type: ignore )
async def _astream_with_aggregation( async def _astream_with_aggregation(
self, self,
@ -356,11 +373,19 @@ class OllamaLLM(BaseLLM):
) -> Iterator[GenerationChunk]: ) -> Iterator[GenerationChunk]:
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs): for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if not isinstance(stream_resp, str): if not isinstance(stream_resp, str):
additional_kwargs = {}
if thinking_content := stream_resp.get("thinking"):
additional_kwargs["reasoning_content"] = thinking_content
chunk = GenerationChunk( chunk = GenerationChunk(
text=(stream_resp.get("response", "")), text=(stream_resp.get("response", "")),
generation_info=( generation_info={
dict(stream_resp) if stream_resp.get("done") is True else None "finish_reason": self.stop,
**additional_kwargs,
**(
dict(stream_resp) if stream_resp.get("done") is True else {}
), ),
},
) )
if run_manager: if run_manager:
run_manager.on_llm_new_token( run_manager.on_llm_new_token(
@ -378,11 +403,19 @@ class OllamaLLM(BaseLLM):
) -> AsyncIterator[GenerationChunk]: ) -> AsyncIterator[GenerationChunk]:
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs): async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
if not isinstance(stream_resp, str): if not isinstance(stream_resp, str):
additional_kwargs = {}
if thinking_content := stream_resp.get("thinking"):
additional_kwargs["reasoning_content"] = thinking_content
chunk = GenerationChunk( chunk = GenerationChunk(
text=(stream_resp.get("response", "")), text=(stream_resp.get("response", "")),
generation_info=( generation_info={
dict(stream_resp) if stream_resp.get("done") is True else None "finish_reason": self.stop,
**additional_kwargs,
**(
dict(stream_resp) if stream_resp.get("done") is True else {}
), ),
},
) )
if run_manager: if run_manager:
await run_manager.on_llm_new_token( await run_manager.on_llm_new_token(

View File

@ -1,17 +1,28 @@
"""Ollama specific chat model integration tests for reasoning models.""" """Ollama specific chat model integration tests for reasoning models."""
import pytest import pytest
from langchain_core.messages import AIMessageChunk, BaseMessageChunk, HumanMessage from langchain_core.messages import (
from pydantic import ValidationError AIMessageChunk,
BaseMessageChunk,
HumanMessage,
)
from pydantic import BaseModel, Field
from langchain_ollama import ChatOllama from langchain_ollama import ChatOllama
SAMPLE = "What is 3^3?" SAMPLE = "What is 3^3?"
class MathAnswer(BaseModel):
"""A mathematical expression and its numerical answer."""
expression: str = Field(description="The mathematical expression to evaluate.")
answer: int = Field(description="The numerical answer to the expression.")
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_stream_no_reasoning(model: str) -> None: def test_stream_no_reasoning(model: str) -> None:
"""Test deepseek model without parsing.""" """Test streaming with `reasoning=False`"""
llm = ChatOllama(model=model, num_ctx=2**12) llm = ChatOllama(model=model, num_ctx=2**12)
messages = [ messages = [
{ {
@ -28,14 +39,41 @@ def test_deepseek_messages_stream_no_reasoning(model: str) -> None:
result += chunk result += chunk
assert isinstance(result, AIMessageChunk) assert isinstance(result, AIMessageChunk)
assert result.content assert result.content
assert "<think>" in result.content and "</think>" in result.content
assert "reasoning_content" not in result.additional_kwargs assert "reasoning_content" not in result.additional_kwargs
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_stream_bool(model: str) -> None: async def test_astream_no_reasoning(model: str) -> None:
"""Test deepseek model with reasoning bool=True""" """Test async streaming with `reasoning=False`"""
llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True) llm = ChatOllama(model=model, num_ctx=2**12)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_stream_reasoning_none(model: str) -> None:
"""Test streaming with `reasoning=None`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
messages = [ messages = [
{ {
"role": "user", "role": "user",
@ -51,26 +89,41 @@ def test_deepseek_messages_stream_bool(model: str) -> None:
result += chunk result += chunk
assert isinstance(result, AIMessageChunk) assert isinstance(result, AIMessageChunk)
assert result.content assert result.content
assert "<think>" not in result.content and "</think>" not in result.content assert "reasoning_content" not in result.additional_kwargs
assert "reasoning_content" in result.additional_kwargs assert "<think>" in result.content and "</think>" in result.content
assert len(result.additional_kwargs["reasoning_content"]) > 0 assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "<think>" in result.additional_kwargs["reasoning_content"] assert "</think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" in result.additional_kwargs["reasoning_content"]
clean_content = (
result.additional_kwargs["reasoning_content"]
.replace("<think>", "")
.replace("</think>", "")
.strip()
)
assert len(clean_content) > 0
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_stream_tuple(model: str) -> None: async def test_astream_reasoning_none(model: str) -> None:
"""Test deepseek model with reasoning with tuple=...""" """Test async streaming with `reasoning=None`"""
llm = ChatOllama( llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
model=model, num_ctx=2**12, extract_reasoning=("<think>", "</think>") messages = [
) {
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" in result.content and "</think>" in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_reasoning_stream(model: str) -> None:
"""Test streaming with `reasoning=True`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
messages = [ messages = [
{ {
"role": "user", "role": "user",
@ -86,77 +139,114 @@ def test_deepseek_messages_stream_tuple(model: str) -> None:
result += chunk result += chunk
assert isinstance(result, AIMessageChunk) assert isinstance(result, AIMessageChunk)
assert result.content assert result.content
assert "<think>" not in result.content and "</think>" not in result.content
assert "reasoning_content" in result.additional_kwargs assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0 assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" in result.additional_kwargs["reasoning_content"] assert "<think>" not in result.content and "</think>" not in result.content
assert "</think>" in result.additional_kwargs["reasoning_content"] assert "<think>" not in result.additional_kwargs["reasoning_content"]
clean_content = ( assert "</think>" not in result.additional_kwargs["reasoning_content"]
result.additional_kwargs["reasoning_content"]
.replace("<think>", "")
.replace("</think>", "")
.strip()
)
assert len(clean_content) > 0
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_invoke_no_reasoning(model: str) -> None: async def test_reasoning_astream(model: str) -> None:
"""Test deepseek model without parsing using invoke.""" """Test async streaming with `reasoning=True`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_invoke_no_reasoning(model: str) -> None:
"""Test using invoke with `reasoning=False`"""
llm = ChatOllama(model=model, num_ctx=2**12) llm = ChatOllama(model=model, num_ctx=2**12)
message = HumanMessage(content=SAMPLE) message = HumanMessage(content=SAMPLE)
result = llm.invoke([message]) result = llm.invoke([message])
assert result.content assert result.content
assert "<think>" in result.content and "</think>" in result.content
assert "reasoning_content" not in result.additional_kwargs assert "reasoning_content" not in result.additional_kwargs
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_invoke_bool(model: str) -> None: async def test_ainvoke_no_reasoning(model: str) -> None:
"""Test deepseek model with reasoning bool=True using invoke""" """Test using async invoke with `reasoning=False`"""
llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True) llm = ChatOllama(model=model, num_ctx=2**12)
message = HumanMessage(content=SAMPLE)
result = await llm.ainvoke([message])
assert result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_invoke_reasoning_none(model: str) -> None:
"""Test using invoke with `reasoning=None`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
message = HumanMessage(content=SAMPLE) message = HumanMessage(content=SAMPLE)
result = llm.invoke([message]) result = llm.invoke([message])
assert result.content assert result.content
assert "<think>" not in result.content and "</think>" not in result.content assert "reasoning_content" not in result.additional_kwargs
assert "reasoning_content" in result.additional_kwargs assert "<think>" in result.content and "</think>" in result.content
assert len(result.additional_kwargs["reasoning_content"]) > 0 assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "<think>" in result.additional_kwargs["reasoning_content"] assert "</think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" in result.additional_kwargs["reasoning_content"]
clean_content = (
result.additional_kwargs["reasoning_content"]
.replace("<think>", "")
.replace("</think>", "")
.strip()
)
assert len(clean_content) > 0
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_messages_invoke_tuple(model: str) -> None: async def test_ainvoke_reasoning_none(model: str) -> None:
"""Test deepseek model with reasoning with tuple=... using invoke""" """Test using async invoke with `reasoning=None`"""
llm = ChatOllama( llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
model=model, num_ctx=2**12, extract_reasoning=("<think>", "</think>") message = HumanMessage(content=SAMPLE)
) result = await llm.ainvoke([message])
assert result.content
assert "reasoning_content" not in result.additional_kwargs
assert "<think>" in result.content and "</think>" in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_reasoning_invoke(model: str) -> None:
"""Test invoke with `reasoning=True`"""
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
message = HumanMessage(content=SAMPLE) message = HumanMessage(content=SAMPLE)
result = llm.invoke([message]) result = llm.invoke([message])
assert result.content assert result.content
assert "<think>" not in result.content and "</think>" not in result.content
assert "reasoning_content" in result.additional_kwargs assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0 assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" in result.additional_kwargs["reasoning_content"] assert "<think>" not in result.content and "</think>" not in result.content
assert "</think>" in result.additional_kwargs["reasoning_content"] assert "<think>" not in result.additional_kwargs["reasoning_content"]
clean_content = ( assert "</think>" not in result.additional_kwargs["reasoning_content"]
result.additional_kwargs["reasoning_content"]
.replace("<think>", "")
.replace("</think>", "")
.strip()
)
assert len(clean_content) > 0
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_deepseek_invalid(model: str) -> None: async def test_reasoning_ainvoke(model: str) -> None:
"""Test deepseek model with reasoning raises ValidationError""" """Test invoke with `reasoning=True`"""
with pytest.raises(ValidationError): llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
_ = ChatOllama(model=model, extract_reasoning={"invalid": "data"}) # type: ignore[arg-type] message = HumanMessage(content=SAMPLE)
result = await llm.ainvoke([message])
assert result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]

View File

@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from httpx import ConnectError from httpx import ConnectError
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolCallChunk
from langchain_core.tools import tool
from langchain_tests.integration_tests import ChatModelIntegrationTests from langchain_tests.integration_tests import ChatModelIntegrationTests
from ollama import ResponseError from ollama import ResponseError
from pydantic import ValidationError from pydantic import ValidationError
@ -14,6 +16,15 @@ from langchain_ollama.chat_models import ChatOllama
DEFAULT_MODEL_NAME = "llama3.1" DEFAULT_MODEL_NAME = "llama3.1"
@tool
def get_current_weather(location: str) -> dict:
"""Gets the current weather in a given location."""
if "boston" in location.lower():
return {"temperature": "15°F", "conditions": "snow"}
else:
return {"temperature": "unknown", "conditions": "unknown"}
class TestChatOllama(ChatModelIntegrationTests): class TestChatOllama(ChatModelIntegrationTests):
@property @property
def chat_model_class(self) -> type[ChatOllama]: def chat_model_class(self) -> type[ChatOllama]:
@ -29,12 +40,104 @@ class TestChatOllama(ChatModelIntegrationTests):
@property @property
def has_tool_choice(self) -> bool: def has_tool_choice(self) -> bool:
return False # TODO: update after Ollama implements # TODO: update after Ollama implements
# https://github.com/ollama/ollama/blob/main/docs/openai.md
return False
@property @property
def supports_image_inputs(self) -> bool: def supports_image_inputs(self) -> bool:
return True return True
def test_tool_streaming(self, model: BaseChatModel) -> None:
"""Test that the model can stream tool calls."""
chat_model_with_tools = model.bind_tools([get_current_weather])
prompt = [HumanMessage("What is the weather today in Boston?")]
# Flags and collectors for validation
tool_chunk_found = False
final_tool_calls = []
collected_tool_chunks: list[ToolCallChunk] = []
# Stream the response and inspect the chunks
for chunk in chat_model_with_tools.stream(prompt):
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
if chunk.tool_call_chunks:
tool_chunk_found = True
for tc_chunk in chunk.tool_call_chunks:
collected_tool_chunks.append(tc_chunk)
if chunk.tool_calls:
final_tool_calls.extend(chunk.tool_calls)
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
assert (
len(final_tool_calls) == 1
), f"Expected 1 final tool call, but got {len(final_tool_calls)}"
final_tool_call = final_tool_calls[0]
assert final_tool_call["name"] == "get_current_weather"
assert final_tool_call["args"] == {"location": "Boston"}
assert len(collected_tool_chunks) > 0
assert collected_tool_chunks[0]["name"] == "get_current_weather"
# The ID should be consistent across chunks that have it
tool_call_id = collected_tool_chunks[0].get("id")
assert tool_call_id is not None
assert all(
chunk.get("id") == tool_call_id
for chunk in collected_tool_chunks
if chunk.get("id")
)
assert final_tool_call["id"] == tool_call_id
async def test_tool_astreaming(self, model: BaseChatModel) -> None:
"""Test that the model can stream tool calls."""
chat_model_with_tools = model.bind_tools([get_current_weather])
prompt = [HumanMessage("What is the weather today in Boston?")]
# Flags and collectors for validation
tool_chunk_found = False
final_tool_calls = []
collected_tool_chunks: list[ToolCallChunk] = []
# Stream the response and inspect the chunks
async for chunk in chat_model_with_tools.astream(prompt):
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
if chunk.tool_call_chunks:
tool_chunk_found = True
for tc_chunk in chunk.tool_call_chunks:
collected_tool_chunks.append(tc_chunk)
if chunk.tool_calls:
final_tool_calls.extend(chunk.tool_calls)
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
assert (
len(final_tool_calls) == 1
), f"Expected 1 final tool call, but got {len(final_tool_calls)}"
final_tool_call = final_tool_calls[0]
assert final_tool_call["name"] == "get_current_weather"
assert final_tool_call["args"] == {"location": "Boston"}
assert len(collected_tool_chunks) > 0
assert collected_tool_chunks[0]["name"] == "get_current_weather"
# The ID should be consistent across chunks that have it
tool_call_id = collected_tool_chunks[0].get("id")
assert tool_call_id is not None
assert all(
chunk.get("id") == tool_call_id
for chunk in collected_tool_chunks
if chunk.get("id")
)
assert final_tool_call["id"] == tool_call_id
@pytest.mark.xfail( @pytest.mark.xfail(
reason=( reason=(
"Will sometime encounter AssertionErrors where tool responses are " "Will sometime encounter AssertionErrors where tool responses are "

View File

@ -1,11 +1,15 @@
"""Test OllamaLLM llm.""" """Test OllamaLLM llm."""
import pytest
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from langchain_ollama.llms import OllamaLLM from langchain_ollama.llms import OllamaLLM
MODEL_NAME = "llama3.1" MODEL_NAME = "llama3.1"
SAMPLE = "What is 3^3?"
def test_stream() -> None: def test_stream() -> None:
"""Test streaming tokens from OpenAI.""" """Test streaming tokens from OpenAI."""
@ -15,6 +19,59 @@ def test_stream() -> None:
assert isinstance(token, str) assert isinstance(token, str)
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_stream_no_reasoning(model: str) -> None:
"""Test streaming with `reasoning=False`"""
llm = OllamaLLM(model=model, num_ctx=2**12)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" not in result.additional_kwargs
# Sanity check the old behavior isn't present
assert "<think>" not in result.content and "</think>" not in result.content
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_reasoning_stream(model: str) -> None:
"""Test streaming with `reasoning=True`"""
llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
for chunk in llm.stream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
# Sanity check the old behavior isn't present
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
async def test_astream() -> None: async def test_astream() -> None:
"""Test streaming tokens from OpenAI.""" """Test streaming tokens from OpenAI."""
llm = OllamaLLM(model=MODEL_NAME) llm = OllamaLLM(model=MODEL_NAME)
@ -23,6 +80,59 @@ async def test_astream() -> None:
assert isinstance(token, str) assert isinstance(token, str)
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_astream_no_reasoning(model: str) -> None:
"""Test async streaming with `reasoning=False`"""
llm = OllamaLLM(model=model, num_ctx=2**12)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" not in result.additional_kwargs
# Sanity check the old behavior isn't present
assert "<think>" not in result.content and "</think>" not in result.content
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_reasoning_astream(model: str) -> None:
"""Test async streaming with `reasoning=True`"""
llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
messages = [
{
"role": "user",
"content": SAMPLE,
}
]
result = None
async for chunk in llm.astream(messages):
assert isinstance(chunk, BaseMessageChunk)
if result is None:
result = chunk
continue
result += chunk
assert isinstance(result, AIMessageChunk)
assert result.content
assert "reasoning_content" in result.additional_kwargs
assert len(result.additional_kwargs["reasoning_content"]) > 0
# Sanity check the old behavior isn't present
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
async def test_abatch() -> None: async def test_abatch() -> None:
"""Test streaming tokens from OllamaLLM.""" """Test streaming tokens from OllamaLLM."""
llm = OllamaLLM(model=MODEL_NAME) llm = OllamaLLM(model=MODEL_NAME)
@ -60,8 +170,68 @@ async def test_ainvoke() -> None:
assert isinstance(result, str) assert isinstance(result, str)
# TODO
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
# async def test_ainvoke_no_reasoning(model: str) -> None:
# """Test using async invoke with `reasoning=False`"""
# llm = OllamaLLM(model=model, num_ctx=2**12)
# message = SAMPLE
# result = await llm.ainvoke(message)
# assert result.content
# assert "reasoning_content" not in result.additional_kwargs
# # Sanity check the old behavior isn't present
# assert "<think>" not in result.content and "</think>" not in result.content
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
# async def test_reasoning_ainvoke(model: str) -> None:
# """Test invoke with `reasoning=True`"""
# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
# message = SAMPLE
# result = await llm.ainvoke(message)
# assert result.content
# assert "reasoning_content" in result.additional_kwargs
# assert len(result.additional_kwargs["reasoning_content"]) > 0
# # Sanity check the old behavior isn't present
# assert "<think>" not in result.content and "</think>" not in result.content
# assert "<think>" not in result.additional_kwargs["reasoning_content"]
# assert "</think>" not in result.additional_kwargs["reasoning_content"]
def test_invoke() -> None: def test_invoke() -> None:
"""Test invoke tokens from OllamaLLM.""" """Test invoke tokens from OllamaLLM."""
llm = OllamaLLM(model=MODEL_NAME) llm = OllamaLLM(model=MODEL_NAME)
result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"])) result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
assert isinstance(result, str) assert isinstance(result, str)
# TODO
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
# def test_invoke_no_reasoning(model: str) -> None:
# """Test using invoke with `reasoning=False`"""
# llm = OllamaLLM(model=model, num_ctx=2**12)
# message = SAMPLE
# result = llm.invoke(message)
# assert result.content
# assert "reasoning_content" not in result.additional_kwargs
# # Sanity check the old behavior isn't present
# assert "<think>" not in result.content and "</think>" not in result.content
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
# def test_reasoning_invoke(model: str) -> None:
# """Test invoke with `reasoning=True`"""
# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
# message = SAMPLE
# result = llm.invoke(message)
# assert result.content
# assert "reasoning_content" in result.additional_kwargs
# assert len(result.additional_kwargs["reasoning_content"]) > 0
# # Sanity check the old behavior isn't present
# assert "<think>" not in result.content and "</think>" not in result.content
# assert "<think>" not in result.additional_kwargs["reasoning_content"]
# assert "</think>" not in result.additional_kwargs["reasoning_content"]

View File

@ -23,7 +23,7 @@ class TestChatOllama(ChatModelUnitTests):
@property @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return {"model": "llama3-groq-tool-use"} return {"model": MODEL_NAME}
def test__parse_arguments_from_tool_call() -> None: def test__parse_arguments_from_tool_call() -> None:
@ -51,7 +51,6 @@ def test_arbitrary_roles_accepted_in_chatmessages(
monkeypatch.setattr(Client, "stream", _mock_httpx_client_stream) monkeypatch.setattr(Client, "stream", _mock_httpx_client_stream)
llm = ChatOllama( llm = ChatOllama(
base_url="http://whocares:11434",
model=MODEL_NAME, model=MODEL_NAME,
verbose=True, verbose=True,
format=None, format=None,

View File

@ -10,7 +10,7 @@ MODEL_NAME = "llama3.1"
def test_initialization() -> None: def test_initialization() -> None:
"""Test embedding model initialization.""" """Test embedding model initialization."""
OllamaEmbeddings(model="llama3", keep_alive=1) OllamaEmbeddings(model=MODEL_NAME, keep_alive=1)
@patch("langchain_ollama.embeddings.validate_model") @patch("langchain_ollama.embeddings.validate_model")

View File

@ -10,25 +10,25 @@ MODEL_NAME = "llama3.1"
def test_initialization() -> None: def test_initialization() -> None:
"""Test integration initialization.""" """Test integration initialization."""
OllamaLLM(model="llama3") OllamaLLM(model=MODEL_NAME)
def test_model_params() -> None: def test_model_params() -> None:
# Test standard tracing params # Test standard tracing params
llm = OllamaLLM(model="llama3") llm = OllamaLLM(model=MODEL_NAME)
ls_params = llm._get_ls_params() ls_params = llm._get_ls_params()
assert ls_params == { assert ls_params == {
"ls_provider": "ollama", "ls_provider": "ollama",
"ls_model_type": "llm", "ls_model_type": "llm",
"ls_model_name": "llama3", "ls_model_name": MODEL_NAME,
} }
llm = OllamaLLM(model="llama3", num_predict=3) llm = OllamaLLM(model=MODEL_NAME, num_predict=3)
ls_params = llm._get_ls_params() ls_params = llm._get_ls_params()
assert ls_params == { assert ls_params == {
"ls_provider": "ollama", "ls_provider": "ollama",
"ls_model_type": "llm", "ls_model_type": "llm",
"ls_model_name": "llama3", "ls_model_name": MODEL_NAME,
"ls_max_tokens": 3, "ls_max_tokens": 3,
} }