diff --git a/libs/langchain/langchain/chat_models/ernie.py b/libs/langchain/langchain/chat_models/ernie.py index 7687c1241e2..a46a15bbdb8 100644 --- a/libs/langchain/langchain/chat_models/ernie.py +++ b/libs/langchain/langchain/chat_models/ernie.py @@ -1,4 +1,3 @@ -import json import logging import threading from typing import Any, Dict, List, Mapping, Optional @@ -46,7 +45,8 @@ class ErnieBotChat(BaseChatModel): and will be regenerated after expiration (30 days). Default model is `ERNIE-Bot-turbo`, - currently supported models are `ERNIE-Bot-turbo`, `ERNIE-Bot` + currently supported models are `ERNIE-Bot-turbo`, `ERNIE-Bot`, `ERNIE-Bot-8K`, + `ERNIE-Bot-4`, `ERNIE-Bot-turbo-AI`. Example: .. code-block:: python @@ -87,6 +87,11 @@ class ErnieBotChat(BaseChatModel): """model name of ernie, default is `ERNIE-Bot-turbo`. Currently supported `ERNIE-Bot-turbo`, `ERNIE-Bot`""" + system: Optional[str] = None + """system is mainly used for model character design, + for example, you are an AI assistant produced by xxx company. + The length of the system is limiting of 1024 characters.""" + request_timeout: Optional[int] = 60 """request timeout for chat http requests""" @@ -123,6 +128,7 @@ class ErnieBotChat(BaseChatModel): "ERNIE-Bot": "completions", "ERNIE-Bot-8K": "ernie_bot_8k", "ERNIE-Bot-4": "completions_pro", + "ERNIE-Bot-turbo-AI": "ai_apaas", "BLOOMZ-7B": "bloomz_7b1", "Llama-2-7b-chat": "llama_2_7b", "Llama-2-13b-chat": "llama_2_13b", @@ -180,6 +186,7 @@ class ErnieBotChat(BaseChatModel): "top_p": self.top_p, "temperature": self.temperature, "penalty_score": self.penalty_score, + "system": self.system, **kwargs, } logger.debug(f"Payload for ernie api is {payload}") @@ -195,14 +202,19 @@ class ErnieBotChat(BaseChatModel): def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: if "function_call" in response: - fc_str = '{{"function_call": {}}}'.format( - json.dumps(response.get("function_call")) - ) - generations = [ChatGeneration(message=AIMessage(content=fc_str))] + additional_kwargs = { + "function_call": dict(response.get("function_call", {})) + } else: - generations = [ - ChatGeneration(message=AIMessage(content=response.get("result"))) - ] + additional_kwargs = {} + generations = [ + ChatGeneration( + message=AIMessage( + content=response.get("result"), + additional_kwargs={**additional_kwargs}, + ) + ) + ] token_usage = response.get("usage", {}) llm_output = {"token_usage": token_usage, "model_name": self.model_name} return ChatResult(generations=generations, llm_output=llm_output) diff --git a/libs/langchain/langchain/output_parsers/ernie_functions.py b/libs/langchain/langchain/output_parsers/ernie_functions.py index b2682c4dc21..dd5e4585345 100644 --- a/libs/langchain/langchain/output_parsers/ernie_functions.py +++ b/libs/langchain/langchain/output_parsers/ernie_functions.py @@ -72,12 +72,8 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]): "This output parser can only be used with a chat generation." ) message = generation.message - message.additional_kwargs["function_call"] = {} - if "function_call" in message.content: - function_call = json.loads(str(message.content)) - if "function_call" in function_call: - fc = function_call["function_call"] - message.additional_kwargs["function_call"] = fc + if "function_call" not in message.additional_kwargs: + return None try: function_call = message.additional_kwargs["function_call"] except KeyError as exc: