From 525109e50663b9b1485b46e70d41d2e7aa0e7dfd Mon Sep 17 00:00:00 2001 From: maang-h <55082429+maang-h@users.noreply.github.com> Date: Thu, 4 Jul 2024 00:10:04 +0800 Subject: [PATCH] feat: Implement ChatBaichuan asynchronous interface (#23589) - **Description:** Add interface to `ChatBaichuan` to support asynchronous requests - `_agenerate` method - `_astream` method --------- Co-authored-by: ccurme --- .../chat_models/baichuan.py | 145 +++++++++++++++--- .../chat_models/test_baichuan.py | 14 ++ 2 files changed, 139 insertions(+), 20 deletions(-) diff --git a/libs/community/langchain_community/chat_models/baichuan.py b/libs/community/langchain_community/chat_models/baichuan.py index 91d1f76dfec..ede68a14e66 100644 --- a/libs/community/langchain_community/chat_models/baichuan.py +++ b/libs/community/langchain_community/chat_models/baichuan.py @@ -1,11 +1,16 @@ import json import logging -from typing import Any, Dict, Iterator, List, Mapping, Optional, Type +from contextlib import asynccontextmanager +from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Type import requests -from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain_core.language_models.chat_models import ( BaseChatModel, + agenerate_from_stream, generate_from_stream, ) from langchain_core.messages import ( @@ -71,6 +76,27 @@ def _convert_delta_to_message_chunk( return default_class(content=content) # type: ignore[call-arg] +@asynccontextmanager +async def aconnect_httpx_sse( + client: Any, method: str, url: str, **kwargs: Any +) -> AsyncIterator: + """Async context manager for connecting to an SSE stream. + + Args: + client: The httpx client. + method: The HTTP method. + url: The URL to connect to. + kwargs: Additional keyword arguments to pass to the client. + + Yields: + An EventSource object. + """ + from httpx_sse import EventSource + + async with client.stream(method, url, **kwargs) as response: + yield EventSource(response) + + class ChatBaichuan(BaseChatModel): """Baichuan chat models API by Baichuan Intelligent Technology. @@ -199,7 +225,7 @@ class ChatBaichuan(BaseChatModel): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - res = self._chat(messages, **kwargs) + res = self._chat(messages, stream=True, **kwargs) if res.status_code != 200: raise ValueError(f"Error from Baichuan api response: {res}") default_chunk_class = AIMessageChunk @@ -222,14 +248,96 @@ class ChatBaichuan(BaseChatModel): run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) yield cg_chunk - def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response: - parameters = {**self._default_params, **kwargs} + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, + **kwargs: Any, + ) -> ChatResult: + should_stream = stream if stream is not None else self.streaming + if should_stream: + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await agenerate_from_stream(stream_iter) - model = parameters.pop("model") - headers = parameters.pop("headers", {}) + headers = self._create_headers_parameters(**kwargs) + payload = self._create_payload_parameters(messages, **kwargs) + + import httpx + + async with httpx.AsyncClient( + headers=headers, timeout=self.request_timeout + ) as client: + response = await client.post(self.baichuan_api_base, json=payload) + response.raise_for_status() + return self._create_chat_result(response.json()) + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + headers = self._create_headers_parameters(**kwargs) + payload = self._create_payload_parameters(messages, stream=True, **kwargs) + import httpx + + async with httpx.AsyncClient( + headers=headers, timeout=self.request_timeout + ) as client: + async with aconnect_httpx_sse( + client, "POST", self.baichuan_api_base, json=payload + ) as event_source: + async for sse in event_source.aiter_sse(): + chunk = json.loads(sse.data) + if len(chunk["choices"]) == 0: + continue + choice = chunk["choices"][0] + chunk = _convert_delta_to_message_chunk( + choice["delta"], AIMessageChunk + ) + finish_reason = choice.get("finish_reason", None) + + generation_info = ( + {"finish_reason": finish_reason} + if finish_reason is not None + else None + ) + chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info + ) + if run_manager: + await run_manager.on_llm_new_token(chunk.text, chunk=chunk) + yield chunk + if finish_reason is not None: + break + + def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response: + payload = self._create_payload_parameters(messages, **kwargs) + url = self.baichuan_api_base + headers = self._create_headers_parameters(**kwargs) + + res = requests.post( + url=url, + timeout=self.request_timeout, + headers=headers, + json=payload, + stream=self.streaming, + ) + return res + + def _create_payload_parameters( # type: ignore[no-untyped-def] + self, messages: List[BaseMessage], **kwargs + ) -> Dict[str, Any]: + parameters = {**self._default_params, **kwargs} temperature = parameters.pop("temperature", 0.3) top_k = parameters.pop("top_k", 5) top_p = parameters.pop("top_p", 0.85) + model = parameters.pop("model") with_search_enhance = parameters.pop("with_search_enhance", False) stream = parameters.pop("stream", False) @@ -242,24 +350,21 @@ class ChatBaichuan(BaseChatModel): "with_search_enhance": with_search_enhance, "stream": stream, } + return payload - url = self.baichuan_api_base + def _create_headers_parameters(self, **kwargs) -> Dict[str, Any]: # type: ignore[no-untyped-def] + parameters = {**self._default_params, **kwargs} + default_headers = parameters.pop("headers", {}) api_key = "" if self.baichuan_api_key: api_key = self.baichuan_api_key.get_secret_value() - res = requests.post( - url=url, - timeout=self.request_timeout, - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - **headers, - }, - json=payload, - stream=self.streaming, - ) - return res + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + **default_headers, + } + return headers def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: generations = [] diff --git a/libs/community/tests/integration_tests/chat_models/test_baichuan.py b/libs/community/tests/integration_tests/chat_models/test_baichuan.py index 715391c1963..3ffcf2e9dda 100644 --- a/libs/community/tests/integration_tests/chat_models/test_baichuan.py +++ b/libs/community/tests/integration_tests/chat_models/test_baichuan.py @@ -62,3 +62,17 @@ def test_extra_kwargs() -> None: assert chat.temperature == 0.88 assert chat.top_p == 0.7 assert chat.with_search_enhance is True + + +async def test_chat_baichuan_agenerate() -> None: + chat = ChatBaichuan() # type: ignore[call-arg] + response = await chat.ainvoke("你好呀") + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + +async def test_chat_baichuan_astream() -> None: + chat = ChatBaichuan() # type: ignore[call-arg] + async for chunk in chat.astream("今天天气如何?"): + assert isinstance(chunk, AIMessage) + assert isinstance(chunk.content, str)