Features: multi llm model support. (#78)

Features: multi llms support.
     - model_adapter for load multi models
     - chat_adapter for chat with models.
This commit is contained in:
Aries-ckt 2023-05-21 17:28:11 +08:00 committed by GitHub
commit e847a3fc7a
15 changed files with 463 additions and 72 deletions

1
.gitignore vendored
View File

@ -23,6 +23,7 @@ lib/
lib64/ lib64/
parts/ parts/
sdist/ sdist/
models
var/ var/
wheels/ wheels/
models/ models/

View File

@ -29,6 +29,10 @@ Currently, we have released multiple key features, which are listed below to dem
- Unified vector storage/indexing of knowledge base - Unified vector storage/indexing of knowledge base
- Support for unstructured data such as PDF, Markdown, CSV, and WebURL - Support for unstructured data such as PDF, Markdown, CSV, and WebURL
- Milti LLMs Support
- Supports multiple large language models, currently supporting Vicuna (7b, 13b), ChatGLM-6b (int4, int8)
- TODO: codegen2, codet5p
## Demo ## Demo
@ -175,6 +179,10 @@ Notice: the webserver need to connect llmserver, so you need change the .env f
We provide a user interface for Gradio, which allows you to use DB-GPT through our user interface. Additionally, we have prepared several reference articles (written in Chinese) that introduce the code and principles related to our project. We provide a user interface for Gradio, which allows you to use DB-GPT through our user interface. Additionally, we have prepared several reference articles (written in Chinese) that introduce the code and principles related to our project.
- [LLM Practical In Action Series (1) — Combined Langchain-Vicuna Application Practical](https://medium.com/@cfqcsunny/llm-practical-in-action-series-1-combined-langchain-vicuna-application-practical-701cd0413c9f) - [LLM Practical In Action Series (1) — Combined Langchain-Vicuna Application Practical](https://medium.com/@cfqcsunny/llm-practical-in-action-series-1-combined-langchain-vicuna-application-practical-701cd0413c9f)
### Multi LLMs Usage
To use multiple models, modify the LLM_MODEL parameter in the .env configuration file to switch between the models.
## Acknowledgement ## Acknowledgement
The achievements of this project are thanks to the technical community, especially the following projects: The achievements of this project are thanks to the technical community, especially the following projects:

View File

@ -26,6 +26,10 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目使用本地
- 知识库统一向量存储/索引 - 知识库统一向量存储/索引
- 非结构化数据支持包括PDF、MarkDown、CSV、WebURL - 非结构化数据支持包括PDF、MarkDown、CSV、WebURL
- 多模型支持
- 支持多种大语言模型, 当前已支持Vicuna(7b,13b), ChatGLM-6b(int4, int8)
- TODO: codet5p, codegen2
## 效果演示 ## 效果演示
示例通过 RTX 4090 GPU 演示,[YouTube 地址](https://www.youtube.com/watch?v=1PWI6F89LPo) 示例通过 RTX 4090 GPU 演示,[YouTube 地址](https://www.youtube.com/watch?v=1PWI6F89LPo)
@ -178,6 +182,10 @@ $ python webserver.py
2. [大模型实战系列(2) —— DB-GPT 阿里云部署指南](https://zhuanlan.zhihu.com/p/629467580) 2. [大模型实战系列(2) —— DB-GPT 阿里云部署指南](https://zhuanlan.zhihu.com/p/629467580)
3. [大模型实战系列(3) —— DB-GPT插件模型原理与使用](https://zhuanlan.zhihu.com/p/629623125) 3. [大模型实战系列(3) —— DB-GPT插件模型原理与使用](https://zhuanlan.zhihu.com/p/629623125)
### 多模型使用
在.env 配置文件当中, 修改LLM_MODEL参数来切换使用的模型。
## 感谢 ## 感谢
项目取得的成果,需要感谢技术社区,尤其以下项目。 项目取得的成果,需要感谢技术社区,尤其以下项目。

View File

@ -5,14 +5,21 @@ import requests
import json import json
import time import time
import uuid import uuid
import os
import sys
from urllib.parse import urljoin from urllib.parse import urljoin
import gradio as gr import gradio as gr
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(ROOT_PATH)
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.conversation import conv_qa_prompt_template, conv_templates from pilot.conversation import conv_qa_prompt_template, conv_templates
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
vicuna_stream_path = "generate_stream" llmstream_stream_path = "generate_stream"
CFG = Config() CFG = Config()
@ -21,38 +28,44 @@ def generate(query):
template_name = "conv_one_shot" template_name = "conv_one_shot"
state = conv_templates[template_name].copy() state = conv_templates[template_name].copy()
pt = PromptTemplate( # pt = PromptTemplate(
template=conv_qa_prompt_template, # template=conv_qa_prompt_template,
input_variables=["context", "question"] # input_variables=["context", "question"]
) # )
result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.", # result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
question=query) # question=query)
print(result) # print(result)
state.append_message(state.roles[0], result) state.append_message(state.roles[0], query)
state.append_message(state.roles[1], None) state.append_message(state.roles[1], None)
prompt = state.get_prompt() prompt = state.get_prompt()
params = { params = {
"model": "vicuna-13b", "model": "chatglm-6b",
"prompt": prompt, "prompt": prompt,
"temperature": 0.7, "temperature": 1.0,
"max_new_tokens": 1024, "max_new_tokens": 1024,
"stop": "###" "stop": "###"
} }
response = requests.post( response = requests.post(
url=urljoin(CFG.MODEL_SERVER, vicuna_stream_path), data=json.dumps(params) url=urljoin(CFG.MODEL_SERVER, llmstream_stream_path), data=json.dumps(params)
) )
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3 skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
data = json.loads(chunk.decode()) data = json.loads(chunk.decode())
if data["error_code"] == 0: if data["error_code"] == 0:
output = data["text"][skip_echo_len:].strip()
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
state.messages[-1][-1] = output + "" state.messages[-1][-1] = output + ""
yield(output) yield(output)

View File

@ -16,12 +16,17 @@ DATA_DIR = os.path.join(PILOT_PATH, "data")
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
LLM_MODEL_CONFIG = { LLM_MODEL_CONFIG = {
"flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"), "flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"), "vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
"vicuna-7b": os.path.join(MODEL_PATH, "vicuna-7b"),
"text2vec": os.path.join(MODEL_PATH, "text2vec-large-chinese"), "text2vec": os.path.join(MODEL_PATH, "text2vec-large-chinese"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2") "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"codegen2-1b": os.path.join(MODEL_PATH, "codegen2-1B"),
"codet5p-2b": os.path.join(MODEL_PATH, "codet5p-2b"),
"chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"),
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
} }
# Load model config # Load model config

View File

@ -15,6 +15,9 @@ DB_SETTINGS = {
"port": CFG.LOCAL_DB_PORT "port": CFG.LOCAL_DB_PORT
} }
ROLE_USER = "USER"
ROLE_ASSISTANT = "Assistant"
class SeparatorStyle(Enum): class SeparatorStyle(Enum):
SINGLE = auto() SINGLE = auto()
TWO = auto() TWO = auto()

View File

@ -9,6 +9,8 @@ from transformers import (
AutoModel AutoModel
) )
from pilot.configs.model_config import DEVICE
class BaseLLMAdaper: class BaseLLMAdaper:
"""The Base class for multi model, in our project. """The Base class for multi model, in our project.
We will support those model, which performance resemble ChatGPT """ We will support those model, which performance resemble ChatGPT """
@ -61,13 +63,29 @@ class ChatGLMAdapater(BaseLLMAdaper):
"""LLM Adatpter for THUDM/chatglm-6b""" """LLM Adatpter for THUDM/chatglm-6b"""
def match(self, model_path: str): def match(self, model_path: str):
return "chatglm" in model_path return "chatglm" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict): def loader(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(
model_path, trust_remote_code=True, **from_pretrained_kwargs if DEVICE != "cuda":
).half().cuda() model = AutoModel.from_pretrained(
return model, tokenizer model_path, trust_remote_code=True, **from_pretrained_kwargs
).float()
return model, tokenizer
else:
model = AutoModel.from_pretrained(
model_path, trust_remote_code=True, **from_pretrained_kwargs
).half().cuda()
return model, tokenizer
class CodeGenAdapter(BaseLLMAdaper):
pass
class StarCoderAdapter(BaseLLMAdaper):
pass
class T5CodeAdapter(BaseLLMAdaper):
pass
class KoalaLLMAdapter(BaseLLMAdaper): class KoalaLLMAdapter(BaseLLMAdaper):
"""Koala LLM Adapter which Based LLaMA """ """Koala LLM Adapter which Based LLaMA """
@ -91,6 +109,7 @@ class GPT4AllAdapter(BaseLLMAdaper):
register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater)
# TODO Default support vicuna, other model need to tests and Evaluate # TODO Default support vicuna, other model need to tests and Evaluate
register_llm_model_adapters(BaseLLMAdaper) register_llm_model_adapters(BaseLLMAdaper)

