[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -6,17 +6,18 @@ class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
raise ValueError("Decay must be between 0 and 1")
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
else torch.tensor(-1, dtype=torch.int))
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
"num_updates", torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int)
)
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '')
s_name = name.replace(".", "")
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
@@ -24,7 +25,7 @@ class LitEma(nn.Module):
def reset_num_updates(self):
del self.num_updates
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
def forward(self, model):
decay = self.decay