From 6fd4ac428398f4f4c053e85bd51122bd01d7e4ee Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 5 Sep 2024 18:26:43 -0400 Subject: [PATCH] core[patch]: Replace @validator with @model_validator in length based example selector (#26124) Resolves another warning from usage of deprecated functionality in pydantic 2 --- .../example_selectors/length_based.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/libs/core/langchain_core/example_selectors/length_based.py b/libs/core/langchain_core/example_selectors/length_based.py index 810d1bc4f41..d9e6b6ea762 100644 --- a/libs/core/langchain_core/example_selectors/length_based.py +++ b/libs/core/langchain_core/example_selectors/length_based.py @@ -3,7 +3,8 @@ import re from typing import Callable, Dict, List -from pydantic import BaseModel, validator +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Self from langchain_core.example_selectors.base import BaseExampleSelector from langchain_core.prompts.prompt import PromptTemplate @@ -28,7 +29,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel): max_length: int = 2048 """Max length for the prompt, beyond which examples are cut.""" - example_text_lengths: List[int] = [] #: :meta private: + example_text_lengths: List[int] = Field(default_factory=list) # :meta private: """Length of each example.""" def add_example(self, example: Dict[str, str]) -> None: @@ -52,17 +53,14 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel): self.add_example(example) - @validator("example_text_lengths", always=True) - def calculate_example_text_lengths(cls, v: List[int], values: Dict) -> List[int]: - """Calculate text lengths if they don't exist.""" - # Check if text lengths were passed in - if v: - return v - # If they were not, calculate them - example_prompt = values["example_prompt"] - get_text_length = values["get_text_length"] - string_examples = [example_prompt.format(**eg) for eg in values["examples"]] - return [get_text_length(eg) for eg in string_examples] + @model_validator(mode="after") + def post_init(self) -> Self: + """Validate that the examples are formatted correctly.""" + if self.example_text_lengths: + return self + string_examples = [self.example_prompt.format(**eg) for eg in self.examples] + self.example_text_lengths = [self.get_text_length(eg) for eg in string_examples] + return self def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: """Select which examples to use based on the input lengths.