ColossalAI/applications/Chat/tests/test_checkpoint.py
Wenhao Chen da4f7b855f
[chat] fix bugs and add unit tests (#4213)
* style: rename replay buffer

Experience replay is typically for off policy algorithms.
Use this name in PPO maybe misleading.

* fix: fix wrong zero2 default arg

* test: update experience tests

* style: rename zero_pad fn

* fix: defer init in CycledDataLoader

* test: add benchmark test

* style: rename internal fn of generation

* style: rename internal fn of lora

* fix: remove unused loss fn

* fix: remove unused utils fn

* refactor: remove generate_with_actor fn

* fix: fix type annotation

* test: add models tests

* fix: skip llama due to long execution time

* style: modify dataset

* style: apply formatter

* perf: update reward dataset

* fix: fix wrong IGNORE_INDEX in sft dataset

* fix: remove DataCollatorForSupervisedDataset

* test: add dataset tests

* style: apply formatter

* style: rename test_ci to test_train

* feat: add llama in inference

* test: add inference tests

* test: change test scripts directory

* fix: update ci

* fix: fix typo

* fix: skip llama due to oom

* fix: fix file mod

* style: apply formatter

* refactor: remove duplicated llama_gptq

* style: apply formatter

* to: update rm test

* feat: add tokenizer arg

* feat: add download model script

* test: update train tests

* fix: modify gemini load and save pretrained

* test: update checkpoint io test

* to: modify nproc_per_node

* fix: do not remove existing dir

* fix: modify save path

* test: add random choice

* fix: fix sft path

* fix: enlarge nproc_per_node to avoid oom

* fix: add num_retry

* fix: make lora config of rm and critic consistent

* fix: add warning about lora weights

* fix: skip some gpt2 tests

* fix: remove grad ckpt in rm and critic due to errors

* refactor: directly use Actor in train_sft

* test: add more arguments

* fix: disable grad ckpt when using lora

* fix: fix save_pretrained and related tests

* test: enable zero2 tests

* revert: remove useless fn

* style: polish code

* test: modify test args
2023-08-02 10:17:36 +08:00

107 lines
3.7 KiB
Python

import os
import tempfile
from contextlib import nullcontext
import pytest
import torch
import torch.distributed as dist
from coati.models.gpt import GPTActor
from coati.models.utils import calc_action_log_probs
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import rerun_if_address_is_in_use, spawn
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
def get_data(batch_size: int, seq_len: int = 10) -> dict:
input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda")
attention_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def train_step(strategy: Strategy,
actor: GPTActor,
actor_optim: HybridAdam,
batch_size: int = 8):
data = get_data(batch_size)
action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool)
actor_output = actor(data["input_ids"], data["attention_mask"])
action_log_probs = calc_action_log_probs(actor_output, data["input_ids"], action_mask.size(1))
loss = action_log_probs.sum()
strategy.backward(loss, actor, actor_optim)
strategy.optimizer_step(actor_optim)
def run_test_checkpoint(strategy_name: str,
shard: bool):
if strategy_name == "ddp":
strategy = DDPStrategy()
elif strategy_name == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
elif strategy_name == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f"Unsupported strategy '{strategy_name}'")
with strategy.model_init_context():
actor = GPTActor(config=GPT_CONFIG).cuda()
actor_optim = HybridAdam(actor.parameters())
actor, actor_optim = strategy.prepare((actor, actor_optim))
train_step(strategy, actor, actor_optim)
ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
with ctx as dirname:
rank0_dirname = [dirname]
dist.broadcast_object_list(rank0_dirname)
rank0_dirname = rank0_dirname[0]
model_path = os.path.join(
rank0_dirname, "model" if shard else f"model.pt")
strategy.save_model(actor, model_path, only_rank0=not shard)
optim_path = os.path.join(
rank0_dirname, "optim" if shard else "optim.pt")
strategy.save_optimizer(actor_optim, optim_path, only_rank0=not shard)
dist.barrier()
strategy.load_model(actor, model_path, strict=False)
strategy.load_optimizer(actor_optim, optim_path)
dist.barrier()
train_step(strategy, actor, actor_optim)
def run_dist(rank: int,
world_size: int,
port: int,
strategy_name: str,
shard: bool):
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
run_test_checkpoint(strategy_name, shard)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("strategy_name", ["ddp", "colossalai_gemini", "colossalai_zero2"])
@pytest.mark.parametrize("shard", [False, True])
@rerun_if_address_is_in_use()
def test_checkpoint(world_size: int,
strategy_name: str,
shard: bool):
spawn(run_dist,
world_size,
strategy_name=strategy_name,
shard=shard)
if __name__ == "__main__":
test_checkpoint(2, "colossalai_gemini", shard=False)