mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 04:50:37 +00:00
feat: Implement ChatBaichuan asynchronous interface (#23589)
- **Description:** Add interface to `ChatBaichuan` to support asynchronous requests - `_agenerate` method - `_astream` method --------- Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
parent
8842a0d986
commit
525109e506
@ -1,11 +1,16 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Type
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
from langchain_core.language_models.chat_models import (
|
from langchain_core.language_models.chat_models import (
|
||||||
BaseChatModel,
|
BaseChatModel,
|
||||||
|
agenerate_from_stream,
|
||||||
generate_from_stream,
|
generate_from_stream,
|
||||||
)
|
)
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
@ -71,6 +76,27 @@ def _convert_delta_to_message_chunk(
|
|||||||
return default_class(content=content) # type: ignore[call-arg]
|
return default_class(content=content) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def aconnect_httpx_sse(
|
||||||
|
client: Any, method: str, url: str, **kwargs: Any
|
||||||
|
) -> AsyncIterator:
|
||||||
|
"""Async context manager for connecting to an SSE stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: The httpx client.
|
||||||
|
method: The HTTP method.
|
||||||
|
url: The URL to connect to.
|
||||||
|
kwargs: Additional keyword arguments to pass to the client.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
An EventSource object.
|
||||||
|
"""
|
||||||
|
from httpx_sse import EventSource
|
||||||
|
|
||||||
|
async with client.stream(method, url, **kwargs) as response:
|
||||||
|
yield EventSource(response)
|
||||||
|
|
||||||
|
|
||||||
class ChatBaichuan(BaseChatModel):
|
class ChatBaichuan(BaseChatModel):
|
||||||
"""Baichuan chat models API by Baichuan Intelligent Technology.
|
"""Baichuan chat models API by Baichuan Intelligent Technology.
|
||||||
|
|
||||||
@ -199,7 +225,7 @@ class ChatBaichuan(BaseChatModel):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
res = self._chat(messages, **kwargs)
|
res = self._chat(messages, stream=True, **kwargs)
|
||||||
if res.status_code != 200:
|
if res.status_code != 200:
|
||||||
raise ValueError(f"Error from Baichuan api response: {res}")
|
raise ValueError(f"Error from Baichuan api response: {res}")
|
||||||
default_chunk_class = AIMessageChunk
|
default_chunk_class = AIMessageChunk
|
||||||
@ -222,14 +248,96 @@ class ChatBaichuan(BaseChatModel):
|
|||||||
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
||||||
yield cg_chunk
|
yield cg_chunk
|
||||||
|
|
||||||
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
|
async def _agenerate(
|
||||||
parameters = {**self._default_params, **kwargs}
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
should_stream = stream if stream is not None else self.streaming
|
||||||
|
if should_stream:
|
||||||
|
stream_iter = self._astream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return await agenerate_from_stream(stream_iter)
|
||||||
|
|
||||||
model = parameters.pop("model")
|
headers = self._create_headers_parameters(**kwargs)
|
||||||
headers = parameters.pop("headers", {})
|
payload = self._create_payload_parameters(messages, **kwargs)
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
headers=headers, timeout=self.request_timeout
|
||||||
|
) as client:
|
||||||
|
response = await client.post(self.baichuan_api_base, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
return self._create_chat_result(response.json())
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
headers = self._create_headers_parameters(**kwargs)
|
||||||
|
payload = self._create_payload_parameters(messages, stream=True, **kwargs)
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
headers=headers, timeout=self.request_timeout
|
||||||
|
) as client:
|
||||||
|
async with aconnect_httpx_sse(
|
||||||
|
client, "POST", self.baichuan_api_base, json=payload
|
||||||
|
) as event_source:
|
||||||
|
async for sse in event_source.aiter_sse():
|
||||||
|
chunk = json.loads(sse.data)
|
||||||
|
if len(chunk["choices"]) == 0:
|
||||||
|
continue
|
||||||
|
choice = chunk["choices"][0]
|
||||||
|
chunk = _convert_delta_to_message_chunk(
|
||||||
|
choice["delta"], AIMessageChunk
|
||||||
|
)
|
||||||
|
finish_reason = choice.get("finish_reason", None)
|
||||||
|
|
||||||
|
generation_info = (
|
||||||
|
{"finish_reason": finish_reason}
|
||||||
|
if finish_reason is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
chunk = ChatGenerationChunk(
|
||||||
|
message=chunk, generation_info=generation_info
|
||||||
|
)
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
|
yield chunk
|
||||||
|
if finish_reason is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
|
||||||
|
payload = self._create_payload_parameters(messages, **kwargs)
|
||||||
|
url = self.baichuan_api_base
|
||||||
|
headers = self._create_headers_parameters(**kwargs)
|
||||||
|
|
||||||
|
res = requests.post(
|
||||||
|
url=url,
|
||||||
|
timeout=self.request_timeout,
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
stream=self.streaming,
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _create_payload_parameters( # type: ignore[no-untyped-def]
|
||||||
|
self, messages: List[BaseMessage], **kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
parameters = {**self._default_params, **kwargs}
|
||||||
temperature = parameters.pop("temperature", 0.3)
|
temperature = parameters.pop("temperature", 0.3)
|
||||||
top_k = parameters.pop("top_k", 5)
|
top_k = parameters.pop("top_k", 5)
|
||||||
top_p = parameters.pop("top_p", 0.85)
|
top_p = parameters.pop("top_p", 0.85)
|
||||||
|
model = parameters.pop("model")
|
||||||
with_search_enhance = parameters.pop("with_search_enhance", False)
|
with_search_enhance = parameters.pop("with_search_enhance", False)
|
||||||
stream = parameters.pop("stream", False)
|
stream = parameters.pop("stream", False)
|
||||||
|
|
||||||
@ -242,24 +350,21 @@ class ChatBaichuan(BaseChatModel):
|
|||||||
"with_search_enhance": with_search_enhance,
|
"with_search_enhance": with_search_enhance,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
}
|
}
|
||||||
|
return payload
|
||||||
|
|
||||||
url = self.baichuan_api_base
|
def _create_headers_parameters(self, **kwargs) -> Dict[str, Any]: # type: ignore[no-untyped-def]
|
||||||
|
parameters = {**self._default_params, **kwargs}
|
||||||
|
default_headers = parameters.pop("headers", {})
|
||||||
api_key = ""
|
api_key = ""
|
||||||
if self.baichuan_api_key:
|
if self.baichuan_api_key:
|
||||||
api_key = self.baichuan_api_key.get_secret_value()
|
api_key = self.baichuan_api_key.get_secret_value()
|
||||||
|
|
||||||
res = requests.post(
|
headers = {
|
||||||
url=url,
|
"Content-Type": "application/json",
|
||||||
timeout=self.request_timeout,
|
"Authorization": f"Bearer {api_key}",
|
||||||
headers={
|
**default_headers,
|
||||||
"Content-Type": "application/json",
|
}
|
||||||
"Authorization": f"Bearer {api_key}",
|
return headers
|
||||||
**headers,
|
|
||||||
},
|
|
||||||
json=payload,
|
|
||||||
stream=self.streaming,
|
|
||||||
)
|
|
||||||
return res
|
|
||||||
|
|
||||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||||
generations = []
|
generations = []
|
||||||
|
@ -62,3 +62,17 @@ def test_extra_kwargs() -> None:
|
|||||||
assert chat.temperature == 0.88
|
assert chat.temperature == 0.88
|
||||||
assert chat.top_p == 0.7
|
assert chat.top_p == 0.7
|
||||||
assert chat.with_search_enhance is True
|
assert chat.with_search_enhance is True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_chat_baichuan_agenerate() -> None:
|
||||||
|
chat = ChatBaichuan() # type: ignore[call-arg]
|
||||||
|
response = await chat.ainvoke("你好呀")
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_chat_baichuan_astream() -> None:
|
||||||
|
chat = ChatBaichuan() # type: ignore[call-arg]
|
||||||
|
async for chunk in chat.astream("今天天气如何?"):
|
||||||
|
assert isinstance(chunk, AIMessage)
|
||||||
|
assert isinstance(chunk.content, str)
|
||||||
|
Loading…
Reference in New Issue
Block a user