mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-26 21:37:40 +00:00
llms: add chatglm
This commit is contained in:
parent
370e327bf3
commit
8e127b3863
@ -15,7 +15,7 @@ class BaseChatAdpter:
|
|||||||
def get_generate_stream_func(self):
|
def get_generate_stream_func(self):
|
||||||
"""Return the generate stream handler func"""
|
"""Return the generate stream handler func"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
llm_model_chat_adapters: List[BaseChatAdpter] = []
|
llm_model_chat_adapters: List[BaseChatAdpter] = []
|
||||||
|
|
||||||
|
@ -23,19 +23,67 @@ from pilot.model.inference import generate_output, get_embeddings
|
|||||||
from pilot.model.loader import ModelLoader
|
from pilot.model.loader import ModelLoader
|
||||||
from pilot.configs.model_config import *
|
from pilot.configs.model_config import *
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
from pilot.server.chat_adapter import get_llm_chat_adapter
|
||||||
|
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||||
|
print(model_path)
|
||||||
ml = ModelLoader(model_path=model_path)
|
|
||||||
model, tokenizer = ml.loader(num_gpus=1, load_8bit=ISLOAD_8BIT, debug=ISDEBUG)
|
|
||||||
|
|
||||||
class ModelWorker:
|
class ModelWorker:
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# TODO
|
def __init__(self, model_path, model_name, device, num_gpus=1):
|
||||||
|
|
||||||
|
if model_path.endswith("/"):
|
||||||
|
model_path = model_path[:-1]
|
||||||
|
self.model_name = model_name or model_path.split("/")[-1]
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
self.ml = ModelLoader(model_path=model_path)
|
||||||
|
self.model, self.tokenizer = self.ml.loader(num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG)
|
||||||
|
|
||||||
|
if hasattr(self.model.config, "max_sequence_length"):
|
||||||
|
self.context_len = self.model.config.max_sequence_length
|
||||||
|
elif hasattr(self.model.config, "max_position_embeddings"):
|
||||||
|
self.context_len = self.model.config.max_position_embeddings
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.context_len = 2048
|
||||||
|
|
||||||
|
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
|
||||||
|
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func()
|
||||||
|
|
||||||
|
def get_queue_length(self):
|
||||||
|
if model_semaphore is None or model_semaphore._value is None or model_semaphore._waiters is None:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
CFG.LIMIT_MODEL_CONCURRENCY - model_semaphore._value + len(model_semaphore._waiters)
|
||||||
|
|
||||||
|
def generate_stream_gate(self, params):
|
||||||
|
try:
|
||||||
|
for output in self.generate_stream_func(
|
||||||
|
self.model,
|
||||||
|
self.tokenizer,
|
||||||
|
params,
|
||||||
|
DEVICE,
|
||||||
|
CFG.MAX_POSITION_EMBEDDINGS
|
||||||
|
):
|
||||||
|
print("output: ", output)
|
||||||
|
ret = {
|
||||||
|
"text": output,
|
||||||
|
"error_code": 0,
|
||||||
|
}
|
||||||
|
yield json.dumps(ret).encode() + b"\0"
|
||||||
|
|
||||||
|
except torch.cuda.CudaError:
|
||||||
|
ret = {
|
||||||
|
"text": "**GPU OutOfMemory, Please Refresh.**",
|
||||||
|
"error_code": 0
|
||||||
|
}
|
||||||
|
yield json.dumps(ret).encode() + b"\0"
|
||||||
|
|
||||||
|
def get_embeddings(self, prompt):
|
||||||
|
return get_embeddings(self.model, self.tokenizer, prompt)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
@ -60,41 +108,17 @@ def release_model_semaphore():
|
|||||||
model_semaphore.release()
|
model_semaphore.release()
|
||||||
|
|
||||||
|
|
||||||
def generate_stream_gate(params):
|
|
||||||
try:
|
|
||||||
for output in generate_stream(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
params,
|
|
||||||
DEVICE,
|
|
||||||
CFG.MAX_POSITION_EMBEDDINGS,
|
|
||||||
):
|
|
||||||
print("output: ", output)
|
|
||||||
ret = {
|
|
||||||
"text": output,
|
|
||||||
"error_code": 0,
|
|
||||||
}
|
|
||||||
yield json.dumps(ret).encode() + b"\0"
|
|
||||||
except torch.cuda.CudaError:
|
|
||||||
ret = {
|
|
||||||
"text": "**GPU OutOfMemory, Please Refresh.**",
|
|
||||||
"error_code": 0
|
|
||||||
}
|
|
||||||
yield json.dumps(ret).encode() + b"\0"
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/generate_stream")
|
@app.post("/generate_stream")
|
||||||
async def api_generate_stream(request: Request):
|
async def api_generate_stream(request: Request):
|
||||||
global model_semaphore, global_counter
|
global model_semaphore, global_counter
|
||||||
global_counter += 1
|
global_counter += 1
|
||||||
params = await request.json()
|
params = await request.json()
|
||||||
print(model, tokenizer, params, DEVICE)
|
|
||||||
|
|
||||||
if model_semaphore is None:
|
if model_semaphore is None:
|
||||||
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
|
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
|
||||||
await model_semaphore.acquire()
|
await model_semaphore.acquire()
|
||||||
|
|
||||||
generator = generate_stream_gate(params)
|
generator = worker.generate_stream_gate(params)
|
||||||
background_tasks = BackgroundTasks()
|
background_tasks = BackgroundTasks()
|
||||||
background_tasks.add_task(release_model_semaphore)
|
background_tasks.add_task(release_model_semaphore)
|
||||||
return StreamingResponse(generator, background=background_tasks)
|
return StreamingResponse(generator, background=background_tasks)
|
||||||
@ -110,7 +134,7 @@ def generate(prompt_request: PromptRequest):
|
|||||||
|
|
||||||
response = []
|
response = []
|
||||||
rsp_str = ""
|
rsp_str = ""
|
||||||
output = generate_stream_gate(params)
|
output = worker.generate_stream_gate(params)
|
||||||
for rsp in output:
|
for rsp in output:
|
||||||
# rsp = rsp.decode("utf-8")
|
# rsp = rsp.decode("utf-8")
|
||||||
rsp_str = str(rsp, "utf-8")
|
rsp_str = str(rsp, "utf-8")
|
||||||
@ -124,9 +148,18 @@ def generate(prompt_request: PromptRequest):
|
|||||||
def embeddings(prompt_request: EmbeddingRequest):
|
def embeddings(prompt_request: EmbeddingRequest):
|
||||||
params = {"prompt": prompt_request.prompt}
|
params = {"prompt": prompt_request.prompt}
|
||||||
print("Received prompt: ", params["prompt"])
|
print("Received prompt: ", params["prompt"])
|
||||||
output = get_embeddings(model, tokenizer, params["prompt"])
|
output = worker.get_embeddings(params["prompt"])
|
||||||
return {"response": [float(x) for x in output]}
|
return {"response": [float(x) for x in output]}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
|
||||||
|
worker = ModelWorker(
|
||||||
|
model_path=model_path,
|
||||||
|
model_name=CFG.LLM_MODEL,
|
||||||
|
device=DEVICE,
|
||||||
|
num_gpus=1
|
||||||
|
)
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info")
|
uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info")
|
Loading…
Reference in New Issue
Block a user