mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-14 05:56:40 +00:00
Add baichuan model
This commit is contained in:
@@ -21,6 +21,7 @@ from langchain.chat_models.anthropic import ChatAnthropic
|
||||
from langchain.chat_models.anyscale import ChatAnyscale
|
||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
||||
from langchain.chat_models.baichuan_baichuaninc_endpoint import BaichuanChatEndpoint
|
||||
from langchain.chat_models.bedrock import BedrockChat
|
||||
from langchain.chat_models.cohere import ChatCohere
|
||||
from langchain.chat_models.ernie import ErnieBotChat
|
||||
@@ -63,4 +64,5 @@ __all__ = [
|
||||
"ChatKonko",
|
||||
"QianfanChatEndpoint",
|
||||
"ChatFireworks",
|
||||
"BaichuanChatEndpoint",
|
||||
]
|
||||
|
@@ -0,0 +1,161 @@
|
||||
"""Baichuan chat wrapper."""
|
||||
from __future__ import annotations
|
||||
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema import ChatGeneration, ChatResult
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_resp_to_message_chunk(resp: Mapping[str, Any]) -> BaseMessageChunk:
|
||||
return AIMessageChunk(
|
||||
content=resp["result"],
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"""Convert a message to a dictionary that can be passed to the API."""
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["functions"] = message.additional_kwargs["function_call"]
|
||||
# If function call only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
|
||||
def calculate_md5(input_string):
|
||||
md5 = hashlib.md5()
|
||||
md5.update(input_string.encode('utf-8'))
|
||||
encrypted = md5.hexdigest()
|
||||
return encrypted
|
||||
|
||||
|
||||
class BaichuanChatEndpoint():
|
||||
"""
|
||||
Currently only enterprise registration is supported for use
|
||||
|
||||
To use, you should have the environment variable ``Baichuan_AK`` and ``Baichuan_SK`` set with your
|
||||
api_key and secret_key.
|
||||
|
||||
ak, sk are required parameters
|
||||
which you could get from https: // api.baichuan-ai.com
|
||||
|
||||
Example:
|
||||
.. code-block: : python
|
||||
|
||||
from langchain.chat_models import BaichuanChatEndpoint
|
||||
baichuan_chat = BaichuanChatEndpoint("your_ak", "your_sk","Baichuan2-13B")
|
||||
result=baichuan_chat.predict(message)
|
||||
print(result.text")
|
||||
|
||||
Because Baichuan was no pip package made,So we will temporarily use this method and iterate and upgrade in the future
|
||||
|
||||
Args: They cannot be empty
|
||||
baichuan_ak (str): api_key
|
||||
baichuan_sk (str): secret_key
|
||||
model (str): Default Baichuan2-7B,Baichuan2-13B,Baichuan2-53B which is commercial.
|
||||
streaming (bool): Default False
|
||||
|
||||
Returns:
|
||||
Execute predict return response.
|
||||
|
||||
"""
|
||||
|
||||
baichuan_ak: Optional[str] = None
|
||||
baichuan_sk: Optional[str] = None
|
||||
|
||||
request_timeout: Optional[int] = 60
|
||||
"""request timeout for chat http requests"""
|
||||
|
||||
top_p: Optional[float] = 0.8
|
||||
temperature: Optional[float] = 0.95
|
||||
|
||||
endpoint: Optional[str] = None
|
||||
"""Endpoint of the Qianfan LLM, required if custom model used."""
|
||||
|
||||
def __init__(self, baichuan_ak, baichuan_sk, model="Baichuan2-7B", streaming=False):
|
||||
self.baichuan_ak = baichuan_ak
|
||||
self.baichuan_sk = baichuan_sk
|
||||
self.model = "Baichuan2-7B" if model is None else model
|
||||
self.streaming = False if streaming is not None and streaming is False else True
|
||||
|
||||
def predict(self, messages: List[BaseMessage]) -> Response:
|
||||
|
||||
if self.streaming is not None and self.streaming is False:
|
||||
url = "https://api.baichuan-ai.com/v1/chat"
|
||||
elif self.streaming is not None and self.streaming is True:
|
||||
url = "https://api.baichuan-ai.com/v1/stream/chat"
|
||||
|
||||
data = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": messages
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
json_data = json.dumps(data)
|
||||
time_stamp = int(time.time())
|
||||
signature = calculate_md5(self.baichuan_sk + json_data + str(time_stamp))
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.baichuan_ak,
|
||||
"X-BC-Request-Id": "your requestId",
|
||||
"X-BC-Timestamp": str(time_stamp),
|
||||
"X-BC-Signature": signature,
|
||||
"X-BC-Sign-Algo": "MD5",
|
||||
}
|
||||
response = requests.post(url, data=json_data, headers=headers)
|
||||
return response
|
Reference in New Issue
Block a user