mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-09 03:47:57 +00:00
support instrcut training (#3230)
This commit is contained in:
parent
9bc702ab48
commit
bd39877da4
@ -119,10 +119,15 @@ def preprocess(
|
|||||||
class AlpacaDataset(Dataset):
|
class AlpacaDataset(Dataset):
|
||||||
"""Dataset for supervised fine-tuning."""
|
"""Dataset for supervised fine-tuning."""
|
||||||
|
|
||||||
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
|
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_length: int=None):
|
||||||
super(AlpacaDataset, self).__init__()
|
super(AlpacaDataset, self).__init__()
|
||||||
logger.info("Loading data...")
|
logger.info("Loading data...")
|
||||||
list_data_dict = jload(data_path)
|
list_data_dict = jload(data_path)
|
||||||
|
logger.info(f"Loaded {len(list_data_dict)} examples.")
|
||||||
|
|
||||||
|
if max_length is not None:
|
||||||
|
logger.info(f"Truncating data to max length {max_length}...")
|
||||||
|
list_data_dict = [example for example in list_data_dict if len(example["input"]) <= max_length]
|
||||||
|
|
||||||
logger.info("Formatting inputs...")
|
logger.info("Formatting inputs...")
|
||||||
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
||||||
|
@ -60,3 +60,6 @@ class Actor(LoRAModule):
|
|||||||
logits = output['logits']
|
logits = output['logits']
|
||||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||||
return log_probs[:, -num_actions:]
|
return log_probs[:, -num_actions:]
|
||||||
|
|
||||||
|
def get_base_model(self):
|
||||||
|
return self.model
|
@ -36,3 +36,5 @@ class LlamaLM(LM):
|
|||||||
|
|
||||||
super().__init__(model, lora_rank, lora_train_bias)
|
super().__init__(model, lora_rank, lora_train_bias)
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
||||||
|
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
|
||||||
|
@ -61,13 +61,15 @@ class SFTTrainer(ABC):
|
|||||||
# train
|
# train
|
||||||
self.model.train()
|
self.model.train()
|
||||||
for batch_id, batch in enumerate(self.train_dataloader):
|
for batch_id, batch in enumerate(self.train_dataloader):
|
||||||
prompt_ids = batch["input_ids"]
|
prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
|
||||||
p_mask = batch["attention_mask"]
|
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
|
||||||
labels = batch["labels"]
|
labels = batch["labels"].to(torch.cuda.current_device())
|
||||||
prompt_ids = prompt_ids.squeeze(1).cuda()
|
# prompt_ids = prompt_ids.squeeze(1).cuda()
|
||||||
p_mask = p_mask.squeeze(1).cuda()
|
# p_mask = p_mask.squeeze(1).cuda()
|
||||||
# prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
# prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
||||||
loss, prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
||||||
|
loss = outputs.loss
|
||||||
|
prompt_logits = outputs.logits
|
||||||
|
|
||||||
# loss = self.loss_fn(prompt_logits, labels)
|
# loss = self.loss_fn(prompt_logits, labels)
|
||||||
self.strategy.backward(loss, self.model, self.optimizer)
|
self.strategy.backward(loss, self.model, self.optimizer)
|
||||||
@ -83,13 +85,16 @@ class SFTTrainer(ABC):
|
|||||||
loss_sum = 0
|
loss_sum = 0
|
||||||
num_seen = 0
|
num_seen = 0
|
||||||
for batch in self.eval_dataloader:
|
for batch in self.eval_dataloader:
|
||||||
prompt_ids = batch["input_ids"]
|
prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
|
||||||
p_mask = batch["attention_mask"]
|
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
|
||||||
prompt_ids = prompt_ids.squeeze(1).cuda()
|
labels = batch["labels"].to(torch.cuda.current_device())
|
||||||
p_mask = p_mask.squeeze(1).cuda()
|
# prompt_ids = prompt_ids.squeeze(1).cuda()
|
||||||
|
# p_mask = p_mask.squeeze(1).cuda()
|
||||||
|
|
||||||
|
outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
||||||
|
loss = outputs.loss
|
||||||
|
# prompt_logits = outputs.logits
|
||||||
|
|
||||||
prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
|
|
||||||
loss = self.loss_fn(prompt_logits, prompt_ids)
|
|
||||||
loss_sum += loss.item()
|
loss_sum += loss.item()
|
||||||
num_seen += prompt_ids.size(0)
|
num_seen += prompt_ids.size(0)
|
||||||
|
|
||||||
|
@ -9,7 +9,6 @@ from chatgpt.models.base import Actor
|
|||||||
from chatgpt.models.lora import LoraLinear
|
from chatgpt.models.lora import LoraLinear
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
@ -16,6 +16,8 @@ from typing import Dict
|
|||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
|
from ..models.llama.llama_lm import LlamaLM
|
||||||
|
|
||||||
DEFAULT_PAD_TOKEN = "[PAD]"
|
DEFAULT_PAD_TOKEN = "[PAD]"
|
||||||
DEFAULT_EOS_TOKEN = "</s>"
|
DEFAULT_EOS_TOKEN = "</s>"
|
||||||
DEFAULT_BOS_TOKEN = "</s>"
|
DEFAULT_BOS_TOKEN = "</s>"
|
||||||
@ -60,6 +62,10 @@ def smart_tokenizer_and_embedding_resize(
|
|||||||
|
|
||||||
if tokenizer.pad_token is None:
|
if tokenizer.pad_token is None:
|
||||||
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
||||||
|
|
||||||
|
if isinstance(model, LlamaLM):
|
||||||
|
model = model.get_base_model()
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
if num_new_tokens > 0:
|
if num_new_tokens > 0:
|
||||||
|
@ -93,25 +93,27 @@ def train(args):
|
|||||||
elif 'alpaca' in args.dataset:
|
elif 'alpaca' in args.dataset:
|
||||||
train_dataset = AlpacaDataset(tokenizer=tokenizer, data_path=args.dataset)
|
train_dataset = AlpacaDataset(tokenizer=tokenizer, data_path=args.dataset)
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
eval_dataset
|
|
||||||
data_collator = AlpacaDataCollator(tokenizer=tokenizer)
|
data_collator = AlpacaDataCollator(tokenizer=tokenizer)
|
||||||
|
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True)
|
train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True)
|
||||||
logger.info("Using Distributed Sampler")
|
if eval_dataset is not None:
|
||||||
|
eval_sampler = DistributedSampler(eval_dataset, shuffle=False, seed=42, drop_last=False)
|
||||||
else:
|
else:
|
||||||
sampler = None
|
train_sampler = None
|
||||||
|
eval_sampler = None
|
||||||
|
|
||||||
train_dataloader = DataLoader(train_dataset, shuffle=(sampler is None), sampler=sampler, batch_size=args.batch_size)
|
train_dataloader = DataLoader(train_dataset, shuffle=(train_sampler is None), sampler=train_sampler, batch_size=args.batch_size, collate_fn=data_collator)
|
||||||
if eval_dataset is not None:
|
if eval_dataset is not None:
|
||||||
eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size)
|
eval_dataloader = DataLoader(eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, collate_fn=data_collator)
|
||||||
|
else:
|
||||||
|
eval_dataloader = None
|
||||||
|
|
||||||
trainer = SFTTrainer(model=model,
|
trainer = SFTTrainer(model=model,
|
||||||
strategy=strategy,
|
strategy=strategy,
|
||||||
optim=optim,
|
optim=optim,
|
||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
eval_dataloader=eval_dataloader,
|
eval_dataloader=eval_dataloader,
|
||||||
sampler=sampler,
|
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
max_epochs=args.max_epochs)
|
max_epochs=args.max_epochs)
|
||||||
|
|
||||||
@ -128,7 +130,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--strategy',
|
parser.add_argument('--strategy',
|
||||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||||
default='naive')
|
default='naive')
|
||||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
|
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
|
||||||
parser.add_argument('--pretrain', type=str, default=None)
|
parser.add_argument('--pretrain', type=str, default=None)
|
||||||
parser.add_argument('--dataset', type=str, default='yizhongw/self_instruct')
|
parser.add_argument('--dataset', type=str, default='yizhongw/self_instruct')
|
||||||
parser.add_argument('--save_path', type=str, default='sft_ckpt.pth')
|
parser.add_argument('--save_path', type=str, default='sft_ckpt.pth')
|
||||||
|
@ -17,4 +17,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 8
|
|||||||
|
|
||||||
#torchrun --standalone --nproc_per_node=2 train_sft.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2 --log_interval 10
|
#torchrun --standalone --nproc_per_node=2 train_sft.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2 --log_interval 10
|
||||||
#torchrun --standalone --nproc_per_node=8 train_sft.py --model 'gpt2' --strategy colossalai_zero2 --batch_size 1 --log_interval 10
|
#torchrun --standalone --nproc_per_node=8 train_sft.py --model 'gpt2' --strategy colossalai_zero2 --batch_size 1 --log_interval 10
|
||||||
torchrun --standalone --nproc_per_node=2 train_sft.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2 --log_interval 10
|
torchrun --standalone --nproc_per_node=8 train_sft.py \
|
||||||
|
--pretrain "/data/personal/nus-mql/LLAMA-7B" \
|
||||||
|
--model 'llama' \
|
||||||
|
--strategy colossalai_zero2 \
|
||||||
|
--log_interval 10 \
|
||||||
|
--save_path /data/personal/nus-mql/Coati-7B \
|
||||||
|
--dataset /data/personal/nus-mql/stanford_alpaca/alpaca_data.json
|
||||||
|
@ -1 +1 @@
|
|||||||
0.1.0
|
1.0.0
|
||||||
|
Loading…
Reference in New Issue
Block a user