feat(core): Upgrade pydantic to 2.x (#1428)

This commit is contained in:
Fangyin Cheng
2024-04-20 09:41:16 +08:00
committed by GitHub
parent baa1e3f9f6
commit 57be1ece18
103 changed files with 1146 additions and 534 deletions

View File

@@ -1,10 +1,12 @@
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
DEFAULT_CONTEXT_WINDOW = 3900
DEFAULT_NUM_OUTPUTS = 256
class LLMMetadata(BaseModel):
model_config = ConfigDict(protected_namespaces=())
context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description=(

View File

@@ -1,7 +1,14 @@
from typing import get_origin
import pydantic
if pydantic.VERSION.startswith("1."):
PYDANTIC_VERSION = 1
raise NotImplementedError("pydantic 1.x is not supported, please upgrade to 2.x.")
else:
PYDANTIC_VERSION = 2
# pydantic 2.x
# Now we upgrade to pydantic 2.x
from pydantic import (
BaseModel,
ConfigDict,
@@ -13,33 +20,72 @@ if pydantic.VERSION.startswith("1."):
PositiveInt,
PrivateAttr,
ValidationError,
root_validator,
validator,
)
else:
PYDANTIC_VERSION = 2
# pydantic 2.x
from pydantic.v1 import (
BaseModel,
ConfigDict,
Extra,
Field,
NonNegativeFloat,
NonNegativeInt,
PositiveFloat,
PositiveInt,
PrivateAttr,
ValidationError,
field_validator,
model_validator,
root_validator,
validator,
)
EXTRA_FORBID = "forbid"
def model_to_json(model, **kwargs):
"""Convert a pydantic model to json"""
def model_to_json(model, **kwargs) -> str:
"""Convert a pydantic model to json."""
if PYDANTIC_VERSION == 1:
return model.json(**kwargs)
else:
if "ensure_ascii" in kwargs:
del kwargs["ensure_ascii"]
return model.model_dump_json(**kwargs)
def model_to_dict(model, **kwargs) -> dict:
"""Convert a pydantic model to dict."""
if PYDANTIC_VERSION == 1:
return model.dict(**kwargs)
else:
return model.model_dump(**kwargs)
def model_fields(model):
"""Return the fields of a pydantic model."""
if PYDANTIC_VERSION == 1:
return model.__fields__
else:
return model.model_fields
def field_is_required(field) -> bool:
"""Return if a field is required."""
if PYDANTIC_VERSION == 1:
return field.required
else:
return field.is_required()
def field_outer_type(field):
"""Return the outer type of a field."""
if PYDANTIC_VERSION == 1:
return field.outer_type_
else:
# https://github.com/pydantic/pydantic/discussions/7217
origin = get_origin(field.annotation)
if origin is None:
return field.annotation
return origin
def field_description(field):
"""Return the description of a field."""
if PYDANTIC_VERSION == 1:
return field.field_info.description
else:
return field.description
def field_default(field):
"""Return the default value of a field."""
if PYDANTIC_VERSION == 1:
return field.field_info.default
else:
return field.default