core[patch]: Replace @validator with @model_validator in length based example selector (#26124)

Resolves another warning from usage of deprecated functionality in
pydantic 2
This commit is contained in:
Eugene Yurtsev
2024-09-05 18:26:43 -04:00
committed by GitHub
parent f4e7cb394f
commit 6fd4ac4283

View File

@@ -3,7 +3,8 @@
import re import re
from typing import Callable, Dict, List from typing import Callable, Dict, List
from pydantic import BaseModel, validator from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from langchain_core.example_selectors.base import BaseExampleSelector from langchain_core.example_selectors.base import BaseExampleSelector
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
@@ -28,7 +29,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
max_length: int = 2048 max_length: int = 2048
"""Max length for the prompt, beyond which examples are cut.""" """Max length for the prompt, beyond which examples are cut."""
example_text_lengths: List[int] = [] #: :meta private: example_text_lengths: List[int] = Field(default_factory=list) # :meta private:
"""Length of each example.""" """Length of each example."""
def add_example(self, example: Dict[str, str]) -> None: def add_example(self, example: Dict[str, str]) -> None:
@@ -52,17 +53,14 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
self.add_example(example) self.add_example(example)
@validator("example_text_lengths", always=True) @model_validator(mode="after")
def calculate_example_text_lengths(cls, v: List[int], values: Dict) -> List[int]: def post_init(self) -> Self:
"""Calculate text lengths if they don't exist.""" """Validate that the examples are formatted correctly."""
# Check if text lengths were passed in if self.example_text_lengths:
if v: return self
return v string_examples = [self.example_prompt.format(**eg) for eg in self.examples]
# If they were not, calculate them self.example_text_lengths = [self.get_text_length(eg) for eg in string_examples]
example_prompt = values["example_prompt"] return self
get_text_length = values["get_text_length"]
string_examples = [example_prompt.format(**eg) for eg in values["examples"]]
return [get_text_length(eg) for eg in string_examples]
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on the input lengths. """Select which examples to use based on the input lengths.