mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 02:33:19 +00:00
fix openai structured chain with pydantic (#7622)
should return pydantic class
This commit is contained in:
parent
ee70d4a0cd
commit
1d4db1327a
@ -9,6 +9,7 @@ from langchain.base_language import BaseLanguageModel
|
|||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain.output_parsers.openai_functions import (
|
from langchain.output_parsers.openai_functions import (
|
||||||
JsonOutputFunctionsParser,
|
JsonOutputFunctionsParser,
|
||||||
|
PydanticAttrOutputFunctionsParser,
|
||||||
PydanticOutputFunctionsParser,
|
PydanticOutputFunctionsParser,
|
||||||
)
|
)
|
||||||
from langchain.prompts import BasePromptTemplate
|
from langchain.prompts import BasePromptTemplate
|
||||||
@ -318,17 +319,26 @@ def create_structured_output_chain(
|
|||||||
chain.run("Harry was a chubby brown beagle who loved chicken")
|
chain.run("Harry was a chubby brown beagle who loved chicken")
|
||||||
# -> Dog(name="Harry", color="brown", fav_food="chicken")
|
# -> Dog(name="Harry", color="brown", fav_food="chicken")
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
function: Dict = {
|
if isinstance(output_schema, dict):
|
||||||
"name": "output_formatter",
|
function: Any = {
|
||||||
"description": (
|
"name": "output_formatter",
|
||||||
"Output formatter. Should always be used to format your response to the"
|
"description": (
|
||||||
" user."
|
"Output formatter. Should always be used to format your response to the"
|
||||||
),
|
" user."
|
||||||
}
|
),
|
||||||
parameters = (
|
"parameters": output_schema,
|
||||||
output_schema if isinstance(output_schema, dict) else output_schema.schema()
|
}
|
||||||
)
|
else:
|
||||||
function["parameters"] = parameters
|
|
||||||
|
class _OutputFormatter(BaseModel):
|
||||||
|
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501
|
||||||
|
|
||||||
|
output: output_schema # type: ignore
|
||||||
|
|
||||||
|
function = _OutputFormatter
|
||||||
|
output_parser = output_parser or PydanticAttrOutputFunctionsParser(
|
||||||
|
pydantic_schema=_OutputFormatter, attr_name="output"
|
||||||
|
)
|
||||||
return create_openai_fn_chain(
|
return create_openai_fn_chain(
|
||||||
[function], llm, prompt, output_parser=output_parser, **kwargs
|
[function], llm, prompt, output_parser=output_parser, **kwargs
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user