mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
langchain[minor]: Upgrade ambiguous root_validator to @pre_init (#24895)
The @pre_init validator is a temporary solution for base models. It has similar (but not identical) semantics to @root_validator(), but it works strictly as a pre-init validator. It'll work as expected as long as the pydantic model type hints were correct.
This commit is contained in:
parent
5099a9c9b4
commit
69c656aa5f
@ -10,7 +10,8 @@ from langchain_core.callbacks.manager import (
|
|||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import Field
|
||||||
|
from langchain_core.utils import pre_init
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
|
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
|
||||||
@ -68,7 +69,7 @@ class _EmbeddingDistanceChainMixin(Chain):
|
|||||||
embeddings: Embeddings = Field(default_factory=_embedding_factory)
|
embeddings: Embeddings = Field(default_factory=_embedding_factory)
|
||||||
distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE)
|
distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE)
|
||||||
|
|
||||||
@root_validator(pre=False)
|
@pre_init
|
||||||
def _validate_tiktoken_installed(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def _validate_tiktoken_installed(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Validate that the TikTok library is installed.
|
"""Validate that the TikTok library is installed.
|
||||||
|
|
||||||
|
@ -8,7 +8,8 @@ from langchain_core.callbacks.manager import (
|
|||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import Field
|
||||||
|
from langchain_core.utils import pre_init
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
|
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
|
||||||
@ -63,7 +64,7 @@ class _RapidFuzzChainMixin(Chain):
|
|||||||
"""Whether to normalize the score to a value between 0 and 1.
|
"""Whether to normalize the score to a value between 0 and 1.
|
||||||
Applies only to the Levenshtein and Damerau-Levenshtein distances."""
|
Applies only to the Levenshtein and Damerau-Levenshtein distances."""
|
||||||
|
|
||||||
@root_validator
|
@pre_init
|
||||||
def validate_dependencies(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def validate_dependencies(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Validate that the rapidfuzz library is installed.
|
Validate that the rapidfuzz library is installed.
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
from langchain_core.pydantic_v1 import root_validator
|
from langchain_core.utils import pre_init
|
||||||
|
|
||||||
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
|
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
|
||||||
from langchain.memory.utils import get_prompt_input_key
|
from langchain.memory.utils import get_prompt_input_key
|
||||||
@ -82,7 +82,7 @@ class ConversationStringBufferMemory(BaseMemory):
|
|||||||
input_key: Optional[str] = None
|
input_key: Optional[str] = None
|
||||||
memory_key: str = "history" #: :meta private:
|
memory_key: str = "history" #: :meta private:
|
||||||
|
|
||||||
@root_validator()
|
@pre_init
|
||||||
def validate_chains(cls, values: Dict) -> Dict:
|
def validate_chains(cls, values: Dict) -> Dict:
|
||||||
"""Validate that return messages is not True."""
|
"""Validate that return messages is not True."""
|
||||||
if values.get("return_messages", False):
|
if values.get("return_messages", False):
|
||||||
|
@ -6,7 +6,8 @@ from langchain_core.chat_history import BaseChatMessageHistory
|
|||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
from langchain_core.utils import pre_init
|
||||||
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
@ -85,7 +86,7 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
|
|||||||
buffer = self.buffer
|
buffer = self.buffer
|
||||||
return {self.memory_key: buffer}
|
return {self.memory_key: buffer}
|
||||||
|
|
||||||
@root_validator()
|
@pre_init
|
||||||
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
||||||
"""Validate that prompt input variables are consistent."""
|
"""Validate that prompt input variables are consistent."""
|
||||||
prompt_variables = values["prompt"].input_variables
|
prompt_variables = values["prompt"].input_variables
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
from langchain_core.pydantic_v1 import root_validator
|
from langchain_core.utils import pre_init
|
||||||
|
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.memory.summary import SummarizerMixin
|
from langchain.memory.summary import SummarizerMixin
|
||||||
@ -64,7 +64,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
|||||||
)
|
)
|
||||||
return {self.memory_key: final_buffer}
|
return {self.memory_key: final_buffer}
|
||||||
|
|
||||||
@root_validator()
|
@pre_init
|
||||||
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
||||||
"""Validate that prompt input variables are consistent."""
|
"""Validate that prompt input variables are consistent."""
|
||||||
prompt_variables = values["prompt"].input_variables
|
prompt_variables = values["prompt"].input_variables
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from langchain_core.output_parsers import BaseOutputParser
|
from langchain_core.output_parsers import BaseOutputParser
|
||||||
from langchain_core.pydantic_v1 import root_validator
|
from langchain_core.utils import pre_init
|
||||||
|
|
||||||
|
|
||||||
class CombiningOutputParser(BaseOutputParser[Dict[str, Any]]):
|
class CombiningOutputParser(BaseOutputParser[Dict[str, Any]]):
|
||||||
@ -15,7 +15,7 @@ class CombiningOutputParser(BaseOutputParser[Dict[str, Any]]):
|
|||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@root_validator()
|
@pre_init
|
||||||
def validate_parsers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def validate_parsers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Validate the parsers."""
|
"""Validate the parsers."""
|
||||||
parsers = values["parsers"]
|
parsers = values["parsers"]
|
||||||
|
@ -3,7 +3,7 @@ from typing import Dict, List, Type
|
|||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.output_parsers import BaseOutputParser
|
from langchain_core.output_parsers import BaseOutputParser
|
||||||
from langchain_core.pydantic_v1 import root_validator
|
from langchain_core.utils import pre_init
|
||||||
|
|
||||||
|
|
||||||
class EnumOutputParser(BaseOutputParser[Enum]):
|
class EnumOutputParser(BaseOutputParser[Enum]):
|
||||||
@ -12,7 +12,7 @@ class EnumOutputParser(BaseOutputParser[Enum]):
|
|||||||
enum: Type[Enum]
|
enum: Type[Enum]
|
||||||
"""The enum to parse. Its values must be strings."""
|
"""The enum to parse. Its values must be strings."""
|
||||||
|
|
||||||
@root_validator()
|
@pre_init
|
||||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||||
enum = values["enum"]
|
enum = values["enum"]
|
||||||
if not all(isinstance(e.value, str) for e in enum):
|
if not all(isinstance(e.value, str) for e in enum):
|
||||||
|
@ -4,7 +4,8 @@ import numpy as np
|
|||||||
from langchain_core.callbacks.manager import Callbacks
|
from langchain_core.callbacks.manager import Callbacks
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import Field
|
||||||
|
from langchain_core.utils import pre_init
|
||||||
|
|
||||||
from langchain.retrievers.document_compressors.base import (
|
from langchain.retrievers.document_compressors.base import (
|
||||||
BaseDocumentCompressor,
|
BaseDocumentCompressor,
|
||||||
@ -45,7 +46,7 @@ class EmbeddingsFilter(BaseDocumentCompressor):
|
|||||||
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@root_validator()
|
@pre_init
|
||||||
def validate_params(cls, values: Dict) -> Dict:
|
def validate_params(cls, values: Dict) -> Dict:
|
||||||
"""Validate similarity parameters."""
|
"""Validate similarity parameters."""
|
||||||
if values["k"] is None and values["similarity_threshold"] is None:
|
if values["k"] is None and values["similarity_threshold"] is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user