diff --git a/docs/extras/integrations/llms/huggingface_pipelines.ipynb b/docs/extras/integrations/llms/huggingface_pipelines.ipynb index 387f61e2b4f..3fd8e0a0a6c 100644 --- a/docs/extras/integrations/llms/huggingface_pipelines.ipynb +++ b/docs/extras/integrations/llms/huggingface_pipelines.ipynb @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "165ae236-962a-4763-8052-c4836d78a5d2", "metadata": { "tags": [] @@ -75,18 +75,10 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "3acf0069", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " First, we need to understand what is an electroencephalogram. An electroencephalogram is a recording of brain activity. It is a recording of brain activity that is made by placing electrodes on the scalp. The electrodes are placed\n" - ] - } - ], + "outputs": [], "source": [ "from langchain.prompts import PromptTemplate\n", "\n", @@ -101,6 +93,42 @@ "\n", "print(chain.invoke({\"question\": question}))" ] + }, + { + "cell_type": "markdown", + "id": "dbbc3a37", + "metadata": {}, + "source": [ + "### Batch GPU Inference\n", + "\n", + "If running on a device with GPU, you can also run inference on the GPU in batch mode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "097ba62f", + "metadata": {}, + "outputs": [], + "source": [ + "gpu_llm = HuggingFacePipeline.from_model_id(\n", + " model_id=\"bigscience/bloom-1b7\",\n", + " task=\"text-generation\",\n", + " device=0, # -1 for CPU\n", + " batch_size=2, # adjust as needed based on GPU map and model size.\n", + " model_kwargs={\"temperature\": 0, \"max_length\": 64},\n", + ")\n", + "\n", + "gpu_chain = prompt | gpu_llm.bind(stop=[\"\\n\\n\"])\n", + "\n", + "questions = []\n", + "for i in range(4):\n", + " questions.append({\"question\": f\"What is the number {i} in french?\"})\n", + "\n", + "answers = gpu_chain.batch(questions)\n", + "for answer in answers:\n", + " print(answer)" + ] } ], "metadata": { @@ -119,7 +147,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/libs/langchain/langchain/llms/huggingface_pipeline.py b/libs/langchain/langchain/llms/huggingface_pipeline.py index 9c32787928d..336b61842ed 100644 --- a/libs/langchain/langchain/llms/huggingface_pipeline.py +++ b/libs/langchain/langchain/llms/huggingface_pipeline.py @@ -1,20 +1,24 @@ +from __future__ import annotations + import importlib.util import logging from typing import Any, List, Mapping, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.llms.utils import enforce_stop_tokens from langchain.pydantic_v1 import Extra +from langchain.schema import Generation, LLMResult DEFAULT_MODEL_ID = "gpt2" DEFAULT_TASK = "text-generation" VALID_TASKS = ("text2text-generation", "text-generation", "summarization") +DEFAULT_BATCH_SIZE = 4 logger = logging.getLogger(__name__) -class HuggingFacePipeline(LLM): +class HuggingFacePipeline(BaseLLM): """HuggingFace Pipeline API. To use, you should have the ``transformers`` python package installed. @@ -52,6 +56,8 @@ class HuggingFacePipeline(LLM): """Key word arguments passed to the model.""" pipeline_kwargs: Optional[dict] = None """Key word arguments passed to the pipeline.""" + batch_size: int = DEFAULT_BATCH_SIZE + """Batch size to use when passing multiple documents to generate.""" class Config: """Configuration for this pydantic object.""" @@ -66,8 +72,9 @@ class HuggingFacePipeline(LLM): device: int = -1, model_kwargs: Optional[dict] = None, pipeline_kwargs: Optional[dict] = None, + batch_size: int = DEFAULT_BATCH_SIZE, **kwargs: Any, - ) -> LLM: + ) -> HuggingFacePipeline: """Construct the pipeline object from model_id and task.""" try: from transformers import ( @@ -128,6 +135,7 @@ class HuggingFacePipeline(LLM): model=model, tokenizer=tokenizer, device=device, + batch_size=batch_size, model_kwargs=_model_kwargs, **_pipeline_kwargs, ) @@ -141,6 +149,7 @@ class HuggingFacePipeline(LLM): model_id=model_id, model_kwargs=_model_kwargs, pipeline_kwargs=_pipeline_kwargs, + batch_size=batch_size, **kwargs, ) @@ -157,28 +166,47 @@ class HuggingFacePipeline(LLM): def _llm_type(self) -> str: return "huggingface_pipeline" - def _call( + def _generate( self, - prompt: str, + prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> str: - response = self.pipeline(prompt) - if self.pipeline.task == "text-generation": - # Text generation return includes the starter text. - text = response[0]["generated_text"][len(prompt) :] - elif self.pipeline.task == "text2text-generation": - text = response[0]["generated_text"] - elif self.pipeline.task == "summarization": - text = response[0]["summary_text"] - else: - raise ValueError( - f"Got invalid task {self.pipeline.task}, " - f"currently only {VALID_TASKS} are supported" - ) - if stop: - # This is a bit hacky, but I can't figure out a better way to enforce - # stop tokens when making calls to huggingface_hub. - text = enforce_stop_tokens(text, stop) - return text + ) -> LLMResult: + # List to hold all results + text_generations: List[str] = [] + + for i in range(0, len(prompts), self.batch_size): + batch_prompts = prompts[i : i + self.batch_size] + + # Process batch of prompts + responses = self.pipeline(batch_prompts) + + # Process each response in the batch + for j, response in enumerate(responses): + if isinstance(response, list): + # if model returns multiple generations, pick the top one + response = response[0] + + if self.pipeline.task == "text-generation": + # Text generation return includes the starter text + text = response["generated_text"][len(batch_prompts[j]) :] + elif self.pipeline.task == "text2text-generation": + text = response["generated_text"] + elif self.pipeline.task == "summarization": + text = response["summary_text"] + else: + raise ValueError( + f"Got invalid task {self.pipeline.task}, " + f"currently only {VALID_TASKS} are supported" + ) + if stop: + # Enforce stop tokens + text = enforce_stop_tokens(text, stop) + + # Append the processed text to results + text_generations.append(text) + + return LLMResult( + generations=[[Generation(text=text)] for text in text_generations] + )