mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 20:39:44 +00:00
feat(core): Upgrade pydantic to 2.x (#1428)
This commit is contained in:
@@ -4,7 +4,7 @@ import json
|
||||
import uuid
|
||||
from typing import Any, Dict
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
@@ -64,7 +64,7 @@ class Chunk(Document):
|
||||
|
||||
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Convert Chunk to dict."""
|
||||
data = self.dict(**kwargs)
|
||||
data = model_to_dict(self, **kwargs)
|
||||
data["class_name"] = self.class_name()
|
||||
return data
|
||||
|
||||
|
@@ -10,7 +10,7 @@ from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt._private.pydantic import BaseModel, model_to_dict
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.util import BaseParameters
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
@@ -312,7 +312,7 @@ class ModelRequest:
|
||||
if isinstance(context, dict):
|
||||
context_dict = context
|
||||
elif isinstance(context, BaseModel):
|
||||
context_dict = context.dict()
|
||||
context_dict = model_to_dict(context)
|
||||
if context_dict and "stream" not in context_dict:
|
||||
context_dict["stream"] = stream
|
||||
if context_dict:
|
||||
|
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
||||
from dbgpt.core.interface.storage import (
|
||||
InMemoryStorage,
|
||||
ResourceIdentifier,
|
||||
@@ -42,7 +42,7 @@ class BaseMessage(BaseModel, ABC):
|
||||
"""
|
||||
return {
|
||||
"type": self.type,
|
||||
"data": self.dict(),
|
||||
"data": model_to_dict(self),
|
||||
"index": self.index,
|
||||
"round_index": self.round_index,
|
||||
}
|
||||
@@ -264,7 +264,7 @@ class ModelMessage(BaseModel):
|
||||
Returns:
|
||||
List[Dict[str, str]]: The dict list
|
||||
"""
|
||||
return list(map(lambda m: m.dict(), messages))
|
||||
return list(map(lambda m: model_to_dict(m), messages))
|
||||
|
||||
@staticmethod
|
||||
def build_human_message(content: str) -> "ModelMessage":
|
||||
|
@@ -2,7 +2,7 @@
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import root_validator
|
||||
from dbgpt._private.pydantic import model_validator
|
||||
from dbgpt.core import (
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
@@ -71,9 +71,12 @@ from dbgpt.util.i18n_utils import _
|
||||
class CommonChatPromptTemplate(ChatPromptTemplate):
|
||||
"""The common chat prompt template."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the messages."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
if "system_message" not in values:
|
||||
values["system_message"] = "You are a helpful AI Assistant."
|
||||
if "human_message" not in values:
|
||||
|
@@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
|
||||
from string import Formatter
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, root_validator
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, model_validator
|
||||
from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
|
||||
from dbgpt.core.interface.storage import (
|
||||
InMemoryStorage,
|
||||
@@ -51,6 +51,8 @@ class BasePromptTemplate(BaseModel):
|
||||
class PromptTemplate(BasePromptTemplate):
|
||||
"""Prompt template."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
template: str
|
||||
"""The prompt template."""
|
||||
|
||||
@@ -69,11 +71,6 @@ class PromptTemplate(BasePromptTemplate):
|
||||
template_define: Optional[str] = None
|
||||
"""this template define"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
@@ -239,9 +236,12 @@ class ChatPromptTemplate(BasePromptTemplate):
|
||||
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||
return result_messages
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre-fill the messages."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
input_variables = values.get("input_variables", {})
|
||||
messages = values.get("messages", [])
|
||||
if not input_variables:
|
||||
|
@@ -9,7 +9,7 @@ from typing import Dict, Optional, Type
|
||||
class Serializable(ABC):
|
||||
"""The serializable abstract class."""
|
||||
|
||||
serializer: Optional["Serializer"] = None
|
||||
_serializer: Optional["Serializer"] = None
|
||||
|
||||
@abstractmethod
|
||||
def to_dict(self) -> Dict:
|
||||
@@ -21,11 +21,12 @@ class Serializable(ABC):
|
||||
Returns:
|
||||
bytes: The byte array after serialization
|
||||
"""
|
||||
if self.serializer is None:
|
||||
if self._serializer is None:
|
||||
raise ValueError(
|
||||
"Serializer is not set. Please set the serializer before serialization."
|
||||
"Serializer is not set. Please set the serializer before "
|
||||
"serialization."
|
||||
)
|
||||
return self.serializer.serialize(self)
|
||||
return self._serializer.serialize(self)
|
||||
|
||||
def set_serializer(self, serializer: "Serializer") -> None:
|
||||
"""Set the serializer for current serializable object.
|
||||
@@ -33,7 +34,7 @@ class Serializable(ABC):
|
||||
Args:
|
||||
serializer (Serializer): The serializer to set
|
||||
"""
|
||||
self.serializer = serializer
|
||||
self._serializer = serializer
|
||||
|
||||
|
||||
class Serializer(ABC):
|
||||
|
@@ -426,7 +426,7 @@ class InMemoryStorage(StorageInterface[T, T]):
|
||||
"""
|
||||
if not data:
|
||||
raise StorageError("Data cannot be None")
|
||||
if not data.serializer:
|
||||
if not data._serializer:
|
||||
data.set_serializer(self.serializer)
|
||||
|
||||
if data.identifier.str_identifier in self._data:
|
||||
@@ -439,7 +439,7 @@ class InMemoryStorage(StorageInterface[T, T]):
|
||||
"""Update the data to the storage."""
|
||||
if not data:
|
||||
raise StorageError("Data cannot be None")
|
||||
if not data.serializer:
|
||||
if not data._serializer:
|
||||
data.set_serializer(self.serializer)
|
||||
self._data[data.identifier.str_identifier] = data.serialize()
|
||||
|
||||
|
Reference in New Issue
Block a user