mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-09 04:08:10 +00:00
Merge remote-tracking branch 'origin/main' into feat_llm_manage
This commit is contained in:
commit
8c4accb09e
@ -46,6 +46,7 @@ class ComponentType(str, Enum):
|
|||||||
WORKER_MANAGER = "dbgpt_worker_manager"
|
WORKER_MANAGER = "dbgpt_worker_manager"
|
||||||
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
|
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
|
||||||
MODEL_CONTROLLER = "dbgpt_model_controller"
|
MODEL_CONTROLLER = "dbgpt_model_controller"
|
||||||
|
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
|
||||||
|
|
||||||
|
|
||||||
class BaseComponent(LifeCycle, ABC):
|
class BaseComponent(LifeCycle, ABC):
|
||||||
|
@ -141,6 +141,7 @@ class Config(metaclass=Singleton):
|
|||||||
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
|
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
|
||||||
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
||||||
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
||||||
|
self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10))
|
||||||
|
|
||||||
self.LOCAL_DB_MANAGE = None
|
self.LOCAL_DB_MANAGE = None
|
||||||
|
|
||||||
|
@ -4,7 +4,9 @@ import asyncio
|
|||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
|
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
|
||||||
from pilot.common.schema import DBType
|
from pilot.common.schema import DBType
|
||||||
from pilot.component import SystemApp
|
from pilot.component import SystemApp, ComponentType
|
||||||
|
from pilot.utils.executor_utils import ExecutorFactory
|
||||||
|
|
||||||
from pilot.connections.rdbms.conn_mysql import MySQLConnect
|
from pilot.connections.rdbms.conn_mysql import MySQLConnect
|
||||||
from pilot.connections.base import BaseConnect
|
from pilot.connections.base import BaseConnect
|
||||||
|
|
||||||
@ -76,7 +78,11 @@ class ConnectManager:
|
|||||||
+ CFG.LOCAL_DB_HOST
|
+ CFG.LOCAL_DB_HOST
|
||||||
+ ":"
|
+ ":"
|
||||||
+ str(CFG.LOCAL_DB_PORT),
|
+ str(CFG.LOCAL_DB_PORT),
|
||||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
engine_args={
|
||||||
|
"pool_size": CFG.LOCAL_DB_POOL_SIZE,
|
||||||
|
"pool_recycle": 3600,
|
||||||
|
"echo": True,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
# default_mysql = MySQLConnect.from_uri(
|
# default_mysql = MySQLConnect.from_uri(
|
||||||
# "mysql+pymysql://"
|
# "mysql+pymysql://"
|
||||||
@ -208,13 +214,15 @@ class ConnectManager:
|
|||||||
db_info.comment,
|
db_info.comment,
|
||||||
)
|
)
|
||||||
# async embedding
|
# async embedding
|
||||||
thread = threading.Thread(
|
executor = CFG.SYSTEM_APP.get_component(
|
||||||
target=self.db_summary_client.db_summary_embedding(
|
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||||
db_info.db_name, db_info.db_type
|
).create()
|
||||||
)
|
executor.submit(
|
||||||
|
self.db_summary_client.db_summary_embedding,
|
||||||
|
db_info.db_name,
|
||||||
|
db_info.db_type,
|
||||||
)
|
)
|
||||||
thread.start()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError("Add db connect info error!" + str(e))
|
raise ValueError("Add db connect info error!" + str(e))
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
@ -1,17 +1,18 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
|
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, FastAPI
|
from fastapi import APIRouter, FastAPI
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from pilot.component import SystemApp
|
from pilot.component import SystemApp
|
||||||
from pilot.model.base import (
|
from pilot.model.base import (
|
||||||
ModelInstance,
|
ModelInstance,
|
||||||
@ -20,16 +21,16 @@ from pilot.model.base import (
|
|||||||
WorkerApplyType,
|
WorkerApplyType,
|
||||||
WorkerSupportedModel,
|
WorkerSupportedModel,
|
||||||
)
|
)
|
||||||
from pilot.model.cluster.registry import ModelRegistry
|
from pilot.model.cluster.base import *
|
||||||
from pilot.model.llm_utils import list_supported_models
|
|
||||||
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
|
||||||
from pilot.model.cluster.worker_base import ModelWorker
|
|
||||||
from pilot.model.cluster.manager_base import (
|
from pilot.model.cluster.manager_base import (
|
||||||
WorkerManager,
|
WorkerManager,
|
||||||
WorkerRunData,
|
|
||||||
WorkerManagerFactory,
|
WorkerManagerFactory,
|
||||||
|
WorkerRunData,
|
||||||
)
|
)
|
||||||
from pilot.model.cluster.base import *
|
from pilot.model.cluster.registry import ModelRegistry
|
||||||
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
|
from pilot.model.llm_utils import list_supported_models
|
||||||
|
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
||||||
from pilot.utils.parameter_utils import (
|
from pilot.utils.parameter_utils import (
|
||||||
EnvArgumentParser,
|
EnvArgumentParser,
|
||||||
ParameterDescription,
|
ParameterDescription,
|
||||||
@ -639,6 +640,10 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not worker_params.controller_addr:
|
if not worker_params.controller_addr:
|
||||||
|
# if we have http_proxy or https_proxy in env, the server can not start
|
||||||
|
# so set it to empty here
|
||||||
|
os.environ["http_proxy"] = ""
|
||||||
|
os.environ["https_proxy"] = ""
|
||||||
worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}"
|
worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}"
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Run WorkerManager with standalone mode, controller_addr: {worker_params.controller_addr}"
|
f"Run WorkerManager with standalone mode, controller_addr: {worker_params.controller_addr}"
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
@ -246,6 +247,11 @@ class ProxyModelParameters(BaseModelParameters):
|
|||||||
proxy_api_key: str = field(
|
proxy_api_key: str = field(
|
||||||
metadata={"tags": "privacy", "help": "The api key of current proxy LLM"},
|
metadata={"tags": "privacy", "help": "The api key of current proxy LLM"},
|
||||||
)
|
)
|
||||||
|
http_proxy: Optional[str] = field(
|
||||||
|
default=os.environ.get("http_proxy") or os.environ.get("https_proxy"),
|
||||||
|
metadata={"help": "The http or https proxy to use openai"},
|
||||||
|
)
|
||||||
|
|
||||||
proxyllm_backend: Optional[str] = field(
|
proxyllm_backend: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
|
@ -4,6 +4,7 @@ import logging
|
|||||||
from typing import TYPE_CHECKING, Any, Type
|
from typing import TYPE_CHECKING, Any, Type
|
||||||
|
|
||||||
from pilot.component import ComponentType, SystemApp
|
from pilot.component import ComponentType, SystemApp
|
||||||
|
from pilot.utils.executor_utils import DefaultExecutorFactory
|
||||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||||
from pilot.server.base import WebWerverParameters
|
from pilot.server.base import WebWerverParameters
|
||||||
|
|
||||||
@ -22,6 +23,9 @@ def initialize_components(
|
|||||||
):
|
):
|
||||||
from pilot.model.cluster.controller.controller import controller
|
from pilot.model.cluster.controller.controller import controller
|
||||||
|
|
||||||
|
# Register global default executor factory first
|
||||||
|
system_app.register(DefaultExecutorFactory)
|
||||||
|
|
||||||
system_app.register_instance(controller)
|
system_app.register_instance(controller)
|
||||||
|
|
||||||
_initialize_embedding_model(
|
_initialize_embedding_model(
|
||||||
|
@ -124,12 +124,13 @@ def knowledge_init(
|
|||||||
def upload(filename: str):
|
def upload(filename: str):
|
||||||
try:
|
try:
|
||||||
logger.info(f"Begin upload document: {filename} to {space.name}")
|
logger.info(f"Begin upload document: {filename} to {space.name}")
|
||||||
return client.document_upload(
|
doc_id = client.document_upload(
|
||||||
space.name, filename, KnowledgeType.DOCUMENT.value, filename
|
space.name, filename, KnowledgeType.DOCUMENT.value, filename
|
||||||
)
|
)
|
||||||
|
client.document_sync(space.name, DocumentSyncRequest(doc_ids=[doc_id]))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if skip_wrong_doc:
|
if skip_wrong_doc:
|
||||||
logger.warn(f"Warning: {str(e)}")
|
logger.warn(f"Upload {filename} to {space.name} failed: {str(e)}")
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@ -144,5 +145,3 @@ def knowledge_init(
|
|||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
logger.warn("Warning: no document to sync")
|
logger.warn("Warning: no document to sync")
|
||||||
return
|
return
|
||||||
logger.info(f"Begin sync document: {doc_ids}")
|
|
||||||
client.document_sync(space.name, DocumentSyncRequest(doc_ids=doc_ids))
|
|
||||||
|
@ -9,6 +9,9 @@ from pilot.configs.model_config import (
|
|||||||
EMBEDDING_MODEL_CONFIG,
|
EMBEDDING_MODEL_CONFIG,
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
)
|
)
|
||||||
|
from pilot.component import ComponentType
|
||||||
|
from pilot.utils.executor_utils import ExecutorFactory
|
||||||
|
|
||||||
from pilot.logs import logger
|
from pilot.logs import logger
|
||||||
from pilot.server.knowledge.chunk_db import (
|
from pilot.server.knowledge.chunk_db import (
|
||||||
DocumentChunkEntity,
|
DocumentChunkEntity,
|
||||||
@ -227,10 +230,15 @@ class KnowledgeService:
|
|||||||
doc.gmt_modified = datetime.now()
|
doc.gmt_modified = datetime.now()
|
||||||
knowledge_document_dao.update_knowledge_document(doc)
|
knowledge_document_dao.update_knowledge_document(doc)
|
||||||
# async doc embeddings
|
# async doc embeddings
|
||||||
thread = threading.Thread(
|
# thread = threading.Thread(
|
||||||
target=self.async_doc_embedding, args=(client, chunk_docs, doc)
|
# target=self.async_doc_embedding, args=(client, chunk_docs, doc)
|
||||||
)
|
# )
|
||||||
thread.start()
|
# thread.start()
|
||||||
|
executor = CFG.SYSTEM_APP.get_component(
|
||||||
|
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||||
|
).create()
|
||||||
|
executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
|
||||||
|
|
||||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||||
# save chunk details
|
# save chunk details
|
||||||
chunk_entities = [
|
chunk_entities = [
|
||||||
|
26
pilot/utils/executor_utils.py
Normal file
26
pilot/utils/executor_utils.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||||
|
|
||||||
|
from pilot.component import BaseComponent, ComponentType, SystemApp
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutorFactory(BaseComponent, ABC):
|
||||||
|
name = ComponentType.EXECUTOR_DEFAULT.value
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(self) -> "Executor":
|
||||||
|
"""Create executor"""
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultExecutorFactory(ExecutorFactory):
|
||||||
|
def __init__(self, system_app: SystemApp | None = None, max_workers=None):
|
||||||
|
super().__init__(system_app)
|
||||||
|
self._executor = ThreadPoolExecutor(
|
||||||
|
max_workers=max_workers, thread_name_prefix=self.name
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_app(self, system_app: SystemApp):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def create(self) -> Executor:
|
||||||
|
return self._executor
|
6
setup.py
6
setup.py
@ -8,7 +8,6 @@ from enum import Enum
|
|||||||
import urllib.request
|
import urllib.request
|
||||||
from urllib.parse import urlparse, quote
|
from urllib.parse import urlparse, quote
|
||||||
import re
|
import re
|
||||||
from pip._internal.utils.appdirs import user_cache_dir
|
|
||||||
import shutil
|
import shutil
|
||||||
from setuptools import find_packages
|
from setuptools import find_packages
|
||||||
|
|
||||||
@ -67,6 +66,9 @@ def cache_package(package_url: str, package_name: str, is_windows: bool = False)
|
|||||||
safe_url, parsed_url = encode_url(package_url)
|
safe_url, parsed_url = encode_url(package_url)
|
||||||
if BUILD_NO_CACHE:
|
if BUILD_NO_CACHE:
|
||||||
return safe_url
|
return safe_url
|
||||||
|
|
||||||
|
from pip._internal.utils.appdirs import user_cache_dir
|
||||||
|
|
||||||
filename = os.path.basename(parsed_url)
|
filename = os.path.basename(parsed_url)
|
||||||
cache_dir = os.path.join(user_cache_dir("pip"), "http", "wheels", package_name)
|
cache_dir = os.path.join(user_cache_dir("pip"), "http", "wheels", package_name)
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
@ -279,7 +281,7 @@ def core_requires():
|
|||||||
"importlib-resources==5.12.0",
|
"importlib-resources==5.12.0",
|
||||||
"psutil==5.9.4",
|
"psutil==5.9.4",
|
||||||
"python-dotenv==1.0.0",
|
"python-dotenv==1.0.0",
|
||||||
"colorama",
|
"colorama==0.4.10",
|
||||||
"prettytable",
|
"prettytable",
|
||||||
"cachetools",
|
"cachetools",
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user