core[minor], langchain[minor]: deprecate old Chain and LLM methods (#15499)

This commit is contained in:
Bagatur 2024-01-05 11:58:35 -05:00 committed by GitHub
parent fd5fbb507d
commit 00dfbd2a99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 265 additions and 135 deletions

View File

@ -15,6 +15,7 @@ from typing import (
from typing_extensions import TypeAlias
from langchain_core._api import deprecated
from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables import Runnable, RunnableSerializable
@ -60,17 +61,6 @@ class BaseLanguageModel(
"""Abstract base class for interfacing with language models.
All language model wrappers inherit from BaseLanguageModel.
Exposes three main methods:
- generate_prompt: generate language model outputs for a sequence of prompt
values. A prompt value is a model input that can be converted to any language
model input format (string or messages).
- predict: pass in a single string to a language model and return a string
prediction.
- predict_messages: pass in a sequence of BaseMessages (corresponding to a single
model call) to a language model and return a BaseMessage prediction.
Each of these has an equivalent asynchronous method.
"""
@property
@ -160,11 +150,12 @@ class BaseLanguageModel(
prompt and additional model provider-specific output.
"""
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
@abstractmethod
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
"""Pass a single string input to the model and return a string prediction.
"""Pass a single string input to the model and return a string.
Use this method when passing in raw text. If you want to pass in specific
types of chat messages, use predict_messages.
@ -180,6 +171,7 @@ class BaseLanguageModel(
Top model prediction as a string.
"""
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
@abstractmethod
def predict_messages(
self,
@ -188,7 +180,7 @@ class BaseLanguageModel(
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
"""Pass a message sequence to the model and return a message prediction.
"""Pass a message sequence to the model and return a message.
Use this method when passing in chat messages. If you want to pass in raw text,
use predict.
@ -204,11 +196,12 @@ class BaseLanguageModel(
Top model prediction as a message.
"""
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
@abstractmethod
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
"""Asynchronously pass a string to the model and return a string prediction.
"""Asynchronously pass a string to the model and return a string.
Use this method when calling pure text generation models and only the top
candidate generation is needed.
@ -224,6 +217,7 @@ class BaseLanguageModel(
Top model prediction as a string.
"""
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
@abstractmethod
async def apredict_messages(
self,
@ -232,7 +226,7 @@ class BaseLanguageModel(
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
"""Asynchronously pass messages to the model and return a message prediction.
"""Asynchronously pass messages to the model and return a message.
Use this method when calling chat models and only the top
candidate generation is needed.

View File

@ -16,6 +16,7 @@ from typing import (
cast,
)
from langchain_core._api import deprecated
from langchain_core.callbacks import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
@ -108,7 +109,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to add to the run trace."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""Callback manager to add to the run trace."""
"""[DEPRECATED] Callback manager to add to the run trace."""
tags: Optional[List[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace."""
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
@ -345,7 +346,30 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_name: Optional[str] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
"""Pass a sequence of prompts to the model and return model generations.
This method should make use of batched calls for models that expose a batched
API.
Use this method when you want to:
1. take advantage of batched calls,
2. need more output from the model than just the top generated value,
3. are building chains that are agnostic to the underlying language model
type (e.g., pure text completion models vs chat models).
Args:
messages: List of list of messages.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
callbacks: Callbacks to pass through. Used for executing additional
functionality, such as logging or streaming, throughout generation.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
"""
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop}
@ -407,7 +431,30 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_name: Optional[str] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
"""Asynchronously pass a sequence of prompts to a model and return generations.
This method should make use of batched calls for models that expose a batched
API.
Use this method when you want to:
1. take advantage of batched calls,
2. need more output from the model than just the top generated value,
3. are building chains that are agnostic to the underlying language model
type (e.g., pure text completion models vs chat models).
Args:
messages: List of list of messages.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
callbacks: Callbacks to pass through. Used for executing additional
functionality, such as logging or streaming, throughout generation.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
"""
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop}
@ -632,6 +679,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
) -> AsyncIterator[ChatGenerationChunk]:
raise NotImplementedError()
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
def __call__(
self,
messages: List[BaseMessage],
@ -663,11 +711,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
else:
raise ValueError("Unexpected generation type")
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
def call_as_llm(
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
) -> str:
return self.predict(message, stop=stop, **kwargs)
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
@ -681,6 +731,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
else:
raise ValueError("Cannot use predict when output is not a string.")
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
def predict_messages(
self,
messages: List[BaseMessage],
@ -694,6 +745,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
_stop = list(stop)
return self(messages, stop=_stop, **kwargs)
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
@ -709,6 +761,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
else:
raise ValueError("Cannot use predict when output is not a string.")
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
async def apredict_messages(
self,
messages: List[BaseMessage],

View File

@ -36,6 +36,7 @@ from tenacity import (
wait_exponential,
)
from langchain_core._api import deprecated
from langchain_core.callbacks import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
@ -157,14 +158,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
It should take in a prompt and return a string."""
cache: Optional[bool] = None
"""Whether to cache the response."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""Callbacks to add to the run trace."""
tags: Optional[List[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace."""
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
"""Metadata to add to the run trace."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED]"""
class Config:
"""Configuration for this pydantic object."""
@ -576,7 +580,30 @@ class BaseLLM(BaseLanguageModel[str], ABC):
run_name: Optional[Union[str, List[str]]] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
"""Pass a sequence of prompts to a model and return generations.
This method should make use of batched calls for models that expose a batched
API.
Use this method when you want to:
1. take advantage of batched calls,
2. need more output from the model than just the top generated value,
3. are building chains that are agnostic to the underlying language model
type (e.g., pure text completion models vs chat models).
Args:
prompts: List of string prompts.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
callbacks: Callbacks to pass through. Used for executing additional
functionality, such as logging or streaming, throughout generation.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
"""
if not isinstance(prompts, list):
raise ValueError(
"Argument 'prompts' is expected to be of type List[str], received"
@ -754,7 +781,30 @@ class BaseLLM(BaseLanguageModel[str], ABC):
run_name: Optional[Union[str, List[str]]] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
"""Asynchronously pass a sequence of prompts to a model and return generations.
This method should make use of batched calls for models that expose a batched
API.
Use this method when you want to:
1. take advantage of batched calls,
2. need more output from the model than just the top generated value,
3. are building chains that are agnostic to the underlying language model
type (e.g., pure text completion models vs chat models).
Args:
prompts: List of string prompts.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
callbacks: Callbacks to pass through. Used for executing additional
functionality, such as logging or streaming, throughout generation.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
"""
# Create callback managers
if isinstance(callbacks, list) and (
isinstance(callbacks[0], (list, BaseCallbackManager))
@ -927,6 +977,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
)
return result.generations[0][0].text
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
@ -936,6 +987,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
_stop = list(stop)
return self(text, stop=_stop, **kwargs)
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
def predict_messages(
self,
messages: List[BaseMessage],
@ -951,6 +1003,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
content = self(text, stop=_stop, **kwargs)
return AIMessage(content=content)
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
@ -960,6 +1013,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
_stop = list(stop)
return await self._call_async(text, stop=_stop, **kwargs)
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
async def apredict_messages(
self,
messages: List[BaseMessage],

View File

@ -5,9 +5,10 @@ import logging
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Dict, List, Optional, Type, Union, cast
import yaml
from langchain_core._api import deprecated
from langchain_core.load.dump import dumpd
from langchain_core.memory import BaseMemory
from langchain_core.outputs import RunInfo
@ -67,6 +68,43 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
chains and cannot return as rich of an output as `__call__`.
"""
memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None.
Memory is a class that gets called at the start
and at the end of every chain. At the start, memory loads variables and passes
them along in the chain. At the end, it saves any returned variables.
There are many different types of memory - please see memory docs
for the full catalog."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Optional list of callback handlers (or callback manager). Defaults to None.
Callback handlers are called throughout the lifecycle of a call to a chain,
starting with on_chain_start, ending with on_chain_end or on_chain_error.
Each custom chain can optionally call additional callback methods, see Callback docs
for full details."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
will be printed to the console. Defaults to the global `verbose` value,
accessible via `langchain.globals.get_verbose()`."""
tags: Optional[List[str]] = None
"""Optional list of tags associated with the chain. Defaults to None.
These tags will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the chain. Defaults to None.
This metadata will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED] Use `callbacks` instead."""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
@ -90,14 +128,45 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
**kwargs: Any,
) -> Dict[str, Any]:
config = ensure_config(config)
return self(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
callbacks = config.get("callbacks")
tags = config.get("tags")
metadata = config.get("metadata")
run_name = config.get("run_name")
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)
inputs = self.prep_inputs(input)
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
name=run_name,
)
try:
outputs = (
self._call(inputs, run_manager=run_manager)
if new_arg_supported
else self._call(inputs)
)
except BaseException as e:
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
async def ainvoke(
self,
@ -106,51 +175,45 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
**kwargs: Any,
) -> Dict[str, Any]:
config = ensure_config(config)
return await self.acall(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
callbacks = config.get("callbacks")
tags = config.get("tags")
metadata = config.get("metadata")
run_name = config.get("run_name")
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)
inputs = self.prep_inputs(input)
callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None.
Memory is a class that gets called at the start
and at the end of every chain. At the start, memory loads variables and passes
them along in the chain. At the end, it saves any returned variables.
There are many different types of memory - please see memory docs
for the full catalog."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Optional list of callback handlers (or callback manager). Defaults to None.
Callback handlers are called throughout the lifecycle of a call to a chain,
starting with on_chain_start, ending with on_chain_end or on_chain_error.
Each custom chain can optionally call additional callback methods, see Callback docs
for full details."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""Deprecated, use `callbacks` instead."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
will be printed to the console. Defaults to the global `verbose` value,
accessible via `langchain.globals.get_verbose()`."""
tags: Optional[List[str]] = None
"""Optional list of tags associated with the chain. Defaults to None.
These tags will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the chain. Defaults to None.
This metadata will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
run_manager = await callback_manager.on_chain_start(
dumpd(self),
inputs,
name=run_name,
)
try:
outputs = (
await self._acall(inputs, run_manager=run_manager)
if new_arg_supported
else await self._acall(inputs)
)
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
await run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
@property
def _chain_type(self) -> str:
@ -253,6 +316,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
None, self._call, inputs, run_manager.get_sync() if run_manager else None
)
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
@ -289,39 +353,21 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
inputs = self.prep_inputs(inputs)
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
name=run_name,
)
try:
outputs = (
self._call(inputs, run_manager=run_manager)
if new_arg_supported
else self._call(inputs)
)
except BaseException as e:
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
config = {
"callbacks": callbacks,
"tags": tags,
"metadata": metadata,
"run_name": run_name,
}
return self.invoke(
inputs,
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}),
return_only_outputs=return_only_outputs,
include_run_info=include_run_info,
)
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
async def acall(
self,
inputs: Union[Dict[str, Any], Any],
@ -358,38 +404,18 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
inputs = self.prep_inputs(inputs)
callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
run_manager = await callback_manager.on_chain_start(
dumpd(self),
config = {
"callbacks": callbacks,
"tags": tags,
"metadata": metadata,
"run_name": run_name,
}
return await self.ainvoke(
inputs,
name=run_name,
cast(RunnableConfig, {k: v for k, v in config.items() if k is not None}),
return_only_outputs=return_only_outputs,
include_run_info=include_run_info,
)
try:
outputs = (
await self._acall(inputs, run_manager=run_manager)
if new_arg_supported
else await self._acall(inputs)
)
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
await run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
def prep_outputs(
self,
@ -458,6 +484,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
)
return self.output_keys[0]
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
def run(
self,
*args: Any,
@ -528,6 +555,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
async def arun(
self,
*args: Any,
@ -665,6 +693,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
else:
raise ValueError(f"{save_path} must be json or yaml")
@deprecated("0.1.0", alternative="batch", removal="0.2.0")
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]: