mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-07 20:15:40 +00:00
Add asynchronous generate interface (#30001)
- [ ] **PR title**: [langchain_community.llms.xinference]: Add asynchronous generate interface - [ ] **PR message**: The asynchronous generate interface support stream data and non-stream data. chain = prompt | llm async for chunk in chain.astream(input=user_input): yield chunk - [ ] **Add tests and docs**: from langchain_community.llms import Xinference from langchain.prompts import PromptTemplate llm = Xinference( server_url="http://0.0.0.0:9997", # replace your xinference server url model_uid={model_uid} # replace model_uid with the model UID return from launching the model stream = True ) prompt = PromptTemplate(input=['country'], template="Q: where can we visit in the capital of {country}? A:") chain = prompt | llm async for chunk in chain.astream(input=user_input): yield chunk
This commit is contained in:
parent
a1897ca621
commit
86b364de3b
@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
@ -12,7 +14,12 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
import aiohttp
|
||||
import requests
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
|
||||
@ -126,6 +133,7 @@ class Xinference(LLM):
|
||||
self,
|
||||
server_url: Optional[str] = None,
|
||||
model_uid: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
**model_kwargs: Any,
|
||||
):
|
||||
try:
|
||||
@ -155,7 +163,13 @@ class Xinference(LLM):
|
||||
if self.model_uid is None:
|
||||
raise ValueError("Please provide the model UID")
|
||||
|
||||
self.client = RESTfulClient(server_url)
|
||||
self._headers: Dict[str, str] = {}
|
||||
self._cluster_authed = False
|
||||
self._check_cluster_authenticated()
|
||||
if api_key is not None and self._cluster_authed:
|
||||
self._headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
self.client = RESTfulClient(server_url, api_key)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
@ -171,6 +185,20 @@ class Xinference(LLM):
|
||||
**{"model_kwargs": self.model_kwargs},
|
||||
}
|
||||
|
||||
def _check_cluster_authenticated(self) -> None:
|
||||
url = f"{self.server_url}/v1/cluster/auth"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 404:
|
||||
self._cluster_authed = False
|
||||
else:
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Failed to get cluster information, "
|
||||
f"detail: {response.json()['detail']}"
|
||||
)
|
||||
response_data = response.json()
|
||||
self._cluster_authed = bool(response_data["auth"])
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -305,3 +333,61 @@ class Xinference(LLM):
|
||||
return GenerationChunk(text=token)
|
||||
else:
|
||||
raise TypeError("stream_response type error!")
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
generate_config = kwargs.get("generate_config", {})
|
||||
generate_config = {**self.model_kwargs, **generate_config}
|
||||
if stop:
|
||||
generate_config["stop"] = stop
|
||||
async for stream_resp in self._acreate_generate_stream(prompt, generate_config):
|
||||
if stream_resp:
|
||||
chunk = self._stream_response_to_generation_chunk(stream_resp)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
yield chunk
|
||||
|
||||
async def _acreate_generate_stream(
|
||||
self, prompt: str, generate_config: Optional[Dict[str, List[str]]] = None
|
||||
) -> AsyncIterator[str]:
|
||||
request_body: Dict[str, Any] = {"model": self.model_uid, "prompt": prompt}
|
||||
if generate_config is not None:
|
||||
for key, value in generate_config.items():
|
||||
request_body[key] = value
|
||||
|
||||
stream = bool(generate_config and generate_config.get("stream"))
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url=f"{self.server_url}/v1/completions",
|
||||
json=request_body,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
if response.status == 404:
|
||||
raise FileNotFoundError(
|
||||
"astream call failed with status code 404."
|
||||
)
|
||||
else:
|
||||
optional_detail = response.text
|
||||
raise ValueError(
|
||||
f"astream call failed with status code {response.status}."
|
||||
f" Details: {optional_detail}"
|
||||
)
|
||||
|
||||
async for line in response.content:
|
||||
if not stream:
|
||||
yield json.loads(line)
|
||||
else:
|
||||
json_str = line.decode("utf-8")
|
||||
if line.startswith(b"data:"):
|
||||
json_str = json_str[len(b"data:") :].strip()
|
||||
if not json_str:
|
||||
continue
|
||||
yield json.loads(json_str)
|
||||
|
Loading…
Reference in New Issue
Block a user