[chat] refactor model save/load logic (#3654)

* [chat] strategy refactor unwrap model

* [chat] strategy refactor save model

* [chat] add docstr

* [chat] refactor trainer save model

* [chat] fix strategy typing

* [chat] refactor trainer save model

* [chat] update readme

* [chat] fix unit test
This commit is contained in:
Hongxin Liu
2023-04-27 18:41:49 +08:00
committed by GitHub
parent 6ef7011462
commit 842768a174
14 changed files with 155 additions and 181 deletions

View File

@@ -66,6 +66,7 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
--lr 2e-5 \
--max_datasets_size 512 \
--max_epochs 1 \
--grad_checkpoint
```
### Arg List
- --strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='naive'
@@ -78,6 +79,7 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
- --batch_size: batch size while training, type=int, default=4
- --lora_rank: low-rank adaptation matrices rank, type=int, default=0
- --log_interval: how many steps to log, type=int, default=100
- --grad_checkpoint: enable gradient checkpointing, type=bool, default=False
## Stage2 - Training reward model
@@ -152,7 +154,7 @@ torchrun --standalone --nproc_per_node=4 train_prompts.py \
--rm_path /your/rm/model/path
```
Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use [seed_prompts_ch.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_ch.jsonl) or [seed_prompts_en.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_en.jsonl) in InstructionWild.
Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use [seed_prompts_ch.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_ch.jsonl) or [seed_prompts_en.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_en.jsonl) in InstructionWild.
Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning.
### Arg List
@@ -254,29 +256,6 @@ class CoatiActor(Actor):
super().__init__(model, lora_rank, lora_train_bias)
```
### LM model
```
from ..base import LM
from transformers.models.coati import CoatiModel
class GPTLM(LM):
def __init__(self,
pretrained: Optional[str] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = CoatiModel.from_pretrained(pretrained)
else:
model = build_model() # load your own model if it is not support in transformers
super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
```
### Reward model
```
from ..base import RewardModel

View File

@@ -194,7 +194,7 @@ def main(args):
update_timesteps=args.update_timesteps)
# save model checkpoint after fitting
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
strategy.save_model(actor, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(actor_optim,

View File

@@ -124,11 +124,23 @@ def train(args):
raise ValueError(f'Unsupported dataset "{args.dataset}"')
if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
train_sampler = DistributedSampler(train_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size())
valid_sampler = DistributedSampler(valid_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
valid_sampler = DistributedSampler(valid_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size())
eval_sampler = DistributedSampler(eval_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
eval_sampler = DistributedSampler(eval_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size())
else:
train_sampler = None
@@ -141,13 +153,19 @@ def train(args):
batch_size=args.batch_size,
pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, shuffle=(valid_sampler is None),
valid_dataloader = DataLoader(valid_dataset,
shuffle=(valid_sampler is None),
sampler=valid_sampler,
batch_size=args.batch_size, pin_memory=True)
batch_size=args.batch_size,
pin_memory=True)
eval_dataloader = DataLoader(eval_dataset, shuffle=(eval_sampler is None),
sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True)
eval_dataloader = DataLoader(eval_dataset,
shuffle=(eval_sampler is None),
sampler=eval_sampler,
batch_size=args.batch_size,
pin_memory=True)
(model, optim) = strategy.prepare((model, optim))
trainer = RewardModelTrainer(model=model,
strategy=strategy,
optim=optim,
@@ -155,12 +173,11 @@ def train(args):
train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
eval_dataloader=eval_dataloader,
batch_size=args.batch_size,
max_epochs=args.max_epochs)
trainer.fit()
# save model checkpoint after fitting on only rank0
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
strategy.save_model(model, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(trainer.optimizer,

View File

@@ -152,6 +152,7 @@ def train(args):
else:
eval_dataloader = None
(model, optim) = strategy.prepare((model, optim))
trainer = SFTTrainer(model=model,
strategy=strategy,
optim=optim,
@@ -163,7 +164,7 @@ def train(args):
trainer.fit(logger=logger, use_wandb=args.use_wandb)
# save model checkpoint after fitting on only rank0
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(trainer.optimizer,