mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 17:38:36 +00:00
Ensure dict() does not raise not implemented error, which should instead be raised in our custom method save()
This commit is contained in:
parent
392df7b2e3
commit
202acce0c9
@ -274,7 +274,10 @@ class BaseMultiActionAgent(BaseModel):
|
|||||||
def dict(self, **kwargs: Any) -> Dict:
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
"""Return dictionary representation of agent."""
|
"""Return dictionary representation of agent."""
|
||||||
_dict = super().dict()
|
_dict = super().dict()
|
||||||
_dict["_type"] = str(self._agent_type)
|
try:
|
||||||
|
_dict["_type"] = str(self._agent_type)
|
||||||
|
except NotImplementedError:
|
||||||
|
pass
|
||||||
return _dict
|
return _dict
|
||||||
|
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
@ -295,11 +298,13 @@ class BaseMultiActionAgent(BaseModel):
|
|||||||
else:
|
else:
|
||||||
save_path = file_path
|
save_path = file_path
|
||||||
|
|
||||||
directory_path = save_path.parent
|
|
||||||
directory_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Fetch dictionary to save
|
# Fetch dictionary to save
|
||||||
agent_dict = self.dict()
|
agent_dict = self.dict()
|
||||||
|
if "_type" not in agent_dict:
|
||||||
|
raise NotImplementedError(f"Agent {self} does not support saving.")
|
||||||
|
|
||||||
|
directory_path = save_path.parent
|
||||||
|
directory_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if save_path.suffix == ".json":
|
if save_path.suffix == ".json":
|
||||||
with open(file_path, "w") as f:
|
with open(file_path, "w") as f:
|
||||||
|
@ -610,8 +610,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
chain.dict(exclude_unset=True)
|
chain.dict(exclude_unset=True)
|
||||||
# -> {"_type": "foo", "verbose": False, ...}
|
# -> {"_type": "foo", "verbose": False, ...}
|
||||||
"""
|
"""
|
||||||
if self.memory is not None:
|
|
||||||
raise ValueError("Saving of memory is not yet supported.")
|
|
||||||
_dict = super().dict(**kwargs)
|
_dict = super().dict(**kwargs)
|
||||||
try:
|
try:
|
||||||
_dict["_type"] = self._chain_type
|
_dict["_type"] = self._chain_type
|
||||||
@ -633,6 +631,14 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
|
|
||||||
chain.save(file_path="path/chain.yaml")
|
chain.save(file_path="path/chain.yaml")
|
||||||
"""
|
"""
|
||||||
|
if self.memory is not None:
|
||||||
|
raise ValueError("Saving of memory is not yet supported.")
|
||||||
|
|
||||||
|
# Fetch dictionary to save
|
||||||
|
chain_dict = self.dict()
|
||||||
|
if "_type" not in chain_dict:
|
||||||
|
raise NotImplementedError(f"Chain {self} does not support saving.")
|
||||||
|
|
||||||
# Convert file to Path object.
|
# Convert file to Path object.
|
||||||
if isinstance(file_path, str):
|
if isinstance(file_path, str):
|
||||||
save_path = Path(file_path)
|
save_path = Path(file_path)
|
||||||
@ -642,11 +648,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
directory_path = save_path.parent
|
directory_path = save_path.parent
|
||||||
directory_path.mkdir(parents=True, exist_ok=True)
|
directory_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Fetch dictionary to save
|
|
||||||
chain_dict = self.dict()
|
|
||||||
if "_type" not in chain_dict:
|
|
||||||
raise NotImplementedError(f"Chain {self} does not support saving.")
|
|
||||||
|
|
||||||
if save_path.suffix == ".json":
|
if save_path.suffix == ".json":
|
||||||
with open(file_path, "w") as f:
|
with open(file_path, "w") as f:
|
||||||
json.dump(chain_dict, f, indent=4)
|
json.dump(chain_dict, f, indent=4)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Prompt template that contains few shot examples."""
|
"""Prompt template that contains few shot examples."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from langchain.prompts.base import (
|
from langchain.prompts.base import (
|
||||||
@ -151,11 +152,10 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
|||||||
"""Return the prompt type key."""
|
"""Return the prompt type key."""
|
||||||
return "few_shot"
|
return "few_shot"
|
||||||
|
|
||||||
def dict(self, **kwargs: Any) -> Dict:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
"""Return a dictionary of the prompt."""
|
|
||||||
if self.example_selector:
|
if self.example_selector:
|
||||||
raise ValueError("Saving an example selector is not currently supported")
|
raise ValueError("Saving an example selector is not currently supported")
|
||||||
return super().dict(**kwargs)
|
return super().save(file_path)
|
||||||
|
|
||||||
|
|
||||||
class FewShotChatMessagePromptTemplate(
|
class FewShotChatMessagePromptTemplate(
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Prompt template that contains few shot examples."""
|
"""Prompt template that contains few shot examples."""
|
||||||
from typing import Any, Dict, List, Optional
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING, StringPromptTemplate
|
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING, StringPromptTemplate
|
||||||
from langchain.prompts.example_selector.base import BaseExampleSelector
|
from langchain.prompts.example_selector.base import BaseExampleSelector
|
||||||
@ -140,8 +141,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
|||||||
"""Return the prompt type key."""
|
"""Return the prompt type key."""
|
||||||
return "few_shot_with_templates"
|
return "few_shot_with_templates"
|
||||||
|
|
||||||
def dict(self, **kwargs: Any) -> Dict:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
"""Return a dictionary of the prompt."""
|
|
||||||
if self.example_selector:
|
if self.example_selector:
|
||||||
raise ValueError("Saving an example selector is not currently supported")
|
raise ValueError("Saving an example selector is not currently supported")
|
||||||
return super().dict(**kwargs)
|
return super().save(file_path)
|
||||||
|
@ -298,7 +298,10 @@ class BaseOutputParser(
|
|||||||
def dict(self, **kwargs: Any) -> Dict:
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
"""Return dictionary representation of output parser."""
|
"""Return dictionary representation of output parser."""
|
||||||
output_parser_dict = super().dict(**kwargs)
|
output_parser_dict = super().dict(**kwargs)
|
||||||
output_parser_dict["_type"] = self._type
|
try:
|
||||||
|
output_parser_dict["_type"] = self._type
|
||||||
|
except NotImplementedError:
|
||||||
|
pass
|
||||||
return output_parser_dict
|
return output_parser_dict
|
||||||
|
|
||||||
|
|
||||||
|
@ -132,7 +132,10 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
|||||||
def dict(self, **kwargs: Any) -> Dict:
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
"""Return dictionary representation of prompt."""
|
"""Return dictionary representation of prompt."""
|
||||||
prompt_dict = super().dict(**kwargs)
|
prompt_dict = super().dict(**kwargs)
|
||||||
prompt_dict["_type"] = self._prompt_type
|
try:
|
||||||
|
prompt_dict["_type"] = self._prompt_type
|
||||||
|
except NotImplementedError:
|
||||||
|
pass
|
||||||
return prompt_dict
|
return prompt_dict
|
||||||
|
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
@ -148,6 +151,12 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
|||||||
"""
|
"""
|
||||||
if self.partial_variables:
|
if self.partial_variables:
|
||||||
raise ValueError("Cannot save prompt with partial variables.")
|
raise ValueError("Cannot save prompt with partial variables.")
|
||||||
|
|
||||||
|
# Fetch dictionary to save
|
||||||
|
prompt_dict = self.dict()
|
||||||
|
if "_type" not in prompt_dict:
|
||||||
|
raise NotImplementedError(f"Prompt {self} does not support saving.")
|
||||||
|
|
||||||
# Convert file to Path object.
|
# Convert file to Path object.
|
||||||
if isinstance(file_path, str):
|
if isinstance(file_path, str):
|
||||||
save_path = Path(file_path)
|
save_path = Path(file_path)
|
||||||
@ -157,9 +166,6 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
|||||||
directory_path = save_path.parent
|
directory_path = save_path.parent
|
||||||
directory_path.mkdir(parents=True, exist_ok=True)
|
directory_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Fetch dictionary to save
|
|
||||||
prompt_dict = self.dict()
|
|
||||||
|
|
||||||
if save_path.suffix == ".json":
|
if save_path.suffix == ".json":
|
||||||
with open(file_path, "w") as f:
|
with open(file_path, "w") as f:
|
||||||
json.dump(prompt_dict, f, indent=4)
|
json.dump(prompt_dict, f, indent=4)
|
||||||
|
Loading…
Reference in New Issue
Block a user