community: Fix Baichuan Chat. (#15207)

- **Description:** Baichuan Chat (with both Baichuan-Turbo and
Baichuan-Turbo-192K models) has updated their APIs. There are breaking
changes. For example, BAICHUAN_SECRET_KEY is removed in the latest API
but is still required in Langchain. Baichuan's Langchain integration
needs to be updated to the latest version.
  - **Issue:** #15206
  - **Dependencies:** None,
  - **Twitter handle:** None

@hwchase17.

Co-authored-by: BaiChuanHelper <wintergyc@WinterGYCs-MacBook-Pro.local>
This commit is contained in:
baichuan-assistant 2024-01-24 09:01:57 +08:00 committed by GitHub
parent cfc225ecb3
commit 20fcd49348
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 68 additions and 102 deletions

View File

@ -13,7 +13,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# ChatBaichuan\n", "# Chat with Baichuan-192K\n",
"\n", "\n",
"Baichuan chat models API by Baichuan Intelligent Technology. For more information, see [https://platform.baichuan-ai.com/docs/api](https://platform.baichuan-ai.com/docs/api)" "Baichuan chat models API by Baichuan Intelligent Technology. For more information, see [https://platform.baichuan-ai.com/docs/api](https://platform.baichuan-ai.com/docs/api)"
] ]
@ -44,19 +44,16 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"chat = ChatBaichuan(\n", "chat = ChatBaichuan(baichuan_api_key=\"YOUR_API_KEY\")"
" baichuan_api_key=\"YOUR_API_KEY\", baichuan_secret_key=\"YOUR_SECRET_KEY\"\n",
")"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"or you can set `api_key` and `secret_key` in your environment variables\n", "or you can set `api_key` in your environment variables\n",
"```bash\n", "```bash\n",
"export BAICHUAN_API_KEY=YOUR_API_KEY\n", "export BAICHUAN_API_KEY=YOUR_API_KEY\n",
"export BAICHUAN_SECRET_KEY=YOUR_SECRET_KEY\n",
"```" "```"
] ]
}, },
@ -91,7 +88,7 @@
"collapsed": false "collapsed": false
}, },
"source": [ "source": [
"## For ChatBaichuan with Streaming" "## Chat with Baichuan-192K with Streaming"
] ]
}, },
{ {
@ -108,7 +105,6 @@
"source": [ "source": [
"chat = ChatBaichuan(\n", "chat = ChatBaichuan(\n",
" baichuan_api_key=\"YOUR_API_KEY\",\n", " baichuan_api_key=\"YOUR_API_KEY\",\n",
" baichuan_secret_key=\"YOUR_SECRET_KEY\",\n",
" streaming=True,\n", " streaming=True,\n",
")" ")"
] ]

View File

@ -1,7 +1,5 @@
import hashlib
import json import json
import logging import logging
import time
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
import requests import requests
@ -30,7 +28,7 @@ from langchain_core.utils import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_API_BASE = "https://api.baichuan-ai.com/v1" DEFAULT_API_BASE = "https://api.baichuan-ai.com/v1/chat/completions"
def _convert_message_to_dict(message: BaseMessage) -> dict: def _convert_message_to_dict(message: BaseMessage) -> dict:
@ -73,14 +71,6 @@ def _convert_delta_to_message_chunk(
return default_class(content=content) return default_class(content=content)
# signature generation
def _signature(secret_key: SecretStr, payload: Dict[str, Any], timestamp: int) -> str:
input_str = secret_key.get_secret_value() + json.dumps(payload) + str(timestamp)
md5 = hashlib.md5()
md5.update(input_str.encode("utf-8"))
return md5.hexdigest()
class ChatBaichuan(BaseChatModel): class ChatBaichuan(BaseChatModel):
"""Baichuan chat models API by Baichuan Intelligent Technology. """Baichuan chat models API by Baichuan Intelligent Technology.
@ -91,7 +81,6 @@ class ChatBaichuan(BaseChatModel):
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> Dict[str, str]:
return { return {
"baichuan_api_key": "BAICHUAN_API_KEY", "baichuan_api_key": "BAICHUAN_API_KEY",
"baichuan_secret_key": "BAICHUAN_SECRET_KEY",
} }
@property @property
@ -103,14 +92,14 @@ class ChatBaichuan(BaseChatModel):
baichuan_api_key: Optional[SecretStr] = None baichuan_api_key: Optional[SecretStr] = None
"""Baichuan API Key""" """Baichuan API Key"""
baichuan_secret_key: Optional[SecretStr] = None baichuan_secret_key: Optional[SecretStr] = None
"""Baichuan Secret Key""" """[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key"""
streaming: bool = False streaming: bool = False
"""Whether to stream the results or not.""" """Whether to stream the results or not."""
request_timeout: int = 60 request_timeout: int = 60
"""request timeout for chat http requests""" """request timeout for chat http requests"""
model = "Baichuan2-Turbo-192K"
model = "Baichuan2-53B" """model name of Baichuan, default is `Baichuan2-Turbo-192K`,
"""model name of Baichuan, default is `Baichuan2-53B`.""" other options include `Baichuan2-Turbo`"""
temperature: float = 0.3 temperature: float = 0.3
"""What sampling temperature to use.""" """What sampling temperature to use."""
top_k: int = 5 top_k: int = 5
@ -168,13 +157,6 @@ class ChatBaichuan(BaseChatModel):
"BAICHUAN_API_KEY", "BAICHUAN_API_KEY",
) )
) )
values["baichuan_secret_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"baichuan_secret_key",
"BAICHUAN_SECRET_KEY",
)
)
return values return values
@ -187,6 +169,7 @@ class ChatBaichuan(BaseChatModel):
"top_p": self.top_p, "top_p": self.top_p,
"top_k": self.top_k, "top_k": self.top_k,
"with_search_enhance": self.with_search_enhance, "with_search_enhance": self.with_search_enhance,
"stream": self.streaming,
} }
return {**normal_params, **self.model_kwargs} return {**normal_params, **self.model_kwargs}
@ -205,12 +188,9 @@ class ChatBaichuan(BaseChatModel):
return generate_from_stream(stream_iter) return generate_from_stream(stream_iter)
res = self._chat(messages, **kwargs) res = self._chat(messages, **kwargs)
if res.status_code != 200:
raise ValueError(f"Error from Baichuan api response: {res}")
response = res.json() response = res.json()
if response.get("code") != 0:
raise ValueError(f"Error from Baichuan api response: {response}")
return self._create_chat_result(response) return self._create_chat_result(response)
def _stream( def _stream(
@ -221,43 +201,49 @@ class ChatBaichuan(BaseChatModel):
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
res = self._chat(messages, **kwargs) res = self._chat(messages, **kwargs)
if res.status_code != 200:
raise ValueError(f"Error from Baichuan api response: {res}")
default_chunk_class = AIMessageChunk default_chunk_class = AIMessageChunk
for chunk in res.iter_lines(): for chunk in res.iter_lines():
chunk = chunk.decode("utf-8").strip("\r\n")
parts = chunk.split("data: ", 1)
chunk = parts[1] if len(parts) > 1 else None
if chunk is None:
continue
if chunk == "[DONE]":
break
response = json.loads(chunk) response = json.loads(chunk)
if response.get("code") != 0: for m in response.get("choices"):
raise ValueError(f"Error from Baichuan api response: {response}") chunk = _convert_delta_to_message_chunk(
m.get("delta"), default_chunk_class
data = response.get("data") )
for m in data.get("messages"):
chunk = _convert_delta_to_message_chunk(m, default_chunk_class)
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk) yield ChatGenerationChunk(message=chunk)
if run_manager: if run_manager:
run_manager.on_llm_new_token(chunk.content) run_manager.on_llm_new_token(chunk.content)
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response: def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
if self.baichuan_secret_key is None:
raise ValueError("Baichuan secret key is not set.")
parameters = {**self._default_params, **kwargs} parameters = {**self._default_params, **kwargs}
model = parameters.pop("model") model = parameters.pop("model")
headers = parameters.pop("headers", {}) headers = parameters.pop("headers", {})
temperature = parameters.pop("temperature", 0.3)
top_k = parameters.pop("top_k", 5)
top_p = parameters.pop("top_p", 0.85)
with_search_enhance = parameters.pop("with_search_enhance", False)
stream = parameters.pop("stream", False)
payload = { payload = {
"model": model, "model": model,
"messages": [_convert_message_to_dict(m) for m in messages], "messages": [_convert_message_to_dict(m) for m in messages],
"parameters": parameters, "top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"with_search_enhance": with_search_enhance,
"stream": stream,
} }
timestamp = int(time.time())
url = self.baichuan_api_base url = self.baichuan_api_base
if self.streaming:
url = f"{url}/stream"
url = f"{url}/chat"
api_key = "" api_key = ""
if self.baichuan_api_key: if self.baichuan_api_key:
api_key = self.baichuan_api_key.get_secret_value() api_key = self.baichuan_api_key.get_secret_value()
@ -268,13 +254,6 @@ class ChatBaichuan(BaseChatModel):
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {api_key}", "Authorization": f"Bearer {api_key}",
"X-BC-Timestamp": str(timestamp),
"X-BC-Signature": _signature(
secret_key=self.baichuan_secret_key,
payload=payload,
timestamp=timestamp,
),
"X-BC-Sign-Algo": "MD5",
**headers, **headers,
}, },
json=payload, json=payload,
@ -284,8 +263,8 @@ class ChatBaichuan(BaseChatModel):
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = [] generations = []
for m in response["data"]["messages"]: for c in response["choices"]:
message = _convert_dict_to_message(m) message = _convert_dict_to_message(c["message"])
gen = ChatGeneration(message=message) gen = ChatGeneration(message=message)
generations.append(gen) generations.append(gen)

