experimental[patch]: Migrate pydantic extra to literals (#25194)

Migrate pydantic extra to literals

Upgrade to using a literal for specifying the extra which is the
recommended approach in pydantic 2.

This works correctly also in pydantic v1.

```python
from pydantic.v1 import BaseModel

class Foo(BaseModel, extra="forbid"):
    x: int

Foo(x=5, y=1)
```

And 


```python
from pydantic.v1 import BaseModel

class Foo(BaseModel):
    x: int

    class Config:
      extra = "forbid"

Foo(x=5, y=1)
```


## Enum -> literal using grit pattern:

```
engine marzano(0.1)
language python
or {
    `extra=Extra.allow` => `extra="allow"`,
    `extra=Extra.forbid` => `extra="forbid"`,
    `extra=Extra.ignore` => `extra="ignore"`
}
```

Resorted attributes in config and removed doc-string in case we will
need to deal with going back and forth between pydantic v1 and v2 during
the 0.3 release. (This will reduce merge conflicts.)


## Sort attributes in Config:

```
engine marzano(0.1)
language python


function sort($values) js {
    return $values.text.split(',').sort().join("\n");
}


class_definition($name, $body) as $C where {
    $name <: `Config`,
    $body <: block($statements),
    $values = [],
    $statements <: some bubble($values) assignment() as $A where {
        $values += $A
    },
    $body => sort($values),
}

```
This commit is contained in:
Eugene Yurtsev 2024-08-08 15:05:54 -04:00 committed by GitHub
parent 30fb345342
commit b9f65e5038
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 13 additions and 32 deletions

View File

@ -52,8 +52,6 @@ class BabyAGI(Chain, BaseModel): # type: ignore[misc]
max_iterations: Optional[int] = None max_iterations: Optional[int] = None
class Config: class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True arbitrary_types_allowed = True
def add_task(self, task: Dict) -> None: def add_task(self, task: Dict) -> None:

View File

@ -36,8 +36,6 @@ class GenerativeAgent(BaseModel):
"""Summary of the events in the plan that the agent took.""" """Summary of the events in the plan that the agent took."""
class Config: class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True arbitrary_types_allowed = True
# LLM-related methods # LLM-related methods

View File

@ -14,7 +14,7 @@ from langchain_core.language_models import BaseLanguageModel
from langchain_experimental.llm_bash.bash import BashProcess from langchain_experimental.llm_bash.bash import BashProcess
from langchain_experimental.llm_bash.prompt import PROMPT from langchain_experimental.llm_bash.prompt import PROMPT
from langchain_experimental.pydantic_v1 import Extra, Field, root_validator from langchain_experimental.pydantic_v1 import Field, root_validator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -40,10 +40,8 @@ class LLMBashChain(Chain):
bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private: bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private:
class Config: class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
extra = "forbid"
@root_validator(pre=True) @root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict: def raise_deprecation(cls, values: Dict) -> Dict:

View File

@ -15,7 +15,6 @@ from langchain_core.callbacks.manager import (
from langchain_core.prompts.base import BasePromptTemplate from langchain_core.prompts.base import BasePromptTemplate
from langchain_experimental.llm_symbolic_math.prompt import PROMPT from langchain_experimental.llm_symbolic_math.prompt import PROMPT
from langchain_experimental.pydantic_v1 import Extra
class LLMSymbolicMathChain(Chain): class LLMSymbolicMathChain(Chain):
@ -38,10 +37,8 @@ class LLMSymbolicMathChain(Chain):
output_key: str = "answer" #: :meta private: output_key: str = "answer" #: :meta private:
class Config: class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
extra = "forbid"
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:

View File

@ -17,7 +17,7 @@ from langchain_core.language_models import BaseLanguageModel
from langchain_experimental.pal_chain.colored_object_prompt import COLORED_OBJECT_PROMPT from langchain_experimental.pal_chain.colored_object_prompt import COLORED_OBJECT_PROMPT
from langchain_experimental.pal_chain.math_prompt import MATH_PROMPT from langchain_experimental.pal_chain.math_prompt import MATH_PROMPT
from langchain_experimental.pydantic_v1 import Extra, Field, root_validator from langchain_experimental.pydantic_v1 import Field, root_validator
from langchain_experimental.utilities import PythonREPL from langchain_experimental.utilities import PythonREPL
COMMAND_EXECUTION_FUNCTIONS = [ COMMAND_EXECUTION_FUNCTIONS = [
@ -169,10 +169,8 @@ class PALChain(Chain):
return values return values
class Config: class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
extra = "forbid"
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:

View File

@ -26,7 +26,7 @@ from langchain_core.prompts import (
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
) )
from langchain_experimental.pydantic_v1 import BaseModel, Extra, root_validator from langchain_experimental.pydantic_v1 import BaseModel, root_validator
from langchain_experimental.rl_chain.metrics import ( from langchain_experimental.rl_chain.metrics import (
MetricsTrackerAverage, MetricsTrackerAverage,
MetricsTrackerRollingWindow, MetricsTrackerRollingWindow,
@ -417,10 +417,8 @@ class RLChain(Chain, Generic[TEvent]):
self.metrics = MetricsTrackerAverage(step=metrics_step) self.metrics = MetricsTrackerAverage(step=metrics_step)
class Config: class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
extra = "forbid"
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:

View File

@ -15,7 +15,7 @@ from langchain_core.prompts.chat import (
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
) )
from langchain_experimental.pydantic_v1 import Extra, root_validator from langchain_experimental.pydantic_v1 import root_validator
class SmartLLMChain(Chain): class SmartLLMChain(Chain):
@ -84,7 +84,7 @@ class SmartLLMChain(Chain):
history: SmartLLMChainHistory = SmartLLMChainHistory() history: SmartLLMChainHistory = SmartLLMChainHistory()
class Config: class Config:
extra = Extra.forbid extra = "forbid"
# TODO: move away from `root_validator` since it is deprecated in pydantic v2 # TODO: move away from `root_validator` since it is deprecated in pydantic v2
# and causes mypy type-checking failures (hence the `type: ignore`) # and causes mypy type-checking failures (hence the `type: ignore`)

View File

@ -15,7 +15,7 @@ from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
from langchain_experimental.pydantic_v1 import Extra, Field, root_validator from langchain_experimental.pydantic_v1 import Field, root_validator
INTERMEDIATE_STEPS_KEY = "intermediate_steps" INTERMEDIATE_STEPS_KEY = "intermediate_steps"
SQL_QUERY = "SQLQuery:" SQL_QUERY = "SQLQuery:"
@ -67,10 +67,8 @@ class SQLDatabaseChain(Chain):
"""The prompt template that should be used by the query checker""" """The prompt template that should be used by the query checker"""
class Config: class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
extra = "forbid"
@root_validator(pre=True) @root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict: def raise_deprecation(cls, values: Dict) -> Dict:

View File

@ -10,7 +10,6 @@ from langchain_core.callbacks.manager import (
CallbackManagerForChainRun, CallbackManagerForChainRun,
) )
from langchain_experimental.pydantic_v1 import Extra
from langchain_experimental.tot.checker import ToTChecker from langchain_experimental.tot.checker import ToTChecker
from langchain_experimental.tot.controller import ToTController from langchain_experimental.tot.controller import ToTController
from langchain_experimental.tot.memory import ToTDFSMemory from langchain_experimental.tot.memory import ToTDFSMemory
@ -44,10 +43,8 @@ class ToTChain(Chain):
verbose_llm: bool = False verbose_llm: bool = False
class Config: class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
extra = "forbid"
@classmethod @classmethod
def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> ToTChain: def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> ToTChain:

View File

@ -4,7 +4,6 @@ from langchain.chains.base import Chain
from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import Extra
from langchain_experimental.video_captioning.services.audio_service import ( from langchain_experimental.video_captioning.services.audio_service import (
AudioProcessor, AudioProcessor,
@ -38,8 +37,8 @@ class VideoCaptioningChain(Chain):
use_unclustered_video_models: bool = False use_unclustered_video_models: bool = False
class Config: class Config:
extra = Extra.allow
arbitrary_types_allowed = True arbitrary_types_allowed = True
extra = "allow"
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]: