mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[booster] refactor all dp fashion plugins (#3684)
* [booster] add dp plugin base * [booster] inherit dp plugin base * [booster] refactor unit tests
This commit is contained in:
@@ -117,34 +117,9 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
|
||||
assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])
|
||||
|
||||
|
||||
def check_dataloader_sharding():
|
||||
plugin = GeminiPlugin()
|
||||
|
||||
# create a custom dasetset with 0 to 10
|
||||
dataset = torch.utils.data.TensorDataset(torch.arange(0, 10))
|
||||
train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2)
|
||||
|
||||
# get the first batch of data
|
||||
batch = next(iter(train_dataloader))[0].cuda()
|
||||
is_rank_0 = dist.get_rank() == 0
|
||||
|
||||
if is_rank_0:
|
||||
batch_to_compare = batch.clone()
|
||||
else:
|
||||
batch_to_compare = batch
|
||||
# pass to the rank 1 value to rank 0
|
||||
dist.broadcast(batch_to_compare, src=1)
|
||||
|
||||
# compare on rank 0
|
||||
if is_rank_0:
|
||||
assert not torch.equal(batch,
|
||||
batch_to_compare), 'Same number was found across ranks but expected it to be different'
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||
# init dist env
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_dataloader_sharding()
|
||||
check_gemini_plugin(early_stop=early_stop)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user