mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 13:27:36 +00:00
community[patch]: added 'conversational' as a valid task for hugginface endopoint models (#15761)
- **Description:** added the conversational task to hugginFace endpoint in order to use models designed for chatbot programming. - **Dependencies:** None --------- Co-authored-by: Alessio Serra (ext.) <alessio.serra@partner.bmw.de> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
4c7755778d
commit
d628a80a5d
@ -8,7 +8,12 @@ from langchain_core.utils import get_from_dict_or_env
|
|||||||
|
|
||||||
from langchain_community.llms.utils import enforce_stop_tokens
|
from langchain_community.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
VALID_TASKS = (
|
||||||
|
"text2text-generation",
|
||||||
|
"text-generation",
|
||||||
|
"summarization",
|
||||||
|
"conversational",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceEndpoint(LLM):
|
class HuggingFaceEndpoint(LLM):
|
||||||
@ -144,6 +149,8 @@ class HuggingFaceEndpoint(LLM):
|
|||||||
text = generated_text[0]["generated_text"]
|
text = generated_text[0]["generated_text"]
|
||||||
elif self.task == "summarization":
|
elif self.task == "summarization":
|
||||||
text = generated_text[0]["summary_text"]
|
text = generated_text[0]["summary_text"]
|
||||||
|
elif self.task == "conversational":
|
||||||
|
text = generated_text["response"][1]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Got invalid task {self.task}, "
|
f"Got invalid task {self.task}, "
|
||||||
|
Loading…
Reference in New Issue
Block a user