feat: Implement ChatBaichuan asynchronous interface (#23589)

- **Description:** Add interface to `ChatBaichuan` to support
asynchronous requests
    - `_agenerate` method
    - `_astream` method

---------

Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
maang-h 2024-07-04 00:10:04 +08:00 committed by GitHub
parent 8842a0d986
commit 525109e506
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 139 additions and 20 deletions

View File

@ -1,11 +1,16 @@
import json
import logging
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Type
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
@ -71,6 +76,27 @@ def _convert_delta_to_message_chunk(
return default_class(content=content) # type: ignore[call-arg]
@asynccontextmanager
async def aconnect_httpx_sse(
client: Any, method: str, url: str, **kwargs: Any
) -> AsyncIterator:
"""Async context manager for connecting to an SSE stream.
Args:
client: The httpx client.
method: The HTTP method.
url: The URL to connect to.
kwargs: Additional keyword arguments to pass to the client.
Yields:
An EventSource object.
"""
from httpx_sse import EventSource
async with client.stream(method, url, **kwargs) as response:
yield EventSource(response)
class ChatBaichuan(BaseChatModel):
"""Baichuan chat models API by Baichuan Intelligent Technology.
@ -199,7 +225,7 @@ class ChatBaichuan(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
res = self._chat(messages, **kwargs)
res = self._chat(messages, stream=True, **kwargs)
if res.status_code != 200:
raise ValueError(f"Error from Baichuan api response: {res}")
default_chunk_class = AIMessageChunk
@ -222,14 +248,96 @@ class ChatBaichuan(BaseChatModel):
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
parameters = {**self._default_params, **kwargs}
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
model = parameters.pop("model")
headers = parameters.pop("headers", {})
headers = self._create_headers_parameters(**kwargs)
payload = self._create_payload_parameters(messages, **kwargs)
import httpx
async with httpx.AsyncClient(
headers=headers, timeout=self.request_timeout
) as client:
response = await client.post(self.baichuan_api_base, json=payload)
response.raise_for_status()
return self._create_chat_result(response.json())
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
headers = self._create_headers_parameters(**kwargs)
payload = self._create_payload_parameters(messages, stream=True, **kwargs)
import httpx
async with httpx.AsyncClient(
headers=headers, timeout=self.request_timeout
) as client:
async with aconnect_httpx_sse(
client, "POST", self.baichuan_api_base, json=payload
) as event_source:
async for sse in event_source.aiter_sse():
chunk = json.loads(sse.data)
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], AIMessageChunk
)
finish_reason = choice.get("finish_reason", None)
generation_info = (
{"finish_reason": finish_reason}
if finish_reason is not None
else None
)
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
if finish_reason is not None:
break
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
payload = self._create_payload_parameters(messages, **kwargs)
url = self.baichuan_api_base
headers = self._create_headers_parameters(**kwargs)
res = requests.post(
url=url,
timeout=self.request_timeout,
headers=headers,
json=payload,
stream=self.streaming,
)
return res
def _create_payload_parameters( # type: ignore[no-untyped-def]
self, messages: List[BaseMessage], **kwargs
) -> Dict[str, Any]:
parameters = {**self._default_params, **kwargs}
temperature = parameters.pop("temperature", 0.3)
top_k = parameters.pop("top_k", 5)
top_p = parameters.pop("top_p", 0.85)
model = parameters.pop("model")
with_search_enhance = parameters.pop("with_search_enhance", False)
stream = parameters.pop("stream", False)
@ -242,24 +350,21 @@ class ChatBaichuan(BaseChatModel):
"with_search_enhance": with_search_enhance,
"stream": stream,
}
return payload
url = self.baichuan_api_base
def _create_headers_parameters(self, **kwargs) -> Dict[str, Any]: # type: ignore[no-untyped-def]
parameters = {**self._default_params, **kwargs}
default_headers = parameters.pop("headers", {})
api_key = ""
if self.baichuan_api_key:
api_key = self.baichuan_api_key.get_secret_value()
res = requests.post(
url=url,
timeout=self.request_timeout,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
**headers,
},
json=payload,
stream=self.streaming,
)
return res
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
**default_headers,
}
return headers
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []

View File

@ -62,3 +62,17 @@ def test_extra_kwargs() -> None:
assert chat.temperature == 0.88
assert chat.top_p == 0.7
assert chat.with_search_enhance is True
async def test_chat_baichuan_agenerate() -> None:
chat = ChatBaichuan() # type: ignore[call-arg]
response = await chat.ainvoke("你好呀")
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
async def test_chat_baichuan_astream() -> None:
chat = ChatBaichuan() # type: ignore[call-arg]
async for chunk in chat.astream("今天天气如何?"):
assert isinstance(chunk, AIMessage)
assert isinstance(chunk.content, str)