mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-31 20:19:43 +00:00
Feature/enhance huggingfacepipeline to handle different return type (#11394)
**Description:** Avoid huggingfacepipeline to truncate the response if user setup return_full_text as False within huggingface pipeline. **Dependencies:** : None **Tag maintainer:** Maybe @sam-h-bean ? --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
2aba9ab47e
commit
0b743f005b
@ -202,8 +202,23 @@ class HuggingFacePipeline(BaseLLM):
|
||||
response = response[0]
|
||||
|
||||
if self.pipeline.task == "text-generation":
|
||||
# Text generation return includes the starter text
|
||||
text = response["generated_text"][len(batch_prompts[j]) :]
|
||||
try:
|
||||
from transformers.pipelines.text_generation import ReturnType
|
||||
|
||||
remove_prompt = (
|
||||
self.pipeline._postprocess_params.get("return_type")
|
||||
!= ReturnType.NEW_TEXT
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Unable to extract pipeline return_type. "
|
||||
f"Received error:\n\n{e}"
|
||||
)
|
||||
remove_prompt = True
|
||||
if remove_prompt:
|
||||
text = response["generated_text"][len(batch_prompts[j]) :]
|
||||
else:
|
||||
text = response["generated_text"]
|
||||
elif self.pipeline.task == "text2text-generation":
|
||||
text = response["generated_text"]
|
||||
elif self.pipeline.task == "summarization":
|
||||
|
Loading…
Reference in New Issue
Block a user