feature: gorllia support (#173)

This commit is contained in:
csunny 2023-06-08 21:33:22 +08:00
parent bea3f7c8d2
commit 716460460f
7 changed files with 44 additions and 26 deletions

View File

@ -146,7 +146,7 @@ class Config(metaclass=Singleton):
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None) self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None) self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
# QLoRA # QLoRA
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True") self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
### EMBEDDING Configuration ### EMBEDDING Configuration

View File

@ -4,13 +4,25 @@
import torch import torch
from typing import List from typing import List
from functools import cache from functools import cache
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BitsAndBytesConfig from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
LlamaTokenizer,
BitsAndBytesConfig,
)
from pilot.configs.model_config import DEVICE from pilot.configs.model_config import DEVICE
from pilot.configs.config import Config from pilot.configs.config import Config
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=False) bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="bfloat16",
bnb_4bit_use_double_quant=False,
)
CFG = Config() CFG = Config()
class BaseLLMAdaper: class BaseLLMAdaper:
"""The Base class for multi model, in our project. """The Base class for multi model, in our project.
We will support those model, which performance resemble ChatGPT""" We will support those model, which performance resemble ChatGPT"""
@ -98,8 +110,8 @@ class GuanacoAdapter(BaseLLMAdaper):
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
) )
return model, tokenizer return model, tokenizer
class FalconAdapater(BaseLLMAdaper): class FalconAdapater(BaseLLMAdaper):
"""falcon Adapter""" """falcon Adapter"""
@ -111,23 +123,23 @@ class FalconAdapater(BaseLLMAdaper):
if CFG.QLoRA: if CFG.QLoRA:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, model_path,
load_in_4bit=True, #quantize load_in_4bit=True, # quantize
quantization_config=bnb_config, quantization_config=bnb_config,
device_map={"": 0}, device_map={"": 0},
trust_remote_code=True, trust_remote_code=True,
**from_pretrained_kwagrs **from_pretrained_kwagrs,
) )
else: else:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, model_path,
trust_remote_code=True, trust_remote_code=True,
device_map={"": 0}, device_map={"": 0},
**from_pretrained_kwagrs **from_pretrained_kwagrs,
) )
return model, tokenizer return model, tokenizer
class GorillaAdapter(BaseLLMAdaper): class GorillaAdapter(BaseLLMAdaper):
"""TODO Support guanaco""" """TODO Support guanaco"""

View File

@ -19,7 +19,7 @@ def falcon_generate_output(model, tokenizer, params, device, context_len=2048):
streamer = TextIteratorStreamer( streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
) )
tokenizer.bos_token_id = 1 tokenizer.bos_token_id = 1
stop_token_ids = [0] stop_token_ids = [0]

View File

@ -1,5 +1,6 @@
import torch import torch
@torch.inference_mode() @torch.inference_mode()
def generate_stream( def generate_stream(
model, tokenizer, params, device, context_len=42048, stream_interval=2 model, tokenizer, params, device, context_len=42048, stream_interval=2
@ -22,7 +23,7 @@ def generate_stream(
out = model(torch.as_tensor([input_ids], device=device), use_cache=True) out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits logits = out.logits
past_key_values = out.past_key_values past_key_values = out.past_key_values
else: else:
out = model( out = model(
input_ids=torch.as_tensor([[token]], device=device), input_ids=torch.as_tensor([[token]], device=device),
use_cache=True, use_cache=True,
@ -37,7 +38,6 @@ def generate_stream(
token = int(torch.multinomial(probs, num_samples=1)) token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token) output_ids.append(token)
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
stopped = True stopped = True
else: else:
@ -45,7 +45,11 @@ def generate_stream(
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
tmp_output_ids = output_ids[input_echo_len:] tmp_output_ids = output_ids[input_echo_len:]
output = tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False,) output = tokenizer.decode(
tmp_output_ids,
skip_special_tokens=True,
spaces_between_special_tokens=False,
)
pos = output.rfind(stop_str, l_prompt) pos = output.rfind(stop_str, l_prompt)
if pos != -1: if pos != -1:
output = output[:pos] output = output[:pos]
@ -55,4 +59,4 @@ def generate_stream(
if stopped: if stopped:
break break
del past_key_values del past_key_values

View File

@ -62,7 +62,7 @@ class BaseOutputParser(ABC):
# stream out output # stream out output
output = data["text"][11:].replace("<s>", "").strip() output = data["text"][11:].replace("<s>", "").strip()
# TODO gorilla and falcon output # TODO gorilla and falcon output
else: else:
output = data["text"].strip() output = data["text"].strip()

View File

@ -94,7 +94,7 @@ class GuanacoChatAdapter(BaseChatAdpter):
return guanaco_generate_stream return guanaco_generate_stream
class FalconChatAdapter(BaseChatAdpter): class FalconChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco""" """Model chat adapter for Guanaco"""
@ -105,7 +105,8 @@ class FalconChatAdapter(BaseChatAdpter):
from pilot.model.llm_out.falcon_llm import falcon_generate_output from pilot.model.llm_out.falcon_llm import falcon_generate_output
return falcon_generate_output return falcon_generate_output
class ProxyllmChatAdapter(BaseChatAdpter): class ProxyllmChatAdapter(BaseChatAdpter):
def match(self, model_path: str): def match(self, model_path: str):
return "proxyllm" in model_path return "proxyllm" in model_path
@ -116,8 +117,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
return proxyllm_generate_stream return proxyllm_generate_stream
class GorillaChatAdapter(BaseChatAdpter): class GorillaChatAdapter(BaseChatAdpter):
def match(self, model_path: str): def match(self, model_path: str):
return "gorilla" in model_path return "gorilla" in model_path

View File

@ -28,7 +28,9 @@ class PDFEmbedding(SourceEmbedding):
# textsplitter = CHNDocumentSplitter( # textsplitter = CHNDocumentSplitter(
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE # pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
# ) # )
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=1000, chunk_overlap=200) textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", chunk_size=1000, chunk_overlap=200
)
return loader.load_and_split(textsplitter) return loader.load_and_split(textsplitter)
@register @register