[chat] fix bugs and add unit tests (#4213)

* style: rename replay buffer

Experience replay is typically for off policy algorithms.
Use this name in PPO maybe misleading.

* fix: fix wrong zero2 default arg

* test: update experience tests

* style: rename zero_pad fn

* fix: defer init in CycledDataLoader

* test: add benchmark test

* style: rename internal fn of generation

* style: rename internal fn of lora

* fix: remove unused loss fn

* fix: remove unused utils fn

* refactor: remove generate_with_actor fn

* fix: fix type annotation

* test: add models tests

* fix: skip llama due to long execution time

* style: modify dataset

* style: apply formatter

* perf: update reward dataset

* fix: fix wrong IGNORE_INDEX in sft dataset

* fix: remove DataCollatorForSupervisedDataset

* test: add dataset tests

* style: apply formatter

* style: rename test_ci to test_train

* feat: add llama in inference

* test: add inference tests

* test: change test scripts directory

* fix: update ci

* fix: fix typo

* fix: skip llama due to oom

* fix: fix file mod

* style: apply formatter

* refactor: remove duplicated llama_gptq

* style: apply formatter

* to: update rm test

* feat: add tokenizer arg

* feat: add download model script

* test: update train tests

* fix: modify gemini load and save pretrained

* test: update checkpoint io test

* to: modify nproc_per_node

* fix: do not remove existing dir

* fix: modify save path

* test: add random choice

* fix: fix sft path

* fix: enlarge nproc_per_node to avoid oom

* fix: add num_retry

* fix: make lora config of rm and critic consistent

* fix: add warning about lora weights

* fix: skip some gpt2 tests

* fix: remove grad ckpt in rm and critic due to errors

* refactor: directly use Actor in train_sft

* test: add more arguments

* fix: disable grad ckpt when using lora

* fix: fix save_pretrained and related tests

* test: enable zero2 tests

* revert: remove useless fn

* style: polish code

* test: modify test args
This commit is contained in:
Wenhao Chen
2023-08-02 10:17:36 +08:00
committed by GitHub
parent 16bf4c0221
commit da4f7b855f
62 changed files with 1404 additions and 1202 deletions

View File

@@ -4,8 +4,8 @@ import argparse
from time import time
import torch
from llama_gptq import load_quant
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
from coati.quant import llama_load_quant, low_resource_init
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
def generate_prompt(instruction, input=None):
@@ -106,7 +106,10 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
if args.quant == '4bit':
model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size)
with low_resource_init():
config = LlamaConfig.from_pretrained(args.pretrained)
model = LlamaForCausalLM(config)
model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
model.cuda()
else:
model = LlamaForCausalLM.from_pretrained(

View File

@@ -1,5 +0,0 @@
from .loader import load_quant
__all__ = [
'load_quant',
]

View File

@@ -1,41 +0,0 @@
import torch
import torch.nn as nn
import transformers
from transformers import LlamaConfig, LlamaForCausalLM
from .model_utils import find_layers
from .quant import make_quant
def load_quant(pretrained: str, checkpoint: str, wbits: int, groupsize: int):
config = LlamaConfig.from_pretrained(pretrained)
def noop(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = LlamaForCausalLM(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant(model, layers, wbits, groupsize)
print(f'Loading model with {wbits} bits...')
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
model.seqlen = 2048
print('Done.')
return model

View File

@@ -1,13 +0,0 @@
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
import torch
import torch.nn as nn
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
return res

View File

@@ -1,283 +0,0 @@
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/quant.py
import math
import numpy as np
import torch
import torch.nn as nn
def quantize(x, scale, zero, maxq):
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
return scale * (q - zero)
class Quantizer(nn.Module):
def __init__(self, shape=1):
super(Quantizer, self).__init__()
self.register_buffer('maxq', torch.tensor(0))
self.register_buffer('scale', torch.zeros(shape))
self.register_buffer('zero', torch.zeros(shape))
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
self.mse = mse
self.norm = norm
self.grid = grid
self.maxshrink = maxshrink
def find_params(self, x, weight=False):
dev = x.device
self.maxq = self.maxq.to(dev)
shape = x.shape
if self.perchannel:
if weight:
x = x.flatten(1)
else:
if len(shape) == 4:
x = x.permute([1, 0, 2, 3])
x = x.flatten(1)
if len(shape) == 3:
x = x.reshape((-1, shape[-1])).t()
if len(shape) == 2:
x = x.t()
else:
x = x.flatten().unsqueeze(0)
tmp = torch.zeros(x.shape[0], device=dev)
xmin = torch.minimum(x.min(1)[0], tmp)
xmax = torch.maximum(x.max(1)[0], tmp)
if self.sym:
xmax = torch.maximum(torch.abs(xmin), xmax)
tmp = xmin < 0
if torch.any(tmp):
xmin[tmp] = -xmax[tmp]
tmp = (xmin == 0) & (xmax == 0)
xmin[tmp] = -1
xmax[tmp] = +1
self.scale = (xmax - xmin) / self.maxq
if self.sym:
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
else:
self.zero = torch.round(-xmin / self.scale)
if self.mse:
best = torch.full([x.shape[0]], float('inf'), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / self.maxq
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
q -= x
q.abs_()
q.pow_(self.norm)
err = torch.sum(q, 1)
tmp = err < best
if torch.any(tmp):
best[tmp] = err[tmp]
self.scale[tmp] = scale1[tmp]
self.zero[tmp] = zero1[tmp]
if not self.perchannel:
if weight:
tmp = shape[0]
else:
tmp = shape[1] if len(shape) != 3 else shape[2]
self.scale = self.scale.repeat(tmp)
self.zero = self.zero.repeat(tmp)
if weight:
shape = [-1] + [1] * (len(shape) - 1)
self.scale = self.scale.reshape(shape)
self.zero = self.zero.reshape(shape)
return
if len(shape) == 4:
self.scale = self.scale.reshape((1, -1, 1, 1))
self.zero = self.zero.reshape((1, -1, 1, 1))
if len(shape) == 3:
self.scale = self.scale.reshape((1, 1, -1))
self.zero = self.zero.reshape((1, 1, -1))
if len(shape) == 2:
self.scale = self.scale.unsqueeze(0)
self.zero = self.zero.unsqueeze(0)
def quantize(self, x):
if self.ready():
return quantize(x, self.scale, self.zero, self.maxq)
return x
def enabled(self):
return self.maxq > 0
def ready(self):
return torch.all(self.scale != 0)
try:
import quant_cuda
except:
print('CUDA extension not installed.')
# Assumes layer is perfectly divisible into 256 * 256 blocks
class QuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures):
super().__init__()
if bits not in [2, 3, 4, 8]:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
if groupsize != -1 and groupsize < 32 and groupsize != int(math.pow(2, int(math.log2(groupsize)))):
raise NotImplementedError("groupsize supports powers of 2 greater than 32. (e.g. : 32,64,128,etc)")
groupsize = groupsize if groupsize != -1 else infeatures
self.groupsize = groupsize
self.register_buffer(
'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)),
dtype=torch.int))
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
self.register_buffer('bias', torch.zeros(outfeatures))
self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
self._initialized_quant_state = False
def pack(self, linear, scales, zeros):
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone()
if linear.bias is not None:
self.bias = linear.bias.clone()
intweight = []
for idx in range(self.infeatures):
g_idx = idx // self.groupsize
intweight.append(
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,
None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros((intweight.shape[0] // 256 * (self.bits * 8), intweight.shape[1]), dtype=np.uint32)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
elif self.bits == 3:
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i))
i += 10
qweight[row] |= intweight[i] << 30
row += 1
qweight[row] |= (intweight[i] >> 2) & 1
i += 1
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i) + 1)
i += 10
qweight[row] |= intweight[i] << 31
row += 1
qweight[row] |= (intweight[i] >> 1) & 0x3
i += 1
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i) + 2)
i += 10
row += 1
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 256 * (self.bits * 8)), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
elif self.bits == 3:
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
i += 10
qzeros[:, col] |= zeros[:, i] << 30
col += 1
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
i += 10
qzeros[:, col] |= zeros[:, i] << 31
col += 1
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
i += 10
col += 1
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
intermediate_dtype = torch.float32
if not self._initialized_quant_state:
# Do we even have a bias? Check for at least one non-zero element.
if self.bias is not None and bool(torch.any(self.bias != 0)):
# Then make sure it's the right type.
self.bias.data = self.bias.data.to(intermediate_dtype)
else:
self.bias = None
outshape = list(x.shape)
outshape[-1] = self.outfeatures
x = x.reshape(-1, x.shape[-1])
if self.bias is None:
y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device)
else:
y = self.bias.clone().repeat(x.shape[0], 1)
output_dtype = x.dtype
x = x.to(intermediate_dtype)
if self.bits == 2:
quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
elif self.bits == 3:
quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
elif self.bits == 4:
quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
elif self.bits == 8:
quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
y = y.to(output_dtype)
return y.reshape(outshape)
def make_quant(module, names, bits, groupsize, name=''):
if isinstance(module, QuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features))
for name1, child in module.named_children():
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)

