From 46d6bf0330a8553130a6c323251c78b94218c3c0 Mon Sep 17 00:00:00 2001 From: Lance Martin <122662504+rlancemartin@users.noreply.github.com> Date: Tue, 18 Mar 2025 09:44:22 -0700 Subject: [PATCH] ollama[minor]: update default method for structured output (#30273) From function calling to Ollama's [dedicated structured output feature](https://ollama.com/blog/structured-outputs). --------- Co-authored-by: Chester Curme --- .../ollama/langchain_ollama/chat_models.py | 38 ++++++++++--------- libs/partners/ollama/pyproject.toml | 2 +- .../chat_models/test_chat_models_standard.py | 18 --------- libs/partners/ollama/uv.lock | 9 ++--- 4 files changed, 26 insertions(+), 41 deletions(-) diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index e449969025a..94595a1bde5 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -47,13 +47,14 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResu from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import ( - _convert_any_typed_dicts_to_pydantic as convert_any_typed_dicts_to_pydantic, + convert_to_json_schema, + convert_to_openai_tool, ) -from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass from ollama import AsyncClient, Client, Message, Options from pydantic import BaseModel, PrivateAttr, model_validator from pydantic.json_schema import JsonSchemaValue +from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import Self, is_typeddict @@ -831,9 +832,7 @@ class ChatOllama(BaseChatModel): self, schema: Union[Dict, type], *, - method: Literal[ - "function_calling", "json_mode", "json_schema" - ] = "function_calling", + method: Literal["function_calling", "json_mode", "json_schema"] = "json_schema", include_raw: bool = False, **kwargs: Any, ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: @@ -857,10 +856,10 @@ class ChatOllama(BaseChatModel): method: The method for steering model generation, one of: - - "function_calling": - Uses Ollama's tool-calling API - "json_schema": Uses Ollama's structured output API: https://ollama.com/blog/structured-outputs + - "function_calling": + Uses Ollama's tool-calling API - "json_mode": Specifies ``format="json"``. Note that if using JSON mode then you must include instructions for formatting the output into the @@ -891,7 +890,11 @@ class ChatOllama(BaseChatModel): Added support for structured output API via ``format`` parameter. - .. dropdown:: Example: schema=Pydantic class, method="function_calling", include_raw=False + .. versionchanged:: 0.3.0 + + Updated default ``method`` to ``"json_schema"``. + + .. dropdown:: Example: schema=Pydantic class, method="json_schema", include_raw=False .. code-block:: python @@ -924,7 +927,7 @@ class ChatOllama(BaseChatModel): # justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.' # ) - .. dropdown:: Example: schema=Pydantic class, method="function_calling", include_raw=True + .. dropdown:: Example: schema=Pydantic class, method="json_schema", include_raw=True .. code-block:: python @@ -953,7 +956,7 @@ class ChatOllama(BaseChatModel): # 'parsing_error': None # } - .. dropdown:: Example: schema=Pydantic class, method="json_schema", include_raw=False + .. dropdown:: Example: schema=Pydantic class, method="function_calling", include_raw=False .. code-block:: python @@ -974,7 +977,7 @@ class ChatOllama(BaseChatModel): llm = ChatOllama(model="llama3.1", temperature=0) structured_llm = llm.with_structured_output( - AnswerWithJustification, method="json_schema" + AnswerWithJustification, method="function_calling" ) structured_llm.invoke( @@ -1125,8 +1128,12 @@ class ChatOllama(BaseChatModel): ) if is_pydantic_schema: schema = cast(TypeBaseModel, schema) + if issubclass(schema, BaseModelV1): + response_format = schema.schema() + else: + response_format = schema.model_json_schema() llm = self.bind( - format=schema.model_json_schema(), + format=response_format, ls_structured_output_format={ "kwargs": {"method": method}, "schema": schema, @@ -1135,17 +1142,14 @@ class ChatOllama(BaseChatModel): output_parser = PydanticOutputParser(pydantic_object=schema) else: if is_typeddict(schema): - schema = cast(type, schema) - response_format = convert_any_typed_dicts_to_pydantic( - schema, visited={} - ).schema() # type: ignore[attr-defined] + response_format = convert_to_json_schema(schema) if "required" not in response_format: response_format["required"] = list( response_format["properties"].keys() ) else: # is JSON schema - response_format = schema + response_format = cast(dict, schema) llm = self.bind( format=response_format, ls_structured_output_format={ diff --git a/libs/partners/ollama/pyproject.toml b/libs/partners/ollama/pyproject.toml index c159c6eb3d1..0ff67a2abbc 100644 --- a/libs/partners/ollama/pyproject.toml +++ b/libs/partners/ollama/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "pdm.backend" authors = [] license = { text = "MIT" } requires-python = "<4.0,>=3.9" -dependencies = ["ollama<1,>=0.4.4", "langchain-core<1.0.0,>=0.3.33"] +dependencies = ["ollama<1,>=0.4.4", "langchain-core<1.0.0,>=0.3.45"] name = "langchain-ollama" version = "0.2.3" description = "An integration package connecting Ollama and LangChain" diff --git a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py index 5f990f2251b..bc39b1319df 100644 --- a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py +++ b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py @@ -2,8 +2,6 @@ from typing import Type -import pytest -from langchain_core.language_models import BaseChatModel from langchain_tests.integration_tests import ChatModelIntegrationTests from langchain_ollama.chat_models import ChatOllama @@ -25,19 +23,3 @@ class TestChatOllama(ChatModelIntegrationTests): @property def supports_json_mode(self) -> bool: return True - - @pytest.mark.xfail( - reason=( - "Fails with 'AssertionError'. Ollama does not support 'tool_choice' yet." - ) - ) - def test_structured_output(self, model: BaseChatModel, schema_type: str) -> None: - super().test_structured_output(model, schema_type) - - @pytest.mark.xfail( - reason=( - "Fails with 'AssertionError'. Ollama does not support 'tool_choice' yet." - ) - ) - def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None: - super().test_structured_output_pydantic_2_v1(model) diff --git a/libs/partners/ollama/uv.lock b/libs/partners/ollama/uv.lock index f2207769418..3b8d772a1e3 100644 --- a/libs/partners/ollama/uv.lock +++ b/libs/partners/ollama/uv.lock @@ -287,7 +287,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "0.3.35" +version = "0.3.45" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, @@ -319,7 +319,7 @@ dev = [ ] lint = [{ name = "ruff", specifier = ">=0.9.2,<1.0.0" }] test = [ - { name = "blockbuster", specifier = "~=1.5.11" }, + { name = "blockbuster", specifier = "~=1.5.18" }, { name = "freezegun", specifier = ">=1.2.2,<2.0.0" }, { name = "grandalf", specifier = ">=0.8,<1.0" }, { name = "langchain-tests", directory = "../../standard-tests" }, @@ -403,7 +403,7 @@ typing = [ [[package]] name = "langchain-tests" -version = "0.3.11" +version = "0.3.14" source = { editable = "../../standard-tests" } dependencies = [ { name = "httpx" }, @@ -420,8 +420,7 @@ dependencies = [ requires-dist = [ { name = "httpx", specifier = ">=0.25.0,<1" }, { name = "langchain-core", editable = "../../core" }, - { name = "numpy", marker = "python_full_version < '3.12'", specifier = ">=1.24.0,<2.0.0" }, - { name = "numpy", marker = "python_full_version >= '3.12'", specifier = ">=1.26.2,<3" }, + { name = "numpy", specifier = ">=1.26.2,<3" }, { name = "pytest", specifier = ">=7,<9" }, { name = "pytest-asyncio", specifier = ">=0.20,<1" }, { name = "pytest-socket", specifier = ">=0.6.0,<1" },