feat(core): Upgrade pydantic to 2.x (#1428)

This commit is contained in:
Fangyin Cheng
2024-04-20 09:41:16 +08:00
committed by GitHub
parent baa1e3f9f6
commit 57be1ece18
103 changed files with 1146 additions and 534 deletions

View File

@@ -4,7 +4,7 @@ import json
import uuid
from typing import Any, Dict
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
class Document(BaseModel):
@@ -64,7 +64,7 @@ class Chunk(Document):
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
"""Convert Chunk to dict."""
data = self.dict(**kwargs)
data = model_to_dict(self, **kwargs)
data["class_name"] = self.class_name()
return data

View File

@@ -10,7 +10,7 @@ from typing import Any, AsyncIterator, Dict, List, Optional, Union
from cachetools import TTLCache
from dbgpt._private.pydantic import BaseModel
from dbgpt._private.pydantic import BaseModel, model_to_dict
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.util import BaseParameters
from dbgpt.util.annotations import PublicAPI
@@ -312,7 +312,7 @@ class ModelRequest:
if isinstance(context, dict):
context_dict = context
elif isinstance(context, BaseModel):
context_dict = context.dict()
context_dict = model_to_dict(context)
if context_dict and "stream" not in context_dict:
context_dict["stream"] = stream
if context_dict:

View File

@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
from datetime import datetime
from typing import Callable, Dict, List, Optional, Tuple, Union, cast
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
from dbgpt.core.interface.storage import (
InMemoryStorage,
ResourceIdentifier,
@@ -42,7 +42,7 @@ class BaseMessage(BaseModel, ABC):
"""
return {
"type": self.type,
"data": self.dict(),
"data": model_to_dict(self),
"index": self.index,
"round_index": self.round_index,
}
@@ -264,7 +264,7 @@ class ModelMessage(BaseModel):
Returns:
List[Dict[str, str]]: The dict list
"""
return list(map(lambda m: m.dict(), messages))
return list(map(lambda m: model_to_dict(m), messages))
@staticmethod
def build_human_message(content: str) -> "ModelMessage":

View File

@@ -2,7 +2,7 @@
from abc import ABC
from typing import Any, Dict, List, Optional, Union
from dbgpt._private.pydantic import root_validator
from dbgpt._private.pydantic import model_validator
from dbgpt.core import (
ModelMessage,
ModelMessageRoleType,
@@ -71,9 +71,12 @@ from dbgpt.util.i18n_utils import _
class CommonChatPromptTemplate(ChatPromptTemplate):
"""The common chat prompt template."""
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the messages."""
if not isinstance(values, dict):
return values
if "system_message" not in values:
values["system_message"] = "You are a helpful AI Assistant."
if "human_message" not in values:

View File

@@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
from string import Formatter
from typing import Any, Callable, Dict, List, Optional, Set, Union
from dbgpt._private.pydantic import BaseModel, root_validator
from dbgpt._private.pydantic import BaseModel, ConfigDict, model_validator
from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
from dbgpt.core.interface.storage import (
InMemoryStorage,
@@ -51,6 +51,8 @@ class BasePromptTemplate(BaseModel):
class PromptTemplate(BasePromptTemplate):
"""Prompt template."""
model_config = ConfigDict(arbitrary_types_allowed=True)
template: str
"""The prompt template."""
@@ -69,11 +71,6 @@ class PromptTemplate(BasePromptTemplate):
template_define: Optional[str] = None
"""this template define"""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
@@ -239,9 +236,12 @@ class ChatPromptTemplate(BasePromptTemplate):
raise ValueError(f"Unsupported message type: {type(message)}")
return result_messages
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre-fill the messages."""
if not isinstance(values, dict):
return values
input_variables = values.get("input_variables", {})
messages = values.get("messages", [])
if not input_variables:

View File

@@ -9,7 +9,7 @@ from typing import Dict, Optional, Type
class Serializable(ABC):
"""The serializable abstract class."""
serializer: Optional["Serializer"] = None
_serializer: Optional["Serializer"] = None
@abstractmethod
def to_dict(self) -> Dict:
@@ -21,11 +21,12 @@ class Serializable(ABC):
Returns:
bytes: The byte array after serialization
"""
if self.serializer is None:
if self._serializer is None:
raise ValueError(
"Serializer is not set. Please set the serializer before serialization."
"Serializer is not set. Please set the serializer before "
"serialization."
)
return self.serializer.serialize(self)
return self._serializer.serialize(self)
def set_serializer(self, serializer: "Serializer") -> None:
"""Set the serializer for current serializable object.
@@ -33,7 +34,7 @@ class Serializable(ABC):
Args:
serializer (Serializer): The serializer to set
"""
self.serializer = serializer
self._serializer = serializer
class Serializer(ABC):

View File

@@ -426,7 +426,7 @@ class InMemoryStorage(StorageInterface[T, T]):
"""
if not data:
raise StorageError("Data cannot be None")
if not data.serializer:
if not data._serializer:
data.set_serializer(self.serializer)
if data.identifier.str_identifier in self._data:
@@ -439,7 +439,7 @@ class InMemoryStorage(StorageInterface[T, T]):
"""Update the data to the storage."""
if not data:
raise StorageError("Data cannot be None")
if not data.serializer:
if not data._serializer:
data.set_serializer(self.serializer)
self._data[data.identifier.str_identifier] = data.serialize()