mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -44,20 +43,19 @@ def forward_inplace(x, weight):
|
||||
@parameterize("use_reentrant", [True, False])
|
||||
@parameterize("cpu_offload", [True, False])
|
||||
def test_activation_checkpointing(cpu_offload, use_reentrant):
|
||||
|
||||
# as seed manager is singleton
|
||||
# if we don't reset seeds here,
|
||||
# other tests might affect this test
|
||||
reset_seeds()
|
||||
|
||||
# We put initialization here to avoid change cuda rng state below
|
||||
inputs = torch.rand(2, 2, requires_grad=True, device='cuda')
|
||||
weight = torch.rand(2, 4, requires_grad=True, device='cuda')
|
||||
inputs = torch.rand(2, 2, requires_grad=True, device="cuda")
|
||||
weight = torch.rand(2, 4, requires_grad=True, device="cuda")
|
||||
|
||||
# Get a copy of input tensors
|
||||
inputs_ = torch.empty(2, 2, requires_grad=True, device='cuda')
|
||||
inputs_ = torch.empty(2, 2, requires_grad=True, device="cuda")
|
||||
inputs_.data.copy_(inputs.data)
|
||||
weight_ = torch.empty(2, 4, requires_grad=True, device='cuda')
|
||||
weight_ = torch.empty(2, 4, requires_grad=True, device="cuda")
|
||||
weight_.data.copy_(weight.data)
|
||||
|
||||
add_seed(ParallelMode.GLOBAL, 1024)
|
||||
@@ -83,7 +81,7 @@ def test_activation_checkpointing(cpu_offload, use_reentrant):
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match'
|
||||
assert torch.all(inputs.grad == inputs_.grad), "Gradient of the input does not match"
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Extra test for use_reentrant=False
|
||||
@@ -110,7 +108,7 @@ def test_activation_checkpointing(cpu_offload, use_reentrant):
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match'
|
||||
assert torch.all(inputs.grad == inputs_.grad), "Gradient of the input does not match"
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# as seed manager is singleton
|
||||
|
Reference in New Issue
Block a user