diff --git a/docs/modules/models/llms/integrations/huggingface_textgen_inference.ipynb b/docs/modules/models/llms/integrations/huggingface_textgen_inference.ipynb new file mode 100644 index 00000000000..3d27a831b7d --- /dev/null +++ b/docs/modules/models/llms/integrations/huggingface_textgen_inference.ipynb @@ -0,0 +1,77 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Huggingface TextGen Inference\n", + "\n", + "[Text Generation Inference](https://github.com/huggingface/text-generation-inference) is a Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co/) to power LLMs api-inference widgets.\n", + "\n", + "This notebooks goes over how to use a self hosted LLM using `Text Generation Inference`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To use, you should have the `text_generation` python package installed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# !pip3 install text_generation " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "llm = HuggingFaceTextGenInference(\n", + " inference_server_url='http://localhost:8010/',\n", + " max_new_tokens=512,\n", + " top_k=10,\n", + " top_p=0.95,\n", + " typical_p=0.95,\n", + " temperature=0.01,\n", + " repetition_penalty=1.03,\n", + ")\n", + "llm(\"What did foo say about bar?\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + }, + "vscode": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/langchain/__init__.py b/langchain/__init__.py index c2513b2847b..b9cc9469bd5 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -26,6 +26,7 @@ from langchain.llms import ( ForefrontAI, GooseAI, HuggingFaceHub, + HuggingFaceTextGenInference, LlamaCpp, Modal, OpenAI, @@ -114,4 +115,5 @@ __all__ = [ "QAWithSourcesChain", "PALChain", "LlamaCpp", + "HuggingFaceTextGenInference", ] diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py index 2dae19c92e0..a7f5980cfdd 100644 --- a/langchain/llms/__init__.py +++ b/langchain/llms/__init__.py @@ -16,6 +16,7 @@ from langchain.llms.gpt4all import GPT4All from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint from langchain.llms.huggingface_hub import HuggingFaceHub from langchain.llms.huggingface_pipeline import HuggingFacePipeline +from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference from langchain.llms.human import HumanInputLLM from langchain.llms.llamacpp import LlamaCpp from langchain.llms.modal import Modal @@ -67,6 +68,7 @@ __all__ = [ "RWKV", "PredictionGuard", "HumanInputLLM", + "HuggingFaceTextGenInference", ] type_to_cls_dict: Dict[str, Type[BaseLLM]] = { @@ -99,4 +101,5 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = { "stochasticai": StochasticAI, "writer": Writer, "rwkv": RWKV, + "huggingface_textgen_inference": HuggingFaceTextGenInference, } diff --git a/langchain/llms/huggingface_text_gen_inference.py b/langchain/llms/huggingface_text_gen_inference.py new file mode 100644 index 00000000000..a2489865fab --- /dev/null +++ b/langchain/llms/huggingface_text_gen_inference.py @@ -0,0 +1,118 @@ +"""Wrapper around Huggingface text generation inference API.""" +from typing import Any, Dict, List, Optional + +from pydantic import Extra, Field, root_validator + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM + + +class HuggingFaceTextGenInference(LLM): + """ + HuggingFace text generation inference API. + + This class is a wrapper around the HuggingFace text generation inference API. + It is used to generate text from a given prompt. + + Attributes: + - max_new_tokens: The maximum number of tokens to generate. + - top_k: The number of top-k tokens to consider when generating text. + - top_p: The cumulative probability threshold for generating text. + - typical_p: The typical probability threshold for generating text. + - temperature: The temperature to use when generating text. + - repetition_penalty: The repetition penalty to use when generating text. + - stop_sequences: A list of stop sequences to use when generating text. + - seed: The seed to use when generating text. + - inference_server_url: The URL of the inference server to use. + - timeout: The timeout value in seconds to use while connecting to inference server. + - client: The client object used to communicate with the inference server. + + Methods: + - _call: Generates text based on a given prompt and stop sequences. + - _llm_type: Returns the type of LLM. + """ + + """ + Example: + .. code-block:: python + + llm = HuggingFaceTextGenInference( + inference_server_url = "http://localhost:8010/", + max_new_tokens = 512, + top_k = 10, + top_p = 0.95, + typical_p = 0.95, + temperature = 0.01, + repetition_penalty = 1.03, + ) + """ + + max_new_tokens: int = 512 + top_k: Optional[int] = None + top_p: Optional[float] = 0.95 + typical_p: Optional[float] = 0.95 + temperature: float = 0.8 + repetition_penalty: Optional[float] = None + stop_sequences: List[str] = Field(default_factory=list) + seed: Optional[int] = None + inference_server_url: str = "" + timeout: int = 120 + client: Any + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that python package exists in environment.""" + + try: + import text_generation + + values["client"] = text_generation.Client( + values["inference_server_url"], timeout=values["timeout"] + ) + except ImportError: + raise ValueError( + "Could not import text_generation python package. " + "Please install it with `pip install text_generation`." + ) + return values + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "hf_textgen_inference" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: + if stop is None: + stop = self.stop_sequences + else: + stop += self.stop_sequences + + res = self.client.generate( + prompt, + stop_sequences=stop, + max_new_tokens=self.max_new_tokens, + top_k=self.top_k, + top_p=self.top_p, + typical_p=self.typical_p, + temperature=self.temperature, + repetition_penalty=self.repetition_penalty, + seed=self.seed, + ) + # remove stop sequences from the end of the generated text + for stop_seq in stop: + if stop_seq in res.generated_text: + res.generated_text = res.generated_text[ + : res.generated_text.index(stop_seq) + ] + + return res.generated_text