mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[chat] refactor model save/load logic (#3654)
* [chat] strategy refactor unwrap model * [chat] strategy refactor save model * [chat] add docstr * [chat] refactor trainer save model * [chat] fix strategy typing * [chat] refactor trainer save model * [chat] update readme * [chat] fix unit test
This commit is contained in:
@@ -66,6 +66,7 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
|
||||
--lr 2e-5 \
|
||||
--max_datasets_size 512 \
|
||||
--max_epochs 1 \
|
||||
--grad_checkpoint
|
||||
```
|
||||
### Arg List
|
||||
- --strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='naive'
|
||||
@@ -78,6 +79,7 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
|
||||
- --batch_size: batch size while training, type=int, default=4
|
||||
- --lora_rank: low-rank adaptation matrices rank, type=int, default=0
|
||||
- --log_interval: how many steps to log, type=int, default=100
|
||||
- --grad_checkpoint: enable gradient checkpointing, type=bool, default=False
|
||||
|
||||
## Stage2 - Training reward model
|
||||
|
||||
@@ -152,7 +154,7 @@ torchrun --standalone --nproc_per_node=4 train_prompts.py \
|
||||
--rm_path /your/rm/model/path
|
||||
```
|
||||
|
||||
Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use [seed_prompts_ch.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_ch.jsonl) or [seed_prompts_en.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_en.jsonl) in InstructionWild.
|
||||
Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use [seed_prompts_ch.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_ch.jsonl) or [seed_prompts_en.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_en.jsonl) in InstructionWild.
|
||||
Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning.
|
||||
|
||||
### Arg List
|
||||
@@ -254,29 +256,6 @@ class CoatiActor(Actor):
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
||||
```
|
||||
|
||||
### LM model
|
||||
|
||||
```
|
||||
from ..base import LM
|
||||
from transformers.models.coati import CoatiModel
|
||||
|
||||
class GPTLM(LM):
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = CoatiModel.from_pretrained(pretrained)
|
||||
else:
|
||||
model = build_model() # load your own model if it is not support in transformers
|
||||
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
||||
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
|
||||
```
|
||||
### Reward model
|
||||
```
|
||||
from ..base import RewardModel
|
||||
|
||||
@@ -194,7 +194,7 @@ def main(args):
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
||||
# save model checkpoint after fitting
|
||||
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
strategy.save_model(actor, args.save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(actor_optim,
|
||||
|
||||
@@ -124,11 +124,23 @@ def train(args):
|
||||
raise ValueError(f'Unsupported dataset "{args.dataset}"')
|
||||
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
|
||||
train_sampler = DistributedSampler(train_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
valid_sampler = DistributedSampler(valid_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
|
||||
valid_sampler = DistributedSampler(valid_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
eval_sampler = DistributedSampler(eval_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
|
||||
eval_sampler = DistributedSampler(eval_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
else:
|
||||
train_sampler = None
|
||||
@@ -141,13 +153,19 @@ def train(args):
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True)
|
||||
|
||||
valid_dataloader = DataLoader(valid_dataset, shuffle=(valid_sampler is None),
|
||||
valid_dataloader = DataLoader(valid_dataset,
|
||||
shuffle=(valid_sampler is None),
|
||||
sampler=valid_sampler,
|
||||
batch_size=args.batch_size, pin_memory=True)
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True)
|
||||
|
||||
eval_dataloader = DataLoader(eval_dataset, shuffle=(eval_sampler is None),
|
||||
sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True)
|
||||
eval_dataloader = DataLoader(eval_dataset,
|
||||
shuffle=(eval_sampler is None),
|
||||
sampler=eval_sampler,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True)
|
||||
|
||||
(model, optim) = strategy.prepare((model, optim))
|
||||
trainer = RewardModelTrainer(model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
@@ -155,12 +173,11 @@ def train(args):
|
||||
train_dataloader=train_dataloader,
|
||||
valid_dataloader=valid_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
batch_size=args.batch_size,
|
||||
max_epochs=args.max_epochs)
|
||||
|
||||
trainer.fit()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
strategy.save_model(model, args.save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(trainer.optimizer,
|
||||
|
||||
@@ -152,6 +152,7 @@ def train(args):
|
||||
else:
|
||||
eval_dataloader = None
|
||||
|
||||
(model, optim) = strategy.prepare((model, optim))
|
||||
trainer = SFTTrainer(model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
@@ -163,7 +164,7 @@ def train(args):
|
||||
trainer.fit(logger=logger, use_wandb=args.use_wandb)
|
||||
|
||||
# save model checkpoint after fitting on only rank0
|
||||
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(trainer.optimizer,
|
||||
|
||||
Reference in New Issue
Block a user