community[patch]: added 'conversational' as a valid task for hugginface endopoint models (#15761)

- **Description:** added the conversational task to hugginFace endpoint
in order to use models designed for chatbot programming.
  - **Dependencies:** None

---------

Co-authored-by: Alessio Serra (ext.) <alessio.serra@partner.bmw.de>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Alessio Serra 2024-01-24 05:04:15 +01:00 committed by GitHub
parent 4c7755778d
commit d628a80a5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,7 +8,12 @@ from langchain_core.utils import get_from_dict_or_env
from langchain_community.llms.utils import enforce_stop_tokens from langchain_community.llms.utils import enforce_stop_tokens
VALID_TASKS = ("text2text-generation", "text-generation", "summarization") VALID_TASKS = (
"text2text-generation",
"text-generation",
"summarization",
"conversational",
)
class HuggingFaceEndpoint(LLM): class HuggingFaceEndpoint(LLM):
@ -144,6 +149,8 @@ class HuggingFaceEndpoint(LLM):
text = generated_text[0]["generated_text"] text = generated_text[0]["generated_text"]
elif self.task == "summarization": elif self.task == "summarization":
text = generated_text[0]["summary_text"] text = generated_text[0]["summary_text"]
elif self.task == "conversational":
text = generated_text["response"][1]
else: else:
raise ValueError( raise ValueError(
f"Got invalid task {self.task}, " f"Got invalid task {self.task}, "