fix replicate output type (#10598)

This commit is contained in:
Bagatur
2023-09-14 14:02:01 -07:00
committed by GitHub
parent 7608f85f13
commit 77a165e0d9
3 changed files with 44 additions and 49 deletions

View File

@@ -131,7 +131,10 @@ class Replicate(LLM):
prediction.wait()
if prediction.status == "failed":
raise RuntimeError(prediction.error)
completion = prediction.output
if isinstance(prediction.output, str):
completion = prediction.output
else:
completion = "".join(prediction.output)
assert completion is not None
stop_conditions = stop or self.stop
for s in stop_conditions:

View File

@@ -4,16 +4,15 @@ from langchain.callbacks.manager import CallbackManager
from langchain.llms.replicate import Replicate
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
TEST_MODEL_NAME = "replicate/hello-world"
TEST_MODEL_VER = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa"
TEST_MODEL = TEST_MODEL_NAME + ":" + TEST_MODEL_VER
TEST_MODEL = "replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5" # noqa: E501
def test_replicate_call() -> None:
"""Test simple non-streaming call to Replicate."""
llm = Replicate(model=TEST_MODEL)
output = llm("LangChain")
assert output == "hello LangChain"
output = llm("What is LangChain")
assert output
assert isinstance(output, str)
def test_replicate_streaming_call() -> None:
@@ -22,13 +21,6 @@ def test_replicate_streaming_call() -> None:
callback_manager = CallbackManager([callback_handler])
llm = Replicate(streaming=True, callback_manager=callback_manager, model=TEST_MODEL)
output = llm("LangChain")
assert output == "hello LangChain"
assert callback_handler.llm_streams == 15
def test_replicate_stop_sequence() -> None:
"""Test call to Replicate with a stop sequence."""
llm = Replicate(model=TEST_MODEL)
output = llm("one two three", stop=["two"])
assert output == "hello one "
output = llm("What is LangChain")
assert output
assert isinstance(output, str)