From 1dd4236177c3e21c26773d1948759129c59520e7 Mon Sep 17 00:00:00 2001 From: Hashem Alsaket Date: Tue, 11 Jul 2023 02:06:05 -0500 Subject: [PATCH] Fix HF endpoint returns blank for text-generation (#7386) Description: Current `_call` function in the `langchain.llms.HuggingFaceEndpoint` class truncates response when `task=text-generation`. Same error discussed a few days ago on Hugging Face: https://huggingface.co/tiiuae/falcon-40b-instruct/discussions/51 Issue: Fixes #7353 Tag maintainer: @hwchase17 @baskaryan @hinthornw --------- Co-authored-by: Bagatur --- langchain/llms/huggingface_endpoint.py | 6 ++++-- tests/integration_tests/vectorstores/test_pinecone.py | 2 +- tests/integration_tests/vectorstores/test_qdrant.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/langchain/llms/huggingface_endpoint.py b/langchain/llms/huggingface_endpoint.py index ff83b0baeaa..34aa98d4ab7 100644 --- a/langchain/llms/huggingface_endpoint.py +++ b/langchain/llms/huggingface_endpoint.py @@ -137,8 +137,10 @@ class HuggingFaceEndpoint(LLM): f"Error raised by inference API: {generated_text['error']}" ) if self.task == "text-generation": - # Text generation return includes the starter text. - text = generated_text[0]["generated_text"][len(prompt) :] + text = generated_text[0]["generated_text"] + # Remove prompt if included in generated text. + if text.startswith(prompt): + text = text[len(prompt) :] elif self.task == "text2text-generation": text = generated_text[0]["generated_text"] elif self.task == "summarization": diff --git a/tests/integration_tests/vectorstores/test_pinecone.py b/tests/integration_tests/vectorstores/test_pinecone.py index 238f46f6ccb..d1c6fa0c56f 100644 --- a/tests/integration_tests/vectorstores/test_pinecone.py +++ b/tests/integration_tests/vectorstores/test_pinecone.py @@ -2,9 +2,9 @@ import importlib import os import time import uuid -import numpy as np from typing import List +import numpy as np import pinecone import pytest diff --git a/tests/integration_tests/vectorstores/test_qdrant.py b/tests/integration_tests/vectorstores/test_qdrant.py index de779dadf85..e46a6e36e64 100644 --- a/tests/integration_tests/vectorstores/test_qdrant.py +++ b/tests/integration_tests/vectorstores/test_qdrant.py @@ -1,8 +1,8 @@ """Test Qdrant functionality.""" import tempfile from typing import Callable, Optional -import numpy as np +import numpy as np import pytest from qdrant_client.http import models as rest