refactor kernel (#142)

This commit is contained in:
ver217
2022-01-13 16:47:17 +08:00
committed by GitHub
parent 4a3d3446b0
commit f68eddfb3d
24 changed files with 334 additions and 414 deletions

View File

@@ -90,8 +90,7 @@ class FusedSGD(Optimizer):
[0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_sgd = colossal_C.multi_tensor_sgd
else:
raise RuntimeError(
'apex.optimizers.FusedSGD requires cuda extensions')
raise RuntimeError('FusedSGD requires cuda extensions')
def __setstate__(self, state):
super(FusedSGD, self).__setstate__(state)