Add asynchronous generate interface (#30001)

- [ ] **PR title**: [langchain_community.llms.xinference]: Add
asynchronous generate interface

- [ ] **PR message**: The asynchronous generate interface support stream
data and non-stream data.
          
        chain = prompt | llm
        async for chunk in chain.astream(input=user_input):
            yield chunk


- [ ] **Add tests and docs**:

       from langchain_community.llms import Xinference
       from langchain.prompts import PromptTemplate

       llm = Xinference(
server_url="http://0.0.0.0:9997", # replace your xinference server url
model_uid={model_uid} # replace model_uid with the model UID return from
launching the model
           stream = True
            )
prompt = PromptTemplate(input=['country'], template="Q: where can we
visit in the capital of {country}? A:")
       chain = prompt | llm
       async for chunk in chain.astream(input=user_input):
           yield chunk
This commit is contained in:
TheSongg 2025-03-01 01:32:44 +08:00 committed by GitHub
parent a1897ca621
commit 86b364de3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,8 +1,10 @@
from __future__ import annotations
import json
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Generator,
Iterator,
@ -12,7 +14,12 @@ from typing import (
Union,
)
from langchain_core.callbacks import CallbackManagerForLLMRun
import aiohttp
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
@ -126,6 +133,7 @@ class Xinference(LLM):
self,
server_url: Optional[str] = None,
model_uid: Optional[str] = None,
api_key: Optional[str] = None,
**model_kwargs: Any,
):
try:
@ -155,7 +163,13 @@ class Xinference(LLM):
if self.model_uid is None:
raise ValueError("Please provide the model UID")
self.client = RESTfulClient(server_url)
self._headers: Dict[str, str] = {}
self._cluster_authed = False
self._check_cluster_authenticated()
if api_key is not None and self._cluster_authed:
self._headers["Authorization"] = f"Bearer {api_key}"
self.client = RESTfulClient(server_url, api_key)
@property
def _llm_type(self) -> str:
@ -171,6 +185,20 @@ class Xinference(LLM):
**{"model_kwargs": self.model_kwargs},
}
def _check_cluster_authenticated(self) -> None:
url = f"{self.server_url}/v1/cluster/auth"
response = requests.get(url)
if response.status_code == 404:
self._cluster_authed = False
else:
if response.status_code != 200:
raise RuntimeError(
f"Failed to get cluster information, "
f"detail: {response.json()['detail']}"
)
response_data = response.json()
self._cluster_authed = bool(response_data["auth"])
def _call(
self,
prompt: str,
@ -305,3 +333,61 @@ class Xinference(LLM):
return GenerationChunk(text=token)
else:
raise TypeError("stream_response type error!")
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
generate_config = kwargs.get("generate_config", {})
generate_config = {**self.model_kwargs, **generate_config}
if stop:
generate_config["stop"] = stop
async for stream_resp in self._acreate_generate_stream(prompt, generate_config):
if stream_resp:
chunk = self._stream_response_to_generation_chunk(stream_resp)
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
yield chunk
async def _acreate_generate_stream(
self, prompt: str, generate_config: Optional[Dict[str, List[str]]] = None
) -> AsyncIterator[str]:
request_body: Dict[str, Any] = {"model": self.model_uid, "prompt": prompt}
if generate_config is not None:
for key, value in generate_config.items():
request_body[key] = value
stream = bool(generate_config and generate_config.get("stream"))
async with aiohttp.ClientSession() as session:
async with session.post(
url=f"{self.server_url}/v1/completions",
json=request_body,
) as response:
if response.status != 200:
if response.status == 404:
raise FileNotFoundError(
"astream call failed with status code 404."
)
else:
optional_detail = response.text
raise ValueError(
f"astream call failed with status code {response.status}."
f" Details: {optional_detail}"
)
async for line in response.content:
if not stream:
yield json.loads(line)
else:
json_str = line.decode("utf-8")
if line.startswith(b"data:"):
json_str = json_str[len(b"data:") :].strip()
if not json_str:
continue
yield json.loads(json_str)