community[minor]: Add support for translation in HuggingFacePipeline (#19190)

- [x] **Support for translation**: "community: Add support for
translation in `HuggingFacePipeline`"


- [x] **Add support for translation in `HuggingFacePipeline`**:
- **Description:** Add support for translation in `HuggingFacePipeline`,
which earlier used to support only text summarization and generation.
    - **Issue:** N/A
    - **Dependencies:** N/A
    - **Twitter handle:** None
This commit is contained in:
Nikhil Kumar 2024-03-16 17:48:13 -07:00 committed by GitHub
parent a1b26dd9b6
commit 635b3372bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,7 +11,12 @@ from langchain_core.pydantic_v1 import Extra
DEFAULT_MODEL_ID = "gpt2"
DEFAULT_TASK = "text-generation"
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
VALID_TASKS = (
"text2text-generation",
"text-generation",
"summarization",
"translation",
)
DEFAULT_BATCH_SIZE = 4
logger = logging.getLogger(__name__)
@ -121,7 +126,7 @@ class HuggingFacePipeline(BaseLLM):
model = AutoModelForCausalLM.from_pretrained(
model_id, **_model_kwargs
)
elif task in ("text2text-generation", "summarization"):
elif task in ("text2text-generation", "summarization", "translation"):
if backend == "openvino":
try:
from optimum.intel.openvino import OVModelForSeq2SeqLM
@ -260,8 +265,6 @@ class HuggingFacePipeline(BaseLLM):
# Process batch of prompts
responses = self.pipeline(
batch_prompts,
stop_sequence=stop,
return_full_text=False,
**pipeline_kwargs,
)
@ -277,6 +280,8 @@ class HuggingFacePipeline(BaseLLM):
text = response["generated_text"]
elif self.pipeline.task == "summarization":
text = response["summary_text"]
elif self.pipeline.task in "translation":
text = response["translation_text"]
else:
raise ValueError(
f"Got invalid task {self.pipeline.task}, "