mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 03:50:42 +00:00
feat(core): Upgrade pydantic to 2.x (#1428)
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar, Union
|
||||
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from dbgpt._private.pydantic import model_to_dict
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
|
||||
from .db_manager import BaseQuery, DatabaseManager, db
|
||||
@@ -165,7 +166,7 @@ class BaseDao(Generic[T, REQ, RES]):
|
||||
entry = query.first()
|
||||
if entry is None:
|
||||
raise Exception("Invalid request")
|
||||
for key, value in update_request.dict().items(): # type: ignore
|
||||
for key, value in model_to_dict(update_request).items(): # type: ignore
|
||||
if value is not None:
|
||||
setattr(entry, key, value)
|
||||
session.merge(entry)
|
||||
@@ -272,7 +273,9 @@ class BaseDao(Generic[T, REQ, RES]):
|
||||
model_cls = type(self.from_request(query_request))
|
||||
query = session.query(model_cls)
|
||||
query_dict = (
|
||||
query_request if isinstance(query_request, dict) else query_request.dict()
|
||||
query_request
|
||||
if isinstance(query_request, dict)
|
||||
else model_to_dict(query_request)
|
||||
)
|
||||
for key, value in query_dict.items():
|
||||
if value is not None:
|
||||
|
@@ -4,7 +4,7 @@ import pytest
|
||||
from sqlalchemy import Column, Integer, String
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel as PydanticBaseModel
|
||||
from dbgpt._private.pydantic import Field
|
||||
from dbgpt._private.pydantic import Field, model_to_dict
|
||||
from dbgpt.storage.metadata.db_manager import (
|
||||
BaseModel,
|
||||
DatabaseManager,
|
||||
@@ -61,7 +61,7 @@ def user_dao(db, User):
|
||||
class UserDao(BaseDao[User, UserRequest, UserResponse]):
|
||||
def from_request(self, request: Union[UserRequest, Dict[str, Any]]) -> User:
|
||||
if isinstance(request, UserRequest):
|
||||
return User(**request.dict())
|
||||
return User(**model_to_dict(request))
|
||||
else:
|
||||
return User(**request)
|
||||
|
||||
@@ -71,7 +71,7 @@ def user_dao(db, User):
|
||||
)
|
||||
|
||||
def from_response(self, response: UserResponse) -> User:
|
||||
return User(**response.dict())
|
||||
return User(**model_to_dict(response))
|
||||
|
||||
def to_response(self, entity: User):
|
||||
return UserResponse(id=entity.id, name=entity.name, age=entity.age)
|
||||
|
@@ -4,9 +4,9 @@ import math
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.core.awel.flow import Parameter
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
@@ -87,10 +87,7 @@ _COMMON_PARAMETERS = [
|
||||
class VectorStoreConfig(BaseModel):
|
||||
"""Vector store config."""
|
||||
|
||||
class Config:
|
||||
"""Config for BaseModel."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
name: str = Field(
|
||||
default="dbgpt_collection",
|
||||
@@ -122,6 +119,10 @@ class VectorStoreConfig(BaseModel):
|
||||
"bigger than 1, please make sure your vector store is thread-safe.",
|
||||
)
|
||||
|
||||
def to_dict(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Convert to dict."""
|
||||
return model_to_dict(self, **kwargs)
|
||||
|
||||
|
||||
class VectorStoreBase(ABC):
|
||||
"""Vector store base class."""
|
||||
|
@@ -6,7 +6,7 @@ from typing import List, Optional
|
||||
from chromadb import PersistentClient
|
||||
from chromadb.config import Settings
|
||||
|
||||
from dbgpt._private.pydantic import Field
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.configs.model_config import PILOT_PATH
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
@@ -38,16 +38,13 @@ logger = logging.getLogger(__name__)
|
||||
class ChromaVectorConfig(VectorStoreConfig):
|
||||
"""Chroma vector store config."""
|
||||
|
||||
class Config:
|
||||
"""Config for BaseModel."""
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
persist_path: str = Field(
|
||||
persist_path: Optional[str] = Field(
|
||||
default=os.getenv("CHROMA_PERSIST_PATH", None),
|
||||
description="the persist path of vector store.",
|
||||
)
|
||||
collection_metadata: dict = Field(
|
||||
collection_metadata: Optional[dict] = Field(
|
||||
default=None,
|
||||
description="the index metadata of vector store, if not set, will use the "
|
||||
"default metadata.",
|
||||
@@ -61,7 +58,7 @@ class ChromaStore(VectorStoreBase):
|
||||
"""Create a ChromaStore instance."""
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
chroma_vector_config = vector_store_config.dict(exclude_none=True)
|
||||
chroma_vector_config = vector_store_config.to_dict(exclude_none=True)
|
||||
chroma_path = chroma_vector_config.get(
|
||||
"persist_path", os.path.join(PILOT_PATH, "data")
|
||||
)
|
||||
|
@@ -2,7 +2,7 @@
|
||||
from enum import Enum
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FilterOperator(str, Enum):
|
||||
|
@@ -6,7 +6,7 @@ import logging
|
||||
import os
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import Field
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.storage.vector_store.base import (
|
||||
@@ -96,10 +96,7 @@ logger = logging.getLogger(__name__)
|
||||
class MilvusVectorConfig(VectorStoreConfig):
|
||||
"""Milvus vector store config."""
|
||||
|
||||
class Config:
|
||||
"""Config for BaseModel."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
uri: str = Field(
|
||||
default="localhost",
|
||||
@@ -155,7 +152,7 @@ class MilvusStore(VectorStoreBase):
|
||||
from pymilvus import connections
|
||||
|
||||
connect_kwargs = {}
|
||||
milvus_vector_config = vector_store_config.dict()
|
||||
milvus_vector_config = vector_store_config.to_dict()
|
||||
self.uri = milvus_vector_config.get("uri") or os.getenv(
|
||||
"MILVUS_URL", "localhost"
|
||||
)
|
||||
|
@@ -2,7 +2,7 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import Field
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.storage.vector_store.base import (
|
||||
@@ -39,10 +39,7 @@ logger = logging.getLogger(__name__)
|
||||
class PGVectorConfig(VectorStoreConfig):
|
||||
"""PG vector store config."""
|
||||
|
||||
class Config:
|
||||
"""Config for BaseModel."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
connection_string: str = Field(
|
||||
default=None,
|
||||
|
@@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import Field
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.util.i18n_utils import _
|
||||
@@ -44,10 +44,7 @@ logger = logging.getLogger(__name__)
|
||||
class WeaviateVectorConfig(VectorStoreConfig):
|
||||
"""Weaviate vector store config."""
|
||||
|
||||
class Config:
|
||||
"""Config for BaseModel."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
weaviate_url: str = Field(
|
||||
default=os.getenv("WEAVIATE_URL", None),
|
||||
|
Reference in New Issue
Block a user