mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 01:48:57 +00:00
add replicate stream (#10518)
support direct replicate streaming. cc @cbh123 @tjaffri
This commit is contained in:
parent
7f3f6097e7
commit
9dd4cacae2
@ -1,13 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from replicate.prediction import Prediction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -46,10 +50,10 @@ class Replicate(LLM):
|
||||
serializable, is excluded from serialization.
|
||||
"""
|
||||
|
||||
streaming: bool = Field(default=False)
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results."""
|
||||
|
||||
stop: Optional[List[str]] = Field(default=[])
|
||||
stop: List[str] = Field(default_factory=list)
|
||||
"""Stop sequences to early-terminate generation."""
|
||||
|
||||
class Config:
|
||||
@ -97,7 +101,7 @@ class Replicate(LLM):
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model": self.model,
|
||||
**{"model_kwargs": self.model_kwargs},
|
||||
"model_kwargs": self.model_kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
@ -113,6 +117,63 @@ class Replicate(LLM):
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call to replicate endpoint."""
|
||||
if self.streaming:
|
||||
completion: Optional[str] = None
|
||||
for chunk in self._stream(
|
||||
prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if completion is None:
|
||||
completion = chunk.text
|
||||
else:
|
||||
completion += chunk.text
|
||||
else:
|
||||
prediction = self._create_prediction(prompt, **kwargs)
|
||||
prediction.wait()
|
||||
if prediction.status == "failed":
|
||||
raise RuntimeError(prediction.error)
|
||||
completion = prediction.output
|
||||
assert completion is not None
|
||||
stop_conditions = stop or self.stop
|
||||
for s in stop_conditions:
|
||||
if s in completion:
|
||||
completion = completion[: completion.find(s)]
|
||||
return completion
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
prediction = self._create_prediction(prompt, **kwargs)
|
||||
stop_conditions = stop or self.stop
|
||||
stop_condition_reached = False
|
||||
current_completion: str = ""
|
||||
for output in prediction.output_iterator():
|
||||
current_completion += output
|
||||
# test for stop conditions, if specified
|
||||
for s in stop_conditions:
|
||||
if s in current_completion:
|
||||
prediction.cancel()
|
||||
stop_condition_reached = True
|
||||
# Potentially some tokens that should still be yielded before ending
|
||||
# stream.
|
||||
stop_index = max(output.find(s), 0)
|
||||
output = output[:stop_index]
|
||||
if not output:
|
||||
break
|
||||
if output:
|
||||
yield GenerationChunk(text=output)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
output,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
if stop_condition_reached:
|
||||
break
|
||||
|
||||
def _create_prediction(self, prompt: str, **kwargs: Any) -> Prediction:
|
||||
try:
|
||||
import replicate as replicate_python
|
||||
except ImportError:
|
||||
@ -138,29 +199,7 @@ class Replicate(LLM):
|
||||
|
||||
self.prompt_key = input_properties[0][0]
|
||||
|
||||
inputs: Dict = {self.prompt_key: prompt, **self.input}
|
||||
|
||||
prediction = replicate_python.predictions.create(
|
||||
version=self.version_obj, input={**inputs, **kwargs}
|
||||
input_: Dict = {self.prompt_key: prompt, **self.input, **kwargs}
|
||||
return replicate_python.predictions.create(
|
||||
version=self.version_obj, input=input_
|
||||
)
|
||||
current_completion: str = ""
|
||||
stop_condition_reached = False
|
||||
for output in prediction.output_iterator():
|
||||
current_completion += output
|
||||
|
||||
# test for stop conditions, if specified
|
||||
if stop:
|
||||
for s in stop:
|
||||
if s in current_completion:
|
||||
prediction.cancel()
|
||||
stop_index = current_completion.find(s)
|
||||
current_completion = current_completion[:stop_index]
|
||||
stop_condition_reached = True
|
||||
break
|
||||
|
||||
if stop_condition_reached:
|
||||
break
|
||||
|
||||
if self.streaming and run_manager:
|
||||
run_manager.on_llm_new_token(output)
|
||||
return current_completion
|
||||
|
Loading…
Reference in New Issue
Block a user