mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 04:49:17 +00:00
fix replicate output type (#10598)
This commit is contained in:
@@ -131,7 +131,10 @@ class Replicate(LLM):
|
||||
prediction.wait()
|
||||
if prediction.status == "failed":
|
||||
raise RuntimeError(prediction.error)
|
||||
completion = prediction.output
|
||||
if isinstance(prediction.output, str):
|
||||
completion = prediction.output
|
||||
else:
|
||||
completion = "".join(prediction.output)
|
||||
assert completion is not None
|
||||
stop_conditions = stop or self.stop
|
||||
for s in stop_conditions:
|
||||
|
@@ -4,16 +4,15 @@ from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.llms.replicate import Replicate
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
TEST_MODEL_NAME = "replicate/hello-world"
|
||||
TEST_MODEL_VER = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa"
|
||||
TEST_MODEL = TEST_MODEL_NAME + ":" + TEST_MODEL_VER
|
||||
TEST_MODEL = "replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5" # noqa: E501
|
||||
|
||||
|
||||
def test_replicate_call() -> None:
|
||||
"""Test simple non-streaming call to Replicate."""
|
||||
llm = Replicate(model=TEST_MODEL)
|
||||
output = llm("LangChain")
|
||||
assert output == "hello LangChain"
|
||||
output = llm("What is LangChain")
|
||||
assert output
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_replicate_streaming_call() -> None:
|
||||
@@ -22,13 +21,6 @@ def test_replicate_streaming_call() -> None:
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
|
||||
llm = Replicate(streaming=True, callback_manager=callback_manager, model=TEST_MODEL)
|
||||
output = llm("LangChain")
|
||||
assert output == "hello LangChain"
|
||||
assert callback_handler.llm_streams == 15
|
||||
|
||||
|
||||
def test_replicate_stop_sequence() -> None:
|
||||
"""Test call to Replicate with a stop sequence."""
|
||||
llm = Replicate(model=TEST_MODEL)
|
||||
output = llm("one two three", stop=["two"])
|
||||
assert output == "hello one "
|
||||
output = llm("What is LangChain")
|
||||
assert output
|
||||
assert isinstance(output, str)
|
||||
|
Reference in New Issue
Block a user