mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-17 07:00:15 +00:00
feat(model): Support llama.cpp server deploy (#2263)
This commit is contained in:
@@ -55,6 +55,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
model_type = self.llm_adapter.model_type()
|
||||
self.param_cls = self.llm_adapter.model_param_class(model_type)
|
||||
self._support_async = self.llm_adapter.support_async()
|
||||
self._support_generate_func = self.llm_adapter.support_generate_function()
|
||||
|
||||
logger.info(
|
||||
f"model_name: {self.model_name}, model_path: {self.model_path}, model_param_class: {self.param_cls}"
|
||||
@@ -85,7 +86,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
model_path=self.model_path,
|
||||
model_type=model_type,
|
||||
)
|
||||
if not model_params.device:
|
||||
if hasattr(model_params, "device") and not model_params.device:
|
||||
model_params.device = get_device()
|
||||
logger.info(
|
||||
f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}"
|
||||
@@ -129,7 +130,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
|
||||
def stop(self) -> None:
|
||||
if not self.model:
|
||||
logger.warn("Model has been stopped!!")
|
||||
logger.warning("Model has been stopped!!")
|
||||
return
|
||||
del self.model
|
||||
del self.tokenizer
|
||||
@@ -190,9 +191,40 @@ class DefaultModelWorker(ModelWorker):
|
||||
def generate(self, params: Dict) -> ModelOutput:
|
||||
"""Generate non stream result"""
|
||||
output = None
|
||||
for out in self.generate_stream(params):
|
||||
output = out
|
||||
return output
|
||||
if self._support_generate_func:
|
||||
(
|
||||
params,
|
||||
model_context,
|
||||
generate_stream_func,
|
||||
model_span,
|
||||
) = self._prepare_generate_stream(
|
||||
params,
|
||||
span_operation_name="DefaultModelWorker_call.generate_func",
|
||||
is_stream=False,
|
||||
)
|
||||
previous_response = ""
|
||||
last_metrics = ModelInferenceMetrics.create_metrics()
|
||||
is_first_generate = True
|
||||
output = generate_stream_func(
|
||||
self.model, self.tokenizer, params, get_device(), self.context_len
|
||||
)
|
||||
(
|
||||
model_output,
|
||||
incremental_output,
|
||||
output_str,
|
||||
current_metrics,
|
||||
) = self._handle_output(
|
||||
output,
|
||||
previous_response,
|
||||
model_context,
|
||||
last_metrics,
|
||||
is_first_generate,
|
||||
)
|
||||
return model_output
|
||||
else:
|
||||
for out in self.generate_stream(params):
|
||||
output = out
|
||||
return output
|
||||
|
||||
def count_token(self, prompt: str) -> int:
|
||||
return _try_to_count_token(prompt, self.tokenizer, self.model)
|
||||
@@ -277,12 +309,45 @@ class DefaultModelWorker(ModelWorker):
|
||||
span.end(metadata={"error": output.to_dict()})
|
||||
|
||||
async def async_generate(self, params: Dict) -> ModelOutput:
|
||||
output = None
|
||||
async for out in self.async_generate_stream(params):
|
||||
output = out
|
||||
return output
|
||||
if self._support_generate_func:
|
||||
(
|
||||
params,
|
||||
model_context,
|
||||
generate_stream_func,
|
||||
model_span,
|
||||
) = self._prepare_generate_stream(
|
||||
params,
|
||||
span_operation_name="DefaultModelWorker_call.generate_func",
|
||||
is_stream=False,
|
||||
)
|
||||
previous_response = ""
|
||||
last_metrics = ModelInferenceMetrics.create_metrics()
|
||||
is_first_generate = True
|
||||
output = await generate_stream_func(
|
||||
self.model, self.tokenizer, params, get_device(), self.context_len
|
||||
)
|
||||
(
|
||||
model_output,
|
||||
incremental_output,
|
||||
output_str,
|
||||
current_metrics,
|
||||
) = self._handle_output(
|
||||
output,
|
||||
previous_response,
|
||||
model_context,
|
||||
last_metrics,
|
||||
is_first_generate,
|
||||
)
|
||||
return model_output
|
||||
else:
|
||||
output = None
|
||||
async for out in self.async_generate_stream(params):
|
||||
output = out
|
||||
return output
|
||||
|
||||
def _prepare_generate_stream(self, params: Dict, span_operation_name: str):
|
||||
def _prepare_generate_stream(
|
||||
self, params: Dict, span_operation_name: str, is_stream=True
|
||||
):
|
||||
params, model_context = self.llm_adapter.model_adaptation(
|
||||
params,
|
||||
self.model_name,
|
||||
@@ -290,29 +355,48 @@ class DefaultModelWorker(ModelWorker):
|
||||
self.tokenizer,
|
||||
prompt_template=self.ml.prompt_template,
|
||||
)
|
||||
stream_type = ""
|
||||
if self.support_async():
|
||||
generate_stream_func = self.llm_adapter.get_async_generate_stream_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
stream_type = "async "
|
||||
logger.info(
|
||||
"current generate stream function is asynchronous stream function"
|
||||
)
|
||||
if not is_stream and self.llm_adapter.support_generate_function():
|
||||
func = self.llm_adapter.get_generate_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
func_type = "async generate"
|
||||
logger.info(
|
||||
"current generate function is asynchronous generate function"
|
||||
)
|
||||
else:
|
||||
func = self.llm_adapter.get_async_generate_stream_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
func_type = "async generate stream"
|
||||
logger.info(
|
||||
"current generate stream function is asynchronous generate stream function"
|
||||
)
|
||||
else:
|
||||
generate_stream_func = self.llm_adapter.get_generate_stream_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
if not is_stream and self.llm_adapter.support_generate_function():
|
||||
func = self.llm_adapter.get_generate_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
func_type = "generate"
|
||||
logger.info(
|
||||
"current generate function is synchronous generate function"
|
||||
)
|
||||
else:
|
||||
func = self.llm_adapter.get_generate_stream_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
func_type = "generate stream"
|
||||
logger.info(
|
||||
"current generate stream function is synchronous generate stream function"
|
||||
)
|
||||
str_prompt = params.get("prompt")
|
||||
if not str_prompt:
|
||||
str_prompt = params.get("string_prompt")
|
||||
print(
|
||||
f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n"
|
||||
f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{func_type} output:\n"
|
||||
)
|
||||
|
||||
generate_stream_func_str_name = "{}.{}".format(
|
||||
generate_stream_func.__module__, generate_stream_func.__name__
|
||||
)
|
||||
generate_func_str_name = "{}.{}".format(func.__module__, func.__name__)
|
||||
|
||||
span_params = {k: v for k, v in params.items()}
|
||||
if "messages" in span_params:
|
||||
@@ -323,7 +407,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
metadata = {
|
||||
"is_async_func": self.support_async(),
|
||||
"llm_adapter": str(self.llm_adapter),
|
||||
"generate_stream_func": generate_stream_func_str_name,
|
||||
"generate_func": generate_func_str_name,
|
||||
}
|
||||
metadata.update(span_params)
|
||||
metadata.update(model_context)
|
||||
@@ -331,7 +415,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
|
||||
model_span = root_tracer.start_span(span_operation_name, metadata=metadata)
|
||||
|
||||
return params, model_context, generate_stream_func, model_span
|
||||
return params, model_context, func, model_span
|
||||
|
||||
def _handle_output(
|
||||
self,
|
||||
|
@@ -89,6 +89,8 @@ class RemoteModelWorker(ModelWorker):
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if response.status_code not in [200, 201]:
|
||||
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||
return ModelOutput(**response.json())
|
||||
|
||||
def count_token(self, prompt: str) -> int:
|
||||
@@ -106,6 +108,8 @@ class RemoteModelWorker(ModelWorker):
|
||||
json={"prompt": prompt},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if response.status_code not in [200, 201]:
|
||||
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||
return response.json()
|
||||
|
||||
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
@@ -123,6 +127,8 @@ class RemoteModelWorker(ModelWorker):
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if response.status_code not in [200, 201]:
|
||||
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||
return ModelMetadata.from_dict(response.json())
|
||||
|
||||
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
@@ -141,6 +147,8 @@ class RemoteModelWorker(ModelWorker):
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if response.status_code not in [200, 201]:
|
||||
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||
return response.json()
|
||||
|
||||
async def async_embeddings(self, params: Dict) -> List[List[float]]:
|
||||
@@ -156,6 +164,8 @@ class RemoteModelWorker(ModelWorker):
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if response.status_code not in [200, 201]:
|
||||
raise Exception(f"Request to {url} failed, error: {response.text}")
|
||||
return response.json()
|
||||
|
||||
def _get_trace_headers(self):
|
||||
|
Reference in New Issue
Block a user