[zero] add L2 gradient clipping for ZeRO (#2112)

* [zero] add L2 gradient clipping

* [testing] add MlpModel

* [zero] add unit test for grad clipping

* fix atol
This commit is contained in:
HELSON
2022-12-09 18:09:17 +08:00
committed by GitHub
parent 70a8556946
commit 63fbba3c19
5 changed files with 194 additions and 11 deletions

View File

@@ -302,7 +302,11 @@ class ZeroDDP(ColoDDP):
chunk.chunk_total.div_(chunk.pg_size)
else:
chunk.cuda_shard.div_(chunk.pg_size)
# check overflow elements
self.overflow_counter += chunk.has_inf_or_nan
# record l2 norm for gradient clipping
if chunk.l2_norm_flag:
chunk.set_l2_norm()
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
return empty_grad