From 74a64cfbabb1679ea1ace754b1bc49ece3f05651 Mon Sep 17 00:00:00 2001 From: Kenny Date: Tue, 15 Aug 2023 20:01:32 -0400 Subject: [PATCH] expose output key to create_openai_fn_chain (#9155) I quick change to allow the output key of create_openai_fn_chain to optionally be changed. @baskaryan --------- Co-authored-by: Bagatur --- .../langchain/chains/openai_functions/base.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/chains/openai_functions/base.py b/libs/langchain/langchain/chains/openai_functions/base.py index 9eb773eae96..db322041435 100644 --- a/libs/langchain/langchain/chains/openai_functions/base.py +++ b/libs/langchain/langchain/chains/openai_functions/base.py @@ -192,6 +192,7 @@ def create_openai_fn_chain( llm: BaseLanguageModel, prompt: BasePromptTemplate, *, + output_key: str = "function", output_parser: Optional[BaseLLMOutputParser] = None, **kwargs: Any, ) -> LLMChain: @@ -210,6 +211,7 @@ def create_openai_fn_chain( pydantic.BaseModels for arguments. llm: Language model to use, assumed to support the OpenAI function-calling API. prompt: BasePromptTemplate to pass to the model. + output_key: The key to use when returning the output in LLMChain.__call__. output_parser: BaseLLMOutputParser to use for parsing model outputs. By default will be inferred from the function types. If pydantic.BaseModels are passed in, then the OutputParser will try to parse outputs using those. Otherwise @@ -274,7 +276,7 @@ def create_openai_fn_chain( prompt=prompt, output_parser=output_parser, llm_kwargs=llm_kwargs, - output_key="function", + output_key=output_key, **kwargs, ) return llm_chain @@ -285,6 +287,7 @@ def create_structured_output_chain( llm: BaseLanguageModel, prompt: BasePromptTemplate, *, + output_key: str = "function", output_parser: Optional[BaseLLMOutputParser] = None, **kwargs: Any, ) -> LLMChain: @@ -297,6 +300,7 @@ def create_structured_output_chain( the schema represents and descriptions for the parameters. llm: Language model to use, assumed to support the OpenAI function-calling API. prompt: BasePromptTemplate to pass to the model. + output_key: The key to use when returning the output in LLMChain.__call__. output_parser: BaseLLMOutputParser to use for parsing model outputs. By default will be inferred from the function types. If pydantic.BaseModels are passed in, then the OutputParser will try to parse outputs using those. Otherwise @@ -354,5 +358,10 @@ def create_structured_output_chain( pydantic_schema=_OutputFormatter, attr_name="output" ) return create_openai_fn_chain( - [function], llm, prompt, output_parser=output_parser, **kwargs + [function], + llm, + prompt, + output_key=output_key, + output_parser=output_parser, + **kwargs, )