[zero] Suggests a minor change to confusing variable names in the ZeRO optimizer. (#3173)

* Fix confusing variable name in zero opt

* Apply lint

* Fix util func

* Fix minor util func

* Fix zero param optimizer name
This commit is contained in:
YH
2023-04-27 19:43:14 +09:00
committed by GitHub
parent 842768a174
commit a22407cc02
3 changed files with 85 additions and 76 deletions

View File

@@ -91,10 +91,18 @@ def get_grad_accumulate_object(tensor):
return grad_acc_obj
def split_half_float_double(tensor_list):
def split_by_dtype(tensor_list):
"""
Splits a list of PyTorch tensors into sublists based on their data type.
:param tensor_list: A list of PyTorch tensors.
:type tensor_list: list[torch.Tensor]
:return: A list of sublists, where each sublist contains tensors of a specific data type.
:rtype: list[list[torch.Tensor]]
"""
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
buckets = []
for i, dtype in enumerate(dtypes):
for _, dtype in enumerate(dtypes):
bucket = [t for t in tensor_list if t.type() == dtype]
if bucket:
buckets.append(bucket)