mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 05:43:55 +00:00
feat: Implement ChatBaichuan asynchronous interface (#23589)
- **Description:** Add interface to `ChatBaichuan` to support asynchronous requests - `_agenerate` method - `_astream` method --------- Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
parent
8842a0d986
commit
525109e506
@ -1,11 +1,16 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Type
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
@ -71,6 +76,27 @@ def _convert_delta_to_message_chunk(
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def aconnect_httpx_sse(
|
||||
client: Any, method: str, url: str, **kwargs: Any
|
||||
) -> AsyncIterator:
|
||||
"""Async context manager for connecting to an SSE stream.
|
||||
|
||||
Args:
|
||||
client: The httpx client.
|
||||
method: The HTTP method.
|
||||
url: The URL to connect to.
|
||||
kwargs: Additional keyword arguments to pass to the client.
|
||||
|
||||
Yields:
|
||||
An EventSource object.
|
||||
"""
|
||||
from httpx_sse import EventSource
|
||||
|
||||
async with client.stream(method, url, **kwargs) as response:
|
||||
yield EventSource(response)
|
||||
|
||||
|
||||
class ChatBaichuan(BaseChatModel):
|
||||
"""Baichuan chat models API by Baichuan Intelligent Technology.
|
||||
|
||||
@ -199,7 +225,7 @@ class ChatBaichuan(BaseChatModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
res = self._chat(messages, **kwargs)
|
||||
res = self._chat(messages, stream=True, **kwargs)
|
||||
if res.status_code != 200:
|
||||
raise ValueError(f"Error from Baichuan api response: {res}")
|
||||
default_chunk_class = AIMessageChunk
|
||||
@ -222,14 +248,96 @@ class ChatBaichuan(BaseChatModel):
|
||||
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
|
||||
parameters = {**self._default_params, **kwargs}
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
model = parameters.pop("model")
|
||||
headers = parameters.pop("headers", {})
|
||||
headers = self._create_headers_parameters(**kwargs)
|
||||
payload = self._create_payload_parameters(messages, **kwargs)
|
||||
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
headers=headers, timeout=self.request_timeout
|
||||
) as client:
|
||||
response = await client.post(self.baichuan_api_base, json=payload)
|
||||
response.raise_for_status()
|
||||
return self._create_chat_result(response.json())
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
headers = self._create_headers_parameters(**kwargs)
|
||||
payload = self._create_payload_parameters(messages, stream=True, **kwargs)
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
headers=headers, timeout=self.request_timeout
|
||||
) as client:
|
||||
async with aconnect_httpx_sse(
|
||||
client, "POST", self.baichuan_api_base, json=payload
|
||||
) as event_source:
|
||||
async for sse in event_source.aiter_sse():
|
||||
chunk = json.loads(sse.data)
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], AIMessageChunk
|
||||
)
|
||||
finish_reason = choice.get("finish_reason", None)
|
||||
|
||||
generation_info = (
|
||||
{"finish_reason": finish_reason}
|
||||
if finish_reason is not None
|
||||
else None
|
||||
)
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
yield chunk
|
||||
if finish_reason is not None:
|
||||
break
|
||||
|
||||
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
|
||||
payload = self._create_payload_parameters(messages, **kwargs)
|
||||
url = self.baichuan_api_base
|
||||
headers = self._create_headers_parameters(**kwargs)
|
||||
|
||||
res = requests.post(
|
||||
url=url,
|
||||
timeout=self.request_timeout,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
stream=self.streaming,
|
||||
)
|
||||
return res
|
||||
|
||||
def _create_payload_parameters( # type: ignore[no-untyped-def]
|
||||
self, messages: List[BaseMessage], **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
parameters = {**self._default_params, **kwargs}
|
||||
temperature = parameters.pop("temperature", 0.3)
|
||||
top_k = parameters.pop("top_k", 5)
|
||||
top_p = parameters.pop("top_p", 0.85)
|
||||
model = parameters.pop("model")
|
||||
with_search_enhance = parameters.pop("with_search_enhance", False)
|
||||
stream = parameters.pop("stream", False)
|
||||
|
||||
@ -242,24 +350,21 @@ class ChatBaichuan(BaseChatModel):
|
||||
"with_search_enhance": with_search_enhance,
|
||||
"stream": stream,
|
||||
}
|
||||
return payload
|
||||
|
||||
url = self.baichuan_api_base
|
||||
def _create_headers_parameters(self, **kwargs) -> Dict[str, Any]: # type: ignore[no-untyped-def]
|
||||
parameters = {**self._default_params, **kwargs}
|
||||
default_headers = parameters.pop("headers", {})
|
||||
api_key = ""
|
||||
if self.baichuan_api_key:
|
||||
api_key = self.baichuan_api_key.get_secret_value()
|
||||
|
||||
res = requests.post(
|
||||
url=url,
|
||||
timeout=self.request_timeout,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
**headers,
|
||||
},
|
||||
json=payload,
|
||||
stream=self.streaming,
|
||||
)
|
||||
return res
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
**default_headers,
|
||||
}
|
||||
return headers
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
|
@ -62,3 +62,17 @@ def test_extra_kwargs() -> None:
|
||||
assert chat.temperature == 0.88
|
||||
assert chat.top_p == 0.7
|
||||
assert chat.with_search_enhance is True
|
||||
|
||||
|
||||
async def test_chat_baichuan_agenerate() -> None:
|
||||
chat = ChatBaichuan() # type: ignore[call-arg]
|
||||
response = await chat.ainvoke("你好呀")
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
async def test_chat_baichuan_astream() -> None:
|
||||
chat = ChatBaichuan() # type: ignore[call-arg]
|
||||
async for chunk in chat.astream("今天天气如何?"):
|
||||
assert isinstance(chunk, AIMessage)
|
||||
assert isinstance(chunk.content, str)
|
||||
|
Loading…
Reference in New Issue
Block a user