feat: add ERNIE-Bot-4 Function Calling (#13320)

- **Description:** ERNIE-Bot-Chat-4 Large Language Model adds the
ability of `Function Calling` by passing parameters through the
`functions` parameter in the request. To simplify function calling for
ERNIE-Bot-Chat-4, the `create_ernie_fn_chain()` function has been added.
The definition and usage of the `create_ernie_fn_chain()` function is
similar to that of the `create_openai_fn_chain()` function.

Examples as the follows:

```
import json

from langchain.chains.ernie_functions import (
    create_ernie_fn_chain,
)
from langchain.chat_models import ErnieBotChat
from langchain.prompts import ChatPromptTemplate

def get_current_news(location: str) -> str:
    """Get the current news based on the location.'

    Args:
        location (str): The location to query.
    
    Returs:
        str: Current news based on the location.
    """

    news_info = {
        "location": location,
        "news": [
            "I have a Book.",
            "It's a nice day, today."
        ]
    }

    return json.dumps(news_info)

def get_current_weather(location: str, unit: str="celsius") -> str:
    """Get the current weather in a given location

    Args:
        location (str): location of the weather.
        unit (str): unit of the tempuature.
    
    Returns:
        str: weather in the given location.
    """

    weather_info = {
        "location": location,
        "temperature": "27",
        "unit": unit,
        "forecast": ["sunny", "windy"],
    }
    return json.dumps(weather_info)

llm = ErnieBotChat(model_name="ERNIE-Bot-4")
prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{query}"),
    ]
)

chain = create_ernie_fn_chain([get_current_weather, get_current_news], llm, prompt, verbose=True)
res = chain.run("北京今天的新闻是什么?")
print(res)
```

The running results of the above program are shown below:
```
> Entering new LLMChain chain...
Prompt after formatting:
Human: 北京今天的新闻是什么?



> Finished chain.
{'name': 'get_current_news', 'thoughts': '用户想要知道北京今天的新闻。我可以使用get_current_news工具来获取这些信息。', 'arguments': {'location': '北京'}}
```
This commit is contained in:
Wang Wei 2023-11-20 14:36:12 +08:00 committed by GitHub
parent 10418ab0c1
commit fe7b40cb2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 809 additions and 3 deletions

View File

@ -0,0 +1,17 @@
from langchain.chains.ernie_functions.base import (
convert_to_ernie_function,
create_ernie_fn_chain,
create_ernie_fn_runnable,
create_structured_output_chain,
create_structured_output_runnable,
get_ernie_output_parser,
)
__all__ = [
"convert_to_ernie_function",
"create_structured_output_chain",
"create_ernie_fn_chain",
"create_structured_output_runnable",
"create_ernie_fn_runnable",
"get_ernie_output_parser",
]

View File

@ -0,0 +1,547 @@
"""Methods for creating chains that use Ernie function-calling APIs."""
import inspect
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from langchain.base_language import BaseLanguageModel
from langchain.chains import LLMChain
from langchain.output_parsers.ernie_functions import (
JsonOutputFunctionsParser,
PydanticAttrOutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain.prompts import BasePromptTemplate
from langchain.pydantic_v1 import BaseModel
from langchain.schema import BaseLLMOutputParser
from langchain.schema.output_parser import BaseGenerationOutputParser, BaseOutputParser
from langchain.schema.runnable import Runnable
from langchain.utils.ernie_functions import convert_pydantic_to_ernie_function
PYTHON_TO_JSON_TYPES = {
"str": "string",
"int": "number",
"float": "number",
"bool": "boolean",
}
def _get_python_function_name(function: Callable) -> str:
"""Get the name of a Python function."""
return function.__name__
def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide.
"""
docstring = inspect.getdoc(function)
if docstring:
docstring_blocks = docstring.split("\n\n")
descriptors = []
args_block = None
past_descriptors = False
for block in docstring_blocks:
if block.startswith("Args:"):
args_block = block
break
elif block.startswith("Returns:") or block.startswith("Example:"):
# Don't break in case Args come after
past_descriptors = True
elif not past_descriptors:
descriptors.append(block)
else:
continue
description = " ".join(descriptors)
else:
description = ""
args_block = None
arg_descriptions = {}
if args_block:
arg = None
for line in args_block.split("\n")[1:]:
if ":" in line:
arg, desc = line.split(":")
arg_descriptions[arg.strip()] = desc.strip()
elif arg:
arg_descriptions[arg.strip()] += " " + line.strip()
return description, arg_descriptions
def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict:
"""Get JsonSchema describing a Python functions arguments.
Assumes all function arguments are of primitive types (int, float, str, bool) or
are subclasses of pydantic.BaseModel.
"""
properties = {}
annotations = inspect.getfullargspec(function).annotations
for arg, arg_type in annotations.items():
if arg == "return":
continue
if isinstance(arg_type, type) and issubclass(arg_type, BaseModel):
# Mypy error:
# "type" has no attribute "schema"
properties[arg] = arg_type.schema() # type: ignore[attr-defined]
elif arg_type.__name__ in PYTHON_TO_JSON_TYPES:
properties[arg] = {"type": PYTHON_TO_JSON_TYPES[arg_type.__name__]}
if arg in arg_descriptions:
if arg not in properties:
properties[arg] = {}
properties[arg]["description"] = arg_descriptions[arg]
return properties
def _get_python_function_required_args(function: Callable) -> List[str]:
"""Get the required arguments for a Python function."""
spec = inspect.getfullargspec(function)
required = spec.args[: -len(spec.defaults)] if spec.defaults else spec.args
required += [k for k in spec.kwonlyargs if k not in (spec.kwonlydefaults or {})]
is_class = type(function) is type
if is_class and required[0] == "self":
required = required[1:]
return required
def convert_python_function_to_ernie_function(
function: Callable,
) -> Dict[str, Any]:
"""Convert a Python function to an Ernie function-calling API compatible dict.
Assumes the Python function has type hints and a docstring with a description. If
the docstring has Google Python style argument descriptions, these will be
included as well.
"""
description, arg_descriptions = _parse_python_function_docstring(function)
return {
"name": _get_python_function_name(function),
"description": description,
"parameters": {
"type": "object",
"properties": _get_python_function_arguments(function, arg_descriptions),
"required": _get_python_function_required_args(function),
},
}
def convert_to_ernie_function(
function: Union[Dict[str, Any], Type[BaseModel], Callable]
) -> Dict[str, Any]:
"""Convert a raw function/class to an Ernie function.
Args:
function: Either a dictionary, a pydantic.BaseModel class, or a Python function.
If a dictionary is passed in, it is assumed to already be a valid Ernie
function.
Returns:
A dict version of the passed in function which is compatible with the
Ernie function-calling API.
"""
if isinstance(function, dict):
return function
elif isinstance(function, type) and issubclass(function, BaseModel):
return cast(Dict, convert_pydantic_to_ernie_function(function))
elif callable(function):
return convert_python_function_to_ernie_function(function)
else:
raise ValueError(
f"Unsupported function type {type(function)}. Functions must be passed in"
f" as Dict, pydantic.BaseModel, or Callable."
)
def get_ernie_output_parser(
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
"""Get the appropriate function output parser given the user functions.
Args:
functions: Sequence where element is a dictionary, a pydantic.BaseModel class,
or a Python function. If a dictionary is passed in, it is assumed to
already be a valid Ernie function.
Returns:
A PydanticOutputFunctionsParser if functions are Pydantic classes, otherwise
a JsonOutputFunctionsParser. If there's only one function and it is
not a Pydantic class, then the output parser will automatically extract
only the function arguments and not the function name.
"""
function_names = [convert_to_ernie_function(f)["name"] for f in functions]
if isinstance(functions[0], type) and issubclass(functions[0], BaseModel):
if len(functions) > 1:
pydantic_schema: Union[Dict, Type[BaseModel]] = {
name: fn for name, fn in zip(function_names, functions)
}
else:
pydantic_schema = functions[0]
output_parser: Union[
BaseOutputParser, BaseGenerationOutputParser
] = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
else:
output_parser = JsonOutputFunctionsParser(args_only=len(functions) <= 1)
return output_parser
def create_ernie_fn_runnable(
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
llm: Runnable,
prompt: BasePromptTemplate,
*,
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
**kwargs: Any,
) -> Runnable:
"""Create a runnable sequence that uses Ernie functions.
Args:
functions: A sequence of either dictionaries, pydantic.BaseModels classes, or
Python functions. If dictionaries are passed in, they are assumed to
already be a valid Ernie functions. If only a single
function is passed in, then it will be enforced that the model use that
function. pydantic.BaseModels and Python functions should have docstrings
describing what the function does. For best results, pydantic.BaseModels
should have descriptions of the parameters and Python functions should have
Google Python style args descriptions in the docstring. Additionally,
Python functions should only use primitive types (str, int, float, bool) or
pydantic.BaseModels for arguments.
llm: Language model to use, assumed to support the Ernie function-calling API.
prompt: BasePromptTemplate to pass to the model.
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
will be inferred from the function types. If pydantic.BaseModels are passed
in, then the OutputParser will try to parse outputs using those. Otherwise
model outputs will simply be parsed as JSON. If multiple functions are
passed in and they are not pydantic.BaseModels, the chain output will
include both the name of the function that was returned and the arguments
to pass to the function.
Returns:
A runnable sequence that will pass in the given functions to the model when run.
Example:
.. code-block:: python
from typing import Optional
from langchain.chains.ernie_functions import create_ernie_fn_chain
from langchain.chat_models import ErnieBotChat
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel, Field
class RecordPerson(BaseModel):
\"\"\"Record some identifying information about a person.\"\"\"
name: str = Field(..., description="The person's name")
age: int = Field(..., description="The person's age")
fav_food: Optional[str] = Field(None, description="The person's favorite food")
class RecordDog(BaseModel):
\"\"\"Record some identifying information about a dog.\"\"\"
name: str = Field(..., description="The dog's name")
color: str = Field(..., description="The dog's color")
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
llm = ErnieBotChat(model_name="ERNIE-Bot-4")
prompt = ChatPromptTemplate.from_messages(
[
("user", "Make calls to the relevant function to record the entities in the following input: {input}"),
("assistant", "OK!"),
("user", "Tip: Make sure to answer in the correct format"),
]
)
chain = create_ernie_fn_runnable([RecordPerson, RecordDog], llm, prompt)
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
""" # noqa: E501
if not functions:
raise ValueError("Need to pass in at least one function. Received zero.")
ernie_functions = [convert_to_ernie_function(f) for f in functions]
llm_kwargs: Dict[str, Any] = {"functions": ernie_functions, **kwargs}
if len(ernie_functions) == 1:
llm_kwargs["function_call"] = {"name": ernie_functions[0]["name"]}
output_parser = output_parser or get_ernie_output_parser(functions)
return prompt | llm.bind(**llm_kwargs) | output_parser
def create_structured_output_runnable(
output_schema: Union[Dict[str, Any], Type[BaseModel]],
llm: Runnable,
prompt: BasePromptTemplate,
*,
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
**kwargs: Any,
) -> Runnable:
"""Create a runnable that uses an Ernie function to get a structured output.
Args:
output_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary
is passed in, it's assumed to already be a valid JsonSchema.
For best results, pydantic.BaseModels should have docstrings describing what
the schema represents and descriptions for the parameters.
llm: Language model to use, assumed to support the Ernie function-calling API.
prompt: BasePromptTemplate to pass to the model.
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
will be inferred from the function types. If pydantic.BaseModels are passed
in, then the OutputParser will try to parse outputs using those. Otherwise
model outputs will simply be parsed as JSON.
Returns:
A runnable sequence that will pass the given function to the model when run.
Example:
.. code-block:: python
from typing import Optional
from langchain.chains.ernie_functions import create_structured_output_chain
from langchain.chat_models import ErnieBotChat
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel, Field
class Dog(BaseModel):
\"\"\"Identifying information about a dog.\"\"\"
name: str = Field(..., description="The dog's name")
color: str = Field(..., description="The dog's color")
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
llm = ErnieBotChat(model_name="ERNIE-Bot-4")
prompt = ChatPromptTemplate.from_messages(
[
("user", "Use the given format to extract information from the following input: {input}"),
("assistant", "OK!"),
("user", "Tip: Make sure to answer in the correct format"),
]
)
chain = create_structured_output_chain(Dog, llm, prompt)
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
# -> Dog(name="Harry", color="brown", fav_food="chicken")
""" # noqa: E501
if isinstance(output_schema, dict):
function: Any = {
"name": "output_formatter",
"description": (
"Output formatter. Should always be used to format your response to the"
" user."
),
"parameters": output_schema,
}
else:
class _OutputFormatter(BaseModel):
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501
output: output_schema # type: ignore
function = _OutputFormatter
output_parser = output_parser or PydanticAttrOutputFunctionsParser(
pydantic_schema=_OutputFormatter, attr_name="output"
)
return create_ernie_fn_runnable(
[function],
llm,
prompt,
output_parser=output_parser,
**kwargs,
)
""" --- Legacy --- """
def create_ernie_fn_chain(
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
*,
output_key: str = "function",
output_parser: Optional[BaseLLMOutputParser] = None,
**kwargs: Any,
) -> LLMChain:
"""[Legacy] Create an LLM chain that uses Ernie functions.
Args:
functions: A sequence of either dictionaries, pydantic.BaseModels classes, or
Python functions. If dictionaries are passed in, they are assumed to
already be a valid Ernie functions. If only a single
function is passed in, then it will be enforced that the model use that
function. pydantic.BaseModels and Python functions should have docstrings
describing what the function does. For best results, pydantic.BaseModels
should have descriptions of the parameters and Python functions should have
Google Python style args descriptions in the docstring. Additionally,
Python functions should only use primitive types (str, int, float, bool) or
pydantic.BaseModels for arguments.
llm: Language model to use, assumed to support the Ernie function-calling API.
prompt: BasePromptTemplate to pass to the model.
output_key: The key to use when returning the output in LLMChain.__call__.
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
will be inferred from the function types. If pydantic.BaseModels are passed
in, then the OutputParser will try to parse outputs using those. Otherwise
model outputs will simply be parsed as JSON. If multiple functions are
passed in and they are not pydantic.BaseModels, the chain output will
include both the name of the function that was returned and the arguments
to pass to the function.
Returns:
An LLMChain that will pass in the given functions to the model when run.
Example:
.. code-block:: python
from typing import Optional
from langchain.chains.ernie_functions import create_ernie_fn_chain
from langchain.chat_models import ErnieBotChat
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel, Field
class RecordPerson(BaseModel):
\"\"\"Record some identifying information about a person.\"\"\"
name: str = Field(..., description="The person's name")
age: int = Field(..., description="The person's age")
fav_food: Optional[str] = Field(None, description="The person's favorite food")
class RecordDog(BaseModel):
\"\"\"Record some identifying information about a dog.\"\"\"
name: str = Field(..., description="The dog's name")
color: str = Field(..., description="The dog's color")
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
llm = ErnieBotChat(model_name="ERNIE-Bot-4")
prompt = ChatPromptTemplate.from_messages(
[
("user", "Make calls to the relevant function to record the entities in the following input: {input}"),
("assistant", "OK!"),
("user", "Tip: Make sure to answer in the correct format"),
]
)
chain = create_ernie_fn_chain([RecordPerson, RecordDog], llm, prompt)
chain.run("Harry was a chubby brown beagle who loved chicken")
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
""" # noqa: E501
if not functions:
raise ValueError("Need to pass in at least one function. Received zero.")
ernie_functions = [convert_to_ernie_function(f) for f in functions]
output_parser = output_parser or get_ernie_output_parser(functions)
llm_kwargs: Dict[str, Any] = {
"functions": ernie_functions,
}
if len(ernie_functions) == 1:
llm_kwargs["function_call"] = {"name": ernie_functions[0]["name"]}
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
output_parser=output_parser,
llm_kwargs=llm_kwargs,
output_key=output_key,
**kwargs,
)
return llm_chain
def create_structured_output_chain(
output_schema: Union[Dict[str, Any], Type[BaseModel]],
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
*,
output_key: str = "function",
output_parser: Optional[BaseLLMOutputParser] = None,
**kwargs: Any,
) -> LLMChain:
"""[Legacy] Create an LLMChain that uses an Ernie function to get a structured output.
Args:
output_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary
is passed in, it's assumed to already be a valid JsonSchema.
For best results, pydantic.BaseModels should have docstrings describing what
the schema represents and descriptions for the parameters.
llm: Language model to use, assumed to support the Ernie function-calling API.
prompt: BasePromptTemplate to pass to the model.
output_key: The key to use when returning the output in LLMChain.__call__.
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
will be inferred from the function types. If pydantic.BaseModels are passed
in, then the OutputParser will try to parse outputs using those. Otherwise
model outputs will simply be parsed as JSON.
Returns:
An LLMChain that will pass the given function to the model.
Example:
.. code-block:: python
from typing import Optional
from langchain.chains.ernie_functions import create_structured_output_chain
from langchain.chat_models import ErnieBotChat
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel, Field
class Dog(BaseModel):
\"\"\"Identifying information about a dog.\"\"\"
name: str = Field(..., description="The dog's name")
color: str = Field(..., description="The dog's color")
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
llm = ErnieBotChat(model_name="ERNIE-Bot-4")
prompt = ChatPromptTemplate.from_messages(
[
("user", "Use the given format to extract information from the following input: {input}"),
("assistant", "OK!"),
("user", "Tip: Make sure to answer in the correct format"),
]
)
chain = create_structured_output_chain(Dog, llm, prompt)
chain.run("Harry was a chubby brown beagle who loved chicken")
# -> Dog(name="Harry", color="brown", fav_food="chicken")
""" # noqa: E501
if isinstance(output_schema, dict):
function: Any = {
"name": "output_formatter",
"description": (
"Output formatter. Should always be used to format your response to the"
" user."
),
"parameters": output_schema,
}
else:
class _OutputFormatter(BaseModel):
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501
output: output_schema # type: ignore
function = _OutputFormatter
output_parser = output_parser or PydanticAttrOutputFunctionsParser(
pydantic_schema=_OutputFormatter, attr_name="output"
)
return create_ernie_fn_chain(
[function],
llm,
prompt,
output_key=output_key,
output_parser=output_parser,
**kwargs,
)

View File

@ -1,3 +1,4 @@
import json
import logging
import threading
from typing import Any, Dict, List, Mapping, Optional
@ -179,9 +180,15 @@ class ErnieBotChat(BaseChatModel):
return self._create_chat_result(resp)
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = [
ChatGeneration(message=AIMessage(content=response.get("result")))
]
if "function_call" in response:
fc_str = '{{"function_call": {}}}'.format(
json.dumps(response.get("function_call"))
)
generations = [ChatGeneration(message=AIMessage(content=fc_str))]
else:
generations = [
ChatGeneration(message=AIMessage(content=response.get("result")))
]
token_usage = response.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
return ChatResult(generations=generations, llm_output=llm_output)

View File

@ -0,0 +1,184 @@
import copy
import json
from typing import Any, Dict, List, Optional, Type, Union
import jsonpatch
from langchain.output_parsers.json import parse_partial_json
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema import (
ChatGeneration,
Generation,
OutputParserException,
)
from langchain.schema.output_parser import (
BaseCumulativeTransformOutputParser,
BaseGenerationOutputParser,
)
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
"""Parse an output that is one of sets of values."""
args_only: bool = True
"""Whether to only return the arguments to the function call."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
func_call = copy.deepcopy(message.additional_kwargs["function_call"])
except KeyError as exc:
raise OutputParserException(f"Could not parse function call: {exc}")
if self.args_only:
return func_call["arguments"]
return func_call
class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
"""Parse an output as the Json object."""
strict: bool = False
"""Whether to allow non-JSON-compliant strings.
See: https://docs.python.org/3/library/json.html#encoders-and-decoders
Useful when the parsed output may include unicode characters or new lines.
"""
args_only: bool = True
"""Whether to only return the arguments to the function call."""
@property
def _type(self) -> str:
return "json_functions"
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
if len(result) != 1:
raise OutputParserException(
f"Expected exactly one result, but got {len(result)}"
)
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
message.additional_kwargs["function_call"] = {}
if "function_call" in message.content:
function_call = json.loads(str(message.content))
if "function_call" in function_call:
fc = function_call["function_call"]
message.additional_kwargs["function_call"] = fc
try:
function_call = message.additional_kwargs["function_call"]
except KeyError as exc:
if partial:
return None
else:
raise OutputParserException(f"Could not parse function call: {exc}")
try:
if partial:
if self.args_only:
return parse_partial_json(
function_call["arguments"], strict=self.strict
)
else:
return {
**function_call,
"arguments": parse_partial_json(
function_call["arguments"], strict=self.strict
),
}
else:
if self.args_only:
try:
return json.loads(
function_call["arguments"], strict=self.strict
)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
else:
try:
return {
**function_call,
"arguments": json.loads(
function_call["arguments"], strict=self.strict
),
}
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
except KeyError:
return None
# This method would be called by the default implementation of `parse_result`
# but we're overriding that method so it's not needed.
def parse(self, text: str) -> Any:
raise NotImplementedError()
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
"""Parse an output as the element of the Json object."""
key_name: str
"""The name of the key to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
res = super().parse_result(result, partial=partial)
if partial and res is None:
return None
return res.get(self.key_name) if partial else res[self.key_name]
class PydanticOutputFunctionsParser(OutputFunctionsParser):
"""Parse an output as a pydantic object."""
pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
"""The pydantic schema to parse the output with."""
@root_validator(pre=True)
def validate_schema(cls, values: Dict) -> Dict:
schema = values["pydantic_schema"]
if "args_only" not in values:
values["args_only"] = isinstance(schema, type) and issubclass(
schema, BaseModel
)
elif values["args_only"] and isinstance(schema, Dict):
raise ValueError(
"If multiple pydantic schemas are provided then args_only should be"
" False."
)
return values
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
_result = super().parse_result(result)
if self.args_only:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
else:
fn_name = _result["name"]
_args = _result["arguments"]
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore # noqa: E501
return pydantic_args
class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
"""Parse an output as an attribute of a pydantic object."""
attr_name: str
"""The name of the attribute to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
result = super().parse_result(result)
return getattr(result, self.attr_name)

View File

@ -0,0 +1,51 @@
from typing import Literal, Optional, Type, TypedDict
from langchain.pydantic_v1 import BaseModel
from langchain.utils.json_schema import dereference_refs
class FunctionDescription(TypedDict):
"""Representation of a callable function to the Ernie API."""
name: str
"""The name of the function."""
description: str
"""A description of the function."""
parameters: dict
"""The parameters of the function."""
class ToolDescription(TypedDict):
"""Representation of a callable function to the Ernie API."""
type: Literal["function"]
function: FunctionDescription
def convert_pydantic_to_ernie_function(
model: Type[BaseModel],
*,
name: Optional[str] = None,
description: Optional[str] = None,
) -> FunctionDescription:
"""Converts a Pydantic model to a function description for the Ernie API."""
schema = dereference_refs(model.schema())
schema.pop("definitions", None)
return {
"name": name or schema["title"],
"description": description or schema["description"],
"parameters": schema,
}
def convert_pydantic_to_ernie_tool(
model: Type[BaseModel],
*,
name: Optional[str] = None,
description: Optional[str] = None,
) -> ToolDescription:
"""Converts a Pydantic model to a function description for the Ernie API."""
function = convert_pydantic_to_ernie_function(
model, name=name, description=description
)
return {"type": "function", "function": function}