mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 10:13:29 +00:00
Added support for streaming output response to HuggingFaceTextgenInference LLM class (#4633)
# Added support for streaming output response to HuggingFaceTextgenInference LLM class Current implementation does not support streaming output. Updated to incorporate this feature. Tagging @agola11 for visibility.
This commit is contained in:
parent
435b70da47
commit
c70ae562b4
@ -1,4 +1,5 @@
|
|||||||
"""Wrapper around Huggingface text generation inference API."""
|
"""Wrapper around Huggingface text generation inference API."""
|
||||||
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import Extra, Field, root_validator
|
from pydantic import Extra, Field, root_validator
|
||||||
@ -36,6 +37,7 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Basic Example (no streaming)
|
||||||
llm = HuggingFaceTextGenInference(
|
llm = HuggingFaceTextGenInference(
|
||||||
inference_server_url = "http://localhost:8010/",
|
inference_server_url = "http://localhost:8010/",
|
||||||
max_new_tokens = 512,
|
max_new_tokens = 512,
|
||||||
@ -45,6 +47,25 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
temperature = 0.01,
|
temperature = 0.01,
|
||||||
repetition_penalty = 1.03,
|
repetition_penalty = 1.03,
|
||||||
)
|
)
|
||||||
|
print(llm("What is Deep Learning?"))
|
||||||
|
|
||||||
|
# Streaming response example
|
||||||
|
from langchain.callbacks import streaming_stdout
|
||||||
|
|
||||||
|
callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
|
||||||
|
llm = HuggingFaceTextGenInference(
|
||||||
|
inference_server_url = "http://localhost:8010/",
|
||||||
|
max_new_tokens = 512,
|
||||||
|
top_k = 10,
|
||||||
|
top_p = 0.95,
|
||||||
|
typical_p = 0.95,
|
||||||
|
temperature = 0.01,
|
||||||
|
repetition_penalty = 1.03,
|
||||||
|
callbacks = callbacks,
|
||||||
|
stream = True
|
||||||
|
)
|
||||||
|
print(llm("What is Deep Learning?"))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_new_tokens: int = 512
|
max_new_tokens: int = 512
|
||||||
@ -57,6 +78,7 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
inference_server_url: str = ""
|
inference_server_url: str = ""
|
||||||
timeout: int = 120
|
timeout: int = 120
|
||||||
|
stream: bool = False
|
||||||
client: Any
|
client: Any
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -97,22 +119,52 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
else:
|
else:
|
||||||
stop += self.stop_sequences
|
stop += self.stop_sequences
|
||||||
|
|
||||||
res = self.client.generate(
|
if not self.stream:
|
||||||
prompt,
|
res = self.client.generate(
|
||||||
stop_sequences=stop,
|
prompt,
|
||||||
max_new_tokens=self.max_new_tokens,
|
stop_sequences=stop,
|
||||||
top_k=self.top_k,
|
max_new_tokens=self.max_new_tokens,
|
||||||
top_p=self.top_p,
|
top_k=self.top_k,
|
||||||
typical_p=self.typical_p,
|
top_p=self.top_p,
|
||||||
temperature=self.temperature,
|
typical_p=self.typical_p,
|
||||||
repetition_penalty=self.repetition_penalty,
|
temperature=self.temperature,
|
||||||
seed=self.seed,
|
repetition_penalty=self.repetition_penalty,
|
||||||
)
|
seed=self.seed,
|
||||||
# remove stop sequences from the end of the generated text
|
)
|
||||||
for stop_seq in stop:
|
# remove stop sequences from the end of the generated text
|
||||||
if stop_seq in res.generated_text:
|
for stop_seq in stop:
|
||||||
res.generated_text = res.generated_text[
|
if stop_seq in res.generated_text:
|
||||||
: res.generated_text.index(stop_seq)
|
res.generated_text = res.generated_text[
|
||||||
]
|
: res.generated_text.index(stop_seq)
|
||||||
|
]
|
||||||
return res.generated_text
|
text = res.generated_text
|
||||||
|
else:
|
||||||
|
text_callback = None
|
||||||
|
if run_manager:
|
||||||
|
text_callback = partial(
|
||||||
|
run_manager.on_llm_new_token, verbose=self.verbose
|
||||||
|
)
|
||||||
|
params = {
|
||||||
|
"stop_sequences": stop,
|
||||||
|
"max_new_tokens": self.max_new_tokens,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"typical_p": self.typical_p,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"repetition_penalty": self.repetition_penalty,
|
||||||
|
"seed": self.seed,
|
||||||
|
}
|
||||||
|
text = ""
|
||||||
|
for res in self.client.generate_stream(prompt, **params):
|
||||||
|
token = res.token
|
||||||
|
is_stop = False
|
||||||
|
for stop_seq in stop:
|
||||||
|
if stop_seq in token.text:
|
||||||
|
is_stop = True
|
||||||
|
break
|
||||||
|
if is_stop:
|
||||||
|
break
|
||||||
|
if not token.special:
|
||||||
|
if text_callback:
|
||||||
|
text_callback(token.text)
|
||||||
|
return text
|
||||||
|
Loading…
Reference in New Issue
Block a user