Merge remote-tracking branch 'origin/main' into feat_llm_manage

This commit is contained in:
aries_ckt 2023-09-26 11:16:32 +08:00
commit 8c4accb09e
10 changed files with 88 additions and 28 deletions

View File

@ -46,6 +46,7 @@ class ComponentType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager"
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
MODEL_CONTROLLER = "dbgpt_model_controller"
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
class BaseComponent(LifeCycle, ABC):

View File

@ -141,6 +141,7 @@ class Config(metaclass=Singleton):
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10))
self.LOCAL_DB_MANAGE = None

View File

@ -4,7 +4,9 @@ import asyncio
from pilot.configs.config import Config
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
from pilot.common.schema import DBType
from pilot.component import SystemApp
from pilot.component import SystemApp, ComponentType
from pilot.utils.executor_utils import ExecutorFactory
from pilot.connections.rdbms.conn_mysql import MySQLConnect
from pilot.connections.base import BaseConnect
@ -76,7 +78,11 @@ class ConnectManager:
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT),
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
engine_args={
"pool_size": CFG.LOCAL_DB_POOL_SIZE,
"pool_recycle": 3600,
"echo": True,
},
)
# default_mysql = MySQLConnect.from_uri(
# "mysql+pymysql://"
@ -208,13 +214,15 @@ class ConnectManager:
db_info.comment,
)
# async embedding
thread = threading.Thread(
target=self.db_summary_client.db_summary_embedding(
db_info.db_name, db_info.db_type
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
executor.submit(
self.db_summary_client.db_summary_embedding,
db_info.db_name,
db_info.db_type,
)
)
thread.start()
except Exception as e:
raise ValueError("Add db connect info error" + str(e))
raise ValueError("Add db connect info error!" + str(e))
return True

View File

@ -1,17 +1,18 @@
import asyncio
import itertools
import json
import os
import sys
import random
import time
import logging
import os
import random
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
from fastapi import APIRouter, FastAPI
from fastapi.responses import StreamingResponse
from pilot.component import SystemApp
from pilot.model.base import (
ModelInstance,
@ -20,16 +21,16 @@ from pilot.model.base import (
WorkerApplyType,
WorkerSupportedModel,
)
from pilot.model.cluster.registry import ModelRegistry
from pilot.model.llm_utils import list_supported_models
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.base import *
from pilot.model.cluster.manager_base import (
WorkerManager,
WorkerRunData,
WorkerManagerFactory,
WorkerRunData,
)
from pilot.model.cluster.base import *
from pilot.model.cluster.registry import ModelRegistry
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.llm_utils import list_supported_models
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
from pilot.utils.parameter_utils import (
EnvArgumentParser,
ParameterDescription,
@ -639,6 +640,10 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
)
if not worker_params.controller_addr:
# if we have http_proxy or https_proxy in env, the server can not start
# so set it to empty here
os.environ["http_proxy"] = ""
os.environ["https_proxy"] = ""
worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}"
logger.info(
f"Run WorkerManager with standalone mode, controller_addr: {worker_params.controller_addr}"

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Optional
@ -246,6 +247,11 @@ class ProxyModelParameters(BaseModelParameters):
proxy_api_key: str = field(
metadata={"tags": "privacy", "help": "The api key of current proxy LLM"},
)
http_proxy: Optional[str] = field(
default=os.environ.get("http_proxy") or os.environ.get("https_proxy"),
metadata={"help": "The http or https proxy to use openai"},
)
proxyllm_backend: Optional[str] = field(
default=None,
metadata={

View File

@ -4,6 +4,7 @@ import logging
from typing import TYPE_CHECKING, Any, Type
from pilot.component import ComponentType, SystemApp
from pilot.utils.executor_utils import DefaultExecutorFactory
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.server.base import WebWerverParameters
@ -22,6 +23,9 @@ def initialize_components(
):
from pilot.model.cluster.controller.controller import controller
# Register global default executor factory first
system_app.register(DefaultExecutorFactory)
system_app.register_instance(controller)
_initialize_embedding_model(

View File

@ -124,12 +124,13 @@ def knowledge_init(
def upload(filename: str):
try:
logger.info(f"Begin upload document: {filename} to {space.name}")
return client.document_upload(
doc_id = client.document_upload(
space.name, filename, KnowledgeType.DOCUMENT.value, filename
)
client.document_sync(space.name, DocumentSyncRequest(doc_ids=[doc_id]))
except Exception as e:
if skip_wrong_doc:
logger.warn(f"Warning: {str(e)}")
logger.warn(f"Upload {filename} to {space.name} failed: {str(e)}")
else:
raise e
@ -144,5 +145,3 @@ def knowledge_init(
if not doc_ids:
logger.warn("Warning: no document to sync")
return
logger.info(f"Begin sync document: {doc_ids}")
client.document_sync(space.name, DocumentSyncRequest(doc_ids=doc_ids))

View File

@ -9,6 +9,9 @@ from pilot.configs.model_config import (
EMBEDDING_MODEL_CONFIG,
KNOWLEDGE_UPLOAD_ROOT_PATH,
)
from pilot.component import ComponentType
from pilot.utils.executor_utils import ExecutorFactory
from pilot.logs import logger
from pilot.server.knowledge.chunk_db import (
DocumentChunkEntity,
@ -227,10 +230,15 @@ class KnowledgeService:
doc.gmt_modified = datetime.now()
knowledge_document_dao.update_knowledge_document(doc)
# async doc embeddings
thread = threading.Thread(
target=self.async_doc_embedding, args=(client, chunk_docs, doc)
)
thread.start()
# thread = threading.Thread(
# target=self.async_doc_embedding, args=(client, chunk_docs, doc)
# )
# thread.start()
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
# save chunk details
chunk_entities = [

View File

@ -0,0 +1,26 @@
from abc import ABC, abstractmethod
from concurrent.futures import Executor, ThreadPoolExecutor
from pilot.component import BaseComponent, ComponentType, SystemApp
class ExecutorFactory(BaseComponent, ABC):
name = ComponentType.EXECUTOR_DEFAULT.value
@abstractmethod
def create(self) -> "Executor":
"""Create executor"""
class DefaultExecutorFactory(ExecutorFactory):
def __init__(self, system_app: SystemApp | None = None, max_workers=None):
super().__init__(system_app)
self._executor = ThreadPoolExecutor(
max_workers=max_workers, thread_name_prefix=self.name
)
def init_app(self, system_app: SystemApp):
pass
def create(self) -> Executor:
return self._executor

View File

@ -8,7 +8,6 @@ from enum import Enum
import urllib.request
from urllib.parse import urlparse, quote
import re
from pip._internal.utils.appdirs import user_cache_dir
import shutil
from setuptools import find_packages
@ -67,6 +66,9 @@ def cache_package(package_url: str, package_name: str, is_windows: bool = False)
safe_url, parsed_url = encode_url(package_url)
if BUILD_NO_CACHE:
return safe_url
from pip._internal.utils.appdirs import user_cache_dir
filename = os.path.basename(parsed_url)
cache_dir = os.path.join(user_cache_dir("pip"), "http", "wheels", package_name)
os.makedirs(cache_dir, exist_ok=True)
@ -279,7 +281,7 @@ def core_requires():
"importlib-resources==5.12.0",
"psutil==5.9.4",
"python-dotenv==1.0.0",
"colorama",
"colorama==0.4.10",
"prettytable",
"cachetools",
]