Merge branch 'llm_fxp' into dev

# Conflicts:
#	.env.template
#	pilot/out_parser/base.py
This commit is contained in:
yhjun1026 2023-06-01 15:35:01 +08:00
commit 96c516ab55
11 changed files with 158 additions and 9 deletions

View File

@ -101,8 +101,8 @@ LANGUAGE=en
#*******************************************************************# #*******************************************************************#
# ** PROXY_SERVER # ** PROXY_SERVER
#*******************************************************************# #*******************************************************************#
PROXY_API_KEY=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 PROXY_API_KEY=
PROXY_SERVER_URL=http://127.0.0.1:3000/api/openai/v1/chat/completions PROXY_SERVER_URL=http://127.0.0.1:3000/proxy_address
#*******************************************************************# #*******************************************************************#

View File

@ -260,7 +260,7 @@ Run the Python interpreter and type the commands:
这是一个用于数据库的复杂且创新的工具, 我们的项目也在紧急的开发当中, 会陆续发布一些新的feature。如在使用当中有任何具体问题, 优先在项目下提issue, 如有需要, 请联系如下微信,我会尽力提供帮助,同时也非常欢迎大家参与到项目建设中。 这是一个用于数据库的复杂且创新的工具, 我们的项目也在紧急的开发当中, 会陆续发布一些新的feature。如在使用当中有任何具体问题, 优先在项目下提issue, 如有需要, 请联系如下微信,我会尽力提供帮助,同时也非常欢迎大家参与到项目建设中。
<p align="center"> <p align="center">
<img src="./assets/DB_GPT_wechat.png" width="320px" /> <img src="./assets/wechat.jpg" width="320px" />
</p> </p>
## Licence ## Licence

Binary file not shown.

Before

Width:  |  Height:  |  Size: 157 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 323 KiB

After

Width:  |  Height:  |  Size: 168 KiB

View File

@ -35,6 +35,7 @@ LLM_MODEL_CONFIG = {
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"), "chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"proxyllm": "proxyllm", "proxyllm": "proxyllm",
} }

View File

@ -3,7 +3,7 @@
from functools import cache from functools import cache
from typing import List from typing import List
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
from pilot.configs.model_config import DEVICE from pilot.configs.model_config import DEVICE
@ -86,7 +86,15 @@ class ChatGLMAdapater(BaseLLMAdaper):
class GuanacoAdapter(BaseLLMAdaper): class GuanacoAdapter(BaseLLMAdaper):
"""TODO Support guanaco""" """TODO Support guanaco"""
pass def match(self, model_path: str):
return "guanaco" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = LlamaTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
)
return model, tokenizer
class CodeGenAdapter(BaseLLMAdaper): class CodeGenAdapter(BaseLLMAdaper):
@ -143,6 +151,7 @@ class ProxyllmAdapter(BaseLLMAdaper):
register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate # TODO Default support vicuna, other model need to tests and Evaluate
# just for test, remove this later # just for test, remove this later

View File

@ -0,0 +1,65 @@
import torch
from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
print(params)
stop = params.get("stop", "###")
messages = params["prompt"]
hist = []
for i in range(1, len(messages) - 2, 2):
hist.append(
(
messages[i].split(ROLE_USER + ":")[1],
messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
)
)
query = messages[-2].split(ROLE_USER + ":")[1]
print("Query Message: ", query)
input_ids = tokenizer(query, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
stop_token_ids = [0]
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
stop = StopOnTokens()
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=512,
temperature=1.0,
do_sample=True,
top_k=1,
streamer=streamer,
repetition_penalty=1.7,
stopping_criteria=StoppingCriteriaList([stop])
)
t1 = Thread(target=model.generate, kwargs=generate_kwargs)
t1.start()
generator = model.generate(**generate_kwargs)
for output in generator:
# new_tokens = len(output) - len(input_ids[0])
decoded_output = tokenizer.decode(output)
if output[-1] in [tokenizer.eos_token_id]:
break
out = decoded_output.split("### Response:")[-1].strip()
yield out

View File

@ -1,6 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import traceback
from queue import Queue
from threading import Thread
import transformers
from typing import List, Optional from typing import List, Optional
from pilot.configs.config import Config from pilot.configs.config import Config
@ -47,3 +52,66 @@ def create_chat_completion(
response = None response = None
# TODO impl this use vicuna server api # TODO impl this use vicuna server api
class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}, callback=None):
self.mfunc = func
self.c_callback = callback
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
self.stop_now = False
def _callback(val):
if self.stop_now:
raise ValueError
self.q.put(val)
def gentask():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
except:
traceback.print_exc()
pass
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)
self.thread = Thread(target=gentask)
self.thread.start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True, None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True

View File

@ -113,7 +113,11 @@ class ModelLoader(metaclass=Singleton):
or self.device == "mps" or self.device == "mps"
and tokenizer and tokenizer
): ):
model.to(self.device) # 4-bit not support this
try:
model.to(self.device)
except ValueError:
pass
if debug: if debug:
print(model) print(model)

View File

@ -108,7 +108,7 @@ class BaseOutputParser(ABC):
if not self.is_stream_out: if not self.is_stream_out:
return self._parse_model_nostream_resp(response, self.sep) return self._parse_model_nostream_resp(response, self.sep)
else: else:
return self._parse_model_stream_resp(response, self.sep, skip_echo_len) return self._parse_model_stream_resp(response, self.sep)
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
""" """

View File

@ -91,8 +91,9 @@ class GuanacoChatAdapter(BaseChatAdpter):
return "guanaco" in model_path return "guanaco" in model_path
def get_generate_stream_func(self): def get_generate_stream_func(self):
# TODO from pilot.model.guanaco_llm import guanaco_generate_output
pass
return guanaco_generate_output
class ProxyllmChatAdapter(BaseChatAdpter): class ProxyllmChatAdapter(BaseChatAdpter):
@ -107,6 +108,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter)
# Proxy model for test and develop, it's cheap for us now. # Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter) register_llm_model_chat_adapter(ProxyllmChatAdapter)