[misc] Use dist logger in plugins (#6011)

* use dist logger in plugins

* remove trash

* print on rank 0

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
Edenzzzz
2024-08-20 10:32:41 +08:00
committed by GitHub
parent f1c3266a94
commit dcc44aab8d
8 changed files with 101 additions and 70 deletions

View File

@@ -1,7 +1,6 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
import math
import warnings
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
import torch
@@ -136,7 +135,7 @@ class GeminiOptimizer(OptimizerWrapper):
self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
self.verbose = verbose
self.param_groups_backup = list()
self.logger = get_dist_logger()
# Mapping from integer id to real/fake param tensor, used for checkpointing.
self.id_to_real_params: Dict[int, Parameter] = dict()
self.id_to_fake_params: Dict[int, Parameter] = dict()
@@ -148,9 +147,10 @@ class GeminiOptimizer(OptimizerWrapper):
for name, param in module.named_parameters():
if is_ddp_ignored(param):
if param.requires_grad:
warnings.warn(
self.logger.warning(
f"Parameter `{name}` is ignored by DDP but requires gradient! "
"You should handle its optimizer update by yourself!"
"You should handle its optimizer update by yourself!",
ranks=[0],
)
else:
ddp_param_list.append(param)
@@ -842,7 +842,9 @@ class GeminiOptimizer(OptimizerWrapper):
*args,
**kwargs,
) -> torch.Tensor:
warnings.warn(f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm")
self.logger.warning(
f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0]
)
class GeminiAdamOptimizer(GeminiOptimizer):