mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 13:00:34 +00:00
Add asynchronous generate interface (#30001)
- [ ] **PR title**: [langchain_community.llms.xinference]: Add asynchronous generate interface - [ ] **PR message**: The asynchronous generate interface support stream data and non-stream data. chain = prompt | llm async for chunk in chain.astream(input=user_input): yield chunk - [ ] **Add tests and docs**: from langchain_community.llms import Xinference from langchain.prompts import PromptTemplate llm = Xinference( server_url="http://0.0.0.0:9997", # replace your xinference server url model_uid={model_uid} # replace model_uid with the model UID return from launching the model stream = True ) prompt = PromptTemplate(input=['country'], template="Q: where can we visit in the capital of {country}? A:") chain = prompt | llm async for chunk in chain.astream(input=user_input): yield chunk
This commit is contained in:
parent
a1897ca621
commit
86b364de3b
@ -1,8 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
Iterator,
|
Iterator,
|
||||||
@ -12,7 +14,12 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
import aiohttp
|
||||||
|
import requests
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
from langchain_core.outputs import GenerationChunk
|
from langchain_core.outputs import GenerationChunk
|
||||||
|
|
||||||
@ -126,6 +133,7 @@ class Xinference(LLM):
|
|||||||
self,
|
self,
|
||||||
server_url: Optional[str] = None,
|
server_url: Optional[str] = None,
|
||||||
model_uid: Optional[str] = None,
|
model_uid: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
**model_kwargs: Any,
|
**model_kwargs: Any,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
@ -155,7 +163,13 @@ class Xinference(LLM):
|
|||||||
if self.model_uid is None:
|
if self.model_uid is None:
|
||||||
raise ValueError("Please provide the model UID")
|
raise ValueError("Please provide the model UID")
|
||||||
|
|
||||||
self.client = RESTfulClient(server_url)
|
self._headers: Dict[str, str] = {}
|
||||||
|
self._cluster_authed = False
|
||||||
|
self._check_cluster_authenticated()
|
||||||
|
if api_key is not None and self._cluster_authed:
|
||||||
|
self._headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
self.client = RESTfulClient(server_url, api_key)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
@ -171,6 +185,20 @@ class Xinference(LLM):
|
|||||||
**{"model_kwargs": self.model_kwargs},
|
**{"model_kwargs": self.model_kwargs},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _check_cluster_authenticated(self) -> None:
|
||||||
|
url = f"{self.server_url}/v1/cluster/auth"
|
||||||
|
response = requests.get(url)
|
||||||
|
if response.status_code == 404:
|
||||||
|
self._cluster_authed = False
|
||||||
|
else:
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to get cluster information, "
|
||||||
|
f"detail: {response.json()['detail']}"
|
||||||
|
)
|
||||||
|
response_data = response.json()
|
||||||
|
self._cluster_authed = bool(response_data["auth"])
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -305,3 +333,61 @@ class Xinference(LLM):
|
|||||||
return GenerationChunk(text=token)
|
return GenerationChunk(text=token)
|
||||||
else:
|
else:
|
||||||
raise TypeError("stream_response type error!")
|
raise TypeError("stream_response type error!")
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[GenerationChunk]:
|
||||||
|
generate_config = kwargs.get("generate_config", {})
|
||||||
|
generate_config = {**self.model_kwargs, **generate_config}
|
||||||
|
if stop:
|
||||||
|
generate_config["stop"] = stop
|
||||||
|
async for stream_resp in self._acreate_generate_stream(prompt, generate_config):
|
||||||
|
if stream_resp:
|
||||||
|
chunk = self._stream_response_to_generation_chunk(stream_resp)
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(
|
||||||
|
chunk.text,
|
||||||
|
verbose=self.verbose,
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def _acreate_generate_stream(
|
||||||
|
self, prompt: str, generate_config: Optional[Dict[str, List[str]]] = None
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
request_body: Dict[str, Any] = {"model": self.model_uid, "prompt": prompt}
|
||||||
|
if generate_config is not None:
|
||||||
|
for key, value in generate_config.items():
|
||||||
|
request_body[key] = value
|
||||||
|
|
||||||
|
stream = bool(generate_config and generate_config.get("stream"))
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
url=f"{self.server_url}/v1/completions",
|
||||||
|
json=request_body,
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
if response.status == 404:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"astream call failed with status code 404."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optional_detail = response.text
|
||||||
|
raise ValueError(
|
||||||
|
f"astream call failed with status code {response.status}."
|
||||||
|
f" Details: {optional_detail}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async for line in response.content:
|
||||||
|
if not stream:
|
||||||
|
yield json.loads(line)
|
||||||
|
else:
|
||||||
|
json_str = line.decode("utf-8")
|
||||||
|
if line.startswith(b"data:"):
|
||||||
|
json_str = json_str[len(b"data:") :].strip()
|
||||||
|
if not json_str:
|
||||||
|
continue
|
||||||
|
yield json.loads(json_str)
|
||||||
|
Loading…
Reference in New Issue
Block a user