mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-09 14:35:50 +00:00
Fix .dict() for agent/chain (#11436)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
1e59c44d36
commit
2f490be09b
@ -145,10 +145,13 @@ class BaseSingleActionAgent(BaseModel):
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of agent."""
|
||||
_dict = super().dict()
|
||||
try:
|
||||
_type = self._agent_type
|
||||
except NotImplementedError:
|
||||
_type = None
|
||||
if isinstance(_type, AgentType):
|
||||
_dict["_type"] = str(_type.value)
|
||||
else:
|
||||
elif _type is not None:
|
||||
_dict["_type"] = _type
|
||||
return _dict
|
||||
|
||||
@ -175,6 +178,8 @@ class BaseSingleActionAgent(BaseModel):
|
||||
|
||||
# Fetch dictionary to save
|
||||
agent_dict = self.dict()
|
||||
if "_type" not in agent_dict:
|
||||
raise NotImplementedError(f"Agent {self} does not support saving")
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
|
@ -611,7 +611,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
if self.memory is not None:
|
||||
raise ValueError("Saving of memory is not yet supported.")
|
||||
_dict = super().dict(**kwargs)
|
||||
try:
|
||||
_dict["_type"] = self._chain_type
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
@ -639,6 +642,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
|
||||
# Fetch dictionary to save
|
||||
chain_dict = self.dict()
|
||||
if "_type" not in chain_dict:
|
||||
raise NotImplementedError(f"Chain {self} does not support saving.")
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
|
@ -31,6 +31,13 @@ class SerializedNotImplemented(BaseSerialized):
|
||||
repr: Optional[str]
|
||||
|
||||
|
||||
def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
|
||||
try:
|
||||
return model.__fields__[key].get_default() != value
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
class Serializable(BaseModel, ABC):
|
||||
"""Serializable base class."""
|
||||
|
||||
@ -81,7 +88,7 @@ class Serializable(BaseModel, ABC):
|
||||
return [
|
||||
(k, v)
|
||||
for k, v in super().__repr_args__()
|
||||
if (k not in self.__fields__ or self.__fields__[k].get_default() != v)
|
||||
if (k not in self.__fields__ or try_neq_default(v, k, self))
|
||||
]
|
||||
|
||||
_lc_kwargs = PrivateAttr(default_factory=dict)
|
||||
|
@ -2,6 +2,7 @@ from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar
|
||||
from langchain.schema.runnable.base import (
|
||||
Runnable,
|
||||
RunnableBinding,
|
||||
RunnableGenerator,
|
||||
RunnableLambda,
|
||||
RunnableMap,
|
||||
RunnableSequence,
|
||||
@ -12,8 +13,10 @@ from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||
from langchain.schema.runnable.fallbacks import RunnableWithFallbacks
|
||||
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
||||
from langchain.schema.runnable.utils import ConfigurableField
|
||||
|
||||
__all__ = [
|
||||
"ConfigurableField",
|
||||
"GetLocalVar",
|
||||
"patch_config",
|
||||
"PutLocalVar",
|
||||
@ -24,6 +27,7 @@ __all__ = [
|
||||
"RunnableBinding",
|
||||
"RunnableBranch",
|
||||
"RunnableConfig",
|
||||
"RunnableGenerator",
|
||||
"RunnableLambda",
|
||||
"RunnableMap",
|
||||
"RunnablePassthrough",
|
||||
|
@ -227,7 +227,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
}
|
||||
|
||||
if configurable:
|
||||
return self.default.__class__(**{**self.default.dict(), **configurable})
|
||||
return self.default.__class__(**{**self.default.__dict__, **configurable})
|
||||
else:
|
||||
return self.default
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user