refactor kernel (#142)

This commit is contained in:
ver217
2022-01-13 16:47:17 +08:00
committed by GitHub
parent 4a3d3446b0
commit f68eddfb3d
24 changed files with 334 additions and 414 deletions

View File

@@ -73,8 +73,7 @@ class FusedLAMB(torch.optim.Optimizer):
[0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_lamb = colossal_C.multi_tensor_lamb
else:
raise RuntimeError(
'apex.optimizers.FusedLAMB requires cuda extensions')
raise RuntimeError('FusedLAMB requires cuda extensions')
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none