mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 01:54:44 +00:00
feat(core): Upgrade pydantic to 2.x (#1428)
This commit is contained in:
@@ -159,9 +159,9 @@ def setup_dev_environment(
|
||||
|
||||
start_http = _check_has_http_trigger(dags)
|
||||
if start_http:
|
||||
from fastapi import FastAPI
|
||||
from dbgpt.util.fastapi import create_app
|
||||
|
||||
app = FastAPI()
|
||||
app = create_app()
|
||||
else:
|
||||
app = None
|
||||
system_app = SystemApp(app)
|
||||
|
@@ -4,6 +4,7 @@ DAGManager will load DAGs from dag_dirs, and register the trigger nodes
|
||||
to TriggerManager.
|
||||
"""
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
@@ -29,6 +30,7 @@ class DAGManager(BaseComponent):
|
||||
from ..trigger.trigger_manager import DefaultTriggerManager
|
||||
|
||||
super().__init__(system_app)
|
||||
self.lock = threading.Lock()
|
||||
self.dag_loader = LocalFileDAGLoader(dag_dirs)
|
||||
self.system_app = system_app
|
||||
self.dag_map: Dict[str, DAG] = {}
|
||||
@@ -61,39 +63,54 @@ class DAGManager(BaseComponent):
|
||||
|
||||
def register_dag(self, dag: DAG, alias_name: Optional[str] = None):
|
||||
"""Register a DAG."""
|
||||
dag_id = dag.dag_id
|
||||
if dag_id in self.dag_map:
|
||||
raise ValueError(f"Register DAG error, DAG ID {dag_id} has already exist")
|
||||
self.dag_map[dag_id] = dag
|
||||
if alias_name:
|
||||
self.dag_alias_map[alias_name] = dag_id
|
||||
with self.lock:
|
||||
dag_id = dag.dag_id
|
||||
if dag_id in self.dag_map:
|
||||
raise ValueError(
|
||||
f"Register DAG error, DAG ID {dag_id} has already exist"
|
||||
)
|
||||
self.dag_map[dag_id] = dag
|
||||
if alias_name:
|
||||
self.dag_alias_map[alias_name] = dag_id
|
||||
|
||||
if self._trigger_manager:
|
||||
for trigger in dag.trigger_nodes:
|
||||
self._trigger_manager.register_trigger(trigger, self.system_app)
|
||||
self._trigger_manager.after_register()
|
||||
else:
|
||||
logger.warning("No trigger manager, not register dag trigger")
|
||||
if self._trigger_manager:
|
||||
for trigger in dag.trigger_nodes:
|
||||
self._trigger_manager.register_trigger(trigger, self.system_app)
|
||||
self._trigger_manager.after_register()
|
||||
else:
|
||||
logger.warning("No trigger manager, not register dag trigger")
|
||||
|
||||
def unregister_dag(self, dag_id: str):
|
||||
"""Unregister a DAG."""
|
||||
if dag_id not in self.dag_map:
|
||||
raise ValueError(f"Unregister DAG error, DAG ID {dag_id} does not exist")
|
||||
dag = self.dag_map[dag_id]
|
||||
# Clear the alias map
|
||||
for alias_name, _dag_id in self.dag_alias_map.items():
|
||||
if _dag_id == dag_id:
|
||||
with self.lock:
|
||||
if dag_id not in self.dag_map:
|
||||
raise ValueError(
|
||||
f"Unregister DAG error, DAG ID {dag_id} does not exist"
|
||||
)
|
||||
dag = self.dag_map[dag_id]
|
||||
|
||||
# Collect aliases to remove
|
||||
# TODO(fangyinc): It can be faster if we maintain a reverse map
|
||||
aliases_to_remove = [
|
||||
alias_name
|
||||
for alias_name, _dag_id in self.dag_alias_map.items()
|
||||
if _dag_id == dag_id
|
||||
]
|
||||
# Remove collected aliases
|
||||
for alias_name in aliases_to_remove:
|
||||
del self.dag_alias_map[alias_name]
|
||||
|
||||
if self._trigger_manager:
|
||||
for trigger in dag.trigger_nodes:
|
||||
self._trigger_manager.unregister_trigger(trigger, self.system_app)
|
||||
del self.dag_map[dag_id]
|
||||
if self._trigger_manager:
|
||||
for trigger in dag.trigger_nodes:
|
||||
self._trigger_manager.unregister_trigger(trigger, self.system_app)
|
||||
# Finally remove the DAG from the map
|
||||
del self.dag_map[dag_id]
|
||||
|
||||
def get_dag(
|
||||
self, dag_id: Optional[str] = None, alias_name: Optional[str] = None
|
||||
) -> Optional[DAG]:
|
||||
"""Get a DAG by dag_id or alias_name."""
|
||||
# Not lock, because it is read only and need to be fast
|
||||
if dag_id and dag_id in self.dag_map:
|
||||
return self.dag_map[dag_id]
|
||||
if alias_name in self.dag_alias_map:
|
||||
|
@@ -7,7 +7,13 @@ from datetime import date, datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, ValidationError, root_validator
|
||||
from dbgpt._private.pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
ValidationError,
|
||||
model_to_dict,
|
||||
model_validator,
|
||||
)
|
||||
from dbgpt.core.awel.util.parameter_util import BaseDynamicOptions, OptionValue
|
||||
from dbgpt.core.interface.serialization import Serializable
|
||||
|
||||
@@ -281,7 +287,7 @@ class TypeMetadata(BaseModel):
|
||||
|
||||
def new(self: TM) -> TM:
|
||||
"""Copy the metadata."""
|
||||
return self.__class__(**self.dict())
|
||||
return self.__class__(**self.model_dump(exclude_defaults=True))
|
||||
|
||||
|
||||
class Parameter(TypeMetadata, Serializable):
|
||||
@@ -332,12 +338,15 @@ class Parameter(TypeMetadata, Serializable):
|
||||
None, description="The value of the parameter(Saved in the dag file)"
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the metadata.
|
||||
|
||||
Transform the value to the real type.
|
||||
"""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
type_cls = values.get("type_cls")
|
||||
to_handle_values = {
|
||||
"value": values.get("value"),
|
||||
@@ -443,7 +452,7 @@ class Parameter(TypeMetadata, Serializable):
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert current metadata to json dict."""
|
||||
dict_value = self.dict(exclude={"options"})
|
||||
dict_value = model_to_dict(self, exclude={"options"})
|
||||
if not self.options:
|
||||
dict_value["options"] = None
|
||||
elif isinstance(self.options, BaseDynamicOptions):
|
||||
@@ -535,7 +544,7 @@ class BaseResource(Serializable, BaseModel):
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert current metadata to json dict."""
|
||||
return self.dict()
|
||||
return model_to_dict(self)
|
||||
|
||||
|
||||
class Resource(BaseResource, TypeMetadata):
|
||||
@@ -693,9 +702,12 @@ class BaseMetadata(BaseResource):
|
||||
)
|
||||
return runnable_parameters
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the metadata."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
if "category_label" not in values:
|
||||
category = values["category"]
|
||||
if isinstance(category, str):
|
||||
@@ -713,7 +725,7 @@ class BaseMetadata(BaseResource):
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert current metadata to json dict."""
|
||||
dict_value = self.dict(exclude={"parameters"})
|
||||
dict_value = model_to_dict(self, exclude={"parameters"})
|
||||
dict_value["parameters"] = [
|
||||
parameter.to_dict() for parameter in self.parameters
|
||||
]
|
||||
@@ -738,9 +750,12 @@ class ResourceMetadata(BaseMetadata, TypeMetadata):
|
||||
],
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the metadata."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
if "flow_type" not in values:
|
||||
values["flow_type"] = "resource"
|
||||
if "id" not in values:
|
||||
@@ -846,9 +861,12 @@ class ViewMetadata(BaseMetadata):
|
||||
examples=["dbgpt.model.operators.LLMOperator"],
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the metadata."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
if "flow_type" not in values:
|
||||
values["flow_type"] = "operator"
|
||||
if "id" not in values:
|
||||
|
@@ -6,7 +6,13 @@ from contextlib import suppress
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, root_validator, validator
|
||||
from dbgpt._private.pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
field_validator,
|
||||
model_to_dict,
|
||||
model_validator,
|
||||
)
|
||||
from dbgpt.core.awel.dag.base import DAG, DAGNode
|
||||
|
||||
from .base import (
|
||||
@@ -73,7 +79,8 @@ class FlowNodeData(BaseModel):
|
||||
..., description="Absolute position of the node"
|
||||
)
|
||||
|
||||
@validator("data", pre=True)
|
||||
@field_validator("data", mode="before")
|
||||
@classmethod
|
||||
def parse_data(cls, value: Any):
|
||||
"""Parse the data."""
|
||||
if isinstance(value, dict):
|
||||
@@ -123,9 +130,12 @@ class FlowEdgeData(BaseModel):
|
||||
examples=["buttonedge"],
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the metadata."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
if (
|
||||
"source_order" not in values
|
||||
and "source_handle" in values
|
||||
@@ -315,9 +325,12 @@ class FlowPanel(BaseModel):
|
||||
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the metadata."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
label = values.get("label")
|
||||
name = values.get("name")
|
||||
flow_category = str(values.get("flow_category", ""))
|
||||
@@ -329,6 +342,10 @@ class FlowPanel(BaseModel):
|
||||
values["name"] = name
|
||||
return values
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dict."""
|
||||
return model_to_dict(self)
|
||||
|
||||
|
||||
class FlowFactory:
|
||||
"""Flow factory."""
|
||||
|
@@ -15,7 +15,14 @@ from typing import (
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
field_is_required,
|
||||
field_outer_type,
|
||||
model_fields,
|
||||
model_to_dict,
|
||||
)
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
from ..dag.base import DAG
|
||||
@@ -61,7 +68,7 @@ class AWELHttpError(RuntimeError):
|
||||
|
||||
def _default_streaming_predict_func(body: "CommonRequestType") -> bool:
|
||||
if isinstance(body, BaseModel):
|
||||
body = body.dict()
|
||||
body = model_to_dict(body)
|
||||
elif isinstance(body, str):
|
||||
try:
|
||||
body = json.loads(body)
|
||||
@@ -254,7 +261,7 @@ class CommonLLMHttpRequestBody(BaseHttpBody):
|
||||
"or in full each time. "
|
||||
"If this parameter is not provided, the default is full return.",
|
||||
)
|
||||
enable_vis: str = Field(
|
||||
enable_vis: bool = Field(
|
||||
default=True, description="response content whether to output vis label"
|
||||
)
|
||||
extra: Optional[Dict[str, Any]] = Field(
|
||||
@@ -574,18 +581,20 @@ class HttpTrigger(Trigger):
|
||||
if isinstance(req_body_cls, type) and issubclass(
|
||||
req_body_cls, BaseModel
|
||||
):
|
||||
fields = req_body_cls.__fields__ # type: ignore
|
||||
fields = model_fields(req_body_cls) # type: ignore
|
||||
parameters = []
|
||||
for field_name, field in fields.items():
|
||||
default_value = (
|
||||
Parameter.empty if field.required else field.default
|
||||
Parameter.empty
|
||||
if field_is_required(field)
|
||||
else field.default
|
||||
)
|
||||
parameters.append(
|
||||
Parameter(
|
||||
name=field_name,
|
||||
kind=Parameter.KEYWORD_ONLY,
|
||||
default=default_value,
|
||||
annotation=field.outer_type_,
|
||||
annotation=field_outer_type(field),
|
||||
)
|
||||
)
|
||||
elif req_body_cls == Dict[str, Any] or req_body_cls == dict:
|
||||
@@ -1029,7 +1038,7 @@ class RequestBodyToDictOperator(MapOperator[CommonLLMHttpRequestBody, Dict[str,
|
||||
|
||||
async def map(self, request_body: CommonLLMHttpRequestBody) -> Dict[str, Any]:
|
||||
"""Map the request body to response body."""
|
||||
dict_value = request_body.dict()
|
||||
dict_value = model_to_dict(request_body)
|
||||
if not self._key:
|
||||
return dict_value
|
||||
else:
|
||||
@@ -1138,7 +1147,7 @@ class RequestedParsedOperator(MapOperator[CommonLLMHttpRequestBody, str]):
|
||||
|
||||
async def map(self, request_body: CommonLLMHttpRequestBody) -> str:
|
||||
"""Map the request body to response body."""
|
||||
dict_value = request_body.dict()
|
||||
dict_value = model_to_dict(request_body)
|
||||
if not self._key or self._key not in dict_value:
|
||||
raise ValueError(
|
||||
f"Prefix key {self._key} is not a valid key of the request body"
|
||||
|
@@ -4,7 +4,7 @@ import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, root_validator
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_validator
|
||||
from dbgpt.core.interface.serialization import Serializable
|
||||
|
||||
_DEFAULT_DYNAMIC_REGISTRY = {}
|
||||
@@ -44,9 +44,12 @@ class FunctionDynamicOptions(BaseDynamicOptions):
|
||||
"""Return the option values of the parameter."""
|
||||
return self.func()
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the function id."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
func = values.get("func")
|
||||
if func is None:
|
||||
raise ValueError(
|
||||
|
@@ -4,7 +4,7 @@ import json
|
||||
import uuid
|
||||
from typing import Any, Dict
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
@@ -64,7 +64,7 @@ class Chunk(Document):
|
||||
|
||||
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Convert Chunk to dict."""
|
||||
data = self.dict(**kwargs)
|
||||
data = model_to_dict(self, **kwargs)
|
||||
data["class_name"] = self.class_name()
|
||||
return data
|
||||
|
||||
|
@@ -10,7 +10,7 @@ from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt._private.pydantic import BaseModel, model_to_dict
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.util import BaseParameters
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
@@ -312,7 +312,7 @@ class ModelRequest:
|
||||
if isinstance(context, dict):
|
||||
context_dict = context
|
||||
elif isinstance(context, BaseModel):
|
||||
context_dict = context.dict()
|
||||
context_dict = model_to_dict(context)
|
||||
if context_dict and "stream" not in context_dict:
|
||||
context_dict["stream"] = stream
|
||||
if context_dict:
|
||||
|
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
||||
from dbgpt.core.interface.storage import (
|
||||
InMemoryStorage,
|
||||
ResourceIdentifier,
|
||||
@@ -42,7 +42,7 @@ class BaseMessage(BaseModel, ABC):
|
||||
"""
|
||||
return {
|
||||
"type": self.type,
|
||||
"data": self.dict(),
|
||||
"data": model_to_dict(self),
|
||||
"index": self.index,
|
||||
"round_index": self.round_index,
|
||||
}
|
||||
@@ -264,7 +264,7 @@ class ModelMessage(BaseModel):
|
||||
Returns:
|
||||
List[Dict[str, str]]: The dict list
|
||||
"""
|
||||
return list(map(lambda m: m.dict(), messages))
|
||||
return list(map(lambda m: model_to_dict(m), messages))
|
||||
|
||||
@staticmethod
|
||||
def build_human_message(content: str) -> "ModelMessage":
|
||||
|
@@ -2,7 +2,7 @@
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import root_validator
|
||||
from dbgpt._private.pydantic import model_validator
|
||||
from dbgpt.core import (
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
@@ -71,9 +71,12 @@ from dbgpt.util.i18n_utils import _
|
||||
class CommonChatPromptTemplate(ChatPromptTemplate):
|
||||
"""The common chat prompt template."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the messages."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
if "system_message" not in values:
|
||||
values["system_message"] = "You are a helpful AI Assistant."
|
||||
if "human_message" not in values:
|
||||
|
@@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
|
||||
from string import Formatter
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, root_validator
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, model_validator
|
||||
from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
|
||||
from dbgpt.core.interface.storage import (
|
||||
InMemoryStorage,
|
||||
@@ -51,6 +51,8 @@ class BasePromptTemplate(BaseModel):
|
||||
class PromptTemplate(BasePromptTemplate):
|
||||
"""Prompt template."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
template: str
|
||||
"""The prompt template."""
|
||||
|
||||
@@ -69,11 +71,6 @@ class PromptTemplate(BasePromptTemplate):
|
||||
template_define: Optional[str] = None
|
||||
"""this template define"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
@@ -239,9 +236,12 @@ class ChatPromptTemplate(BasePromptTemplate):
|
||||
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||
return result_messages
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre-fill the messages."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
input_variables = values.get("input_variables", {})
|
||||
messages = values.get("messages", [])
|
||||
if not input_variables:
|
||||
|
@@ -9,7 +9,7 @@ from typing import Dict, Optional, Type
|
||||
class Serializable(ABC):
|
||||
"""The serializable abstract class."""
|
||||
|
||||
serializer: Optional["Serializer"] = None
|
||||
_serializer: Optional["Serializer"] = None
|
||||
|
||||
@abstractmethod
|
||||
def to_dict(self) -> Dict:
|
||||
@@ -21,11 +21,12 @@ class Serializable(ABC):
|
||||
Returns:
|
||||
bytes: The byte array after serialization
|
||||
"""
|
||||
if self.serializer is None:
|
||||
if self._serializer is None:
|
||||
raise ValueError(
|
||||
"Serializer is not set. Please set the serializer before serialization."
|
||||
"Serializer is not set. Please set the serializer before "
|
||||
"serialization."
|
||||
)
|
||||
return self.serializer.serialize(self)
|
||||
return self._serializer.serialize(self)
|
||||
|
||||
def set_serializer(self, serializer: "Serializer") -> None:
|
||||
"""Set the serializer for current serializable object.
|
||||
@@ -33,7 +34,7 @@ class Serializable(ABC):
|
||||
Args:
|
||||
serializer (Serializer): The serializer to set
|
||||
"""
|
||||
self.serializer = serializer
|
||||
self._serializer = serializer
|
||||
|
||||
|
||||
class Serializer(ABC):
|
||||
|
@@ -426,7 +426,7 @@ class InMemoryStorage(StorageInterface[T, T]):
|
||||
"""
|
||||
if not data:
|
||||
raise StorageError("Data cannot be None")
|
||||
if not data.serializer:
|
||||
if not data._serializer:
|
||||
data.set_serializer(self.serializer)
|
||||
|
||||
if data.identifier.str_identifier in self._data:
|
||||
@@ -439,7 +439,7 @@ class InMemoryStorage(StorageInterface[T, T]):
|
||||
"""Update the data to the storage."""
|
||||
if not data:
|
||||
raise StorageError("Data cannot be None")
|
||||
if not data.serializer:
|
||||
if not data._serializer:
|
||||
data.set_serializer(self.serializer)
|
||||
self._data[data.identifier.str_identifier] = data.serialize()
|
||||
|
||||
|
@@ -2,9 +2,10 @@
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Generic, List, Literal, Optional, TypeVar
|
||||
from enum import IntEnum
|
||||
from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -41,6 +42,28 @@ class Result(BaseModel, Generic[T]):
|
||||
"""
|
||||
return Result(success=False, err_code=err_code, err_msg=msg, data=None)
|
||||
|
||||
def to_dict(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Convert to dict."""
|
||||
return model_to_dict(self, **kwargs)
|
||||
|
||||
|
||||
class APIChatCompletionRequest(BaseModel):
|
||||
"""Chat completion request entity."""
|
||||
|
||||
model: str = Field(..., description="Model name")
|
||||
messages: Union[str, List[Dict[str, str]]] = Field(..., description="Messages")
|
||||
temperature: Optional[float] = Field(0.7, description="Temperature")
|
||||
top_p: Optional[float] = Field(1.0, description="Top p")
|
||||
top_k: Optional[int] = Field(-1, description="Top k")
|
||||
n: Optional[int] = Field(1, description="Number of completions")
|
||||
max_tokens: Optional[int] = Field(None, description="Max tokens")
|
||||
stop: Optional[Union[str, List[str]]] = Field(None, description="Stop")
|
||||
stream: Optional[bool] = Field(False, description="Stream")
|
||||
user: Optional[str] = Field(None, description="User")
|
||||
repetition_penalty: Optional[float] = Field(1.0, description="Repetition penalty")
|
||||
frequency_penalty: Optional[float] = Field(0.0, description="Frequency penalty")
|
||||
presence_penalty: Optional[float] = Field(0.0, description="Presence penalty")
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
"""Delta message entity for chat completion response."""
|
||||
@@ -122,3 +145,97 @@ class ErrorResponse(BaseModel):
|
||||
object: str = Field("error", description="Object type")
|
||||
message: str = Field(..., description="Error message")
|
||||
code: int = Field(..., description="Error code")
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
"""Embeddings request entity."""
|
||||
|
||||
model: Optional[str] = Field(None, description="Model name")
|
||||
engine: Optional[str] = Field(None, description="Engine name")
|
||||
input: Union[str, List[Any]] = Field(..., description="Input data")
|
||||
user: Optional[str] = Field(None, description="User name")
|
||||
encoding_format: Optional[str] = Field(None, description="Encoding format")
|
||||
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
"""Embeddings response entity."""
|
||||
|
||||
object: str = Field("list", description="Object type")
|
||||
data: List[Dict[str, Any]] = Field(..., description="Data list")
|
||||
model: str = Field(..., description="Model name")
|
||||
usage: UsageInfo = Field(..., description="Usage info")
|
||||
|
||||
|
||||
class ModelPermission(BaseModel):
|
||||
"""Model permission entity."""
|
||||
|
||||
id: str = Field(
|
||||
default_factory=lambda: f"modelperm-{str(uuid.uuid1())}",
|
||||
description="Permission ID",
|
||||
)
|
||||
object: str = Field("model_permission", description="Object type")
|
||||
created: int = Field(
|
||||
default_factory=lambda: int(time.time()), description="Created time"
|
||||
)
|
||||
allow_create_engine: bool = Field(False, description="Allow create engine")
|
||||
allow_sampling: bool = Field(True, description="Allow sampling")
|
||||
allow_logprobs: bool = Field(True, description="Allow logprobs")
|
||||
allow_search_indices: bool = Field(True, description="Allow search indices")
|
||||
allow_view: bool = Field(True, description="Allow view")
|
||||
allow_fine_tuning: bool = Field(False, description="Allow fine tuning")
|
||||
organization: str = Field("*", description="Organization")
|
||||
group: Optional[str] = Field(None, description="Group")
|
||||
is_blocking: bool = Field(False, description="Is blocking")
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
"""Model card entity."""
|
||||
|
||||
id: str = Field(..., description="Model ID")
|
||||
object: str = Field("model", description="Object type")
|
||||
created: int = Field(
|
||||
default_factory=lambda: int(time.time()), description="Created time"
|
||||
)
|
||||
owned_by: str = Field("DB-GPT", description="Owned by")
|
||||
root: Optional[str] = Field(None, description="Root")
|
||||
parent: Optional[str] = Field(None, description="Parent")
|
||||
permission: List[ModelPermission] = Field(
|
||||
default_factory=list, description="Permission"
|
||||
)
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
"""Model list entity."""
|
||||
|
||||
object: str = Field("list", description="Object type")
|
||||
data: List[ModelCard] = Field(default_factory=list, description="Model list data")
|
||||
|
||||
|
||||
class ErrorCode(IntEnum):
|
||||
"""Error code enumeration.
|
||||
|
||||
https://platform.openai.com/docs/guides/error-codes/api-errors.
|
||||
|
||||
Adapted from fastchat.constants.
|
||||
"""
|
||||
|
||||
VALIDATION_TYPE_ERROR = 40001
|
||||
|
||||
INVALID_AUTH_KEY = 40101
|
||||
INCORRECT_AUTH_KEY = 40102
|
||||
NO_PERMISSION = 40103
|
||||
|
||||
INVALID_MODEL = 40301
|
||||
PARAM_OUT_OF_RANGE = 40302
|
||||
CONTEXT_OVERFLOW = 40303
|
||||
|
||||
RATE_LIMIT = 42901
|
||||
QUOTA_EXCEEDED = 42902
|
||||
ENGINE_OVERLOADED = 42903
|
||||
|
||||
INTERNAL_ERROR = 50001
|
||||
CUDA_OUT_OF_MEMORY = 50002
|
||||
GRADIO_REQUEST_ERROR = 50003
|
||||
GRADIO_STREAM_UNKNOWN_ERROR = 50004
|
||||
CONTROLLER_NO_WORKER = 50005
|
||||
CONTROLLER_WORKER_TIMEOUT = 50006
|
||||
|
Reference in New Issue
Block a user