mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-09 14:35:50 +00:00
Fix .dict() for agent/chain (#11436)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
1e59c44d36
commit
2f490be09b
@ -145,10 +145,13 @@ class BaseSingleActionAgent(BaseModel):
|
|||||||
def dict(self, **kwargs: Any) -> Dict:
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
"""Return dictionary representation of agent."""
|
"""Return dictionary representation of agent."""
|
||||||
_dict = super().dict()
|
_dict = super().dict()
|
||||||
_type = self._agent_type
|
try:
|
||||||
|
_type = self._agent_type
|
||||||
|
except NotImplementedError:
|
||||||
|
_type = None
|
||||||
if isinstance(_type, AgentType):
|
if isinstance(_type, AgentType):
|
||||||
_dict["_type"] = str(_type.value)
|
_dict["_type"] = str(_type.value)
|
||||||
else:
|
elif _type is not None:
|
||||||
_dict["_type"] = _type
|
_dict["_type"] = _type
|
||||||
return _dict
|
return _dict
|
||||||
|
|
||||||
@ -175,6 +178,8 @@ class BaseSingleActionAgent(BaseModel):
|
|||||||
|
|
||||||
# Fetch dictionary to save
|
# Fetch dictionary to save
|
||||||
agent_dict = self.dict()
|
agent_dict = self.dict()
|
||||||
|
if "_type" not in agent_dict:
|
||||||
|
raise NotImplementedError(f"Agent {self} does not support saving")
|
||||||
|
|
||||||
if save_path.suffix == ".json":
|
if save_path.suffix == ".json":
|
||||||
with open(file_path, "w") as f:
|
with open(file_path, "w") as f:
|
||||||
|
@ -611,7 +611,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
if self.memory is not None:
|
if self.memory is not None:
|
||||||
raise ValueError("Saving of memory is not yet supported.")
|
raise ValueError("Saving of memory is not yet supported.")
|
||||||
_dict = super().dict(**kwargs)
|
_dict = super().dict(**kwargs)
|
||||||
_dict["_type"] = self._chain_type
|
try:
|
||||||
|
_dict["_type"] = self._chain_type
|
||||||
|
except NotImplementedError:
|
||||||
|
pass
|
||||||
return _dict
|
return _dict
|
||||||
|
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
@ -639,6 +642,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
|
|
||||||
# Fetch dictionary to save
|
# Fetch dictionary to save
|
||||||
chain_dict = self.dict()
|
chain_dict = self.dict()
|
||||||
|
if "_type" not in chain_dict:
|
||||||
|
raise NotImplementedError(f"Chain {self} does not support saving.")
|
||||||
|
|
||||||
if save_path.suffix == ".json":
|
if save_path.suffix == ".json":
|
||||||
with open(file_path, "w") as f:
|
with open(file_path, "w") as f:
|
||||||
|
@ -31,6 +31,13 @@ class SerializedNotImplemented(BaseSerialized):
|
|||||||
repr: Optional[str]
|
repr: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
|
||||||
|
try:
|
||||||
|
return model.__fields__[key].get_default() != value
|
||||||
|
except Exception:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class Serializable(BaseModel, ABC):
|
class Serializable(BaseModel, ABC):
|
||||||
"""Serializable base class."""
|
"""Serializable base class."""
|
||||||
|
|
||||||
@ -81,7 +88,7 @@ class Serializable(BaseModel, ABC):
|
|||||||
return [
|
return [
|
||||||
(k, v)
|
(k, v)
|
||||||
for k, v in super().__repr_args__()
|
for k, v in super().__repr_args__()
|
||||||
if (k not in self.__fields__ or self.__fields__[k].get_default() != v)
|
if (k not in self.__fields__ or try_neq_default(v, k, self))
|
||||||
]
|
]
|
||||||
|
|
||||||
_lc_kwargs = PrivateAttr(default_factory=dict)
|
_lc_kwargs = PrivateAttr(default_factory=dict)
|
||||||
|
@ -2,6 +2,7 @@ from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar
|
|||||||
from langchain.schema.runnable.base import (
|
from langchain.schema.runnable.base import (
|
||||||
Runnable,
|
Runnable,
|
||||||
RunnableBinding,
|
RunnableBinding,
|
||||||
|
RunnableGenerator,
|
||||||
RunnableLambda,
|
RunnableLambda,
|
||||||
RunnableMap,
|
RunnableMap,
|
||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
@ -12,8 +13,10 @@ from langchain.schema.runnable.config import RunnableConfig, patch_config
|
|||||||
from langchain.schema.runnable.fallbacks import RunnableWithFallbacks
|
from langchain.schema.runnable.fallbacks import RunnableWithFallbacks
|
||||||
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||||
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
||||||
|
from langchain.schema.runnable.utils import ConfigurableField
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"ConfigurableField",
|
||||||
"GetLocalVar",
|
"GetLocalVar",
|
||||||
"patch_config",
|
"patch_config",
|
||||||
"PutLocalVar",
|
"PutLocalVar",
|
||||||
@ -24,6 +27,7 @@ __all__ = [
|
|||||||
"RunnableBinding",
|
"RunnableBinding",
|
||||||
"RunnableBranch",
|
"RunnableBranch",
|
||||||
"RunnableConfig",
|
"RunnableConfig",
|
||||||
|
"RunnableGenerator",
|
||||||
"RunnableLambda",
|
"RunnableLambda",
|
||||||
"RunnableMap",
|
"RunnableMap",
|
||||||
"RunnablePassthrough",
|
"RunnablePassthrough",
|
||||||
|
@ -227,7 +227,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if configurable:
|
if configurable:
|
||||||
return self.default.__class__(**{**self.default.dict(), **configurable})
|
return self.default.__class__(**{**self.default.__dict__, **configurable})
|
||||||
else:
|
else:
|
||||||
return self.default
|
return self.default
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user