mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-05 18:33:52 +00:00
fix(model): Fix remote embedding model error in some case (#587)
This commit is contained in:
commit
465c065ef0
@ -152,6 +152,7 @@ class SystemApp(LifeCycle):
|
|||||||
@self.app.on_event("startup")
|
@self.app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
"""ASGI app startup event handler."""
|
"""ASGI app startup event handler."""
|
||||||
|
# TODO catch exception and shutdown if worker manager start failed
|
||||||
asyncio.create_task(self.async_after_start())
|
asyncio.create_task(self.async_after_start())
|
||||||
self.after_start()
|
self.after_start()
|
||||||
|
|
||||||
|
@ -18,9 +18,11 @@ class EmbeddingFactory(BaseComponet, ABC):
|
|||||||
|
|
||||||
|
|
||||||
class DefaultEmbeddingFactory(EmbeddingFactory):
|
class DefaultEmbeddingFactory(EmbeddingFactory):
|
||||||
def __init__(self, system_app=None, model_name: str = None, **kwargs: Any) -> None:
|
def __init__(
|
||||||
|
self, system_app=None, default_model_name: str = None, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
super().__init__(system_app=system_app)
|
super().__init__(system_app=system_app)
|
||||||
self._default_model_name = model_name
|
self._default_model_name = default_model_name
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def init_app(self, system_app):
|
def init_app(self, system_app):
|
||||||
@ -31,9 +33,13 @@ class DefaultEmbeddingFactory(EmbeddingFactory):
|
|||||||
) -> "Embeddings":
|
) -> "Embeddings":
|
||||||
if not model_name:
|
if not model_name:
|
||||||
model_name = self._default_model_name
|
model_name = self._default_model_name
|
||||||
|
|
||||||
|
new_kwargs = {k: v for k, v in self.kwargs.items()}
|
||||||
|
new_kwargs["model_name"] = model_name
|
||||||
|
|
||||||
if embedding_cls:
|
if embedding_cls:
|
||||||
return embedding_cls(model_name=model_name, **self.kwargs)
|
return embedding_cls(**new_kwargs)
|
||||||
else:
|
else:
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
return HuggingFaceEmbeddings(model_name=model_name, **self.kwargs)
|
return HuggingFaceEmbeddings(**new_kwargs)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Type
|
from typing import Dict, List, Type, Optional
|
||||||
|
|
||||||
from pilot.configs.model_config import get_device
|
from pilot.configs.model_config import get_device
|
||||||
from pilot.model.loader import _get_model_real_path
|
from pilot.model.loader import _get_model_real_path
|
||||||
@ -45,21 +45,12 @@ class EmbeddingsModelWorker(ModelWorker):
|
|||||||
self, command_args: List[str] = None
|
self, command_args: List[str] = None
|
||||||
) -> EmbeddingModelParameters:
|
) -> EmbeddingModelParameters:
|
||||||
param_cls = self.model_param_class()
|
param_cls = self.model_param_class()
|
||||||
model_args = EnvArgumentParser()
|
return _parse_embedding_params(
|
||||||
env_prefix = EnvArgumentParser.get_env_prefix(self.model_name)
|
|
||||||
model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass(
|
|
||||||
param_cls,
|
|
||||||
env_prefix=env_prefix,
|
|
||||||
command_args=command_args,
|
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_path=self.model_path,
|
model_path=self.model_path,
|
||||||
|
command_args=command_args,
|
||||||
|
param_cls=param_cls,
|
||||||
)
|
)
|
||||||
if not model_params.device:
|
|
||||||
model_params.device = get_device()
|
|
||||||
logger.info(
|
|
||||||
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
|
|
||||||
)
|
|
||||||
return model_params
|
|
||||||
|
|
||||||
def start(
|
def start(
|
||||||
self,
|
self,
|
||||||
@ -100,3 +91,26 @@ class EmbeddingsModelWorker(ModelWorker):
|
|||||||
logger.info(f"Receive embeddings request, model: {model}")
|
logger.info(f"Receive embeddings request, model: {model}")
|
||||||
input: List[str] = params["input"]
|
input: List[str] = params["input"]
|
||||||
return self._embeddings_impl.embed_documents(input)
|
return self._embeddings_impl.embed_documents(input)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_embedding_params(
|
||||||
|
model_name: str,
|
||||||
|
model_path: str,
|
||||||
|
command_args: List[str] = None,
|
||||||
|
param_cls: Optional[Type] = EmbeddingModelParameters,
|
||||||
|
):
|
||||||
|
model_args = EnvArgumentParser()
|
||||||
|
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
|
||||||
|
model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass(
|
||||||
|
param_cls,
|
||||||
|
env_prefix=env_prefix,
|
||||||
|
command_args=command_args,
|
||||||
|
model_name=model_name,
|
||||||
|
model_path=model_path,
|
||||||
|
)
|
||||||
|
if not model_params.device:
|
||||||
|
model_params.device = get_device()
|
||||||
|
logger.info(
|
||||||
|
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
|
||||||
|
)
|
||||||
|
return model_params
|
||||||
|
@ -635,6 +635,7 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
|
|||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
|
# TODO catch exception and shutdown if worker manager start failed
|
||||||
asyncio.create_task(worker_manager.start())
|
asyncio.create_task(worker_manager.start())
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
@app.on_event("shutdown")
|
||||||
|
@ -102,6 +102,12 @@ class WebWerverParameters(BaseParameters):
|
|||||||
"help": "Whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. "
|
"help": "Whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. "
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
remote_embedding: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether to enable remote embedding models. If it is True, you need to start a embedding model through `dbgpt start worker --worker_type text2vec --model_name xxx --model_path xxx`"
|
||||||
|
},
|
||||||
|
)
|
||||||
log_level: Optional[str] = field(
|
log_level: Optional[str] = field(
|
||||||
default="INFO",
|
default="INFO",
|
||||||
metadata={
|
metadata={
|
||||||
|
@ -1,21 +1,65 @@
|
|||||||
from typing import Any, Type, TYPE_CHECKING
|
from typing import Any, Type, TYPE_CHECKING
|
||||||
|
|
||||||
from pilot.componet import SystemApp
|
from pilot.componet import SystemApp
|
||||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
import logging
|
||||||
|
from pilot.configs.model_config import get_device
|
||||||
|
from pilot.embedding_engine.embedding_factory import (
|
||||||
|
EmbeddingFactory,
|
||||||
|
DefaultEmbeddingFactory,
|
||||||
|
)
|
||||||
|
from pilot.server.base import WebWerverParameters
|
||||||
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
|
||||||
|
|
||||||
def initialize_componets(system_app: SystemApp, embedding_model_name: str):
|
logger = logging.getLogger(__name__)
|
||||||
from pilot.model.cluster import worker_manager
|
|
||||||
|
|
||||||
|
def initialize_componets(
|
||||||
|
param: WebWerverParameters,
|
||||||
|
system_app: SystemApp,
|
||||||
|
embedding_model_name: str,
|
||||||
|
embedding_model_path: str,
|
||||||
|
):
|
||||||
from pilot.model.cluster.controller.controller import controller
|
from pilot.model.cluster.controller.controller import controller
|
||||||
|
|
||||||
system_app.register(
|
|
||||||
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
|
|
||||||
)
|
|
||||||
system_app.register_instance(controller)
|
system_app.register_instance(controller)
|
||||||
|
|
||||||
|
_initialize_embedding_model(
|
||||||
|
param, system_app, embedding_model_name, embedding_model_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_embedding_model(
|
||||||
|
param: WebWerverParameters,
|
||||||
|
system_app: SystemApp,
|
||||||
|
embedding_model_name: str,
|
||||||
|
embedding_model_path: str,
|
||||||
|
):
|
||||||
|
from pilot.model.cluster import worker_manager
|
||||||
|
|
||||||
|
if param.remote_embedding:
|
||||||
|
logger.info("Register remote RemoteEmbeddingFactory")
|
||||||
|
system_app.register(
|
||||||
|
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from pilot.model.parameter import EmbeddingModelParameters
|
||||||
|
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
|
||||||
|
|
||||||
|
model_params: EmbeddingModelParameters = _parse_embedding_params(
|
||||||
|
model_name=embedding_model_name,
|
||||||
|
model_path=embedding_model_path,
|
||||||
|
param_cls=EmbeddingModelParameters,
|
||||||
|
)
|
||||||
|
kwargs = model_params.build_kwargs(model_name=embedding_model_path)
|
||||||
|
logger.info(f"Register local DefaultEmbeddingFactory with kwargs: {kwargs}")
|
||||||
|
system_app.register(
|
||||||
|
DefaultEmbeddingFactory, default_model_name=embedding_model_path, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RemoteEmbeddingFactory(EmbeddingFactory):
|
class RemoteEmbeddingFactory(EmbeddingFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -109,20 +109,27 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
|||||||
# Before start
|
# Before start
|
||||||
system_app.before_start()
|
system_app.before_start()
|
||||||
|
|
||||||
|
print(param)
|
||||||
|
|
||||||
|
embedding_model_name = CFG.EMBEDDING_MODEL
|
||||||
|
embedding_model_path = EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||||
|
|
||||||
server_init(param, system_app)
|
server_init(param, system_app)
|
||||||
model_start_listener = _create_model_start_listener(system_app)
|
model_start_listener = _create_model_start_listener(system_app)
|
||||||
initialize_componets(system_app, CFG.EMBEDDING_MODEL)
|
initialize_componets(param, system_app, embedding_model_name, embedding_model_path)
|
||||||
|
|
||||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||||
if not param.light:
|
if not param.light:
|
||||||
print("Model Unified Deployment Mode!")
|
print("Model Unified Deployment Mode!")
|
||||||
|
if not param.remote_embedding:
|
||||||
|
embedding_model_name, embedding_model_path = None, None
|
||||||
initialize_worker_manager_in_client(
|
initialize_worker_manager_in_client(
|
||||||
app=app,
|
app=app,
|
||||||
model_name=CFG.LLM_MODEL,
|
model_name=CFG.LLM_MODEL,
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
local_port=param.port,
|
local_port=param.port,
|
||||||
embedding_model_name=CFG.EMBEDDING_MODEL,
|
embedding_model_name=embedding_model_name,
|
||||||
embedding_model_path=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
embedding_model_path=embedding_model_path,
|
||||||
start_listener=model_start_listener,
|
start_listener=model_start_listener,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user