langchain[minor]: Upgrade ambiguous root_validator to @pre_init (#24895)

The @pre_init validator is a temporary solution for base models. It has
similar (but not identical) semantics to @root_validator(), but it works
strictly as a pre-init validator.

It'll work as expected as long as the pydantic model type hints were
correct.
This commit is contained in:
Eugene Yurtsev 2024-07-31 14:46:47 -04:00 committed by GitHub
parent 5099a9c9b4
commit 69c656aa5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 20 additions and 16 deletions

View File

@ -10,7 +10,8 @@ from langchain_core.callbacks.manager import (
Callbacks,
)
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.pydantic_v1 import Field
from langchain_core.utils import pre_init
from langchain.chains.base import Chain
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
@ -68,7 +69,7 @@ class _EmbeddingDistanceChainMixin(Chain):
embeddings: Embeddings = Field(default_factory=_embedding_factory)
distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE)
@root_validator(pre=False)
@pre_init
def _validate_tiktoken_installed(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that the TikTok library is installed.

View File

@ -8,7 +8,8 @@ from langchain_core.callbacks.manager import (
CallbackManagerForChainRun,
Callbacks,
)
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.pydantic_v1 import Field
from langchain_core.utils import pre_init
from langchain.chains.base import Chain
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
@ -63,7 +64,7 @@ class _RapidFuzzChainMixin(Chain):
"""Whether to normalize the score to a value between 0 and 1.
Applies only to the Levenshtein and Damerau-Levenshtein distances."""
@root_validator
@pre_init
def validate_dependencies(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate that the rapidfuzz library is installed.

View File

@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils import pre_init
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
from langchain.memory.utils import get_prompt_input_key
@ -82,7 +82,7 @@ class ConversationStringBufferMemory(BaseMemory):
input_key: Optional[str] = None
memory_key: str = "history" #: :meta private:
@root_validator()
@pre_init
def validate_chains(cls, values: Dict) -> Dict:
"""Validate that return messages is not True."""
if values.get("return_messages", False):

View File

@ -6,7 +6,8 @@ from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils import pre_init
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
@ -85,7 +86,7 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
buffer = self.buffer
return {self.memory_key: buffer}
@root_validator()
@pre_init
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
"""Validate that prompt input variables are consistent."""
prompt_variables = values["prompt"].input_variables

View File

@ -1,7 +1,7 @@
from typing import Any, Dict, List, Union
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils import pre_init
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.summary import SummarizerMixin
@ -64,7 +64,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
)
return {self.memory_key: final_buffer}
@root_validator()
@pre_init
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
"""Validate that prompt input variables are consistent."""
prompt_variables = values["prompt"].input_variables

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any, Dict, List
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils import pre_init
class CombiningOutputParser(BaseOutputParser[Dict[str, Any]]):
@ -15,7 +15,7 @@ class CombiningOutputParser(BaseOutputParser[Dict[str, Any]]):
def is_lc_serializable(cls) -> bool:
return True
@root_validator()
@pre_init
def validate_parsers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate the parsers."""
parsers = values["parsers"]

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Type
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils import pre_init
class EnumOutputParser(BaseOutputParser[Enum]):
@ -12,7 +12,7 @@ class EnumOutputParser(BaseOutputParser[Enum]):
enum: Type[Enum]
"""The enum to parse. Its values must be strings."""
@root_validator()
@pre_init
def raise_deprecation(cls, values: Dict) -> Dict:
enum = values["enum"]
if not all(isinstance(e.value, str) for e in enum):

View File

@ -4,7 +4,8 @@ import numpy as np
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.pydantic_v1 import Field
from langchain_core.utils import pre_init
from langchain.retrievers.document_compressors.base import (
BaseDocumentCompressor,
@ -45,7 +46,7 @@ class EmbeddingsFilter(BaseDocumentCompressor):
arbitrary_types_allowed = True
@root_validator()
@pre_init
def validate_params(cls, values: Dict) -> Dict:
"""Validate similarity parameters."""
if values["k"] is None and values["similarity_threshold"] is None: