mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 23:12:38 +00:00
Batching for hf_pipeline (#10795)
The huggingface pipeline in langchain (used for locally hosted models) does not support batching. If you send in a batch of prompts, it just processes them serially using the base implementation of _generate: https://github.com/docugami/langchain/blob/master/libs/langchain/langchain/llms/base.py#L1004C2-L1004C29 This PR adds support for batching in this pipeline, so that GPUs can be fully saturated. I updated the accompanying notebook to show GPU batch inference. --------- Co-authored-by: Taqi Jaffri <tjaffri@docugami.com>
This commit is contained in:
@@ -1,20 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import Extra
|
||||
from langchain.schema import Generation, LLMResult
|
||||
|
||||
DEFAULT_MODEL_ID = "gpt2"
|
||||
DEFAULT_TASK = "text-generation"
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
DEFAULT_BATCH_SIZE = 4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingFacePipeline(LLM):
|
||||
class HuggingFacePipeline(BaseLLM):
|
||||
"""HuggingFace Pipeline API.
|
||||
|
||||
To use, you should have the ``transformers`` python package installed.
|
||||
@@ -52,6 +56,8 @@ class HuggingFacePipeline(LLM):
|
||||
"""Key word arguments passed to the model."""
|
||||
pipeline_kwargs: Optional[dict] = None
|
||||
"""Key word arguments passed to the pipeline."""
|
||||
batch_size: int = DEFAULT_BATCH_SIZE
|
||||
"""Batch size to use when passing multiple documents to generate."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -66,8 +72,9 @@ class HuggingFacePipeline(LLM):
|
||||
device: int = -1,
|
||||
model_kwargs: Optional[dict] = None,
|
||||
pipeline_kwargs: Optional[dict] = None,
|
||||
batch_size: int = DEFAULT_BATCH_SIZE,
|
||||
**kwargs: Any,
|
||||
) -> LLM:
|
||||
) -> HuggingFacePipeline:
|
||||
"""Construct the pipeline object from model_id and task."""
|
||||
try:
|
||||
from transformers import (
|
||||
@@ -128,6 +135,7 @@ class HuggingFacePipeline(LLM):
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
batch_size=batch_size,
|
||||
model_kwargs=_model_kwargs,
|
||||
**_pipeline_kwargs,
|
||||
)
|
||||
@@ -141,6 +149,7 @@ class HuggingFacePipeline(LLM):
|
||||
model_id=model_id,
|
||||
model_kwargs=_model_kwargs,
|
||||
pipeline_kwargs=_pipeline_kwargs,
|
||||
batch_size=batch_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -157,28 +166,47 @@ class HuggingFacePipeline(LLM):
|
||||
def _llm_type(self) -> str:
|
||||
return "huggingface_pipeline"
|
||||
|
||||
def _call(
|
||||
def _generate(
|
||||
self,
|
||||
prompt: str,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
response = self.pipeline(prompt)
|
||||
if self.pipeline.task == "text-generation":
|
||||
# Text generation return includes the starter text.
|
||||
text = response[0]["generated_text"][len(prompt) :]
|
||||
elif self.pipeline.task == "text2text-generation":
|
||||
text = response[0]["generated_text"]
|
||||
elif self.pipeline.task == "summarization":
|
||||
text = response[0]["summary_text"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {self.pipeline.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
if stop:
|
||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||
# stop tokens when making calls to huggingface_hub.
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
) -> LLMResult:
|
||||
# List to hold all results
|
||||
text_generations: List[str] = []
|
||||
|
||||
for i in range(0, len(prompts), self.batch_size):
|
||||
batch_prompts = prompts[i : i + self.batch_size]
|
||||
|
||||
# Process batch of prompts
|
||||
responses = self.pipeline(batch_prompts)
|
||||
|
||||
# Process each response in the batch
|
||||
for j, response in enumerate(responses):
|
||||
if isinstance(response, list):
|
||||
# if model returns multiple generations, pick the top one
|
||||
response = response[0]
|
||||
|
||||
if self.pipeline.task == "text-generation":
|
||||
# Text generation return includes the starter text
|
||||
text = response["generated_text"][len(batch_prompts[j]) :]
|
||||
elif self.pipeline.task == "text2text-generation":
|
||||
text = response["generated_text"]
|
||||
elif self.pipeline.task == "summarization":
|
||||
text = response["summary_text"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {self.pipeline.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
if stop:
|
||||
# Enforce stop tokens
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
# Append the processed text to results
|
||||
text_generations.append(text)
|
||||
|
||||
return LLMResult(
|
||||
generations=[[Generation(text=text)] for text in text_generations]
|
||||
)
|
||||
|
Reference in New Issue
Block a user