From b9f65e503833ae0cabd1e3b6b22ca17a1ea9ca68 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 8 Aug 2024 15:05:54 -0400 Subject: [PATCH] experimental[patch]: Migrate pydantic extra to literals (#25194) Migrate pydantic extra to literals Upgrade to using a literal for specifying the extra which is the recommended approach in pydantic 2. This works correctly also in pydantic v1. ```python from pydantic.v1 import BaseModel class Foo(BaseModel, extra="forbid"): x: int Foo(x=5, y=1) ``` And ```python from pydantic.v1 import BaseModel class Foo(BaseModel): x: int class Config: extra = "forbid" Foo(x=5, y=1) ``` ## Enum -> literal using grit pattern: ``` engine marzano(0.1) language python or { `extra=Extra.allow` => `extra="allow"`, `extra=Extra.forbid` => `extra="forbid"`, `extra=Extra.ignore` => `extra="ignore"` } ``` Resorted attributes in config and removed doc-string in case we will need to deal with going back and forth between pydantic v1 and v2 during the 0.3 release. (This will reduce merge conflicts.) ## Sort attributes in Config: ``` engine marzano(0.1) language python function sort($values) js { return $values.text.split(',').sort().join("\n"); } class_definition($name, $body) as $C where { $name <: `Config`, $body <: block($statements), $values = [], $statements <: some bubble($values) assignment() as $A where { $values += $A }, $body => sort($values), } ``` --- .../autonomous_agents/baby_agi/baby_agi.py | 2 -- .../generative_agents/generative_agent.py | 2 -- libs/experimental/langchain_experimental/llm_bash/base.py | 6 ++---- .../langchain_experimental/llm_symbolic_math/base.py | 5 +---- libs/experimental/langchain_experimental/pal_chain/base.py | 6 ++---- libs/experimental/langchain_experimental/rl_chain/base.py | 6 ++---- libs/experimental/langchain_experimental/smart_llm/base.py | 4 ++-- libs/experimental/langchain_experimental/sql/base.py | 6 ++---- libs/experimental/langchain_experimental/tot/base.py | 5 +---- .../langchain_experimental/video_captioning/base.py | 3 +-- 10 files changed, 13 insertions(+), 32 deletions(-) diff --git a/libs/experimental/langchain_experimental/autonomous_agents/baby_agi/baby_agi.py b/libs/experimental/langchain_experimental/autonomous_agents/baby_agi/baby_agi.py index e559362453f..ade9537dd6d 100644 --- a/libs/experimental/langchain_experimental/autonomous_agents/baby_agi/baby_agi.py +++ b/libs/experimental/langchain_experimental/autonomous_agents/baby_agi/baby_agi.py @@ -52,8 +52,6 @@ class BabyAGI(Chain, BaseModel): # type: ignore[misc] max_iterations: Optional[int] = None class Config: - """Configuration for this pydantic object.""" - arbitrary_types_allowed = True def add_task(self, task: Dict) -> None: diff --git a/libs/experimental/langchain_experimental/generative_agents/generative_agent.py b/libs/experimental/langchain_experimental/generative_agents/generative_agent.py index db677ec24cc..95247e68242 100644 --- a/libs/experimental/langchain_experimental/generative_agents/generative_agent.py +++ b/libs/experimental/langchain_experimental/generative_agents/generative_agent.py @@ -36,8 +36,6 @@ class GenerativeAgent(BaseModel): """Summary of the events in the plan that the agent took.""" class Config: - """Configuration for this pydantic object.""" - arbitrary_types_allowed = True # LLM-related methods diff --git a/libs/experimental/langchain_experimental/llm_bash/base.py b/libs/experimental/langchain_experimental/llm_bash/base.py index 3541631a7ec..9b54a747c38 100644 --- a/libs/experimental/langchain_experimental/llm_bash/base.py +++ b/libs/experimental/langchain_experimental/llm_bash/base.py @@ -14,7 +14,7 @@ from langchain_core.language_models import BaseLanguageModel from langchain_experimental.llm_bash.bash import BashProcess from langchain_experimental.llm_bash.prompt import PROMPT -from langchain_experimental.pydantic_v1 import Extra, Field, root_validator +from langchain_experimental.pydantic_v1 import Field, root_validator logger = logging.getLogger(__name__) @@ -40,10 +40,8 @@ class LLMBashChain(Chain): bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private: class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @root_validator(pre=True) def raise_deprecation(cls, values: Dict) -> Dict: diff --git a/libs/experimental/langchain_experimental/llm_symbolic_math/base.py b/libs/experimental/langchain_experimental/llm_symbolic_math/base.py index 8c5989e1154..8c671038be1 100644 --- a/libs/experimental/langchain_experimental/llm_symbolic_math/base.py +++ b/libs/experimental/langchain_experimental/llm_symbolic_math/base.py @@ -15,7 +15,6 @@ from langchain_core.callbacks.manager import ( from langchain_core.prompts.base import BasePromptTemplate from langchain_experimental.llm_symbolic_math.prompt import PROMPT -from langchain_experimental.pydantic_v1 import Extra class LLMSymbolicMathChain(Chain): @@ -38,10 +37,8 @@ class LLMSymbolicMathChain(Chain): output_key: str = "answer" #: :meta private: class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @property def input_keys(self) -> List[str]: diff --git a/libs/experimental/langchain_experimental/pal_chain/base.py b/libs/experimental/langchain_experimental/pal_chain/base.py index fe47c1263b2..77109118af5 100644 --- a/libs/experimental/langchain_experimental/pal_chain/base.py +++ b/libs/experimental/langchain_experimental/pal_chain/base.py @@ -17,7 +17,7 @@ from langchain_core.language_models import BaseLanguageModel from langchain_experimental.pal_chain.colored_object_prompt import COLORED_OBJECT_PROMPT from langchain_experimental.pal_chain.math_prompt import MATH_PROMPT -from langchain_experimental.pydantic_v1 import Extra, Field, root_validator +from langchain_experimental.pydantic_v1 import Field, root_validator from langchain_experimental.utilities import PythonREPL COMMAND_EXECUTION_FUNCTIONS = [ @@ -169,10 +169,8 @@ class PALChain(Chain): return values class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @property def input_keys(self) -> List[str]: diff --git a/libs/experimental/langchain_experimental/rl_chain/base.py b/libs/experimental/langchain_experimental/rl_chain/base.py index e1f2dedd230..7b15f00fc1f 100644 --- a/libs/experimental/langchain_experimental/rl_chain/base.py +++ b/libs/experimental/langchain_experimental/rl_chain/base.py @@ -26,7 +26,7 @@ from langchain_core.prompts import ( SystemMessagePromptTemplate, ) -from langchain_experimental.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_experimental.pydantic_v1 import BaseModel, root_validator from langchain_experimental.rl_chain.metrics import ( MetricsTrackerAverage, MetricsTrackerRollingWindow, @@ -417,10 +417,8 @@ class RLChain(Chain, Generic[TEvent]): self.metrics = MetricsTrackerAverage(step=metrics_step) class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @property def input_keys(self) -> List[str]: diff --git a/libs/experimental/langchain_experimental/smart_llm/base.py b/libs/experimental/langchain_experimental/smart_llm/base.py index 765e55eaf36..6e25027a2c9 100644 --- a/libs/experimental/langchain_experimental/smart_llm/base.py +++ b/libs/experimental/langchain_experimental/smart_llm/base.py @@ -15,7 +15,7 @@ from langchain_core.prompts.chat import ( HumanMessagePromptTemplate, ) -from langchain_experimental.pydantic_v1 import Extra, root_validator +from langchain_experimental.pydantic_v1 import root_validator class SmartLLMChain(Chain): @@ -84,7 +84,7 @@ class SmartLLMChain(Chain): history: SmartLLMChainHistory = SmartLLMChainHistory() class Config: - extra = Extra.forbid + extra = "forbid" # TODO: move away from `root_validator` since it is deprecated in pydantic v2 # and causes mypy type-checking failures (hence the `type: ignore`) diff --git a/libs/experimental/langchain_experimental/sql/base.py b/libs/experimental/langchain_experimental/sql/base.py index 708055d8915..cae029f7092 100644 --- a/libs/experimental/langchain_experimental/sql/base.py +++ b/libs/experimental/langchain_experimental/sql/base.py @@ -15,7 +15,7 @@ from langchain_core.callbacks.manager import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts.prompt import PromptTemplate -from langchain_experimental.pydantic_v1 import Extra, Field, root_validator +from langchain_experimental.pydantic_v1 import Field, root_validator INTERMEDIATE_STEPS_KEY = "intermediate_steps" SQL_QUERY = "SQLQuery:" @@ -67,10 +67,8 @@ class SQLDatabaseChain(Chain): """The prompt template that should be used by the query checker""" class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @root_validator(pre=True) def raise_deprecation(cls, values: Dict) -> Dict: diff --git a/libs/experimental/langchain_experimental/tot/base.py b/libs/experimental/langchain_experimental/tot/base.py index f1f280a1714..7de31381960 100644 --- a/libs/experimental/langchain_experimental/tot/base.py +++ b/libs/experimental/langchain_experimental/tot/base.py @@ -10,7 +10,6 @@ from langchain_core.callbacks.manager import ( CallbackManagerForChainRun, ) -from langchain_experimental.pydantic_v1 import Extra from langchain_experimental.tot.checker import ToTChecker from langchain_experimental.tot.controller import ToTController from langchain_experimental.tot.memory import ToTDFSMemory @@ -44,10 +43,8 @@ class ToTChain(Chain): verbose_llm: bool = False class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @classmethod def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> ToTChain: diff --git a/libs/experimental/langchain_experimental/video_captioning/base.py b/libs/experimental/langchain_experimental/video_captioning/base.py index 5ae2c02dfa9..2cbbdbb8422 100644 --- a/libs/experimental/langchain_experimental/video_captioning/base.py +++ b/libs/experimental/langchain_experimental/video_captioning/base.py @@ -4,7 +4,6 @@ from langchain.chains.base import Chain from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate -from langchain_core.pydantic_v1 import Extra from langchain_experimental.video_captioning.services.audio_service import ( AudioProcessor, @@ -38,8 +37,8 @@ class VideoCaptioningChain(Chain): use_unclustered_video_models: bool = False class Config: - extra = Extra.allow arbitrary_types_allowed = True + extra = "allow" @property def input_keys(self) -> List[str]: