mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
experimental[patch]: Migrate pydantic extra to literals (#25194)
Migrate pydantic extra to literals Upgrade to using a literal for specifying the extra which is the recommended approach in pydantic 2. This works correctly also in pydantic v1. ```python from pydantic.v1 import BaseModel class Foo(BaseModel, extra="forbid"): x: int Foo(x=5, y=1) ``` And ```python from pydantic.v1 import BaseModel class Foo(BaseModel): x: int class Config: extra = "forbid" Foo(x=5, y=1) ``` ## Enum -> literal using grit pattern: ``` engine marzano(0.1) language python or { `extra=Extra.allow` => `extra="allow"`, `extra=Extra.forbid` => `extra="forbid"`, `extra=Extra.ignore` => `extra="ignore"` } ``` Resorted attributes in config and removed doc-string in case we will need to deal with going back and forth between pydantic v1 and v2 during the 0.3 release. (This will reduce merge conflicts.) ## Sort attributes in Config: ``` engine marzano(0.1) language python function sort($values) js { return $values.text.split(',').sort().join("\n"); } class_definition($name, $body) as $C where { $name <: `Config`, $body <: block($statements), $values = [], $statements <: some bubble($values) assignment() as $A where { $values += $A }, $body => sort($values), } ```
This commit is contained in:
parent
30fb345342
commit
b9f65e5038
@ -52,8 +52,6 @@ class BabyAGI(Chain, BaseModel): # type: ignore[misc]
|
|||||||
max_iterations: Optional[int] = None
|
max_iterations: Optional[int] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def add_task(self, task: Dict) -> None:
|
def add_task(self, task: Dict) -> None:
|
||||||
|
@ -36,8 +36,6 @@ class GenerativeAgent(BaseModel):
|
|||||||
"""Summary of the events in the plan that the agent took."""
|
"""Summary of the events in the plan that the agent took."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
# LLM-related methods
|
# LLM-related methods
|
||||||
|
@ -14,7 +14,7 @@ from langchain_core.language_models import BaseLanguageModel
|
|||||||
|
|
||||||
from langchain_experimental.llm_bash.bash import BashProcess
|
from langchain_experimental.llm_bash.bash import BashProcess
|
||||||
from langchain_experimental.llm_bash.prompt import PROMPT
|
from langchain_experimental.llm_bash.prompt import PROMPT
|
||||||
from langchain_experimental.pydantic_v1 import Extra, Field, root_validator
|
from langchain_experimental.pydantic_v1 import Field, root_validator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -40,10 +40,8 @@ class LLMBashChain(Chain):
|
|||||||
bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private:
|
bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private:
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
extra = "forbid"
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||||
|
@ -15,7 +15,6 @@ from langchain_core.callbacks.manager import (
|
|||||||
from langchain_core.prompts.base import BasePromptTemplate
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
|
|
||||||
from langchain_experimental.llm_symbolic_math.prompt import PROMPT
|
from langchain_experimental.llm_symbolic_math.prompt import PROMPT
|
||||||
from langchain_experimental.pydantic_v1 import Extra
|
|
||||||
|
|
||||||
|
|
||||||
class LLMSymbolicMathChain(Chain):
|
class LLMSymbolicMathChain(Chain):
|
||||||
@ -38,10 +37,8 @@ class LLMSymbolicMathChain(Chain):
|
|||||||
output_key: str = "answer" #: :meta private:
|
output_key: str = "answer" #: :meta private:
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
extra = "forbid"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
|
@ -17,7 +17,7 @@ from langchain_core.language_models import BaseLanguageModel
|
|||||||
|
|
||||||
from langchain_experimental.pal_chain.colored_object_prompt import COLORED_OBJECT_PROMPT
|
from langchain_experimental.pal_chain.colored_object_prompt import COLORED_OBJECT_PROMPT
|
||||||
from langchain_experimental.pal_chain.math_prompt import MATH_PROMPT
|
from langchain_experimental.pal_chain.math_prompt import MATH_PROMPT
|
||||||
from langchain_experimental.pydantic_v1 import Extra, Field, root_validator
|
from langchain_experimental.pydantic_v1 import Field, root_validator
|
||||||
from langchain_experimental.utilities import PythonREPL
|
from langchain_experimental.utilities import PythonREPL
|
||||||
|
|
||||||
COMMAND_EXECUTION_FUNCTIONS = [
|
COMMAND_EXECUTION_FUNCTIONS = [
|
||||||
@ -169,10 +169,8 @@ class PALChain(Chain):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
extra = "forbid"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
|
@ -26,7 +26,7 @@ from langchain_core.prompts import (
|
|||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_experimental.pydantic_v1 import BaseModel, Extra, root_validator
|
from langchain_experimental.pydantic_v1 import BaseModel, root_validator
|
||||||
from langchain_experimental.rl_chain.metrics import (
|
from langchain_experimental.rl_chain.metrics import (
|
||||||
MetricsTrackerAverage,
|
MetricsTrackerAverage,
|
||||||
MetricsTrackerRollingWindow,
|
MetricsTrackerRollingWindow,
|
||||||
@ -417,10 +417,8 @@ class RLChain(Chain, Generic[TEvent]):
|
|||||||
self.metrics = MetricsTrackerAverage(step=metrics_step)
|
self.metrics = MetricsTrackerAverage(step=metrics_step)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
extra = "forbid"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
|
@ -15,7 +15,7 @@ from langchain_core.prompts.chat import (
|
|||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_experimental.pydantic_v1 import Extra, root_validator
|
from langchain_experimental.pydantic_v1 import root_validator
|
||||||
|
|
||||||
|
|
||||||
class SmartLLMChain(Chain):
|
class SmartLLMChain(Chain):
|
||||||
@ -84,7 +84,7 @@ class SmartLLMChain(Chain):
|
|||||||
history: SmartLLMChainHistory = SmartLLMChainHistory()
|
history: SmartLLMChainHistory = SmartLLMChainHistory()
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = Extra.forbid
|
extra = "forbid"
|
||||||
|
|
||||||
# TODO: move away from `root_validator` since it is deprecated in pydantic v2
|
# TODO: move away from `root_validator` since it is deprecated in pydantic v2
|
||||||
# and causes mypy type-checking failures (hence the `type: ignore`)
|
# and causes mypy type-checking failures (hence the `type: ignore`)
|
||||||
|
@ -15,7 +15,7 @@ from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
|||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
from langchain_experimental.pydantic_v1 import Extra, Field, root_validator
|
from langchain_experimental.pydantic_v1 import Field, root_validator
|
||||||
|
|
||||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||||
SQL_QUERY = "SQLQuery:"
|
SQL_QUERY = "SQLQuery:"
|
||||||
@ -67,10 +67,8 @@ class SQLDatabaseChain(Chain):
|
|||||||
"""The prompt template that should be used by the query checker"""
|
"""The prompt template that should be used by the query checker"""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
extra = "forbid"
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||||
|
@ -10,7 +10,6 @@ from langchain_core.callbacks.manager import (
|
|||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_experimental.pydantic_v1 import Extra
|
|
||||||
from langchain_experimental.tot.checker import ToTChecker
|
from langchain_experimental.tot.checker import ToTChecker
|
||||||
from langchain_experimental.tot.controller import ToTController
|
from langchain_experimental.tot.controller import ToTController
|
||||||
from langchain_experimental.tot.memory import ToTDFSMemory
|
from langchain_experimental.tot.memory import ToTDFSMemory
|
||||||
@ -44,10 +43,8 @@ class ToTChain(Chain):
|
|||||||
verbose_llm: bool = False
|
verbose_llm: bool = False
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
extra = "forbid"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> ToTChain:
|
def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> ToTChain:
|
||||||
|
@ -4,7 +4,6 @@ from langchain.chains.base import Chain
|
|||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_core.pydantic_v1 import Extra
|
|
||||||
|
|
||||||
from langchain_experimental.video_captioning.services.audio_service import (
|
from langchain_experimental.video_captioning.services.audio_service import (
|
||||||
AudioProcessor,
|
AudioProcessor,
|
||||||
@ -38,8 +37,8 @@ class VideoCaptioningChain(Chain):
|
|||||||
use_unclustered_video_models: bool = False
|
use_unclustered_video_models: bool = False
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = Extra.allow
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
|
Loading…
Reference in New Issue
Block a user