[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pytest
import torch
import torch.nn.functional as F
@@ -44,20 +43,19 @@ def forward_inplace(x, weight):
@parameterize("use_reentrant", [True, False])
@parameterize("cpu_offload", [True, False])
def test_activation_checkpointing(cpu_offload, use_reentrant):
# as seed manager is singleton
# if we don't reset seeds here,
# other tests might affect this test
reset_seeds()
# We put initialization here to avoid change cuda rng state below
inputs = torch.rand(2, 2, requires_grad=True, device='cuda')
weight = torch.rand(2, 4, requires_grad=True, device='cuda')
inputs = torch.rand(2, 2, requires_grad=True, device="cuda")
weight = torch.rand(2, 4, requires_grad=True, device="cuda")
# Get a copy of input tensors
inputs_ = torch.empty(2, 2, requires_grad=True, device='cuda')
inputs_ = torch.empty(2, 2, requires_grad=True, device="cuda")
inputs_.data.copy_(inputs.data)
weight_ = torch.empty(2, 4, requires_grad=True, device='cuda')
weight_ = torch.empty(2, 4, requires_grad=True, device="cuda")
weight_.data.copy_(weight.data)
add_seed(ParallelMode.GLOBAL, 1024)
@@ -83,7 +81,7 @@ def test_activation_checkpointing(cpu_offload, use_reentrant):
loss = out.sum()
loss.backward()
assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match'
assert torch.all(inputs.grad == inputs_.grad), "Gradient of the input does not match"
torch.cuda.empty_cache()
# Extra test for use_reentrant=False
@@ -110,7 +108,7 @@ def test_activation_checkpointing(cpu_offload, use_reentrant):
loss = out.sum()
loss.backward()
assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match'
assert torch.all(inputs.grad == inputs_.grad), "Gradient of the input does not match"
torch.cuda.empty_cache()
# as seed manager is singleton