mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 06:00:07 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -3,4 +3,4 @@ from .gradient_store import GradientStore
|
||||
from .parameter_store import ParameterStore
|
||||
from .tensor_bucket import TensorBucket
|
||||
|
||||
__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket']
|
||||
__all__ = ["GradientStore", "ParameterStore", "BucketStore", "TensorBucket"]
|
||||
|
@@ -3,7 +3,6 @@ from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class BaseStore:
|
||||
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
self._world_size = dist.get_world_size(group=torch_pg)
|
||||
self._local_rank = dist.get_rank(group=torch_pg)
|
||||
|
@@ -9,7 +9,6 @@ from .base_store import BaseStore
|
||||
|
||||
|
||||
class BucketStore(BaseStore):
|
||||
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
super().__init__(torch_pg)
|
||||
|
||||
@@ -38,8 +37,7 @@ class BucketStore(BaseStore):
|
||||
return self._num_elements_in_bucket
|
||||
|
||||
def reset_num_elements_in_bucket(self):
|
||||
"""Set the number of elements in bucket to zero.
|
||||
"""
|
||||
"""Set the number of elements in bucket to zero."""
|
||||
|
||||
self._num_elements_in_bucket = 0
|
||||
|
||||
@@ -54,7 +52,7 @@ class BucketStore(BaseStore):
|
||||
|
||||
self._param_list.append(param)
|
||||
self._padding_size.append(padding_size)
|
||||
self._num_elements_in_bucket += (param.numel() + padding_size)
|
||||
self._num_elements_in_bucket += param.numel() + padding_size
|
||||
self.current_group_id = group_id
|
||||
|
||||
# number of tensors in current bucket
|
||||
@@ -119,8 +117,7 @@ class BucketStore(BaseStore):
|
||||
return self.grad_to_param_mapping[id(grad)]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the bucket storage after reduction, only release the tensors have been reduced
|
||||
"""
|
||||
"""Reset the bucket storage after reduction, only release the tensors have been reduced"""
|
||||
cur_offset = self.offset_list.pop(0)
|
||||
self._param_list = self._param_list[cur_offset:]
|
||||
self._padding_size = self._padding_size[cur_offset:]
|
||||
|
@@ -1,13 +1,11 @@
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
from torch._utils import _flatten_dense_tensors
|
||||
|
||||
from .base_store import BaseStore
|
||||
|
||||
|
||||
class GradientStore(BaseStore):
|
||||
|
||||
def __init__(self, *args, partition_grad: bool = False):
|
||||
super().__init__(*args)
|
||||
"""
|
||||
|
@@ -5,7 +5,6 @@ from .base_store import BaseStore
|
||||
|
||||
|
||||
class ParameterStore(BaseStore):
|
||||
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
super().__init__(torch_pg)
|
||||
|
||||
|
@@ -2,7 +2,6 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
|
||||
class TensorBucket:
|
||||
|
||||
def __init__(self, size):
|
||||
self._max_size = size
|
||||
self._current_size = 0
|
||||
@@ -26,8 +25,7 @@ class TensorBucket:
|
||||
tensor_size = tensor.numel()
|
||||
|
||||
if not allow_oversize and self.will_exceed_max_size(tensor_size):
|
||||
msg = f"The param bucket max size {self._max_size} is exceeded" \
|
||||
+ f"by tensor (size {tensor_size})"
|
||||
msg = f"The param bucket max size {self._max_size} is exceeded" + f"by tensor (size {tensor_size})"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
self._bucket.append(tensor)
|
||||
|
Reference in New Issue
Block a user