fix openai structured chain with pydantic (#7622)

should return pydantic class
This commit is contained in:
Bagatur 2023-07-12 23:46:13 -04:00 committed by GitHub
parent ee70d4a0cd
commit 1d4db1327a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -9,6 +9,7 @@ from langchain.base_language import BaseLanguageModel
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.output_parsers.openai_functions import ( from langchain.output_parsers.openai_functions import (
JsonOutputFunctionsParser, JsonOutputFunctionsParser,
PydanticAttrOutputFunctionsParser,
PydanticOutputFunctionsParser, PydanticOutputFunctionsParser,
) )
from langchain.prompts import BasePromptTemplate from langchain.prompts import BasePromptTemplate
@ -318,17 +319,26 @@ def create_structured_output_chain(
chain.run("Harry was a chubby brown beagle who loved chicken") chain.run("Harry was a chubby brown beagle who loved chicken")
# -> Dog(name="Harry", color="brown", fav_food="chicken") # -> Dog(name="Harry", color="brown", fav_food="chicken")
""" # noqa: E501 """ # noqa: E501
function: Dict = { if isinstance(output_schema, dict):
"name": "output_formatter", function: Any = {
"description": ( "name": "output_formatter",
"Output formatter. Should always be used to format your response to the" "description": (
" user." "Output formatter. Should always be used to format your response to the"
), " user."
} ),
parameters = ( "parameters": output_schema,
output_schema if isinstance(output_schema, dict) else output_schema.schema() }
) else:
function["parameters"] = parameters
class _OutputFormatter(BaseModel):
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501
output: output_schema # type: ignore
function = _OutputFormatter
output_parser = output_parser or PydanticAttrOutputFunctionsParser(
pydantic_schema=_OutputFormatter, attr_name="output"
)
return create_openai_fn_chain( return create_openai_fn_chain(
[function], llm, prompt, output_parser=output_parser, **kwargs [function], llm, prompt, output_parser=output_parser, **kwargs
) )