mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 05:30:39 +00:00
feat: 1. Add system parameters, 2. Align with the QianfanChatEndpoint for function calling (#14275)
- **Description:** 1. Add system parameters to the ERNIE LLM API to set the role of the LLM. 2. Add support for the ERNIE-Bot-turbo-AI model according from the document https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Alp0kdm0n. 3. For the function call of ErnieBotChat, align with the QianfanChatEndpoint. With this PR, the `QianfanChatEndpoint()` can use the `function calling` ability with `create_ernie_fn_chain()`. The example is as the following: ``` from langchain.prompts import ChatPromptTemplate import json from langchain.prompts.chat import ( ChatPromptTemplate, ) from langchain.chat_models import QianfanChatEndpoint from langchain.chains.ernie_functions import ( create_ernie_fn_chain, ) def get_current_news(location: str) -> str: """Get the current news based on the location.' Args: location (str): The location to query. Returs: str: Current news based on the location. """ news_info = { "location": location, "news": [ "I have a Book.", "It's a nice day, today." ] } return json.dumps(news_info) def get_current_weather(location: str, unit: str="celsius") -> str: """Get the current weather in a given location Args: location (str): location of the weather. unit (str): unit of the tempuature. Returns: str: weather in the given location. """ weather_info = { "location": location, "temperature": "27", "unit": unit, "forecast": ["sunny", "windy"], } return json.dumps(weather_info) template = ChatPromptTemplate.from_messages([ ("user", "{user_input}"), ]) chat = QianfanChatEndpoint(model="ERNIE-Bot-4") chain = create_ernie_fn_chain([get_current_weather, get_current_news], chat, template, verbose=True) res = chain.run("北京今天的新闻是什么?") print(res) ``` The result of the above code: ``` > Entering new LLMChain chain... Prompt after formatting: Human: 北京今天的新闻是什么? > Finished chain. {'name': 'get_current_news', 'arguments': {'location': '北京'}} ``` For the `ErnieBotChat`, now can use the `system` parameter to set the role of the LLM. ``` from langchain.prompts import ChatPromptTemplate from langchain.chains import LLMChain from langchain.chat_models import ErnieBotChat llm = ErnieBotChat(model_name="ERNIE-Bot-turbo-AI", system="你是一个能力很强的机器人,你的名字叫 小叮当。无论问你什么问题,你都可以给出答案。") prompt = ChatPromptTemplate.from_messages( [ ("human", "{query}"), ] ) chain = LLMChain(llm=llm, prompt=prompt, verbose=True) res = chain.run(query="你是谁?") print(res) ``` The result of the above code: ``` > Entering new LLMChain chain... Prompt after formatting: Human: 你是谁? > Finished chain. 我是小叮当,一个智能机器人。我可以为你提供各种服务,包括回答问题、提供信息、进行计算等。如果你需要任何帮助,请随时告诉我,我会尽力为你提供最好的服务。 ```
This commit is contained in:
parent
fd5be55a7b
commit
7205bfdd00
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
@ -46,7 +45,8 @@ class ErnieBotChat(BaseChatModel):
|
|||||||
and will be regenerated after expiration (30 days).
|
and will be regenerated after expiration (30 days).
|
||||||
|
|
||||||
Default model is `ERNIE-Bot-turbo`,
|
Default model is `ERNIE-Bot-turbo`,
|
||||||
currently supported models are `ERNIE-Bot-turbo`, `ERNIE-Bot`
|
currently supported models are `ERNIE-Bot-turbo`, `ERNIE-Bot`, `ERNIE-Bot-8K`,
|
||||||
|
`ERNIE-Bot-4`, `ERNIE-Bot-turbo-AI`.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -87,6 +87,11 @@ class ErnieBotChat(BaseChatModel):
|
|||||||
"""model name of ernie, default is `ERNIE-Bot-turbo`.
|
"""model name of ernie, default is `ERNIE-Bot-turbo`.
|
||||||
Currently supported `ERNIE-Bot-turbo`, `ERNIE-Bot`"""
|
Currently supported `ERNIE-Bot-turbo`, `ERNIE-Bot`"""
|
||||||
|
|
||||||
|
system: Optional[str] = None
|
||||||
|
"""system is mainly used for model character design,
|
||||||
|
for example, you are an AI assistant produced by xxx company.
|
||||||
|
The length of the system is limiting of 1024 characters."""
|
||||||
|
|
||||||
request_timeout: Optional[int] = 60
|
request_timeout: Optional[int] = 60
|
||||||
"""request timeout for chat http requests"""
|
"""request timeout for chat http requests"""
|
||||||
|
|
||||||
@ -123,6 +128,7 @@ class ErnieBotChat(BaseChatModel):
|
|||||||
"ERNIE-Bot": "completions",
|
"ERNIE-Bot": "completions",
|
||||||
"ERNIE-Bot-8K": "ernie_bot_8k",
|
"ERNIE-Bot-8K": "ernie_bot_8k",
|
||||||
"ERNIE-Bot-4": "completions_pro",
|
"ERNIE-Bot-4": "completions_pro",
|
||||||
|
"ERNIE-Bot-turbo-AI": "ai_apaas",
|
||||||
"BLOOMZ-7B": "bloomz_7b1",
|
"BLOOMZ-7B": "bloomz_7b1",
|
||||||
"Llama-2-7b-chat": "llama_2_7b",
|
"Llama-2-7b-chat": "llama_2_7b",
|
||||||
"Llama-2-13b-chat": "llama_2_13b",
|
"Llama-2-13b-chat": "llama_2_13b",
|
||||||
@ -180,6 +186,7 @@ class ErnieBotChat(BaseChatModel):
|
|||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"penalty_score": self.penalty_score,
|
"penalty_score": self.penalty_score,
|
||||||
|
"system": self.system,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
logger.debug(f"Payload for ernie api is {payload}")
|
logger.debug(f"Payload for ernie api is {payload}")
|
||||||
@ -195,14 +202,19 @@ class ErnieBotChat(BaseChatModel):
|
|||||||
|
|
||||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||||
if "function_call" in response:
|
if "function_call" in response:
|
||||||
fc_str = '{{"function_call": {}}}'.format(
|
additional_kwargs = {
|
||||||
json.dumps(response.get("function_call"))
|
"function_call": dict(response.get("function_call", {}))
|
||||||
)
|
}
|
||||||
generations = [ChatGeneration(message=AIMessage(content=fc_str))]
|
|
||||||
else:
|
else:
|
||||||
generations = [
|
additional_kwargs = {}
|
||||||
ChatGeneration(message=AIMessage(content=response.get("result")))
|
generations = [
|
||||||
]
|
ChatGeneration(
|
||||||
|
message=AIMessage(
|
||||||
|
content=response.get("result"),
|
||||||
|
additional_kwargs={**additional_kwargs},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
token_usage = response.get("usage", {})
|
token_usage = response.get("usage", {})
|
||||||
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
||||||
return ChatResult(generations=generations, llm_output=llm_output)
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
@ -72,12 +72,8 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
"This output parser can only be used with a chat generation."
|
"This output parser can only be used with a chat generation."
|
||||||
)
|
)
|
||||||
message = generation.message
|
message = generation.message
|
||||||
message.additional_kwargs["function_call"] = {}
|
if "function_call" not in message.additional_kwargs:
|
||||||
if "function_call" in message.content:
|
return None
|
||||||
function_call = json.loads(str(message.content))
|
|
||||||
if "function_call" in function_call:
|
|
||||||
fc = function_call["function_call"]
|
|
||||||
message.additional_kwargs["function_call"] = fc
|
|
||||||
try:
|
try:
|
||||||
function_call = message.additional_kwargs["function_call"]
|
function_call = message.additional_kwargs["function_call"]
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
|
Loading…
Reference in New Issue
Block a user