[plugin] support get_grad_norm (#6115)

This commit is contained in:
Hongxin Liu
2024-11-05 18:12:47 +08:00
committed by GitHub
parent 13ffa08cfa
commit a15ab139ad
8 changed files with 40 additions and 2 deletions

View File

@@ -76,6 +76,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True)
optimizer.step()
grad_norm = optimizer.get_grad_norm()
assert grad_norm is None or isinstance(grad_norm, float)
except Exception as e:
return repr(e)