mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-07 20:15:40 +00:00
community[minor]: Add support for translation in HuggingFacePipeline (#19190)
- [x] **Support for translation**: "community: Add support for translation in `HuggingFacePipeline`" - [x] **Add support for translation in `HuggingFacePipeline`**: - **Description:** Add support for translation in `HuggingFacePipeline`, which earlier used to support only text summarization and generation. - **Issue:** N/A - **Dependencies:** N/A - **Twitter handle:** None
This commit is contained in:
parent
a1b26dd9b6
commit
635b3372bd
@ -11,7 +11,12 @@ from langchain_core.pydantic_v1 import Extra
|
||||
|
||||
DEFAULT_MODEL_ID = "gpt2"
|
||||
DEFAULT_TASK = "text-generation"
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
VALID_TASKS = (
|
||||
"text2text-generation",
|
||||
"text-generation",
|
||||
"summarization",
|
||||
"translation",
|
||||
)
|
||||
DEFAULT_BATCH_SIZE = 4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -121,7 +126,7 @@ class HuggingFacePipeline(BaseLLM):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, **_model_kwargs
|
||||
)
|
||||
elif task in ("text2text-generation", "summarization"):
|
||||
elif task in ("text2text-generation", "summarization", "translation"):
|
||||
if backend == "openvino":
|
||||
try:
|
||||
from optimum.intel.openvino import OVModelForSeq2SeqLM
|
||||
@ -260,8 +265,6 @@ class HuggingFacePipeline(BaseLLM):
|
||||
# Process batch of prompts
|
||||
responses = self.pipeline(
|
||||
batch_prompts,
|
||||
stop_sequence=stop,
|
||||
return_full_text=False,
|
||||
**pipeline_kwargs,
|
||||
)
|
||||
|
||||
@ -277,6 +280,8 @@ class HuggingFacePipeline(BaseLLM):
|
||||
text = response["generated_text"]
|
||||
elif self.pipeline.task == "summarization":
|
||||
text = response["summary_text"]
|
||||
elif self.pipeline.task in "translation":
|
||||
text = response["translation_text"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {self.pipeline.task}, "
|
||||
|
Loading…
Reference in New Issue
Block a user