mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
fix(model): Fix remote embedding model error in some case (#587)
This commit is contained in:
commit
465c065ef0
@ -152,6 +152,7 @@ class SystemApp(LifeCycle):
|
||||
@self.app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""ASGI app startup event handler."""
|
||||
# TODO catch exception and shutdown if worker manager start failed
|
||||
asyncio.create_task(self.async_after_start())
|
||||
self.after_start()
|
||||
|
||||
|
@ -18,9 +18,11 @@ class EmbeddingFactory(BaseComponet, ABC):
|
||||
|
||||
|
||||
class DefaultEmbeddingFactory(EmbeddingFactory):
|
||||
def __init__(self, system_app=None, model_name: str = None, **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self, system_app=None, default_model_name: str = None, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
self._default_model_name = model_name
|
||||
self._default_model_name = default_model_name
|
||||
self.kwargs = kwargs
|
||||
|
||||
def init_app(self, system_app):
|
||||
@ -31,9 +33,13 @@ class DefaultEmbeddingFactory(EmbeddingFactory):
|
||||
) -> "Embeddings":
|
||||
if not model_name:
|
||||
model_name = self._default_model_name
|
||||
|
||||
new_kwargs = {k: v for k, v in self.kwargs.items()}
|
||||
new_kwargs["model_name"] = model_name
|
||||
|
||||
if embedding_cls:
|
||||
return embedding_cls(model_name=model_name, **self.kwargs)
|
||||
return embedding_cls(**new_kwargs)
|
||||
else:
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
return HuggingFaceEmbeddings(model_name=model_name, **self.kwargs)
|
||||
return HuggingFaceEmbeddings(**new_kwargs)
|
||||
|
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Dict, List, Type
|
||||
from typing import Dict, List, Type, Optional
|
||||
|
||||
from pilot.configs.model_config import get_device
|
||||
from pilot.model.loader import _get_model_real_path
|
||||
@ -45,21 +45,12 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
self, command_args: List[str] = None
|
||||
) -> EmbeddingModelParameters:
|
||||
param_cls = self.model_param_class()
|
||||
model_args = EnvArgumentParser()
|
||||
env_prefix = EnvArgumentParser.get_env_prefix(self.model_name)
|
||||
model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass(
|
||||
param_cls,
|
||||
env_prefix=env_prefix,
|
||||
command_args=command_args,
|
||||
return _parse_embedding_params(
|
||||
model_name=self.model_name,
|
||||
model_path=self.model_path,
|
||||
command_args=command_args,
|
||||
param_cls=param_cls,
|
||||
)
|
||||
if not model_params.device:
|
||||
model_params.device = get_device()
|
||||
logger.info(
|
||||
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
|
||||
)
|
||||
return model_params
|
||||
|
||||
def start(
|
||||
self,
|
||||
@ -100,3 +91,26 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
logger.info(f"Receive embeddings request, model: {model}")
|
||||
input: List[str] = params["input"]
|
||||
return self._embeddings_impl.embed_documents(input)
|
||||
|
||||
|
||||
def _parse_embedding_params(
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
command_args: List[str] = None,
|
||||
param_cls: Optional[Type] = EmbeddingModelParameters,
|
||||
):
|
||||
model_args = EnvArgumentParser()
|
||||
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
|
||||
model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass(
|
||||
param_cls,
|
||||
env_prefix=env_prefix,
|
||||
command_args=command_args,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
)
|
||||
if not model_params.device:
|
||||
model_params.device = get_device()
|
||||
logger.info(
|
||||
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
|
||||
)
|
||||
return model_params
|
||||
|
@ -635,6 +635,7 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
# TODO catch exception and shutdown if worker manager start failed
|
||||
asyncio.create_task(worker_manager.start())
|
||||
|
||||
@app.on_event("shutdown")
|
||||
|
@ -102,6 +102,12 @@ class WebWerverParameters(BaseParameters):
|
||||
"help": "Whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. "
|
||||
},
|
||||
)
|
||||
remote_embedding: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to enable remote embedding models. If it is True, you need to start a embedding model through `dbgpt start worker --worker_type text2vec --model_name xxx --model_path xxx`"
|
||||
},
|
||||
)
|
||||
log_level: Optional[str] = field(
|
||||
default="INFO",
|
||||
metadata={
|
||||
|
@ -1,21 +1,65 @@
|
||||
from typing import Any, Type, TYPE_CHECKING
|
||||
|
||||
from pilot.componet import SystemApp
|
||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
import logging
|
||||
from pilot.configs.model_config import get_device
|
||||
from pilot.embedding_engine.embedding_factory import (
|
||||
EmbeddingFactory,
|
||||
DefaultEmbeddingFactory,
|
||||
)
|
||||
from pilot.server.base import WebWerverParameters
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
def initialize_componets(system_app: SystemApp, embedding_model_name: str):
|
||||
from pilot.model.cluster import worker_manager
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_componets(
|
||||
param: WebWerverParameters,
|
||||
system_app: SystemApp,
|
||||
embedding_model_name: str,
|
||||
embedding_model_path: str,
|
||||
):
|
||||
from pilot.model.cluster.controller.controller import controller
|
||||
|
||||
system_app.register(
|
||||
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
|
||||
)
|
||||
system_app.register_instance(controller)
|
||||
|
||||
_initialize_embedding_model(
|
||||
param, system_app, embedding_model_name, embedding_model_path
|
||||
)
|
||||
|
||||
|
||||
def _initialize_embedding_model(
|
||||
param: WebWerverParameters,
|
||||
system_app: SystemApp,
|
||||
embedding_model_name: str,
|
||||
embedding_model_path: str,
|
||||
):
|
||||
from pilot.model.cluster import worker_manager
|
||||
|
||||
if param.remote_embedding:
|
||||
logger.info("Register remote RemoteEmbeddingFactory")
|
||||
system_app.register(
|
||||
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
|
||||
)
|
||||
else:
|
||||
from pilot.model.parameter import EmbeddingModelParameters
|
||||
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
|
||||
|
||||
model_params: EmbeddingModelParameters = _parse_embedding_params(
|
||||
model_name=embedding_model_name,
|
||||
model_path=embedding_model_path,
|
||||
param_cls=EmbeddingModelParameters,
|
||||
)
|
||||
kwargs = model_params.build_kwargs(model_name=embedding_model_path)
|
||||
logger.info(f"Register local DefaultEmbeddingFactory with kwargs: {kwargs}")
|
||||
system_app.register(
|
||||
DefaultEmbeddingFactory, default_model_name=embedding_model_path, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class RemoteEmbeddingFactory(EmbeddingFactory):
|
||||
def __init__(
|
||||
|
@ -109,20 +109,27 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||
# Before start
|
||||
system_app.before_start()
|
||||
|
||||
print(param)
|
||||
|
||||
embedding_model_name = CFG.EMBEDDING_MODEL
|
||||
embedding_model_path = EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
|
||||
server_init(param, system_app)
|
||||
model_start_listener = _create_model_start_listener(system_app)
|
||||
initialize_componets(system_app, CFG.EMBEDDING_MODEL)
|
||||
initialize_componets(param, system_app, embedding_model_name, embedding_model_path)
|
||||
|
||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||
if not param.light:
|
||||
print("Model Unified Deployment Mode!")
|
||||
if not param.remote_embedding:
|
||||
embedding_model_name, embedding_model_path = None, None
|
||||
initialize_worker_manager_in_client(
|
||||
app=app,
|
||||
model_name=CFG.LLM_MODEL,
|
||||
model_path=model_path,
|
||||
local_port=param.port,
|
||||
embedding_model_name=CFG.EMBEDDING_MODEL,
|
||||
embedding_model_path=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
embedding_model_name=embedding_model_name,
|
||||
embedding_model_path=embedding_model_path,
|
||||
start_listener=model_start_listener,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user