Ensure dict() does not raise not implemented error, which should instead be raised in our custom method save()

This commit is contained in:
Nuno Campos 2023-10-18 09:44:41 +01:00
parent 392df7b2e3
commit 202acce0c9
6 changed files with 38 additions and 23 deletions

View File

@ -274,7 +274,10 @@ class BaseMultiActionAgent(BaseModel):
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent.""" """Return dictionary representation of agent."""
_dict = super().dict() _dict = super().dict()
_dict["_type"] = str(self._agent_type) try:
_dict["_type"] = str(self._agent_type)
except NotImplementedError:
pass
return _dict return _dict
def save(self, file_path: Union[Path, str]) -> None: def save(self, file_path: Union[Path, str]) -> None:
@ -295,11 +298,13 @@ class BaseMultiActionAgent(BaseModel):
else: else:
save_path = file_path save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save # Fetch dictionary to save
agent_dict = self.dict() agent_dict = self.dict()
if "_type" not in agent_dict:
raise NotImplementedError(f"Agent {self} does not support saving.")
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
if save_path.suffix == ".json": if save_path.suffix == ".json":
with open(file_path, "w") as f: with open(file_path, "w") as f:

View File

@ -610,8 +610,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
chain.dict(exclude_unset=True) chain.dict(exclude_unset=True)
# -> {"_type": "foo", "verbose": False, ...} # -> {"_type": "foo", "verbose": False, ...}
""" """
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
_dict = super().dict(**kwargs) _dict = super().dict(**kwargs)
try: try:
_dict["_type"] = self._chain_type _dict["_type"] = self._chain_type
@ -633,6 +631,14 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
chain.save(file_path="path/chain.yaml") chain.save(file_path="path/chain.yaml")
""" """
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
# Fetch dictionary to save
chain_dict = self.dict()
if "_type" not in chain_dict:
raise NotImplementedError(f"Chain {self} does not support saving.")
# Convert file to Path object. # Convert file to Path object.
if isinstance(file_path, str): if isinstance(file_path, str):
save_path = Path(file_path) save_path = Path(file_path)
@ -642,11 +648,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
directory_path = save_path.parent directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True) directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
chain_dict = self.dict()
if "_type" not in chain_dict:
raise NotImplementedError(f"Chain {self} does not support saving.")
if save_path.suffix == ".json": if save_path.suffix == ".json":
with open(file_path, "w") as f: with open(file_path, "w") as f:
json.dump(chain_dict, f, indent=4) json.dump(chain_dict, f, indent=4)

View File

@ -1,6 +1,7 @@
"""Prompt template that contains few shot examples.""" """Prompt template that contains few shot examples."""
from __future__ import annotations from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from langchain.prompts.base import ( from langchain.prompts.base import (
@ -151,11 +152,10 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
"""Return the prompt type key.""" """Return the prompt type key."""
return "few_shot" return "few_shot"
def dict(self, **kwargs: Any) -> Dict: def save(self, file_path: Union[Path, str]) -> None:
"""Return a dictionary of the prompt."""
if self.example_selector: if self.example_selector:
raise ValueError("Saving an example selector is not currently supported") raise ValueError("Saving an example selector is not currently supported")
return super().dict(**kwargs) return super().save(file_path)
class FewShotChatMessagePromptTemplate( class FewShotChatMessagePromptTemplate(

View File

@ -1,5 +1,6 @@
"""Prompt template that contains few shot examples.""" """Prompt template that contains few shot examples."""
from typing import Any, Dict, List, Optional from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING, StringPromptTemplate from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING, StringPromptTemplate
from langchain.prompts.example_selector.base import BaseExampleSelector from langchain.prompts.example_selector.base import BaseExampleSelector
@ -140,8 +141,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
"""Return the prompt type key.""" """Return the prompt type key."""
return "few_shot_with_templates" return "few_shot_with_templates"
def dict(self, **kwargs: Any) -> Dict: def save(self, file_path: Union[Path, str]) -> None:
"""Return a dictionary of the prompt."""
if self.example_selector: if self.example_selector:
raise ValueError("Saving an example selector is not currently supported") raise ValueError("Saving an example selector is not currently supported")
return super().dict(**kwargs) return super().save(file_path)

View File

@ -298,7 +298,10 @@ class BaseOutputParser(
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser.""" """Return dictionary representation of output parser."""
output_parser_dict = super().dict(**kwargs) output_parser_dict = super().dict(**kwargs)
output_parser_dict["_type"] = self._type try:
output_parser_dict["_type"] = self._type
except NotImplementedError:
pass
return output_parser_dict return output_parser_dict

View File

@ -132,7 +132,10 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of prompt.""" """Return dictionary representation of prompt."""
prompt_dict = super().dict(**kwargs) prompt_dict = super().dict(**kwargs)
prompt_dict["_type"] = self._prompt_type try:
prompt_dict["_type"] = self._prompt_type
except NotImplementedError:
pass
return prompt_dict return prompt_dict
def save(self, file_path: Union[Path, str]) -> None: def save(self, file_path: Union[Path, str]) -> None:
@ -148,6 +151,12 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
""" """
if self.partial_variables: if self.partial_variables:
raise ValueError("Cannot save prompt with partial variables.") raise ValueError("Cannot save prompt with partial variables.")
# Fetch dictionary to save
prompt_dict = self.dict()
if "_type" not in prompt_dict:
raise NotImplementedError(f"Prompt {self} does not support saving.")
# Convert file to Path object. # Convert file to Path object.
if isinstance(file_path, str): if isinstance(file_path, str):
save_path = Path(file_path) save_path = Path(file_path)
@ -157,9 +166,6 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
directory_path = save_path.parent directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True) directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
prompt_dict = self.dict()
if save_path.suffix == ".json": if save_path.suffix == ".json":
with open(file_path, "w") as f: with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4) json.dump(prompt_dict, f, indent=4)