mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 18:08:36 +00:00
add replicate stream (#10518)
support direct replicate streaming. cc @cbh123 @tjaffri
This commit is contained in:
parent
7f3f6097e7
commit
9dd4cacae2
@ -1,13 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.pydantic_v1 import Extra, Field, root_validator
|
from langchain.pydantic_v1 import Extra, Field, root_validator
|
||||||
|
from langchain.schema.output import GenerationChunk
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from replicate.prediction import Prediction
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -46,10 +50,10 @@ class Replicate(LLM):
|
|||||||
serializable, is excluded from serialization.
|
serializable, is excluded from serialization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
streaming: bool = Field(default=False)
|
streaming: bool = False
|
||||||
"""Whether to stream the results."""
|
"""Whether to stream the results."""
|
||||||
|
|
||||||
stop: Optional[List[str]] = Field(default=[])
|
stop: List[str] = Field(default_factory=list)
|
||||||
"""Stop sequences to early-terminate generation."""
|
"""Stop sequences to early-terminate generation."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -97,7 +101,7 @@ class Replicate(LLM):
|
|||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
return {
|
return {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
**{"model_kwargs": self.model_kwargs},
|
"model_kwargs": self.model_kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -113,6 +117,63 @@ class Replicate(LLM):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call to replicate endpoint."""
|
"""Call to replicate endpoint."""
|
||||||
|
if self.streaming:
|
||||||
|
completion: Optional[str] = None
|
||||||
|
for chunk in self._stream(
|
||||||
|
prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
if completion is None:
|
||||||
|
completion = chunk.text
|
||||||
|
else:
|
||||||
|
completion += chunk.text
|
||||||
|
else:
|
||||||
|
prediction = self._create_prediction(prompt, **kwargs)
|
||||||
|
prediction.wait()
|
||||||
|
if prediction.status == "failed":
|
||||||
|
raise RuntimeError(prediction.error)
|
||||||
|
completion = prediction.output
|
||||||
|
assert completion is not None
|
||||||
|
stop_conditions = stop or self.stop
|
||||||
|
for s in stop_conditions:
|
||||||
|
if s in completion:
|
||||||
|
completion = completion[: completion.find(s)]
|
||||||
|
return completion
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
prediction = self._create_prediction(prompt, **kwargs)
|
||||||
|
stop_conditions = stop or self.stop
|
||||||
|
stop_condition_reached = False
|
||||||
|
current_completion: str = ""
|
||||||
|
for output in prediction.output_iterator():
|
||||||
|
current_completion += output
|
||||||
|
# test for stop conditions, if specified
|
||||||
|
for s in stop_conditions:
|
||||||
|
if s in current_completion:
|
||||||
|
prediction.cancel()
|
||||||
|
stop_condition_reached = True
|
||||||
|
# Potentially some tokens that should still be yielded before ending
|
||||||
|
# stream.
|
||||||
|
stop_index = max(output.find(s), 0)
|
||||||
|
output = output[:stop_index]
|
||||||
|
if not output:
|
||||||
|
break
|
||||||
|
if output:
|
||||||
|
yield GenerationChunk(text=output)
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(
|
||||||
|
output,
|
||||||
|
verbose=self.verbose,
|
||||||
|
)
|
||||||
|
if stop_condition_reached:
|
||||||
|
break
|
||||||
|
|
||||||
|
def _create_prediction(self, prompt: str, **kwargs: Any) -> Prediction:
|
||||||
try:
|
try:
|
||||||
import replicate as replicate_python
|
import replicate as replicate_python
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -138,29 +199,7 @@ class Replicate(LLM):
|
|||||||
|
|
||||||
self.prompt_key = input_properties[0][0]
|
self.prompt_key = input_properties[0][0]
|
||||||
|
|
||||||
inputs: Dict = {self.prompt_key: prompt, **self.input}
|
input_: Dict = {self.prompt_key: prompt, **self.input, **kwargs}
|
||||||
|
return replicate_python.predictions.create(
|
||||||
prediction = replicate_python.predictions.create(
|
version=self.version_obj, input=input_
|
||||||
version=self.version_obj, input={**inputs, **kwargs}
|
|
||||||
)
|
)
|
||||||
current_completion: str = ""
|
|
||||||
stop_condition_reached = False
|
|
||||||
for output in prediction.output_iterator():
|
|
||||||
current_completion += output
|
|
||||||
|
|
||||||
# test for stop conditions, if specified
|
|
||||||
if stop:
|
|
||||||
for s in stop:
|
|
||||||
if s in current_completion:
|
|
||||||
prediction.cancel()
|
|
||||||
stop_index = current_completion.find(s)
|
|
||||||
current_completion = current_completion[:stop_index]
|
|
||||||
stop_condition_reached = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if stop_condition_reached:
|
|
||||||
break
|
|
||||||
|
|
||||||
if self.streaming and run_manager:
|
|
||||||
run_manager.on_llm_new_token(output)
|
|
||||||
return current_completion
|
|
||||||
|
Loading…
Reference in New Issue
Block a user