From 9c4e6124b67efd83c4142f03f289fdd46fbdd716 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Mon, 4 Aug 2025 14:18:43 -0400 Subject: [PATCH] updates --- .../ollama/langchain_ollama/chat_models.py | 63 +- .../ollama/langchain_ollama/chat_models_v1.py | 647 ++++++++++++++++-- 2 files changed, 608 insertions(+), 102 deletions(-) diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index b9bff6b1652..446856b8d9f 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -449,26 +449,26 @@ class ChatOllama(BaseChatModel): mirostat: Optional[int] = None """Enable Mirostat sampling for controlling perplexity. - (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)""" + (default: ``0``, ``0`` = disabled, ``1`` = Mirostat, ``2`` = Mirostat 2.0)""" mirostat_eta: Optional[float] = None """Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make - the algorithm more responsive. (Default: 0.1)""" + the algorithm more responsive. (Default: ``0.1``)""" mirostat_tau: Optional[float] = None """Controls the balance between coherence and diversity of the output. A lower value will result in more focused and - coherent text. (Default: 5.0)""" + coherent text. (Default: ``5.0``)""" num_ctx: Optional[int] = None """Sets the size of the context window used to generate the - next token. (Default: 2048) """ + next token. (Default: ``2048``) """ num_gpu: Optional[int] = None - """The number of GPUs to use. On macOS it defaults to 1 to - enable metal support, 0 to disable.""" + """The number of GPUs to use. On macOS it defaults to ``1`` to + enable metal support, ``0`` to disable.""" num_thread: Optional[int] = None """Sets the number of threads to use during computation. @@ -478,20 +478,20 @@ class ChatOllama(BaseChatModel): num_predict: Optional[int] = None """Maximum number of tokens to predict when generating text. - (Default: 128, -1 = infinite generation, -2 = fill context)""" + (Default: ``128``, ``-1`` = infinite generation, ``-2`` = fill context)""" repeat_last_n: Optional[int] = None """Sets how far back for the model to look back to prevent - repetition. (Default: 64, 0 = disabled, -1 = num_ctx)""" + repetition. (Default: ``64``, ``0`` = disabled, ``-1`` = num_ctx)""" repeat_penalty: Optional[float] = None - """Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) - will penalize repetitions more strongly, while a lower value (e.g., 0.9) - will be more lenient. (Default: 1.1)""" + """Sets how strongly to penalize repetitions. A higher value (e.g., ``1.5``) + will penalize repetitions more strongly, while a lower value (e.g., ``0.9``) + will be more lenient. (Default: ``1.1``)""" temperature: Optional[float] = None """The temperature of the model. Increasing the temperature will - make the model answer more creatively. (Default: 0.8)""" + make the model answer more creatively. (Default: ``0.8``)""" seed: Optional[int] = None """Sets the random number seed to use for generation. Setting this @@ -503,21 +503,21 @@ class ChatOllama(BaseChatModel): tfs_z: Optional[float] = None """Tail free sampling is used to reduce the impact of less probable - tokens from the output. A higher value (e.g., 2.0) will reduce the - impact more, while a value of 1.0 disables this setting. (default: 1)""" + tokens from the output. A higher value (e.g., ``2.0``) will reduce the + impact more, while a value of ``1.0`` disables this setting. (default: ``1``)""" top_k: Optional[int] = None - """Reduces the probability of generating nonsense. A higher value (e.g. 100) - will give more diverse answers, while a lower value (e.g. 10) - will be more conservative. (Default: 40)""" + """Reduces the probability of generating nonsense. A higher value (e.g. ``100``) + will give more diverse answers, while a lower value (e.g. ``10``) + will be more conservative. (Default: ``40``)""" top_p: Optional[float] = None - """Works together with top-k. A higher value (e.g., 0.95) will lead - to more diverse text, while a lower value (e.g., 0.5) will - generate more focused and conservative text. (Default: 0.9)""" + """Works together with top-k. A higher value (e.g., ``0.95``) will lead + to more diverse text, while a lower value (e.g., ``0.5``) will + generate more focused and conservative text. (Default: ``0.9``)""" format: Optional[Union[Literal["", "json"], JsonSchemaValue]] = None - """Specify the format of the output (options: "json", JSON schema).""" + """Specify the format of the output (options: ``'json'``, JSON schema).""" keep_alive: Optional[Union[int, str]] = None """How long the model will stay loaded into memory.""" @@ -527,32 +527,35 @@ class ChatOllama(BaseChatModel): client_kwargs: Optional[dict] = {} """Additional kwargs to pass to the httpx clients. + These arguments are passed to both synchronous and async clients. - Use sync_client_kwargs and async_client_kwargs to pass different arguments + + Use ``sync_client_kwargs`` and ``async_client_kwargs`` to pass different arguments to synchronous and asynchronous clients. + """ async_client_kwargs: Optional[dict] = {} - """Additional kwargs to merge with client_kwargs before + """Additional kwargs to merge with ``client_kwargs`` before passing to the httpx AsyncClient. + `Full list of params. `__ + """ sync_client_kwargs: Optional[dict] = {} - """Additional kwargs to merge with client_kwargs before + """Additional kwargs to merge with ``client_kwargs`` before passing to the httpx Client. + `Full list of params. `__ + """ _client: Client = PrivateAttr() - """ - The client to use for making requests. - """ + """The client to use for making requests.""" _async_client: AsyncClient = PrivateAttr() - """ - The async client to use for making requests. - """ + """The async client to use for making requests.""" def _chat_params( self, diff --git a/libs/partners/ollama/langchain_ollama/chat_models_v1.py b/libs/partners/ollama/langchain_ollama/chat_models_v1.py index 28024f5150a..a53cad68d55 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models_v1.py +++ b/libs/partners/ollama/langchain_ollama/chat_models_v1.py @@ -8,26 +8,45 @@ content blocks and always returns AIMessageV1 format responses. from __future__ import annotations +import ast +import json import logging from collections.abc import AsyncIterator, Iterator, Sequence -from typing import Any, Callable, Literal, Optional, Union +from operator import itemgetter +from typing import Any, Callable, Literal, Optional, Union, cast from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun +from langchain_core.exceptions import OutputParserException from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import LangSmithParams -from langchain_core.language_models.v1.chat_models import BaseChatModelV1 +from langchain_core.language_models.v1.chat_models import ( + BaseChatModelV1, + agenerate_from_stream, + generate_from_stream, +) from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.v1 import AIMessage as AIMessageV1 from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 from langchain_core.messages.v1 import MessageV1 -from langchain_core.runnables import Runnable +from langchain_core.output_parsers import ( + JsonOutputKeyToolsParser, + JsonOutputParser, + PydanticOutputParser, + PydanticToolsParser, +) +from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool -from langchain_core.utils.function_calling import convert_to_openai_tool +from langchain_core.utils.function_calling import ( + convert_to_json_schema, + convert_to_openai_tool, +) +from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass from ollama import AsyncClient, Client, Options -from pydantic import PrivateAttr, model_validator +from pydantic import BaseModel, PrivateAttr, model_validator from pydantic.json_schema import JsonSchemaValue -from typing_extensions import Self +from pydantic.v1 import BaseModel as BaseModelV1 +from typing_extensions import Self, is_typeddict from ._compat import ( _convert_chunk_to_v1, @@ -54,8 +73,236 @@ def _get_usage_metadata_from_response( return None -class BaseChatOllamaV1(BaseChatModelV1): - """Base class for Ollama v1 chat models.""" +def _parse_json_string( + json_string: str, + *, + raw_tool_call: dict[str, Any], + skip: bool, +) -> Any: + """Attempt to parse a JSON string for tool calling. + + It first tries to use the standard ``json.loads``. If that fails, it falls + back to ``ast.literal_eval`` to safely parse Python literals, which is more + robust against models using single quotes or containing apostrophes. + + Args: + json_string: JSON string to parse. + raw_tool_call: Raw tool call to include in error message. + skip: Whether to ignore parsing errors and return the value anyways. + + Returns: + The parsed JSON string or Python literal. + + Raises: + OutputParserException: If the string is invalid and ``skip=False``. + + """ + try: + return json.loads(json_string) + except json.JSONDecodeError: + try: + # Use ast.literal_eval to safely parse Python-style dicts + # (e.g. with single quotes) + return ast.literal_eval(json_string) + except (SyntaxError, ValueError) as e: + # If both fail, and we're not skipping, raise an informative error. + if skip: + return json_string + msg = ( + f"Function {raw_tool_call['function']['name']} arguments:\n\n" + f"{raw_tool_call['function']['arguments']}" + "\n\nare not valid JSON or a Python literal. " + f"Received error: {e}" + ) + raise OutputParserException(msg) from e + except TypeError as e: + if skip: + return json_string + msg = ( + f"Function {raw_tool_call['function']['name']} arguments:\n\n" + f"{raw_tool_call['function']['arguments']}\n\nare not a string or a " + f"dictionary. Received TypeError {e}" + ) + raise OutputParserException(msg) from e + + +def _parse_arguments_from_tool_call( + raw_tool_call: dict[str, Any], +) -> Optional[dict[str, Any]]: + """Parse arguments by trying to parse any shallowly nested string-encoded JSON. + + Band-aid fix for issue in Ollama with inconsistent tool call argument structure. + Should be removed/changed if fixed upstream. + + See https://github.com/ollama/ollama/issues/6155 + + """ + if "function" not in raw_tool_call: + return None + arguments = raw_tool_call["function"]["arguments"] + parsed_arguments: dict = {} + if isinstance(arguments, dict): + for key, value in arguments.items(): + if isinstance(value, str): + parsed_value = _parse_json_string( + value, skip=True, raw_tool_call=raw_tool_call + ) + if isinstance(parsed_value, (dict, list)): + parsed_arguments[key] = parsed_value + else: + parsed_arguments[key] = value + else: + parsed_arguments[key] = value + else: + parsed_arguments = _parse_json_string( + arguments, skip=False, raw_tool_call=raw_tool_call + ) + return parsed_arguments + + +# Missing: _get_tool_calls_from_response + + +# Missing: _lc_tool_call_to_openai_tool_call + + +# Missing: _get_image_from_data_content_block + + +def _is_pydantic_class(obj: Any) -> bool: + return isinstance(obj, type) and is_basemodel_subclass(obj) + + +class ChatOllamaV1(BaseChatModelV1): + r"""Ollama chat model with native v1 message/content block support. + + This implementation provides native support for structured content blocks + and always returns AIMessageV1 format responses with structured content. + + .. dropdown:: Setup + :open: + + Install ``langchain-ollama`` and download any models you want to use from ollama. + + .. code-block:: bash + + ollama pull mistral:v0.3 + pip install -U langchain-ollama + + Key init args — completion params: + model: str + Name of Ollama model to use. + reasoning: Optional[bool] + Controls the reasoning/thinking mode for + `supported models `__. + + - ``True``: Enables reasoning mode. The model's reasoning process will be + captured and returned as a ``ReasoningContentBlock`` in the response + message content. The main response content will not include the reasoning tags. + - ``False``: Disables reasoning mode. The model will not perform any reasoning, + and the response will not include any reasoning content. + - ``None`` (Default): The model will use its default reasoning behavior. Note + however, if the model's default behavior *is* to perform reasoning, think tags + (```` and ````) will be present within the main response ``TextContentBlock``s + unless you set ``reasoning`` to ``True``. + temperature: float + Sampling temperature. Ranges from ``0.0`` to ``1.0``. + num_predict: Optional[int] + Max number of tokens to generate. + + See full list of supported init args and their descriptions in the params section. + + Instantiate: + .. code-block:: python + + from langchain_ollama import ChatOllamaV1 + + llm = ChatOllamaV1( + model = "llama3", + temperature = 0.8, + num_predict = 256, + # other params ... + ) + + Invoke: + .. code-block:: python + + from langchain_core.messages.v1 import HumanMessage + from langchain_core.messages.content_blocks import TextContentBlock + + messages = [ + HumanMessage(content=[ + TextContentBlock(type="text", text="Hello!") + ]) + ] + llm.invoke(messages) + + .. code-block:: python + + AIMessageV1(content=[{'type': 'text', 'text': 'Hello! How can I help you today?'}], response_metadata={'model': 'llama3', 'created_at': '2024-07-04T03:37:50.182604Z', 'done_reason': 'stop', 'done': True, 'total_duration': 3576619666, 'load_duration': 788524916, 'prompt_eval_count': 32, 'prompt_eval_duration': 128125000, 'eval_count': 71, 'eval_duration': 2656556000}, id='run-ba48f958-6402-41a5-b461-5e250a4ebd36-0') + + Stream: + .. code-block:: python + + from langchain_core.messages.v1 import HumanMessage + from langchain_core.messages.content_blocks import TextContentBlock + + messages = [ + HumanMessage(content=[ + TextContentBlock(type="text", text="Return the words Hello World!") + ]) + ] + for chunk in llm.stream(messages): + print(chunk.content, end="") + + .. code-block:: python + + [{'type': 'text', 'text': 'Hello'}] + [{'type': 'text', 'text': ' World'}] + [{'type': 'text', 'text': '!'}] + + Multi-modal input: + .. code-block:: python + + from langchain_core.messages.content_blocks import ImageContentBlock + + response = llm.invoke([ + HumanMessage(content=[ + TextContentBlock(type="text", text="Describe this image:"), + ImageContentBlock( + type="image", + mime_type="image/jpeg", + base64="base64_encoded_image", + ) + ]) + ]) + + Tool Calling: + .. code-block:: python + + from langchain_ollama import ChatOllamaV1 + from pydantic import BaseModel, Field + + class Multiply(BaseModel): + a: int = Field(..., description="First integer") + b: int = Field(..., description="Second integer") + + llm_with_tools = llm.bind_tools([Multiply]) + ans = llm_with_tools.invoke([ + HumanMessage(content=[ + TextContentBlock(type="text", text="What is 45*67") + ]) + ]) + ans.tool_calls + + .. code-block:: python + + [{'name': 'Multiply', + 'args': {'a': 45, 'b': 67}, + 'id': '420c3f3b-df10-4188-945f-eb3abdb40622', + 'type': 'tool_call'}] + + """ # noqa: E501, pylint: disable=line-too-long model: str """Model name to use.""" @@ -79,52 +326,76 @@ class BaseChatOllamaV1(BaseChatModelV1): # Ollama-specific parameters mirostat: Optional[int] = None - """Enable Mirostat sampling for controlling perplexity.""" + """Enable Mirostat sampling for controlling perplexity. + (default: ``0``, ``0`` = disabled, ``1`` = Mirostat, ``2`` = Mirostat 2.0)""" mirostat_eta: Optional[float] = None - """Influences how quickly the algorithm responds to feedback.""" + """Influences how quickly the algorithm responds to feedback + from the generated text. A lower learning rate will result in + slower adjustments, while a higher learning rate will make + the algorithm more responsive. (Default: ``0.1``)""" mirostat_tau: Optional[float] = None - """Controls the balance between coherence and diversity.""" + """Controls the balance between coherence and diversity + of the output. A lower value will result in more focused and + coherent text. (Default: ``5.0``)""" num_ctx: Optional[int] = None - """Sets the size of the context window.""" + """Sets the size of the context window used to generate the + next token. (Default: ``2048``) """ num_gpu: Optional[int] = None - """The number of GPUs to use.""" + """The number of GPUs to use. On macOS it defaults to ``1`` to + enable metal support, ``0`` to disable.""" num_thread: Optional[int] = None - """Sets the number of threads to use during computation.""" + """Sets the number of threads to use during computation. + By default, Ollama will detect this for optimal performance. + It is recommended to set this value to the number of physical + CPU cores your system has (as opposed to the logical number of cores).""" num_predict: Optional[int] = None - """Maximum number of tokens to predict.""" + """Maximum number of tokens to predict when generating text. + (Default: ``128``, ``-1`` = infinite generation, ``-2`` = fill context)""" repeat_last_n: Optional[int] = None - """Sets how far back for the model to look back to prevent repetition.""" + """Sets how far back for the model to look back to prevent + repetition. (Default: ``64``, ``0`` = disabled, ``-1`` = ``num_ctx``)""" repeat_penalty: Optional[float] = None - """Sets how strongly to penalize repetitions.""" + """Sets how strongly to penalize repetitions. A higher value (e.g., ``1.5``) + will penalize repetitions more strongly, while a lower value (e.g., ``0.9``) + will be more lenient. (Default: ``1.1``)""" temperature: Optional[float] = None - """The temperature of the model.""" + """The temperature of the model. Increasing the temperature will + make the model answer more creatively. (Default: ``0.8``)""" seed: Optional[int] = None - """Sets the random number seed to use for generation.""" + """Sets the random number seed to use for generation. Setting this + to a specific number will make the model generate the same text for + the same prompt.""" stop: Optional[list[str]] = None """Sets the stop tokens to use.""" tfs_z: Optional[float] = None - """Tail free sampling parameter.""" + """Tail free sampling is used to reduce the impact of less probable + tokens from the output. A higher value (e.g., ``2.0``) will reduce the + impact more, while a value of ``1.0`` disables this setting. (default: ``1``)""" top_k: Optional[int] = None - """Reduces the probability of generating nonsense.""" + """Reduces the probability of generating nonsense. A higher value (e.g. ``100``) + will give more diverse answers, while a lower value (e.g. ``10``) + will be more conservative. (Default: ``40``)""" top_p: Optional[float] = None - """Works together with top-k.""" + """Works together with top-k. A higher value (e.g., ``0.95``) will lead + to more diverse text, while a lower value (e.g., ``0.5``) will + generate more focused and conservative text. (Default: ``0.9``)""" format: Optional[Union[Literal["", "json"], JsonSchemaValue]] = None - """Specify the format of the output.""" + """Specify the format of the output (options: ``'json'``, JSON schema).""" keep_alive: Optional[Union[int, str]] = None """How long the model will stay loaded into memory.""" @@ -133,16 +404,36 @@ class BaseChatOllamaV1(BaseChatModelV1): """Base url the model is hosted under.""" client_kwargs: Optional[dict] = {} - """Additional kwargs to pass to the httpx clients.""" + """Additional kwargs to pass to the httpx clients. + + These arguments are passed to both synchronous and async clients. + + Use ``sync_client_kwargs`` and ``async_client_kwargs`` to pass different arguments + to synchronous and asynchronous clients. + + """ async_client_kwargs: Optional[dict] = {} - """Additional kwargs for the async httpx client.""" + """Additional kwargs to merge with ``client_kwargs`` before + passing to the httpx AsyncClient. + + `Full list of params. `__ + + """ sync_client_kwargs: Optional[dict] = {} - """Additional kwargs for the sync httpx client.""" + """Additional kwargs to merge with ``client_kwargs`` before + passing to the httpx Client. + + `Full list of params. `__ + + """ _client: Client = PrivateAttr() + """The client to use for making requests.""" + _async_client: AsyncClient = PrivateAttr() + """The async client to use for making requests.""" @model_validator(mode="after") def _set_clients(self) -> Self: @@ -210,52 +501,6 @@ class BaseChatOllamaV1(BaseChatModelV1): """Return type of chat model.""" return "chat-ollama-v1" - -class ChatOllamaV1(BaseChatOllamaV1): - """Ollama chat model with native v1 content block support. - - This implementation provides native support for structured content blocks - and always returns AIMessageV1 format responses. - - Examples: - Basic text conversation: - - .. code-block:: python - - from langchain_ollama import ChatOllamaV1 - from langchain_core.messages.v1 import HumanMessage - from langchain_core.messages.content_blocks import TextContentBlock - - llm = ChatOllamaV1(model="llama3") - response = llm.invoke([ - HumanMessage(content=[ - TextContentBlock(type="text", text="Hello!") - ]) - ]) - - # Response is always structured - print(response.content) - # [{"type": "text", "text": "Hello! How can I help?"}] - - Multi-modal input: - - .. code-block:: python - - from langchain_core.messages.content_blocks import ImageContentBlock - - response = llm.invoke([ - HumanMessage(content=[ - TextContentBlock(type="text", text="Describe this image:"), - ImageContentBlock( - type="image", - mime_type="image/jpeg", - data="base64_encoded_image", - source_type="base64" - ) - ]) - ]) - """ - def _chat_params( self, messages: list[MessageV1], @@ -427,6 +672,97 @@ class ChatOllamaV1(BaseChatOllamaV1): ) yield chunk + def _invoke( + self, + messages: list[MessageV1], + stop: Optional[list[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AIMessageV1: + """Invoke the model with v1 messages and return a complete response. + + Args: + messages: List of v1 format messages. + stop: List of stop sequences. + run_manager: Callback manager for the run. + kwargs: Additional parameters. + + Returns: + Complete AI message response. + """ + stream_iter = self._generate_stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) + + async def _ainvoke( + self, + messages: list[MessageV1], + stop: Optional[list[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AIMessageV1: + """Async invoke the model with v1 messages and return a complete response. + + Args: + messages: List of v1 format messages. + stop: List of stop sequences. + run_manager: Async callback manager for the run. + kwargs: Additional parameters. + + Returns: + Complete AI message response. + """ + stream_iter = self._agenerate_stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await agenerate_from_stream(stream_iter) + + def _stream( + self, + messages: list[MessageV1], + stop: Optional[list[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[AIMessageChunkV1]: + """Stream response chunks using the v1 format. + + Args: + messages: List of v1 format messages. + stop: List of stop sequences. + run_manager: Callback manager for the run. + kwargs: Additional parameters. + + Yields: + AI message chunks in v1 format. + """ + yield from self._generate_stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + + async def _astream( + self, + messages: list[MessageV1], + stop: Optional[list[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[AIMessageChunkV1]: + """Async stream response chunks using the v1 format. + + Args: + messages: List of v1 format messages. + stop: List of stop sequences. + run_manager: Async callback manager for the run. + kwargs: Additional parameters. + + Yields: + AI message chunks in v1 format. + """ + async for chunk in self._agenerate_stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ): + yield chunk + def bind_tools( self, tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]], @@ -443,3 +779,170 @@ class ChatOllamaV1(BaseChatOllamaV1): """ formatted_tools = [convert_to_openai_tool(tool) for tool in tools] return super().bind(tools=formatted_tools, **kwargs) + + def with_structured_output( + self, + schema: Union[dict, type], + *, + method: Literal["function_calling", "json_mode", "json_schema"] = "json_schema", + include_raw: bool = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: + """Model wrapper that returns outputs formatted to match the given schema. + + Args: + schema: The output schema. Can be passed in as: + + - a Pydantic class, + - a JSON schema + - a TypedDict class + - an OpenAI function/tool schema. + + If ``schema`` is a Pydantic class then the model output will be a + Pydantic instance of that class, and the model-generated fields will be + validated by the Pydantic class. Otherwise the model output will be a + dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool` + for more on how to properly specify types and descriptions of + schema fields when specifying a Pydantic or TypedDict class. + + method: The method for steering model generation, one of: + + - ``'json_schema'``: + Uses Ollama's `structured output API `__ + - ``'function_calling'``: + Uses Ollama's tool-calling API + - ``'json_mode'``: + Specifies ``format='json'``. Note that if using JSON mode then you + must include instructions for formatting the output into the + desired schema into the model call. + + include_raw: + If False then only the parsed structured output is returned. If + an error occurs during model output parsing it will be raised. If True + then both the raw model response (a ``BaseMessage``) and the parsed model + response will be returned. If an error occurs during output parsing it + will be caught and returned as well. The final output is always a dict + with keys ``'raw'``, ``'parsed'``, and ``'parsing_error'``. + + kwargs: Additional keyword args aren't supported. + + Returns: + A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. + + If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict. + + If ``include_raw`` is True, then Runnable outputs a dict with keys: + + - ``'raw'``: ``BaseMessage`` + - ``'parsed'``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above. + - ``'parsing_error'``: Optional[BaseException] + + .. versionchanged:: 0.2.2 + + Added support for structured output API via ``format`` parameter. + + .. versionchanged:: 0.3.0 + + Updated default ``method`` to ``'json_schema'``. + + """ # noqa: E501 + _ = kwargs.pop("strict", None) + if kwargs: + msg = f"Received unsupported arguments {kwargs}" + raise ValueError(msg) + is_pydantic_schema = _is_pydantic_class(schema) + if method == "function_calling": + if schema is None: + msg = ( + "schema must be specified when method is not 'json_mode'. " + "Received None." + ) + raise ValueError(msg) + formatted_tool = convert_to_openai_tool(schema) + tool_name = formatted_tool["function"]["name"] + llm = self.bind_tools( + [schema], + tool_choice=tool_name, + ls_structured_output_format={ + "kwargs": {"method": method}, + "schema": formatted_tool, + }, + ) + if is_pydantic_schema: + output_parser: Runnable = PydanticToolsParser( + tools=[schema], # type: ignore[list-item] + first_tool_only=True, + ) + else: + output_parser = JsonOutputKeyToolsParser( + key_name=tool_name, first_tool_only=True + ) + elif method == "json_mode": + llm = self.bind( + format="json", + ls_structured_output_format={ + "kwargs": {"method": method}, + "schema": schema, + }, + ) + output_parser = ( + PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] + if is_pydantic_schema + else JsonOutputParser() + ) + elif method == "json_schema": + if schema is None: + msg = ( + "schema must be specified when method is not 'json_mode'. " + "Received None." + ) + raise ValueError(msg) + if is_pydantic_schema: + schema = cast(TypeBaseModel, schema) + if issubclass(schema, BaseModelV1): + response_format = schema.schema() + else: + response_format = schema.model_json_schema() + llm = self.bind( + format=response_format, + ls_structured_output_format={ + "kwargs": {"method": method}, + "schema": schema, + }, + ) + output_parser = PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] + else: + if is_typeddict(schema): + response_format = convert_to_json_schema(schema) + if "required" not in response_format: + response_format["required"] = list( + response_format["properties"].keys() + ) + else: + # is JSON schema + response_format = cast(dict, schema) + llm = self.bind( + format=response_format, + ls_structured_output_format={ + "kwargs": {"method": method}, + "schema": response_format, + }, + ) + output_parser = JsonOutputParser() + else: + msg = ( + f"Unrecognized method argument. Expected one of 'function_calling', " + f"'json_schema', or 'json_mode'. Received: '{method}'" + ) + raise ValueError(msg) + + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + return llm | output_parser