feat(model): Support llama.cpp server deploy (#2263)

This commit is contained in:
Fangyin Cheng
2025-01-02 16:50:53 +08:00
committed by GitHub
parent 576da34e92
commit 0b2af2e9a2
14 changed files with 823 additions and 44 deletions

View File

@@ -55,6 +55,7 @@ class DefaultModelWorker(ModelWorker):
model_type = self.llm_adapter.model_type()
self.param_cls = self.llm_adapter.model_param_class(model_type)
self._support_async = self.llm_adapter.support_async()
self._support_generate_func = self.llm_adapter.support_generate_function()
logger.info(
f"model_name: {self.model_name}, model_path: {self.model_path}, model_param_class: {self.param_cls}"
@@ -85,7 +86,7 @@ class DefaultModelWorker(ModelWorker):
model_path=self.model_path,
model_type=model_type,
)
if not model_params.device:
if hasattr(model_params, "device") and not model_params.device:
model_params.device = get_device()
logger.info(
f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}"
@@ -129,7 +130,7 @@ class DefaultModelWorker(ModelWorker):
def stop(self) -> None:
if not self.model:
logger.warn("Model has been stopped!!")
logger.warning("Model has been stopped!!")
return
del self.model
del self.tokenizer
@@ -190,9 +191,40 @@ class DefaultModelWorker(ModelWorker):
def generate(self, params: Dict) -> ModelOutput:
"""Generate non stream result"""
output = None
for out in self.generate_stream(params):
output = out
return output
if self._support_generate_func:
(
params,
model_context,
generate_stream_func,
model_span,
) = self._prepare_generate_stream(
params,
span_operation_name="DefaultModelWorker_call.generate_func",
is_stream=False,
)
previous_response = ""
last_metrics = ModelInferenceMetrics.create_metrics()
is_first_generate = True
output = generate_stream_func(
self.model, self.tokenizer, params, get_device(), self.context_len
)
(
model_output,
incremental_output,
output_str,
current_metrics,
) = self._handle_output(
output,
previous_response,
model_context,
last_metrics,
is_first_generate,
)
return model_output
else:
for out in self.generate_stream(params):
output = out
return output
def count_token(self, prompt: str) -> int:
return _try_to_count_token(prompt, self.tokenizer, self.model)
@@ -277,12 +309,45 @@ class DefaultModelWorker(ModelWorker):
span.end(metadata={"error": output.to_dict()})
async def async_generate(self, params: Dict) -> ModelOutput:
output = None
async for out in self.async_generate_stream(params):
output = out
return output
if self._support_generate_func:
(
params,
model_context,
generate_stream_func,
model_span,
) = self._prepare_generate_stream(
params,
span_operation_name="DefaultModelWorker_call.generate_func",
is_stream=False,
)
previous_response = ""
last_metrics = ModelInferenceMetrics.create_metrics()
is_first_generate = True
output = await generate_stream_func(
self.model, self.tokenizer, params, get_device(), self.context_len
)
(
model_output,
incremental_output,
output_str,
current_metrics,
) = self._handle_output(
output,
previous_response,
model_context,
last_metrics,
is_first_generate,
)
return model_output
else:
output = None
async for out in self.async_generate_stream(params):
output = out
return output
def _prepare_generate_stream(self, params: Dict, span_operation_name: str):
def _prepare_generate_stream(
self, params: Dict, span_operation_name: str, is_stream=True
):
params, model_context = self.llm_adapter.model_adaptation(
params,
self.model_name,
@@ -290,29 +355,48 @@ class DefaultModelWorker(ModelWorker):
self.tokenizer,
prompt_template=self.ml.prompt_template,
)
stream_type = ""
if self.support_async():
generate_stream_func = self.llm_adapter.get_async_generate_stream_function(
self.model, self.model_path
)
stream_type = "async "
logger.info(
"current generate stream function is asynchronous stream function"
)
if not is_stream and self.llm_adapter.support_generate_function():
func = self.llm_adapter.get_generate_function(
self.model, self.model_path
)
func_type = "async generate"
logger.info(
"current generate function is asynchronous generate function"
)
else:
func = self.llm_adapter.get_async_generate_stream_function(
self.model, self.model_path
)
func_type = "async generate stream"
logger.info(
"current generate stream function is asynchronous generate stream function"
)
else:
generate_stream_func = self.llm_adapter.get_generate_stream_function(
self.model, self.model_path
)
if not is_stream and self.llm_adapter.support_generate_function():
func = self.llm_adapter.get_generate_function(
self.model, self.model_path
)
func_type = "generate"
logger.info(
"current generate function is synchronous generate function"
)
else:
func = self.llm_adapter.get_generate_stream_function(
self.model, self.model_path
)
func_type = "generate stream"
logger.info(
"current generate stream function is synchronous generate stream function"
)
str_prompt = params.get("prompt")
if not str_prompt:
str_prompt = params.get("string_prompt")
print(
f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n"
f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{func_type} output:\n"
)
generate_stream_func_str_name = "{}.{}".format(
generate_stream_func.__module__, generate_stream_func.__name__
)
generate_func_str_name = "{}.{}".format(func.__module__, func.__name__)
span_params = {k: v for k, v in params.items()}
if "messages" in span_params:
@@ -323,7 +407,7 @@ class DefaultModelWorker(ModelWorker):
metadata = {
"is_async_func": self.support_async(),
"llm_adapter": str(self.llm_adapter),
"generate_stream_func": generate_stream_func_str_name,
"generate_func": generate_func_str_name,
}
metadata.update(span_params)
metadata.update(model_context)
@@ -331,7 +415,7 @@ class DefaultModelWorker(ModelWorker):
model_span = root_tracer.start_span(span_operation_name, metadata=metadata)
return params, model_context, generate_stream_func, model_span
return params, model_context, func, model_span
def _handle_output(
self,

View File

@@ -89,6 +89,8 @@ class RemoteModelWorker(ModelWorker):
json=params,
timeout=self.timeout,
)
if response.status_code not in [200, 201]:
raise Exception(f"Request to {url} failed, error: {response.text}")
return ModelOutput(**response.json())
def count_token(self, prompt: str) -> int:
@@ -106,6 +108,8 @@ class RemoteModelWorker(ModelWorker):
json={"prompt": prompt},
timeout=self.timeout,
)
if response.status_code not in [200, 201]:
raise Exception(f"Request to {url} failed, error: {response.text}")
return response.json()
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
@@ -123,6 +127,8 @@ class RemoteModelWorker(ModelWorker):
json=params,
timeout=self.timeout,
)
if response.status_code not in [200, 201]:
raise Exception(f"Request to {url} failed, error: {response.text}")
return ModelMetadata.from_dict(response.json())
def get_model_metadata(self, params: Dict) -> ModelMetadata:
@@ -141,6 +147,8 @@ class RemoteModelWorker(ModelWorker):
json=params,
timeout=self.timeout,
)
if response.status_code not in [200, 201]:
raise Exception(f"Request to {url} failed, error: {response.text}")
return response.json()
async def async_embeddings(self, params: Dict) -> List[List[float]]:
@@ -156,6 +164,8 @@ class RemoteModelWorker(ModelWorker):
json=params,
timeout=self.timeout,
)
if response.status_code not in [200, 201]:
raise Exception(f"Request to {url} failed, error: {response.text}")
return response.json()
def _get_trace_headers(self):