View File

@ -1,3 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import torch
from pilot.conversation import ROLE_USER, ROLE_ASSISTANT
@torch.inference_mode()
def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, stream_interval=2):
"""Generate text using chatglm model's chat api """
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
stop = params.get("stop", "###")
echo = params.get("echo", False)
generate_kwargs = {
"do_sample": True if temperature > 1e-5 else False,
"top_p": top_p,
"repetition_penalty": 1.0,
"logits_processor": None
}
if temperature > 1e-5:
generate_kwargs["temperature"] = temperature
# TODO, Fix this
hist = []
messages = prompt.split(stop)
# Add history chat to hist for model.
for i in range(1, len(messages) - 2, 2):
hist.append((messages[i].split(ROLE_USER + ":")[1], messages[i+1].split(ROLE_ASSISTANT + ":")[1]))
query = messages[-2].split(ROLE_USER + ":")[1]
print("Query Message: ", query)
output = ""
i = 0
for i, (response, new_hist) in enumerate(model.stream_chat(tokenizer, query, hist, **generate_kwargs)):
if echo:
output = query + " " + response
else:
output = response
yield output
yield output

View File

@ -0,0 +1,125 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import math
from typing import Optional, Tuple
import torch
from torch import nn
import transformers
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2].clone()
x2 = x[..., x.shape[-1] // 2 :].clone()
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
self.head_dim
)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def replace_llama_attn_with_non_inplace_operations():
"""Avoid bugs in mps backend by not using in-place operations."""
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
import transformers
def replace_llama_attn_with_non_inplace_operations():
"""Avoid bugs in mps backend by not using in-place operations."""
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward

