mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 05:01:25 +00:00
feat(core): Upgrade pydantic to 2.x (#1428)
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
DEFAULT_CONTEXT_WINDOW = 3900
|
||||
DEFAULT_NUM_OUTPUTS = 256
|
||||
|
||||
|
||||
class LLMMetadata(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
context_window: int = Field(
|
||||
default=DEFAULT_CONTEXT_WINDOW,
|
||||
description=(
|
||||
|
@@ -1,7 +1,14 @@
|
||||
from typing import get_origin
|
||||
|
||||
import pydantic
|
||||
|
||||
if pydantic.VERSION.startswith("1."):
|
||||
PYDANTIC_VERSION = 1
|
||||
raise NotImplementedError("pydantic 1.x is not supported, please upgrade to 2.x.")
|
||||
else:
|
||||
PYDANTIC_VERSION = 2
|
||||
# pydantic 2.x
|
||||
# Now we upgrade to pydantic 2.x
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
@@ -13,33 +20,72 @@ if pydantic.VERSION.startswith("1."):
|
||||
PositiveInt,
|
||||
PrivateAttr,
|
||||
ValidationError,
|
||||
root_validator,
|
||||
validator,
|
||||
)
|
||||
else:
|
||||
PYDANTIC_VERSION = 2
|
||||
# pydantic 2.x
|
||||
from pydantic.v1 import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Extra,
|
||||
Field,
|
||||
NonNegativeFloat,
|
||||
NonNegativeInt,
|
||||
PositiveFloat,
|
||||
PositiveInt,
|
||||
PrivateAttr,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
model_validator,
|
||||
root_validator,
|
||||
validator,
|
||||
)
|
||||
|
||||
EXTRA_FORBID = "forbid"
|
||||
|
||||
def model_to_json(model, **kwargs):
|
||||
"""Convert a pydantic model to json"""
|
||||
|
||||
def model_to_json(model, **kwargs) -> str:
|
||||
"""Convert a pydantic model to json."""
|
||||
if PYDANTIC_VERSION == 1:
|
||||
return model.json(**kwargs)
|
||||
else:
|
||||
if "ensure_ascii" in kwargs:
|
||||
del kwargs["ensure_ascii"]
|
||||
return model.model_dump_json(**kwargs)
|
||||
|
||||
|
||||
def model_to_dict(model, **kwargs) -> dict:
|
||||
"""Convert a pydantic model to dict."""
|
||||
if PYDANTIC_VERSION == 1:
|
||||
return model.dict(**kwargs)
|
||||
else:
|
||||
return model.model_dump(**kwargs)
|
||||
|
||||
|
||||
def model_fields(model):
|
||||
"""Return the fields of a pydantic model."""
|
||||
if PYDANTIC_VERSION == 1:
|
||||
return model.__fields__
|
||||
else:
|
||||
return model.model_fields
|
||||
|
||||
|
||||
def field_is_required(field) -> bool:
|
||||
"""Return if a field is required."""
|
||||
if PYDANTIC_VERSION == 1:
|
||||
return field.required
|
||||
else:
|
||||
return field.is_required()
|
||||
|
||||
|
||||
def field_outer_type(field):
|
||||
"""Return the outer type of a field."""
|
||||
if PYDANTIC_VERSION == 1:
|
||||
return field.outer_type_
|
||||
else:
|
||||
# https://github.com/pydantic/pydantic/discussions/7217
|
||||
origin = get_origin(field.annotation)
|
||||
if origin is None:
|
||||
return field.annotation
|
||||
return origin
|
||||
|
||||
|
||||
def field_description(field):
|
||||
"""Return the description of a field."""
|
||||
if PYDANTIC_VERSION == 1:
|
||||
return field.field_info.description
|
||||
else:
|
||||
return field.description
|
||||
|
||||
|
||||
def field_default(field):
|
||||
"""Return the default value of a field."""
|
||||
if PYDANTIC_VERSION == 1:
|
||||
return field.field_info.default
|
||||
else:
|
||||
return field.default
|
||||
|
Reference in New Issue
Block a user