mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -7,12 +7,14 @@ from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.layer import cross_entropy_1d
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
|
||||
CONFIG = dict(
|
||||
parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")),
|
||||
)
|
||||
|
||||
|
||||
def check_dist_crossentropy(rank, world_size, port, ignore_index):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl')
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
|
||||
|
||||
# prepare data
|
||||
pred = torch.randn(2, 4, 8, requires_grad=True)
|
||||
@@ -25,10 +27,11 @@ def check_dist_crossentropy(rank, world_size, port, ignore_index):
|
||||
org_loss = F.cross_entropy(org_pred, org_labels)
|
||||
|
||||
dist_pred = pred.chunk(world_size, -1)[rank]
|
||||
dist_loss = cross_entropy_1d(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index)
|
||||
dist_loss = cross_entropy_1d(dist_pred.to("cuda"), labels.to("cuda"), ignore_index=ignore_index)
|
||||
|
||||
assert torch.allclose(org_loss, dist_loss,
|
||||
atol=1e-5), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}"
|
||||
assert torch.allclose(
|
||||
org_loss, dist_loss, atol=1e-5
|
||||
), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}"
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@@ -38,5 +41,5 @@ def test_dist_crossentropy():
|
||||
spawn(check_dist_crossentropy, 2, ignore_index=ignore_index)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_dist_crossentropy()
|
||||
|
Reference in New Issue
Block a user