core[minor], langchain[minor]: deprecate old Chain and LLM methods (#15499)

This commit is contained in:
Bagatur
2024-01-05 11:58:35 -05:00
committed by GitHub
parent fd5fbb507d
commit 00dfbd2a99
4 changed files with 265 additions and 135 deletions

View File

@@ -5,9 +5,10 @@ import logging
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Dict, List, Optional, Type, Union, cast
import yaml
from langchain_core._api import deprecated
from langchain_core.load.dump import dumpd
from langchain_core.memory import BaseMemory
from langchain_core.outputs import RunInfo
@@ -67,6 +68,43 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
chains and cannot return as rich of an output as `__call__`.
"""
memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None.
Memory is a class that gets called at the start
and at the end of every chain. At the start, memory loads variables and passes
them along in the chain. At the end, it saves any returned variables.
There are many different types of memory - please see memory docs
for the full catalog."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Optional list of callback handlers (or callback manager). Defaults to None.
Callback handlers are called throughout the lifecycle of a call to a chain,
starting with on_chain_start, ending with on_chain_end or on_chain_error.
Each custom chain can optionally call additional callback methods, see Callback docs
for full details."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
will be printed to the console. Defaults to the global `verbose` value,
accessible via `langchain.globals.get_verbose()`."""
tags: Optional[List[str]] = None
"""Optional list of tags associated with the chain. Defaults to None.
These tags will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the chain. Defaults to None.
This metadata will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED] Use `callbacks` instead."""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
@@ -90,14 +128,45 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
**kwargs: Any,
) -> Dict[str, Any]:
config = ensure_config(config)
return self(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
callbacks = config.get("callbacks")
tags = config.get("tags")
metadata = config.get("metadata")
run_name = config.get("run_name")
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)
inputs = self.prep_inputs(input)
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
name=run_name,
)
try:
outputs = (
self._call(inputs, run_manager=run_manager)
if new_arg_supported
else self._call(inputs)
)
except BaseException as e:
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
async def ainvoke(
self,
@@ -106,51 +175,45 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
**kwargs: Any,
) -> Dict[str, Any]:
config = ensure_config(config)
return await self.acall(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
callbacks = config.get("callbacks")
tags = config.get("tags")
metadata = config.get("metadata")
run_name = config.get("run_name")
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)
inputs = self.prep_inputs(input)
callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None.
Memory is a class that gets called at the start
and at the end of every chain. At the start, memory loads variables and passes
them along in the chain. At the end, it saves any returned variables.
There are many different types of memory - please see memory docs
for the full catalog."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Optional list of callback handlers (or callback manager). Defaults to None.
Callback handlers are called throughout the lifecycle of a call to a chain,
starting with on_chain_start, ending with on_chain_end or on_chain_error.
Each custom chain can optionally call additional callback methods, see Callback docs
for full details."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""Deprecated, use `callbacks` instead."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
will be printed to the console. Defaults to the global `verbose` value,
accessible via `langchain.globals.get_verbose()`."""
tags: Optional[List[str]] = None
"""Optional list of tags associated with the chain. Defaults to None.
These tags will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the chain. Defaults to None.
This metadata will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
run_manager = await callback_manager.on_chain_start(
dumpd(self),
inputs,
name=run_name,
)
try:
outputs = (
await self._acall(inputs, run_manager=run_manager)
if new_arg_supported
else await self._acall(inputs)
)
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
await run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
@property
def _chain_type(self) -> str:
@@ -253,6 +316,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
None, self._call, inputs, run_manager.get_sync() if run_manager else None
)
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
@@ -289,39 +353,21 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
inputs = self.prep_inputs(inputs)
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
name=run_name,
)
try:
outputs = (
self._call(inputs, run_manager=run_manager)
if new_arg_supported
else self._call(inputs)
)
except BaseException as e:
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
config = {
"callbacks": callbacks,
"tags": tags,
"metadata": metadata,
"run_name": run_name,
}
return self.invoke(
inputs,
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}),
return_only_outputs=return_only_outputs,
include_run_info=include_run_info,
)
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
async def acall(
self,
inputs: Union[Dict[str, Any], Any],
@@ -358,38 +404,18 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
inputs = self.prep_inputs(inputs)
callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
run_manager = await callback_manager.on_chain_start(
dumpd(self),
config = {
"callbacks": callbacks,
"tags": tags,
"metadata": metadata,
"run_name": run_name,
}
return await self.ainvoke(
inputs,
name=run_name,
cast(RunnableConfig, {k: v for k, v in config.items() if k is not None}),
return_only_outputs=return_only_outputs,
include_run_info=include_run_info,
)
try:
outputs = (
await self._acall(inputs, run_manager=run_manager)
if new_arg_supported
else await self._acall(inputs)
)
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
await run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
def prep_outputs(
self,
@@ -458,6 +484,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
)
return self.output_keys[0]
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
def run(
self,
*args: Any,
@@ -528,6 +555,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
async def arun(
self,
*args: Any,
@@ -665,6 +693,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
else:
raise ValueError(f"{save_path} must be json or yaml")
@deprecated("0.1.0", alternative="batch", removal="0.2.0")
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]: