diff --git a/libs/community/langchain_community/llms/huggingface_pipeline.py b/libs/community/langchain_community/llms/huggingface_pipeline.py index 9b2e94db326..388ba117c25 100644 --- a/libs/community/langchain_community/llms/huggingface_pipeline.py +++ b/libs/community/langchain_community/llms/huggingface_pipeline.py @@ -195,12 +195,13 @@ class HuggingFacePipeline(BaseLLM): ) -> LLMResult: # List to hold all results text_generations: List[str] = [] + pipeline_kwargs = kwargs.get("pipeline_kwargs", {}) for i in range(0, len(prompts), self.batch_size): batch_prompts = prompts[i : i + self.batch_size] # Process batch of prompts - responses = self.pipeline(batch_prompts) + responses = self.pipeline(batch_prompts, **pipeline_kwargs) # Process each response in the batch for j, response in enumerate(responses): diff --git a/libs/community/tests/integration_tests/llms/test_huggingface_pipeline.py b/libs/community/tests/integration_tests/llms/test_huggingface_pipeline.py index 885f3b3eaf3..aa6a0e1defe 100755 --- a/libs/community/tests/integration_tests/llms/test_huggingface_pipeline.py +++ b/libs/community/tests/integration_tests/llms/test_huggingface_pipeline.py @@ -69,3 +69,14 @@ def test_init_with_pipeline() -> None: llm = HuggingFacePipeline(pipeline=pipe) output = llm("Say foo:") assert isinstance(output, str) + + +def test_huggingface_pipeline_runtime_kwargs() -> None: + """Test pipelines specifying the device map parameter.""" + llm = HuggingFacePipeline.from_model_id( + model_id="gpt2", + task="text-generation", + ) + prompt = "Say foo:" + output = llm(prompt, pipeline_kwargs={"max_new_tokens": 2}) + assert len(output) < 10