mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-21 23:17:48 +00:00
feat: Added class to support huggingface text generation inference server (#4447)
[Text Generation Inference](https://github.com/huggingface/text-generation-inference) is a Rust, Python and gRPC server for generating text using LLMs. This pull request add support for self hosted Text Generation Inference servers. feature: #4280 --------- Co-authored-by: Your Name <you@example.com> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
258c319855
commit
cf4c1394a2
@ -0,0 +1,77 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Huggingface TextGen Inference\n",
|
||||||
|
"\n",
|
||||||
|
"[Text Generation Inference](https://github.com/huggingface/text-generation-inference) is a Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co/) to power LLMs api-inference widgets.\n",
|
||||||
|
"\n",
|
||||||
|
"This notebooks goes over how to use a self hosted LLM using `Text Generation Inference`."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"To use, you should have the `text_generation` python package installed."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# !pip3 install text_generation "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm = HuggingFaceTextGenInference(\n",
|
||||||
|
" inference_server_url='http://localhost:8010/',\n",
|
||||||
|
" max_new_tokens=512,\n",
|
||||||
|
" top_k=10,\n",
|
||||||
|
" top_p=0.95,\n",
|
||||||
|
" typical_p=0.95,\n",
|
||||||
|
" temperature=0.01,\n",
|
||||||
|
" repetition_penalty=1.03,\n",
|
||||||
|
")\n",
|
||||||
|
"llm(\"What did foo say about bar?\")"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.3"
|
||||||
|
},
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
@ -26,6 +26,7 @@ from langchain.llms import (
|
|||||||
ForefrontAI,
|
ForefrontAI,
|
||||||
GooseAI,
|
GooseAI,
|
||||||
HuggingFaceHub,
|
HuggingFaceHub,
|
||||||
|
HuggingFaceTextGenInference,
|
||||||
LlamaCpp,
|
LlamaCpp,
|
||||||
Modal,
|
Modal,
|
||||||
OpenAI,
|
OpenAI,
|
||||||
@ -114,4 +115,5 @@ __all__ = [
|
|||||||
"QAWithSourcesChain",
|
"QAWithSourcesChain",
|
||||||
"PALChain",
|
"PALChain",
|
||||||
"LlamaCpp",
|
"LlamaCpp",
|
||||||
|
"HuggingFaceTextGenInference",
|
||||||
]
|
]
|
||||||
|
@ -16,6 +16,7 @@ from langchain.llms.gpt4all import GPT4All
|
|||||||
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
|
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||||
from langchain.llms.huggingface_hub import HuggingFaceHub
|
from langchain.llms.huggingface_hub import HuggingFaceHub
|
||||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
|
from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
|
||||||
from langchain.llms.human import HumanInputLLM
|
from langchain.llms.human import HumanInputLLM
|
||||||
from langchain.llms.llamacpp import LlamaCpp
|
from langchain.llms.llamacpp import LlamaCpp
|
||||||
from langchain.llms.modal import Modal
|
from langchain.llms.modal import Modal
|
||||||
@ -67,6 +68,7 @@ __all__ = [
|
|||||||
"RWKV",
|
"RWKV",
|
||||||
"PredictionGuard",
|
"PredictionGuard",
|
||||||
"HumanInputLLM",
|
"HumanInputLLM",
|
||||||
|
"HuggingFaceTextGenInference",
|
||||||
]
|
]
|
||||||
|
|
||||||
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||||
@ -99,4 +101,5 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
|||||||
"stochasticai": StochasticAI,
|
"stochasticai": StochasticAI,
|
||||||
"writer": Writer,
|
"writer": Writer,
|
||||||
"rwkv": RWKV,
|
"rwkv": RWKV,
|
||||||
|
"huggingface_textgen_inference": HuggingFaceTextGenInference,
|
||||||
}
|
}
|
||||||
|
118
langchain/llms/huggingface_text_gen_inference.py
Normal file
118
langchain/llms/huggingface_text_gen_inference.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
"""Wrapper around Huggingface text generation inference API."""
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import Extra, Field, root_validator
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceTextGenInference(LLM):
|
||||||
|
"""
|
||||||
|
HuggingFace text generation inference API.
|
||||||
|
|
||||||
|
This class is a wrapper around the HuggingFace text generation inference API.
|
||||||
|
It is used to generate text from a given prompt.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
- max_new_tokens: The maximum number of tokens to generate.
|
||||||
|
- top_k: The number of top-k tokens to consider when generating text.
|
||||||
|
- top_p: The cumulative probability threshold for generating text.
|
||||||
|
- typical_p: The typical probability threshold for generating text.
|
||||||
|
- temperature: The temperature to use when generating text.
|
||||||
|
- repetition_penalty: The repetition penalty to use when generating text.
|
||||||
|
- stop_sequences: A list of stop sequences to use when generating text.
|
||||||
|
- seed: The seed to use when generating text.
|
||||||
|
- inference_server_url: The URL of the inference server to use.
|
||||||
|
- timeout: The timeout value in seconds to use while connecting to inference server.
|
||||||
|
- client: The client object used to communicate with the inference server.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
- _call: Generates text based on a given prompt and stop sequences.
|
||||||
|
- _llm_type: Returns the type of LLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
llm = HuggingFaceTextGenInference(
|
||||||
|
inference_server_url = "http://localhost:8010/",
|
||||||
|
max_new_tokens = 512,
|
||||||
|
top_k = 10,
|
||||||
|
top_p = 0.95,
|
||||||
|
typical_p = 0.95,
|
||||||
|
temperature = 0.01,
|
||||||
|
repetition_penalty = 1.03,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_new_tokens: int = 512
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
top_p: Optional[float] = 0.95
|
||||||
|
typical_p: Optional[float] = 0.95
|
||||||
|
temperature: float = 0.8
|
||||||
|
repetition_penalty: Optional[float] = None
|
||||||
|
stop_sequences: List[str] = Field(default_factory=list)
|
||||||
|
seed: Optional[int] = None
|
||||||
|
inference_server_url: str = ""
|
||||||
|
timeout: int = 120
|
||||||
|
client: Any
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that python package exists in environment."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
import text_generation
|
||||||
|
|
||||||
|
values["client"] = text_generation.Client(
|
||||||
|
values["inference_server_url"], timeout=values["timeout"]
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import text_generation python package. "
|
||||||
|
"Please install it with `pip install text_generation`."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "hf_textgen_inference"
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
) -> str:
|
||||||
|
if stop is None:
|
||||||
|
stop = self.stop_sequences
|
||||||
|
else:
|
||||||
|
stop += self.stop_sequences
|
||||||
|
|
||||||
|
res = self.client.generate(
|
||||||
|
prompt,
|
||||||
|
stop_sequences=stop,
|
||||||
|
max_new_tokens=self.max_new_tokens,
|
||||||
|
top_k=self.top_k,
|
||||||
|
top_p=self.top_p,
|
||||||
|
typical_p=self.typical_p,
|
||||||
|
temperature=self.temperature,
|
||||||
|
repetition_penalty=self.repetition_penalty,
|
||||||
|
seed=self.seed,
|
||||||
|
)
|
||||||
|
# remove stop sequences from the end of the generated text
|
||||||
|
for stop_seq in stop:
|
||||||
|
if stop_seq in res.generated_text:
|
||||||
|
res.generated_text = res.generated_text[
|
||||||
|
: res.generated_text.index(stop_seq)
|
||||||
|
]
|
||||||
|
|
||||||
|
return res.generated_text
|
Loading…
Reference in New Issue
Block a user