[booster] refactor all dp fashion plugins (#3684)

* [booster] add dp plugin base

* [booster] inherit dp plugin base

* [booster] refactor unit tests
This commit is contained in:
Hongxin Liu
2023-05-05 19:36:10 +08:00
committed by GitHub
parent b49020c1b1
commit d0915f54f4
8 changed files with 190 additions and 308 deletions

View File

@@ -117,34 +117,9 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])
def check_dataloader_sharding():
plugin = GeminiPlugin()
# create a custom dasetset with 0 to 10
dataset = torch.utils.data.TensorDataset(torch.arange(0, 10))
train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2)
# get the first batch of data
batch = next(iter(train_dataloader))[0].cuda()
is_rank_0 = dist.get_rank() == 0
if is_rank_0:
batch_to_compare = batch.clone()
else:
batch_to_compare = batch
# pass to the rank 1 value to rank 0
dist.broadcast(batch_to_compare, src=1)
# compare on rank 0
if is_rank_0:
assert not torch.equal(batch,
batch_to_compare), 'Same number was found across ranks but expected it to be different'
def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_dataloader_sharding()
check_gemini_plugin(early_stop=early_stop)