mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 16:36:06 +00:00
ollama: thinking, tool streaming, docs, tests (#31772)
* New `reasoning` (bool) param to support toggling [Ollama thinking](https://ollama.com/blog/thinking) (#31573, #31700). If `reasoning=True`, Ollama's `thinking` content will be placed in the model responses' `additional_kwargs.reasoning_content`. * Supported by: * ChatOllama (class level, invocation level TODO) * OllamaLLM (TODO) * Added tests to ensure streaming tool calls is successful (#29129) * Refactored tests that relied on `extract_reasoning()` * Myriad docs additions and consistency/typo fixes * Improved type safety in some spots Closes #29129 Addresses #31573 and #31700 Supersedes #31701
This commit is contained in:
parent
0eb10f31c1
commit
e686a70ee0
@ -38,11 +38,11 @@
|
||||
"\n",
|
||||
"\n",
|
||||
":::caution COMPATIBILITY\n",
|
||||
"LangChain cannot automatically propagate configuration, including callbacks necessary for astream_events(), to child runnables if you are running async code in python<=3.10. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n",
|
||||
"LangChain cannot automatically propagate configuration, including callbacks necessary for astream_events(), to child runnables if you are running async code in `python<=3.10`. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n",
|
||||
"\n",
|
||||
"If you are running python<=3.10, you will need to manually propagate the `RunnableConfig` object to the child runnable in async environments. For an example of how to manually propagate the config, see the implementation of the `bar` RunnableLambda below.\n",
|
||||
"If you are running `python<=3.10`, you will need to manually propagate the `RunnableConfig` object to the child runnable in async environments. For an example of how to manually propagate the config, see the implementation of the `bar` RunnableLambda below.\n",
|
||||
"\n",
|
||||
"If you are running python>=3.11, the `RunnableConfig` will automatically propagate to child runnables in async environment. However, it is still a good idea to propagate the `RunnableConfig` manually if your code may run in other Python versions.\n",
|
||||
"If you are running `python>=3.11`, the `RunnableConfig` will automatically propagate to child runnables in async environment. However, it is still a good idea to propagate the `RunnableConfig` manually if your code may run in other Python versions.\n",
|
||||
":::"
|
||||
]
|
||||
},
|
||||
|
@ -16,15 +16,15 @@
|
||||
"\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
"If you have [tools](/docs/concepts/tools/) that call [chat models](/docs/concepts/chat_models/), [retrievers](/docs/concepts/retrievers/), or other [runnables](/docs/concepts/runnables/), you may want to access internal events from those runnables or configure them with additional properties. This guide shows you how to manually pass parameters properly so that you can do this using the `astream_events()` method.\n",
|
||||
"If you have [tools](/docs/concepts/tools/) that call [chat models](/docs/concepts/chat_models/), [retrievers](/docs/concepts/retrievers/), or other [runnables](/docs/concepts/runnables/), you may want to access [internal events](https://python.langchain.com/docs/how_to/streaming/#event-reference) from those runnables or configure them with additional properties. This guide shows you how to manually pass parameters properly so that you can do this using the `astream_events()` method.\n",
|
||||
"\n",
|
||||
":::caution Compatibility\n",
|
||||
"\n",
|
||||
"LangChain cannot automatically propagate configuration, including callbacks necessary for `astream_events()`, to child runnables if you are running `async` code in `python<=3.10`. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n",
|
||||
"LangChain cannot automatically propagate configuration, including callbacks necessary for `astream_events()`, to child runnables if you are running `async` code in `python<=3.10`. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n",
|
||||
"\n",
|
||||
"If you are running python<=3.10, you will need to manually propagate the `RunnableConfig` object to the child runnable in async environments. For an example of how to manually propagate the config, see the implementation of the `bar` RunnableLambda below.\n",
|
||||
"If you are running `python<=3.10`, you will need to manually propagate the `RunnableConfig` object to the child runnable in async environments. For an example of how to manually propagate the config, see the implementation of the `bar` RunnableLambda below.\n",
|
||||
"\n",
|
||||
"If you are running python>=3.11, the `RunnableConfig` will automatically propagate to child runnables in async environment. However, it is still a good idea to propagate the `RunnableConfig` manually if your code may run in older Python versions.\n",
|
||||
"If you are running `python>=3.11`, the `RunnableConfig` will automatically propagate to child runnables in async environment. However, it is still a good idea to propagate the `RunnableConfig` manually if your code may run in older Python versions.\n",
|
||||
"\n",
|
||||
"This guide also requires `langchain-core>=0.2.16`.\n",
|
||||
":::\n",
|
||||
|
@ -224,6 +224,13 @@
|
||||
"source": [
|
||||
"print(type(gathered.tool_calls[0][\"args\"]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note the key difference: accumulating `tool_call_chunks` captures the raw tool arguments as an unparsed string as they are streamed. In contrast, **accumulating** `tool_calls` demonstrates partial parsing by progressively converting the streamed argument string into a valid, usable dictionary at each step of the process."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -1414,6 +1414,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
ValueError: If the file path is not a string or Path object.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
llm.save(file_path="path/llm.yaml")
|
||||
|
@ -1160,22 +1160,21 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
|
||||
A StreamEvent is a dictionary with the following schema:
|
||||
|
||||
- ``event``: **str** - Event names are of the
|
||||
format: on_[runnable_type]_(start|stream|end).
|
||||
- ``event``: **str** - Event names are of the format:
|
||||
on_[runnable_type]_(start|stream|end).
|
||||
- ``name``: **str** - The name of the Runnable that generated the event.
|
||||
- ``run_id``: **str** - randomly generated ID associated with the given execution of
|
||||
the Runnable that emitted the event.
|
||||
A child Runnable that gets invoked as part of the execution of a
|
||||
parent Runnable is assigned its own unique ID.
|
||||
- ``parent_ids``: **list[str]** - The IDs of the parent runnables that
|
||||
generated the event. The root Runnable will have an empty list.
|
||||
The order of the parent IDs is from the root to the immediate parent.
|
||||
Only available for v2 version of the API. The v1 version of the API
|
||||
will return an empty list.
|
||||
- ``run_id``: **str** - randomly generated ID associated with the given
|
||||
execution of the Runnable that emitted the event. A child Runnable that gets
|
||||
invoked as part of the execution of a parent Runnable is assigned its own
|
||||
unique ID.
|
||||
- ``parent_ids``: **list[str]** - The IDs of the parent runnables that generated
|
||||
the event. The root Runnable will have an empty list. The order of the parent
|
||||
IDs is from the root to the immediate parent. Only available for v2 version of
|
||||
the API. The v1 version of the API will return an empty list.
|
||||
- ``tags``: **Optional[list[str]]** - The tags of the Runnable that generated
|
||||
the event.
|
||||
- ``metadata``: **Optional[dict[str, Any]]** - The metadata of the Runnable
|
||||
that generated the event.
|
||||
- ``metadata``: **Optional[dict[str, Any]]** - The metadata of the Runnable that
|
||||
generated the event.
|
||||
- ``data``: **dict[str, Any]**
|
||||
|
||||
|
||||
@ -1183,7 +1182,7 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
chains. Metadata fields have been omitted from the table for brevity.
|
||||
Chain definitions have been included after the table.
|
||||
|
||||
**ATTENTION** This reference table is for the V2 version of the schema.
|
||||
.. NOTE:: This reference table is for the V2 version of the schema.
|
||||
|
||||
+----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
|
||||
| event | name | chunk | input | output |
|
||||
|
@ -6,7 +6,6 @@ from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Final,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
@ -25,7 +24,6 @@ from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
@ -57,9 +55,6 @@ from typing_extensions import Self, is_typeddict
|
||||
|
||||
from ._utils import validate_model
|
||||
|
||||
DEFAULT_THINK_TOKEN_START: Final[str] = "<think>"
|
||||
DEFAULT_THINK_TOKEN_END: Final[str] = "</think>"
|
||||
|
||||
|
||||
def _get_usage_metadata_from_generation_info(
|
||||
generation_info: Optional[Mapping[str, Any]],
|
||||
@ -166,13 +161,14 @@ def _get_tool_calls_from_response(
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
|
||||
def _lc_tool_call_to_openai_tool_call(tool_call_: ToolCall) -> dict:
|
||||
"""Convert a LangChain tool call to an OpenAI tool call format."""
|
||||
return {
|
||||
"type": "function",
|
||||
"id": tool_call["id"],
|
||||
"id": tool_call_["id"],
|
||||
"function": {
|
||||
"name": tool_call["name"],
|
||||
"arguments": tool_call["args"],
|
||||
"name": tool_call_["name"],
|
||||
"arguments": tool_call_["args"],
|
||||
},
|
||||
}
|
||||
|
||||
@ -211,6 +207,20 @@ class ChatOllama(BaseChatModel):
|
||||
Key init args — completion params:
|
||||
model: str
|
||||
Name of Ollama model to use.
|
||||
reasoning: Optional[bool]
|
||||
Controls the reasoning/thinking mode for
|
||||
`supported models <https://ollama.com/search?c=thinking>`__.
|
||||
|
||||
- ``True``: Enables reasoning mode. The model's reasoning process will be
|
||||
captured and returned separately in the ``additional_kwargs`` of the
|
||||
response message, under ``reasoning_content``. The main response
|
||||
content will not include the reasoning tags.
|
||||
- ``False``: Disables reasoning mode. The model will not perform any reasoning,
|
||||
and the response will not include any reasoning content.
|
||||
- ``None`` (Default): The model will use its default reasoning behavior. Note
|
||||
however, if the model's default behavior *is* to perform reasoning, think tags
|
||||
()``<think>`` and ``</think>``) will be present within the main response content
|
||||
unless you set ``reasoning`` to ``True``.
|
||||
temperature: float
|
||||
Sampling temperature. Ranges from 0.0 to 1.0.
|
||||
num_predict: Optional[int]
|
||||
@ -347,21 +357,29 @@ class ChatOllama(BaseChatModel):
|
||||
'args': {'a': 45, 'b': 67},
|
||||
'id': '420c3f3b-df10-4188-945f-eb3abdb40622',
|
||||
'type': 'tool_call'}]
|
||||
""" # noqa: E501
|
||||
""" # noqa: E501, pylint: disable=line-too-long
|
||||
|
||||
model: str
|
||||
"""Model name to use."""
|
||||
|
||||
reasoning: Optional[bool] = None
|
||||
"""Controls the reasoning/thinking mode for
|
||||
`supported models <https://ollama.com/search?c=thinking>`__.
|
||||
|
||||
- ``True``: Enables reasoning mode. The model's reasoning process will be
|
||||
captured and returned separately in the ``additional_kwargs`` of the
|
||||
response message, under ``reasoning_content``. The main response
|
||||
content will not include the reasoning tags.
|
||||
- ``False``: Disables reasoning mode. The model will not perform any reasoning,
|
||||
and the response will not include any reasoning content.
|
||||
- ``None`` (Default): The model will use its default reasoning behavior. Note
|
||||
however, if the model's default behavior *is* to perform reasoning, think tags
|
||||
()``<think>`` and ``</think>``) will be present within the main response content
|
||||
unless you set ``reasoning`` to ``True``."""
|
||||
|
||||
validate_model_on_init: bool = False
|
||||
"""Whether to validate the model exists in Ollama locally on initialization."""
|
||||
|
||||
extract_reasoning: Optional[Union[bool, tuple[str, str]]] = False
|
||||
"""Whether to extract the reasoning tokens in think blocks.
|
||||
Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`.
|
||||
If a tuple is supplied, they are assumed to be the (start, end) tokens.
|
||||
If `extract_reasoning=True`, the tokens will default to (<think>, </think>).
|
||||
"""
|
||||
|
||||
mirostat: Optional[int] = None
|
||||
"""Enable Mirostat sampling for controlling perplexity.
|
||||
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"""
|
||||
@ -448,24 +466,23 @@ class ChatOllama(BaseChatModel):
|
||||
"""
|
||||
|
||||
async_client_kwargs: Optional[dict] = {}
|
||||
"""Additional kwargs to merge with client_kwargs before passing to the HTTPX
|
||||
AsyncClient.
|
||||
|
||||
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#asyncclient>`__.
|
||||
"""Additional kwargs to merge with client_kwargs before
|
||||
passing to the httpx AsyncClient.
|
||||
`Full list of params. <https://www.python-httpx.org/api/#asyncclient>`__
|
||||
"""
|
||||
|
||||
sync_client_kwargs: Optional[dict] = {}
|
||||
"""Additional kwargs to merge with client_kwargs before passing to the HTTPX Client.
|
||||
|
||||
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#client>`__.
|
||||
"""Additional kwargs to merge with client_kwargs before
|
||||
passing to the httpx Client.
|
||||
`Full list of params. <https://www.python-httpx.org/api/#client>`__
|
||||
"""
|
||||
|
||||
_client: Client = PrivateAttr(default=None) # type: ignore
|
||||
_client: Client = PrivateAttr()
|
||||
"""
|
||||
The client to use for making requests.
|
||||
"""
|
||||
|
||||
_async_client: AsyncClient = PrivateAttr(default=None) # type: ignore
|
||||
_async_client: AsyncClient = PrivateAttr()
|
||||
"""
|
||||
The async client to use for making requests.
|
||||
"""
|
||||
@ -480,7 +497,7 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
if self.stop is not None:
|
||||
stop = self.stop
|
||||
|
||||
options_dict = kwargs.pop(
|
||||
@ -508,6 +525,7 @@ class ChatOllama(BaseChatModel):
|
||||
"messages": ollama_messages,
|
||||
"stream": kwargs.pop("stream", True),
|
||||
"model": kwargs.pop("model", self.model),
|
||||
"think": kwargs.pop("reasoning", self.reasoning),
|
||||
"format": kwargs.pop("format", self.format),
|
||||
"options": Options(**options_dict),
|
||||
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
|
||||
@ -618,35 +636,13 @@ class ChatOllama(BaseChatModel):
|
||||
"images": images,
|
||||
}
|
||||
if tool_calls:
|
||||
msg["tool_calls"] = tool_calls # type: ignore
|
||||
msg["tool_calls"] = tool_calls
|
||||
if tool_call_id:
|
||||
msg["tool_call_id"] = tool_call_id
|
||||
ollama_messages.append(msg)
|
||||
|
||||
return ollama_messages
|
||||
|
||||
def _extract_reasoning(
|
||||
self, message_chunk: BaseMessageChunk, is_thinking: bool
|
||||
) -> tuple[BaseMessageChunk, bool]:
|
||||
"""Mutate a message chunk to extract reasoning content."""
|
||||
if not self.extract_reasoning:
|
||||
return message_chunk, is_thinking
|
||||
elif self.extract_reasoning is True:
|
||||
start_token = DEFAULT_THINK_TOKEN_START
|
||||
end_token = DEFAULT_THINK_TOKEN_END
|
||||
else:
|
||||
start_token, end_token = cast(tuple, self.extract_reasoning)
|
||||
if start_token in message_chunk.content:
|
||||
is_thinking = True
|
||||
content = message_chunk.content
|
||||
if is_thinking:
|
||||
message_chunk.additional_kwargs["reasoning_content"] = content
|
||||
message_chunk.content = ""
|
||||
if end_token in content:
|
||||
is_thinking = False
|
||||
|
||||
return message_chunk, is_thinking
|
||||
|
||||
async def _acreate_chat_stream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@ -670,8 +666,10 @@ class ChatOllama(BaseChatModel):
|
||||
chat_params = self._chat_params(messages, stop, **kwargs)
|
||||
|
||||
if chat_params["stream"]:
|
||||
if self._client:
|
||||
yield from self._client.chat(**chat_params)
|
||||
else:
|
||||
if self._client:
|
||||
yield self._client.chat(**chat_params)
|
||||
|
||||
def _chat_stream_with_aggregation(
|
||||
@ -767,22 +765,34 @@ class ChatOllama(BaseChatModel):
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
is_thinking = False
|
||||
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
|
||||
if not isinstance(stream_resp, str):
|
||||
if stream_resp.get("done") is True:
|
||||
generation_info = dict(stream_resp)
|
||||
if "model" in generation_info:
|
||||
generation_info["model_name"] = generation_info["model"]
|
||||
_ = generation_info.pop("message", None)
|
||||
else:
|
||||
generation_info = None
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
|
||||
content = (
|
||||
stream_resp["message"]["content"]
|
||||
if "message" in stream_resp
|
||||
and "content" in stream_resp["message"]
|
||||
if "message" in stream_resp and "content" in stream_resp["message"]
|
||||
else ""
|
||||
),
|
||||
)
|
||||
|
||||
additional_kwargs = {}
|
||||
if (
|
||||
self.reasoning
|
||||
and "message" in stream_resp
|
||||
and (thinking_content := stream_resp["message"].get("thinking"))
|
||||
):
|
||||
additional_kwargs["reasoning_content"] = thinking_content
|
||||
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
usage_metadata=_get_usage_metadata_from_generation_info(
|
||||
stream_resp
|
||||
),
|
||||
@ -790,15 +800,7 @@ class ChatOllama(BaseChatModel):
|
||||
),
|
||||
generation_info=generation_info,
|
||||
)
|
||||
if chunk.generation_info and (
|
||||
model := chunk.generation_info.get("model")
|
||||
):
|
||||
chunk.generation_info["model_name"] = model # backwards compat
|
||||
if self.extract_reasoning:
|
||||
message, is_thinking = self._extract_reasoning(
|
||||
chunk.message, is_thinking
|
||||
)
|
||||
chunk.message = message
|
||||
|
||||
yield chunk
|
||||
|
||||
def _stream(
|
||||
@ -822,22 +824,34 @@ class ChatOllama(BaseChatModel):
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
is_thinking = False
|
||||
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
|
||||
if not isinstance(stream_resp, str):
|
||||
if stream_resp.get("done") is True:
|
||||
generation_info = dict(stream_resp)
|
||||
if "model" in generation_info:
|
||||
generation_info["model_name"] = generation_info["model"]
|
||||
_ = generation_info.pop("message", None)
|
||||
else:
|
||||
generation_info = None
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
|
||||
content = (
|
||||
stream_resp["message"]["content"]
|
||||
if "message" in stream_resp
|
||||
and "content" in stream_resp["message"]
|
||||
if "message" in stream_resp and "content" in stream_resp["message"]
|
||||
else ""
|
||||
),
|
||||
)
|
||||
|
||||
additional_kwargs = {}
|
||||
if (
|
||||
self.reasoning
|
||||
and "message" in stream_resp
|
||||
and (thinking_content := stream_resp["message"].get("thinking"))
|
||||
):
|
||||
additional_kwargs["reasoning_content"] = thinking_content
|
||||
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
usage_metadata=_get_usage_metadata_from_generation_info(
|
||||
stream_resp
|
||||
),
|
||||
@ -845,15 +859,7 @@ class ChatOllama(BaseChatModel):
|
||||
),
|
||||
generation_info=generation_info,
|
||||
)
|
||||
if chunk.generation_info and (
|
||||
model := chunk.generation_info.get("model")
|
||||
):
|
||||
chunk.generation_info["model_name"] = model # backwards compat
|
||||
if self.extract_reasoning:
|
||||
message, is_thinking = self._extract_reasoning(
|
||||
chunk.message, is_thinking
|
||||
)
|
||||
chunk.message = message
|
||||
|
||||
yield chunk
|
||||
|
||||
async def _astream(
|
||||
@ -950,7 +956,7 @@ class ChatOllama(BaseChatModel):
|
||||
method: The method for steering model generation, one of:
|
||||
|
||||
- "json_schema":
|
||||
Uses Ollama's structured output API: https://ollama.com/blog/structured-outputs
|
||||
Uses Ollama's `structured output API <https://ollama.com/blog/structured-outputs>`__
|
||||
- "function_calling":
|
||||
Uses Ollama's tool-calling API
|
||||
- "json_mode":
|
||||
@ -1267,5 +1273,4 @@ class ChatOllama(BaseChatModel):
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
|
@ -151,12 +151,12 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#client>`__.
|
||||
"""
|
||||
|
||||
_client: Client = PrivateAttr(default=None) # type: ignore
|
||||
_client: Optional[Client] = PrivateAttr(default=None)
|
||||
"""
|
||||
The client to use for making requests.
|
||||
"""
|
||||
|
||||
_async_client: AsyncClient = PrivateAttr(default=None) # type: ignore
|
||||
_async_client: Optional[AsyncClient] = PrivateAttr(default=None)
|
||||
"""
|
||||
The async client to use for making requests.
|
||||
"""
|
||||
@ -270,6 +270,11 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs."""
|
||||
if not self._client:
|
||||
raise ValueError(
|
||||
"Ollama client is not initialized. "
|
||||
"Please ensure Ollama is running and the model is loaded."
|
||||
)
|
||||
embedded_docs = self._client.embed(
|
||||
self.model, texts, options=self._default_params, keep_alive=self.keep_alive
|
||||
)["embeddings"]
|
||||
@ -281,6 +286,11 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs."""
|
||||
if not self._async_client:
|
||||
raise ValueError(
|
||||
"Ollama client is not initialized. "
|
||||
"Please ensure Ollama is running and the model is loaded."
|
||||
)
|
||||
embedded_docs = (
|
||||
await self._async_client.embed(
|
||||
self.model, texts, keep_alive=self.keep_alive
|
||||
|
@ -36,6 +36,20 @@ class OllamaLLM(BaseLLM):
|
||||
model: str
|
||||
"""Model name to use."""
|
||||
|
||||
reasoning: Optional[bool] = True
|
||||
"""Controls the reasoning/thinking mode for
|
||||
`supported models <https://ollama.com/search?c=thinking>`__.
|
||||
|
||||
- ``True``: Enables reasoning mode. The model's reasoning process will be
|
||||
captured and returned separately in the ``additional_kwargs`` of the
|
||||
response message, under ``reasoning_content``. The main response
|
||||
content will not include the reasoning tags.
|
||||
- ``False``: Disables reasoning mode. The model will not perform any reasoning,
|
||||
and the response will not include any reasoning content.
|
||||
- ``None`` (Default): The model will use its default reasoning behavior. If
|
||||
the model performs reasoning, the ``<think>`` and ``</think>`` tags will
|
||||
be present directly within the main response content."""
|
||||
|
||||
validate_model_on_init: bool = False
|
||||
"""Whether to validate the model exists in ollama locally on initialization."""
|
||||
|
||||
@ -137,12 +151,12 @@ class OllamaLLM(BaseLLM):
|
||||
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#client>`__.
|
||||
"""
|
||||
|
||||
_client: Client = PrivateAttr(default=None) # type: ignore
|
||||
_client: Optional[Client] = PrivateAttr(default=None)
|
||||
"""
|
||||
The client to use for making requests.
|
||||
"""
|
||||
|
||||
_async_client: AsyncClient = PrivateAttr(default=None) # type: ignore
|
||||
_async_client: Optional[AsyncClient] = PrivateAttr(default=None)
|
||||
"""
|
||||
The async client to use for making requests.
|
||||
"""
|
||||
@ -155,7 +169,7 @@ class OllamaLLM(BaseLLM):
|
||||
) -> dict[str, Any]:
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
if self.stop is not None:
|
||||
stop = self.stop
|
||||
|
||||
options_dict = kwargs.pop(
|
||||
@ -183,6 +197,7 @@ class OllamaLLM(BaseLLM):
|
||||
"prompt": prompt,
|
||||
"stream": kwargs.pop("stream", True),
|
||||
"model": kwargs.pop("model", self.model),
|
||||
"think": kwargs.pop("reasoning", self.reasoning),
|
||||
"format": kwargs.pop("format", self.format),
|
||||
"options": Options(**options_dict),
|
||||
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
|
||||
@ -230,10 +245,11 @@ class OllamaLLM(BaseLLM):
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
|
||||
if self._async_client:
|
||||
async for part in await self._async_client.generate(
|
||||
**self._generate_params(prompt, stop=stop, **kwargs)
|
||||
): # type: ignore
|
||||
yield part # type: ignore
|
||||
):
|
||||
yield part
|
||||
|
||||
def _create_generate_stream(
|
||||
self,
|
||||
@ -241,9 +257,10 @@ class OllamaLLM(BaseLLM):
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Union[Mapping[str, Any], str]]:
|
||||
if self._client:
|
||||
yield from self._client.generate(
|
||||
**self._generate_params(prompt, stop=stop, **kwargs)
|
||||
) # type: ignore
|
||||
)
|
||||
|
||||
async def _astream_with_aggregation(
|
||||
self,
|
||||
@ -356,11 +373,19 @@ class OllamaLLM(BaseLLM):
|
||||
) -> Iterator[GenerationChunk]:
|
||||
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
|
||||
if not isinstance(stream_resp, str):
|
||||
additional_kwargs = {}
|
||||
if thinking_content := stream_resp.get("thinking"):
|
||||
additional_kwargs["reasoning_content"] = thinking_content
|
||||
|
||||
chunk = GenerationChunk(
|
||||
text=(stream_resp.get("response", "")),
|
||||
generation_info=(
|
||||
dict(stream_resp) if stream_resp.get("done") is True else None
|
||||
generation_info={
|
||||
"finish_reason": self.stop,
|
||||
**additional_kwargs,
|
||||
**(
|
||||
dict(stream_resp) if stream_resp.get("done") is True else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
@ -378,11 +403,19 @@ class OllamaLLM(BaseLLM):
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
|
||||
if not isinstance(stream_resp, str):
|
||||
additional_kwargs = {}
|
||||
if thinking_content := stream_resp.get("thinking"):
|
||||
additional_kwargs["reasoning_content"] = thinking_content
|
||||
|
||||
chunk = GenerationChunk(
|
||||
text=(stream_resp.get("response", "")),
|
||||
generation_info=(
|
||||
dict(stream_resp) if stream_resp.get("done") is True else None
|
||||
generation_info={
|
||||
"finish_reason": self.stop,
|
||||
**additional_kwargs,
|
||||
**(
|
||||
dict(stream_resp) if stream_resp.get("done") is True else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
|
@ -1,17 +1,28 @@
|
||||
"""Ollama specific chat model integration tests for reasoning models."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessageChunk, HumanMessage
|
||||
from pydantic import ValidationError
|
||||
from langchain_core.messages import (
|
||||
AIMessageChunk,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
SAMPLE = "What is 3^3?"
|
||||
|
||||
|
||||
class MathAnswer(BaseModel):
|
||||
"""A mathematical expression and its numerical answer."""
|
||||
|
||||
expression: str = Field(description="The mathematical expression to evaluate.")
|
||||
answer: int = Field(description="The numerical answer to the expression.")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_stream_no_reasoning(model: str) -> None:
|
||||
"""Test deepseek model without parsing."""
|
||||
def test_stream_no_reasoning(model: str) -> None:
|
||||
"""Test streaming with `reasoning=False`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12)
|
||||
messages = [
|
||||
{
|
||||
@ -28,14 +39,41 @@ def test_deepseek_messages_stream_no_reasoning(model: str) -> None:
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "<think>" in result.content and "</think>" in result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_stream_bool(model: str) -> None:
|
||||
"""Test deepseek model with reasoning bool=True"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True)
|
||||
async def test_astream_no_reasoning(model: str) -> None:
|
||||
"""Test async streaming with `reasoning=False`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_stream_reasoning_none(model: str) -> None:
|
||||
"""Test streaming with `reasoning=None`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
@ -51,26 +89,41 @@ def test_deepseek_messages_stream_bool(model: str) -> None:
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" in result.additional_kwargs["reasoning_content"]
|
||||
clean_content = (
|
||||
result.additional_kwargs["reasoning_content"]
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.strip()
|
||||
)
|
||||
assert len(clean_content) > 0
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" in result.content and "</think>" in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_stream_tuple(model: str) -> None:
|
||||
"""Test deepseek model with reasoning with tuple=..."""
|
||||
llm = ChatOllama(
|
||||
model=model, num_ctx=2**12, extract_reasoning=("<think>", "</think>")
|
||||
)
|
||||
async def test_astream_reasoning_none(model: str) -> None:
|
||||
"""Test async streaming with `reasoning=None`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" in result.content and "</think>" in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_reasoning_stream(model: str) -> None:
|
||||
"""Test streaming with `reasoning=True`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
@ -86,77 +139,114 @@ def test_deepseek_messages_stream_tuple(model: str) -> None:
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" in result.additional_kwargs["reasoning_content"]
|
||||
clean_content = (
|
||||
result.additional_kwargs["reasoning_content"]
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.strip()
|
||||
)
|
||||
assert len(clean_content) > 0
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_invoke_no_reasoning(model: str) -> None:
|
||||
"""Test deepseek model without parsing using invoke."""
|
||||
async def test_reasoning_astream(model: str) -> None:
|
||||
"""Test async streaming with `reasoning=True`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_invoke_no_reasoning(model: str) -> None:
|
||||
"""Test using invoke with `reasoning=False`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = llm.invoke([message])
|
||||
assert result.content
|
||||
assert "<think>" in result.content and "</think>" in result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_invoke_bool(model: str) -> None:
|
||||
"""Test deepseek model with reasoning bool=True using invoke"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True)
|
||||
async def test_ainvoke_no_reasoning(model: str) -> None:
|
||||
"""Test using async invoke with `reasoning=False`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = await llm.ainvoke([message])
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_invoke_reasoning_none(model: str) -> None:
|
||||
"""Test using invoke with `reasoning=None`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = llm.invoke([message])
|
||||
assert result.content
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" in result.additional_kwargs["reasoning_content"]
|
||||
clean_content = (
|
||||
result.additional_kwargs["reasoning_content"]
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.strip()
|
||||
)
|
||||
assert len(clean_content) > 0
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" in result.content and "</think>" in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_messages_invoke_tuple(model: str) -> None:
|
||||
"""Test deepseek model with reasoning with tuple=... using invoke"""
|
||||
llm = ChatOllama(
|
||||
model=model, num_ctx=2**12, extract_reasoning=("<think>", "</think>")
|
||||
)
|
||||
async def test_ainvoke_reasoning_none(model: str) -> None:
|
||||
"""Test using async invoke with `reasoning=None`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = await llm.ainvoke([message])
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
assert "<think>" in result.content and "</think>" in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_reasoning_invoke(model: str) -> None:
|
||||
"""Test invoke with `reasoning=True`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = llm.invoke([message])
|
||||
assert result.content
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" in result.additional_kwargs["reasoning_content"]
|
||||
clean_content = (
|
||||
result.additional_kwargs["reasoning_content"]
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.strip()
|
||||
)
|
||||
assert len(clean_content) > 0
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_deepseek_invalid(model: str) -> None:
|
||||
"""Test deepseek model with reasoning raises ValidationError"""
|
||||
with pytest.raises(ValidationError):
|
||||
_ = ChatOllama(model=model, extract_reasoning={"invalid": "data"}) # type: ignore[arg-type]
|
||||
async def test_reasoning_ainvoke(model: str) -> None:
|
||||
"""Test invoke with `reasoning=True`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = await llm.ainvoke([message])
|
||||
assert result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from httpx import ConnectError
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolCallChunk
|
||||
from langchain_core.tools import tool
|
||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||
from ollama import ResponseError
|
||||
from pydantic import ValidationError
|
||||
@ -14,6 +16,15 @@ from langchain_ollama.chat_models import ChatOllama
|
||||
DEFAULT_MODEL_NAME = "llama3.1"
|
||||
|
||||
|
||||
@tool
|
||||
def get_current_weather(location: str) -> dict:
|
||||
"""Gets the current weather in a given location."""
|
||||
if "boston" in location.lower():
|
||||
return {"temperature": "15°F", "conditions": "snow"}
|
||||
else:
|
||||
return {"temperature": "unknown", "conditions": "unknown"}
|
||||
|
||||
|
||||
class TestChatOllama(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> type[ChatOllama]:
|
||||
@ -29,12 +40,104 @@ class TestChatOllama(ChatModelIntegrationTests):
|
||||
|
||||
@property
|
||||
def has_tool_choice(self) -> bool:
|
||||
return False # TODO: update after Ollama implements
|
||||
# TODO: update after Ollama implements
|
||||
# https://github.com/ollama/ollama/blob/main/docs/openai.md
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_image_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
def test_tool_streaming(self, model: BaseChatModel) -> None:
|
||||
"""Test that the model can stream tool calls."""
|
||||
chat_model_with_tools = model.bind_tools([get_current_weather])
|
||||
|
||||
prompt = [HumanMessage("What is the weather today in Boston?")]
|
||||
|
||||
# Flags and collectors for validation
|
||||
tool_chunk_found = False
|
||||
final_tool_calls = []
|
||||
collected_tool_chunks: list[ToolCallChunk] = []
|
||||
|
||||
# Stream the response and inspect the chunks
|
||||
for chunk in chat_model_with_tools.stream(prompt):
|
||||
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
|
||||
|
||||
if chunk.tool_call_chunks:
|
||||
tool_chunk_found = True
|
||||
for tc_chunk in chunk.tool_call_chunks:
|
||||
collected_tool_chunks.append(tc_chunk)
|
||||
|
||||
if chunk.tool_calls:
|
||||
final_tool_calls.extend(chunk.tool_calls)
|
||||
|
||||
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
|
||||
assert (
|
||||
len(final_tool_calls) == 1
|
||||
), f"Expected 1 final tool call, but got {len(final_tool_calls)}"
|
||||
|
||||
final_tool_call = final_tool_calls[0]
|
||||
assert final_tool_call["name"] == "get_current_weather"
|
||||
assert final_tool_call["args"] == {"location": "Boston"}
|
||||
|
||||
assert len(collected_tool_chunks) > 0
|
||||
assert collected_tool_chunks[0]["name"] == "get_current_weather"
|
||||
|
||||
# The ID should be consistent across chunks that have it
|
||||
tool_call_id = collected_tool_chunks[0].get("id")
|
||||
assert tool_call_id is not None
|
||||
assert all(
|
||||
chunk.get("id") == tool_call_id
|
||||
for chunk in collected_tool_chunks
|
||||
if chunk.get("id")
|
||||
)
|
||||
assert final_tool_call["id"] == tool_call_id
|
||||
|
||||
async def test_tool_astreaming(self, model: BaseChatModel) -> None:
|
||||
"""Test that the model can stream tool calls."""
|
||||
chat_model_with_tools = model.bind_tools([get_current_weather])
|
||||
|
||||
prompt = [HumanMessage("What is the weather today in Boston?")]
|
||||
|
||||
# Flags and collectors for validation
|
||||
tool_chunk_found = False
|
||||
final_tool_calls = []
|
||||
collected_tool_chunks: list[ToolCallChunk] = []
|
||||
|
||||
# Stream the response and inspect the chunks
|
||||
async for chunk in chat_model_with_tools.astream(prompt):
|
||||
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
|
||||
|
||||
if chunk.tool_call_chunks:
|
||||
tool_chunk_found = True
|
||||
for tc_chunk in chunk.tool_call_chunks:
|
||||
collected_tool_chunks.append(tc_chunk)
|
||||
|
||||
if chunk.tool_calls:
|
||||
final_tool_calls.extend(chunk.tool_calls)
|
||||
|
||||
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
|
||||
assert (
|
||||
len(final_tool_calls) == 1
|
||||
), f"Expected 1 final tool call, but got {len(final_tool_calls)}"
|
||||
|
||||
final_tool_call = final_tool_calls[0]
|
||||
assert final_tool_call["name"] == "get_current_weather"
|
||||
assert final_tool_call["args"] == {"location": "Boston"}
|
||||
|
||||
assert len(collected_tool_chunks) > 0
|
||||
assert collected_tool_chunks[0]["name"] == "get_current_weather"
|
||||
|
||||
# The ID should be consistent across chunks that have it
|
||||
tool_call_id = collected_tool_chunks[0].get("id")
|
||||
assert tool_call_id is not None
|
||||
assert all(
|
||||
chunk.get("id") == tool_call_id
|
||||
for chunk in collected_tool_chunks
|
||||
if chunk.get("id")
|
||||
)
|
||||
assert final_tool_call["id"] == tool_call_id
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"Will sometime encounter AssertionErrors where tool responses are "
|
||||
|
@ -1,11 +1,15 @@
|
||||
"""Test OllamaLLM llm."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from langchain_ollama.llms import OllamaLLM
|
||||
|
||||
MODEL_NAME = "llama3.1"
|
||||
|
||||
SAMPLE = "What is 3^3?"
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
@ -15,6 +19,59 @@ def test_stream() -> None:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_stream_no_reasoning(model: str) -> None:
|
||||
"""Test streaming with `reasoning=False`"""
|
||||
llm = OllamaLLM(model=model, num_ctx=2**12)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
# Sanity check the old behavior isn't present
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_reasoning_stream(model: str) -> None:
|
||||
"""Test streaming with `reasoning=True`"""
|
||||
llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
|
||||
# Sanity check the old behavior isn't present
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
@ -23,6 +80,59 @@ async def test_astream() -> None:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
async def test_astream_no_reasoning(model: str) -> None:
|
||||
"""Test async streaming with `reasoning=False`"""
|
||||
llm = OllamaLLM(model=model, num_ctx=2**12)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
# Sanity check the old behavior isn't present
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
async def test_reasoning_astream(model: str) -> None:
|
||||
"""Test async streaming with `reasoning=True`"""
|
||||
llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
|
||||
# Sanity check the old behavior isn't present
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
"""Test streaming tokens from OllamaLLM."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
@ -60,8 +170,68 @@ async def test_ainvoke() -> None:
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
# TODO
|
||||
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
# async def test_ainvoke_no_reasoning(model: str) -> None:
|
||||
# """Test using async invoke with `reasoning=False`"""
|
||||
# llm = OllamaLLM(model=model, num_ctx=2**12)
|
||||
# message = SAMPLE
|
||||
# result = await llm.ainvoke(message)
|
||||
# assert result.content
|
||||
# assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
# # Sanity check the old behavior isn't present
|
||||
# assert "<think>" not in result.content and "</think>" not in result.content
|
||||
|
||||
|
||||
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
# async def test_reasoning_ainvoke(model: str) -> None:
|
||||
# """Test invoke with `reasoning=True`"""
|
||||
# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
||||
# message = SAMPLE
|
||||
# result = await llm.ainvoke(message)
|
||||
# assert result.content
|
||||
# assert "reasoning_content" in result.additional_kwargs
|
||||
# assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
|
||||
# # Sanity check the old behavior isn't present
|
||||
# assert "<think>" not in result.content and "</think>" not in result.content
|
||||
# assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
# assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
"""Test invoke tokens from OllamaLLM."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
# TODO
|
||||
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
# def test_invoke_no_reasoning(model: str) -> None:
|
||||
# """Test using invoke with `reasoning=False`"""
|
||||
# llm = OllamaLLM(model=model, num_ctx=2**12)
|
||||
# message = SAMPLE
|
||||
# result = llm.invoke(message)
|
||||
# assert result.content
|
||||
# assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
# # Sanity check the old behavior isn't present
|
||||
# assert "<think>" not in result.content and "</think>" not in result.content
|
||||
|
||||
|
||||
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
# def test_reasoning_invoke(model: str) -> None:
|
||||
# """Test invoke with `reasoning=True`"""
|
||||
# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
||||
# message = SAMPLE
|
||||
# result = llm.invoke(message)
|
||||
# assert result.content
|
||||
# assert "reasoning_content" in result.additional_kwargs
|
||||
# assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
|
||||
# # Sanity check the old behavior isn't present
|
||||
# assert "<think>" not in result.content and "</think>" not in result.content
|
||||
# assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
# assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
@ -23,7 +23,7 @@ class TestChatOllama(ChatModelUnitTests):
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"model": "llama3-groq-tool-use"}
|
||||
return {"model": MODEL_NAME}
|
||||
|
||||
|
||||
def test__parse_arguments_from_tool_call() -> None:
|
||||
@ -51,7 +51,6 @@ def test_arbitrary_roles_accepted_in_chatmessages(
|
||||
monkeypatch.setattr(Client, "stream", _mock_httpx_client_stream)
|
||||
|
||||
llm = ChatOllama(
|
||||
base_url="http://whocares:11434",
|
||||
model=MODEL_NAME,
|
||||
verbose=True,
|
||||
format=None,
|
||||
|
@ -10,7 +10,7 @@ MODEL_NAME = "llama3.1"
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test embedding model initialization."""
|
||||
OllamaEmbeddings(model="llama3", keep_alive=1)
|
||||
OllamaEmbeddings(model=MODEL_NAME, keep_alive=1)
|
||||
|
||||
|
||||
@patch("langchain_ollama.embeddings.validate_model")
|
||||
|
@ -10,25 +10,25 @@ MODEL_NAME = "llama3.1"
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test integration initialization."""
|
||||
OllamaLLM(model="llama3")
|
||||
OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
|
||||
def test_model_params() -> None:
|
||||
# Test standard tracing params
|
||||
llm = OllamaLLM(model="llama3")
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_provider": "ollama",
|
||||
"ls_model_type": "llm",
|
||||
"ls_model_name": "llama3",
|
||||
"ls_model_name": MODEL_NAME,
|
||||
}
|
||||
|
||||
llm = OllamaLLM(model="llama3", num_predict=3)
|
||||
llm = OllamaLLM(model=MODEL_NAME, num_predict=3)
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_provider": "ollama",
|
||||
"ls_model_type": "llm",
|
||||
"ls_model_name": "llama3",
|
||||
"ls_model_name": MODEL_NAME,
|
||||
"ls_max_tokens": 3,
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user