diff --git a/libs/langchain/langchain/chains/openai_functions/base.py b/libs/langchain/langchain/chains/openai_functions/base.py index 14259dff23e..6f162cc0524 100644 --- a/libs/langchain/langchain/chains/openai_functions/base.py +++ b/libs/langchain/langchain/chains/openai_functions/base.py @@ -204,6 +204,7 @@ def create_openai_fn_runnable( llm: Runnable, prompt: BasePromptTemplate, *, + enforce_single_function_usage: bool = True, output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None, **kwargs: Any, ) -> Runnable: @@ -222,6 +223,9 @@ def create_openai_fn_runnable( pydantic.BaseModels for arguments. llm: Language model to use, assumed to support the OpenAI function-calling API. prompt: BasePromptTemplate to pass to the model. + enforce_single_function_usage: only used if a single function is passed in. If + True, then the model will be forced to use the given function. If False, + then the model will be given the option to use the given function or not. output_parser: BaseLLMOutputParser to use for parsing model outputs. By default will be inferred from the function types. If pydantic.BaseModels are passed in, then the OutputParser will try to parse outputs using those. Otherwise @@ -276,7 +280,7 @@ def create_openai_fn_runnable( raise ValueError("Need to pass in at least one function. Received zero.") openai_functions = [convert_to_openai_function(f) for f in functions] llm_kwargs: Dict[str, Any] = {"functions": openai_functions, **kwargs} - if len(openai_functions) == 1: + if len(openai_functions) == 1 and enforce_single_function_usage: llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]} output_parser = output_parser or get_openai_output_parser(functions) return prompt | llm.bind(**llm_kwargs) | output_parser @@ -373,6 +377,7 @@ def create_openai_fn_chain( llm: BaseLanguageModel, prompt: BasePromptTemplate, *, + enforce_single_function_usage: bool = True, output_key: str = "function", output_parser: Optional[BaseLLMOutputParser] = None, **kwargs: Any, @@ -392,6 +397,9 @@ def create_openai_fn_chain( pydantic.BaseModels for arguments. llm: Language model to use, assumed to support the OpenAI function-calling API. prompt: BasePromptTemplate to pass to the model. + enforce_single_function_usage: only used if a single function is passed in. If + True, then the model will be forced to use the given function. If False, + then the model will be given the option to use the given function or not. output_key: The key to use when returning the output in LLMChain.__call__. output_parser: BaseLLMOutputParser to use for parsing model outputs. By default will be inferred from the function types. If pydantic.BaseModels are passed @@ -451,7 +459,7 @@ def create_openai_fn_chain( llm_kwargs: Dict[str, Any] = { "functions": openai_functions, } - if len(openai_functions) == 1: + if len(openai_functions) == 1 and enforce_single_function_usage: llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]} llm_chain = LLMChain( llm=llm,