mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
Compare commits
5 Commits
langchain-
...
harrison/i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c9133e851 | ||
|
|
51d81b1ce6 | ||
|
|
daff73b3f5 | ||
|
|
cb0c1b2eaa | ||
|
|
00c9fd239d |
143
docs/modules/llms/integrations/hf_text_generation.ipynb
Normal file
143
docs/modules/llms/integrations/hf_text_generation.ipynb
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "16ae872d",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Hugging Face Text Generation\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook shows how to use the Hugging Face Text Generation endpoint."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "aed972a5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.llms import HFTextGeneration"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "09e693b8",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"os.environ[\"HUGGINGFACEHUB_API_TOKEN\"] = \"...\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "6ce9486b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model = HFTextGeneration(repo_id=\"google/flan-ul2\", model_kwargs={\"temperature\":0.1, \"max_new_tokens\":200})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "2af51b5e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"\"i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model(\"write me a song about music\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "d47e1e10",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Streaming\n",
|
||||||
|
"This section shows how to using streaming with the HF text generation endpoint."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "3ffe8d7e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.callbacks.base import CallbackManager\n",
|
||||||
|
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "5a84bdf3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model = HFTextGeneration(repo_id=\"bigscience/bloomz\", streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "6d460fe9",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"? Rayleigh scattering"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"resp = model(\"why is the sky blue\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6ed74872",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
@@ -213,7 +213,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.9"
|
"version": "3.9.1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from langchain.llms.gooseai import GooseAI
|
|||||||
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
|
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||||
from langchain.llms.huggingface_hub import HuggingFaceHub
|
from langchain.llms.huggingface_hub import HuggingFaceHub
|
||||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
|
from langchain.llms.huggingface_text_generation import HFTextGeneration
|
||||||
from langchain.llms.modal import Modal
|
from langchain.llms.modal import Modal
|
||||||
from langchain.llms.nlpcloud import NLPCloud
|
from langchain.llms.nlpcloud import NLPCloud
|
||||||
from langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat
|
from langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat
|
||||||
@@ -41,6 +42,7 @@ __all__ = [
|
|||||||
"HuggingFaceEndpoint",
|
"HuggingFaceEndpoint",
|
||||||
"HuggingFaceHub",
|
"HuggingFaceHub",
|
||||||
"HuggingFacePipeline",
|
"HuggingFacePipeline",
|
||||||
|
"HFTextGeneration",
|
||||||
"AI21",
|
"AI21",
|
||||||
"AzureOpenAI",
|
"AzureOpenAI",
|
||||||
"SelfHostedPipeline",
|
"SelfHostedPipeline",
|
||||||
|
|||||||
71
langchain/llms/huggingface_text_generation.py
Normal file
71
langchain/llms/huggingface_text_generation.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import root_validator, Field
|
||||||
|
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
|
class HFTextGeneration(LLM):
|
||||||
|
repo_id: str
|
||||||
|
token: str
|
||||||
|
streaming: bool = False
|
||||||
|
model_kwargs: dict = Field(default_factory=dict)
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
huggingfacehub_api_token = get_from_dict_or_env(
|
||||||
|
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from text_generation import InferenceAPIClient
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import text_generation python package. "
|
||||||
|
"Please it install it with `pip install text_generation`."
|
||||||
|
)
|
||||||
|
values["token"] = huggingfacehub_api_token
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
from text_generation import InferenceAPIClient
|
||||||
|
client = InferenceAPIClient(
|
||||||
|
repo_id=self.repo_id,
|
||||||
|
token=self.token,
|
||||||
|
)
|
||||||
|
if not self.streaming:
|
||||||
|
return client.generate(prompt, **self.model_kwargs).generated_text
|
||||||
|
if self.streaming:
|
||||||
|
text = ""
|
||||||
|
for response in client.generate_stream(prompt, **self.model_kwargs):
|
||||||
|
if not response.token.special:
|
||||||
|
self.callback_manager.on_llm_new_token(
|
||||||
|
response.token.text, verbose=self.verbose
|
||||||
|
)
|
||||||
|
text += response.token.text
|
||||||
|
return text
|
||||||
|
|
||||||
|
async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
from text_generation import InferenceAPIAsyncClient
|
||||||
|
client = InferenceAPIAsyncClient(
|
||||||
|
repo_id=self.repo_id,
|
||||||
|
token=self.token,
|
||||||
|
)
|
||||||
|
if not self.streaming:
|
||||||
|
response = await client.generate(prompt, **self.model_kwargs)
|
||||||
|
return response.generated_text
|
||||||
|
if self.streaming:
|
||||||
|
text = ""
|
||||||
|
async for response in client.generate_stream(prompt, **self.model_kwargs):
|
||||||
|
if not response.token.special:
|
||||||
|
await self.callback_manager.on_llm_new_token(
|
||||||
|
response.token.text, verbose=self.verbose
|
||||||
|
)
|
||||||
|
text += response.token.text
|
||||||
|
return text
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
Reference in New Issue
Block a user