Batching for hf_pipeline (#10795)

The huggingface pipeline in langchain (used for locally hosted models)
does not support batching. If you send in a batch of prompts, it just
processes them serially using the base implementation of _generate:
https://github.com/docugami/langchain/blob/master/libs/langchain/langchain/llms/base.py#L1004C2-L1004C29

This PR adds support for batching in this pipeline, so that GPUs can be
fully saturated. I updated the accompanying notebook to show GPU batch
inference.

---------

Co-authored-by: Taqi Jaffri <tjaffri@docugami.com>
This commit is contained in:
Taqi Jaffri 2023-09-25 10:23:11 -07:00 committed by GitHub
parent aa6e6db8c7
commit b7290f01d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 92 additions and 36 deletions

View File

@ -46,7 +46,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": null,
"id": "165ae236-962a-4763-8052-c4836d78a5d2", "id": "165ae236-962a-4763-8052-c4836d78a5d2",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -75,18 +75,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": null,
"id": "3acf0069", "id": "3acf0069",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
" First, we need to understand what is an electroencephalogram. An electroencephalogram is a recording of brain activity. It is a recording of brain activity that is made by placing electrodes on the scalp. The electrodes are placed\n"
]
}
],
"source": [ "source": [
"from langchain.prompts import PromptTemplate\n", "from langchain.prompts import PromptTemplate\n",
"\n", "\n",
@ -101,6 +93,42 @@
"\n", "\n",
"print(chain.invoke({\"question\": question}))" "print(chain.invoke({\"question\": question}))"
] ]
},
{
"cell_type": "markdown",
"id": "dbbc3a37",
"metadata": {},
"source": [
"### Batch GPU Inference\n",
"\n",
"If running on a device with GPU, you can also run inference on the GPU in batch mode."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "097ba62f",
"metadata": {},
"outputs": [],
"source": [
"gpu_llm = HuggingFacePipeline.from_model_id(\n",
" model_id=\"bigscience/bloom-1b7\",\n",
" task=\"text-generation\",\n",
" device=0, # -1 for CPU\n",
" batch_size=2, # adjust as needed based on GPU map and model size.\n",
" model_kwargs={\"temperature\": 0, \"max_length\": 64},\n",
")\n",
"\n",
"gpu_chain = prompt | gpu_llm.bind(stop=[\"\\n\\n\"])\n",
"\n",
"questions = []\n",
"for i in range(4):\n",
" questions.append({\"question\": f\"What is the number {i} in french?\"})\n",
"\n",
"answers = gpu_chain.batch(questions)\n",
"for answer in answers:\n",
" print(answer)"
]
} }
], ],
"metadata": { "metadata": {
@ -119,7 +147,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.2" "version": "3.8.10"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -1,20 +1,24 @@
from __future__ import annotations
import importlib.util import importlib.util
import logging import logging
from typing import Any, List, Mapping, Optional from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import Extra from langchain.pydantic_v1 import Extra
from langchain.schema import Generation, LLMResult
DEFAULT_MODEL_ID = "gpt2" DEFAULT_MODEL_ID = "gpt2"
DEFAULT_TASK = "text-generation" DEFAULT_TASK = "text-generation"
VALID_TASKS = ("text2text-generation", "text-generation", "summarization") VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
DEFAULT_BATCH_SIZE = 4
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class HuggingFacePipeline(LLM): class HuggingFacePipeline(BaseLLM):
"""HuggingFace Pipeline API. """HuggingFace Pipeline API.
To use, you should have the ``transformers`` python package installed. To use, you should have the ``transformers`` python package installed.
@ -52,6 +56,8 @@ class HuggingFacePipeline(LLM):
"""Key word arguments passed to the model.""" """Key word arguments passed to the model."""
pipeline_kwargs: Optional[dict] = None pipeline_kwargs: Optional[dict] = None
"""Key word arguments passed to the pipeline.""" """Key word arguments passed to the pipeline."""
batch_size: int = DEFAULT_BATCH_SIZE
"""Batch size to use when passing multiple documents to generate."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -66,8 +72,9 @@ class HuggingFacePipeline(LLM):
device: int = -1, device: int = -1,
model_kwargs: Optional[dict] = None, model_kwargs: Optional[dict] = None,
pipeline_kwargs: Optional[dict] = None, pipeline_kwargs: Optional[dict] = None,
batch_size: int = DEFAULT_BATCH_SIZE,
**kwargs: Any, **kwargs: Any,
) -> LLM: ) -> HuggingFacePipeline:
"""Construct the pipeline object from model_id and task.""" """Construct the pipeline object from model_id and task."""
try: try:
from transformers import ( from transformers import (
@ -128,6 +135,7 @@ class HuggingFacePipeline(LLM):
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
device=device, device=device,
batch_size=batch_size,
model_kwargs=_model_kwargs, model_kwargs=_model_kwargs,
**_pipeline_kwargs, **_pipeline_kwargs,
) )
@ -141,6 +149,7 @@ class HuggingFacePipeline(LLM):
model_id=model_id, model_id=model_id,
model_kwargs=_model_kwargs, model_kwargs=_model_kwargs,
pipeline_kwargs=_pipeline_kwargs, pipeline_kwargs=_pipeline_kwargs,
batch_size=batch_size,
**kwargs, **kwargs,
) )
@ -157,28 +166,47 @@ class HuggingFacePipeline(LLM):
def _llm_type(self) -> str: def _llm_type(self) -> str:
return "huggingface_pipeline" return "huggingface_pipeline"
def _call( def _generate(
self, self,
prompt: str, prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> LLMResult:
response = self.pipeline(prompt) # List to hold all results
text_generations: List[str] = []
for i in range(0, len(prompts), self.batch_size):
batch_prompts = prompts[i : i + self.batch_size]
# Process batch of prompts
responses = self.pipeline(batch_prompts)
# Process each response in the batch
for j, response in enumerate(responses):
if isinstance(response, list):
# if model returns multiple generations, pick the top one
response = response[0]
if self.pipeline.task == "text-generation": if self.pipeline.task == "text-generation":
# Text generation return includes the starter text. # Text generation return includes the starter text
text = response[0]["generated_text"][len(prompt) :] text = response["generated_text"][len(batch_prompts[j]) :]
elif self.pipeline.task == "text2text-generation": elif self.pipeline.task == "text2text-generation":
text = response[0]["generated_text"] text = response["generated_text"]
elif self.pipeline.task == "summarization": elif self.pipeline.task == "summarization":
text = response[0]["summary_text"] text = response["summary_text"]
else: else:
raise ValueError( raise ValueError(
f"Got invalid task {self.pipeline.task}, " f"Got invalid task {self.pipeline.task}, "
f"currently only {VALID_TASKS} are supported" f"currently only {VALID_TASKS} are supported"
) )
if stop: if stop:
# This is a bit hacky, but I can't figure out a better way to enforce # Enforce stop tokens
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop) text = enforce_stop_tokens(text, stop)
return text
# Append the processed text to results
text_generations.append(text)
return LLMResult(
generations=[[Generation(text=text)] for text in text_generations]
)