Fix HF endpoint returns blank for text-generation (#7386)

Description: Current `_call` function in the
`langchain.llms.HuggingFaceEndpoint` class truncates response when
`task=text-generation`. Same error discussed a few days ago on Hugging
Face: https://huggingface.co/tiiuae/falcon-40b-instruct/discussions/51
Issue: Fixes #7353 
Tag maintainer: @hwchase17 @baskaryan @hinthornw

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Hashem Alsaket 2023-07-11 02:06:05 -05:00 committed by GitHub
parent 4a94f56258
commit 1dd4236177
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 4 deletions

View File

@ -137,8 +137,10 @@ class HuggingFaceEndpoint(LLM):
f"Error raised by inference API: {generated_text['error']}" f"Error raised by inference API: {generated_text['error']}"
) )
if self.task == "text-generation": if self.task == "text-generation":
# Text generation return includes the starter text. text = generated_text[0]["generated_text"]
text = generated_text[0]["generated_text"][len(prompt) :] # Remove prompt if included in generated text.
if text.startswith(prompt):
text = text[len(prompt) :]
elif self.task == "text2text-generation": elif self.task == "text2text-generation":
text = generated_text[0]["generated_text"] text = generated_text[0]["generated_text"]
elif self.task == "summarization": elif self.task == "summarization":

View File

@ -2,9 +2,9 @@ import importlib
import os import os
import time import time
import uuid import uuid
import numpy as np
from typing import List from typing import List
import numpy as np
import pinecone import pinecone
import pytest import pytest

View File

@ -1,8 +1,8 @@
"""Test Qdrant functionality.""" """Test Qdrant functionality."""
import tempfile import tempfile
from typing import Callable, Optional from typing import Callable, Optional
import numpy as np
import numpy as np
import pytest import pytest
from qdrant_client.http import models as rest from qdrant_client.http import models as rest