mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 06:23:20 +00:00
core, openai[patch]: support serialization of pydantic models in messages (#29940)
Resolves https://github.com/langchain-ai/langchain/issues/29003, https://github.com/langchain-ai/langchain/issues/27264 Related: https://github.com/langchain-ai/langchain-redis/issues/52 ```python from langchain.chat_models import init_chat_model from langchain.globals import set_llm_cache from langchain_community.cache import SQLiteCache from pydantic import BaseModel cache = SQLiteCache() set_llm_cache(cache) class Temperature(BaseModel): value: int city: str llm = init_chat_model("openai:gpt-4o-mini") structured_llm = llm.with_structured_output(Temperature) ``` ```python # 681 ms response = structured_llm.invoke("What is the average temperature of Rome in May?") ``` ```python # 6.98 ms response = structured_llm.invoke("What is the average temperature of Rome in May?") ```
This commit is contained in:
@@ -9,6 +9,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from math import ceil
|
||||
from operator import itemgetter
|
||||
@@ -78,7 +79,12 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough, chain
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableLambda,
|
||||
RunnableMap,
|
||||
RunnablePassthrough,
|
||||
)
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
@@ -1436,9 +1442,9 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
},
|
||||
)
|
||||
if is_pydantic_schema:
|
||||
output_parser = _oai_structured_outputs_parser.with_types(
|
||||
output_type=cast(type, schema)
|
||||
)
|
||||
output_parser = RunnableLambda(
|
||||
partial(_oai_structured_outputs_parser, schema=cast(type, schema))
|
||||
).with_types(output_type=cast(type, schema))
|
||||
else:
|
||||
output_parser = JsonOutputParser()
|
||||
else:
|
||||
@@ -2517,10 +2523,14 @@ def _convert_to_openai_response_format(
|
||||
return response_format
|
||||
|
||||
|
||||
@chain
|
||||
def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel:
|
||||
if ai_msg.additional_kwargs.get("parsed"):
|
||||
return ai_msg.additional_kwargs["parsed"]
|
||||
def _oai_structured_outputs_parser(
|
||||
ai_msg: AIMessage, schema: Type[_BM]
|
||||
) -> PydanticBaseModel:
|
||||
if parsed := ai_msg.additional_kwargs.get("parsed"):
|
||||
if isinstance(parsed, dict):
|
||||
return schema(**parsed)
|
||||
else:
|
||||
return parsed
|
||||
elif ai_msg.additional_kwargs.get("refusal"):
|
||||
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
|
||||
else:
|
||||
|
Reference in New Issue
Block a user