diff --git a/langchain/chains/openai_functions/base.py b/langchain/chains/openai_functions/base.py index 5ee1bfd3ef3..aa5598e3f01 100644 --- a/langchain/chains/openai_functions/base.py +++ b/langchain/chains/openai_functions/base.py @@ -9,6 +9,7 @@ from langchain.base_language import BaseLanguageModel from langchain.chains import LLMChain from langchain.output_parsers.openai_functions import ( JsonOutputFunctionsParser, + PydanticAttrOutputFunctionsParser, PydanticOutputFunctionsParser, ) from langchain.prompts import BasePromptTemplate @@ -318,17 +319,26 @@ def create_structured_output_chain( chain.run("Harry was a chubby brown beagle who loved chicken") # -> Dog(name="Harry", color="brown", fav_food="chicken") """ # noqa: E501 - function: Dict = { - "name": "output_formatter", - "description": ( - "Output formatter. Should always be used to format your response to the" - " user." - ), - } - parameters = ( - output_schema if isinstance(output_schema, dict) else output_schema.schema() - ) - function["parameters"] = parameters + if isinstance(output_schema, dict): + function: Any = { + "name": "output_formatter", + "description": ( + "Output formatter. Should always be used to format your response to the" + " user." + ), + "parameters": output_schema, + } + else: + + class _OutputFormatter(BaseModel): + """Output formatter. Should always be used to format your response to the user.""" # noqa: E501 + + output: output_schema # type: ignore + + function = _OutputFormatter + output_parser = output_parser or PydanticAttrOutputFunctionsParser( + pydantic_schema=_OutputFormatter, attr_name="output" + ) return create_openai_fn_chain( [function], llm, prompt, output_parser=output_parser, **kwargs )