From e686a70ee05002a739c24ea8a3537ba10e682b35 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Mon, 7 Jul 2025 13:56:41 -0400 Subject: [PATCH] ollama: thinking, tool streaming, docs, tests (#31772) * New `reasoning` (bool) param to support toggling [Ollama thinking](https://ollama.com/blog/thinking) (#31573, #31700). If `reasoning=True`, Ollama's `thinking` content will be placed in the model responses' `additional_kwargs.reasoning_content`. * Supported by: * ChatOllama (class level, invocation level TODO) * OllamaLLM (TODO) * Added tests to ensure streaming tool calls is successful (#29129) * Refactored tests that relied on `extract_reasoning()` * Myriad docs additions and consistency/typo fixes * Improved type safety in some spots Closes #29129 Addresses #31573 and #31700 Supersedes #31701 --- .../docs/how_to/callbacks_custom_events.ipynb | 6 +- docs/docs/how_to/tool_stream_events.ipynb | 8 +- docs/docs/how_to/tool_streaming.ipynb | 7 + .../langchain_core/language_models/llms.py | 5 +- libs/core/langchain_core/runnables/base.py | 29 ++- .../ollama/langchain_ollama/chat_models.py | 179 +++++++------- .../ollama/langchain_ollama/embeddings.py | 18 +- libs/partners/ollama/langchain_ollama/llms.py | 67 +++-- .../chat_models/test_chat_models_reasoning.py | 234 ++++++++++++------ .../chat_models/test_chat_models_standard.py | 105 +++++++- .../tests/integration_tests/test_llms.py | 170 +++++++++++++ .../tests/unit_tests/test_chat_models.py | 3 +- .../tests/unit_tests/test_embeddings.py | 2 +- .../ollama/tests/unit_tests/test_llms.py | 10 +- 14 files changed, 630 insertions(+), 213 deletions(-) diff --git a/docs/docs/how_to/callbacks_custom_events.ipynb b/docs/docs/how_to/callbacks_custom_events.ipynb index 8fd6b0d0662..36be81b063d 100644 --- a/docs/docs/how_to/callbacks_custom_events.ipynb +++ b/docs/docs/how_to/callbacks_custom_events.ipynb @@ -38,11 +38,11 @@ "\n", "\n", ":::caution COMPATIBILITY\n", - "LangChain cannot automatically propagate configuration, including callbacks necessary for astream_events(), to child runnables if you are running async code in python<=3.10. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n", + "LangChain cannot automatically propagate configuration, including callbacks necessary for astream_events(), to child runnables if you are running async code in `python<=3.10`. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n", "\n", - "If you are running python<=3.10, you will need to manually propagate the `RunnableConfig` object to the child runnable in async environments. For an example of how to manually propagate the config, see the implementation of the `bar` RunnableLambda below.\n", + "If you are running `python<=3.10`, you will need to manually propagate the `RunnableConfig` object to the child runnable in async environments. For an example of how to manually propagate the config, see the implementation of the `bar` RunnableLambda below.\n", "\n", - "If you are running python>=3.11, the `RunnableConfig` will automatically propagate to child runnables in async environment. However, it is still a good idea to propagate the `RunnableConfig` manually if your code may run in other Python versions.\n", + "If you are running `python>=3.11`, the `RunnableConfig` will automatically propagate to child runnables in async environment. However, it is still a good idea to propagate the `RunnableConfig` manually if your code may run in other Python versions.\n", ":::" ] }, diff --git a/docs/docs/how_to/tool_stream_events.ipynb b/docs/docs/how_to/tool_stream_events.ipynb index 2f48cf40e4a..55b3b73cbf4 100644 --- a/docs/docs/how_to/tool_stream_events.ipynb +++ b/docs/docs/how_to/tool_stream_events.ipynb @@ -16,15 +16,15 @@ "\n", ":::\n", "\n", - "If you have [tools](/docs/concepts/tools/) that call [chat models](/docs/concepts/chat_models/), [retrievers](/docs/concepts/retrievers/), or other [runnables](/docs/concepts/runnables/), you may want to access internal events from those runnables or configure them with additional properties. This guide shows you how to manually pass parameters properly so that you can do this using the `astream_events()` method.\n", + "If you have [tools](/docs/concepts/tools/) that call [chat models](/docs/concepts/chat_models/), [retrievers](/docs/concepts/retrievers/), or other [runnables](/docs/concepts/runnables/), you may want to access [internal events](https://python.langchain.com/docs/how_to/streaming/#event-reference) from those runnables or configure them with additional properties. This guide shows you how to manually pass parameters properly so that you can do this using the `astream_events()` method.\n", "\n", ":::caution Compatibility\n", "\n", - "LangChain cannot automatically propagate configuration, including callbacks necessary for `astream_events()`, to child runnables if you are running `async` code in `python<=3.10`. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n", + "LangChain cannot automatically propagate configuration, including callbacks necessary for `astream_events()`, to child runnables if you are running `async` code in `python<=3.10`. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n", "\n", - "If you are running python<=3.10, you will need to manually propagate the `RunnableConfig` object to the child runnable in async environments. For an example of how to manually propagate the config, see the implementation of the `bar` RunnableLambda below.\n", + "If you are running `python<=3.10`, you will need to manually propagate the `RunnableConfig` object to the child runnable in async environments. For an example of how to manually propagate the config, see the implementation of the `bar` RunnableLambda below.\n", "\n", - "If you are running python>=3.11, the `RunnableConfig` will automatically propagate to child runnables in async environment. However, it is still a good idea to propagate the `RunnableConfig` manually if your code may run in older Python versions.\n", + "If you are running `python>=3.11`, the `RunnableConfig` will automatically propagate to child runnables in async environment. However, it is still a good idea to propagate the `RunnableConfig` manually if your code may run in older Python versions.\n", "\n", "This guide also requires `langchain-core>=0.2.16`.\n", ":::\n", diff --git a/docs/docs/how_to/tool_streaming.ipynb b/docs/docs/how_to/tool_streaming.ipynb index fc85d3c31fa..aae04df82c1 100644 --- a/docs/docs/how_to/tool_streaming.ipynb +++ b/docs/docs/how_to/tool_streaming.ipynb @@ -224,6 +224,13 @@ "source": [ "print(type(gathered.tool_calls[0][\"args\"]))" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note the key difference: accumulating `tool_call_chunks` captures the raw tool arguments as an unparsed string as they are streamed. In contrast, **accumulating** `tool_calls` demonstrates partial parsing by progressively converting the streamed argument string into a valid, usable dictionary at each step of the process." + ] } ], "metadata": { diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index ff5a14dcd15..c106e0698e0 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -1414,9 +1414,10 @@ class BaseLLM(BaseLanguageModel[str], ABC): ValueError: If the file path is not a string or Path object. Example: - .. code-block:: python - llm.save(file_path="path/llm.yaml") + .. code-block:: python + + llm.save(file_path="path/llm.yaml") """ # Convert file to Path object. save_path = Path(file_path) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 5b093a205b4..56ac81226f9 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -1160,22 +1160,21 @@ class Runnable(ABC, Generic[Input, Output]): A StreamEvent is a dictionary with the following schema: - - ``event``: **str** - Event names are of the - format: on_[runnable_type]_(start|stream|end). + - ``event``: **str** - Event names are of the format: + on_[runnable_type]_(start|stream|end). - ``name``: **str** - The name of the Runnable that generated the event. - - ``run_id``: **str** - randomly generated ID associated with the given execution of - the Runnable that emitted the event. - A child Runnable that gets invoked as part of the execution of a - parent Runnable is assigned its own unique ID. - - ``parent_ids``: **list[str]** - The IDs of the parent runnables that - generated the event. The root Runnable will have an empty list. - The order of the parent IDs is from the root to the immediate parent. - Only available for v2 version of the API. The v1 version of the API - will return an empty list. + - ``run_id``: **str** - randomly generated ID associated with the given + execution of the Runnable that emitted the event. A child Runnable that gets + invoked as part of the execution of a parent Runnable is assigned its own + unique ID. + - ``parent_ids``: **list[str]** - The IDs of the parent runnables that generated + the event. The root Runnable will have an empty list. The order of the parent + IDs is from the root to the immediate parent. Only available for v2 version of + the API. The v1 version of the API will return an empty list. - ``tags``: **Optional[list[str]]** - The tags of the Runnable that generated - the event. - - ``metadata``: **Optional[dict[str, Any]]** - The metadata of the Runnable - that generated the event. + the event. + - ``metadata``: **Optional[dict[str, Any]]** - The metadata of the Runnable that + generated the event. - ``data``: **dict[str, Any]** @@ -1183,7 +1182,7 @@ class Runnable(ABC, Generic[Input, Output]): chains. Metadata fields have been omitted from the table for brevity. Chain definitions have been included after the table. - **ATTENTION** This reference table is for the V2 version of the schema. + .. NOTE:: This reference table is for the V2 version of the schema. +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ | event | name | chunk | input | output | diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index 3341cdde58e..47ec78df011 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -6,7 +6,6 @@ from operator import itemgetter from typing import ( Any, Callable, - Final, Literal, Optional, Union, @@ -25,7 +24,6 @@ from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, - BaseMessageChunk, ChatMessage, HumanMessage, SystemMessage, @@ -57,9 +55,6 @@ from typing_extensions import Self, is_typeddict from ._utils import validate_model -DEFAULT_THINK_TOKEN_START: Final[str] = "" -DEFAULT_THINK_TOKEN_END: Final[str] = "" - def _get_usage_metadata_from_generation_info( generation_info: Optional[Mapping[str, Any]], @@ -166,13 +161,14 @@ def _get_tool_calls_from_response( return tool_calls -def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict: +def _lc_tool_call_to_openai_tool_call(tool_call_: ToolCall) -> dict: + """Convert a LangChain tool call to an OpenAI tool call format.""" return { "type": "function", - "id": tool_call["id"], + "id": tool_call_["id"], "function": { - "name": tool_call["name"], - "arguments": tool_call["args"], + "name": tool_call_["name"], + "arguments": tool_call_["args"], }, } @@ -211,6 +207,20 @@ class ChatOllama(BaseChatModel): Key init args — completion params: model: str Name of Ollama model to use. + reasoning: Optional[bool] + Controls the reasoning/thinking mode for + `supported models `__. + + - ``True``: Enables reasoning mode. The model's reasoning process will be + captured and returned separately in the ``additional_kwargs`` of the + response message, under ``reasoning_content``. The main response + content will not include the reasoning tags. + - ``False``: Disables reasoning mode. The model will not perform any reasoning, + and the response will not include any reasoning content. + - ``None`` (Default): The model will use its default reasoning behavior. Note + however, if the model's default behavior *is* to perform reasoning, think tags + ()```` and ````) will be present within the main response content + unless you set ``reasoning`` to ``True``. temperature: float Sampling temperature. Ranges from 0.0 to 1.0. num_predict: Optional[int] @@ -347,21 +357,29 @@ class ChatOllama(BaseChatModel): 'args': {'a': 45, 'b': 67}, 'id': '420c3f3b-df10-4188-945f-eb3abdb40622', 'type': 'tool_call'}] - """ # noqa: E501 + """ # noqa: E501, pylint: disable=line-too-long model: str """Model name to use.""" + reasoning: Optional[bool] = None + """Controls the reasoning/thinking mode for + `supported models `__. + + - ``True``: Enables reasoning mode. The model's reasoning process will be + captured and returned separately in the ``additional_kwargs`` of the + response message, under ``reasoning_content``. The main response + content will not include the reasoning tags. + - ``False``: Disables reasoning mode. The model will not perform any reasoning, + and the response will not include any reasoning content. + - ``None`` (Default): The model will use its default reasoning behavior. Note + however, if the model's default behavior *is* to perform reasoning, think tags + ()```` and ````) will be present within the main response content + unless you set ``reasoning`` to ``True``.""" + validate_model_on_init: bool = False """Whether to validate the model exists in Ollama locally on initialization.""" - extract_reasoning: Optional[Union[bool, tuple[str, str]]] = False - """Whether to extract the reasoning tokens in think blocks. - Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`. - If a tuple is supplied, they are assumed to be the (start, end) tokens. - If `extract_reasoning=True`, the tokens will default to (, ). - """ - mirostat: Optional[int] = None """Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)""" @@ -448,24 +466,23 @@ class ChatOllama(BaseChatModel): """ async_client_kwargs: Optional[dict] = {} - """Additional kwargs to merge with client_kwargs before passing to the HTTPX - AsyncClient. - - For a full list of the params, see the `HTTPX documentation `__. + """Additional kwargs to merge with client_kwargs before + passing to the httpx AsyncClient. + `Full list of params. `__ """ sync_client_kwargs: Optional[dict] = {} - """Additional kwargs to merge with client_kwargs before passing to the HTTPX Client. - - For a full list of the params, see the `HTTPX documentation `__. + """Additional kwargs to merge with client_kwargs before + passing to the httpx Client. + `Full list of params. `__ """ - _client: Client = PrivateAttr(default=None) # type: ignore + _client: Client = PrivateAttr() """ The client to use for making requests. """ - _async_client: AsyncClient = PrivateAttr(default=None) # type: ignore + _async_client: AsyncClient = PrivateAttr() """ The async client to use for making requests. """ @@ -480,7 +497,7 @@ class ChatOllama(BaseChatModel): if self.stop is not None and stop is not None: raise ValueError("`stop` found in both the input and default params.") - elif self.stop is not None: + if self.stop is not None: stop = self.stop options_dict = kwargs.pop( @@ -508,6 +525,7 @@ class ChatOllama(BaseChatModel): "messages": ollama_messages, "stream": kwargs.pop("stream", True), "model": kwargs.pop("model", self.model), + "think": kwargs.pop("reasoning", self.reasoning), "format": kwargs.pop("format", self.format), "options": Options(**options_dict), "keep_alive": kwargs.pop("keep_alive", self.keep_alive), @@ -618,35 +636,13 @@ class ChatOllama(BaseChatModel): "images": images, } if tool_calls: - msg["tool_calls"] = tool_calls # type: ignore + msg["tool_calls"] = tool_calls if tool_call_id: msg["tool_call_id"] = tool_call_id ollama_messages.append(msg) return ollama_messages - def _extract_reasoning( - self, message_chunk: BaseMessageChunk, is_thinking: bool - ) -> tuple[BaseMessageChunk, bool]: - """Mutate a message chunk to extract reasoning content.""" - if not self.extract_reasoning: - return message_chunk, is_thinking - elif self.extract_reasoning is True: - start_token = DEFAULT_THINK_TOKEN_START - end_token = DEFAULT_THINK_TOKEN_END - else: - start_token, end_token = cast(tuple, self.extract_reasoning) - if start_token in message_chunk.content: - is_thinking = True - content = message_chunk.content - if is_thinking: - message_chunk.additional_kwargs["reasoning_content"] = content - message_chunk.content = "" - if end_token in content: - is_thinking = False - - return message_chunk, is_thinking - async def _acreate_chat_stream( self, messages: list[BaseMessage], @@ -670,9 +666,11 @@ class ChatOllama(BaseChatModel): chat_params = self._chat_params(messages, stop, **kwargs) if chat_params["stream"]: - yield from self._client.chat(**chat_params) + if self._client: + yield from self._client.chat(**chat_params) else: - yield self._client.chat(**chat_params) + if self._client: + yield self._client.chat(**chat_params) def _chat_stream_with_aggregation( self, @@ -767,22 +765,34 @@ class ChatOllama(BaseChatModel): stop: Optional[list[str]] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - is_thinking = False for stream_resp in self._create_chat_stream(messages, stop, **kwargs): if not isinstance(stream_resp, str): if stream_resp.get("done") is True: generation_info = dict(stream_resp) + if "model" in generation_info: + generation_info["model_name"] = generation_info["model"] _ = generation_info.pop("message", None) else: generation_info = None + + content = ( + stream_resp["message"]["content"] + if "message" in stream_resp and "content" in stream_resp["message"] + else "" + ) + + additional_kwargs = {} + if ( + self.reasoning + and "message" in stream_resp + and (thinking_content := stream_resp["message"].get("thinking")) + ): + additional_kwargs["reasoning_content"] = thinking_content + chunk = ChatGenerationChunk( message=AIMessageChunk( - content=( - stream_resp["message"]["content"] - if "message" in stream_resp - and "content" in stream_resp["message"] - else "" - ), + content=content, + additional_kwargs=additional_kwargs, usage_metadata=_get_usage_metadata_from_generation_info( stream_resp ), @@ -790,15 +800,7 @@ class ChatOllama(BaseChatModel): ), generation_info=generation_info, ) - if chunk.generation_info and ( - model := chunk.generation_info.get("model") - ): - chunk.generation_info["model_name"] = model # backwards compat - if self.extract_reasoning: - message, is_thinking = self._extract_reasoning( - chunk.message, is_thinking - ) - chunk.message = message + yield chunk def _stream( @@ -822,22 +824,34 @@ class ChatOllama(BaseChatModel): stop: Optional[list[str]] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: - is_thinking = False async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs): if not isinstance(stream_resp, str): if stream_resp.get("done") is True: generation_info = dict(stream_resp) + if "model" in generation_info: + generation_info["model_name"] = generation_info["model"] _ = generation_info.pop("message", None) else: generation_info = None + + content = ( + stream_resp["message"]["content"] + if "message" in stream_resp and "content" in stream_resp["message"] + else "" + ) + + additional_kwargs = {} + if ( + self.reasoning + and "message" in stream_resp + and (thinking_content := stream_resp["message"].get("thinking")) + ): + additional_kwargs["reasoning_content"] = thinking_content + chunk = ChatGenerationChunk( message=AIMessageChunk( - content=( - stream_resp["message"]["content"] - if "message" in stream_resp - and "content" in stream_resp["message"] - else "" - ), + content=content, + additional_kwargs=additional_kwargs, usage_metadata=_get_usage_metadata_from_generation_info( stream_resp ), @@ -845,15 +859,7 @@ class ChatOllama(BaseChatModel): ), generation_info=generation_info, ) - if chunk.generation_info and ( - model := chunk.generation_info.get("model") - ): - chunk.generation_info["model_name"] = model # backwards compat - if self.extract_reasoning: - message, is_thinking = self._extract_reasoning( - chunk.message, is_thinking - ) - chunk.message = message + yield chunk async def _astream( @@ -950,7 +956,7 @@ class ChatOllama(BaseChatModel): method: The method for steering model generation, one of: - "json_schema": - Uses Ollama's structured output API: https://ollama.com/blog/structured-outputs + Uses Ollama's `structured output API `__ - "function_calling": Uses Ollama's tool-calling API - "json_mode": @@ -1267,5 +1273,4 @@ class ChatOllama(BaseChatModel): [parser_none], exception_key="parsing_error" ) return RunnableMap(raw=llm) | parser_with_fallback - else: - return llm | output_parser + return llm | output_parser diff --git a/libs/partners/ollama/langchain_ollama/embeddings.py b/libs/partners/ollama/langchain_ollama/embeddings.py index ea15b71e87c..7e589bd291b 100644 --- a/libs/partners/ollama/langchain_ollama/embeddings.py +++ b/libs/partners/ollama/langchain_ollama/embeddings.py @@ -97,7 +97,7 @@ class OllamaEmbeddings(BaseModel, Embeddings): Embed multiple texts: .. code-block:: python - input_texts = ["Document 1...", "Document 2..."] + input_texts = ["Document 1...", "Document 2..."] vectors = embed.embed_documents(input_texts) print(len(vectors)) # The first 3 coordinates for the first vector @@ -112,7 +112,7 @@ class OllamaEmbeddings(BaseModel, Embeddings): .. code-block:: python vector = await embed.aembed_query(input_text) - print(vector[:3]) + print(vector[:3]) # multiple: # await embed.aembed_documents(input_texts) @@ -151,12 +151,12 @@ class OllamaEmbeddings(BaseModel, Embeddings): For a full list of the params, see the `HTTPX documentation `__. """ - _client: Client = PrivateAttr(default=None) # type: ignore + _client: Optional[Client] = PrivateAttr(default=None) """ The client to use for making requests. """ - _async_client: AsyncClient = PrivateAttr(default=None) # type: ignore + _async_client: Optional[AsyncClient] = PrivateAttr(default=None) """ The async client to use for making requests. """ @@ -270,6 +270,11 @@ class OllamaEmbeddings(BaseModel, Embeddings): def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs.""" + if not self._client: + raise ValueError( + "Ollama client is not initialized. " + "Please ensure Ollama is running and the model is loaded." + ) embedded_docs = self._client.embed( self.model, texts, options=self._default_params, keep_alive=self.keep_alive )["embeddings"] @@ -281,6 +286,11 @@ class OllamaEmbeddings(BaseModel, Embeddings): async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs.""" + if not self._async_client: + raise ValueError( + "Ollama client is not initialized. " + "Please ensure Ollama is running and the model is loaded." + ) embedded_docs = ( await self._async_client.embed( self.model, texts, keep_alive=self.keep_alive diff --git a/libs/partners/ollama/langchain_ollama/llms.py b/libs/partners/ollama/langchain_ollama/llms.py index 2c73c28ef57..f488ea73d1a 100644 --- a/libs/partners/ollama/langchain_ollama/llms.py +++ b/libs/partners/ollama/langchain_ollama/llms.py @@ -36,6 +36,20 @@ class OllamaLLM(BaseLLM): model: str """Model name to use.""" + reasoning: Optional[bool] = True + """Controls the reasoning/thinking mode for + `supported models `__. + + - ``True``: Enables reasoning mode. The model's reasoning process will be + captured and returned separately in the ``additional_kwargs`` of the + response message, under ``reasoning_content``. The main response + content will not include the reasoning tags. + - ``False``: Disables reasoning mode. The model will not perform any reasoning, + and the response will not include any reasoning content. + - ``None`` (Default): The model will use its default reasoning behavior. If + the model performs reasoning, the ```` and ```` tags will + be present directly within the main response content.""" + validate_model_on_init: bool = False """Whether to validate the model exists in ollama locally on initialization.""" @@ -56,7 +70,7 @@ class OllamaLLM(BaseLLM): num_ctx: Optional[int] = None """Sets the size of the context window used to generate the - next token. (Default: 2048) """ + next token. (Default: 2048)""" num_gpu: Optional[int] = None """The number of GPUs to use. On macOS it defaults to 1 to @@ -137,12 +151,12 @@ class OllamaLLM(BaseLLM): For a full list of the params, see the `HTTPX documentation `__. """ - _client: Client = PrivateAttr(default=None) # type: ignore + _client: Optional[Client] = PrivateAttr(default=None) """ The client to use for making requests. """ - _async_client: AsyncClient = PrivateAttr(default=None) # type: ignore + _async_client: Optional[AsyncClient] = PrivateAttr(default=None) """ The async client to use for making requests. """ @@ -155,7 +169,7 @@ class OllamaLLM(BaseLLM): ) -> dict[str, Any]: if self.stop is not None and stop is not None: raise ValueError("`stop` found in both the input and default params.") - elif self.stop is not None: + if self.stop is not None: stop = self.stop options_dict = kwargs.pop( @@ -183,6 +197,7 @@ class OllamaLLM(BaseLLM): "prompt": prompt, "stream": kwargs.pop("stream", True), "model": kwargs.pop("model", self.model), + "think": kwargs.pop("reasoning", self.reasoning), "format": kwargs.pop("format", self.format), "options": Options(**options_dict), "keep_alive": kwargs.pop("keep_alive", self.keep_alive), @@ -230,10 +245,11 @@ class OllamaLLM(BaseLLM): stop: Optional[list[str]] = None, **kwargs: Any, ) -> AsyncIterator[Union[Mapping[str, Any], str]]: - async for part in await self._async_client.generate( - **self._generate_params(prompt, stop=stop, **kwargs) - ): # type: ignore - yield part # type: ignore + if self._async_client: + async for part in await self._async_client.generate( + **self._generate_params(prompt, stop=stop, **kwargs) + ): + yield part def _create_generate_stream( self, @@ -241,9 +257,10 @@ class OllamaLLM(BaseLLM): stop: Optional[list[str]] = None, **kwargs: Any, ) -> Iterator[Union[Mapping[str, Any], str]]: - yield from self._client.generate( - **self._generate_params(prompt, stop=stop, **kwargs) - ) # type: ignore + if self._client: + yield from self._client.generate( + **self._generate_params(prompt, stop=stop, **kwargs) + ) async def _astream_with_aggregation( self, @@ -356,11 +373,19 @@ class OllamaLLM(BaseLLM): ) -> Iterator[GenerationChunk]: for stream_resp in self._create_generate_stream(prompt, stop, **kwargs): if not isinstance(stream_resp, str): + additional_kwargs = {} + if thinking_content := stream_resp.get("thinking"): + additional_kwargs["reasoning_content"] = thinking_content + chunk = GenerationChunk( text=(stream_resp.get("response", "")), - generation_info=( - dict(stream_resp) if stream_resp.get("done") is True else None - ), + generation_info={ + "finish_reason": self.stop, + **additional_kwargs, + **( + dict(stream_resp) if stream_resp.get("done") is True else {} + ), + }, ) if run_manager: run_manager.on_llm_new_token( @@ -378,11 +403,19 @@ class OllamaLLM(BaseLLM): ) -> AsyncIterator[GenerationChunk]: async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs): if not isinstance(stream_resp, str): + additional_kwargs = {} + if thinking_content := stream_resp.get("thinking"): + additional_kwargs["reasoning_content"] = thinking_content + chunk = GenerationChunk( text=(stream_resp.get("response", "")), - generation_info=( - dict(stream_resp) if stream_resp.get("done") is True else None - ), + generation_info={ + "finish_reason": self.stop, + **additional_kwargs, + **( + dict(stream_resp) if stream_resp.get("done") is True else {} + ), + }, ) if run_manager: await run_manager.on_llm_new_token( diff --git a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_reasoning.py b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_reasoning.py index 5e41c424a66..19e2106e9ce 100644 --- a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_reasoning.py +++ b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_reasoning.py @@ -1,17 +1,28 @@ """Ollama specific chat model integration tests for reasoning models.""" import pytest -from langchain_core.messages import AIMessageChunk, BaseMessageChunk, HumanMessage -from pydantic import ValidationError +from langchain_core.messages import ( + AIMessageChunk, + BaseMessageChunk, + HumanMessage, +) +from pydantic import BaseModel, Field from langchain_ollama import ChatOllama SAMPLE = "What is 3^3?" +class MathAnswer(BaseModel): + """A mathematical expression and its numerical answer.""" + + expression: str = Field(description="The mathematical expression to evaluate.") + answer: int = Field(description="The numerical answer to the expression.") + + @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -def test_deepseek_messages_stream_no_reasoning(model: str) -> None: - """Test deepseek model without parsing.""" +def test_stream_no_reasoning(model: str) -> None: + """Test streaming with `reasoning=False`""" llm = ChatOllama(model=model, num_ctx=2**12) messages = [ { @@ -28,14 +39,41 @@ def test_deepseek_messages_stream_no_reasoning(model: str) -> None: result += chunk assert isinstance(result, AIMessageChunk) assert result.content - assert "" in result.content and "" in result.content assert "reasoning_content" not in result.additional_kwargs + assert "" not in result.content and "" not in result.content + assert "" not in result.additional_kwargs["reasoning_content"] + assert "" not in result.additional_kwargs["reasoning_content"] @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -def test_deepseek_messages_stream_bool(model: str) -> None: - """Test deepseek model with reasoning bool=True""" - llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True) +async def test_astream_no_reasoning(model: str) -> None: + """Test async streaming with `reasoning=False`""" + llm = ChatOllama(model=model, num_ctx=2**12) + messages = [ + { + "role": "user", + "content": SAMPLE, + } + ] + result = None + async for chunk in llm.astream(messages): + assert isinstance(chunk, BaseMessageChunk) + if result is None: + result = chunk + continue + result += chunk + assert isinstance(result, AIMessageChunk) + assert result.content + assert "reasoning_content" not in result.additional_kwargs + assert "" not in result.content and "" not in result.content + assert "" not in result.additional_kwargs["reasoning_content"] + assert "" not in result.additional_kwargs["reasoning_content"] + + +@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +def test_stream_reasoning_none(model: str) -> None: + """Test streaming with `reasoning=None`""" + llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None) messages = [ { "role": "user", @@ -51,26 +89,41 @@ def test_deepseek_messages_stream_bool(model: str) -> None: result += chunk assert isinstance(result, AIMessageChunk) assert result.content - assert "" not in result.content and "" not in result.content - assert "reasoning_content" in result.additional_kwargs - assert len(result.additional_kwargs["reasoning_content"]) > 0 - assert "" in result.additional_kwargs["reasoning_content"] - assert "" in result.additional_kwargs["reasoning_content"] - clean_content = ( - result.additional_kwargs["reasoning_content"] - .replace("", "") - .replace("", "") - .strip() - ) - assert len(clean_content) > 0 + assert "reasoning_content" not in result.additional_kwargs + assert "" in result.content and "" in result.content + assert "" not in result.additional_kwargs["reasoning_content"] + assert "" not in result.additional_kwargs["reasoning_content"] @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -def test_deepseek_messages_stream_tuple(model: str) -> None: - """Test deepseek model with reasoning with tuple=...""" - llm = ChatOllama( - model=model, num_ctx=2**12, extract_reasoning=("", "") - ) +async def test_astream_reasoning_none(model: str) -> None: + """Test async streaming with `reasoning=None`""" + llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None) + messages = [ + { + "role": "user", + "content": SAMPLE, + } + ] + result = None + async for chunk in llm.astream(messages): + assert isinstance(chunk, BaseMessageChunk) + if result is None: + result = chunk + continue + result += chunk + assert isinstance(result, AIMessageChunk) + assert result.content + assert "reasoning_content" not in result.additional_kwargs + assert "" in result.content and "" in result.content + assert "" not in result.additional_kwargs["reasoning_content"] + assert "" not in result.additional_kwargs["reasoning_content"] + + +@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +def test_reasoning_stream(model: str) -> None: + """Test streaming with `reasoning=True`""" + llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True) messages = [ { "role": "user", @@ -86,77 +139,114 @@ def test_deepseek_messages_stream_tuple(model: str) -> None: result += chunk assert isinstance(result, AIMessageChunk) assert result.content - assert "" not in result.content and "" not in result.content assert "reasoning_content" in result.additional_kwargs assert len(result.additional_kwargs["reasoning_content"]) > 0 - assert "" in result.additional_kwargs["reasoning_content"] - assert "" in result.additional_kwargs["reasoning_content"] - clean_content = ( - result.additional_kwargs["reasoning_content"] - .replace("", "") - .replace("", "") - .strip() - ) - assert len(clean_content) > 0 + assert "" not in result.content and "" not in result.content + assert "" not in result.additional_kwargs["reasoning_content"] + assert "" not in result.additional_kwargs["reasoning_content"] @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -def test_deepseek_messages_invoke_no_reasoning(model: str) -> None: - """Test deepseek model without parsing using invoke.""" +async def test_reasoning_astream(model: str) -> None: + """Test async streaming with `reasoning=True`""" + llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True) + messages = [ + { + "role": "user", + "content": SAMPLE, + } + ] + result = None + async for chunk in llm.astream(messages): + assert isinstance(chunk, BaseMessageChunk) + if result is None: + result = chunk + continue + result += chunk + assert isinstance(result, AIMessageChunk) + assert result.content + assert "reasoning_content" in result.additional_kwargs + assert len(result.additional_kwargs["reasoning_content"]) > 0 + assert "" not in result.content and "" not in result.content + assert "" not in result.additional_kwargs["reasoning_content"] + assert "" not in result.additional_kwargs["reasoning_content"] + + +@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +def test_invoke_no_reasoning(model: str) -> None: + """Test using invoke with `reasoning=False`""" llm = ChatOllama(model=model, num_ctx=2**12) message = HumanMessage(content=SAMPLE) result = llm.invoke([message]) assert result.content - assert "" in result.content and "" in result.content assert "reasoning_content" not in result.additional_kwargs + assert "" not in result.content and "" not in result.content + assert "" not in result.additional_kwargs["reasoning_content"] + assert "" not in result.additional_kwargs["reasoning_content"] @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -def test_deepseek_messages_invoke_bool(model: str) -> None: - """Test deepseek model with reasoning bool=True using invoke""" - llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True) +async def test_ainvoke_no_reasoning(model: str) -> None: + """Test using async invoke with `reasoning=False`""" + llm = ChatOllama(model=model, num_ctx=2**12) + message = HumanMessage(content=SAMPLE) + result = await llm.ainvoke([message]) + assert result.content + assert "reasoning_content" not in result.additional_kwargs + assert "" not in result.content and "" not in result.content + assert "" not in result.additional_kwargs["reasoning_content"] + assert "" not in result.additional_kwargs["reasoning_content"] + + +@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +def test_invoke_reasoning_none(model: str) -> None: + """Test using invoke with `reasoning=None`""" + llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None) message = HumanMessage(content=SAMPLE) result = llm.invoke([message]) assert result.content - assert "" not in result.content and "" not in result.content - assert "reasoning_content" in result.additional_kwargs - assert len(result.additional_kwargs["reasoning_content"]) > 0 - assert "" in result.additional_kwargs["reasoning_content"] - assert "" in result.additional_kwargs["reasoning_content"] - clean_content = ( - result.additional_kwargs["reasoning_content"] - .replace("", "") - .replace("", "") - .strip() - ) - assert len(clean_content) > 0 + assert "reasoning_content" not in result.additional_kwargs + assert "" in result.content and "" in result.content + assert "" not in result.additional_kwargs["reasoning_content"] + assert "" not in result.additional_kwargs["reasoning_content"] @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -def test_deepseek_messages_invoke_tuple(model: str) -> None: - """Test deepseek model with reasoning with tuple=... using invoke""" - llm = ChatOllama( - model=model, num_ctx=2**12, extract_reasoning=("", "") - ) +async def test_ainvoke_reasoning_none(model: str) -> None: + """Test using async invoke with `reasoning=None`""" + llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None) + message = HumanMessage(content=SAMPLE) + result = await llm.ainvoke([message]) + assert result.content + assert "reasoning_content" not in result.additional_kwargs + assert "" in result.content and "" in result.content + assert "" not in result.additional_kwargs["reasoning_content"] + assert "" not in result.additional_kwargs["reasoning_content"] + + +@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +def test_reasoning_invoke(model: str) -> None: + """Test invoke with `reasoning=True`""" + llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True) message = HumanMessage(content=SAMPLE) result = llm.invoke([message]) assert result.content - assert "" not in result.content and "" not in result.content assert "reasoning_content" in result.additional_kwargs assert len(result.additional_kwargs["reasoning_content"]) > 0 - assert "" in result.additional_kwargs["reasoning_content"] - assert "" in result.additional_kwargs["reasoning_content"] - clean_content = ( - result.additional_kwargs["reasoning_content"] - .replace("", "") - .replace("", "") - .strip() - ) - assert len(clean_content) > 0 + assert "" not in result.content and "" not in result.content + assert "" not in result.additional_kwargs["reasoning_content"] + assert "" not in result.additional_kwargs["reasoning_content"] @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -def test_deepseek_invalid(model: str) -> None: - """Test deepseek model with reasoning raises ValidationError""" - with pytest.raises(ValidationError): - _ = ChatOllama(model=model, extract_reasoning={"invalid": "data"}) # type: ignore[arg-type] +async def test_reasoning_ainvoke(model: str) -> None: + """Test invoke with `reasoning=True`""" + llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True) + message = HumanMessage(content=SAMPLE) + result = await llm.ainvoke([message]) + assert result.content + assert "reasoning_content" in result.additional_kwargs + assert len(result.additional_kwargs["reasoning_content"]) > 0 + assert "" not in result.content and "" not in result.content + assert "" not in result.additional_kwargs["reasoning_content"] + assert "" not in result.additional_kwargs["reasoning_content"] diff --git a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py index c3216b8a37a..c8500683178 100644 --- a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py +++ b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py @@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch import pytest from httpx import ConnectError from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessageChunk, HumanMessage, ToolCallChunk +from langchain_core.tools import tool from langchain_tests.integration_tests import ChatModelIntegrationTests from ollama import ResponseError from pydantic import ValidationError @@ -14,6 +16,15 @@ from langchain_ollama.chat_models import ChatOllama DEFAULT_MODEL_NAME = "llama3.1" +@tool +def get_current_weather(location: str) -> dict: + """Gets the current weather in a given location.""" + if "boston" in location.lower(): + return {"temperature": "15°F", "conditions": "snow"} + else: + return {"temperature": "unknown", "conditions": "unknown"} + + class TestChatOllama(ChatModelIntegrationTests): @property def chat_model_class(self) -> type[ChatOllama]: @@ -29,12 +40,104 @@ class TestChatOllama(ChatModelIntegrationTests): @property def has_tool_choice(self) -> bool: - return False # TODO: update after Ollama implements + # TODO: update after Ollama implements + # https://github.com/ollama/ollama/blob/main/docs/openai.md + return False @property def supports_image_inputs(self) -> bool: return True + def test_tool_streaming(self, model: BaseChatModel) -> None: + """Test that the model can stream tool calls.""" + chat_model_with_tools = model.bind_tools([get_current_weather]) + + prompt = [HumanMessage("What is the weather today in Boston?")] + + # Flags and collectors for validation + tool_chunk_found = False + final_tool_calls = [] + collected_tool_chunks: list[ToolCallChunk] = [] + + # Stream the response and inspect the chunks + for chunk in chat_model_with_tools.stream(prompt): + assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type" + + if chunk.tool_call_chunks: + tool_chunk_found = True + for tc_chunk in chunk.tool_call_chunks: + collected_tool_chunks.append(tc_chunk) + + if chunk.tool_calls: + final_tool_calls.extend(chunk.tool_calls) + + assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks." + assert ( + len(final_tool_calls) == 1 + ), f"Expected 1 final tool call, but got {len(final_tool_calls)}" + + final_tool_call = final_tool_calls[0] + assert final_tool_call["name"] == "get_current_weather" + assert final_tool_call["args"] == {"location": "Boston"} + + assert len(collected_tool_chunks) > 0 + assert collected_tool_chunks[0]["name"] == "get_current_weather" + + # The ID should be consistent across chunks that have it + tool_call_id = collected_tool_chunks[0].get("id") + assert tool_call_id is not None + assert all( + chunk.get("id") == tool_call_id + for chunk in collected_tool_chunks + if chunk.get("id") + ) + assert final_tool_call["id"] == tool_call_id + + async def test_tool_astreaming(self, model: BaseChatModel) -> None: + """Test that the model can stream tool calls.""" + chat_model_with_tools = model.bind_tools([get_current_weather]) + + prompt = [HumanMessage("What is the weather today in Boston?")] + + # Flags and collectors for validation + tool_chunk_found = False + final_tool_calls = [] + collected_tool_chunks: list[ToolCallChunk] = [] + + # Stream the response and inspect the chunks + async for chunk in chat_model_with_tools.astream(prompt): + assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type" + + if chunk.tool_call_chunks: + tool_chunk_found = True + for tc_chunk in chunk.tool_call_chunks: + collected_tool_chunks.append(tc_chunk) + + if chunk.tool_calls: + final_tool_calls.extend(chunk.tool_calls) + + assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks." + assert ( + len(final_tool_calls) == 1 + ), f"Expected 1 final tool call, but got {len(final_tool_calls)}" + + final_tool_call = final_tool_calls[0] + assert final_tool_call["name"] == "get_current_weather" + assert final_tool_call["args"] == {"location": "Boston"} + + assert len(collected_tool_chunks) > 0 + assert collected_tool_chunks[0]["name"] == "get_current_weather" + + # The ID should be consistent across chunks that have it + tool_call_id = collected_tool_chunks[0].get("id") + assert tool_call_id is not None + assert all( + chunk.get("id") == tool_call_id + for chunk in collected_tool_chunks + if chunk.get("id") + ) + assert final_tool_call["id"] == tool_call_id + @pytest.mark.xfail( reason=( "Will sometime encounter AssertionErrors where tool responses are " diff --git a/libs/partners/ollama/tests/integration_tests/test_llms.py b/libs/partners/ollama/tests/integration_tests/test_llms.py index 819711d28dc..afc67ef7118 100644 --- a/libs/partners/ollama/tests/integration_tests/test_llms.py +++ b/libs/partners/ollama/tests/integration_tests/test_llms.py @@ -1,11 +1,15 @@ """Test OllamaLLM llm.""" +import pytest +from langchain_core.messages import AIMessageChunk, BaseMessageChunk from langchain_core.runnables import RunnableConfig from langchain_ollama.llms import OllamaLLM MODEL_NAME = "llama3.1" +SAMPLE = "What is 3^3?" + def test_stream() -> None: """Test streaming tokens from OpenAI.""" @@ -15,6 +19,59 @@ def test_stream() -> None: assert isinstance(token, str) +@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +def test_stream_no_reasoning(model: str) -> None: + """Test streaming with `reasoning=False`""" + llm = OllamaLLM(model=model, num_ctx=2**12) + messages = [ + { + "role": "user", + "content": SAMPLE, + } + ] + result = None + for chunk in llm.stream(messages): + assert isinstance(chunk, BaseMessageChunk) + if result is None: + result = chunk + continue + result += chunk + assert isinstance(result, AIMessageChunk) + assert result.content + assert "reasoning_content" not in result.additional_kwargs + + # Sanity check the old behavior isn't present + assert "" not in result.content and "" not in result.content + + +@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +def test_reasoning_stream(model: str) -> None: + """Test streaming with `reasoning=True`""" + llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True) + messages = [ + { + "role": "user", + "content": SAMPLE, + } + ] + result = None + for chunk in llm.stream(messages): + assert isinstance(chunk, BaseMessageChunk) + if result is None: + result = chunk + continue + result += chunk + assert isinstance(result, AIMessageChunk) + assert result.content + assert "reasoning_content" in result.additional_kwargs + assert len(result.additional_kwargs["reasoning_content"]) > 0 + + # Sanity check the old behavior isn't present + assert "" not in result.content and "" not in result.content + assert "" not in result.additional_kwargs["reasoning_content"] + assert "" not in result.additional_kwargs["reasoning_content"] + + async def test_astream() -> None: """Test streaming tokens from OpenAI.""" llm = OllamaLLM(model=MODEL_NAME) @@ -23,6 +80,59 @@ async def test_astream() -> None: assert isinstance(token, str) +@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +async def test_astream_no_reasoning(model: str) -> None: + """Test async streaming with `reasoning=False`""" + llm = OllamaLLM(model=model, num_ctx=2**12) + messages = [ + { + "role": "user", + "content": SAMPLE, + } + ] + result = None + async for chunk in llm.astream(messages): + assert isinstance(chunk, BaseMessageChunk) + if result is None: + result = chunk + continue + result += chunk + assert isinstance(result, AIMessageChunk) + assert result.content + assert "reasoning_content" not in result.additional_kwargs + + # Sanity check the old behavior isn't present + assert "" not in result.content and "" not in result.content + + +@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +async def test_reasoning_astream(model: str) -> None: + """Test async streaming with `reasoning=True`""" + llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True) + messages = [ + { + "role": "user", + "content": SAMPLE, + } + ] + result = None + async for chunk in llm.astream(messages): + assert isinstance(chunk, BaseMessageChunk) + if result is None: + result = chunk + continue + result += chunk + assert isinstance(result, AIMessageChunk) + assert result.content + assert "reasoning_content" in result.additional_kwargs + assert len(result.additional_kwargs["reasoning_content"]) > 0 + + # Sanity check the old behavior isn't present + assert "" not in result.content and "" not in result.content + assert "" not in result.additional_kwargs["reasoning_content"] + assert "" not in result.additional_kwargs["reasoning_content"] + + async def test_abatch() -> None: """Test streaming tokens from OllamaLLM.""" llm = OllamaLLM(model=MODEL_NAME) @@ -60,8 +170,68 @@ async def test_ainvoke() -> None: assert isinstance(result, str) +# TODO +# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +# async def test_ainvoke_no_reasoning(model: str) -> None: +# """Test using async invoke with `reasoning=False`""" +# llm = OllamaLLM(model=model, num_ctx=2**12) +# message = SAMPLE +# result = await llm.ainvoke(message) +# assert result.content +# assert "reasoning_content" not in result.additional_kwargs + +# # Sanity check the old behavior isn't present +# assert "" not in result.content and "" not in result.content + + +# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +# async def test_reasoning_ainvoke(model: str) -> None: +# """Test invoke with `reasoning=True`""" +# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True) +# message = SAMPLE +# result = await llm.ainvoke(message) +# assert result.content +# assert "reasoning_content" in result.additional_kwargs +# assert len(result.additional_kwargs["reasoning_content"]) > 0 + +# # Sanity check the old behavior isn't present +# assert "" not in result.content and "" not in result.content +# assert "" not in result.additional_kwargs["reasoning_content"] +# assert "" not in result.additional_kwargs["reasoning_content"] + + def test_invoke() -> None: """Test invoke tokens from OllamaLLM.""" llm = OllamaLLM(model=MODEL_NAME) result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"])) assert isinstance(result, str) + + +# TODO +# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +# def test_invoke_no_reasoning(model: str) -> None: +# """Test using invoke with `reasoning=False`""" +# llm = OllamaLLM(model=model, num_ctx=2**12) +# message = SAMPLE +# result = llm.invoke(message) +# assert result.content +# assert "reasoning_content" not in result.additional_kwargs + +# # Sanity check the old behavior isn't present +# assert "" not in result.content and "" not in result.content + + +# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +# def test_reasoning_invoke(model: str) -> None: +# """Test invoke with `reasoning=True`""" +# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True) +# message = SAMPLE +# result = llm.invoke(message) +# assert result.content +# assert "reasoning_content" in result.additional_kwargs +# assert len(result.additional_kwargs["reasoning_content"]) > 0 + +# # Sanity check the old behavior isn't present +# assert "" not in result.content and "" not in result.content +# assert "" not in result.additional_kwargs["reasoning_content"] +# assert "" not in result.additional_kwargs["reasoning_content"] diff --git a/libs/partners/ollama/tests/unit_tests/test_chat_models.py b/libs/partners/ollama/tests/unit_tests/test_chat_models.py index 638e5844711..e7ac5a29957 100644 --- a/libs/partners/ollama/tests/unit_tests/test_chat_models.py +++ b/libs/partners/ollama/tests/unit_tests/test_chat_models.py @@ -23,7 +23,7 @@ class TestChatOllama(ChatModelUnitTests): @property def chat_model_params(self) -> dict: - return {"model": "llama3-groq-tool-use"} + return {"model": MODEL_NAME} def test__parse_arguments_from_tool_call() -> None: @@ -51,7 +51,6 @@ def test_arbitrary_roles_accepted_in_chatmessages( monkeypatch.setattr(Client, "stream", _mock_httpx_client_stream) llm = ChatOllama( - base_url="http://whocares:11434", model=MODEL_NAME, verbose=True, format=None, diff --git a/libs/partners/ollama/tests/unit_tests/test_embeddings.py b/libs/partners/ollama/tests/unit_tests/test_embeddings.py index cbca95a994b..6ceec7c5df9 100644 --- a/libs/partners/ollama/tests/unit_tests/test_embeddings.py +++ b/libs/partners/ollama/tests/unit_tests/test_embeddings.py @@ -10,7 +10,7 @@ MODEL_NAME = "llama3.1" def test_initialization() -> None: """Test embedding model initialization.""" - OllamaEmbeddings(model="llama3", keep_alive=1) + OllamaEmbeddings(model=MODEL_NAME, keep_alive=1) @patch("langchain_ollama.embeddings.validate_model") diff --git a/libs/partners/ollama/tests/unit_tests/test_llms.py b/libs/partners/ollama/tests/unit_tests/test_llms.py index 7f68777c96c..2a0ad896f90 100644 --- a/libs/partners/ollama/tests/unit_tests/test_llms.py +++ b/libs/partners/ollama/tests/unit_tests/test_llms.py @@ -10,25 +10,25 @@ MODEL_NAME = "llama3.1" def test_initialization() -> None: """Test integration initialization.""" - OllamaLLM(model="llama3") + OllamaLLM(model=MODEL_NAME) def test_model_params() -> None: # Test standard tracing params - llm = OllamaLLM(model="llama3") + llm = OllamaLLM(model=MODEL_NAME) ls_params = llm._get_ls_params() assert ls_params == { "ls_provider": "ollama", "ls_model_type": "llm", - "ls_model_name": "llama3", + "ls_model_name": MODEL_NAME, } - llm = OllamaLLM(model="llama3", num_predict=3) + llm = OllamaLLM(model=MODEL_NAME, num_predict=3) ls_params = llm._get_ls_params() assert ls_params == { "ls_provider": "ollama", "ls_model_type": "llm", - "ls_model_name": "llama3", + "ls_model_name": MODEL_NAME, "ls_max_tokens": 3, }