mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -15,8 +15,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
class DPPluginWrapper(DPPluginBase):
|
||||
"""This is a wrapper class for testing DP plugin initialization and dataloader creation.
|
||||
"""
|
||||
"""This is a wrapper class for testing DP plugin initialization and dataloader creation."""
|
||||
|
||||
def configure(
|
||||
self,
|
||||
@@ -73,13 +72,14 @@ def check_dataloader_sharding():
|
||||
|
||||
# compare on rank 0
|
||||
if is_rank_0:
|
||||
assert not torch.equal(batch,
|
||||
batch_to_compare), 'Same number was found across ranks but expected it to be different'
|
||||
assert not torch.equal(
|
||||
batch, batch_to_compare
|
||||
), "Same number was found across ranks but expected it to be different"
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
# init dist env
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_dataloader_sharding()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user