1
0
mirror of https://github.com/hwchase17/langchain.git synced 2025-09-22 19:09:57 +00:00

Add summarization task type for HuggingFace APIs ()

# Add summarization task type for HuggingFace APIs

Add summarization task type for HuggingFace APIs.
This task type is described by [HuggingFace inference
API](https://huggingface.co/docs/api-inference/detailed_parameters#summarization-task)

My project utilizes LangChain to connect multiple LLMs, including
various HuggingFace models that support the summarization task.
Integrating this task type is highly convenient and beneficial.

Fixes 
This commit is contained in:
whuwxl
2023-05-16 07:26:17 +08:00
committed by GitHub
parent 580861e7f2
commit 3f0357f94a
8 changed files with 62 additions and 12 deletions

@@ -9,7 +9,7 @@ from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
VALID_TASKS = ("text2text-generation", "text-generation") VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
class HuggingFaceEndpoint(LLM): class HuggingFaceEndpoint(LLM):
@@ -37,7 +37,8 @@ class HuggingFaceEndpoint(LLM):
endpoint_url: str = "" endpoint_url: str = ""
"""Endpoint URL to use.""" """Endpoint URL to use."""
task: Optional[str] = None task: Optional[str] = None
"""Task to call the model with. Should be a task that returns `generated_text`.""" """Task to call the model with.
Should be a task that returns `generated_text` or `summary_text`."""
model_kwargs: Optional[dict] = None model_kwargs: Optional[dict] = None
"""Key word arguments to pass to the model.""" """Key word arguments to pass to the model."""
@@ -138,6 +139,8 @@ class HuggingFaceEndpoint(LLM):
text = generated_text[0]["generated_text"][len(prompt) :] text = generated_text[0]["generated_text"][len(prompt) :]
elif self.task == "text2text-generation": elif self.task == "text2text-generation":
text = generated_text[0]["generated_text"] text = generated_text[0]["generated_text"]
elif self.task == "summarization":
text = generated_text[0]["summary_text"]
else: else:
raise ValueError( raise ValueError(
f"Got invalid task {self.task}, " f"Got invalid task {self.task}, "

@@ -9,7 +9,7 @@ from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
DEFAULT_REPO_ID = "gpt2" DEFAULT_REPO_ID = "gpt2"
VALID_TASKS = ("text2text-generation", "text-generation") VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
class HuggingFaceHub(LLM): class HuggingFaceHub(LLM):
@@ -19,7 +19,7 @@ class HuggingFaceHub(LLM):
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor. it as a named parameter to the constructor.
Only supports `text-generation` and `text2text-generation` for now. Only supports `text-generation`, `text2text-generation` and `summarization` for now.
Example: Example:
.. code-block:: python .. code-block:: python
@@ -32,7 +32,8 @@ class HuggingFaceHub(LLM):
repo_id: str = DEFAULT_REPO_ID repo_id: str = DEFAULT_REPO_ID
"""Model name to use.""" """Model name to use."""
task: Optional[str] = None task: Optional[str] = None
"""Task to call the model with. Should be a task that returns `generated_text`.""" """Task to call the model with.
Should be a task that returns `generated_text` or `summary_text`."""
model_kwargs: Optional[dict] = None model_kwargs: Optional[dict] = None
"""Key word arguments to pass to the model.""" """Key word arguments to pass to the model."""
@@ -114,6 +115,8 @@ class HuggingFaceHub(LLM):
text = response[0]["generated_text"][len(prompt) :] text = response[0]["generated_text"][len(prompt) :]
elif self.client.task == "text2text-generation": elif self.client.task == "text2text-generation":
text = response[0]["generated_text"] text = response[0]["generated_text"]
elif self.client.task == "summarization":
text = response[0]["summary_text"]
else: else:
raise ValueError( raise ValueError(
f"Got invalid task {self.client.task}, " f"Got invalid task {self.client.task}, "

@@ -11,7 +11,7 @@ from langchain.llms.utils import enforce_stop_tokens
DEFAULT_MODEL_ID = "gpt2" DEFAULT_MODEL_ID = "gpt2"
DEFAULT_TASK = "text-generation" DEFAULT_TASK = "text-generation"
VALID_TASKS = ("text2text-generation", "text-generation") VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -21,7 +21,7 @@ class HuggingFacePipeline(LLM):
To use, you should have the ``transformers`` python package installed. To use, you should have the ``transformers`` python package installed.
Only supports `text-generation` and `text2text-generation` for now. Only supports `text-generation`, `text2text-generation` and `summarization` for now.
Example using from_model_id: Example using from_model_id:
.. code-block:: python .. code-block:: python
@@ -86,7 +86,7 @@ class HuggingFacePipeline(LLM):
try: try:
if task == "text-generation": if task == "text-generation":
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs) model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
elif task == "text2text-generation": elif task in ("text2text-generation", "summarization"):
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs) model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
else: else:
raise ValueError( raise ValueError(
@@ -162,6 +162,8 @@ class HuggingFacePipeline(LLM):
text = response[0]["generated_text"][len(prompt) :] text = response[0]["generated_text"][len(prompt) :]
elif self.pipeline.task == "text2text-generation": elif self.pipeline.task == "text2text-generation":
text = response[0]["generated_text"] text = response[0]["generated_text"]
elif self.pipeline.task == "summarization":
text = response[0]["summary_text"]
else: else:
raise ValueError( raise ValueError(
f"Got invalid task {self.pipeline.task}, " f"Got invalid task {self.pipeline.task}, "

@@ -11,7 +11,7 @@ from langchain.llms.utils import enforce_stop_tokens
DEFAULT_MODEL_ID = "gpt2" DEFAULT_MODEL_ID = "gpt2"
DEFAULT_TASK = "text-generation" DEFAULT_TASK = "text-generation"
VALID_TASKS = ("text2text-generation", "text-generation") VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -35,6 +35,8 @@ def _generate_text(
text = response[0]["generated_text"][len(prompt) :] text = response[0]["generated_text"][len(prompt) :]
elif pipeline.task == "text2text-generation": elif pipeline.task == "text2text-generation":
text = response[0]["generated_text"] text = response[0]["generated_text"]
elif pipeline.task == "summarization":
text = response[0]["summary_text"]
else: else:
raise ValueError( raise ValueError(
f"Got invalid task {pipeline.task}, " f"Got invalid task {pipeline.task}, "
@@ -64,7 +66,7 @@ def _load_transformer(
try: try:
if task == "text-generation": if task == "text-generation":
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs) model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
elif task == "text2text-generation": elif task in ("text2text-generation", "summarization"):
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs) model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
else: else:
raise ValueError( raise ValueError(
@@ -119,7 +121,7 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
To use, you should have the ``runhouse`` python package installed. To use, you should have the ``runhouse`` python package installed.
Only supports `text-generation` and `text2text-generation` for now. Only supports `text-generation`, `text2text-generation` and `summarization` for now.
Example using from_model_id: Example using from_model_id:
.. code-block:: python .. code-block:: python
@@ -153,7 +155,8 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
model_id: str = DEFAULT_MODEL_ID model_id: str = DEFAULT_MODEL_ID
"""Hugging Face model_id to load the model.""" """Hugging Face model_id to load the model."""
task: str = DEFAULT_TASK task: str = DEFAULT_TASK
"""Hugging Face task (either "text-generation" or "text2text-generation").""" """Hugging Face task ("text-generation", "text2text-generation" or
"summarization")."""
device: int = 0 device: int = 0
"""Device to use for inference. -1 for CPU, 0 for GPU, 1 for second GPU, etc.""" """Device to use for inference. -1 for CPU, 0 for GPU, 1 for second GPU, etc."""
model_kwargs: Optional[dict] = None model_kwargs: Optional[dict] = None

@@ -33,6 +33,16 @@ def test_huggingface_endpoint_text2text_generation() -> None:
assert output == "Albany" assert output == "Albany"
@unittest.skip(
"This test requires an inference endpoint. Tested with Hugging Face endpoints"
)
def test_huggingface_endpoint_summarization() -> None:
"""Test valid call to HuggingFace summarization model."""
llm = HuggingFaceEndpoint(endpoint_url="", task="summarization")
output = llm("Say foo:")
assert isinstance(output, str)
def test_huggingface_endpoint_call_error() -> None: def test_huggingface_endpoint_call_error() -> None:
"""Test valid call to HuggingFace that errors.""" """Test valid call to HuggingFace that errors."""
llm = HuggingFaceEndpoint(model_kwargs={"max_new_tokens": -1}) llm = HuggingFaceEndpoint(model_kwargs={"max_new_tokens": -1})

@@ -23,6 +23,13 @@ def test_huggingface_text2text_generation() -> None:
assert output == "Albany" assert output == "Albany"
def test_huggingface_summarization() -> None:
"""Test valid call to HuggingFace summarization model."""
llm = HuggingFaceHub(repo_id="facebook/bart-large-cnn")
output = llm("Say foo:")
assert isinstance(output, str)
def test_huggingface_call_error() -> None: def test_huggingface_call_error() -> None:
"""Test valid call to HuggingFace that errors.""" """Test valid call to HuggingFace that errors."""
llm = HuggingFaceHub(model_kwargs={"max_new_tokens": -1}) llm = HuggingFaceHub(model_kwargs={"max_new_tokens": -1})

@@ -27,6 +27,15 @@ def test_huggingface_pipeline_text2text_generation() -> None:
assert isinstance(output, str) assert isinstance(output, str)
def text_huggingface_pipeline_summarization() -> None:
"""Test valid call to HuggingFace summarization model."""
llm = HuggingFacePipeline.from_model_id(
model_id="facebook/bart-large-cnn", task="summarization"
)
output = llm("Say foo:")
assert isinstance(output, str)
def test_saving_loading_llm(tmp_path: Path) -> None: def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an HuggingFaceHub LLM.""" """Test saving/loading an HuggingFaceHub LLM."""
llm = HuggingFacePipeline.from_model_id( llm = HuggingFacePipeline.from_model_id(

@@ -43,6 +43,19 @@ def test_self_hosted_huggingface_pipeline_text2text_generation() -> None:
assert isinstance(output, str) assert isinstance(output, str)
def test_self_hosted_huggingface_pipeline_summarization() -> None:
"""Test valid call to self-hosted HuggingFace summarization model."""
gpu = get_remote_instance()
llm = SelfHostedHuggingFaceLLM(
model_id="facebook/bart-large-cnn",
task="summarization",
hardware=gpu,
model_reqs=model_reqs,
)
output = llm("Say foo:")
assert isinstance(output, str)
def load_pipeline() -> Any: def load_pipeline() -> Any:
"""Load pipeline for testing.""" """Load pipeline for testing."""
model_id = "gpt2" model_id = "gpt2"