View File

@ -2,11 +2,39 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import torch import torch
import sys
import warnings import warnings
from pilot.singleton import Singleton from pilot.singleton import Singleton
from typing import Optional
from pilot.model.compression import compress_module from pilot.model.compression import compress_module
from pilot.model.adapter import get_llm_model_adapter from pilot.model.adapter import get_llm_model_adapter
from pilot.utils import get_gpu_memory
from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations
def raise_warning_for_incompatible_cpu_offloading_configuration(
device: str, load_8bit: bool, cpu_offloading: bool
):
if cpu_offloading:
if not load_8bit:
warnings.warn(
"The cpu-offloading feature can only be used while also using 8-bit-quantization.\n"
"Use '--load-8bit' to enable 8-bit-quantization\n"
"Continuing without cpu-offloading enabled\n"
)
return False
if not "linux" in sys.platform:
warnings.warn(
"CPU-offloading is only supported on linux-systems due to the limited compatability with the bitsandbytes-package\n"
"Continuing without cpu-offloading enabled\n"
)
return False
if device != "cuda":
warnings.warn(
"CPU-offloading is only enabled when using CUDA-devices\n"
"Continuing without cpu-offloading enabled\n"
)
return False
return cpu_offloading
class ModelLoader(metaclass=Singleton): class ModelLoader(metaclass=Singleton):
@ -30,26 +58,37 @@ class ModelLoader(metaclass=Singleton):
} }
# TODO multi gpu support # TODO multi gpu support
def loader(self, num_gpus, load_8bit=False, debug=False): def loader(self, num_gpus, load_8bit=False, debug=False, cpu_offloading=False, max_gpu_memory: Optional[str]=None):
if self.device == "cpu": if self.device == "cpu":
kwargs = {} kwargs = {"torch_dtype": torch.float32}
elif self.device == "cuda": elif self.device == "cuda":
kwargs = {"torch_dtype": torch.float16} kwargs = {"torch_dtype": torch.float16}
if num_gpus == "auto": num_gpus = int(num_gpus)
if num_gpus != 1:
kwargs["device_map"] = "auto" kwargs["device_map"] = "auto"
if max_gpu_memory is None:
kwargs["device_map"] = "sequential"
available_gpu_memory = get_gpu_memory(num_gpus)
kwargs["max_memory"] = {
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
for i in range(num_gpus)
}
else: else:
num_gpus = int(num_gpus) kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
if num_gpus != 1:
kwargs.update({ elif self.device == "mps":
"device_map": "auto", kwargs = kwargs = {"torch_dtype": torch.float16}
"max_memory": {i: "13GiB" for i in range(num_gpus)}, replace_llama_attn_with_non_inplace_operations()
})
else: else:
# Todo Support mps for practise
raise ValueError(f"Invalid device: {self.device}") raise ValueError(f"Invalid device: {self.device}")
# TODO when cpu loading, need use quantization config
llm_adapter = get_llm_model_adapter(self.model_path) llm_adapter = get_llm_model_adapter(self.model_path)
model, tokenizer = llm_adapter.loader(self.model_path, kwargs) model, tokenizer = llm_adapter.loader(self.model_path, kwargs)
@ -61,7 +100,7 @@ class ModelLoader(metaclass=Singleton):
else: else:
compress_module(model, self.device) compress_module(model, self.device)
if (self.device == "cuda" and num_gpus == 1): if (self.device == "cuda" and num_gpus == 1 and not cpu_offloading) or self.device == "mps":
model.to(self.device) model.to(self.device)
if debug: if debug:

