diff --git a/libs/community/langchain_community/llms/xinference.py b/libs/community/langchain_community/llms/xinference.py index 06bc1c9b30e..f3f2aa734b8 100644 --- a/libs/community/langchain_community/llms/xinference.py +++ b/libs/community/langchain_community/llms/xinference.py @@ -1,8 +1,10 @@ from __future__ import annotations +import json from typing import ( TYPE_CHECKING, Any, + AsyncIterator, Dict, Generator, Iterator, @@ -12,7 +14,12 @@ from typing import ( Union, ) -from langchain_core.callbacks import CallbackManagerForLLMRun +import aiohttp +import requests +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk @@ -126,6 +133,7 @@ class Xinference(LLM): self, server_url: Optional[str] = None, model_uid: Optional[str] = None, + api_key: Optional[str] = None, **model_kwargs: Any, ): try: @@ -155,7 +163,13 @@ class Xinference(LLM): if self.model_uid is None: raise ValueError("Please provide the model UID") - self.client = RESTfulClient(server_url) + self._headers: Dict[str, str] = {} + self._cluster_authed = False + self._check_cluster_authenticated() + if api_key is not None and self._cluster_authed: + self._headers["Authorization"] = f"Bearer {api_key}" + + self.client = RESTfulClient(server_url, api_key) @property def _llm_type(self) -> str: @@ -171,6 +185,20 @@ class Xinference(LLM): **{"model_kwargs": self.model_kwargs}, } + def _check_cluster_authenticated(self) -> None: + url = f"{self.server_url}/v1/cluster/auth" + response = requests.get(url) + if response.status_code == 404: + self._cluster_authed = False + else: + if response.status_code != 200: + raise RuntimeError( + f"Failed to get cluster information, " + f"detail: {response.json()['detail']}" + ) + response_data = response.json() + self._cluster_authed = bool(response_data["auth"]) + def _call( self, prompt: str, @@ -305,3 +333,61 @@ class Xinference(LLM): return GenerationChunk(text=token) else: raise TypeError("stream_response type error!") + + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + generate_config = kwargs.get("generate_config", {}) + generate_config = {**self.model_kwargs, **generate_config} + if stop: + generate_config["stop"] = stop + async for stream_resp in self._acreate_generate_stream(prompt, generate_config): + if stream_resp: + chunk = self._stream_response_to_generation_chunk(stream_resp) + if run_manager: + await run_manager.on_llm_new_token( + chunk.text, + verbose=self.verbose, + ) + yield chunk + + async def _acreate_generate_stream( + self, prompt: str, generate_config: Optional[Dict[str, List[str]]] = None + ) -> AsyncIterator[str]: + request_body: Dict[str, Any] = {"model": self.model_uid, "prompt": prompt} + if generate_config is not None: + for key, value in generate_config.items(): + request_body[key] = value + + stream = bool(generate_config and generate_config.get("stream")) + async with aiohttp.ClientSession() as session: + async with session.post( + url=f"{self.server_url}/v1/completions", + json=request_body, + ) as response: + if response.status != 200: + if response.status == 404: + raise FileNotFoundError( + "astream call failed with status code 404." + ) + else: + optional_detail = response.text + raise ValueError( + f"astream call failed with status code {response.status}." + f" Details: {optional_detail}" + ) + + async for line in response.content: + if not stream: + yield json.loads(line) + else: + json_str = line.decode("utf-8") + if line.startswith(b"data:"): + json_str = json_str[len(b"data:") :].strip() + if not json_str: + continue + yield json.loads(json_str)