View File

@ -2,17 +2,36 @@ from langchain_core.messages import AIMessage, HumanMessage
from langchain_community.chat_models.baichuan import ChatBaichuan from langchain_community.chat_models.baichuan import ChatBaichuan
# For testing, run:
# TEST_FILE=tests/integration_tests/chat_models/test_baichuan.py make test
def test_chat_baichuan() -> None:
def test_chat_baichuan_default() -> None:
chat = ChatBaichuan(streaming=True)
message = HumanMessage(content="请完整背诵将进酒背诵5遍")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_chat_baichuan_default_non_streaming() -> None:
chat = ChatBaichuan() chat = ChatBaichuan()
message = HumanMessage(content="请完整背诵将进酒背诵5遍")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_chat_baichuan_turbo() -> None:
chat = ChatBaichuan(model="Baichuan2-Turbo", streaming=True)
message = HumanMessage(content="Hello") message = HumanMessage(content="Hello")
response = chat([message]) response = chat([message])
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
assert isinstance(response.content, str) assert isinstance(response.content, str)
def test_chat_baichuan_with_model() -> None: def test_chat_baichuan_turbo_non_streaming() -> None:
chat = ChatBaichuan(model="Baichuan2-13B") chat = ChatBaichuan(model="Baichuan2-Turbo")
message = HumanMessage(content="Hello") message = HumanMessage(content="Hello")
response = chat([message]) response = chat([message])
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
@ -20,7 +39,7 @@ def test_chat_baichuan_with_model() -> None:
def test_chat_baichuan_with_temperature() -> None: def test_chat_baichuan_with_temperature() -> None:
chat = ChatBaichuan(model="Baichuan2-13B", temperature=1.0) chat = ChatBaichuan(temperature=1.0)
message = HumanMessage(content="Hello") message = HumanMessage(content="Hello")
response = chat([message]) response = chat([message])
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
@ -29,13 +48,15 @@ def test_chat_baichuan_with_temperature() -> None:
def test_chat_baichuan_with_kwargs() -> None: def test_chat_baichuan_with_kwargs() -> None:
chat = ChatBaichuan() chat = ChatBaichuan()
message = HumanMessage(content="Hello") message = HumanMessage(content="百川192K API是什么时候上线的")
response = chat([message], temperature=0.88, top_p=0.7) response = chat([message], temperature=0.88, top_p=0.7, with_search_enhance=True)
print(response)
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
assert isinstance(response.content, str) assert isinstance(response.content, str)
def test_extra_kwargs() -> None: def test_extra_kwargs() -> None:
chat = ChatBaichuan(temperature=0.88, top_p=0.7) chat = ChatBaichuan(temperature=0.88, top_p=0.7, with_search_enhance=True)
assert chat.temperature == 0.88 assert chat.temperature == 0.88
assert chat.top_p == 0.7 assert chat.top_p == 0.7
assert chat.with_search_enhance is True

