add text2text generation (#93)

fixes issue #90
This commit is contained in:
Harrison Chase 2022-11-08 18:08:46 -08:00 committed by GitHub
parent e48e562ea5
commit b9f61390e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 45 deletions

View File

@ -10,21 +10,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\n", "The Seattle Seahawks won the Super Bowl in 2010. Justin Beiber was born in 2010. The\n"
"\n",
"Justin Beiber was born in New York City on July 1, 1967. He was the son of the late John Beiber and his wife, Mary.\n",
"\n",
"Justin was raised in a small town in the Bronx, New York. He attended the University of New York at Buffalo, where he majored in English.\n",
"\n",
"Justin was a member of the New York Giants from 1967 to 1969. He was a member of the New York Giants from 1969 to 1971.\n",
"\n",
"Justin was a member of the New York Giants from 1971 to 1972. He was a member of the New York Giants from 1972 to 1974.\n",
"\n",
"Justin was a member of the New York Giants from 1974 to 1975. He was a member of the New York Giants from 1975 to 1977.\n",
"\n",
"Justin was a member of the New York Giants from 1977 to 1978. He was a member of the New York Giants from 1978 to 1979.\n",
"\n",
"Justin was a member of the New York Giants from 1979 to\n"
] ]
} }
], ],
@ -35,7 +21,7 @@
"\n", "\n",
"Answer: Let's think step by step.\"\"\"\n", "Answer: Let's think step by step.\"\"\"\n",
"prompt = Prompt(template=template, input_variables=[\"question\"])\n", "prompt = Prompt(template=template, input_variables=[\"question\"])\n",
"llm_chain = LLMChain(prompt=prompt, llm=HuggingFaceHub(repo_id=\"gpt2\", temperature=1e-10))\n", "llm_chain = LLMChain(prompt=prompt, llm=HuggingFaceHub(repo_id=\"google/flan-t5-xl\", model_kwargs={\"temperature\":1e-10}))\n",
"\n", "\n",
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n", "question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
"\n", "\n",
@ -67,7 +53,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.6" "version": "3.8.7"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -1,6 +1,6 @@
"""Wrapper around HuggingFace APIs.""" """Wrapper around HuggingFace APIs."""
import os import os
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
@ -8,6 +8,7 @@ from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
DEFAULT_REPO_ID = "gpt2" DEFAULT_REPO_ID = "gpt2"
VALID_TASKS = ("text2text-generation", "text-generation")
class HuggingFaceHub(BaseModel, LLM): class HuggingFaceHub(BaseModel, LLM):
@ -29,14 +30,10 @@ class HuggingFaceHub(BaseModel, LLM):
client: Any #: :meta private: client: Any #: :meta private:
repo_id: str = DEFAULT_REPO_ID repo_id: str = DEFAULT_REPO_ID
"""Model name to use.""" """Model name to use."""
temperature: float = 0.7 task: Optional[str] = None
"""What sampling temperature to use.""" """Task to call the model with. Should be a task that returns `generated_text`."""
max_new_tokens: int = 200 model_kwargs: Optional[dict] = None
"""The maximum number of tokens to generate in the completion.""" """Key word arguments to pass to the model."""
top_p: int = 1
"""Total probability mass of tokens to consider at each step."""
num_return_sequences: int = 1
"""How many completions to generate for each prompt."""
huggingfacehub_api_token: Optional[str] = os.environ.get("HUGGINGFACEHUB_API_TOKEN") huggingfacehub_api_token: Optional[str] = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
@ -49,7 +46,6 @@ class HuggingFaceHub(BaseModel, LLM):
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
huggingfacehub_api_token = values.get("huggingfacehub_api_token") huggingfacehub_api_token = values.get("huggingfacehub_api_token")
if huggingfacehub_api_token is None or huggingfacehub_api_token == "": if huggingfacehub_api_token is None or huggingfacehub_api_token == "":
raise ValueError( raise ValueError(
"Did not find HuggingFace API token, please add an environment variable" "Did not find HuggingFace API token, please add an environment variable"
@ -60,11 +56,17 @@ class HuggingFaceHub(BaseModel, LLM):
from huggingface_hub.inference_api import InferenceApi from huggingface_hub.inference_api import InferenceApi
repo_id = values.get("repo_id", DEFAULT_REPO_ID) repo_id = values.get("repo_id", DEFAULT_REPO_ID)
values["client"] = InferenceApi( client = InferenceApi(
repo_id=repo_id, repo_id=repo_id,
token=huggingfacehub_api_token, token=huggingfacehub_api_token,
task="text-generation", task=values.get("task"),
) )
if client.task not in VALID_TASKS:
raise ValueError(
f"Got invalid task {client.task}, "
f"currently only {VALID_TASKS} are supported"
)
values["client"] = client
except ImportError: except ImportError:
raise ValueError( raise ValueError(
"Could not import huggingface_hub python package. " "Could not import huggingface_hub python package. "
@ -72,16 +74,6 @@ class HuggingFaceHub(BaseModel, LLM):
) )
return values return values
@property
def _default_params(self) -> Mapping[str, Any]:
"""Get the default parameters for calling HuggingFace Hub API."""
return {
"temperature": self.temperature,
"max_new_tokens": self.max_new_tokens,
"top_p": self.top_p,
"num_return_sequences": self.num_return_sequences,
}
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to HuggingFace Hub's inference endpoint. """Call out to HuggingFace Hub's inference endpoint.
@ -97,10 +89,19 @@ class HuggingFaceHub(BaseModel, LLM):
response = hf("Tell me a joke.") response = hf("Tell me a joke.")
""" """
response = self.client(inputs=prompt, params=self._default_params) response = self.client(inputs=prompt, params=self.model_kwargs)
if "error" in response: if "error" in response:
raise ValueError(f"Error raised by inference API: {response['error']}") raise ValueError(f"Error raised by inference API: {response['error']}")
text = response[0]["generated_text"][len(prompt) :] if self.client.task == "text-generation":
# Text generation return includes the starter text.
text = response[0]["generated_text"][len(prompt) :]
elif self.client.task == "text2text-generation":
text = response[0]["generated_text"]
else:
raise ValueError(
f"Got invalid task {self.client.task}, "
f"currently only {VALID_TASKS} are supported"
)
if stop is not None: if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce # This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub. # stop tokens when making calls to huggingface_hub.

View File

@ -5,15 +5,22 @@ import pytest
from langchain.llms.huggingface_hub import HuggingFaceHub from langchain.llms.huggingface_hub import HuggingFaceHub
def test_huggingface_call() -> None: def test_huggingface_text_generation() -> None:
"""Test valid call to HuggingFace.""" """Test valid call to HuggingFace text generation model."""
llm = HuggingFaceHub(max_new_tokens=10) llm = HuggingFaceHub(repo_id="gpt2", model_kwargs={"max_new_tokens": 10})
output = llm("Say foo:") output = llm("Say foo:")
assert isinstance(output, str) assert isinstance(output, str)
def test_huggingface_text2text_generation() -> None:
"""Test valid call to HuggingFace text2text model."""
llm = HuggingFaceHub(repo_id="google/flan-t5-xl")
output = llm("The capital of New York is")
assert output == "Albany"
def test_huggingface_call_error() -> None: def test_huggingface_call_error() -> None:
"""Test valid call to HuggingFace that errors.""" """Test valid call to HuggingFace that errors."""
llm = HuggingFaceHub(max_new_tokens=-1) llm = HuggingFaceHub(model_kwargs={"max_new_tokens": -1})
with pytest.raises(ValueError): with pytest.raises(ValueError):
llm("Say foo:") llm("Say foo:")