View File

@@ -5,8 +5,7 @@ from locust import HttpUser, task
samples = [[
dict(
instruction='Who is the best player in the history of NBA?',
response=
'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
),
dict(instruction='continue this talk', response=''),
], [

View File

@@ -1,19 +1,19 @@
import argparse
import os
from threading import Lock
from typing import Dict, Generator, List, Optional
from typing import Generator, List, Optional
import torch
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from coati.quant import llama_load_quant, low_resource_init
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from llama_gptq import load_quant
from pydantic import BaseModel, Field
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
@@ -56,7 +56,7 @@ app.add_middleware(
def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
#TODO(ver217): streaming generation does not support repetition_penalty now
# TODO(ver217): streaming generation does not support repetition_penalty now
model_kwargs = {
'max_generate_tokens': max_new_tokens,
'early_stopping': True,
@@ -162,7 +162,10 @@ if __name__ == '__main__':
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
if args.quant == '4bit':
model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size)
with low_resource_init():
config = LlamaConfig.from_pretrained(args.pretrained)
model = LlamaForCausalLM(config)
model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
model.cuda()
else:
model = LlamaForCausalLM.from_pretrained(

View File

@@ -10,37 +10,34 @@ samples = [
([
Dialogue(
instruction='Who is the best player in the history of NBA?',
response=
'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
),
Dialogue(instruction='continue this talk', response=''),
], 128,
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
),
([
Dialogue(
instruction='Who is the best player in the history of NBA?',
response=
'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
),
Dialogue(instruction='continue this talk', response=''),
], 200,
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
),
([
Dialogue(
instruction='Who is the best player in the history of NBA?',
response=
'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
),
Dialogue(instruction='continue this talk', response=''),
], 211,
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n'
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n'
),
([
Dialogue(instruction='Who is the best player in the history of NBA?', response=''),
], 128,
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n'
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n'
),
]

View File

@@ -1,9 +1,9 @@
import json
import re
from threading import Lock
from typing import Any, Callable, Generator, List, Optional
import json
import jieba
import jieba
import torch
import torch.distributed as dist
import torch.nn as nn
@@ -127,7 +127,7 @@ STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
class ChatPromptProcessor:
SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.'
def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str]=[]):
def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []):
self.tokenizer = tokenizer
self.context = context
self.max_len = max_len
@@ -182,6 +182,7 @@ class ChatPromptProcessor:
intersection = set(jieba.cut(text.lower())) & self.censored_words
return len(intersection) > 0
class LockedIterator:
def __init__(self, it, lock: Lock) -> None:
@@ -195,6 +196,7 @@ class LockedIterator:
with self.lock:
return next(self.it)
def load_json(path: str):
with open(path) as f:
return json.load(f)
return json.load(f)