mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-05 22:53:30 +00:00
Feature/enhance huggingfacepipeline to handle different return type (#11394)
**Description:** Avoid huggingfacepipeline to truncate the response if user setup return_full_text as False within huggingface pipeline. **Dependencies:** : None **Tag maintainer:** Maybe @sam-h-bean ? --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
2aba9ab47e
commit
0b743f005b
@ -202,8 +202,23 @@ class HuggingFacePipeline(BaseLLM):
|
|||||||
response = response[0]
|
response = response[0]
|
||||||
|
|
||||||
if self.pipeline.task == "text-generation":
|
if self.pipeline.task == "text-generation":
|
||||||
# Text generation return includes the starter text
|
try:
|
||||||
text = response["generated_text"][len(batch_prompts[j]) :]
|
from transformers.pipelines.text_generation import ReturnType
|
||||||
|
|
||||||
|
remove_prompt = (
|
||||||
|
self.pipeline._postprocess_params.get("return_type")
|
||||||
|
!= ReturnType.NEW_TEXT
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Unable to extract pipeline return_type. "
|
||||||
|
f"Received error:\n\n{e}"
|
||||||
|
)
|
||||||
|
remove_prompt = True
|
||||||
|
if remove_prompt:
|
||||||
|
text = response["generated_text"][len(batch_prompts[j]) :]
|
||||||
|
else:
|
||||||
|
text = response["generated_text"]
|
||||||
elif self.pipeline.task == "text2text-generation":
|
elif self.pipeline.task == "text2text-generation":
|
||||||
text = response["generated_text"]
|
text = response["generated_text"]
|
||||||
elif self.pipeline.task == "summarization":
|
elif self.pipeline.task == "summarization":
|
||||||
|
Loading…
Reference in New Issue
Block a user