[chatgpt] support instuct training (#3216)

This commit is contained in:
Fazzie-Maqianli 2023-03-23 16:46:20 +08:00 committed by GitHub
parent cd142fbefa
commit 4fd4bd9d9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 313 additions and 39 deletions

View File

@ -1,5 +1,5 @@
from .reward_dataset import RmStaticDataset, HhRlhfDataset from .reward_dataset import RmStaticDataset, HhRlhfDataset
from .utils import is_rank_0 from .utils import is_rank_0
from .sft_dataset import SFTDataset from .sft_dataset import SFTDataset, AlpacaDataset, AlpacaDataCollator
__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset'] __all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset', 'AlpacaDataset', 'AlpacaDataCollator']

View File

@ -1,12 +1,46 @@
from typing import Callable # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from dataclasses import dataclass, field
from typing import Callable, Dict, Sequence
import random import random
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torch.distributed as dist import torch.distributed as dist
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from .utils import is_rank_0 from .utils import is_rank_0, jload
import transformers
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
IGNORE_INDEX = -100
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
class SFTDataset(Dataset): class SFTDataset(Dataset):
""" """
@ -38,3 +72,87 @@ class SFTDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
return self.prompts[idx] return self.prompts[idx]
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)
class AlpacaDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
super(AlpacaDataset, self).__init__()
logger.info("Loading data...")
list_data_dict = jload(data_path)
logger.info("Formatting inputs...")
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
for example in list_data_dict
]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
logger.info("Tokenizing inputs... This may take some time...")
data_dict = preprocess(sources, targets, tokenizer)
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
@dataclass
class AlpacaDataCollator(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)

View File

@ -1,5 +1,20 @@
import io
import json
import torch.distributed as dist import torch.distributed as dist
def is_rank_0() -> bool: def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0 return not dist.is_initialized() or dist.get_rank() == 0
def _make_r_io_base(f, mode: str):
if not isinstance(f, io.IOBase):
f = open(f, mode=mode)
return f
def jload(f, mode="r"):
"""Load a .json file into a dictionary."""
f = _make_r_io_base(f, mode)
jdict = json.load(f)
f.close()
return jdict

View File

@ -1,5 +1,6 @@
from .llama_actor import LlamaActor from .llama_actor import LlamaActor
from .llama_critic import LlamaCritic from .llama_critic import LlamaCritic
from .llama_rm import LlamaRM from .llama_rm import LlamaRM
from .llama_lm import LlamaLM
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM'] __all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM', 'LlamaLM']

View File

