From 635b3372bdd3fbcaa724caa848e1e5f7eef1c2a1 Mon Sep 17 00:00:00 2001 From: Nikhil Kumar <64120577+nikhilkmr300@users.noreply.github.com> Date: Sat, 16 Mar 2024 17:48:13 -0700 Subject: [PATCH] community[minor]: Add support for translation in HuggingFacePipeline (#19190) - [x] **Support for translation**: "community: Add support for translation in `HuggingFacePipeline`" - [x] **Add support for translation in `HuggingFacePipeline`**: - **Description:** Add support for translation in `HuggingFacePipeline`, which earlier used to support only text summarization and generation. - **Issue:** N/A - **Dependencies:** N/A - **Twitter handle:** None --- .../llms/huggingface_pipeline.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/libs/community/langchain_community/llms/huggingface_pipeline.py b/libs/community/langchain_community/llms/huggingface_pipeline.py index f95a994d68f..e989cd16da8 100644 --- a/libs/community/langchain_community/llms/huggingface_pipeline.py +++ b/libs/community/langchain_community/llms/huggingface_pipeline.py @@ -11,7 +11,12 @@ from langchain_core.pydantic_v1 import Extra DEFAULT_MODEL_ID = "gpt2" DEFAULT_TASK = "text-generation" -VALID_TASKS = ("text2text-generation", "text-generation", "summarization") +VALID_TASKS = ( + "text2text-generation", + "text-generation", + "summarization", + "translation", +) DEFAULT_BATCH_SIZE = 4 logger = logging.getLogger(__name__) @@ -121,7 +126,7 @@ class HuggingFacePipeline(BaseLLM): model = AutoModelForCausalLM.from_pretrained( model_id, **_model_kwargs ) - elif task in ("text2text-generation", "summarization"): + elif task in ("text2text-generation", "summarization", "translation"): if backend == "openvino": try: from optimum.intel.openvino import OVModelForSeq2SeqLM @@ -260,8 +265,6 @@ class HuggingFacePipeline(BaseLLM): # Process batch of prompts responses = self.pipeline( batch_prompts, - stop_sequence=stop, - return_full_text=False, **pipeline_kwargs, ) @@ -277,6 +280,8 @@ class HuggingFacePipeline(BaseLLM): text = response["generated_text"] elif self.pipeline.task == "summarization": text = response["summary_text"] + elif self.pipeline.task in "translation": + text = response["translation_text"] else: raise ValueError( f"Got invalid task {self.pipeline.task}, "