mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 10:54:29 +00:00
Merge branch 'llm_fxp' into ty_test
This commit is contained in:
commit
85176111b9
@ -21,7 +21,7 @@ LLM_MODEL=vicuna-13b
|
|||||||
MODEL_SERVER=http://127.0.0.1:8000
|
MODEL_SERVER=http://127.0.0.1:8000
|
||||||
LIMIT_MODEL_CONCURRENCY=5
|
LIMIT_MODEL_CONCURRENCY=5
|
||||||
MAX_POSITION_EMBEDDINGS=4096
|
MAX_POSITION_EMBEDDINGS=4096
|
||||||
|
QUANTIZE_QLORA=True
|
||||||
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
|
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
|
||||||
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
|
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
|
||||||
# SMART_LLM_MODEL=vicuna-13b
|
# SMART_LLM_MODEL=vicuna-13b
|
||||||
|
@ -147,6 +147,9 @@ class Config(metaclass=Singleton):
|
|||||||
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
|
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
|
||||||
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
|
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
|
||||||
|
|
||||||
|
# QLoRA
|
||||||
|
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
|
||||||
|
|
||||||
### EMBEDDING Configuration
|
### EMBEDDING Configuration
|
||||||
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
||||||
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 500))
|
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 500))
|
||||||
|
@ -35,6 +35,8 @@ LLM_MODEL_CONFIG = {
|
|||||||
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
|
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
|
||||||
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
|
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
|
||||||
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
||||||
|
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
|
||||||
|
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
|
||||||
"proxyllm": "proxyllm",
|
"proxyllm": "proxyllm",
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,7 +44,6 @@ LLM_MODEL_CONFIG = {
|
|||||||
ISLOAD_8BIT = True
|
ISLOAD_8BIT = True
|
||||||
ISDEBUG = False
|
ISDEBUG = False
|
||||||
|
|
||||||
|
|
||||||
VECTOR_SEARCH_TOP_K = 10
|
VECTOR_SEARCH_TOP_K = 10
|
||||||
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
|
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(
|
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(
|
||||||
|
@ -1,11 +1,26 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from functools import cache
|
|
||||||
|
import torch
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from functools import cache
|
||||||
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
|
from transformers import (
|
||||||
|
AutoModel,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
LlamaTokenizer,
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
)
|
||||||
from pilot.configs.model_config import DEVICE
|
from pilot.configs.model_config import DEVICE
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
|
bnb_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype="bfloat16",
|
||||||
|
bnb_4bit_use_double_quant=False,
|
||||||
|
)
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class BaseLLMAdaper:
|
class BaseLLMAdaper:
|
||||||
@ -97,16 +112,44 @@ class GuanacoAdapter(BaseLLMAdaper):
|
|||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
class GuanacoAdapter(BaseLLMAdaper):
|
class FalconAdapater(BaseLLMAdaper):
|
||||||
|
"""falcon Adapter"""
|
||||||
|
|
||||||
|
def match(self, model_path: str):
|
||||||
|
return "falcon" in model_path
|
||||||
|
|
||||||
|
def loader(self, model_path: str, from_pretrained_kwagrs: dict):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||||
|
|
||||||
|
if CFG.QLoRA:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
load_in_4bit=True, # quantize
|
||||||
|
quantization_config=bnb_config,
|
||||||
|
device_map={"": 0},
|
||||||
|
trust_remote_code=True,
|
||||||
|
**from_pretrained_kwagrs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
device_map={"": 0},
|
||||||
|
**from_pretrained_kwagrs,
|
||||||
|
)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class GorillaAdapter(BaseLLMAdaper):
|
||||||
"""TODO Support guanaco"""
|
"""TODO Support guanaco"""
|
||||||
|
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
return "guanaco" in model_path
|
return "gorilla" in model_path
|
||||||
|
|
||||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(model_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
|
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
||||||
)
|
)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
@ -166,6 +209,8 @@ class ProxyllmAdapter(BaseLLMAdaper):
|
|||||||
register_llm_model_adapters(VicunaLLMAdapater)
|
register_llm_model_adapters(VicunaLLMAdapater)
|
||||||
register_llm_model_adapters(ChatGLMAdapater)
|
register_llm_model_adapters(ChatGLMAdapater)
|
||||||
register_llm_model_adapters(GuanacoAdapter)
|
register_llm_model_adapters(GuanacoAdapter)
|
||||||
|
register_llm_model_adapters(FalconAdapater)
|
||||||
|
register_llm_model_adapters(GorillaAdapter)
|
||||||
# TODO Default support vicuna, other model need to tests and Evaluate
|
# TODO Default support vicuna, other model need to tests and Evaluate
|
||||||
|
|
||||||
# just for test_py, remove this later
|
# just for test_py, remove this later
|
||||||
|
@ -51,7 +51,10 @@ def chatglm_generate_stream(
|
|||||||
# else:
|
# else:
|
||||||
# once_conversation.append(f"""###system:{message} """)
|
# once_conversation.append(f"""###system:{message} """)
|
||||||
|
|
||||||
|
try:
|
||||||
query = messages[-2].split("human:")[1]
|
query = messages[-2].split("human:")[1]
|
||||||
|
except IndexError:
|
||||||
|
query = messages[-3].split("human:")[1]
|
||||||
print("Query Message: ", query)
|
print("Query Message: ", query)
|
||||||
# output = ""
|
# output = ""
|
||||||
# i = 0
|
# i = 0
|
||||||
|
54
pilot/model/llm_out/falcon_llm.py
Normal file
54
pilot/model/llm_out/falcon_llm.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import torch
|
||||||
|
import copy
|
||||||
|
from threading import Thread
|
||||||
|
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
||||||
|
|
||||||
|
|
||||||
|
def falcon_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||||
|
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
|
||||||
|
tokenizer.bos_token_id = 1
|
||||||
|
print(params)
|
||||||
|
stop = params.get("stop", "###")
|
||||||
|
prompt = params["prompt"]
|
||||||
|
query = prompt
|
||||||
|
print("Query Message: ", query)
|
||||||
|
|
||||||
|
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
||||||
|
input_ids = input_ids.to(model.device)
|
||||||
|
|
||||||
|
streamer = TextIteratorStreamer(
|
||||||
|
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer.bos_token_id = 1
|
||||||
|
stop_token_ids = [0]
|
||||||
|
|
||||||
|
class StopOnTokens(StoppingCriteria):
|
||||||
|
def __call__(
|
||||||
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||||
|
) -> bool:
|
||||||
|
for stop_id in stop_token_ids:
|
||||||
|
if input_ids[0][-1] == stop_id:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
stop = StopOnTokens()
|
||||||
|
|
||||||
|
generate_kwargs = dict(
|
||||||
|
input_ids=input_ids,
|
||||||
|
max_new_tokens=512,
|
||||||
|
temperature=1.0,
|
||||||
|
do_sample=True,
|
||||||
|
top_k=1,
|
||||||
|
streamer=streamer,
|
||||||
|
repetition_penalty=1.7,
|
||||||
|
stopping_criteria=StoppingCriteriaList([stop]),
|
||||||
|
)
|
||||||
|
|
||||||
|
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
out = ""
|
||||||
|
for new_text in streamer:
|
||||||
|
out += new_text
|
||||||
|
yield out
|
62
pilot/model/llm_out/gorilla_llm.py
Normal file
62
pilot/model/llm_out/gorilla_llm.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate_stream(
|
||||||
|
model, tokenizer, params, device, context_len=42048, stream_interval=2
|
||||||
|
):
|
||||||
|
"""Fork from https://github.com/ShishirPatil/gorilla/blob/main/inference/serve/gorilla_cli.py"""
|
||||||
|
prompt = params["prompt"]
|
||||||
|
l_prompt = len(prompt)
|
||||||
|
max_new_tokens = int(params.get("max_new_tokens", 1024))
|
||||||
|
stop_str = params.get("stop", None)
|
||||||
|
|
||||||
|
input_ids = tokenizer(prompt).input_ids
|
||||||
|
output_ids = list(input_ids)
|
||||||
|
input_echo_len = len(input_ids)
|
||||||
|
max_src_len = context_len - max_new_tokens - 8
|
||||||
|
input_ids = input_ids[-max_src_len:]
|
||||||
|
past_key_values = out = None
|
||||||
|
|
||||||
|
for i in range(max_new_tokens):
|
||||||
|
if i == 0:
|
||||||
|
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
|
||||||
|
logits = out.logits
|
||||||
|
past_key_values = out.past_key_values
|
||||||
|
else:
|
||||||
|
out = model(
|
||||||
|
input_ids=torch.as_tensor([[token]], device=device),
|
||||||
|
use_cache=True,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
logits = out.logits
|
||||||
|
past_key_values = out.past_key_values
|
||||||
|
|
||||||
|
last_token_logits = logits[0][-1]
|
||||||
|
|
||||||
|
probs = torch.softmax(last_token_logits, dim=-1)
|
||||||
|
token = int(torch.multinomial(probs, num_samples=1))
|
||||||
|
output_ids.append(token)
|
||||||
|
|
||||||
|
if token == tokenizer.eos_token_id:
|
||||||
|
stopped = True
|
||||||
|
else:
|
||||||
|
stopped = False
|
||||||
|
|
||||||
|
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
||||||
|
tmp_output_ids = output_ids[input_echo_len:]
|
||||||
|
output = tokenizer.decode(
|
||||||
|
tmp_output_ids,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
spaces_between_special_tokens=False,
|
||||||
|
)
|
||||||
|
pos = output.rfind(stop_str, l_prompt)
|
||||||
|
if pos != -1:
|
||||||
|
output = output[:pos]
|
||||||
|
stopped = True
|
||||||
|
yield output
|
||||||
|
|
||||||
|
if stopped:
|
||||||
|
break
|
||||||
|
|
||||||
|
del past_key_values
|
@ -76,7 +76,13 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
|||||||
for line in res.iter_lines():
|
for line in res.iter_lines():
|
||||||
if line:
|
if line:
|
||||||
decoded_line = line.decode("utf-8")
|
decoded_line = line.decode("utf-8")
|
||||||
|
try:
|
||||||
json_line = json.loads(decoded_line)
|
json_line = json.loads(decoded_line)
|
||||||
print(json_line)
|
print(json_line)
|
||||||
text += json_line["choices"][0]["message"]["content"]
|
text += json_line["choices"][0]["message"]["content"]
|
||||||
yield text
|
yield text
|
||||||
|
except Exception as e:
|
||||||
|
text += decoded_line
|
||||||
|
yield json.loads(text)["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,6 +61,8 @@ class BaseOutputParser(ABC):
|
|||||||
|
|
||||||
# stream out output
|
# stream out output
|
||||||
output = data["text"][11:].replace("<s>", "").strip()
|
output = data["text"][11:].replace("<s>", "").strip()
|
||||||
|
|
||||||
|
# TODO gorilla and falcon output
|
||||||
else:
|
else:
|
||||||
output = data["text"].strip()
|
output = data["text"].strip()
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pilot.model.llm_out.vicuna_base_llm import generate_stream
|
from pilot.model.llm_out.vicuna_base_llm import generate_stream
|
||||||
|
|
||||||
|
|
||||||
@ -96,6 +95,18 @@ class GuanacoChatAdapter(BaseChatAdpter):
|
|||||||
return guanaco_generate_stream
|
return guanaco_generate_stream
|
||||||
|
|
||||||
|
|
||||||
|
class FalconChatAdapter(BaseChatAdpter):
|
||||||
|
"""Model chat adapter for Guanaco"""
|
||||||
|
|
||||||
|
def match(self, model_path: str):
|
||||||
|
return "falcon" in model_path
|
||||||
|
|
||||||
|
def get_generate_stream_func(self):
|
||||||
|
from pilot.model.llm_out.falcon_llm import falcon_generate_output
|
||||||
|
|
||||||
|
return falcon_generate_output
|
||||||
|
|
||||||
|
|
||||||
class ProxyllmChatAdapter(BaseChatAdpter):
|
class ProxyllmChatAdapter(BaseChatAdpter):
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
return "proxyllm" in model_path
|
return "proxyllm" in model_path
|
||||||
@ -106,10 +117,21 @@ class ProxyllmChatAdapter(BaseChatAdpter):
|
|||||||
return proxyllm_generate_stream
|
return proxyllm_generate_stream
|
||||||
|
|
||||||
|
|
||||||
|
class GorillaChatAdapter(BaseChatAdpter):
|
||||||
|
def match(self, model_path: str):
|
||||||
|
return "gorilla" in model_path
|
||||||
|
|
||||||
|
def get_generate_stream_func(self):
|
||||||
|
from pilot.model.llm_out.gorilla_llm import generate_stream
|
||||||
|
|
||||||
|
return generate_stream
|
||||||
|
|
||||||
|
|
||||||
register_llm_model_chat_adapter(VicunaChatAdapter)
|
register_llm_model_chat_adapter(VicunaChatAdapter)
|
||||||
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
||||||
register_llm_model_chat_adapter(GuanacoChatAdapter)
|
register_llm_model_chat_adapter(GuanacoChatAdapter)
|
||||||
|
register_llm_model_chat_adapter(FalconChatAdapter)
|
||||||
|
register_llm_model_chat_adapter(GorillaChatAdapter)
|
||||||
|
|
||||||
# Proxy model for test and develop, it's cheap for us now.
|
# Proxy model for test and develop, it's cheap for us now.
|
||||||
register_llm_model_chat_adapter(ProxyllmChatAdapter)
|
register_llm_model_chat_adapter(ProxyllmChatAdapter)
|
||||||
|
@ -28,7 +28,9 @@ class PDFEmbedding(SourceEmbedding):
|
|||||||
# textsplitter = CHNDocumentSplitter(
|
# textsplitter = CHNDocumentSplitter(
|
||||||
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||||
# )
|
# )
|
||||||
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=1000, chunk_overlap=200)
|
textsplitter = SpacyTextSplitter(
|
||||||
|
pipeline="zh_core_web_sm", chunk_size=1000, chunk_overlap=200
|
||||||
|
)
|
||||||
return loader.load_and_split(textsplitter)
|
return loader.load_and_split(textsplitter)
|
||||||
|
|
||||||
@register
|
@register
|
||||||
|
Loading…
Reference in New Issue
Block a user