[legacy] move communication and nn to legacy and refactor logger (#4671)

* [legacy] move communication to legacy (#4640)

* [legacy] refactor logger and clean up legacy codes (#4654)

* [legacy] make logger independent to gpc

* [legacy] make optim independent to registry

* [legacy] move test engine to legacy

* [legacy] move nn to legacy (#4656)

* [legacy] move nn to legacy

* [checkpointio] fix save hf config

* [test] remove useledd rpc pp test

* [legacy] fix nn init

* [example] skip tutorial hybriad parallel example

* [devops] test doc check

* [devops] test doc check
This commit is contained in:
Hongxin Liu
2023-09-11 16:24:28 +08:00
committed by GitHub
parent 536397cc95
commit 554aa9592e
170 changed files with 781 additions and 758 deletions

View File

@@ -1,12 +1,14 @@
from .activation_checkpoint import checkpoint
from .checkpointing import load_checkpoint, save_checkpoint
from .common import (
_cast_float,
clip_grad_norm_fp32,
conditional_context,
copy_tensor_parallel_attributes,
count_zeros_fp32,
disposable,
ensure_path_exists,
free_storage,
is_ddp_ignored,
is_dp_rank_0,
is_model_parallel_parameter,
@@ -72,4 +74,6 @@ __all__ = [
'disposable',
'colo_set_cpu_memory_capacity',
'colo_get_cpu_memory_capacity',
'_cast_float',
'free_storage',
]

View File

@@ -470,3 +470,22 @@ def disposable(func: Callable) -> Callable:
return func(*args, **kwargs)
return wrapper
def free_storage(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""
if data.storage().size() > 0:
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
# is the sole occupant of the Storage.
assert data.storage_offset() == 0
data.storage().resize_(0)
def _cast_float(args, dtype: torch.dtype):
if isinstance(args, torch.Tensor) and torch.is_floating_point(args):
args = args.to(dtype)
elif isinstance(args, (list, tuple)):
args = type(args)(_cast_float(t, dtype) for t in args)
elif isinstance(args, dict):
args = {k: _cast_float(v, dtype) for k, v in args.items()}
return args

View File

@@ -12,12 +12,10 @@ from torch.utils.data import DataLoader, Dataset, Sampler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.registry import DATA_SAMPLERS
T_co = TypeVar('T_co', covariant=True)
@DATA_SAMPLERS.register_module
class DataParallelSampler(Sampler):
"""A data sampler for distributed data parallelism.