mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 05:13:46 +00:00
Batching for hf_pipeline (#10795)
The huggingface pipeline in langchain (used for locally hosted models) does not support batching. If you send in a batch of prompts, it just processes them serially using the base implementation of _generate: https://github.com/docugami/langchain/blob/master/libs/langchain/langchain/llms/base.py#L1004C2-L1004C29 This PR adds support for batching in this pipeline, so that GPUs can be fully saturated. I updated the accompanying notebook to show GPU batch inference. --------- Co-authored-by: Taqi Jaffri <tjaffri@docugami.com>
This commit is contained in:
parent
aa6e6db8c7
commit
b7290f01d8
@ -46,7 +46,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": null,
|
||||||
"id": "165ae236-962a-4763-8052-c4836d78a5d2",
|
"id": "165ae236-962a-4763-8052-c4836d78a5d2",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -75,18 +75,10 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": null,
|
||||||
"id": "3acf0069",
|
"id": "3acf0069",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" First, we need to understand what is an electroencephalogram. An electroencephalogram is a recording of brain activity. It is a recording of brain activity that is made by placing electrodes on the scalp. The electrodes are placed\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.prompts import PromptTemplate\n",
|
"from langchain.prompts import PromptTemplate\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -101,6 +93,42 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"print(chain.invoke({\"question\": question}))"
|
"print(chain.invoke({\"question\": question}))"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "dbbc3a37",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Batch GPU Inference\n",
|
||||||
|
"\n",
|
||||||
|
"If running on a device with GPU, you can also run inference on the GPU in batch mode."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "097ba62f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"gpu_llm = HuggingFacePipeline.from_model_id(\n",
|
||||||
|
" model_id=\"bigscience/bloom-1b7\",\n",
|
||||||
|
" task=\"text-generation\",\n",
|
||||||
|
" device=0, # -1 for CPU\n",
|
||||||
|
" batch_size=2, # adjust as needed based on GPU map and model size.\n",
|
||||||
|
" model_kwargs={\"temperature\": 0, \"max_length\": 64},\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"gpu_chain = prompt | gpu_llm.bind(stop=[\"\\n\\n\"])\n",
|
||||||
|
"\n",
|
||||||
|
"questions = []\n",
|
||||||
|
"for i in range(4):\n",
|
||||||
|
" questions.append({\"question\": f\"What is the number {i} in french?\"})\n",
|
||||||
|
"\n",
|
||||||
|
"answers = gpu_chain.batch(questions)\n",
|
||||||
|
"for answer in answers:\n",
|
||||||
|
" print(answer)"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -119,7 +147,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.2"
|
"version": "3.8.10"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -1,20 +1,24 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, List, Mapping, Optional
|
from typing import Any, List, Mapping, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from langchain.pydantic_v1 import Extra
|
from langchain.pydantic_v1 import Extra
|
||||||
|
from langchain.schema import Generation, LLMResult
|
||||||
|
|
||||||
DEFAULT_MODEL_ID = "gpt2"
|
DEFAULT_MODEL_ID = "gpt2"
|
||||||
DEFAULT_TASK = "text-generation"
|
DEFAULT_TASK = "text-generation"
|
||||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||||
|
DEFAULT_BATCH_SIZE = 4
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HuggingFacePipeline(LLM):
|
class HuggingFacePipeline(BaseLLM):
|
||||||
"""HuggingFace Pipeline API.
|
"""HuggingFace Pipeline API.
|
||||||
|
|
||||||
To use, you should have the ``transformers`` python package installed.
|
To use, you should have the ``transformers`` python package installed.
|
||||||
@ -52,6 +56,8 @@ class HuggingFacePipeline(LLM):
|
|||||||
"""Key word arguments passed to the model."""
|
"""Key word arguments passed to the model."""
|
||||||
pipeline_kwargs: Optional[dict] = None
|
pipeline_kwargs: Optional[dict] = None
|
||||||
"""Key word arguments passed to the pipeline."""
|
"""Key word arguments passed to the pipeline."""
|
||||||
|
batch_size: int = DEFAULT_BATCH_SIZE
|
||||||
|
"""Batch size to use when passing multiple documents to generate."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -66,8 +72,9 @@ class HuggingFacePipeline(LLM):
|
|||||||
device: int = -1,
|
device: int = -1,
|
||||||
model_kwargs: Optional[dict] = None,
|
model_kwargs: Optional[dict] = None,
|
||||||
pipeline_kwargs: Optional[dict] = None,
|
pipeline_kwargs: Optional[dict] = None,
|
||||||
|
batch_size: int = DEFAULT_BATCH_SIZE,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLM:
|
) -> HuggingFacePipeline:
|
||||||
"""Construct the pipeline object from model_id and task."""
|
"""Construct the pipeline object from model_id and task."""
|
||||||
try:
|
try:
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@ -128,6 +135,7 @@ class HuggingFacePipeline(LLM):
|
|||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
device=device,
|
device=device,
|
||||||
|
batch_size=batch_size,
|
||||||
model_kwargs=_model_kwargs,
|
model_kwargs=_model_kwargs,
|
||||||
**_pipeline_kwargs,
|
**_pipeline_kwargs,
|
||||||
)
|
)
|
||||||
@ -141,6 +149,7 @@ class HuggingFacePipeline(LLM):
|
|||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_kwargs=_model_kwargs,
|
model_kwargs=_model_kwargs,
|
||||||
pipeline_kwargs=_pipeline_kwargs,
|
pipeline_kwargs=_pipeline_kwargs,
|
||||||
|
batch_size=batch_size,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -157,28 +166,47 @@ class HuggingFacePipeline(LLM):
|
|||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "huggingface_pipeline"
|
return "huggingface_pipeline"
|
||||||
|
|
||||||
def _call(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> LLMResult:
|
||||||
response = self.pipeline(prompt)
|
# List to hold all results
|
||||||
|
text_generations: List[str] = []
|
||||||
|
|
||||||
|
for i in range(0, len(prompts), self.batch_size):
|
||||||
|
batch_prompts = prompts[i : i + self.batch_size]
|
||||||
|
|
||||||
|
# Process batch of prompts
|
||||||
|
responses = self.pipeline(batch_prompts)
|
||||||
|
|
||||||
|
# Process each response in the batch
|
||||||
|
for j, response in enumerate(responses):
|
||||||
|
if isinstance(response, list):
|
||||||
|
# if model returns multiple generations, pick the top one
|
||||||
|
response = response[0]
|
||||||
|
|
||||||
if self.pipeline.task == "text-generation":
|
if self.pipeline.task == "text-generation":
|
||||||
# Text generation return includes the starter text.
|
# Text generation return includes the starter text
|
||||||
text = response[0]["generated_text"][len(prompt) :]
|
text = response["generated_text"][len(batch_prompts[j]) :]
|
||||||
elif self.pipeline.task == "text2text-generation":
|
elif self.pipeline.task == "text2text-generation":
|
||||||
text = response[0]["generated_text"]
|
text = response["generated_text"]
|
||||||
elif self.pipeline.task == "summarization":
|
elif self.pipeline.task == "summarization":
|
||||||
text = response[0]["summary_text"]
|
text = response["summary_text"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Got invalid task {self.pipeline.task}, "
|
f"Got invalid task {self.pipeline.task}, "
|
||||||
f"currently only {VALID_TASKS} are supported"
|
f"currently only {VALID_TASKS} are supported"
|
||||||
)
|
)
|
||||||
if stop:
|
if stop:
|
||||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
# Enforce stop tokens
|
||||||
# stop tokens when making calls to huggingface_hub.
|
|
||||||
text = enforce_stop_tokens(text, stop)
|
text = enforce_stop_tokens(text, stop)
|
||||||
return text
|
|
||||||
|
# Append the processed text to results
|
||||||
|
text_generations.append(text)
|
||||||
|
|
||||||
|
return LLMResult(
|
||||||
|
generations=[[Generation(text=text)] for text in text_generations]
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user