fix move fp32 shards (#1604)

This commit is contained in:
ver217 2022-09-16 17:33:16 +08:00 committed by GitHub
parent eac1b79371
commit c9e8ce67b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -288,6 +288,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
fp32_shards_used_cuda_margin_mem = 0 fp32_shards_used_cuda_margin_mem = 0
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group['params']:
if p.colo_attr.saved_grad.is_null():
continue
shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size() shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size()
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem: if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
colo_model_data_tensor_move_inline(self.master_params[p], torch.cuda.current_device()) colo_model_data_tensor_move_inline(self.master_params[p], torch.cuda.current_device())