From c72e522e967672179089bdc683716c29e21b5906 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 8 Aug 2024 13:27:27 -0400 Subject: [PATCH] langchain[patch]: Upgrade pydantic extra (#25186) Upgrade to using a literal for specifying the extra which is the recommended approach in pydantic 2. This works correctly also in pydantic v1. ```python from pydantic.v1 import BaseModel class Foo(BaseModel, extra="forbid"): x: int Foo(x=5, y=1) ``` And ```python from pydantic.v1 import BaseModel class Foo(BaseModel): x: int class Config: extra = "forbid" Foo(x=5, y=1) ``` ## Enum -> literal using grit pattern: ``` engine marzano(0.1) language python or { `extra=Extra.allow` => `extra="allow"`, `extra=Extra.forbid` => `extra="forbid"`, `extra=Extra.ignore` => `extra="ignore"` } ``` Resorted attributes in config and removed doc-string in case we will need to deal with going back and forth between pydantic v1 and v2 during the 0.3 release. (This will reduce merge conflicts.) ## Sort attributes in Config: ``` engine marzano(0.1) language python function sort($values) js { return $values.text.split(',').sort().join("\n"); } class_definition($name, $body) as $C where { $name <: `Config`, $body <: block($statements), $values = [], $statements <: some bubble($values) assignment() as $A where { $values += $A }, $body => sort($values), } ``` --- libs/langchain/langchain/agents/agent.py | 4 ---- .../agents/agent_toolkits/vectorstore/toolkit.py | 6 ------ libs/langchain/langchain/chains/base.py | 2 -- .../langchain/chains/combine_documents/map_reduce.py | 6 ++---- .../langchain/chains/combine_documents/map_rerank.py | 6 ++---- .../langchain/chains/combine_documents/reduce.py | 5 +---- .../langchain/chains/combine_documents/refine.py | 6 ++---- .../langchain/chains/combine_documents/stuff.py | 6 ++---- libs/langchain/langchain/chains/conversation/base.py | 6 ++---- .../langchain/chains/conversational_retrieval/base.py | 8 +++----- .../langchain/chains/elasticsearch_database/base.py | 6 ++---- libs/langchain/langchain/chains/hyde/base.py | 5 +---- libs/langchain/langchain/chains/llm.py | 6 ++---- libs/langchain/langchain/chains/llm_checker/base.py | 6 ++---- libs/langchain/langchain/chains/llm_math/base.py | 6 ++---- .../langchain/chains/llm_summarization_checker/base.py | 6 ++---- libs/langchain/langchain/chains/mapreduce.py | 5 +---- libs/langchain/langchain/chains/natbot/base.py | 6 ++---- .../langchain/langchain/chains/qa_with_sources/base.py | 6 ++---- .../langchain/chains/query_constructor/schema.py | 2 -- libs/langchain/langchain/chains/retrieval_qa/base.py | 8 +++----- libs/langchain/langchain/chains/router/base.py | 5 +---- .../langchain/chains/router/embedding_router.py | 5 +---- libs/langchain/langchain/chains/sequential.py | 10 +++------- .../evaluation/agents/trajectory_eval_chain.py | 6 ++---- .../langchain/evaluation/comparison/eval_chain.py | 6 ++---- .../langchain/evaluation/criteria/eval_chain.py | 6 ++---- .../langchain/evaluation/embedding_distance/base.py | 2 -- libs/langchain/langchain/evaluation/qa/eval_chain.py | 9 ++------- .../langchain/evaluation/scoring/eval_chain.py | 6 ++---- libs/langchain/langchain/indexes/vectorstore.py | 10 +++------- libs/langchain/langchain/memory/entity.py | 2 -- .../langchain/retrievers/contextual_compression.py | 2 -- .../langchain/retrievers/document_compressors/base.py | 2 -- .../retrievers/document_compressors/cohere_rerank.py | 6 ++---- .../document_compressors/cross_encoder_rerank.py | 5 +---- .../document_compressors/embeddings_filter.py | 2 -- libs/langchain/langchain/retrievers/self_query/base.py | 4 +--- .../langchain/retrievers/time_weighted_retriever.py | 2 -- libs/langchain/tests/unit_tests/load/test_dump.py | 2 -- 40 files changed, 55 insertions(+), 154 deletions(-) diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index ce84c188386..987f15b49d4 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -420,8 +420,6 @@ class RunnableAgent(BaseSingleActionAgent): """ class Config: - """Configuration for this pydantic object.""" - arbitrary_types_allowed = True @property @@ -530,8 +528,6 @@ class RunnableMultiActionAgent(BaseMultiActionAgent): """ class Config: - """Configuration for this pydantic object.""" - arbitrary_types_allowed = True @property diff --git a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py index 19ad3497a49..c256ced143b 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py @@ -16,8 +16,6 @@ class VectorStoreInfo(BaseModel): description: str class Config: - """Configuration for this pydantic object.""" - arbitrary_types_allowed = True @@ -28,8 +26,6 @@ class VectorStoreToolkit(BaseToolkit): llm: BaseLanguageModel class Config: - """Configuration for this pydantic object.""" - arbitrary_types_allowed = True def get_tools(self) -> List[BaseTool]: @@ -71,8 +67,6 @@ class VectorStoreRouterToolkit(BaseToolkit): llm: BaseLanguageModel class Config: - """Configuration for this pydantic object.""" - arbitrary_types_allowed = True def get_tools(self) -> List[BaseTool]: diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 9be0b804d97..4771b83925d 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -97,8 +97,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): """[DEPRECATED] Use `callbacks` instead.""" class Config: - """Configuration for this pydantic object.""" - arbitrary_types_allowed = True def get_input_schema( diff --git a/libs/langchain/langchain/chains/combine_documents/map_reduce.py b/libs/langchain/langchain/chains/combine_documents/map_reduce.py index 9e4981a3405..229ed45740c 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/map_reduce.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type from langchain_core.callbacks import Callbacks from langchain_core.documents import Document -from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.utils import create_model @@ -127,10 +127,8 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): return _output_keys class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @root_validator(pre=True) def get_reduce_chain(cls, values: Dict) -> Dict: diff --git a/libs/langchain/langchain/chains/combine_documents/map_rerank.py b/libs/langchain/langchain/chains/combine_documents/map_rerank.py index 214f47bc460..e05592caf11 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_rerank.py +++ b/libs/langchain/langchain/chains/combine_documents/map_rerank.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast from langchain_core.callbacks import Callbacks from langchain_core.documents import Document -from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.utils import create_model @@ -75,10 +75,8 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain): Intermediate steps include the results of calling llm_chain on each document.""" class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" def get_output_schema( self, config: Optional[RunnableConfig] = None diff --git a/libs/langchain/langchain/chains/combine_documents/reduce.py b/libs/langchain/langchain/chains/combine_documents/reduce.py index a55c067d168..7b2cd6c89a3 100644 --- a/libs/langchain/langchain/chains/combine_documents/reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/reduce.py @@ -6,7 +6,6 @@ from typing import Any, Callable, List, Optional, Protocol, Tuple from langchain_core.callbacks import Callbacks from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Extra from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -206,10 +205,8 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): Otherwise, after it reaches the max number, it will throw an error""" class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @property def _collapse_chain(self) -> BaseCombineDocumentsChain: diff --git a/libs/langchain/langchain/chains/combine_documents/refine.py b/libs/langchain/langchain/chains/combine_documents/refine.py index ff0efdd95ff..cf2f5d9e92f 100644 --- a/libs/langchain/langchain/chains/combine_documents/refine.py +++ b/libs/langchain/langchain/chains/combine_documents/refine.py @@ -8,7 +8,7 @@ from langchain_core.callbacks import Callbacks from langchain_core.documents import Document from langchain_core.prompts import BasePromptTemplate, format_document from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.pydantic_v1 import Field, root_validator from langchain.chains.combine_documents.base import ( BaseCombineDocumentsChain, @@ -99,10 +99,8 @@ class RefineDocumentsChain(BaseCombineDocumentsChain): return _output_keys class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @root_validator(pre=True) def get_return_intermediate_steps(cls, values: Dict) -> Dict: diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index 3157ebf5b5b..544adc50b7a 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -7,7 +7,7 @@ from langchain_core.documents import Document from langchain_core.language_models import LanguageModelLike from langchain_core.output_parsers import BaseOutputParser, StrOutputParser from langchain_core.prompts import BasePromptTemplate, format_document -from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.runnables import Runnable, RunnablePassthrough from langchain.chains.combine_documents.base import ( @@ -147,10 +147,8 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): """The string with which to join the formatted documents""" class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @root_validator(pre=True) def get_default_document_variable_name(cls, values: Dict) -> Dict: diff --git a/libs/langchain/langchain/chains/conversation/base.py b/libs/langchain/langchain/chains/conversation/base.py index f9113c7415d..2f873affac4 100644 --- a/libs/langchain/langchain/chains/conversation/base.py +++ b/libs/langchain/langchain/chains/conversation/base.py @@ -5,7 +5,7 @@ from typing import Dict, List from langchain_core._api import deprecated from langchain_core.memory import BaseMemory from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.pydantic_v1 import Field, root_validator from langchain.chains.conversation.prompt import PROMPT from langchain.chains.llm import LLMChain @@ -111,10 +111,8 @@ class ConversationChain(LLMChain): output_key: str = "response" #: :meta private: class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @classmethod def is_lc_serializable(cls) -> bool: diff --git a/libs/langchain/langchain/chains/conversational_retrieval/base.py b/libs/langchain/langchain/chains/conversational_retrieval/base.py index f2ab3acf403..8111978268e 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/base.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/base.py @@ -18,7 +18,7 @@ from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from langchain_core.messages import BaseMessage from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import RunnableConfig from langchain_core.vectorstores import VectorStore @@ -97,11 +97,9 @@ class BaseConversationalRetrievalChain(Chain): are found for the question. """ class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True allow_population_by_field_name = True + arbitrary_types_allowed = True + extra = "forbid" @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/elasticsearch_database/base.py b/libs/langchain/langchain/chains/elasticsearch_database/base.py index 4ea74f1cd37..2e4b97c8dee 100644 --- a/libs/langchain/langchain/chains/elasticsearch_database/base.py +++ b/libs/langchain/langchain/chains/elasticsearch_database/base.py @@ -9,7 +9,7 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseLLMOutputParser from langchain_core.output_parsers.json import SimpleJsonOutputParser from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.pydantic_v1 import root_validator from langchain.chains.base import Chain from langchain.chains.elasticsearch_database.prompts import ANSWER_PROMPT, DSL_PROMPT @@ -52,10 +52,8 @@ class ElasticsearchDatabaseChain(Chain): """Whether or not to return the intermediate steps along with the final answer.""" class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @root_validator(pre=False, skip_on_failure=True) def validate_indices(cls, values: dict) -> dict: diff --git a/libs/langchain/langchain/chains/hyde/base.py b/libs/langchain/langchain/chains/hyde/base.py index d435c4e12a5..851e76c1599 100644 --- a/libs/langchain/langchain/chains/hyde/base.py +++ b/libs/langchain/langchain/chains/hyde/base.py @@ -12,7 +12,6 @@ from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Extra from langchain.chains.base import Chain from langchain.chains.hyde.prompts import PROMPT_MAP @@ -29,10 +28,8 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings): llm_chain: LLMChain class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/llm.py b/libs/langchain/langchain/chains/llm.py index d36f39126f2..feb3468fc44 100644 --- a/libs/langchain/langchain/chains/llm.py +++ b/libs/langchain/langchain/chains/llm.py @@ -23,7 +23,7 @@ from langchain_core.output_parsers import BaseLLMOutputParser, StrOutputParser from langchain_core.outputs import ChatGeneration, Generation, LLMResult from langchain_core.prompt_values import PromptValue from langchain_core.prompts import BasePromptTemplate, PromptTemplate -from langchain_core.pydantic_v1 import Extra, Field +from langchain_core.pydantic_v1 import Field from langchain_core.runnables import ( Runnable, RunnableBinding, @@ -96,10 +96,8 @@ class LLMChain(Chain): llm_kwargs: dict = Field(default_factory=dict) class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/llm_checker/base.py b/libs/langchain/langchain/chains/llm_checker/base.py index 7a84ec0f861..2e2fa61d725 100644 --- a/libs/langchain/langchain/chains/llm_checker/base.py +++ b/libs/langchain/langchain/chains/llm_checker/base.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate -from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.pydantic_v1 import root_validator from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -91,10 +91,8 @@ class LLMCheckerChain(Chain): output_key: str = "result" #: :meta private: class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @root_validator(pre=True) def raise_deprecation(cls, values: Dict) -> Dict: diff --git a/libs/langchain/langchain/chains/llm_math/base.py b/libs/langchain/langchain/chains/llm_math/base.py index 872409cb477..0733b0079b3 100644 --- a/libs/langchain/langchain/chains/llm_math/base.py +++ b/libs/langchain/langchain/chains/llm_math/base.py @@ -13,7 +13,7 @@ from langchain_core.callbacks import ( ) from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.pydantic_v1 import root_validator from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -40,10 +40,8 @@ class LLMMathChain(Chain): output_key: str = "answer" #: :meta private: class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @root_validator(pre=True) def raise_deprecation(cls, values: Dict) -> Dict: diff --git a/libs/langchain/langchain/chains/llm_summarization_checker/base.py b/libs/langchain/langchain/chains/llm_summarization_checker/base.py index f63e383bbcc..da310232cf7 100644 --- a/libs/langchain/langchain/chains/llm_summarization_checker/base.py +++ b/libs/langchain/langchain/chains/llm_summarization_checker/base.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.pydantic_v1 import root_validator from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -96,10 +96,8 @@ class LLMSummarizationCheckerChain(Chain): """Maximum number of times to check the assertions. Default to double-checking.""" class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @root_validator(pre=True) def raise_deprecation(cls, values: Dict) -> Dict: diff --git a/libs/langchain/langchain/chains/mapreduce.py b/libs/langchain/langchain/chains/mapreduce.py index f607a62d505..359133f0de2 100644 --- a/libs/langchain/langchain/chains/mapreduce.py +++ b/libs/langchain/langchain/chains/mapreduce.py @@ -12,7 +12,6 @@ from langchain_core.callbacks import CallbackManagerForChainRun, Callbacks from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Extra from langchain_text_splitters import TextSplitter from langchain.chains import ReduceDocumentsChain @@ -68,10 +67,8 @@ class MapReduceChain(Chain): ) class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/natbot/base.py b/libs/langchain/langchain/chains/natbot/base.py index 5d14cc68187..910e03f7d4f 100644 --- a/libs/langchain/langchain/chains/natbot/base.py +++ b/libs/langchain/langchain/chains/natbot/base.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel -from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.pydantic_v1 import root_validator from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -48,10 +48,8 @@ class NatBotChain(Chain): output_key: str = "command" #: :meta private: class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @root_validator(pre=True) def raise_deprecation(cls, values: Dict) -> Dict: diff --git a/libs/langchain/langchain/chains/qa_with_sources/base.py b/libs/langchain/langchain/chains/qa_with_sources/base.py index ea444896daf..c7324a7e32f 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/base.py +++ b/libs/langchain/langchain/chains/qa_with_sources/base.py @@ -14,7 +14,7 @@ from langchain_core.callbacks import ( from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.pydantic_v1 import root_validator from langchain.chains import ReduceDocumentsChain from langchain.chains.base import Chain @@ -88,10 +88,8 @@ class BaseQAWithSourcesChain(Chain, ABC): return cls(combine_documents_chain=combine_documents_chain, **kwargs) class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/query_constructor/schema.py b/libs/langchain/langchain/chains/query_constructor/schema.py index 6171b3742f2..585addc7919 100644 --- a/libs/langchain/langchain/chains/query_constructor/schema.py +++ b/libs/langchain/langchain/chains/query_constructor/schema.py @@ -9,7 +9,5 @@ class AttributeInfo(BaseModel): type: str class Config: - """Configuration for this pydantic object.""" - arbitrary_types_allowed = True frozen = True diff --git a/libs/langchain/langchain/chains/retrieval_qa/base.py b/libs/langchain/langchain/chains/retrieval_qa/base.py index 5118cc0d93c..0b25dc00b0b 100644 --- a/libs/langchain/langchain/chains/retrieval_qa/base.py +++ b/libs/langchain/langchain/chains/retrieval_qa/base.py @@ -16,7 +16,7 @@ from langchain_core.callbacks import ( from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate -from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.retrievers import BaseRetriever from langchain_core.vectorstores import VectorStore @@ -39,11 +39,9 @@ class BaseRetrievalQA(Chain): """Return the source documents or not.""" class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True allow_population_by_field_name = True + arbitrary_types_allowed = True + extra = "forbid" @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/router/base.py b/libs/langchain/langchain/chains/router/base.py index 93e127aa7bf..d0b680dd952 100644 --- a/libs/langchain/langchain/chains/router/base.py +++ b/libs/langchain/langchain/chains/router/base.py @@ -10,7 +10,6 @@ from langchain_core.callbacks import ( CallbackManagerForChainRun, Callbacks, ) -from langchain_core.pydantic_v1 import Extra from langchain.chains.base import Chain @@ -62,10 +61,8 @@ class MultiRouteChain(Chain): Defaults to False.""" class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/router/embedding_router.py b/libs/langchain/langchain/chains/router/embedding_router.py index b7fb59f8522..a1bc126a49f 100644 --- a/libs/langchain/langchain/chains/router/embedding_router.py +++ b/libs/langchain/langchain/chains/router/embedding_router.py @@ -8,7 +8,6 @@ from langchain_core.callbacks import ( ) from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import Extra from langchain_core.vectorstores import VectorStore from langchain.chains.router.base import RouterChain @@ -21,10 +20,8 @@ class EmbeddingRouterChain(RouterChain): routing_keys: List[str] = ["query"] class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/sequential.py b/libs/langchain/langchain/chains/sequential.py index 49b5ad06efb..e75300f7cbf 100644 --- a/libs/langchain/langchain/chains/sequential.py +++ b/libs/langchain/langchain/chains/sequential.py @@ -6,7 +6,7 @@ from langchain_core.callbacks import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) -from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.pydantic_v1 import root_validator from langchain_core.utils.input import get_color_mapping from langchain.chains.base import Chain @@ -21,10 +21,8 @@ class SequentialChain(Chain): return_all: bool = False class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @property def input_keys(self) -> List[str]: @@ -132,10 +130,8 @@ class SimpleSequentialChain(Chain): output_key: str = "output" #: :meta private: class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py index 9e89ea1ffb3..791b95f64cf 100644 --- a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py +++ b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py @@ -28,7 +28,7 @@ from langchain_core.exceptions import OutputParserException from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.output_parsers import BaseOutputParser -from langchain_core.pydantic_v1 import Extra, Field +from langchain_core.pydantic_v1 import Field from langchain_core.tools import BaseTool from langchain.chains.llm import LLMChain @@ -157,9 +157,7 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain): """DEPRECATED. Reasoning always returned.""" class Config: - """Configuration for the QAEvalChain.""" - - extra = Extra.ignore + extra = "ignore" @property def requires_reference(self) -> bool: diff --git a/libs/langchain/langchain/evaluation/comparison/eval_chain.py b/libs/langchain/langchain/evaluation/comparison/eval_chain.py index 3ca17e4bfc8..d76f836ad87 100644 --- a/libs/langchain/langchain/evaluation/comparison/eval_chain.py +++ b/libs/langchain/langchain/evaluation/comparison/eval_chain.py @@ -10,7 +10,7 @@ from langchain_core.callbacks.manager import Callbacks from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.pydantic_v1 import Extra, Field +from langchain_core.pydantic_v1 import Field from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.llm import LLMChain @@ -192,9 +192,7 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain): return False class Config: - """Configuration for the PairwiseStringEvalChain.""" - - extra = Extra.ignore + extra = "ignore" @property def requires_reference(self) -> bool: diff --git a/libs/langchain/langchain/evaluation/criteria/eval_chain.py b/libs/langchain/langchain/evaluation/criteria/eval_chain.py index 00f91885025..34fc656cbc7 100644 --- a/libs/langchain/langchain/evaluation/criteria/eval_chain.py +++ b/libs/langchain/langchain/evaluation/criteria/eval_chain.py @@ -8,7 +8,7 @@ from langchain_core.callbacks.manager import Callbacks from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Extra, Field +from langchain_core.pydantic_v1 import Field from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.llm import LLMChain @@ -237,9 +237,7 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain): return False class Config: - """Configuration for the QAEvalChain.""" - - extra = Extra.ignore + extra = "ignore" @property def requires_reference(self) -> bool: diff --git a/libs/langchain/langchain/evaluation/embedding_distance/base.py b/libs/langchain/langchain/evaluation/embedding_distance/base.py index d9bd705efe3..d983c72fbf0 100644 --- a/libs/langchain/langchain/evaluation/embedding_distance/base.py +++ b/libs/langchain/langchain/evaluation/embedding_distance/base.py @@ -114,8 +114,6 @@ class _EmbeddingDistanceChainMixin(Chain): return values class Config: - """Permit embeddings to go unvalidated.""" - arbitrary_types_allowed: bool = True @property diff --git a/libs/langchain/langchain/evaluation/qa/eval_chain.py b/libs/langchain/langchain/evaluation/qa/eval_chain.py index 988655be82a..0204d8fe901 100644 --- a/libs/langchain/langchain/evaluation/qa/eval_chain.py +++ b/libs/langchain/langchain/evaluation/qa/eval_chain.py @@ -9,7 +9,6 @@ from typing import Any, List, Optional, Sequence, Tuple from langchain_core.callbacks.manager import Callbacks from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate -from langchain_core.pydantic_v1 import Extra from langchain.chains.llm import LLMChain from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT @@ -74,9 +73,7 @@ class QAEvalChain(LLMChain, StringEvaluator, LLMEvalChain): output_key: str = "results" #: :meta private: class Config: - """Configuration for the QAEvalChain.""" - - extra = Extra.ignore + extra = "ignore" @classmethod def is_lc_serializable(cls) -> bool: @@ -224,9 +221,7 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain): return True class Config: - """Configuration for the QAEvalChain.""" - - extra = Extra.ignore + extra = "ignore" @classmethod def _validate_input_vars(cls, prompt: PromptTemplate) -> None: diff --git a/libs/langchain/langchain/evaluation/scoring/eval_chain.py b/libs/langchain/langchain/evaluation/scoring/eval_chain.py index e882b187ac5..3b800f8ffc0 100644 --- a/libs/langchain/langchain/evaluation/scoring/eval_chain.py +++ b/libs/langchain/langchain/evaluation/scoring/eval_chain.py @@ -10,7 +10,7 @@ from langchain_core.callbacks.manager import Callbacks from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.pydantic_v1 import Extra, Field +from langchain_core.pydantic_v1 import Field from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.llm import LLMChain @@ -180,9 +180,7 @@ class ScoreStringEvalChain(StringEvaluator, LLMEvalChain, LLMChain): """The name of the criterion being evaluated.""" class Config: - """Configuration for the ScoreStringEvalChain.""" - - extra = Extra.ignore + extra = "ignore" @classmethod def is_lc_serializable(cls) -> bool: diff --git a/libs/langchain/langchain/indexes/vectorstore.py b/libs/langchain/langchain/indexes/vectorstore.py index 55e773ebdcd..deb408dff2d 100644 --- a/libs/langchain/langchain/indexes/vectorstore.py +++ b/libs/langchain/langchain/indexes/vectorstore.py @@ -4,7 +4,7 @@ from langchain_core.document_loaders import BaseLoader from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel -from langchain_core.pydantic_v1 import BaseModel, Extra, Field +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.vectorstores import VectorStore from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter @@ -22,10 +22,8 @@ class VectorStoreIndexWrapper(BaseModel): vectorstore: VectorStore class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" def query( self, @@ -145,10 +143,8 @@ class VectorstoreIndexCreator(BaseModel): vectorstore_kwargs: dict = Field(default_factory=dict) class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" def from_loaders(self, loaders: List[BaseLoader]) -> VectorStoreIndexWrapper: """Create a vectorstore index from loaders.""" diff --git a/libs/langchain/langchain/memory/entity.py b/libs/langchain/langchain/memory/entity.py index a726b86076d..57fdb75537e 100644 --- a/libs/langchain/langchain/memory/entity.py +++ b/libs/langchain/langchain/memory/entity.py @@ -246,8 +246,6 @@ class SQLiteEntityStore(BaseEntityStore): conn: Any = None class Config: - """Configuration for this pydantic object.""" - arbitrary_types_allowed = True def __init__( diff --git a/libs/langchain/langchain/retrievers/contextual_compression.py b/libs/langchain/langchain/retrievers/contextual_compression.py index 9f5c09e3d1c..c73180b889d 100644 --- a/libs/langchain/langchain/retrievers/contextual_compression.py +++ b/libs/langchain/langchain/retrievers/contextual_compression.py @@ -22,8 +22,6 @@ class ContextualCompressionRetriever(BaseRetriever): """Base Retriever to use for getting relevant documents.""" class Config: - """Configuration for this pydantic object.""" - arbitrary_types_allowed = True def _get_relevant_documents( diff --git a/libs/langchain/langchain/retrievers/document_compressors/base.py b/libs/langchain/langchain/retrievers/document_compressors/base.py index b8b01de5dcc..a68515af8b0 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/base.py +++ b/libs/langchain/langchain/retrievers/document_compressors/base.py @@ -16,8 +16,6 @@ class DocumentCompressorPipeline(BaseDocumentCompressor): """List of document filters that are chained together and run in sequence.""" class Config: - """Configuration for this pydantic object.""" - arbitrary_types_allowed = True def compress_documents( diff --git a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py index 50b80439189..d2c169aca44 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union from langchain_core._api.deprecation import deprecated from langchain_core.callbacks.manager import Callbacks from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.pydantic_v1 import root_validator from langchain_core.utils import get_from_dict_or_env from langchain.retrievers.document_compressors.base import BaseDocumentCompressor @@ -31,10 +31,8 @@ class CohereRerank(BaseDocumentCompressor): """Identifier for the application making the request.""" class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: diff --git a/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py index 245722b364f..d1b683f2d9b 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py @@ -5,7 +5,6 @@ from typing import Optional, Sequence from langchain_core.callbacks import Callbacks from langchain_core.documents import BaseDocumentCompressor, Document -from langchain_core.pydantic_v1 import Extra from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder @@ -20,10 +19,8 @@ class CrossEncoderReranker(BaseDocumentCompressor): """Number of documents to return.""" class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid arbitrary_types_allowed = True + extra = "forbid" def compress_documents( self, diff --git a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py index 058a7ff08bd..d29a0e7ac5f 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py @@ -42,8 +42,6 @@ class EmbeddingsFilter(BaseDocumentCompressor): to None.""" class Config: - """Configuration for this pydantic object.""" - arbitrary_types_allowed = True @pre_init diff --git a/libs/langchain/langchain/retrievers/self_query/base.py b/libs/langchain/langchain/retrievers/self_query/base.py index 5d73756f5ff..3c5ca646d50 100644 --- a/libs/langchain/langchain/retrievers/self_query/base.py +++ b/libs/langchain/langchain/retrievers/self_query/base.py @@ -215,10 +215,8 @@ class SelfQueryRetriever(BaseRetriever): """Use original query instead of the revised new query from LLM""" class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True allow_population_by_field_name = True + arbitrary_types_allowed = True @root_validator(pre=True) def validate_translator(cls, values: Dict) -> Dict: diff --git a/libs/langchain/langchain/retrievers/time_weighted_retriever.py b/libs/langchain/langchain/retrievers/time_weighted_retriever.py index 7864cc3b097..0acf17edce8 100644 --- a/libs/langchain/langchain/retrievers/time_weighted_retriever.py +++ b/libs/langchain/langchain/retrievers/time_weighted_retriever.py @@ -47,8 +47,6 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever): """ class Config: - """Configuration for this pydantic object.""" - arbitrary_types_allowed = True def _document_get_date(self, field: str, document: Document) -> datetime.datetime: diff --git a/libs/langchain/tests/unit_tests/load/test_dump.py b/libs/langchain/tests/unit_tests/load/test_dump.py index 21db03e2283..0ac05f7df21 100644 --- a/libs/langchain/tests/unit_tests/load/test_dump.py +++ b/libs/langchain/tests/unit_tests/load/test_dump.py @@ -85,8 +85,6 @@ class TestClass(Serializable): my_other_secret: str = Field() class Config: - """Configuration for this pydantic object.""" - allow_population_by_field_name = True @root_validator(pre=True)