From 641efcf41c8358655dcf1db2fbe827cfe77feae0 Mon Sep 17 00:00:00 2001 From: Armin Stepanyan Date: Thu, 8 Feb 2024 21:58:31 +0000 Subject: [PATCH] community: add runtime kwargs to HuggingFacePipeline (#17005) This PR enables changing the behaviour of huggingface pipeline between different calls. For example, before this PR there's no way of changing maximum generation length between different invocations of the chain. This is desirable in cases, such as when we want to scale the maximum output size depending on a dynamic prompt size. Usage example: ```python from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline model_id = "gpt2" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) hf = HuggingFacePipeline(pipeline=pipe) hf("Say foo:", pipeline_kwargs={"max_new_tokens": 42}) ``` --------- Co-authored-by: Bagatur --- .../langchain_community/llms/huggingface_pipeline.py | 3 ++- .../llms/test_huggingface_pipeline.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/libs/community/langchain_community/llms/huggingface_pipeline.py b/libs/community/langchain_community/llms/huggingface_pipeline.py index 9b2e94db326..388ba117c25 100644 --- a/libs/community/langchain_community/llms/huggingface_pipeline.py +++ b/libs/community/langchain_community/llms/huggingface_pipeline.py @@ -195,12 +195,13 @@ class HuggingFacePipeline(BaseLLM): ) -> LLMResult: # List to hold all results text_generations: List[str] = [] + pipeline_kwargs = kwargs.get("pipeline_kwargs", {}) for i in range(0, len(prompts), self.batch_size): batch_prompts = prompts[i : i + self.batch_size] # Process batch of prompts - responses = self.pipeline(batch_prompts) + responses = self.pipeline(batch_prompts, **pipeline_kwargs) # Process each response in the batch for j, response in enumerate(responses): diff --git a/libs/community/tests/integration_tests/llms/test_huggingface_pipeline.py b/libs/community/tests/integration_tests/llms/test_huggingface_pipeline.py index 885f3b3eaf3..aa6a0e1defe 100755 --- a/libs/community/tests/integration_tests/llms/test_huggingface_pipeline.py +++ b/libs/community/tests/integration_tests/llms/test_huggingface_pipeline.py @@ -69,3 +69,14 @@ def test_init_with_pipeline() -> None: llm = HuggingFacePipeline(pipeline=pipe) output = llm("Say foo:") assert isinstance(output, str) + + +def test_huggingface_pipeline_runtime_kwargs() -> None: + """Test pipelines specifying the device map parameter.""" + llm = HuggingFacePipeline.from_model_id( + model_id="gpt2", + task="text-generation", + ) + prompt = "Say foo:" + output = llm(prompt, pipeline_kwargs={"max_new_tokens": 2}) + assert len(output) < 10