@ -0,0 +1,38 @@
from typing import Optional
from transformers import LlamaConfig, LlamaForCausalLM
from ..base import LM
class LlamaLM(LM):
"""
Llama language model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = LlamaForCausalLM(config)
else:
model = LlamaForCausalLM(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)

View File

@ -2,7 +2,6 @@ from abc import ABC
from typing import Optional from typing import Optional
import loralib as lora import loralib as lora
import torch import torch
from chatgpt.dataset import SFTDataset
from chatgpt.models.loss import GPTLMLoss from chatgpt.models.loss import GPTLMLoss
from torch.optim import Adam, Optimizer from torch.optim import Adam, Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -22,8 +21,8 @@ class SFTTrainer(ABC):
model (torch.nn.Module): the model to train model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training optim(Optimizer): the optimizer to use for training
train_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for training train_dataloader: the dataloader to use for training
eval_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for evaluation eval_dataloader: the dataloader to use for evaluation
batch_size (int, defaults to 1): the batch size while training batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train max_epochs (int, defaults to 2): the number of epochs to train
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
@ -34,8 +33,8 @@ class SFTTrainer(ABC):
model, model,
strategy: Strategy, strategy: Strategy,
optim: Optimizer, optim: Optimizer,
train_dataset: SFTDataset, train_dataloader: DataLoader,
eval_dataset: SFTDataset, eval_dataloader: DataLoader = None,
sampler: Optional[DistributedSampler] = None, sampler: Optional[DistributedSampler] = None,
batch_size: int = 1, batch_size: int = 1,
max_epochs: int = 2, max_epochs: int = 2,
@ -43,13 +42,10 @@ class SFTTrainer(ABC):
super().__init__() super().__init__()
self.strategy = strategy self.strategy = strategy
self.epochs = max_epochs self.epochs = max_epochs
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.sampler = sampler self.sampler = sampler
self.train_dataloader = DataLoader(self.train_dataset, shuffle=(sampler is None), self.train_dataloader = train_dataloader
sampler=sampler, batch_size=batch_size) self.eval_dataloader = eval_dataloader
self.eval_dataloader = DataLoader(self.eval_dataset, batch_size=batch_size)
self.model = strategy.setup_model(model) self.model = strategy.setup_model(model)
if "DDP" in str(self.strategy): if "DDP" in str(self.strategy):
@ -79,23 +75,25 @@ class SFTTrainer(ABC):
logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}') logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')
# eval # eval
self.model.eval() if self.eval_dataloader is not None:
with torch.no_grad(): self.model.eval()
loss_sum = 0 with torch.no_grad():
num_seen = 0 loss_sum = 0
for batch in self.eval_dataloader: num_seen = 0
prompt_ids = batch["input_ids"] for batch in self.eval_dataloader:
p_mask = batch["attention_mask"] prompt_ids = batch["input_ids"]
prompt_ids = prompt_ids.squeeze(1).cuda() p_mask = batch["attention_mask"]
p_mask = p_mask.squeeze(1).cuda() prompt_ids = prompt_ids.squeeze(1).cuda()
p_mask = p_mask.squeeze(1).cuda()
prompt_logits = self.model(prompt_ids, attention_mask=p_mask) prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
loss = self.loss_fn(prompt_logits, prompt_ids) loss = self.loss_fn(prompt_logits, prompt_ids)
loss_sum += loss.item() loss_sum += loss.item()
num_seen += prompt_ids.size(0) num_seen += prompt_ids.size(0)
loss_mean = loss_sum / num_seen loss_mean = loss_sum / num_seen
if dist.get_rank() == 0: if dist.get_rank() == 0:
logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}') logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}')
epoch_bar.update() epoch_bar.update()

View File

@ -0,0 +1,3 @@
from .tokenizer_utils import smart_tokenizer_and_embedding_resize, prepare_llama_tokenizer_and_embedding
__all__ = ['smart_tokenizer_and_embedding_resize', 'prepare_llama_tokenizer_and_embedding']

View File

@ -0,0 +1,74 @@
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import transformers
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "</s>"
def prepare_llama_tokenizer_and_embedding(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
):
"""prepare llama tokenizer and embedding.
"""
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
tokenizer=tokenizer,
model=model,
)
tokenizer.add_special_tokens(
{
"eos_token": DEFAULT_EOS_TOKEN,
"bos_token": DEFAULT_BOS_TOKEN,
"unk_token": DEFAULT_UNK_TOKEN,
}
)
return tokenizer
def smart_tokenizer_and_embedding_resize(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
if tokenizer.pad_token is None:
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg

View File

@ -4,15 +4,18 @@ import loralib as lora
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from chatgpt.dataset import SFTDataset from chatgpt.dataset import SFTDataset, AlpacaDataset, AlpacaDataCollator
from chatgpt.models.base import RewardModel from chatgpt.models.base import RewardModel
from chatgpt.models.bloom import BLOOMLM from chatgpt.models.bloom import BLOOMLM
from chatgpt.models.gpt import GPTLM from chatgpt.models.gpt import GPTLM
from chatgpt.models.opt import OPTLM from chatgpt.models.opt import OPTLM
from chatgpt.models.llama import LlamaLM
from chatgpt.trainer import SFTTrainer from chatgpt.trainer import SFTTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from chatgpt.utils import prepare_llama_tokenizer_and_embedding
from datasets import load_dataset from datasets import load_dataset
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, BloomTokenizerFast from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
@ -41,6 +44,8 @@ def train(args):
model = OPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() model = OPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
elif args.model == 'gpt2': elif args.model == 'gpt2':
model = GPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() model = GPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
elif args.model == 'llama':
model = LlamaLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
@ -53,9 +58,19 @@ def train(args):
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt': elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif args.model == 'llama':
tokenizer = AutoTokenizer.from_pretrained(
args.pretrain,
padding_side="right",
use_fast=False,
)
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
tokenizer.pad_token = tokenizer.eos_token
if args.model == 'llama':
tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model)
else:
tokenizer.pad_token = tokenizer.eos_token
max_len = 512 max_len = 512
@ -67,11 +82,19 @@ def train(args):
logger = get_dist_logger() logger = get_dist_logger()
train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train') # configure dataset
eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test') if args.dataset == 'yizhongw/self_instruct':
train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train')
eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test')
train_dataset = SFTDataset(train_data, tokenizer, max_len) train_dataset = SFTDataset(train_data, tokenizer, max_len)
eval_dataset = SFTDataset(eval_data, tokenizer, max_len) eval_dataset = SFTDataset(eval_data, tokenizer, max_len)
elif 'alpaca' in args.dataset:
train_dataset = AlpacaDataset(tokenizer=tokenizer, data_path=args.dataset)
eval_dataset = None
eval_dataset
data_collator = AlpacaDataCollator(tokenizer=tokenizer)
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True) sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True)
@ -79,11 +102,15 @@ def train(args):
else: else:
sampler = None sampler = None
train_dataloader = DataLoader(train_dataset, shuffle=(sampler is None), sampler=sampler, batch_size=args.batch_size)
if eval_dataset is not None:
eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size)
trainer = SFTTrainer(model=model, trainer = SFTTrainer(model=model,
strategy=strategy, strategy=strategy,
optim=optim, optim=optim,
train_dataset=train_dataset, train_dataloader=train_dataloader,
eval_dataset=eval_dataset, eval_dataloader=eval_dataloader,
sampler=sampler, sampler=sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
max_epochs=args.max_epochs) max_epochs=args.max_epochs)