Compare commits

...

5 Commits

Author SHA1 Message Date
Harrison Chase
6c9133e851 cr 2023-03-10 07:14:43 -08:00
Harrison Chase
51d81b1ce6 Merge branch 'master' into harrison/inference-api 2023-03-10 07:14:25 -08:00
Harrison Chase
daff73b3f5 cr 2023-03-08 22:49:27 -08:00
Harrison Chase
cb0c1b2eaa cr 2023-03-08 22:15:11 -08:00
Harrison Chase
00c9fd239d inference api and streaming 2023-03-08 22:15:00 -08:00
4 changed files with 217 additions and 1 deletions

View File

@@ -0,0 +1,143 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "16ae872d",
"metadata": {},
"source": [
"# Hugging Face Text Generation\n",
"\n",
"This notebook shows how to use the Hugging Face Text Generation endpoint."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "aed972a5",
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import HFTextGeneration"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "09e693b8",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"HUGGINGFACEHUB_API_TOKEN\"] = \"...\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6ce9486b",
"metadata": {},
"outputs": [],
"source": [
"model = HFTextGeneration(repo_id=\"google/flan-ul2\", model_kwargs={\"temperature\":0.1, \"max_new_tokens\":200})"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2af51b5e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man\""
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model(\"write me a song about music\")"
]
},
{
"cell_type": "markdown",
"id": "d47e1e10",
"metadata": {},
"source": [
"## Streaming\n",
"This section shows how to using streaming with the HF text generation endpoint."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3ffe8d7e",
"metadata": {},
"outputs": [],
"source": [
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "5a84bdf3",
"metadata": {},
"outputs": [],
"source": [
"model = HFTextGeneration(repo_id=\"bigscience/bloomz\", streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "6d460fe9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"? Rayleigh scattering"
]
}
],
"source": [
"resp = model(\"why is the sky blue\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6ed74872",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -213,7 +213,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.9" "version": "3.9.1"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@@ -14,6 +14,7 @@ from langchain.llms.gooseai import GooseAI
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain.llms.huggingface_hub import HuggingFaceHub from langchain.llms.huggingface_hub import HuggingFaceHub
from langchain.llms.huggingface_pipeline import HuggingFacePipeline from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.llms.huggingface_text_generation import HFTextGeneration
from langchain.llms.modal import Modal from langchain.llms.modal import Modal
from langchain.llms.nlpcloud import NLPCloud from langchain.llms.nlpcloud import NLPCloud
from langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat from langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat
@@ -41,6 +42,7 @@ __all__ = [
"HuggingFaceEndpoint", "HuggingFaceEndpoint",
"HuggingFaceHub", "HuggingFaceHub",
"HuggingFacePipeline", "HuggingFacePipeline",
"HFTextGeneration",
"AI21", "AI21",
"AzureOpenAI", "AzureOpenAI",
"SelfHostedPipeline", "SelfHostedPipeline",

View File

@@ -0,0 +1,71 @@
from typing import Any, Dict, List, Optional
from pydantic import root_validator, Field
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
class HFTextGeneration(LLM):
repo_id: str
token: str
streaming: bool = False
model_kwargs: dict = Field(default_factory=dict)
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
try:
from text_generation import InferenceAPIClient
except ImportError:
raise ValueError(
"Could not import text_generation python package. "
"Please it install it with `pip install text_generation`."
)
values["token"] = huggingfacehub_api_token
return values
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
from text_generation import InferenceAPIClient
client = InferenceAPIClient(
repo_id=self.repo_id,
token=self.token,
)
if not self.streaming:
return client.generate(prompt, **self.model_kwargs).generated_text
if self.streaming:
text = ""
for response in client.generate_stream(prompt, **self.model_kwargs):
if not response.token.special:
self.callback_manager.on_llm_new_token(
response.token.text, verbose=self.verbose
)
text += response.token.text
return text
async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str:
from text_generation import InferenceAPIAsyncClient
client = InferenceAPIAsyncClient(
repo_id=self.repo_id,
token=self.token,
)
if not self.streaming:
response = await client.generate(prompt, **self.model_kwargs)
return response.generated_text
if self.streaming:
text = ""
async for response in client.generate_stream(prompt, **self.model_kwargs):
if not response.token.special:
await self.callback_manager.on_llm_new_token(
response.token.text, verbose=self.verbose
)
text += response.token.text
return text
@property
def _llm_type(self) -> str:
raise NotImplementedError