mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[zero] Suggests a minor change to confusing variable names in the ZeRO optimizer. (#3173)
* Fix confusing variable name in zero opt * Apply lint * Fix util func * Fix minor util func * Fix zero param optimizer name
This commit is contained in:
@@ -91,10 +91,18 @@ def get_grad_accumulate_object(tensor):
|
||||
return grad_acc_obj
|
||||
|
||||
|
||||
def split_half_float_double(tensor_list):
|
||||
def split_by_dtype(tensor_list):
|
||||
"""
|
||||
Splits a list of PyTorch tensors into sublists based on their data type.
|
||||
|
||||
:param tensor_list: A list of PyTorch tensors.
|
||||
:type tensor_list: list[torch.Tensor]
|
||||
:return: A list of sublists, where each sublist contains tensors of a specific data type.
|
||||
:rtype: list[list[torch.Tensor]]
|
||||
"""
|
||||
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
|
||||
buckets = []
|
||||
for i, dtype in enumerate(dtypes):
|
||||
for _, dtype in enumerate(dtypes):
|
||||
bucket = [t for t in tensor_list if t.type() == dtype]
|
||||
if bucket:
|
||||
buckets.append(bucket)
|
||||
|
Reference in New Issue
Block a user