mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 20:28:10 +00:00
parent
e48e562ea5
commit
b9f61390e9
@ -10,21 +10,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"Justin Beiber was born in New York City on July 1, 1967. He was the son of the late John Beiber and his wife, Mary.\n",
|
||||
"\n",
|
||||
"Justin was raised in a small town in the Bronx, New York. He attended the University of New York at Buffalo, where he majored in English.\n",
|
||||
"\n",
|
||||
"Justin was a member of the New York Giants from 1967 to 1969. He was a member of the New York Giants from 1969 to 1971.\n",
|
||||
"\n",
|
||||
"Justin was a member of the New York Giants from 1971 to 1972. He was a member of the New York Giants from 1972 to 1974.\n",
|
||||
"\n",
|
||||
"Justin was a member of the New York Giants from 1974 to 1975. He was a member of the New York Giants from 1975 to 1977.\n",
|
||||
"\n",
|
||||
"Justin was a member of the New York Giants from 1977 to 1978. He was a member of the New York Giants from 1978 to 1979.\n",
|
||||
"\n",
|
||||
"Justin was a member of the New York Giants from 1979 to\n"
|
||||
"The Seattle Seahawks won the Super Bowl in 2010. Justin Beiber was born in 2010. The\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -35,7 +21,7 @@
|
||||
"\n",
|
||||
"Answer: Let's think step by step.\"\"\"\n",
|
||||
"prompt = Prompt(template=template, input_variables=[\"question\"])\n",
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=HuggingFaceHub(repo_id=\"gpt2\", temperature=1e-10))\n",
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=HuggingFaceHub(repo_id=\"google/flan-t5-xl\", model_kwargs={\"temperature\":1e-10}))\n",
|
||||
"\n",
|
||||
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
|
||||
"\n",
|
||||
@ -67,7 +53,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.6"
|
||||
"version": "3.8.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Wrapper around HuggingFace APIs."""
|
||||
import os
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
@ -8,6 +8,7 @@ from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
DEFAULT_REPO_ID = "gpt2"
|
||||
VALID_TASKS = ("text2text-generation", "text-generation")
|
||||
|
||||
|
||||
class HuggingFaceHub(BaseModel, LLM):
|
||||
@ -29,14 +30,10 @@ class HuggingFaceHub(BaseModel, LLM):
|
||||
client: Any #: :meta private:
|
||||
repo_id: str = DEFAULT_REPO_ID
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
max_new_tokens: int = 200
|
||||
"""The maximum number of tokens to generate in the completion."""
|
||||
top_p: int = 1
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
num_return_sequences: int = 1
|
||||
"""How many completions to generate for each prompt."""
|
||||
task: Optional[str] = None
|
||||
"""Task to call the model with. Should be a task that returns `generated_text`."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
huggingfacehub_api_token: Optional[str] = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
||||
|
||||
@ -49,7 +46,6 @@ class HuggingFaceHub(BaseModel, LLM):
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
huggingfacehub_api_token = values.get("huggingfacehub_api_token")
|
||||
|
||||
if huggingfacehub_api_token is None or huggingfacehub_api_token == "":
|
||||
raise ValueError(
|
||||
"Did not find HuggingFace API token, please add an environment variable"
|
||||
@ -60,11 +56,17 @@ class HuggingFaceHub(BaseModel, LLM):
|
||||
from huggingface_hub.inference_api import InferenceApi
|
||||
|
||||
repo_id = values.get("repo_id", DEFAULT_REPO_ID)
|
||||
values["client"] = InferenceApi(
|
||||
client = InferenceApi(
|
||||
repo_id=repo_id,
|
||||
token=huggingfacehub_api_token,
|
||||
task="text-generation",
|
||||
task=values.get("task"),
|
||||
)
|
||||
if client.task not in VALID_TASKS:
|
||||
raise ValueError(
|
||||
f"Got invalid task {client.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
values["client"] = client
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import huggingface_hub python package. "
|
||||
@ -72,16 +74,6 @@ class HuggingFaceHub(BaseModel, LLM):
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Mapping[str, Any]:
|
||||
"""Get the default parameters for calling HuggingFace Hub API."""
|
||||
return {
|
||||
"temperature": self.temperature,
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"top_p": self.top_p,
|
||||
"num_return_sequences": self.num_return_sequences,
|
||||
}
|
||||
|
||||
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
"""Call out to HuggingFace Hub's inference endpoint.
|
||||
|
||||
@ -97,10 +89,19 @@ class HuggingFaceHub(BaseModel, LLM):
|
||||
|
||||
response = hf("Tell me a joke.")
|
||||
"""
|
||||
response = self.client(inputs=prompt, params=self._default_params)
|
||||
response = self.client(inputs=prompt, params=self.model_kwargs)
|
||||
if "error" in response:
|
||||
raise ValueError(f"Error raised by inference API: {response['error']}")
|
||||
text = response[0]["generated_text"][len(prompt) :]
|
||||
if self.client.task == "text-generation":
|
||||
# Text generation return includes the starter text.
|
||||
text = response[0]["generated_text"][len(prompt) :]
|
||||
elif self.client.task == "text2text-generation":
|
||||
text = response[0]["generated_text"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {self.client.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
if stop is not None:
|
||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||
# stop tokens when making calls to huggingface_hub.
|
||||
|
@ -5,15 +5,22 @@ import pytest
|
||||
from langchain.llms.huggingface_hub import HuggingFaceHub
|
||||
|
||||
|
||||
def test_huggingface_call() -> None:
|
||||
"""Test valid call to HuggingFace."""
|
||||
llm = HuggingFaceHub(max_new_tokens=10)
|
||||
def test_huggingface_text_generation() -> None:
|
||||
"""Test valid call to HuggingFace text generation model."""
|
||||
llm = HuggingFaceHub(repo_id="gpt2", model_kwargs={"max_new_tokens": 10})
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_huggingface_text2text_generation() -> None:
|
||||
"""Test valid call to HuggingFace text2text model."""
|
||||
llm = HuggingFaceHub(repo_id="google/flan-t5-xl")
|
||||
output = llm("The capital of New York is")
|
||||
assert output == "Albany"
|
||||
|
||||
|
||||
def test_huggingface_call_error() -> None:
|
||||
"""Test valid call to HuggingFace that errors."""
|
||||
llm = HuggingFaceHub(max_new_tokens=-1)
|
||||
llm = HuggingFaceHub(model_kwargs={"max_new_tokens": -1})
|
||||
with pytest.raises(ValueError):
|
||||
llm("Say foo:")
|
||||
|
Loading…
Reference in New Issue
Block a user