View File

@ -18,7 +18,6 @@ from langchain_community.chat_models.baichuan import (
_convert_delta_to_message_chunk, _convert_delta_to_message_chunk,
_convert_dict_to_message, _convert_dict_to_message,
_convert_message_to_dict, _convert_message_to_dict,
_signature,
) )
@ -85,62 +84,33 @@ def test__convert_delta_to_message_human() -> None:
assert result == expected_output assert result == expected_output
def test__signature() -> None:
secret_key = SecretStr("YOUR_SECRET_KEY")
result = _signature(
secret_key=secret_key,
payload={
"model": "Baichuan2-53B",
"messages": [{"role": "user", "content": "Hi"}],
},
timestamp=1697734335,
)
# The signature was generated by the demo provided by Baichuan.
# https://platform.baichuan-ai.com/docs/api#4
expected_output = "24a50b2db1648e25a244c67c5ab57d3f"
assert result == expected_output
def test_baichuan_key_masked_when_passed_from_env( def test_baichuan_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None: ) -> None:
"""Test initialization with an API key provided via an env variable""" """Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("BAICHUAN_API_KEY", "test-api-key") monkeypatch.setenv("BAICHUAN_API_KEY", "test-api-key")
monkeypatch.setenv("BAICHUAN_SECRET_KEY", "test-secret-key")
chat = ChatBaichuan() chat = ChatBaichuan()
print(chat.baichuan_api_key, end="") print(chat.baichuan_api_key, end="")
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == "**********" assert captured.out == "**********"
print(chat.baichuan_secret_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_baichuan_key_masked_when_passed_via_constructor( def test_baichuan_key_masked_when_passed_via_constructor(
capsys: CaptureFixture, capsys: CaptureFixture,
) -> None: ) -> None:
"""Test initialization with an API key provided via the initializer""" """Test initialization with an API key provided via the initializer"""
chat = ChatBaichuan( chat = ChatBaichuan(baichuan_api_key="test-api-key")
baichuan_api_key="test-api-key", baichuan_secret_key="test-secret-key"
)
print(chat.baichuan_api_key, end="") print(chat.baichuan_api_key, end="")
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == "**********" assert captured.out == "**********"
print(chat.baichuan_secret_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_uses_actual_secret_value_from_secret_str() -> None: def test_uses_actual_secret_value_from_secret_str() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`.""" """Test that actual secret is retrieved using `.get_secret_value()`."""
chat = ChatBaichuan( chat = ChatBaichuan(
baichuan_api_key="test-api-key", baichuan_secret_key="test-secret-key" baichuan_api_key="test-api-key",
baichuan_secret_key="test-secret-key", # For backward compatibility
) )
assert cast(SecretStr, chat.baichuan_api_key).get_secret_value() == "test-api-key" assert cast(SecretStr, chat.baichuan_api_key).get_secret_value() == "test-api-key"
assert ( assert (