Batching for hf_pipeline (#10795)

The huggingface pipeline in langchain (used for locally hosted models)
does not support batching. If you send in a batch of prompts, it just
processes them serially using the base implementation of _generate:
https://github.com/docugami/langchain/blob/master/libs/langchain/langchain/llms/base.py#L1004C2-L1004C29

This PR adds support for batching in this pipeline, so that GPUs can be
fully saturated. I updated the accompanying notebook to show GPU batch
inference.

---------

Co-authored-by: Taqi Jaffri <tjaffri@docugami.com>
This commit is contained in:
Taqi Jaffri
2023-09-25 10:23:11 -07:00
committed by GitHub
parent aa6e6db8c7
commit b7290f01d8
2 changed files with 92 additions and 36 deletions

View File

@@ -1,20 +1,24 @@
from __future__ import annotations
import importlib.util
import logging
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import Extra
from langchain.schema import Generation, LLMResult
DEFAULT_MODEL_ID = "gpt2"
DEFAULT_TASK = "text-generation"
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
DEFAULT_BATCH_SIZE = 4
logger = logging.getLogger(__name__)
class HuggingFacePipeline(LLM):
class HuggingFacePipeline(BaseLLM):
"""HuggingFace Pipeline API.
To use, you should have the ``transformers`` python package installed.
@@ -52,6 +56,8 @@ class HuggingFacePipeline(LLM):
"""Key word arguments passed to the model."""
pipeline_kwargs: Optional[dict] = None
"""Key word arguments passed to the pipeline."""
batch_size: int = DEFAULT_BATCH_SIZE
"""Batch size to use when passing multiple documents to generate."""
class Config:
"""Configuration for this pydantic object."""
@@ -66,8 +72,9 @@ class HuggingFacePipeline(LLM):
device: int = -1,
model_kwargs: Optional[dict] = None,
pipeline_kwargs: Optional[dict] = None,
batch_size: int = DEFAULT_BATCH_SIZE,
**kwargs: Any,
) -> LLM:
) -> HuggingFacePipeline:
"""Construct the pipeline object from model_id and task."""
try:
from transformers import (
@@ -128,6 +135,7 @@ class HuggingFacePipeline(LLM):
model=model,
tokenizer=tokenizer,
device=device,
batch_size=batch_size,
model_kwargs=_model_kwargs,
**_pipeline_kwargs,
)
@@ -141,6 +149,7 @@ class HuggingFacePipeline(LLM):
model_id=model_id,
model_kwargs=_model_kwargs,
pipeline_kwargs=_pipeline_kwargs,
batch_size=batch_size,
**kwargs,
)
@@ -157,28 +166,47 @@ class HuggingFacePipeline(LLM):
def _llm_type(self) -> str:
return "huggingface_pipeline"
def _call(
def _generate(
self,
prompt: str,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
response = self.pipeline(prompt)
if self.pipeline.task == "text-generation":
# Text generation return includes the starter text.
text = response[0]["generated_text"][len(prompt) :]
elif self.pipeline.task == "text2text-generation":
text = response[0]["generated_text"]
elif self.pipeline.task == "summarization":
text = response[0]["summary_text"]
else:
raise ValueError(
f"Got invalid task {self.pipeline.task}, "
f"currently only {VALID_TASKS} are supported"
)
if stop:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text
) -> LLMResult:
# List to hold all results
text_generations: List[str] = []
for i in range(0, len(prompts), self.batch_size):
batch_prompts = prompts[i : i + self.batch_size]
# Process batch of prompts
responses = self.pipeline(batch_prompts)
# Process each response in the batch
for j, response in enumerate(responses):
if isinstance(response, list):
# if model returns multiple generations, pick the top one
response = response[0]
if self.pipeline.task == "text-generation":
# Text generation return includes the starter text
text = response["generated_text"][len(batch_prompts[j]) :]
elif self.pipeline.task == "text2text-generation":
text = response["generated_text"]
elif self.pipeline.task == "summarization":
text = response["summary_text"]
else:
raise ValueError(
f"Got invalid task {self.pipeline.task}, "
f"currently only {VALID_TASKS} are supported"
)
if stop:
# Enforce stop tokens
text = enforce_stop_tokens(text, stop)
# Append the processed text to results
text_generations.append(text)
return LLMResult(
generations=[[Generation(text=text)] for text in text_generations]
)