From 7205bfdd00eb964014ab4e63d8d00ec6834f88e9 Mon Sep 17 00:00:00 2001 From: Wang Wei Date: Wed, 6 Dec 2023 09:28:31 +0800 Subject: [PATCH] feat: 1. Add system parameters, 2. Align with the QianfanChatEndpoint for function calling (#14275) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - **Description:** 1. Add system parameters to the ERNIE LLM API to set the role of the LLM. 2. Add support for the ERNIE-Bot-turbo-AI model according from the document https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Alp0kdm0n. 3. For the function call of ErnieBotChat, align with the QianfanChatEndpoint. With this PR, the `QianfanChatEndpoint()` can use the `function calling` ability with `create_ernie_fn_chain()`. The example is as the following: ``` from langchain.prompts import ChatPromptTemplate import json from langchain.prompts.chat import ( ChatPromptTemplate, ) from langchain.chat_models import QianfanChatEndpoint from langchain.chains.ernie_functions import ( create_ernie_fn_chain, ) def get_current_news(location: str) -> str: """Get the current news based on the location.' Args: location (str): The location to query. Returs: str: Current news based on the location. """ news_info = { "location": location, "news": [ "I have a Book.", "It's a nice day, today." ] } return json.dumps(news_info) def get_current_weather(location: str, unit: str="celsius") -> str: """Get the current weather in a given location Args: location (str): location of the weather. unit (str): unit of the tempuature. Returns: str: weather in the given location. """ weather_info = { "location": location, "temperature": "27", "unit": unit, "forecast": ["sunny", "windy"], } return json.dumps(weather_info) template = ChatPromptTemplate.from_messages([ ("user", "{user_input}"), ]) chat = QianfanChatEndpoint(model="ERNIE-Bot-4") chain = create_ernie_fn_chain([get_current_weather, get_current_news], chat, template, verbose=True) res = chain.run("北京今天的新闻是什么?") print(res) ``` The result of the above code: ``` > Entering new LLMChain chain... Prompt after formatting: Human: 北京今天的新闻是什么? > Finished chain. {'name': 'get_current_news', 'arguments': {'location': '北京'}} ``` For the `ErnieBotChat`, now can use the `system` parameter to set the role of the LLM. ``` from langchain.prompts import ChatPromptTemplate from langchain.chains import LLMChain from langchain.chat_models import ErnieBotChat llm = ErnieBotChat(model_name="ERNIE-Bot-turbo-AI", system="你是一个能力很强的机器人,你的名字叫 小叮当。无论问你什么问题,你都可以给出答案。") prompt = ChatPromptTemplate.from_messages( [ ("human", "{query}"), ] ) chain = LLMChain(llm=llm, prompt=prompt, verbose=True) res = chain.run(query="你是谁?") print(res) ``` The result of the above code: ``` > Entering new LLMChain chain... Prompt after formatting: Human: 你是谁? > Finished chain. 我是小叮当,一个智能机器人。我可以为你提供各种服务,包括回答问题、提供信息、进行计算等。如果你需要任何帮助,请随时告诉我,我会尽力为你提供最好的服务。 ``` --- libs/langchain/langchain/chat_models/ernie.py | 30 +++++++++++++------ .../output_parsers/ernie_functions.py | 8 ++--- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/libs/langchain/langchain/chat_models/ernie.py b/libs/langchain/langchain/chat_models/ernie.py index 7687c1241e2..a46a15bbdb8 100644 --- a/libs/langchain/langchain/chat_models/ernie.py +++ b/libs/langchain/langchain/chat_models/ernie.py @@ -1,4 +1,3 @@ -import json import logging import threading from typing import Any, Dict, List, Mapping, Optional @@ -46,7 +45,8 @@ class ErnieBotChat(BaseChatModel): and will be regenerated after expiration (30 days). Default model is `ERNIE-Bot-turbo`, - currently supported models are `ERNIE-Bot-turbo`, `ERNIE-Bot` + currently supported models are `ERNIE-Bot-turbo`, `ERNIE-Bot`, `ERNIE-Bot-8K`, + `ERNIE-Bot-4`, `ERNIE-Bot-turbo-AI`. Example: .. code-block:: python @@ -87,6 +87,11 @@ class ErnieBotChat(BaseChatModel): """model name of ernie, default is `ERNIE-Bot-turbo`. Currently supported `ERNIE-Bot-turbo`, `ERNIE-Bot`""" + system: Optional[str] = None + """system is mainly used for model character design, + for example, you are an AI assistant produced by xxx company. + The length of the system is limiting of 1024 characters.""" + request_timeout: Optional[int] = 60 """request timeout for chat http requests""" @@ -123,6 +128,7 @@ class ErnieBotChat(BaseChatModel): "ERNIE-Bot": "completions", "ERNIE-Bot-8K": "ernie_bot_8k", "ERNIE-Bot-4": "completions_pro", + "ERNIE-Bot-turbo-AI": "ai_apaas", "BLOOMZ-7B": "bloomz_7b1", "Llama-2-7b-chat": "llama_2_7b", "Llama-2-13b-chat": "llama_2_13b", @@ -180,6 +186,7 @@ class ErnieBotChat(BaseChatModel): "top_p": self.top_p, "temperature": self.temperature, "penalty_score": self.penalty_score, + "system": self.system, **kwargs, } logger.debug(f"Payload for ernie api is {payload}") @@ -195,14 +202,19 @@ class ErnieBotChat(BaseChatModel): def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: if "function_call" in response: - fc_str = '{{"function_call": {}}}'.format( - json.dumps(response.get("function_call")) - ) - generations = [ChatGeneration(message=AIMessage(content=fc_str))] + additional_kwargs = { + "function_call": dict(response.get("function_call", {})) + } else: - generations = [ - ChatGeneration(message=AIMessage(content=response.get("result"))) - ] + additional_kwargs = {} + generations = [ + ChatGeneration( + message=AIMessage( + content=response.get("result"), + additional_kwargs={**additional_kwargs}, + ) + ) + ] token_usage = response.get("usage", {}) llm_output = {"token_usage": token_usage, "model_name": self.model_name} return ChatResult(generations=generations, llm_output=llm_output) diff --git a/libs/langchain/langchain/output_parsers/ernie_functions.py b/libs/langchain/langchain/output_parsers/ernie_functions.py index b2682c4dc21..dd5e4585345 100644 --- a/libs/langchain/langchain/output_parsers/ernie_functions.py +++ b/libs/langchain/langchain/output_parsers/ernie_functions.py @@ -72,12 +72,8 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]): "This output parser can only be used with a chat generation." ) message = generation.message - message.additional_kwargs["function_call"] = {} - if "function_call" in message.content: - function_call = json.loads(str(message.content)) - if "function_call" in function_call: - fc = function_call["function_call"] - message.additional_kwargs["function_call"] = fc + if "function_call" not in message.additional_kwargs: + return None try: function_call = message.additional_kwargs["function_call"] except KeyError as exc: