mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-25 04:23:05 +00:00
[langchain_community.llms.xinference]: Rewrite _stream() method and support stream() method in xinference.py (#29259)
- [ ] **PR title**:[langchain_community.llms.xinference]: Rewrite _stream() method and support stream() method in xinference.py - [ ] **PR message**: Rewrite the _stream method so that the chain.stream() can be used to return data streams. chain = prompt | llm chain.stream(input=user_input) - [ ] **tests**: from langchain_community.llms import Xinference from langchain.prompts import PromptTemplate llm = Xinference( server_url="http://0.0.0.0:9997", # replace your xinference server url model_uid={model_uid} # replace model_uid with the model UID return from launching the model stream = True ) prompt = PromptTemplate(input=['country'], template="Q: where can we visit in the capital of {country}? A:") chain = prompt | llm chain.stream(input={'country': 'France'})
This commit is contained in:
parent
d4b9404fd6
commit
1cd4d8d101
@ -1,7 +1,20 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from xinference.client import RESTfulChatModelHandle, RESTfulGenerateModelHandle
|
||||
@ -73,6 +86,26 @@ class Xinference(LLM):
|
||||
generate_config={"max_tokens": 1024, "stream": True},
|
||||
)
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import Xinference
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
llm = Xinference(
|
||||
server_url="http://0.0.0.0:9997",
|
||||
model_uid={model_uid}, # replace model_uid with the model UID return from launching the model
|
||||
stream=True
|
||||
)
|
||||
prompt = PromptTemplate(
|
||||
input=['country'],
|
||||
template="Q: where can we visit in the capital of {country}? A:"
|
||||
)
|
||||
chain = prompt | llm
|
||||
chain.stream(input={'country': 'France'})
|
||||
|
||||
|
||||
To view all the supported builtin models, run:
|
||||
|
||||
.. code-block:: bash
|
||||
@ -216,3 +249,59 @@ class Xinference(LLM):
|
||||
token=token, verbose=self.verbose, log_probs=log_probs
|
||||
)
|
||||
yield token
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
generate_config = kwargs.get("generate_config", {})
|
||||
generate_config = {**self.model_kwargs, **generate_config}
|
||||
if stop:
|
||||
generate_config["stop"] = stop
|
||||
for stream_resp in self._create_generate_stream(prompt, generate_config):
|
||||
if stream_resp:
|
||||
chunk = self._stream_response_to_generation_chunk(stream_resp)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
yield chunk
|
||||
|
||||
def _create_generate_stream(
|
||||
self, prompt: str, generate_config: Optional[Dict[str, List[str]]] = None
|
||||
) -> Iterator[str]:
|
||||
if self.client is None:
|
||||
raise ValueError("Client is not initialized!")
|
||||
model = self.client.get_model(self.model_uid)
|
||||
yield from model.generate(prompt=prompt, generate_config=generate_config)
|
||||
|
||||
@staticmethod
|
||||
def _stream_response_to_generation_chunk(
|
||||
stream_response: str,
|
||||
) -> GenerationChunk:
|
||||
"""Convert a stream response to a generation chunk."""
|
||||
token = ""
|
||||
if isinstance(stream_response, dict):
|
||||
choices = stream_response.get("choices", [])
|
||||
if choices:
|
||||
choice = choices[0]
|
||||
if isinstance(choice, dict):
|
||||
token = choice.get("text", "")
|
||||
|
||||
return GenerationChunk(
|
||||
text=token,
|
||||
generation_info=dict(
|
||||
finish_reason=choice.get("finish_reason", None),
|
||||
logprobs=choice.get("logprobs", None),
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise TypeError("choice type error!")
|
||||
else:
|
||||
return GenerationChunk(text=token)
|
||||
else:
|
||||
raise TypeError("stream_response type error!")
|
||||
|
Loading…
Reference in New Issue
Block a user