diff --git a/libs/langchain/langchain/chains/ernie_functions/__init__.py b/libs/langchain/langchain/chains/ernie_functions/__init__.py new file mode 100644 index 00000000000..3efc22d4199 --- /dev/null +++ b/libs/langchain/langchain/chains/ernie_functions/__init__.py @@ -0,0 +1,17 @@ +from langchain.chains.ernie_functions.base import ( + convert_to_ernie_function, + create_ernie_fn_chain, + create_ernie_fn_runnable, + create_structured_output_chain, + create_structured_output_runnable, + get_ernie_output_parser, +) + +__all__ = [ + "convert_to_ernie_function", + "create_structured_output_chain", + "create_ernie_fn_chain", + "create_structured_output_runnable", + "create_ernie_fn_runnable", + "get_ernie_output_parser", +] diff --git a/libs/langchain/langchain/chains/ernie_functions/base.py b/libs/langchain/langchain/chains/ernie_functions/base.py new file mode 100644 index 00000000000..0070531884f --- /dev/null +++ b/libs/langchain/langchain/chains/ernie_functions/base.py @@ -0,0 +1,547 @@ +"""Methods for creating chains that use Ernie function-calling APIs.""" +import inspect +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) + +from langchain.base_language import BaseLanguageModel +from langchain.chains import LLMChain +from langchain.output_parsers.ernie_functions import ( + JsonOutputFunctionsParser, + PydanticAttrOutputFunctionsParser, + PydanticOutputFunctionsParser, +) +from langchain.prompts import BasePromptTemplate +from langchain.pydantic_v1 import BaseModel +from langchain.schema import BaseLLMOutputParser +from langchain.schema.output_parser import BaseGenerationOutputParser, BaseOutputParser +from langchain.schema.runnable import Runnable +from langchain.utils.ernie_functions import convert_pydantic_to_ernie_function + +PYTHON_TO_JSON_TYPES = { + "str": "string", + "int": "number", + "float": "number", + "bool": "boolean", +} + + +def _get_python_function_name(function: Callable) -> str: + """Get the name of a Python function.""" + return function.__name__ + + +def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]: + """Parse the function and argument descriptions from the docstring of a function. + + Assumes the function docstring follows Google Python style guide. + """ + docstring = inspect.getdoc(function) + if docstring: + docstring_blocks = docstring.split("\n\n") + descriptors = [] + args_block = None + past_descriptors = False + for block in docstring_blocks: + if block.startswith("Args:"): + args_block = block + break + elif block.startswith("Returns:") or block.startswith("Example:"): + # Don't break in case Args come after + past_descriptors = True + elif not past_descriptors: + descriptors.append(block) + else: + continue + description = " ".join(descriptors) + else: + description = "" + args_block = None + arg_descriptions = {} + if args_block: + arg = None + for line in args_block.split("\n")[1:]: + if ":" in line: + arg, desc = line.split(":") + arg_descriptions[arg.strip()] = desc.strip() + elif arg: + arg_descriptions[arg.strip()] += " " + line.strip() + return description, arg_descriptions + + +def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict: + """Get JsonSchema describing a Python functions arguments. + + Assumes all function arguments are of primitive types (int, float, str, bool) or + are subclasses of pydantic.BaseModel. + """ + properties = {} + annotations = inspect.getfullargspec(function).annotations + for arg, arg_type in annotations.items(): + if arg == "return": + continue + if isinstance(arg_type, type) and issubclass(arg_type, BaseModel): + # Mypy error: + # "type" has no attribute "schema" + properties[arg] = arg_type.schema() # type: ignore[attr-defined] + elif arg_type.__name__ in PYTHON_TO_JSON_TYPES: + properties[arg] = {"type": PYTHON_TO_JSON_TYPES[arg_type.__name__]} + if arg in arg_descriptions: + if arg not in properties: + properties[arg] = {} + properties[arg]["description"] = arg_descriptions[arg] + return properties + + +def _get_python_function_required_args(function: Callable) -> List[str]: + """Get the required arguments for a Python function.""" + spec = inspect.getfullargspec(function) + required = spec.args[: -len(spec.defaults)] if spec.defaults else spec.args + required += [k for k in spec.kwonlyargs if k not in (spec.kwonlydefaults or {})] + + is_class = type(function) is type + if is_class and required[0] == "self": + required = required[1:] + return required + + +def convert_python_function_to_ernie_function( + function: Callable, +) -> Dict[str, Any]: + """Convert a Python function to an Ernie function-calling API compatible dict. + + Assumes the Python function has type hints and a docstring with a description. If + the docstring has Google Python style argument descriptions, these will be + included as well. + """ + description, arg_descriptions = _parse_python_function_docstring(function) + return { + "name": _get_python_function_name(function), + "description": description, + "parameters": { + "type": "object", + "properties": _get_python_function_arguments(function, arg_descriptions), + "required": _get_python_function_required_args(function), + }, + } + + +def convert_to_ernie_function( + function: Union[Dict[str, Any], Type[BaseModel], Callable] +) -> Dict[str, Any]: + """Convert a raw function/class to an Ernie function. + + Args: + function: Either a dictionary, a pydantic.BaseModel class, or a Python function. + If a dictionary is passed in, it is assumed to already be a valid Ernie + function. + + Returns: + A dict version of the passed in function which is compatible with the + Ernie function-calling API. + """ + if isinstance(function, dict): + return function + elif isinstance(function, type) and issubclass(function, BaseModel): + return cast(Dict, convert_pydantic_to_ernie_function(function)) + elif callable(function): + return convert_python_function_to_ernie_function(function) + + else: + raise ValueError( + f"Unsupported function type {type(function)}. Functions must be passed in" + f" as Dict, pydantic.BaseModel, or Callable." + ) + + +def get_ernie_output_parser( + functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], +) -> Union[BaseOutputParser, BaseGenerationOutputParser]: + """Get the appropriate function output parser given the user functions. + + Args: + functions: Sequence where element is a dictionary, a pydantic.BaseModel class, + or a Python function. If a dictionary is passed in, it is assumed to + already be a valid Ernie function. + + Returns: + A PydanticOutputFunctionsParser if functions are Pydantic classes, otherwise + a JsonOutputFunctionsParser. If there's only one function and it is + not a Pydantic class, then the output parser will automatically extract + only the function arguments and not the function name. + """ + function_names = [convert_to_ernie_function(f)["name"] for f in functions] + if isinstance(functions[0], type) and issubclass(functions[0], BaseModel): + if len(functions) > 1: + pydantic_schema: Union[Dict, Type[BaseModel]] = { + name: fn for name, fn in zip(function_names, functions) + } + else: + pydantic_schema = functions[0] + output_parser: Union[ + BaseOutputParser, BaseGenerationOutputParser + ] = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema) + else: + output_parser = JsonOutputFunctionsParser(args_only=len(functions) <= 1) + return output_parser + + +def create_ernie_fn_runnable( + functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], + llm: Runnable, + prompt: BasePromptTemplate, + *, + output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None, + **kwargs: Any, +) -> Runnable: + """Create a runnable sequence that uses Ernie functions. + + Args: + functions: A sequence of either dictionaries, pydantic.BaseModels classes, or + Python functions. If dictionaries are passed in, they are assumed to + already be a valid Ernie functions. If only a single + function is passed in, then it will be enforced that the model use that + function. pydantic.BaseModels and Python functions should have docstrings + describing what the function does. For best results, pydantic.BaseModels + should have descriptions of the parameters and Python functions should have + Google Python style args descriptions in the docstring. Additionally, + Python functions should only use primitive types (str, int, float, bool) or + pydantic.BaseModels for arguments. + llm: Language model to use, assumed to support the Ernie function-calling API. + prompt: BasePromptTemplate to pass to the model. + output_parser: BaseLLMOutputParser to use for parsing model outputs. By default + will be inferred from the function types. If pydantic.BaseModels are passed + in, then the OutputParser will try to parse outputs using those. Otherwise + model outputs will simply be parsed as JSON. If multiple functions are + passed in and they are not pydantic.BaseModels, the chain output will + include both the name of the function that was returned and the arguments + to pass to the function. + + Returns: + A runnable sequence that will pass in the given functions to the model when run. + + Example: + .. code-block:: python + + from typing import Optional + + from langchain.chains.ernie_functions import create_ernie_fn_chain + from langchain.chat_models import ErnieBotChat + from langchain.prompts import ChatPromptTemplate + from langchain.pydantic_v1 import BaseModel, Field + + + class RecordPerson(BaseModel): + \"\"\"Record some identifying information about a person.\"\"\" + + name: str = Field(..., description="The person's name") + age: int = Field(..., description="The person's age") + fav_food: Optional[str] = Field(None, description="The person's favorite food") + + + class RecordDog(BaseModel): + \"\"\"Record some identifying information about a dog.\"\"\" + + name: str = Field(..., description="The dog's name") + color: str = Field(..., description="The dog's color") + fav_food: Optional[str] = Field(None, description="The dog's favorite food") + + + llm = ErnieBotChat(model_name="ERNIE-Bot-4") + prompt = ChatPromptTemplate.from_messages( + [ + ("user", "Make calls to the relevant function to record the entities in the following input: {input}"), + ("assistant", "OK!"), + ("user", "Tip: Make sure to answer in the correct format"), + ] + ) + chain = create_ernie_fn_runnable([RecordPerson, RecordDog], llm, prompt) + chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"}) + # -> RecordDog(name="Harry", color="brown", fav_food="chicken") + """ # noqa: E501 + if not functions: + raise ValueError("Need to pass in at least one function. Received zero.") + ernie_functions = [convert_to_ernie_function(f) for f in functions] + llm_kwargs: Dict[str, Any] = {"functions": ernie_functions, **kwargs} + if len(ernie_functions) == 1: + llm_kwargs["function_call"] = {"name": ernie_functions[0]["name"]} + output_parser = output_parser or get_ernie_output_parser(functions) + return prompt | llm.bind(**llm_kwargs) | output_parser + + +def create_structured_output_runnable( + output_schema: Union[Dict[str, Any], Type[BaseModel]], + llm: Runnable, + prompt: BasePromptTemplate, + *, + output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None, + **kwargs: Any, +) -> Runnable: + """Create a runnable that uses an Ernie function to get a structured output. + + Args: + output_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary + is passed in, it's assumed to already be a valid JsonSchema. + For best results, pydantic.BaseModels should have docstrings describing what + the schema represents and descriptions for the parameters. + llm: Language model to use, assumed to support the Ernie function-calling API. + prompt: BasePromptTemplate to pass to the model. + output_parser: BaseLLMOutputParser to use for parsing model outputs. By default + will be inferred from the function types. If pydantic.BaseModels are passed + in, then the OutputParser will try to parse outputs using those. Otherwise + model outputs will simply be parsed as JSON. + + Returns: + A runnable sequence that will pass the given function to the model when run. + + Example: + .. code-block:: python + + from typing import Optional + + from langchain.chains.ernie_functions import create_structured_output_chain + from langchain.chat_models import ErnieBotChat + from langchain.prompts import ChatPromptTemplate + from langchain.pydantic_v1 import BaseModel, Field + + class Dog(BaseModel): + \"\"\"Identifying information about a dog.\"\"\" + + name: str = Field(..., description="The dog's name") + color: str = Field(..., description="The dog's color") + fav_food: Optional[str] = Field(None, description="The dog's favorite food") + + llm = ErnieBotChat(model_name="ERNIE-Bot-4") + prompt = ChatPromptTemplate.from_messages( + [ + ("user", "Use the given format to extract information from the following input: {input}"), + ("assistant", "OK!"), + ("user", "Tip: Make sure to answer in the correct format"), + ] + ) + chain = create_structured_output_chain(Dog, llm, prompt) + chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"}) + # -> Dog(name="Harry", color="brown", fav_food="chicken") + """ # noqa: E501 + if isinstance(output_schema, dict): + function: Any = { + "name": "output_formatter", + "description": ( + "Output formatter. Should always be used to format your response to the" + " user." + ), + "parameters": output_schema, + } + else: + + class _OutputFormatter(BaseModel): + """Output formatter. Should always be used to format your response to the user.""" # noqa: E501 + + output: output_schema # type: ignore + + function = _OutputFormatter + output_parser = output_parser or PydanticAttrOutputFunctionsParser( + pydantic_schema=_OutputFormatter, attr_name="output" + ) + return create_ernie_fn_runnable( + [function], + llm, + prompt, + output_parser=output_parser, + **kwargs, + ) + + +""" --- Legacy --- """ + + +def create_ernie_fn_chain( + functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], + llm: BaseLanguageModel, + prompt: BasePromptTemplate, + *, + output_key: str = "function", + output_parser: Optional[BaseLLMOutputParser] = None, + **kwargs: Any, +) -> LLMChain: + """[Legacy] Create an LLM chain that uses Ernie functions. + + Args: + functions: A sequence of either dictionaries, pydantic.BaseModels classes, or + Python functions. If dictionaries are passed in, they are assumed to + already be a valid Ernie functions. If only a single + function is passed in, then it will be enforced that the model use that + function. pydantic.BaseModels and Python functions should have docstrings + describing what the function does. For best results, pydantic.BaseModels + should have descriptions of the parameters and Python functions should have + Google Python style args descriptions in the docstring. Additionally, + Python functions should only use primitive types (str, int, float, bool) or + pydantic.BaseModels for arguments. + llm: Language model to use, assumed to support the Ernie function-calling API. + prompt: BasePromptTemplate to pass to the model. + output_key: The key to use when returning the output in LLMChain.__call__. + output_parser: BaseLLMOutputParser to use for parsing model outputs. By default + will be inferred from the function types. If pydantic.BaseModels are passed + in, then the OutputParser will try to parse outputs using those. Otherwise + model outputs will simply be parsed as JSON. If multiple functions are + passed in and they are not pydantic.BaseModels, the chain output will + include both the name of the function that was returned and the arguments + to pass to the function. + + Returns: + An LLMChain that will pass in the given functions to the model when run. + + Example: + .. code-block:: python + + from typing import Optional + + from langchain.chains.ernie_functions import create_ernie_fn_chain + from langchain.chat_models import ErnieBotChat + from langchain.prompts import ChatPromptTemplate + + from langchain.pydantic_v1 import BaseModel, Field + + + class RecordPerson(BaseModel): + \"\"\"Record some identifying information about a person.\"\"\" + + name: str = Field(..., description="The person's name") + age: int = Field(..., description="The person's age") + fav_food: Optional[str] = Field(None, description="The person's favorite food") + + + class RecordDog(BaseModel): + \"\"\"Record some identifying information about a dog.\"\"\" + + name: str = Field(..., description="The dog's name") + color: str = Field(..., description="The dog's color") + fav_food: Optional[str] = Field(None, description="The dog's favorite food") + + + llm = ErnieBotChat(model_name="ERNIE-Bot-4") + prompt = ChatPromptTemplate.from_messages( + [ + ("user", "Make calls to the relevant function to record the entities in the following input: {input}"), + ("assistant", "OK!"), + ("user", "Tip: Make sure to answer in the correct format"), + ] + ) + chain = create_ernie_fn_chain([RecordPerson, RecordDog], llm, prompt) + chain.run("Harry was a chubby brown beagle who loved chicken") + # -> RecordDog(name="Harry", color="brown", fav_food="chicken") + """ # noqa: E501 + if not functions: + raise ValueError("Need to pass in at least one function. Received zero.") + ernie_functions = [convert_to_ernie_function(f) for f in functions] + output_parser = output_parser or get_ernie_output_parser(functions) + llm_kwargs: Dict[str, Any] = { + "functions": ernie_functions, + } + if len(ernie_functions) == 1: + llm_kwargs["function_call"] = {"name": ernie_functions[0]["name"]} + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + output_parser=output_parser, + llm_kwargs=llm_kwargs, + output_key=output_key, + **kwargs, + ) + return llm_chain + + +def create_structured_output_chain( + output_schema: Union[Dict[str, Any], Type[BaseModel]], + llm: BaseLanguageModel, + prompt: BasePromptTemplate, + *, + output_key: str = "function", + output_parser: Optional[BaseLLMOutputParser] = None, + **kwargs: Any, +) -> LLMChain: + """[Legacy] Create an LLMChain that uses an Ernie function to get a structured output. + + Args: + output_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary + is passed in, it's assumed to already be a valid JsonSchema. + For best results, pydantic.BaseModels should have docstrings describing what + the schema represents and descriptions for the parameters. + llm: Language model to use, assumed to support the Ernie function-calling API. + prompt: BasePromptTemplate to pass to the model. + output_key: The key to use when returning the output in LLMChain.__call__. + output_parser: BaseLLMOutputParser to use for parsing model outputs. By default + will be inferred from the function types. If pydantic.BaseModels are passed + in, then the OutputParser will try to parse outputs using those. Otherwise + model outputs will simply be parsed as JSON. + + Returns: + An LLMChain that will pass the given function to the model. + + Example: + .. code-block:: python + + from typing import Optional + + from langchain.chains.ernie_functions import create_structured_output_chain + from langchain.chat_models import ErnieBotChat + from langchain.prompts import ChatPromptTemplate + + from langchain.pydantic_v1 import BaseModel, Field + + class Dog(BaseModel): + \"\"\"Identifying information about a dog.\"\"\" + + name: str = Field(..., description="The dog's name") + color: str = Field(..., description="The dog's color") + fav_food: Optional[str] = Field(None, description="The dog's favorite food") + + llm = ErnieBotChat(model_name="ERNIE-Bot-4") + prompt = ChatPromptTemplate.from_messages( + [ + ("user", "Use the given format to extract information from the following input: {input}"), + ("assistant", "OK!"), + ("user", "Tip: Make sure to answer in the correct format"), + ] + ) + chain = create_structured_output_chain(Dog, llm, prompt) + chain.run("Harry was a chubby brown beagle who loved chicken") + # -> Dog(name="Harry", color="brown", fav_food="chicken") + """ # noqa: E501 + if isinstance(output_schema, dict): + function: Any = { + "name": "output_formatter", + "description": ( + "Output formatter. Should always be used to format your response to the" + " user." + ), + "parameters": output_schema, + } + else: + + class _OutputFormatter(BaseModel): + """Output formatter. Should always be used to format your response to the user.""" # noqa: E501 + + output: output_schema # type: ignore + + function = _OutputFormatter + output_parser = output_parser or PydanticAttrOutputFunctionsParser( + pydantic_schema=_OutputFormatter, attr_name="output" + ) + return create_ernie_fn_chain( + [function], + llm, + prompt, + output_key=output_key, + output_parser=output_parser, + **kwargs, + ) diff --git a/libs/langchain/langchain/chat_models/ernie.py b/libs/langchain/langchain/chat_models/ernie.py index bebfb374947..58e7647d1bd 100644 --- a/libs/langchain/langchain/chat_models/ernie.py +++ b/libs/langchain/langchain/chat_models/ernie.py @@ -1,3 +1,4 @@ +import json import logging import threading from typing import Any, Dict, List, Mapping, Optional @@ -179,9 +180,15 @@ class ErnieBotChat(BaseChatModel): return self._create_chat_result(resp) def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: - generations = [ - ChatGeneration(message=AIMessage(content=response.get("result"))) - ] + if "function_call" in response: + fc_str = '{{"function_call": {}}}'.format( + json.dumps(response.get("function_call")) + ) + generations = [ChatGeneration(message=AIMessage(content=fc_str))] + else: + generations = [ + ChatGeneration(message=AIMessage(content=response.get("result"))) + ] token_usage = response.get("usage", {}) llm_output = {"token_usage": token_usage, "model_name": self.model_name} return ChatResult(generations=generations, llm_output=llm_output) diff --git a/libs/langchain/langchain/output_parsers/ernie_functions.py b/libs/langchain/langchain/output_parsers/ernie_functions.py new file mode 100644 index 00000000000..b2682c4dc21 --- /dev/null +++ b/libs/langchain/langchain/output_parsers/ernie_functions.py @@ -0,0 +1,184 @@ +import copy +import json +from typing import Any, Dict, List, Optional, Type, Union + +import jsonpatch + +from langchain.output_parsers.json import parse_partial_json +from langchain.pydantic_v1 import BaseModel, root_validator +from langchain.schema import ( + ChatGeneration, + Generation, + OutputParserException, +) +from langchain.schema.output_parser import ( + BaseCumulativeTransformOutputParser, + BaseGenerationOutputParser, +) + + +class OutputFunctionsParser(BaseGenerationOutputParser[Any]): + """Parse an output that is one of sets of values.""" + + args_only: bool = True + """Whether to only return the arguments to the function call.""" + + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + generation = result[0] + if not isinstance(generation, ChatGeneration): + raise OutputParserException( + "This output parser can only be used with a chat generation." + ) + message = generation.message + try: + func_call = copy.deepcopy(message.additional_kwargs["function_call"]) + except KeyError as exc: + raise OutputParserException(f"Could not parse function call: {exc}") + + if self.args_only: + return func_call["arguments"] + return func_call + + +class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]): + """Parse an output as the Json object.""" + + strict: bool = False + """Whether to allow non-JSON-compliant strings. + + See: https://docs.python.org/3/library/json.html#encoders-and-decoders + + Useful when the parsed output may include unicode characters or new lines. + """ + + args_only: bool = True + """Whether to only return the arguments to the function call.""" + + @property + def _type(self) -> str: + return "json_functions" + + def _diff(self, prev: Optional[Any], next: Any) -> Any: + return jsonpatch.make_patch(prev, next).patch + + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + if len(result) != 1: + raise OutputParserException( + f"Expected exactly one result, but got {len(result)}" + ) + generation = result[0] + if not isinstance(generation, ChatGeneration): + raise OutputParserException( + "This output parser can only be used with a chat generation." + ) + message = generation.message + message.additional_kwargs["function_call"] = {} + if "function_call" in message.content: + function_call = json.loads(str(message.content)) + if "function_call" in function_call: + fc = function_call["function_call"] + message.additional_kwargs["function_call"] = fc + try: + function_call = message.additional_kwargs["function_call"] + except KeyError as exc: + if partial: + return None + else: + raise OutputParserException(f"Could not parse function call: {exc}") + try: + if partial: + if self.args_only: + return parse_partial_json( + function_call["arguments"], strict=self.strict + ) + else: + return { + **function_call, + "arguments": parse_partial_json( + function_call["arguments"], strict=self.strict + ), + } + else: + if self.args_only: + try: + return json.loads( + function_call["arguments"], strict=self.strict + ) + except (json.JSONDecodeError, TypeError) as exc: + raise OutputParserException( + f"Could not parse function call data: {exc}" + ) + else: + try: + return { + **function_call, + "arguments": json.loads( + function_call["arguments"], strict=self.strict + ), + } + except (json.JSONDecodeError, TypeError) as exc: + raise OutputParserException( + f"Could not parse function call data: {exc}" + ) + except KeyError: + return None + + # This method would be called by the default implementation of `parse_result` + # but we're overriding that method so it's not needed. + def parse(self, text: str) -> Any: + raise NotImplementedError() + + +class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser): + """Parse an output as the element of the Json object.""" + + key_name: str + """The name of the key to return.""" + + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + res = super().parse_result(result, partial=partial) + if partial and res is None: + return None + return res.get(self.key_name) if partial else res[self.key_name] + + +class PydanticOutputFunctionsParser(OutputFunctionsParser): + """Parse an output as a pydantic object.""" + + pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]] + """The pydantic schema to parse the output with.""" + + @root_validator(pre=True) + def validate_schema(cls, values: Dict) -> Dict: + schema = values["pydantic_schema"] + if "args_only" not in values: + values["args_only"] = isinstance(schema, type) and issubclass( + schema, BaseModel + ) + elif values["args_only"] and isinstance(schema, Dict): + raise ValueError( + "If multiple pydantic schemas are provided then args_only should be" + " False." + ) + return values + + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + _result = super().parse_result(result) + if self.args_only: + pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore + else: + fn_name = _result["name"] + _args = _result["arguments"] + pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore # noqa: E501 + return pydantic_args + + +class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser): + """Parse an output as an attribute of a pydantic object.""" + + attr_name: str + """The name of the attribute to return.""" + + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + result = super().parse_result(result) + return getattr(result, self.attr_name) diff --git a/libs/langchain/langchain/utils/ernie_functions.py b/libs/langchain/langchain/utils/ernie_functions.py new file mode 100644 index 00000000000..080df2ade25 --- /dev/null +++ b/libs/langchain/langchain/utils/ernie_functions.py @@ -0,0 +1,51 @@ +from typing import Literal, Optional, Type, TypedDict + +from langchain.pydantic_v1 import BaseModel +from langchain.utils.json_schema import dereference_refs + + +class FunctionDescription(TypedDict): + """Representation of a callable function to the Ernie API.""" + + name: str + """The name of the function.""" + description: str + """A description of the function.""" + parameters: dict + """The parameters of the function.""" + + +class ToolDescription(TypedDict): + """Representation of a callable function to the Ernie API.""" + + type: Literal["function"] + function: FunctionDescription + + +def convert_pydantic_to_ernie_function( + model: Type[BaseModel], + *, + name: Optional[str] = None, + description: Optional[str] = None, +) -> FunctionDescription: + """Converts a Pydantic model to a function description for the Ernie API.""" + schema = dereference_refs(model.schema()) + schema.pop("definitions", None) + return { + "name": name or schema["title"], + "description": description or schema["description"], + "parameters": schema, + } + + +def convert_pydantic_to_ernie_tool( + model: Type[BaseModel], + *, + name: Optional[str] = None, + description: Optional[str] = None, +) -> ToolDescription: + """Converts a Pydantic model to a function description for the Ernie API.""" + function = convert_pydantic_to_ernie_function( + model, name=name, description=description + ) + return {"type": "function", "function": function}