mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 21:33:51 +00:00
core[minor], langchain[minor]: deprecate old Chain and LLM methods (#15499)
This commit is contained in:
parent
fd5fbb507d
commit
00dfbd2a99
@ -15,6 +15,7 @@ from typing import (
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.runnables import Runnable, RunnableSerializable
|
||||
@ -60,17 +61,6 @@ class BaseLanguageModel(
|
||||
"""Abstract base class for interfacing with language models.
|
||||
|
||||
All language model wrappers inherit from BaseLanguageModel.
|
||||
|
||||
Exposes three main methods:
|
||||
- generate_prompt: generate language model outputs for a sequence of prompt
|
||||
values. A prompt value is a model input that can be converted to any language
|
||||
model input format (string or messages).
|
||||
- predict: pass in a single string to a language model and return a string
|
||||
prediction.
|
||||
- predict_messages: pass in a sequence of BaseMessages (corresponding to a single
|
||||
model call) to a language model and return a BaseMessage prediction.
|
||||
|
||||
Each of these has an equivalent asynchronous method.
|
||||
"""
|
||||
|
||||
@property
|
||||
@ -160,11 +150,12 @@ class BaseLanguageModel(
|
||||
prompt and additional model provider-specific output.
|
||||
"""
|
||||
|
||||
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
|
||||
@abstractmethod
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
"""Pass a single string input to the model and return a string prediction.
|
||||
"""Pass a single string input to the model and return a string.
|
||||
|
||||
Use this method when passing in raw text. If you want to pass in specific
|
||||
types of chat messages, use predict_messages.
|
||||
@ -180,6 +171,7 @@ class BaseLanguageModel(
|
||||
Top model prediction as a string.
|
||||
"""
|
||||
|
||||
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
|
||||
@abstractmethod
|
||||
def predict_messages(
|
||||
self,
|
||||
@ -188,7 +180,7 @@ class BaseLanguageModel(
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""Pass a message sequence to the model and return a message prediction.
|
||||
"""Pass a message sequence to the model and return a message.
|
||||
|
||||
Use this method when passing in chat messages. If you want to pass in raw text,
|
||||
use predict.
|
||||
@ -204,11 +196,12 @@ class BaseLanguageModel(
|
||||
Top model prediction as a message.
|
||||
"""
|
||||
|
||||
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
|
||||
@abstractmethod
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
"""Asynchronously pass a string to the model and return a string prediction.
|
||||
"""Asynchronously pass a string to the model and return a string.
|
||||
|
||||
Use this method when calling pure text generation models and only the top
|
||||
candidate generation is needed.
|
||||
@ -224,6 +217,7 @@ class BaseLanguageModel(
|
||||
Top model prediction as a string.
|
||||
"""
|
||||
|
||||
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
|
||||
@abstractmethod
|
||||
async def apredict_messages(
|
||||
self,
|
||||
@ -232,7 +226,7 @@ class BaseLanguageModel(
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""Asynchronously pass messages to the model and return a message prediction.
|
||||
"""Asynchronously pass messages to the model and return a message.
|
||||
|
||||
Use this method when calling chat models and only the top
|
||||
candidate generation is needed.
|
||||
|
@ -16,6 +16,7 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -108,7 +109,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
"""Callbacks to add to the run trace."""
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""Callback manager to add to the run trace."""
|
||||
"""[DEPRECATED] Callback manager to add to the run trace."""
|
||||
tags: Optional[List[str]] = Field(default=None, exclude=True)
|
||||
"""Tags to add to the run trace."""
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
|
||||
@ -345,7 +346,30 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
"""Pass a sequence of prompts to the model and return model generations.
|
||||
|
||||
This method should make use of batched calls for models that expose a batched
|
||||
API.
|
||||
|
||||
Use this method when you want to:
|
||||
1. take advantage of batched calls,
|
||||
2. need more output from the model than just the top generated value,
|
||||
3. are building chains that are agnostic to the underlying language model
|
||||
type (e.g., pure text completion models vs chat models).
|
||||
|
||||
Args:
|
||||
messages: List of list of messages.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
callbacks: Callbacks to pass through. Used for executing additional
|
||||
functionality, such as logging or streaming, throughout generation.
|
||||
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
An LLMResult, which contains a list of candidate Generations for each input
|
||||
prompt and additional model provider-specific output.
|
||||
"""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop}
|
||||
|
||||
@ -407,7 +431,30 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
"""Asynchronously pass a sequence of prompts to a model and return generations.
|
||||
|
||||
This method should make use of batched calls for models that expose a batched
|
||||
API.
|
||||
|
||||
Use this method when you want to:
|
||||
1. take advantage of batched calls,
|
||||
2. need more output from the model than just the top generated value,
|
||||
3. are building chains that are agnostic to the underlying language model
|
||||
type (e.g., pure text completion models vs chat models).
|
||||
|
||||
Args:
|
||||
messages: List of list of messages.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
callbacks: Callbacks to pass through. Used for executing additional
|
||||
functionality, such as logging or streaming, throughout generation.
|
||||
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
An LLMResult, which contains a list of candidate Generations for each input
|
||||
prompt and additional model provider-specific output.
|
||||
"""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop}
|
||||
|
||||
@ -632,6 +679,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -663,11 +711,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
else:
|
||||
raise ValueError("Unexpected generation type")
|
||||
|
||||
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
|
||||
def call_as_llm(
|
||||
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
return self.predict(message, stop=stop, **kwargs)
|
||||
|
||||
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
@ -681,6 +731,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
else:
|
||||
raise ValueError("Cannot use predict when output is not a string.")
|
||||
|
||||
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
|
||||
def predict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -694,6 +745,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
_stop = list(stop)
|
||||
return self(messages, stop=_stop, **kwargs)
|
||||
|
||||
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
@ -709,6 +761,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
else:
|
||||
raise ValueError("Cannot use predict when output is not a string.")
|
||||
|
||||
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
|
||||
async def apredict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
|
@ -36,6 +36,7 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -157,14 +158,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
It should take in a prompt and return a string."""
|
||||
|
||||
cache: Optional[bool] = None
|
||||
"""Whether to cache the response."""
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""Callbacks to add to the run trace."""
|
||||
tags: Optional[List[str]] = Field(default=None, exclude=True)
|
||||
"""Tags to add to the run trace."""
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
|
||||
"""Metadata to add to the run trace."""
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""[DEPRECATED]"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -576,7 +580,30 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
run_name: Optional[Union[str, List[str]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
"""Pass a sequence of prompts to a model and return generations.
|
||||
|
||||
This method should make use of batched calls for models that expose a batched
|
||||
API.
|
||||
|
||||
Use this method when you want to:
|
||||
1. take advantage of batched calls,
|
||||
2. need more output from the model than just the top generated value,
|
||||
3. are building chains that are agnostic to the underlying language model
|
||||
type (e.g., pure text completion models vs chat models).
|
||||
|
||||
Args:
|
||||
prompts: List of string prompts.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
callbacks: Callbacks to pass through. Used for executing additional
|
||||
functionality, such as logging or streaming, throughout generation.
|
||||
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
An LLMResult, which contains a list of candidate Generations for each input
|
||||
prompt and additional model provider-specific output.
|
||||
"""
|
||||
if not isinstance(prompts, list):
|
||||
raise ValueError(
|
||||
"Argument 'prompts' is expected to be of type List[str], received"
|
||||
@ -754,7 +781,30 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
run_name: Optional[Union[str, List[str]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
"""Asynchronously pass a sequence of prompts to a model and return generations.
|
||||
|
||||
This method should make use of batched calls for models that expose a batched
|
||||
API.
|
||||
|
||||
Use this method when you want to:
|
||||
1. take advantage of batched calls,
|
||||
2. need more output from the model than just the top generated value,
|
||||
3. are building chains that are agnostic to the underlying language model
|
||||
type (e.g., pure text completion models vs chat models).
|
||||
|
||||
Args:
|
||||
prompts: List of string prompts.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
callbacks: Callbacks to pass through. Used for executing additional
|
||||
functionality, such as logging or streaming, throughout generation.
|
||||
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
An LLMResult, which contains a list of candidate Generations for each input
|
||||
prompt and additional model provider-specific output.
|
||||
"""
|
||||
# Create callback managers
|
||||
if isinstance(callbacks, list) and (
|
||||
isinstance(callbacks[0], (list, BaseCallbackManager))
|
||||
@ -927,6 +977,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
)
|
||||
return result.generations[0][0].text
|
||||
|
||||
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
@ -936,6 +987,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
_stop = list(stop)
|
||||
return self(text, stop=_stop, **kwargs)
|
||||
|
||||
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
|
||||
def predict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -951,6 +1003,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
content = self(text, stop=_stop, **kwargs)
|
||||
return AIMessage(content=content)
|
||||
|
||||
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
@ -960,6 +1013,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
_stop = list(stop)
|
||||
return await self._call_async(text, stop=_stop, **kwargs)
|
||||
|
||||
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
|
||||
async def apredict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
|
@ -5,9 +5,10 @@ import logging
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Type, Union, cast
|
||||
|
||||
import yaml
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.memory import BaseMemory
|
||||
from langchain_core.outputs import RunInfo
|
||||
@ -67,6 +68,43 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
chains and cannot return as rich of an output as `__call__`.
|
||||
"""
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
"""Optional memory object. Defaults to None.
|
||||
Memory is a class that gets called at the start
|
||||
and at the end of every chain. At the start, memory loads variables and passes
|
||||
them along in the chain. At the end, it saves any returned variables.
|
||||
There are many different types of memory - please see memory docs
|
||||
for the full catalog."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
"""Optional list of callback handlers (or callback manager). Defaults to None.
|
||||
Callback handlers are called throughout the lifecycle of a call to a chain,
|
||||
starting with on_chain_start, ending with on_chain_end or on_chain_error.
|
||||
Each custom chain can optionally call additional callback methods, see Callback docs
|
||||
for full details."""
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
|
||||
will be printed to the console. Defaults to the global `verbose` value,
|
||||
accessible via `langchain.globals.get_verbose()`."""
|
||||
tags: Optional[List[str]] = None
|
||||
"""Optional list of tags associated with the chain. Defaults to None.
|
||||
These tags will be associated with each call to this chain,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a chain with its use case.
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
"""Optional metadata associated with the chain. Defaults to None.
|
||||
This metadata will be associated with each call to this chain,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a chain with its use case.
|
||||
"""
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""[DEPRECATED] Use `callbacks` instead."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
@ -90,14 +128,45 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
config = ensure_config(config)
|
||||
return self(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
callbacks = config.get("callbacks")
|
||||
tags = config.get("tags")
|
||||
metadata = config.get("metadata")
|
||||
run_name = config.get("run_name")
|
||||
include_run_info = kwargs.get("include_run_info", False)
|
||||
return_only_outputs = kwargs.get("return_only_outputs", False)
|
||||
|
||||
inputs = self.prep_inputs(input)
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
name=run_name,
|
||||
)
|
||||
try:
|
||||
outputs = (
|
||||
self._call(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._call(inputs)
|
||||
)
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
@ -106,51 +175,45 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
config = ensure_config(config)
|
||||
return await self.acall(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
callbacks = config.get("callbacks")
|
||||
tags = config.get("tags")
|
||||
metadata = config.get("metadata")
|
||||
run_name = config.get("run_name")
|
||||
include_run_info = kwargs.get("include_run_info", False)
|
||||
return_only_outputs = kwargs.get("return_only_outputs", False)
|
||||
|
||||
inputs = self.prep_inputs(input)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
"""Optional memory object. Defaults to None.
|
||||
Memory is a class that gets called at the start
|
||||
and at the end of every chain. At the start, memory loads variables and passes
|
||||
them along in the chain. At the end, it saves any returned variables.
|
||||
There are many different types of memory - please see memory docs
|
||||
for the full catalog."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
"""Optional list of callback handlers (or callback manager). Defaults to None.
|
||||
Callback handlers are called throughout the lifecycle of a call to a chain,
|
||||
starting with on_chain_start, ending with on_chain_end or on_chain_error.
|
||||
Each custom chain can optionally call additional callback methods, see Callback docs
|
||||
for full details."""
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""Deprecated, use `callbacks` instead."""
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
|
||||
will be printed to the console. Defaults to the global `verbose` value,
|
||||
accessible via `langchain.globals.get_verbose()`."""
|
||||
tags: Optional[List[str]] = None
|
||||
"""Optional list of tags associated with the chain. Defaults to None.
|
||||
These tags will be associated with each call to this chain,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a chain with its use case.
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
"""Optional metadata associated with the chain. Defaults to None.
|
||||
This metadata will be associated with each call to this chain,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a chain with its use case.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
name=run_name,
|
||||
)
|
||||
try:
|
||||
outputs = (
|
||||
await self._acall(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else await self._acall(inputs)
|
||||
)
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
await run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
@ -253,6 +316,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
None, self._call, inputs, run_manager.get_sync() if run_manager else None
|
||||
)
|
||||
|
||||
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
@ -289,39 +353,21 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
name=run_name,
|
||||
)
|
||||
try:
|
||||
outputs = (
|
||||
self._call(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._call(inputs)
|
||||
)
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
config = {
|
||||
"callbacks": callbacks,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
"run_name": run_name,
|
||||
}
|
||||
|
||||
return self.invoke(
|
||||
inputs,
|
||||
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}),
|
||||
return_only_outputs=return_only_outputs,
|
||||
include_run_info=include_run_info,
|
||||
)
|
||||
|
||||
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
|
||||
async def acall(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
@ -358,38 +404,18 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
config = {
|
||||
"callbacks": callbacks,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
"run_name": run_name,
|
||||
}
|
||||
return await self.ainvoke(
|
||||
inputs,
|
||||
name=run_name,
|
||||
cast(RunnableConfig, {k: v for k, v in config.items() if k is not None}),
|
||||
return_only_outputs=return_only_outputs,
|
||||
include_run_info=include_run_info,
|
||||
)
|
||||
try:
|
||||
outputs = (
|
||||
await self._acall(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else await self._acall(inputs)
|
||||
)
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
await run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
def prep_outputs(
|
||||
self,
|
||||
@ -458,6 +484,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
)
|
||||
return self.output_keys[0]
|
||||
|
||||
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -528,6 +555,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
)
|
||||
|
||||
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
|
||||
async def arun(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -665,6 +693,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
@deprecated("0.1.0", alternative="batch", removal="0.2.0")
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
|
Loading…
Reference in New Issue
Block a user