From d628a80a5d4a0e1bb9d79e37483d24187b792c2a Mon Sep 17 00:00:00 2001 From: Alessio Serra <48280936+alessioserra@users.noreply.github.com> Date: Wed, 24 Jan 2024 05:04:15 +0100 Subject: [PATCH] community[patch]: added 'conversational' as a valid task for hugginface endopoint models (#15761) - **Description:** added the conversational task to hugginFace endpoint in order to use models designed for chatbot programming. - **Dependencies:** None --------- Co-authored-by: Alessio Serra (ext.) Co-authored-by: Harrison Chase Co-authored-by: Bagatur --- .../langchain_community/llms/huggingface_endpoint.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/libs/community/langchain_community/llms/huggingface_endpoint.py b/libs/community/langchain_community/llms/huggingface_endpoint.py index d429e2fd935..c14b2e24a80 100644 --- a/libs/community/langchain_community/llms/huggingface_endpoint.py +++ b/libs/community/langchain_community/llms/huggingface_endpoint.py @@ -8,7 +8,12 @@ from langchain_core.utils import get_from_dict_or_env from langchain_community.llms.utils import enforce_stop_tokens -VALID_TASKS = ("text2text-generation", "text-generation", "summarization") +VALID_TASKS = ( + "text2text-generation", + "text-generation", + "summarization", + "conversational", +) class HuggingFaceEndpoint(LLM): @@ -144,6 +149,8 @@ class HuggingFaceEndpoint(LLM): text = generated_text[0]["generated_text"] elif self.task == "summarization": text = generated_text[0]["summary_text"] + elif self.task == "conversational": + text = generated_text["response"][1] else: raise ValueError( f"Got invalid task {self.task}, "