[zero] Suggests a minor change to confusing variable names in the ZeRO optimizer. (#3173)

* Fix confusing variable name in zero opt

* Apply lint

* Fix util func

* Fix minor util func

* Fix zero param optimizer name
This commit is contained in:
YH
2023-04-27 19:43:14 +09:00
committed by GitHub
parent 842768a174
commit a22407cc02
3 changed files with 85 additions and 76 deletions

View File

@@ -11,9 +11,9 @@ class ParameterStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
# param partitioning data structures
self._fp16_param_to_rank = dict()
self._rank_groupid_to_fp16_param_list = dict()
self._rank_group_id_to_flat_fp16_param = dict()
self._param_to_rank = dict()
self._rank_group_id_to_param_list = dict()
self._rank_group_id_to_flat_param = dict()
# param reduction data structures
self._is_param_reduced = dict()
@@ -29,7 +29,7 @@ class ParameterStore(BaseStore):
:type rank: int
"""
self._fp16_param_to_rank[tensor] = rank
self._param_to_rank[tensor] = rank
def get_param_rank(self, tensor: Tensor) -> int:
"""
@@ -38,7 +38,7 @@ class ParameterStore(BaseStore):
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
"""
return self._fp16_param_to_rank[tensor]
return self._param_to_rank[tensor]
def belongs_to_current_rank(self, tensor) -> bool:
"""
@@ -51,29 +51,29 @@ class ParameterStore(BaseStore):
:rtype: bool
"""
tensor_rank = self._fp16_param_to_rank[tensor]
tensor_rank = self._param_to_rank[tensor]
return tensor_rank == self._local_rank
def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
if rank not in self._rank_groupid_to_fp16_param_list:
self._rank_groupid_to_fp16_param_list[rank] = dict()
def add_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
if rank not in self._rank_group_id_to_param_list:
self._rank_group_id_to_param_list[rank] = dict()
if group_id not in self._rank_groupid_to_fp16_param_list[rank]:
self._rank_groupid_to_fp16_param_list[rank][group_id] = []
if group_id not in self._rank_group_id_to_param_list[rank]:
self._rank_group_id_to_param_list[rank][group_id] = []
self._rank_groupid_to_fp16_param_list[rank][group_id].extend(tensor_list)
self._rank_group_id_to_param_list[rank][group_id].extend(tensor_list)
def get_fp16_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
return self._rank_groupid_to_fp16_param_list[rank][group_id]
def get_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
return self._rank_group_id_to_param_list[rank][group_id]
def add_flat_fp16_param_by_rank_group(self, rank, group_id, tensor) -> None:
if rank not in self._rank_group_id_to_flat_fp16_param:
self._rank_group_id_to_flat_fp16_param[rank] = dict()
def add_flat_param_by_rank_group(self, rank, group_id, tensor) -> None:
if rank not in self._rank_group_id_to_flat_param:
self._rank_group_id_to_flat_param[rank] = dict()
self._rank_group_id_to_flat_fp16_param[rank][group_id] = tensor
self._rank_group_id_to_flat_param[rank][group_id] = tensor
def get_flat_fp16_param_by_rank_group(self, rank, group_id) -> Tensor:
return self._rank_group_id_to_flat_fp16_param[rank][group_id]
def get_flat_param_by_rank_group(self, rank, group_id) -> Tensor:
return self._rank_group_id_to_flat_param[rank][group_id]
def is_param_reduced(self, tensor):
return self._is_param_reduced[tensor]