mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
core[patch]: move some attr/methods to BaseLanguageModel (#18936)
Cleans up some shared code between `BaseLLM` and `BaseChatModel`. One functional difference to make it more consistent (see comment)
This commit is contained in:
parent
4ff6aa5c78
commit
0d888a65cb
@ -7,6 +7,7 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
@ -25,7 +26,7 @@ from langchain_core.messages import (
|
|||||||
get_buffer_string,
|
get_buffer_string,
|
||||||
)
|
)
|
||||||
from langchain_core.prompt_values import PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel, Field, validator
|
||||||
from langchain_core.runnables import Runnable, RunnableSerializable
|
from langchain_core.runnables import Runnable, RunnableSerializable
|
||||||
from langchain_core.utils import get_pydantic_field_names
|
from langchain_core.utils import get_pydantic_field_names
|
||||||
|
|
||||||
@ -63,6 +64,12 @@ LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
|
|||||||
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)
|
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_verbosity() -> bool:
|
||||||
|
from langchain_core.globals import get_verbose
|
||||||
|
|
||||||
|
return get_verbose()
|
||||||
|
|
||||||
|
|
||||||
class BaseLanguageModel(
|
class BaseLanguageModel(
|
||||||
RunnableSerializable[LanguageModelInput, LanguageModelOutputVar], ABC
|
RunnableSerializable[LanguageModelInput, LanguageModelOutputVar], ABC
|
||||||
):
|
):
|
||||||
@ -71,6 +78,28 @@ class BaseLanguageModel(
|
|||||||
All language model wrappers inherit from BaseLanguageModel.
|
All language model wrappers inherit from BaseLanguageModel.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
cache: Optional[bool] = None
|
||||||
|
"""Whether to cache the response."""
|
||||||
|
verbose: bool = Field(default_factory=_get_verbosity)
|
||||||
|
"""Whether to print out response text."""
|
||||||
|
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||||
|
"""Callbacks to add to the run trace."""
|
||||||
|
tags: Optional[List[str]] = Field(default=None, exclude=True)
|
||||||
|
"""Tags to add to the run trace."""
|
||||||
|
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
|
||||||
|
"""Metadata to add to the run trace."""
|
||||||
|
|
||||||
|
@validator("verbose", pre=True, always=True)
|
||||||
|
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||||
|
"""If verbose is None, set it.
|
||||||
|
|
||||||
|
This allows users to pass in None as verbose to access the global setting.
|
||||||
|
"""
|
||||||
|
if verbose is None:
|
||||||
|
return _get_verbosity()
|
||||||
|
else:
|
||||||
|
return verbose
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def InputType(self) -> TypeAlias:
|
def InputType(self) -> TypeAlias:
|
||||||
"""Get the input type for this runnable."""
|
"""Get the input type for this runnable."""
|
||||||
@ -257,6 +286,11 @@ class BaseLanguageModel(
|
|||||||
Top model prediction as a message.
|
Top model prediction as a message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {}
|
||||||
|
|
||||||
def get_token_ids(self, text: str) -> List[int]:
|
def get_token_ids(self, text: str) -> List[int]:
|
||||||
"""Return the ordered ids of the tokens in a text.
|
"""Return the ordered ids of the tokens in a text.
|
||||||
|
|
||||||
|
@ -54,12 +54,6 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
|
||||||
def _get_verbosity() -> bool:
|
|
||||||
from langchain_core.globals import get_verbose
|
|
||||||
|
|
||||||
return get_verbose()
|
|
||||||
|
|
||||||
|
|
||||||
def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
|
def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
|
||||||
"""Generate from a stream."""
|
"""Generate from a stream."""
|
||||||
|
|
||||||
@ -125,18 +119,8 @@ def _as_async_iterator(sync_iterator: Callable) -> Callable:
|
|||||||
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||||
"""Base class for Chat models."""
|
"""Base class for Chat models."""
|
||||||
|
|
||||||
cache: Optional[bool] = None
|
|
||||||
"""Whether to cache the response."""
|
|
||||||
verbose: bool = Field(default_factory=_get_verbosity)
|
|
||||||
"""Whether to print out response text."""
|
|
||||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
|
||||||
"""Callbacks to add to the run trace."""
|
|
||||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||||
"""[DEPRECATED] Callback manager to add to the run trace."""
|
"""[DEPRECATED] Callback manager to add to the run trace."""
|
||||||
tags: Optional[List[str]] = Field(default=None, exclude=True)
|
|
||||||
"""Tags to add to the run trace."""
|
|
||||||
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
|
|
||||||
"""Metadata to add to the run trace."""
|
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||||
@ -816,11 +800,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
_stop = list(stop)
|
_stop = list(stop)
|
||||||
return await self._call_async(messages, stop=_stop, **kwargs)
|
return await self._call_async(messages, stop=_stop, **kwargs)
|
||||||
|
|
||||||
@property
|
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
|
||||||
"""Get the identifying parameters."""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Base interface for large language models to expose."""
|
"""Base interface for large language models to expose."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -16,7 +17,6 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
@ -56,19 +56,13 @@ from langchain_core.messages import (
|
|||||||
)
|
)
|
||||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
||||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator, validator
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
|
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
|
||||||
from langchain_core.runnables.config import run_in_executor
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_verbosity() -> bool:
|
|
||||||
from langchain_core.globals import get_verbose
|
|
||||||
|
|
||||||
return get_verbose()
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def _log_error_once(msg: str) -> None:
|
def _log_error_once(msg: str) -> None:
|
||||||
"""Log an error once."""
|
"""Log an error once."""
|
||||||
@ -200,16 +194,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
|
|
||||||
It should take in a prompt and return a string."""
|
It should take in a prompt and return a string."""
|
||||||
|
|
||||||
cache: Optional[bool] = None
|
|
||||||
"""Whether to cache the response."""
|
|
||||||
verbose: bool = Field(default_factory=_get_verbosity)
|
|
||||||
"""Whether to print out response text."""
|
|
||||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
|
||||||
"""Callbacks to add to the run trace."""
|
|
||||||
tags: Optional[List[str]] = Field(default=None, exclude=True)
|
|
||||||
"""Tags to add to the run trace."""
|
|
||||||
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
|
|
||||||
"""Metadata to add to the run trace."""
|
|
||||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||||
"""[DEPRECATED]"""
|
"""[DEPRECATED]"""
|
||||||
|
|
||||||
@ -229,17 +213,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
values["callbacks"] = values.pop("callback_manager", None)
|
values["callbacks"] = values.pop("callback_manager", None)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@validator("verbose", pre=True, always=True)
|
|
||||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
|
||||||
"""If verbose is None, set it.
|
|
||||||
|
|
||||||
This allows users to pass in None as verbose to access the global setting.
|
|
||||||
"""
|
|
||||||
if verbose is None:
|
|
||||||
return _get_verbosity()
|
|
||||||
else:
|
|
||||||
return verbose
|
|
||||||
|
|
||||||
# --- Runnable methods ---
|
# --- Runnable methods ---
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1081,11 +1054,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
content = await self._call_async(text, stop=_stop, **kwargs)
|
content = await self._call_async(text, stop=_stop, **kwargs)
|
||||||
return AIMessage(content=content)
|
return AIMessage(content=content)
|
||||||
|
|
||||||
@property
|
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
|
||||||
"""Get the identifying parameters."""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""Get a string representation of the object for printing."""
|
"""Get a string representation of the object for printing."""
|
||||||
cls_name = f"\033[1m{self.__class__.__name__}\033[0m"
|
cls_name = f"\033[1m{self.__class__.__name__}\033[0m"
|
||||||
|
Loading…
Reference in New Issue
Block a user