mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-03 05:34:01 +00:00
langchain[minor]: Upgrade ambiguous root_validator to @pre_init (#24895)
The @pre_init validator is a temporary solution for base models. It has similar (but not identical) semantics to @root_validator(), but it works strictly as a pre-init validator. It'll work as expected as long as the pydantic model type hints were correct.
This commit is contained in:
parent
5099a9c9b4
commit
69c656aa5f
@ -10,7 +10,8 @@ from langchain_core.callbacks.manager import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.utils import pre_init
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
|
||||
@ -68,7 +69,7 @@ class _EmbeddingDistanceChainMixin(Chain):
|
||||
embeddings: Embeddings = Field(default_factory=_embedding_factory)
|
||||
distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE)
|
||||
|
||||
@root_validator(pre=False)
|
||||
@pre_init
|
||||
def _validate_tiktoken_installed(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate that the TikTok library is installed.
|
||||
|
||||
|
@ -8,7 +8,8 @@ from langchain_core.callbacks.manager import (
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.utils import pre_init
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
|
||||
@ -63,7 +64,7 @@ class _RapidFuzzChainMixin(Chain):
|
||||
"""Whether to normalize the score to a value between 0 and 1.
|
||||
Applies only to the Levenshtein and Damerau-Levenshtein distances."""
|
||||
|
||||
@root_validator
|
||||
@pre_init
|
||||
def validate_dependencies(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate that the rapidfuzz library is installed.
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils import pre_init
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
@ -82,7 +82,7 @@ class ConversationStringBufferMemory(BaseMemory):
|
||||
input_key: Optional[str] = None
|
||||
memory_key: str = "history" #: :meta private:
|
||||
|
||||
@root_validator()
|
||||
@pre_init
|
||||
def validate_chains(cls, values: Dict) -> Dict:
|
||||
"""Validate that return messages is not True."""
|
||||
if values.get("return_messages", False):
|
||||
|
@ -6,7 +6,8 @@ from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils import pre_init
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
@ -85,7 +86,7 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
|
||||
buffer = self.buffer
|
||||
return {self.memory_key: buffer}
|
||||
|
||||
@root_validator()
|
||||
@pre_init
|
||||
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
||||
"""Validate that prompt input variables are consistent."""
|
||||
prompt_variables = values["prompt"].input_variables
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils import pre_init
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
@ -64,7 +64,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
||||
)
|
||||
return {self.memory_key: final_buffer}
|
||||
|
||||
@root_validator()
|
||||
@pre_init
|
||||
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
||||
"""Validate that prompt input variables are consistent."""
|
||||
prompt_variables = values["prompt"].input_variables
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils import pre_init
|
||||
|
||||
|
||||
class CombiningOutputParser(BaseOutputParser[Dict[str, Any]]):
|
||||
@ -15,7 +15,7 @@ class CombiningOutputParser(BaseOutputParser[Dict[str, Any]]):
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@root_validator()
|
||||
@pre_init
|
||||
def validate_parsers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate the parsers."""
|
||||
parsers = values["parsers"]
|
||||
|
@ -3,7 +3,7 @@ from typing import Dict, List, Type
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils import pre_init
|
||||
|
||||
|
||||
class EnumOutputParser(BaseOutputParser[Enum]):
|
||||
@ -12,7 +12,7 @@ class EnumOutputParser(BaseOutputParser[Enum]):
|
||||
enum: Type[Enum]
|
||||
"""The enum to parse. Its values must be strings."""
|
||||
|
||||
@root_validator()
|
||||
@pre_init
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
enum = values["enum"]
|
||||
if not all(isinstance(e.value, str) for e in enum):
|
||||
|
@ -4,7 +4,8 @@ import numpy as np
|
||||
from langchain_core.callbacks.manager import Callbacks
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.utils import pre_init
|
||||
|
||||
from langchain.retrievers.document_compressors.base import (
|
||||
BaseDocumentCompressor,
|
||||
@ -45,7 +46,7 @@ class EmbeddingsFilter(BaseDocumentCompressor):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator()
|
||||
@pre_init
|
||||
def validate_params(cls, values: Dict) -> Dict:
|
||||
"""Validate similarity parameters."""
|
||||
if values["k"] is None and values["similarity_threshold"] is None:
|
||||
|
Loading…
Reference in New Issue
Block a user