mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 14:05:16 +00:00
feat: add ERNIE-Bot-4 Function Calling (#13320)
- **Description:** ERNIE-Bot-Chat-4 Large Language Model adds the ability of `Function Calling` by passing parameters through the `functions` parameter in the request. To simplify function calling for ERNIE-Bot-Chat-4, the `create_ernie_fn_chain()` function has been added. The definition and usage of the `create_ernie_fn_chain()` function is similar to that of the `create_openai_fn_chain()` function. Examples as the follows: ``` import json from langchain.chains.ernie_functions import ( create_ernie_fn_chain, ) from langchain.chat_models import ErnieBotChat from langchain.prompts import ChatPromptTemplate def get_current_news(location: str) -> str: """Get the current news based on the location.' Args: location (str): The location to query. Returs: str: Current news based on the location. """ news_info = { "location": location, "news": [ "I have a Book.", "It's a nice day, today." ] } return json.dumps(news_info) def get_current_weather(location: str, unit: str="celsius") -> str: """Get the current weather in a given location Args: location (str): location of the weather. unit (str): unit of the tempuature. Returns: str: weather in the given location. """ weather_info = { "location": location, "temperature": "27", "unit": unit, "forecast": ["sunny", "windy"], } return json.dumps(weather_info) llm = ErnieBotChat(model_name="ERNIE-Bot-4") prompt = ChatPromptTemplate.from_messages( [ ("human", "{query}"), ] ) chain = create_ernie_fn_chain([get_current_weather, get_current_news], llm, prompt, verbose=True) res = chain.run("北京今天的新闻是什么?") print(res) ``` The running results of the above program are shown below: ``` > Entering new LLMChain chain... Prompt after formatting: Human: 北京今天的新闻是什么? > Finished chain. {'name': 'get_current_news', 'thoughts': '用户想要知道北京今天的新闻。我可以使用get_current_news工具来获取这些信息。', 'arguments': {'location': '北京'}} ```
This commit is contained in:
parent
10418ab0c1
commit
fe7b40cb2a
17
libs/langchain/langchain/chains/ernie_functions/__init__.py
Normal file
17
libs/langchain/langchain/chains/ernie_functions/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
from langchain.chains.ernie_functions.base import (
|
||||
convert_to_ernie_function,
|
||||
create_ernie_fn_chain,
|
||||
create_ernie_fn_runnable,
|
||||
create_structured_output_chain,
|
||||
create_structured_output_runnable,
|
||||
get_ernie_output_parser,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"convert_to_ernie_function",
|
||||
"create_structured_output_chain",
|
||||
"create_ernie_fn_chain",
|
||||
"create_structured_output_runnable",
|
||||
"create_ernie_fn_runnable",
|
||||
"get_ernie_output_parser",
|
||||
]
|
547
libs/langchain/langchain/chains/ernie_functions/base.py
Normal file
547
libs/langchain/langchain/chains/ernie_functions/base.py
Normal file
@ -0,0 +1,547 @@
|
||||
"""Methods for creating chains that use Ernie function-calling APIs."""
|
||||
import inspect
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.output_parsers.ernie_functions import (
|
||||
JsonOutputFunctionsParser,
|
||||
PydanticAttrOutputFunctionsParser,
|
||||
PydanticOutputFunctionsParser,
|
||||
)
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.schema import BaseLLMOutputParser
|
||||
from langchain.schema.output_parser import BaseGenerationOutputParser, BaseOutputParser
|
||||
from langchain.schema.runnable import Runnable
|
||||
from langchain.utils.ernie_functions import convert_pydantic_to_ernie_function
|
||||
|
||||
PYTHON_TO_JSON_TYPES = {
|
||||
"str": "string",
|
||||
"int": "number",
|
||||
"float": "number",
|
||||
"bool": "boolean",
|
||||
}
|
||||
|
||||
|
||||
def _get_python_function_name(function: Callable) -> str:
|
||||
"""Get the name of a Python function."""
|
||||
return function.__name__
|
||||
|
||||
|
||||
def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
|
||||
"""Parse the function and argument descriptions from the docstring of a function.
|
||||
|
||||
Assumes the function docstring follows Google Python style guide.
|
||||
"""
|
||||
docstring = inspect.getdoc(function)
|
||||
if docstring:
|
||||
docstring_blocks = docstring.split("\n\n")
|
||||
descriptors = []
|
||||
args_block = None
|
||||
past_descriptors = False
|
||||
for block in docstring_blocks:
|
||||
if block.startswith("Args:"):
|
||||
args_block = block
|
||||
break
|
||||
elif block.startswith("Returns:") or block.startswith("Example:"):
|
||||
# Don't break in case Args come after
|
||||
past_descriptors = True
|
||||
elif not past_descriptors:
|
||||
descriptors.append(block)
|
||||
else:
|
||||
continue
|
||||
description = " ".join(descriptors)
|
||||
else:
|
||||
description = ""
|
||||
args_block = None
|
||||
arg_descriptions = {}
|
||||
if args_block:
|
||||
arg = None
|
||||
for line in args_block.split("\n")[1:]:
|
||||
if ":" in line:
|
||||
arg, desc = line.split(":")
|
||||
arg_descriptions[arg.strip()] = desc.strip()
|
||||
elif arg:
|
||||
arg_descriptions[arg.strip()] += " " + line.strip()
|
||||
return description, arg_descriptions
|
||||
|
||||
|
||||
def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict:
|
||||
"""Get JsonSchema describing a Python functions arguments.
|
||||
|
||||
Assumes all function arguments are of primitive types (int, float, str, bool) or
|
||||
are subclasses of pydantic.BaseModel.
|
||||
"""
|
||||
properties = {}
|
||||
annotations = inspect.getfullargspec(function).annotations
|
||||
for arg, arg_type in annotations.items():
|
||||
if arg == "return":
|
||||
continue
|
||||
if isinstance(arg_type, type) and issubclass(arg_type, BaseModel):
|
||||
# Mypy error:
|
||||
# "type" has no attribute "schema"
|
||||
properties[arg] = arg_type.schema() # type: ignore[attr-defined]
|
||||
elif arg_type.__name__ in PYTHON_TO_JSON_TYPES:
|
||||
properties[arg] = {"type": PYTHON_TO_JSON_TYPES[arg_type.__name__]}
|
||||
if arg in arg_descriptions:
|
||||
if arg not in properties:
|
||||
properties[arg] = {}
|
||||
properties[arg]["description"] = arg_descriptions[arg]
|
||||
return properties
|
||||
|
||||
|
||||
def _get_python_function_required_args(function: Callable) -> List[str]:
|
||||
"""Get the required arguments for a Python function."""
|
||||
spec = inspect.getfullargspec(function)
|
||||
required = spec.args[: -len(spec.defaults)] if spec.defaults else spec.args
|
||||
required += [k for k in spec.kwonlyargs if k not in (spec.kwonlydefaults or {})]
|
||||
|
||||
is_class = type(function) is type
|
||||
if is_class and required[0] == "self":
|
||||
required = required[1:]
|
||||
return required
|
||||
|
||||
|
||||
def convert_python_function_to_ernie_function(
|
||||
function: Callable,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a Python function to an Ernie function-calling API compatible dict.
|
||||
|
||||
Assumes the Python function has type hints and a docstring with a description. If
|
||||
the docstring has Google Python style argument descriptions, these will be
|
||||
included as well.
|
||||
"""
|
||||
description, arg_descriptions = _parse_python_function_docstring(function)
|
||||
return {
|
||||
"name": _get_python_function_name(function),
|
||||
"description": description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": _get_python_function_arguments(function, arg_descriptions),
|
||||
"required": _get_python_function_required_args(function),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def convert_to_ernie_function(
|
||||
function: Union[Dict[str, Any], Type[BaseModel], Callable]
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a raw function/class to an Ernie function.
|
||||
|
||||
Args:
|
||||
function: Either a dictionary, a pydantic.BaseModel class, or a Python function.
|
||||
If a dictionary is passed in, it is assumed to already be a valid Ernie
|
||||
function.
|
||||
|
||||
Returns:
|
||||
A dict version of the passed in function which is compatible with the
|
||||
Ernie function-calling API.
|
||||
"""
|
||||
if isinstance(function, dict):
|
||||
return function
|
||||
elif isinstance(function, type) and issubclass(function, BaseModel):
|
||||
return cast(Dict, convert_pydantic_to_ernie_function(function))
|
||||
elif callable(function):
|
||||
return convert_python_function_to_ernie_function(function)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported function type {type(function)}. Functions must be passed in"
|
||||
f" as Dict, pydantic.BaseModel, or Callable."
|
||||
)
|
||||
|
||||
|
||||
def get_ernie_output_parser(
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
|
||||
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
|
||||
"""Get the appropriate function output parser given the user functions.
|
||||
|
||||
Args:
|
||||
functions: Sequence where element is a dictionary, a pydantic.BaseModel class,
|
||||
or a Python function. If a dictionary is passed in, it is assumed to
|
||||
already be a valid Ernie function.
|
||||
|
||||
Returns:
|
||||
A PydanticOutputFunctionsParser if functions are Pydantic classes, otherwise
|
||||
a JsonOutputFunctionsParser. If there's only one function and it is
|
||||
not a Pydantic class, then the output parser will automatically extract
|
||||
only the function arguments and not the function name.
|
||||
"""
|
||||
function_names = [convert_to_ernie_function(f)["name"] for f in functions]
|
||||
if isinstance(functions[0], type) and issubclass(functions[0], BaseModel):
|
||||
if len(functions) > 1:
|
||||
pydantic_schema: Union[Dict, Type[BaseModel]] = {
|
||||
name: fn for name, fn in zip(function_names, functions)
|
||||
}
|
||||
else:
|
||||
pydantic_schema = functions[0]
|
||||
output_parser: Union[
|
||||
BaseOutputParser, BaseGenerationOutputParser
|
||||
] = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
|
||||
else:
|
||||
output_parser = JsonOutputFunctionsParser(args_only=len(functions) <= 1)
|
||||
return output_parser
|
||||
|
||||
|
||||
def create_ernie_fn_runnable(
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
|
||||
llm: Runnable,
|
||||
prompt: BasePromptTemplate,
|
||||
*,
|
||||
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable:
|
||||
"""Create a runnable sequence that uses Ernie functions.
|
||||
|
||||
Args:
|
||||
functions: A sequence of either dictionaries, pydantic.BaseModels classes, or
|
||||
Python functions. If dictionaries are passed in, they are assumed to
|
||||
already be a valid Ernie functions. If only a single
|
||||
function is passed in, then it will be enforced that the model use that
|
||||
function. pydantic.BaseModels and Python functions should have docstrings
|
||||
describing what the function does. For best results, pydantic.BaseModels
|
||||
should have descriptions of the parameters and Python functions should have
|
||||
Google Python style args descriptions in the docstring. Additionally,
|
||||
Python functions should only use primitive types (str, int, float, bool) or
|
||||
pydantic.BaseModels for arguments.
|
||||
llm: Language model to use, assumed to support the Ernie function-calling API.
|
||||
prompt: BasePromptTemplate to pass to the model.
|
||||
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
|
||||
will be inferred from the function types. If pydantic.BaseModels are passed
|
||||
in, then the OutputParser will try to parse outputs using those. Otherwise
|
||||
model outputs will simply be parsed as JSON. If multiple functions are
|
||||
passed in and they are not pydantic.BaseModels, the chain output will
|
||||
include both the name of the function that was returned and the arguments
|
||||
to pass to the function.
|
||||
|
||||
Returns:
|
||||
A runnable sequence that will pass in the given functions to the model when run.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain.chains.ernie_functions import create_ernie_fn_chain
|
||||
from langchain.chat_models import ErnieBotChat
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
|
||||
|
||||
class RecordPerson(BaseModel):
|
||||
\"\"\"Record some identifying information about a person.\"\"\"
|
||||
|
||||
name: str = Field(..., description="The person's name")
|
||||
age: int = Field(..., description="The person's age")
|
||||
fav_food: Optional[str] = Field(None, description="The person's favorite food")
|
||||
|
||||
|
||||
class RecordDog(BaseModel):
|
||||
\"\"\"Record some identifying information about a dog.\"\"\"
|
||||
|
||||
name: str = Field(..., description="The dog's name")
|
||||
color: str = Field(..., description="The dog's color")
|
||||
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
|
||||
|
||||
|
||||
llm = ErnieBotChat(model_name="ERNIE-Bot-4")
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("user", "Make calls to the relevant function to record the entities in the following input: {input}"),
|
||||
("assistant", "OK!"),
|
||||
("user", "Tip: Make sure to answer in the correct format"),
|
||||
]
|
||||
)
|
||||
chain = create_ernie_fn_runnable([RecordPerson, RecordDog], llm, prompt)
|
||||
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
|
||||
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
|
||||
""" # noqa: E501
|
||||
if not functions:
|
||||
raise ValueError("Need to pass in at least one function. Received zero.")
|
||||
ernie_functions = [convert_to_ernie_function(f) for f in functions]
|
||||
llm_kwargs: Dict[str, Any] = {"functions": ernie_functions, **kwargs}
|
||||
if len(ernie_functions) == 1:
|
||||
llm_kwargs["function_call"] = {"name": ernie_functions[0]["name"]}
|
||||
output_parser = output_parser or get_ernie_output_parser(functions)
|
||||
return prompt | llm.bind(**llm_kwargs) | output_parser
|
||||
|
||||
|
||||
def create_structured_output_runnable(
|
||||
output_schema: Union[Dict[str, Any], Type[BaseModel]],
|
||||
llm: Runnable,
|
||||
prompt: BasePromptTemplate,
|
||||
*,
|
||||
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable:
|
||||
"""Create a runnable that uses an Ernie function to get a structured output.
|
||||
|
||||
Args:
|
||||
output_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary
|
||||
is passed in, it's assumed to already be a valid JsonSchema.
|
||||
For best results, pydantic.BaseModels should have docstrings describing what
|
||||
the schema represents and descriptions for the parameters.
|
||||
llm: Language model to use, assumed to support the Ernie function-calling API.
|
||||
prompt: BasePromptTemplate to pass to the model.
|
||||
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
|
||||
will be inferred from the function types. If pydantic.BaseModels are passed
|
||||
in, then the OutputParser will try to parse outputs using those. Otherwise
|
||||
model outputs will simply be parsed as JSON.
|
||||
|
||||
Returns:
|
||||
A runnable sequence that will pass the given function to the model when run.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain.chains.ernie_functions import create_structured_output_chain
|
||||
from langchain.chat_models import ErnieBotChat
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
|
||||
class Dog(BaseModel):
|
||||
\"\"\"Identifying information about a dog.\"\"\"
|
||||
|
||||
name: str = Field(..., description="The dog's name")
|
||||
color: str = Field(..., description="The dog's color")
|
||||
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
|
||||
|
||||
llm = ErnieBotChat(model_name="ERNIE-Bot-4")
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("user", "Use the given format to extract information from the following input: {input}"),
|
||||
("assistant", "OK!"),
|
||||
("user", "Tip: Make sure to answer in the correct format"),
|
||||
]
|
||||
)
|
||||
chain = create_structured_output_chain(Dog, llm, prompt)
|
||||
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
|
||||
# -> Dog(name="Harry", color="brown", fav_food="chicken")
|
||||
""" # noqa: E501
|
||||
if isinstance(output_schema, dict):
|
||||
function: Any = {
|
||||
"name": "output_formatter",
|
||||
"description": (
|
||||
"Output formatter. Should always be used to format your response to the"
|
||||
" user."
|
||||
),
|
||||
"parameters": output_schema,
|
||||
}
|
||||
else:
|
||||
|
||||
class _OutputFormatter(BaseModel):
|
||||
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501
|
||||
|
||||
output: output_schema # type: ignore
|
||||
|
||||
function = _OutputFormatter
|
||||
output_parser = output_parser or PydanticAttrOutputFunctionsParser(
|
||||
pydantic_schema=_OutputFormatter, attr_name="output"
|
||||
)
|
||||
return create_ernie_fn_runnable(
|
||||
[function],
|
||||
llm,
|
||||
prompt,
|
||||
output_parser=output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
""" --- Legacy --- """
|
||||
|
||||
|
||||
def create_ernie_fn_chain(
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate,
|
||||
*,
|
||||
output_key: str = "function",
|
||||
output_parser: Optional[BaseLLMOutputParser] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMChain:
|
||||
"""[Legacy] Create an LLM chain that uses Ernie functions.
|
||||
|
||||
Args:
|
||||
functions: A sequence of either dictionaries, pydantic.BaseModels classes, or
|
||||
Python functions. If dictionaries are passed in, they are assumed to
|
||||
already be a valid Ernie functions. If only a single
|
||||
function is passed in, then it will be enforced that the model use that
|
||||
function. pydantic.BaseModels and Python functions should have docstrings
|
||||
describing what the function does. For best results, pydantic.BaseModels
|
||||
should have descriptions of the parameters and Python functions should have
|
||||
Google Python style args descriptions in the docstring. Additionally,
|
||||
Python functions should only use primitive types (str, int, float, bool) or
|
||||
pydantic.BaseModels for arguments.
|
||||
llm: Language model to use, assumed to support the Ernie function-calling API.
|
||||
prompt: BasePromptTemplate to pass to the model.
|
||||
output_key: The key to use when returning the output in LLMChain.__call__.
|
||||
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
|
||||
will be inferred from the function types. If pydantic.BaseModels are passed
|
||||
in, then the OutputParser will try to parse outputs using those. Otherwise
|
||||
model outputs will simply be parsed as JSON. If multiple functions are
|
||||
passed in and they are not pydantic.BaseModels, the chain output will
|
||||
include both the name of the function that was returned and the arguments
|
||||
to pass to the function.
|
||||
|
||||
Returns:
|
||||
An LLMChain that will pass in the given functions to the model when run.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain.chains.ernie_functions import create_ernie_fn_chain
|
||||
from langchain.chat_models import ErnieBotChat
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
|
||||
|
||||
class RecordPerson(BaseModel):
|
||||
\"\"\"Record some identifying information about a person.\"\"\"
|
||||
|
||||
name: str = Field(..., description="The person's name")
|
||||
age: int = Field(..., description="The person's age")
|
||||
fav_food: Optional[str] = Field(None, description="The person's favorite food")
|
||||
|
||||
|
||||
class RecordDog(BaseModel):
|
||||
\"\"\"Record some identifying information about a dog.\"\"\"
|
||||
|
||||
name: str = Field(..., description="The dog's name")
|
||||
color: str = Field(..., description="The dog's color")
|
||||
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
|
||||
|
||||
|
||||
llm = ErnieBotChat(model_name="ERNIE-Bot-4")
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("user", "Make calls to the relevant function to record the entities in the following input: {input}"),
|
||||
("assistant", "OK!"),
|
||||
("user", "Tip: Make sure to answer in the correct format"),
|
||||
]
|
||||
)
|
||||
chain = create_ernie_fn_chain([RecordPerson, RecordDog], llm, prompt)
|
||||
chain.run("Harry was a chubby brown beagle who loved chicken")
|
||||
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
|
||||
""" # noqa: E501
|
||||
if not functions:
|
||||
raise ValueError("Need to pass in at least one function. Received zero.")
|
||||
ernie_functions = [convert_to_ernie_function(f) for f in functions]
|
||||
output_parser = output_parser or get_ernie_output_parser(functions)
|
||||
llm_kwargs: Dict[str, Any] = {
|
||||
"functions": ernie_functions,
|
||||
}
|
||||
if len(ernie_functions) == 1:
|
||||
llm_kwargs["function_call"] = {"name": ernie_functions[0]["name"]}
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
output_parser=output_parser,
|
||||
llm_kwargs=llm_kwargs,
|
||||
output_key=output_key,
|
||||
**kwargs,
|
||||
)
|
||||
return llm_chain
|
||||
|
||||
|
||||
def create_structured_output_chain(
|
||||
output_schema: Union[Dict[str, Any], Type[BaseModel]],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate,
|
||||
*,
|
||||
output_key: str = "function",
|
||||
output_parser: Optional[BaseLLMOutputParser] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMChain:
|
||||
"""[Legacy] Create an LLMChain that uses an Ernie function to get a structured output.
|
||||
|
||||
Args:
|
||||
output_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary
|
||||
is passed in, it's assumed to already be a valid JsonSchema.
|
||||
For best results, pydantic.BaseModels should have docstrings describing what
|
||||
the schema represents and descriptions for the parameters.
|
||||
llm: Language model to use, assumed to support the Ernie function-calling API.
|
||||
prompt: BasePromptTemplate to pass to the model.
|
||||
output_key: The key to use when returning the output in LLMChain.__call__.
|
||||
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
|
||||
will be inferred from the function types. If pydantic.BaseModels are passed
|
||||
in, then the OutputParser will try to parse outputs using those. Otherwise
|
||||
model outputs will simply be parsed as JSON.
|
||||
|
||||
Returns:
|
||||
An LLMChain that will pass the given function to the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain.chains.ernie_functions import create_structured_output_chain
|
||||
from langchain.chat_models import ErnieBotChat
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
|
||||
class Dog(BaseModel):
|
||||
\"\"\"Identifying information about a dog.\"\"\"
|
||||
|
||||
name: str = Field(..., description="The dog's name")
|
||||
color: str = Field(..., description="The dog's color")
|
||||
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
|
||||
|
||||
llm = ErnieBotChat(model_name="ERNIE-Bot-4")
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("user", "Use the given format to extract information from the following input: {input}"),
|
||||
("assistant", "OK!"),
|
||||
("user", "Tip: Make sure to answer in the correct format"),
|
||||
]
|
||||
)
|
||||
chain = create_structured_output_chain(Dog, llm, prompt)
|
||||
chain.run("Harry was a chubby brown beagle who loved chicken")
|
||||
# -> Dog(name="Harry", color="brown", fav_food="chicken")
|
||||
""" # noqa: E501
|
||||
if isinstance(output_schema, dict):
|
||||
function: Any = {
|
||||
"name": "output_formatter",
|
||||
"description": (
|
||||
"Output formatter. Should always be used to format your response to the"
|
||||
" user."
|
||||
),
|
||||
"parameters": output_schema,
|
||||
}
|
||||
else:
|
||||
|
||||
class _OutputFormatter(BaseModel):
|
||||
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501
|
||||
|
||||
output: output_schema # type: ignore
|
||||
|
||||
function = _OutputFormatter
|
||||
output_parser = output_parser or PydanticAttrOutputFunctionsParser(
|
||||
pydantic_schema=_OutputFormatter, attr_name="output"
|
||||
)
|
||||
return create_ernie_fn_chain(
|
||||
[function],
|
||||
llm,
|
||||
prompt,
|
||||
output_key=output_key,
|
||||
output_parser=output_parser,
|
||||
**kwargs,
|
||||
)
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
@ -179,9 +180,15 @@ class ErnieBotChat(BaseChatModel):
|
||||
return self._create_chat_result(resp)
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = [
|
||||
ChatGeneration(message=AIMessage(content=response.get("result")))
|
||||
]
|
||||
if "function_call" in response:
|
||||
fc_str = '{{"function_call": {}}}'.format(
|
||||
json.dumps(response.get("function_call"))
|
||||
)
|
||||
generations = [ChatGeneration(message=AIMessage(content=fc_str))]
|
||||
else:
|
||||
generations = [
|
||||
ChatGeneration(message=AIMessage(content=response.get("result")))
|
||||
]
|
||||
token_usage = response.get("usage", {})
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
184
libs/langchain/langchain/output_parsers/ernie_functions.py
Normal file
184
libs/langchain/langchain/output_parsers/ernie_functions.py
Normal file
@ -0,0 +1,184 @@
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import jsonpatch
|
||||
|
||||
from langchain.output_parsers.json import parse_partial_json
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
Generation,
|
||||
OutputParserException,
|
||||
)
|
||||
from langchain.schema.output_parser import (
|
||||
BaseCumulativeTransformOutputParser,
|
||||
BaseGenerationOutputParser,
|
||||
)
|
||||
|
||||
|
||||
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
"""Parse an output that is one of sets of values."""
|
||||
|
||||
args_only: bool = True
|
||||
"""Whether to only return the arguments to the function call."""
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
raise OutputParserException(
|
||||
"This output parser can only be used with a chat generation."
|
||||
)
|
||||
message = generation.message
|
||||
try:
|
||||
func_call = copy.deepcopy(message.additional_kwargs["function_call"])
|
||||
except KeyError as exc:
|
||||
raise OutputParserException(f"Could not parse function call: {exc}")
|
||||
|
||||
if self.args_only:
|
||||
return func_call["arguments"]
|
||||
return func_call
|
||||
|
||||
|
||||
class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
"""Parse an output as the Json object."""
|
||||
|
||||
strict: bool = False
|
||||
"""Whether to allow non-JSON-compliant strings.
|
||||
|
||||
See: https://docs.python.org/3/library/json.html#encoders-and-decoders
|
||||
|
||||
Useful when the parsed output may include unicode characters or new lines.
|
||||
"""
|
||||
|
||||
args_only: bool = True
|
||||
"""Whether to only return the arguments to the function call."""
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "json_functions"
|
||||
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
if len(result) != 1:
|
||||
raise OutputParserException(
|
||||
f"Expected exactly one result, but got {len(result)}"
|
||||
)
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
raise OutputParserException(
|
||||
"This output parser can only be used with a chat generation."
|
||||
)
|
||||
message = generation.message
|
||||
message.additional_kwargs["function_call"] = {}
|
||||
if "function_call" in message.content:
|
||||
function_call = json.loads(str(message.content))
|
||||
if "function_call" in function_call:
|
||||
fc = function_call["function_call"]
|
||||
message.additional_kwargs["function_call"] = fc
|
||||
try:
|
||||
function_call = message.additional_kwargs["function_call"]
|
||||
except KeyError as exc:
|
||||
if partial:
|
||||
return None
|
||||
else:
|
||||
raise OutputParserException(f"Could not parse function call: {exc}")
|
||||
try:
|
||||
if partial:
|
||||
if self.args_only:
|
||||
return parse_partial_json(
|
||||
function_call["arguments"], strict=self.strict
|
||||
)
|
||||
else:
|
||||
return {
|
||||
**function_call,
|
||||
"arguments": parse_partial_json(
|
||||
function_call["arguments"], strict=self.strict
|
||||
),
|
||||
}
|
||||
else:
|
||||
if self.args_only:
|
||||
try:
|
||||
return json.loads(
|
||||
function_call["arguments"], strict=self.strict
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
raise OutputParserException(
|
||||
f"Could not parse function call data: {exc}"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
return {
|
||||
**function_call,
|
||||
"arguments": json.loads(
|
||||
function_call["arguments"], strict=self.strict
|
||||
),
|
||||
}
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
raise OutputParserException(
|
||||
f"Could not parse function call data: {exc}"
|
||||
)
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
# This method would be called by the default implementation of `parse_result`
|
||||
# but we're overriding that method so it's not needed.
|
||||
def parse(self, text: str) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||
"""Parse an output as the element of the Json object."""
|
||||
|
||||
key_name: str
|
||||
"""The name of the key to return."""
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
res = super().parse_result(result, partial=partial)
|
||||
if partial and res is None:
|
||||
return None
|
||||
return res.get(self.key_name) if partial else res[self.key_name]
|
||||
|
||||
|
||||
class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
"""Parse an output as a pydantic object."""
|
||||
|
||||
pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
|
||||
"""The pydantic schema to parse the output with."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_schema(cls, values: Dict) -> Dict:
|
||||
schema = values["pydantic_schema"]
|
||||
if "args_only" not in values:
|
||||
values["args_only"] = isinstance(schema, type) and issubclass(
|
||||
schema, BaseModel
|
||||
)
|
||||
elif values["args_only"] and isinstance(schema, Dict):
|
||||
raise ValueError(
|
||||
"If multiple pydantic schemas are provided then args_only should be"
|
||||
" False."
|
||||
)
|
||||
return values
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
_result = super().parse_result(result)
|
||||
if self.args_only:
|
||||
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
|
||||
else:
|
||||
fn_name = _result["name"]
|
||||
_args = _result["arguments"]
|
||||
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore # noqa: E501
|
||||
return pydantic_args
|
||||
|
||||
|
||||
class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
|
||||
"""Parse an output as an attribute of a pydantic object."""
|
||||
|
||||
attr_name: str
|
||||
"""The name of the attribute to return."""
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
result = super().parse_result(result)
|
||||
return getattr(result, self.attr_name)
|
51
libs/langchain/langchain/utils/ernie_functions.py
Normal file
51
libs/langchain/langchain/utils/ernie_functions.py
Normal file
@ -0,0 +1,51 @@
|
||||
from typing import Literal, Optional, Type, TypedDict
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.utils.json_schema import dereference_refs
|
||||
|
||||
|
||||
class FunctionDescription(TypedDict):
|
||||
"""Representation of a callable function to the Ernie API."""
|
||||
|
||||
name: str
|
||||
"""The name of the function."""
|
||||
description: str
|
||||
"""A description of the function."""
|
||||
parameters: dict
|
||||
"""The parameters of the function."""
|
||||
|
||||
|
||||
class ToolDescription(TypedDict):
|
||||
"""Representation of a callable function to the Ernie API."""
|
||||
|
||||
type: Literal["function"]
|
||||
function: FunctionDescription
|
||||
|
||||
|
||||
def convert_pydantic_to_ernie_function(
|
||||
model: Type[BaseModel],
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> FunctionDescription:
|
||||
"""Converts a Pydantic model to a function description for the Ernie API."""
|
||||
schema = dereference_refs(model.schema())
|
||||
schema.pop("definitions", None)
|
||||
return {
|
||||
"name": name or schema["title"],
|
||||
"description": description or schema["description"],
|
||||
"parameters": schema,
|
||||
}
|
||||
|
||||
|
||||
def convert_pydantic_to_ernie_tool(
|
||||
model: Type[BaseModel],
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> ToolDescription:
|
||||
"""Converts a Pydantic model to a function description for the Ernie API."""
|
||||
function = convert_pydantic_to_ernie_function(
|
||||
model, name=name, description=description
|
||||
)
|
||||
return {"type": "function", "function": function}
|
Loading…
Reference in New Issue
Block a user