View File

@ -0,0 +1,82 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List
from functools import cache
from pilot.model.inference import generate_stream
class BaseChatAdpter:
"""The Base class for chat with llm models. it will match the model,
and fetch output from model"""
def match(self, model_path: str):
return True
def get_generate_stream_func(self):
"""Return the generate stream handler func"""
pass
llm_model_chat_adapters: List[BaseChatAdpter] = []
def register_llm_model_chat_adapter(cls):
"""Register a chat adapter"""
llm_model_chat_adapters.append(cls())
@cache
def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
"""Get a chat generate func for a model"""
for adapter in llm_model_chat_adapters:
if adapter.match(model_path):
return adapter
raise ValueError(f"Invalid model for chat adapter {model_path}")
class VicunaChatAdapter(BaseChatAdpter):
""" Model chat Adapter for vicuna"""
def match(self, model_path: str):
return "vicuna" in model_path
def get_generate_stream_func(self):
return generate_stream
class ChatGLMChatAdapter(BaseChatAdpter):
""" Model chat Adapter for ChatGLM"""
def match(self, model_path: str):
return "chatglm" in model_path
def get_generate_stream_func(self):
from pilot.model.chatglm_llm import chatglm_generate_stream
return chatglm_generate_stream
class CodeT5ChatAdapter(BaseChatAdpter):
""" Model chat adapter for CodeT5 """
def match(self, model_path: str):
return "codet5" in model_path
def get_generate_stream_func(self):
# TODO
pass
class CodeGenChatAdapter(BaseChatAdpter):
""" Model chat adapter for CodeGen """
def match(self, model_path: str):
return "codegen" in model_path
def get_generate_stream_func(self):
# TODO
pass
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(BaseChatAdpter)

View File

