feature: add guanaco support #121

This commit is contained in:
csunny 2023-05-30 23:02:46 +08:00
parent 16c6986666
commit b0e22eff05
5 changed files with 163 additions and 7 deletions

View File

@ -35,6 +35,7 @@ LLM_MODEL_CONFIG = {
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"proxyllm": "proxyllm",
}

View File

@ -3,7 +3,7 @@
from functools import cache
from typing import List
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
from pilot.configs.model_config import DEVICE
@ -85,8 +85,15 @@ class ChatGLMAdapater(BaseLLMAdaper):
class GuanacoAdapter(BaseLLMAdaper):
"""TODO Support guanaco"""
def match(self, model_path: str):
return "guanaco" in model_path
pass
def loader(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = LlamaTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
)
return model, tokenizer
class CodeGenAdapter(BaseLLMAdaper):
@ -143,6 +150,7 @@ class ProxyllmAdapter(BaseLLMAdaper):
register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate
# just for test, remove this later

View File

@ -0,0 +1,78 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import transformers
from transformers import GenerationConfig
from llm_utils import Iteratorize, Stream
def guanaco_generate_output(model, tokenizer, params, device):
"""Fork from fastchat: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
prompt = params["prompt"]
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
temperature=0.5,
top_p=0.95,
top_k=45,
max_new_tokens=128,
stream_output=True
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
generate_params = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"max_new_tokens": max_new_tokens,
}
if stream_output:
# Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator,
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
def generate_with_callback(callback=None, **kwargs):
kwargs.setdefault(
"stopping_criteria", transformers.StoppingCriteriaList()
)
kwargs["stopping_criteria"].append(
Stream(callback_func=callback)
)
with torch.no_grad():
model.generate(**kwargs)
def generate_with_streaming(**kwargs):
return Iteratorize(
generate_with_callback, kwargs, callback=None
)
with generate_with_streaming(**generate_params) as generator:
for output in generator:
# new_tokens = len(output) - len(input_ids[0])
decoded_output = tokenizer.decode(output)
if output[-1] in [tokenizer.eos_token_id]:
break
yield decoded_output.split("### Response:")[-1].strip()
return # early return for stream_output
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
print(f"debug_sequences,{s}",s)
output = tokenizer.decode(s)
print(f"debug_output,{output}",output)
yield output.split("### Response:")[-1].strip()

View File

@ -1,6 +1,11 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import traceback
from queue import Queue
from threading import Thread
import transformers
from typing import List, Optional
from pilot.configs.config import Config
@ -47,3 +52,65 @@ def create_chat_completion(
response = None
# TODO impl this use vicuna server api
class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}, callback=None):
self.mfunc = func
self.c_callback = callback
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
self.stop_now = False
def _callback(val):
if self.stop_now:
raise ValueError
self.q.put(val)
def gentask():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
except:
traceback.print_exc()
pass
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)
self.thread = Thread(target=gentask)
self.thread.start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True, None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True

View File

@ -85,14 +85,15 @@ class CodeGenChatAdapter(BaseChatAdpter):
class GuanacoChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco"""
"""Model chat adapter for Guanaco """
def match(self, model_path: str):
return "guanaco" in model_path
def get_generate_stream_func(self):
# TODO
pass
from pilot.model.guanaco_llm import guanaco_generate_output
return guanaco_generate_output
class ProxyllmChatAdapter(BaseChatAdpter):
@ -107,6 +108,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter)
# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter)