feature: add guanaco support #121

This commit is contained in:
csunny 2023-05-30 23:02:46 +08:00
parent 16c6986666
commit b0e22eff05
5 changed files with 163 additions and 7 deletions

View File

@ -35,6 +35,7 @@ LLM_MODEL_CONFIG = {
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"), "chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"proxyllm": "proxyllm", "proxyllm": "proxyllm",
} }

View File

@ -3,7 +3,7 @@
from functools import cache from functools import cache
from typing import List from typing import List
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
from pilot.configs.model_config import DEVICE from pilot.configs.model_config import DEVICE
@ -85,8 +85,15 @@ class ChatGLMAdapater(BaseLLMAdaper):
class GuanacoAdapter(BaseLLMAdaper): class GuanacoAdapter(BaseLLMAdaper):
"""TODO Support guanaco""" """TODO Support guanaco"""
def match(self, model_path: str):
return "guanaco" in model_path
pass def loader(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = LlamaTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
)
return model, tokenizer
class CodeGenAdapter(BaseLLMAdaper): class CodeGenAdapter(BaseLLMAdaper):
@ -143,6 +150,7 @@ class ProxyllmAdapter(BaseLLMAdaper):
register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate # TODO Default support vicuna, other model need to tests and Evaluate
# just for test, remove this later # just for test, remove this later

View File

@ -0,0 +1,78 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import transformers
from transformers import GenerationConfig
from llm_utils import Iteratorize, Stream
def guanaco_generate_output(model, tokenizer, params, device):
"""Fork from fastchat: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
prompt = params["prompt"]
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
temperature=0.5,
top_p=0.95,
top_k=45,
max_new_tokens=128,
stream_output=True
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
generate_params = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"max_new_tokens": max_new_tokens,
}
if stream_output:
# Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator,
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
def generate_with_callback(callback=None, **kwargs):
kwargs.setdefault(
"stopping_criteria", transformers.StoppingCriteriaList()
)
kwargs["stopping_criteria"].append(
Stream(callback_func=callback)
)
with torch.no_grad():
model.generate(**kwargs)
def generate_with_streaming(**kwargs):
return Iteratorize(
generate_with_callback, kwargs, callback=None
)
with generate_with_streaming(**generate_params) as generator:
for output in generator:
# new_tokens = len(output) - len(input_ids[0])
decoded_output = tokenizer.decode(output)
if output[-1] in [tokenizer.eos_token_id]:
break
yield decoded_output.split("### Response:")[-1].strip()
return # early return for stream_output
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
print(f"debug_sequences,{s}",s)
output = tokenizer.decode(s)
print(f"debug_output,{output}",output)
yield output.split("### Response:")[-1].strip()

View File

@ -1,6 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import traceback
from queue import Queue
from threading import Thread
import transformers
from typing import List, Optional from typing import List, Optional
from pilot.configs.config import Config from pilot.configs.config import Config
@ -47,3 +52,65 @@ def create_chat_completion(
response = None response = None
# TODO impl this use vicuna server api # TODO impl this use vicuna server api
class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}, callback=None):
self.mfunc = func
self.c_callback = callback
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
self.stop_now = False
def _callback(val):
if self.stop_now:
raise ValueError
self.q.put(val)
def gentask():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
except:
traceback.print_exc()
pass
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)
self.thread = Thread(target=gentask)
self.thread.start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True, None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True

View File

@ -91,8 +91,9 @@ class GuanacoChatAdapter(BaseChatAdpter):
return "guanaco" in model_path return "guanaco" in model_path
def get_generate_stream_func(self): def get_generate_stream_func(self):
# TODO from pilot.model.guanaco_llm import guanaco_generate_output
pass
return guanaco_generate_output
class ProxyllmChatAdapter(BaseChatAdpter): class ProxyllmChatAdapter(BaseChatAdpter):
@ -107,6 +108,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter)
# Proxy model for test and develop, it's cheap for us now. # Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter) register_llm_model_chat_adapter(ProxyllmChatAdapter)