@ -23,20 +23,65 @@ from pilot.model.inference import generate_output, get_embeddings
from pilot.model.loader import ModelLoader from pilot.model.loader import ModelLoader
from pilot.configs.model_config import * from pilot.configs.model_config import *
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.server.chat_adapter import get_llm_chat_adapter
CFG = Config() CFG = Config()
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
ml = ModelLoader(model_path=model_path)
model, tokenizer = ml.loader(num_gpus=1, load_8bit=ISLOAD_8BIT, debug=ISDEBUG)
#model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
class ModelWorker: class ModelWorker:
def __init__(self):
pass
# TODO def __init__(self, model_path, model_name, device, num_gpus=1):
if model_path.endswith("/"):
model_path = model_path[:-1]
self.model_name = model_name or model_path.split("/")[-1]
self.device = device
self.ml = ModelLoader(model_path=model_path)
self.model, self.tokenizer = self.ml.loader(num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG)
if hasattr(self.model.config, "max_sequence_length"):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model.config, "max_position_embeddings"):
self.context_len = self.model.config.max_position_embeddings
else:
self.context_len = 2048
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func()
def get_queue_length(self):
if model_semaphore is None or model_semaphore._value is None or model_semaphore._waiters is None:
return 0
else:
CFG.LIMIT_MODEL_CONCURRENCY - model_semaphore._value + len(model_semaphore._waiters)
def generate_stream_gate(self, params):
try:
for output in self.generate_stream_func(
self.model,
self.tokenizer,
params,
DEVICE,
CFG.MAX_POSITION_EMBEDDINGS
):
print("output: ", output)
ret = {
"text": output,
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.CudaError:
ret = {
"text": "**GPU OutOfMemory, Please Refresh.**",
"error_code": 0
}
yield json.dumps(ret).encode() + b"\0"
def get_embeddings(self, prompt):
return get_embeddings(self.model, self.tokenizer, prompt)
app = FastAPI() app = FastAPI()
@ -61,41 +106,17 @@ def release_model_semaphore():
model_semaphore.release() model_semaphore.release()
def generate_stream_gate(params):
try:
for output in generate_stream(
model,
tokenizer,
params,
DEVICE,
CFG.MAX_POSITION_EMBEDDINGS,
):
print("output: ", output)
ret = {
"text": output,
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.CudaError:
ret = {
"text": "**GPU OutOfMemory, Please Refresh.**",
"error_code": 0
}
yield json.dumps(ret).encode() + b"\0"
@app.post("/generate_stream") @app.post("/generate_stream")
async def api_generate_stream(request: Request): async def api_generate_stream(request: Request):
global model_semaphore, global_counter global model_semaphore, global_counter
global_counter += 1 global_counter += 1
params = await request.json() params = await request.json()
print(model, tokenizer, params, DEVICE)
if model_semaphore is None: if model_semaphore is None:
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY) model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire() await model_semaphore.acquire()
generator = generate_stream_gate(params) generator = worker.generate_stream_gate(params)
background_tasks = BackgroundTasks() background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore) background_tasks.add_task(release_model_semaphore)
return StreamingResponse(generator, background=background_tasks) return StreamingResponse(generator, background=background_tasks)
@ -111,7 +132,7 @@ def generate(prompt_request: PromptRequest):
response = [] response = []
rsp_str = "" rsp_str = ""
output = generate_stream_gate(params) output = worker.generate_stream_gate(params)
for rsp in output: for rsp in output:
# rsp = rsp.decode("utf-8") # rsp = rsp.decode("utf-8")
rsp_str = str(rsp, "utf-8") rsp_str = str(rsp, "utf-8")
@ -125,9 +146,21 @@ def generate(prompt_request: PromptRequest):
def embeddings(prompt_request: EmbeddingRequest): def embeddings(prompt_request: EmbeddingRequest):
params = {"prompt": prompt_request.prompt} params = {"prompt": prompt_request.prompt}
print("Received prompt: ", params["prompt"]) print("Received prompt: ", params["prompt"])
output = get_embeddings(model, tokenizer, params["prompt"]) output = worker.get_embeddings(params["prompt"])
return {"response": [float(x) for x in output]} return {"response": [float(x) for x in output]}
if __name__ == "__main__": if __name__ == "__main__":
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
print(model_path, DEVICE)
worker = ModelWorker(
model_path=model_path,
model_name=CFG.LLM_MODEL,
device=DEVICE,
num_gpus=1
)
uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info") uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info")

View File

@ -364,8 +364,16 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
data = json.loads(chunk.decode()) data = json.loads(chunk.decode())
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0: if data["error_code"] == 0:
output = data["text"][skip_echo_len:].strip()
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
output = post_process_code(output) output = post_process_code(output)
state.messages[-1][-1] = output + "" state.messages[-1][-1] = output + ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5

View File

@ -42,6 +42,7 @@ tenacity==8.2.2
peft peft
pycocoevalcap pycocoevalcap
sentence-transformers sentence-transformers
cpm_kernels
umap-learn umap-learn
notebook notebook
gradio==3.23 gradio==3.23