mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 06:52:46 +00:00
commit
f8a0e7fb01
@ -6,7 +6,7 @@ All notable changes to this project will be documented in this file.
|
|||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
- Unifed distributed layers
|
- Unified distributed layers
|
||||||
- MoE support
|
- MoE support
|
||||||
- DevOps tools such as github action, code review automation, etc.
|
- DevOps tools such as github action, code review automation, etc.
|
||||||
- New project official website
|
- New project official website
|
||||||
@ -33,4 +33,4 @@ The first beta version of Colossal-AI. Thanks to all contributors for the effort
|
|||||||
### Added
|
### Added
|
||||||
|
|
||||||
- Initial architecture of the system
|
- Initial architecture of the system
|
||||||
- Features such as tensor parallelism, gradient clipping, gradient accumulation
|
- Features such as tensor parallelism, gradient clipping, gradient accumulation
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
[](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
|
[](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
|
||||||
[](https://colossalai.readthedocs.io/en/latest/?badge=latest)
|
[](https://colossalai.readthedocs.io/en/latest/?badge=latest)
|
||||||
[](https://codebeat.co/projects/github-com-hpcaitech-colossalai-main)
|
[](https://codebeat.co/projects/github-com-hpcaitech-colossalai-main)
|
||||||
|
[](https://huggingface.co/hpcai-tech)
|
||||||
[](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w)
|
[](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w)
|
||||||
[](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png)
|
[](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png)
|
||||||
|
|
||||||
|
@ -14,8 +14,10 @@
|
|||||||
[](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
|
[](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
|
||||||
[](https://colossalai.readthedocs.io/en/latest/?badge=latest)
|
[](https://colossalai.readthedocs.io/en/latest/?badge=latest)
|
||||||
[](https://codebeat.co/projects/github-com-hpcaitech-colossalai-main)
|
[](https://codebeat.co/projects/github-com-hpcaitech-colossalai-main)
|
||||||
|
[](https://huggingface.co/hpcai-tech)
|
||||||
[](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w)
|
[](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w)
|
||||||
[](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png)
|
[](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png)
|
||||||
|
|
||||||
|
|
||||||
| [English](README.md) | [中文](README-zh-Hans.md) |
|
| [English](README.md) | [中文](README-zh-Hans.md) |
|
||||||
|
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit 62904e4ff2f3261c5469c773faa3d9307b6f16f4
|
Subproject commit 9757a137495a8fc8b12133087cffe3e4a97ed2cb
|
@ -1,8 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
from colossalai.registry import OPHOOKS
|
from colossalai.registry import OPHOOKS
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
|
|
||||||
from ._base_ophook import BaseOpHook
|
from ._base_ophook import BaseOpHook
|
||||||
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||||
|
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
@OPHOOKS.register_module
|
@OPHOOKS.register_module
|
||||||
@ -11,32 +15,50 @@ class ZeroHook(BaseOpHook):
|
|||||||
A hook to process sharded param for ZeRO method.
|
A hook to process sharded param for ZeRO method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, shard_strategy: BaseShardStrategy):
|
def __init__(self, shard_strategy: BaseShardStrategy, memstarts_collector: Optional[MemStatsCollector]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.shard_strategy = shard_strategy
|
self.shard_strategy = shard_strategy
|
||||||
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
|
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
|
||||||
self.computing_device = torch.device(f'cuda:{get_current_device()}')
|
self.computing_device = torch.device(f'cuda:{get_current_device()}')
|
||||||
|
|
||||||
|
self._memstarts_collector = memstarts_collector
|
||||||
|
|
||||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||||
|
tensor_list = []
|
||||||
|
global_model_data_tracer = ModelDataTracer()
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
assert hasattr(param, 'col_attr')
|
assert hasattr(param, 'col_attr')
|
||||||
self.shard_strategy.gather([param.col_attr.data])
|
tensor_list.append(param.col_attr.data)
|
||||||
|
self.shard_strategy.gather(tensor_list)
|
||||||
|
for param in module.parameters():
|
||||||
if param.col_attr.data.device != self.computing_device:
|
if param.col_attr.data.device != self.computing_device:
|
||||||
param.col_attr.data.to(self.computing_device)
|
param.col_attr.data.to(self.computing_device)
|
||||||
|
global_model_data_tracer.add_tensor(param.col_attr.data.payload)
|
||||||
param.data = param.col_attr.data.payload
|
param.data = param.col_attr.data.payload
|
||||||
|
|
||||||
|
if self._memstarts_collector:
|
||||||
|
self._memstarts_collector.sample_memstats()
|
||||||
|
|
||||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||||
|
tensor_list = []
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
assert hasattr(param, 'col_attr')
|
assert hasattr(param, 'col_attr')
|
||||||
self.shard_strategy.shard([param.col_attr.data])
|
tensor_list.append(param.col_attr.data)
|
||||||
param.data = torch.empty([], dtype=param.col_attr.data.dtype, device=param.col_attr.data.payload.device)
|
self.shard_strategy.shard(tensor_list)
|
||||||
|
for param in module.parameters():
|
||||||
|
param.col_attr.remove_torch_payload()
|
||||||
|
|
||||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||||
|
tensor_list = []
|
||||||
|
global_model_data_tracer = ModelDataTracer()
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
assert hasattr(param, 'col_attr')
|
assert hasattr(param, 'col_attr')
|
||||||
self.shard_strategy.gather([param.col_attr.data])
|
tensor_list.append(param.col_attr.data)
|
||||||
|
self.shard_strategy.gather(tensor_list)
|
||||||
|
for param in module.parameters():
|
||||||
if param.col_attr.data.device != self.computing_device:
|
if param.col_attr.data.device != self.computing_device:
|
||||||
param.col_attr.data.to(self.computing_device)
|
param.col_attr.data.to(self.computing_device)
|
||||||
|
global_model_data_tracer.add_tensor(param.col_attr.data.payload)
|
||||||
param.data = param.col_attr.data.payload
|
param.data = param.col_attr.data.payload
|
||||||
# Store local accumulated grad shard
|
# Store local accumulated grad shard
|
||||||
if param.grad is not None:
|
if param.grad is not None:
|
||||||
@ -50,12 +72,17 @@ class ZeroHook(BaseOpHook):
|
|||||||
# The grad here must be locally computed full grad in this backward pass
|
# The grad here must be locally computed full grad in this backward pass
|
||||||
assert param.grad.shape == param.col_attr.data.origin_shape
|
assert param.grad.shape == param.col_attr.data.origin_shape
|
||||||
param.col_attr.bwd_count += 1
|
param.col_attr.bwd_count += 1
|
||||||
|
if self._memstarts_collector:
|
||||||
|
self._memstarts_collector.sample_memstats()
|
||||||
|
|
||||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||||
|
tensor_list = []
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
assert hasattr(param, 'col_attr')
|
assert hasattr(param, 'col_attr')
|
||||||
self.shard_strategy.shard([param.col_attr.data])
|
tensor_list.append(param.col_attr.data)
|
||||||
param.data = torch.empty([], dtype=param.col_attr.data.dtype, device=param.col_attr.data.payload.device)
|
self.shard_strategy.shard(tensor_list)
|
||||||
|
for param in module.parameters():
|
||||||
|
param.col_attr.remove_torch_payload()
|
||||||
|
|
||||||
def pre_iter(self):
|
def pre_iter(self):
|
||||||
pass
|
pass
|
||||||
|
@ -10,7 +10,7 @@ from .data_sampler import DataParallelSampler, get_dataloader
|
|||||||
from .gradient_accumulation import accumulate_gradient
|
from .gradient_accumulation import accumulate_gradient
|
||||||
from .memory import report_memory_usage
|
from .memory import report_memory_usage
|
||||||
from .timer import MultiTimer, Timer
|
from .timer import MultiTimer, Timer
|
||||||
#from .tensor_detector import TensorDetector
|
from .tensor_detector import TensorDetector
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',
|
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',
|
||||||
|
@ -1,60 +1,19 @@
|
|||||||
import torch
|
import torch
|
||||||
from colossalai.utils.commons.singleton_meta import SingletonMeta
|
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||||
from colossalai.zero.sharded_param import ShardedTensor
|
|
||||||
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
|
|
||||||
def col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
|
def col_move_to_cpu(t: torch.Tensor):
|
||||||
if isinstance(t, ShardedTensor):
|
assert isinstance(t, torch.Tensor)
|
||||||
target = t.payload
|
if t.device.type == 'cpu':
|
||||||
else:
|
return
|
||||||
target = t
|
|
||||||
return target.numel() * target.element_size()
|
ModelDataTracer().delete_tensor(t)
|
||||||
|
t.data = t.data.cpu()
|
||||||
|
|
||||||
|
|
||||||
class ModelDataTracer(metaclass=SingletonMeta):
|
def col_modeldata_allocate(device: torch.device) -> torch.Tensor:
|
||||||
"""
|
|
||||||
A singleton to trace model data usage during runtime.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._cpu_usage = 0
|
|
||||||
self._cuda_usage = 0
|
|
||||||
|
|
||||||
def trace_tensor(self, t: torch.Tensor):
|
|
||||||
mem_use = col_tensor_mem_usage(t)
|
|
||||||
if t.device.type == 'cpu':
|
|
||||||
self._cpu_usage += mem_use
|
|
||||||
elif t.device.type == 'cuda':
|
|
||||||
self._cuda_usage += mem_use
|
|
||||||
else:
|
|
||||||
raise RuntimeError
|
|
||||||
|
|
||||||
def detach_tensor(self, t: torch.Tensor):
|
|
||||||
mem_use = col_tensor_mem_usage(t)
|
|
||||||
if t.device.type == 'cpu':
|
|
||||||
self._cpu_usage -= mem_use
|
|
||||||
elif t.device.type == 'cuda':
|
|
||||||
self._cuda_usage -= mem_use
|
|
||||||
else:
|
|
||||||
raise RuntimeError
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cpu_usage(self):
|
|
||||||
return self._cpu_usage
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cuda_usage(self):
|
|
||||||
return self._cuda_usage
|
|
||||||
|
|
||||||
|
|
||||||
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
|
|
||||||
|
|
||||||
|
|
||||||
def col_allocate_payload(device: torch.device) -> torch.Tensor:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def col_release_payload(t: torch.Tensor):
|
def col_modeldata_release(t: torch.Tensor):
|
||||||
pass
|
pass
|
||||||
|
@ -6,7 +6,7 @@ from colossalai.utils import get_current_device
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def _get_cuda_memory_used(device: torch.device) -> int:
|
def get_cuda_memory_used(device: torch.device) -> int:
|
||||||
"""
|
"""
|
||||||
Get the free memory info of device.
|
Get the free memory info of device.
|
||||||
:param device: device id
|
:param device: device id
|
||||||
@ -87,7 +87,7 @@ class AsyncMemoryMonitor:
|
|||||||
while self.keep_measuring:
|
while self.keep_measuring:
|
||||||
max_usage = max(
|
max_usage = max(
|
||||||
max_usage,
|
max_usage,
|
||||||
_get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')),
|
get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')),
|
||||||
)
|
)
|
||||||
sleep(self.interval)
|
sleep(self.interval)
|
||||||
return max_usage
|
return max_usage
|
||||||
|
11
colossalai/utils/memory_tracer/commons.py
Normal file
11
colossalai/utils/memory_tracer/commons.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from colossalai.zero.sharded_param import ShardedTensor
|
||||||
|
from typing import Union
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
|
||||||
|
if isinstance(t, ShardedTensor):
|
||||||
|
target = t.payload
|
||||||
|
else:
|
||||||
|
target = t
|
||||||
|
return target.numel() * target.element_size()
|
81
colossalai/utils/memory_tracer/memstats_collector.py
Normal file
81
colossalai/utils/memory_tracer/memstats_collector.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||||
|
from .async_memtracer import get_cuda_memory_used
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class SamplingCounter:
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._samplint_cnt = 0
|
||||||
|
|
||||||
|
def advance(self):
|
||||||
|
self._samplint_cnt += 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sampling_cnt(self):
|
||||||
|
return self._samplint_cnt
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._samplint_cnt = 0
|
||||||
|
|
||||||
|
|
||||||
|
class MemStatsCollector:
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""
|
||||||
|
Collecting Memory Statistics.
|
||||||
|
It has two phases.
|
||||||
|
1. Collection Phase: collect memory usage statistics
|
||||||
|
2. Runtime Phase: do not collect statistics.
|
||||||
|
"""
|
||||||
|
self._sampling_cnter = SamplingCounter()
|
||||||
|
self._model_data_cuda = []
|
||||||
|
self._overall_cuda = []
|
||||||
|
|
||||||
|
# TODO(jiaruifang) Now no cpu mem stats collecting
|
||||||
|
self._model_data_cpu = []
|
||||||
|
self._overall_cpu = []
|
||||||
|
|
||||||
|
self._start_flag = False
|
||||||
|
|
||||||
|
def start_collection(self):
|
||||||
|
self._start_flag = True
|
||||||
|
|
||||||
|
def finish_collection(self):
|
||||||
|
self._start_flag = False
|
||||||
|
|
||||||
|
def sample_memstats(self) -> None:
|
||||||
|
"""
|
||||||
|
Sampling memory statistics.
|
||||||
|
Record the current model data CUDA memory usage as well as system CUDA memory usage.
|
||||||
|
"""
|
||||||
|
if self._start_flag:
|
||||||
|
sampling_cnt = self._sampling_cnter.sampling_cnt
|
||||||
|
assert sampling_cnt == len(self._overall_cuda)
|
||||||
|
self._model_data_cuda.append(ModelDataTracer().cuda_usage)
|
||||||
|
self._overall_cuda.append(get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
|
||||||
|
self._sampling_cnter.advance()
|
||||||
|
|
||||||
|
def fetch_memstats(self) -> (int, int):
|
||||||
|
"""
|
||||||
|
returns cuda usage of model data and overall cuda usage.
|
||||||
|
"""
|
||||||
|
sampling_cnt = self._sampling_cnter.sampling_cnt
|
||||||
|
if len(self._model_data_cuda) < sampling_cnt:
|
||||||
|
raise RuntimeError
|
||||||
|
return (self._model_data_cuda[sampling_cnt], self._overall_cuda[sampling_cnt])
|
||||||
|
|
||||||
|
def reset_sampling_cnter(self) -> None:
|
||||||
|
self._sampling_cnter.reset()
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self._model_data_cuda = []
|
||||||
|
self._overall_cuda = []
|
||||||
|
|
||||||
|
self._model_data_cpu = []
|
||||||
|
self._overall_cpu = []
|
||||||
|
|
||||||
|
self._start_flag = False
|
||||||
|
self._sampling_cnter.reset()
|
34
colossalai/utils/memory_tracer/model_data_memtracer.py
Normal file
34
colossalai/utils/memory_tracer/model_data_memtracer.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from colossalai.utils.commons.singleton_meta import SingletonMeta
|
||||||
|
from colossalai.utils.memory_tracer.commons import col_tensor_mem_usage
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDataTracer(metaclass=SingletonMeta):
|
||||||
|
"""
|
||||||
|
A singleton to trace model data usage during runtime.
|
||||||
|
We have to trigger our API (trace_tensor, detach_tensor) when do model-data memory operation,
|
||||||
|
including allocation, releasing and moving.
|
||||||
|
|
||||||
|
NOTE() now the class only trace cuda memory usage
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._cuda_usage = 0
|
||||||
|
|
||||||
|
def add_tensor(self, t: torch.Tensor):
|
||||||
|
assert isinstance(t, torch.Tensor), f"ModelDataTracer add_tensor() should accept a torch.Tensor"
|
||||||
|
mem_use = col_tensor_mem_usage(t)
|
||||||
|
self._cuda_usage += mem_use
|
||||||
|
|
||||||
|
def delete_tensor(self, t: torch.Tensor):
|
||||||
|
assert isinstance(t, torch.Tensor), f"ModelDataTracer delete_tensor() should accept a torch.Tensor"
|
||||||
|
mem_use = col_tensor_mem_usage(t)
|
||||||
|
self._cuda_usage -= mem_use
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cpu_usage(self):
|
||||||
|
return self._cpu_usage
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cuda_usage(self):
|
||||||
|
return self._cuda_usage
|
43
colossalai/utils/memory_tracer/test_memstats_collector.py
Normal file
43
colossalai/utils/memory_tracer/test_memstats_collector.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||||
|
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def test_mem_collector():
|
||||||
|
collector = MemStatsCollector()
|
||||||
|
|
||||||
|
collector.start_collection()
|
||||||
|
|
||||||
|
a = torch.randn(10).cuda()
|
||||||
|
|
||||||
|
# sampling at time 0
|
||||||
|
collector.sample_memstats()
|
||||||
|
|
||||||
|
m_a = torch.randn(10).cuda()
|
||||||
|
ModelDataTracer().add_tensor(m_a)
|
||||||
|
b = torch.randn(10).cuda()
|
||||||
|
|
||||||
|
# sampling at time 1
|
||||||
|
collector.sample_memstats()
|
||||||
|
|
||||||
|
a = b
|
||||||
|
|
||||||
|
# sampling at time 2
|
||||||
|
collector.sample_memstats()
|
||||||
|
|
||||||
|
collector.finish_collection()
|
||||||
|
collector.reset()
|
||||||
|
|
||||||
|
# do nothing after collection, just advance sampling cnter
|
||||||
|
collector.sample_memstats()
|
||||||
|
collector.sample_memstats()
|
||||||
|
|
||||||
|
cuda_use, overall_use = collector.fetch_memstats()
|
||||||
|
print(cuda_use, overall_use)
|
||||||
|
|
||||||
|
print(collector._model_data_cuda)
|
||||||
|
print(collector._overall_cuda)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_mem_collector()
|
@ -6,20 +6,25 @@ from torch.autograd.profiler import profile
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ReduceOp
|
from torch.distributed import ReduceOp
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwith
|
from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwidth
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
def _get_code_location(depth: int):
|
def _get_code_location(depth: int):
|
||||||
ret = ""
|
ret = []
|
||||||
length = len(inspect.stack())
|
length = min(len(inspect.stack()), depth + 1)
|
||||||
for i in range(3, min(length, depth + 1)):
|
for i in range(3, length):
|
||||||
upper_frame = inspect.stack()[i]
|
upper_frame = inspect.stack()[i]
|
||||||
function_name = inspect.stack()[i - 1].function
|
function_name = inspect.stack()[i - 1].function
|
||||||
info = upper_frame.filename + "(" + str(upper_frame.lineno) + "): " + function_name + "\n"
|
ret.append(upper_frame.filename)
|
||||||
ret += info
|
ret.append('(')
|
||||||
|
ret.append(str(upper_frame.lineno))
|
||||||
|
ret.append('): ')
|
||||||
|
ret.append(function_name)
|
||||||
|
if i != length - 1:
|
||||||
|
ret.append('\n')
|
||||||
|
|
||||||
return ret
|
return ''.join(ret)
|
||||||
|
|
||||||
|
|
||||||
torch_all_reduce = dist.all_reduce
|
torch_all_reduce = dist.all_reduce
|
||||||
@ -88,41 +93,52 @@ class CommProfiler(BaseProfiler):
|
|||||||
dist.reduce = torch_reduce
|
dist.reduce = torch_reduce
|
||||||
|
|
||||||
def to_tensorboard(self, writer):
|
def to_tensorboard(self, writer):
|
||||||
writer.add_text(tag="Collective Communication", text_string=self.result_list("\n\n"))
|
writer.add_text(tag="Collective Communication", text_string=self.result_str("\n\n"))
|
||||||
|
|
||||||
def to_file(self, filename: Path):
|
def to_file(self, filename: Path):
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w") as f:
|
||||||
f.write(self.result_list())
|
f.write(self.result_str())
|
||||||
|
|
||||||
def show(self):
|
def show(self):
|
||||||
print(self.result_list())
|
print(self.result_str())
|
||||||
|
|
||||||
def result_list(self, sep: str = "\n"):
|
def result_str(self, sep: str = "\n"):
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
def append(s: str):
|
def append(s: str = None):
|
||||||
res.append(s)
|
if s is not None:
|
||||||
|
res.append(s)
|
||||||
res.append(sep)
|
res.append(sep)
|
||||||
|
|
||||||
if self.warn_flag:
|
if self.warn_flag:
|
||||||
append("Warnning: there exists multiple communication operations in the same time. As a result, "
|
append("Warnning: there exists multiple communication operations in the same time. As a result, "
|
||||||
"the profiling result is not accurate.")
|
"the profiling result is not accurate.")
|
||||||
|
|
||||||
|
if self.total_cuda_time == 0:
|
||||||
|
return "No collective communication has been called yet!"
|
||||||
|
|
||||||
append("Collective communication profiling result:")
|
append("Collective communication profiling result:")
|
||||||
append("total cuda time: {}".format(_format_time(self.total_cuda_time)))
|
append("total cuda time: {}".format(_format_time(self.total_cuda_time)))
|
||||||
append("average bandwith: {}".format(_format_bandwith(self.total_comm_vol, self.total_cuda_time)))
|
append("average bandwidth: {}".format(_format_bandwidth(self.total_comm_vol, self.total_cuda_time)))
|
||||||
append("total number of calls: {}".format(self.total_count))
|
append("total number of calls: {}".format(self.total_count))
|
||||||
append("All events:\n----------------------------------------")
|
append("All events:")
|
||||||
|
|
||||||
|
seperation = '-' * 74
|
||||||
|
row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2
|
||||||
|
|
||||||
|
append(seperation)
|
||||||
|
append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls'))
|
||||||
|
append(seperation)
|
||||||
|
|
||||||
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
|
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
|
||||||
for location, event in show_list:
|
for location, event in show_list:
|
||||||
append(location)
|
append(location)
|
||||||
append("self cuda time: {}".format(_format_time(event.self_cuda_time)))
|
append(
|
||||||
append("{:.1f}% of total communication time".format(event.self_cuda_time / self.total_cuda_time * 100.0))
|
row_format.format('', _format_time(event.self_cuda_time),
|
||||||
append("self communication volme: {}".format(_format_memory(event.self_comm_vol)))
|
'{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0),
|
||||||
append("average bandwith: {}".format(_format_bandwith(event.self_comm_vol, event.self_cuda_time)))
|
_format_memory(event.self_comm_vol),
|
||||||
append("number of calls: {}".format(event.self_count))
|
_format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count))
|
||||||
append("----------------------------------------")
|
append()
|
||||||
|
|
||||||
return ''.join(res)
|
return ''.join(res)
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from torch.autograd.profiler import profile
|
from torch.autograd.profiler import profile
|
||||||
from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwith
|
from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwidth
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
@ -24,6 +24,7 @@ def _reduce_location(locations: List[str]) -> str:
|
|||||||
for lo in locations:
|
for lo in locations:
|
||||||
ret.append(lo)
|
ret.append(lo)
|
||||||
ret.append("\n")
|
ret.append("\n")
|
||||||
|
ret = ret[:-1]
|
||||||
return ''.join(ret)
|
return ''.join(ret)
|
||||||
|
|
||||||
|
|
||||||
@ -48,18 +49,23 @@ class PcieProfiler(BaseProfiler):
|
|||||||
TODO: Merge pcie profiler into communication profiler
|
TODO: Merge pcie profiler into communication profiler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, dtype: str = "fp32", depth: int = 1):
|
||||||
dtype: str = "fp32",
|
|
||||||
depth: int = 1,
|
|
||||||
total_count: int = 0,
|
|
||||||
total_pcie_vol: int = 0,
|
|
||||||
total_cuda_time: int = 0):
|
|
||||||
super().__init__(profiler_name="Pcie", priority=10)
|
super().__init__(profiler_name="Pcie", priority=10)
|
||||||
self.depth = depth
|
self.depth = depth
|
||||||
self.data_size = _get_size(dtype)
|
self.data_size = _get_size(dtype)
|
||||||
self.total_count = total_count
|
self.h2d_count = 0
|
||||||
self.total_pcie_vol = total_pcie_vol
|
self.h2d_time = 0
|
||||||
self.total_cuda_time = total_cuda_time
|
self.d2h_count = 0
|
||||||
|
self.d2h_time = 0
|
||||||
|
|
||||||
|
self.ops_record = dict()
|
||||||
|
self.profiler = None
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.h2d_count = 0
|
||||||
|
self.h2d_time = 0
|
||||||
|
self.d2h_count = 0
|
||||||
|
self.d2h_time = 0
|
||||||
|
|
||||||
self.ops_record = dict()
|
self.ops_record = dict()
|
||||||
self.profiler = None
|
self.profiler = None
|
||||||
@ -81,51 +87,62 @@ class PcieProfiler(BaseProfiler):
|
|||||||
for event in events:
|
for event in events:
|
||||||
if event.name == "aten::copy_":
|
if event.name == "aten::copy_":
|
||||||
t_shape = event.input_shapes[0]
|
t_shape = event.input_shapes[0]
|
||||||
if len(t_shape) == 0 or event.cuda_time_total == 0:
|
if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0:
|
||||||
continue
|
continue
|
||||||
current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total)
|
current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total)
|
||||||
self.total_count += current_comm_event.count
|
|
||||||
self.total_pcie_vol += current_comm_event.pcie_vol
|
|
||||||
self.total_cuda_time += current_comm_event.cuda_time
|
|
||||||
code_location = _reduce_location(event.stack[:self.depth])
|
code_location = _reduce_location(event.stack[:self.depth])
|
||||||
if code_location in self.ops_record:
|
if code_location in self.ops_record:
|
||||||
self.ops_record[code_location].add(current_comm_event)
|
self.ops_record[code_location].add(current_comm_event)
|
||||||
else:
|
else:
|
||||||
self.ops_record[code_location] = current_comm_event
|
self.ops_record[code_location] = current_comm_event
|
||||||
|
elif 'Memcpy HtoD' in event.name:
|
||||||
|
self.h2d_count += 1
|
||||||
|
self.h2d_time += event.cuda_time_total
|
||||||
|
elif 'Memcpy DtoH' in event.name:
|
||||||
|
self.d2h_count += 1
|
||||||
|
self.d2h_time += event.cuda_time_total
|
||||||
|
|
||||||
self.profiler = None
|
self.profiler = None
|
||||||
|
|
||||||
def to_tensorboard(self, writer):
|
def to_tensorboard(self, writer):
|
||||||
writer.add_text(tag="Data Transmission", text_string=self.result_list("\n\n"))
|
writer.add_text(tag="Data Transmission", text_string=self.result_str("\n\n"))
|
||||||
|
|
||||||
def to_file(self, filename: Path):
|
def to_file(self, filename: Path):
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w") as f:
|
||||||
f.write(self.result_list())
|
f.write(self.result_str())
|
||||||
|
|
||||||
def show(self):
|
def show(self):
|
||||||
print(self.result_list())
|
print(self.result_str())
|
||||||
|
|
||||||
def result_list(self, sep: str = "\n"):
|
def result_str(self, sep: str = "\n"):
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
def append(s: str):
|
def append(s: str = None):
|
||||||
res.append(s)
|
if s is not None:
|
||||||
|
res.append(s)
|
||||||
res.append(sep)
|
res.append(sep)
|
||||||
|
|
||||||
append("Pcie profiling result:")
|
append("Pcie profiling result:")
|
||||||
append("total cuda time: {}".format(_format_time(self.total_cuda_time)))
|
append("time of data transmission (CPU -> GPU): {}".format(_format_time(self.h2d_time)))
|
||||||
append("average bandwith: {}".format(_format_bandwith(self.total_pcie_vol, self.total_cuda_time)))
|
append("number of transmission (CPU -> GPU): {}".format(self.h2d_count))
|
||||||
append("total number of calls: {}".format(self.total_count))
|
append("time of data transmission (GPU -> CPU): {}".format(_format_time(self.d2h_time)))
|
||||||
append("All events:\n----------------------------------------")
|
append("number of transmission (GPU -> CPU): {}".format(self.d2h_count))
|
||||||
|
|
||||||
|
append("Possible data transmission events in PCIE:")
|
||||||
|
|
||||||
|
seperation = '-' * 62
|
||||||
|
row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2
|
||||||
|
|
||||||
|
append(seperation)
|
||||||
|
append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls'))
|
||||||
|
append(seperation)
|
||||||
|
|
||||||
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time)
|
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time)
|
||||||
for location, event in show_list:
|
for location, event in show_list:
|
||||||
append(location)
|
append(location)
|
||||||
append("cuda time: {}".format(_format_time(event.cuda_time)))
|
append(
|
||||||
append("{:.1f}% of total pcie time".format(event.cuda_time / self.total_cuda_time * 100.0))
|
row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol),
|
||||||
append("pcie volme: {}".format(_format_memory(event.pcie_vol)))
|
_format_bandwidth(event.pcie_vol, event.cuda_time), event.count))
|
||||||
append("average bandwith: {}".format(_format_bandwith(event.pcie_vol, event.cuda_time)))
|
append()
|
||||||
append("number of calls: {}".format(event.count))
|
|
||||||
append("----------------------------------------")
|
|
||||||
|
|
||||||
return ''.join(res)
|
return ''.join(res)
|
||||||
|
@ -32,7 +32,7 @@ def _format_memory(nbytes):
|
|||||||
return str(nbytes) + ' B'
|
return str(nbytes) + ' B'
|
||||||
|
|
||||||
|
|
||||||
def _format_bandwith(volme: float or int, time_us: int):
|
def _format_bandwidth(volme: float or int, time_us: int):
|
||||||
sec_div_mb = (1000.0 / 1024.0)**2
|
sec_div_mb = (1000.0 / 1024.0)**2
|
||||||
mb_per_sec = volme / time_us * sec_div_mb
|
mb_per_sec = volme / time_us * sec_div_mb
|
||||||
|
|
||||||
|
1
colossalai/utils/tensor_detector/__init__.py
Normal file
1
colossalai/utils/tensor_detector/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .tensor_detector import TensorDetector
|
128
colossalai/utils/tensor_detector/readme.md
Normal file
128
colossalai/utils/tensor_detector/readme.md
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
# Tensor Detector
|
||||||
|
|
||||||
|
This tool supports you to detect tensors on both CPU and GPU. However, there will always be some strange tensors on CPU, including the rng state of PyTorch.
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
An example is worth than a thousand words.
|
||||||
|
|
||||||
|
The code below defines a simple MLP module, with which we will show you how to use the tool.
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(nn.Linear(64, 8),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(8, 32))
|
||||||
|
def forward(self, x):
|
||||||
|
return self.mlp(x)
|
||||||
|
```
|
||||||
|
|
||||||
|
And here is how to use the tool.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from colossalai.utils import TensorDetector
|
||||||
|
|
||||||
|
# create random data
|
||||||
|
data = torch.rand(64, requires_grad=True).cuda()
|
||||||
|
data.retain_grad()
|
||||||
|
# create the module
|
||||||
|
model = MLP().cuda()
|
||||||
|
# create the detector
|
||||||
|
# by passing the model to the detector, it can distinguish module parameters from common tensors
|
||||||
|
detector = TensorDetector(include_cpu=False, module=model)
|
||||||
|
detector.detect()
|
||||||
|
|
||||||
|
out = model(data)
|
||||||
|
|
||||||
|
detector.detect()
|
||||||
|
|
||||||
|
loss = out.sum()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
detector.detect()
|
||||||
|
```
|
||||||
|
|
||||||
|
I have made some comments on the right of the output for your understanding.
|
||||||
|
|
||||||
|
Note that the total `Mem` of all the tensors and parameters is not equal to `Total GPU Memery Allocated`. PyTorch's memory management is really complicated, and for models of a large scale, it's impossible to figure out clearly.
|
||||||
|
|
||||||
|
**The order of print is not equal to the order the tensor creates, but they are really close.**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
Tensor device shape grad dtype Mem
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
+ Tensor cuda:0 (64,) True torch.float32 256 B # data
|
||||||
|
+ mlp.0.weight cuda:0 (8, 64) True torch.float32 2.0 KB
|
||||||
|
+ mlp.0.bias cuda:0 (8,) True torch.float32 32 B
|
||||||
|
+ mlp.2.weight cuda:0 (32, 8) True torch.float32 1.0 KB
|
||||||
|
+ mlp.2.bias cuda:0 (32,) True torch.float32 128 B
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
Detect Location: "test_tensor_detector.py" line 27
|
||||||
|
Totle GPU Memery Allocated on cuda:0 is 4.5 KB
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
Tensor device shape grad dtype Mem
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
+ Tensor cuda:0 (8,) True torch.float32 32 B # activation
|
||||||
|
+ Tensor cuda:0 (32,) True torch.float32 128 B # output
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
Detect Location: "test_tensor_detector.py" line 30
|
||||||
|
Totle GPU Memery Allocated on cuda:0 is 5.5 KB
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
Tensor device shape grad dtype Mem
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
+ Tensor cuda:0 () True torch.float32 4 B # loss
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
Detect Location: "test_tensor_detector.py" line 32
|
||||||
|
Totle GPU Memery Allocated on cuda:0 is 6.0 KB
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
Tensor device shape grad dtype Mem
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
+ Tensor (with grad) cuda:0 (64,) True torch.float32 512 B # data with grad
|
||||||
|
+ mlp.0.weight (with grad) cuda:0 (8, 64) True torch.float32 4.0 KB # for use data.retain_grad()
|
||||||
|
+ mlp.0.bias (with grad) cuda:0 (8,) True torch.float32 64 B
|
||||||
|
+ mlp.2.weight (with grad) cuda:0 (32, 8) True torch.float32 2.0 KB
|
||||||
|
+ mlp.2.bias (with grad) cuda:0 (32,) True torch.float32 256 B
|
||||||
|
|
||||||
|
- mlp.0.weight cuda:0 (8, 64) True torch.float32 2.0 KB
|
||||||
|
- mlp.0.bias cuda:0 (8,) True torch.float32 32 B
|
||||||
|
- mlp.2.weight cuda:0 (32, 8) True torch.float32 1.0 KB
|
||||||
|
- mlp.2.bias cuda:0 (32,) True torch.float32 128 B
|
||||||
|
- Tensor cuda:0 (64,) True torch.float32 256 B
|
||||||
|
- Tensor cuda:0 (8,) True torch.float32 32 B # deleted activation
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
Detect Location: "test_tensor_detector.py" line 34
|
||||||
|
Totle GPU Memery Allocated on cuda:0 is 10.0 KB
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
Tensor device shape grad dtype Mem
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
+ Tensor cuda:0 (64,) False torch.float32 256 B
|
||||||
|
+ Tensor cuda:0 (8, 64) False torch.float32 2.0 KB
|
||||||
|
+ Tensor cuda:0 (8,) False torch.float32 32 B
|
||||||
|
+ Tensor cuda:0 (32, 8) False torch.float32 1.0 KB
|
||||||
|
+ Tensor cuda:0 (32,) False torch.float32 128 B
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
Detect Location: "test_tensor_detector.py" line 36
|
||||||
|
Totle GPU Memery Allocated on cuda:0 is 14.0 KB
|
||||||
|
------------------------------------------------------------------------------------------------------------
|
||||||
|
```
|
||||||
|
|
||||||
|
## Reference
|
||||||
|
|
||||||
|
This tool was inspired by https://github.com/Stonesjtu/pytorch_memlab/blob/master/pytorch_memlab/mem_reporter.py
|
||||||
|
and https://github.com/Oldpan/Pytorch-Memory-Utils
|
||||||
|
|
190
colossalai/utils/tensor_detector/tensor_detector.py
Normal file
190
colossalai/utils/tensor_detector/tensor_detector.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
import gc
|
||||||
|
import inspect
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Optional
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
|
LINE_WIDTH = 108
|
||||||
|
LINE = '-' * LINE_WIDTH + '\n'
|
||||||
|
|
||||||
|
class TensorDetector():
|
||||||
|
def __init__(self,
|
||||||
|
show_info: bool = True,
|
||||||
|
log: str = None,
|
||||||
|
include_cpu: bool = False,
|
||||||
|
module: Optional[nn.Module] = None
|
||||||
|
):
|
||||||
|
"""This class is an detector to detect tensor on different devices.
|
||||||
|
|
||||||
|
:param show_info: whether to print the info on screen, default True
|
||||||
|
:type show_info: bool
|
||||||
|
:param log: the file name to save the log
|
||||||
|
:type log: str
|
||||||
|
:param include_cpu: whether to detect tensor on cpu, default False
|
||||||
|
:type include_cpu: bool
|
||||||
|
:param module: when sending an `nn.Module` it, the detector can name the tensors detected better
|
||||||
|
:type module: Optional[nn.Module]
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.show_info = show_info
|
||||||
|
self.log = log
|
||||||
|
self.include_cpu = include_cpu
|
||||||
|
self.tensor_info = defaultdict(list)
|
||||||
|
self.saved_tensor_info = defaultdict(list)
|
||||||
|
self.order = []
|
||||||
|
self.detected = []
|
||||||
|
self.devices = []
|
||||||
|
self.info = ""
|
||||||
|
|
||||||
|
self.module = module
|
||||||
|
if isinstance(module, nn.Module):
|
||||||
|
# if module is an instance of nn.Module, we can name the parameter with its real name
|
||||||
|
for name, param in module.named_parameters():
|
||||||
|
self.tensor_info[id(param)].append(name)
|
||||||
|
self.tensor_info[id(param)].append(param.device)
|
||||||
|
self.tensor_info[id(param)].append(param.shape)
|
||||||
|
self.tensor_info[id(param)].append(param.requires_grad)
|
||||||
|
self.tensor_info[id(param)].append(param.dtype)
|
||||||
|
self.tensor_info[id(param)].append(self.get_tensor_mem(param))
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensor_mem(self, tensor):
|
||||||
|
# calculate the memory occupied by a tensor
|
||||||
|
memory_size = tensor.element_size() * tensor.storage().size()
|
||||||
|
if (tensor.is_leaf or tensor.retains_grad) and tensor.grad is not None:
|
||||||
|
grad_memory_size = tensor.grad.element_size() * tensor.grad.storage().size()
|
||||||
|
memory_size += grad_memory_size
|
||||||
|
return self.mem_format(memory_size)
|
||||||
|
|
||||||
|
|
||||||
|
def mem_format(self, real_memory_size):
|
||||||
|
# format the tensor memory into a reasonal magnitude
|
||||||
|
if real_memory_size >= 2 ** 30:
|
||||||
|
return str(real_memory_size / (2 ** 30)) + ' GB'
|
||||||
|
if real_memory_size >= 2 ** 20:
|
||||||
|
return str(real_memory_size / (2 ** 20)) + ' MB'
|
||||||
|
if real_memory_size >= 2 ** 10:
|
||||||
|
return str(real_memory_size / (2 ** 10)) + ' KB'
|
||||||
|
return str(real_memory_size) + ' B'
|
||||||
|
|
||||||
|
|
||||||
|
def collect_tensors_state(self):
|
||||||
|
for obj in gc.get_objects():
|
||||||
|
if torch.is_tensor(obj):
|
||||||
|
# skip cpu tensor when include_cpu is false and the tensor we have collected before
|
||||||
|
if (not self.include_cpu) and obj.device == torch.device('cpu'):
|
||||||
|
continue
|
||||||
|
self.detected.append(id(obj))
|
||||||
|
# skip paramters we had added in __init__ when module is an instance of nn.Module for the first epoch
|
||||||
|
if id(obj) not in self.tensor_info:
|
||||||
|
|
||||||
|
name = type(obj).__name__
|
||||||
|
# after backward, we want to update the records, to show you the change
|
||||||
|
if isinstance(self.module, nn.Module) and name == 'Parameter':
|
||||||
|
if obj.grad is not None:
|
||||||
|
# with grad attached
|
||||||
|
for par_name, param in self.module.named_parameters():
|
||||||
|
if param.requires_grad and param.grad.equal(obj.grad):
|
||||||
|
name = par_name + ' (with grad)'
|
||||||
|
else:
|
||||||
|
# with no grad attached
|
||||||
|
# there will be no new paramters created during running
|
||||||
|
# so it must be in saved_tensor_info
|
||||||
|
continue
|
||||||
|
# we can also marked common tensors as tensor(with grad)
|
||||||
|
if name == 'Tensor' and (obj.is_leaf or obj.retains_grad):
|
||||||
|
if obj.grad is not None:
|
||||||
|
name = name + ' (with grad)'
|
||||||
|
# in fact, common tensor have no grad
|
||||||
|
# unless you set retain_grad()
|
||||||
|
if id(obj) in self.saved_tensor_info.keys() and name == self.saved_tensor_info[id(obj)][0]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.tensor_info[id(obj)].append(name)
|
||||||
|
self.tensor_info[id(obj)].append(obj.device)
|
||||||
|
self.tensor_info[id(obj)].append(obj.shape)
|
||||||
|
self.tensor_info[id(obj)].append(obj.requires_grad)
|
||||||
|
self.tensor_info[id(obj)].append(obj.dtype)
|
||||||
|
self.tensor_info[id(obj)].append(self.get_tensor_mem(obj))
|
||||||
|
# recorded the order we got the tensor
|
||||||
|
# by this we can guess the tensor easily
|
||||||
|
# it will record every tensor updated this turn
|
||||||
|
self.order.append(id(obj))
|
||||||
|
# recorded all different devices
|
||||||
|
if obj.device not in self.devices:
|
||||||
|
self.devices.append(obj.device)
|
||||||
|
|
||||||
|
|
||||||
|
def print_tensors_state(self):
|
||||||
|
template_format = '{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}'
|
||||||
|
self.info += LINE
|
||||||
|
self.info += template_format.format(' ', 'Tensor', 'device', 'shape', 'grad', 'dtype', 'Mem')
|
||||||
|
self.info += '\n'
|
||||||
|
self.info += LINE
|
||||||
|
|
||||||
|
# if a tensor updates this turn, and was recorded before
|
||||||
|
# it should be updated in the saved_tensor_info as well
|
||||||
|
outdated = [x for x in self.saved_tensor_info.keys() if x in self.order]
|
||||||
|
minus = [x for x in self.saved_tensor_info.keys() if x not in self.detected]
|
||||||
|
minus = outdated + minus
|
||||||
|
if len(self.order) > 0:
|
||||||
|
for tensor_id in self.order:
|
||||||
|
self.info += template_format.format('+',
|
||||||
|
str(self.tensor_info[tensor_id][0]),
|
||||||
|
str(self.tensor_info[tensor_id][1]),
|
||||||
|
str(tuple(self.tensor_info[tensor_id][2])),
|
||||||
|
str(self.tensor_info[tensor_id][3]),
|
||||||
|
str(self.tensor_info[tensor_id][4]),
|
||||||
|
str(self.tensor_info[tensor_id][5]))
|
||||||
|
self.info += '\n'
|
||||||
|
if len(self.order) > 0 and len(minus) > 0:
|
||||||
|
self.info += '\n'
|
||||||
|
if len(minus) > 0:
|
||||||
|
for tensor_id in minus:
|
||||||
|
self.info += template_format.format('-',
|
||||||
|
str(self.saved_tensor_info[tensor_id][0]),
|
||||||
|
str(self.saved_tensor_info[tensor_id][1]),
|
||||||
|
str(tuple(self.saved_tensor_info[tensor_id][2])),
|
||||||
|
str(self.saved_tensor_info[tensor_id][3]),
|
||||||
|
str(self.saved_tensor_info[tensor_id][4]),
|
||||||
|
str(self.saved_tensor_info[tensor_id][5]))
|
||||||
|
self.info += '\n'
|
||||||
|
# deleted the updated tensor
|
||||||
|
self.saved_tensor_info.pop(tensor_id)
|
||||||
|
|
||||||
|
|
||||||
|
# trace where is the detect()
|
||||||
|
locate_info = inspect.stack()[2]
|
||||||
|
locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno)
|
||||||
|
|
||||||
|
self.info += LINE
|
||||||
|
self.info += f"Detect Location: {locate_msg}\n"
|
||||||
|
for device in self.devices:
|
||||||
|
if device == torch.device('cpu'):
|
||||||
|
continue
|
||||||
|
gpu_mem_alloc = self.mem_format(torch.cuda.memory_allocated(device))
|
||||||
|
self.info += f"Totle GPU Memery Allocated on {device} is {gpu_mem_alloc}\n"
|
||||||
|
self.info += LINE
|
||||||
|
self.info += '\n\n'
|
||||||
|
if self.show_info:
|
||||||
|
print(self.info)
|
||||||
|
if self.log is not None:
|
||||||
|
with open(self.log + '.log', 'a') as f:
|
||||||
|
f.write(self.info)
|
||||||
|
|
||||||
|
|
||||||
|
def detect(self, include_cpu = False):
|
||||||
|
self.include_cpu = include_cpu
|
||||||
|
self.collect_tensors_state()
|
||||||
|
self.print_tensors_state()
|
||||||
|
self.saved_tensor_info.update(self.tensor_info)
|
||||||
|
self.tensor_info.clear()
|
||||||
|
self.order = []
|
||||||
|
self.detected = []
|
||||||
|
self.info = ""
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.saved_tensor_info.clear()
|
||||||
|
self.module = None
|
@ -3,10 +3,11 @@ import functools
|
|||||||
import torch
|
import torch
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param import ShardedParamV2
|
||||||
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||||
|
|
||||||
|
|
||||||
# Inserts _post_init_method at the end of init method
|
# Inserts _post_init_method at the end of init method
|
||||||
|
|
||||||
|
|
||||||
# for all sub classes of torch.nn.Module
|
# for all sub classes of torch.nn.Module
|
||||||
class InsertPostInitMethodToModuleSubClasses(object):
|
class InsertPostInitMethodToModuleSubClasses(object):
|
||||||
|
|
||||||
@ -152,7 +153,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||||||
|
|
||||||
if self.shard_param:
|
if self.shard_param:
|
||||||
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
|
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
|
||||||
GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._data_sharded_tensor.payload)
|
ModelDataTracer().add_tensor(param.col_attr._data_sharded_tensor.payload)
|
||||||
if param.col_attr.grad and self.shard_grad:
|
if param.col_attr.grad and self.shard_grad:
|
||||||
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
|
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
|
||||||
GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._grad_sharded_tensor.payload)
|
ModelDataTracer().add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from colossalai.zero.shard_utils.base_shard_strategy import BaseShardStrategy
|
from .base_shard_strategy import BaseShardStrategy
|
||||||
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
|
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
|
||||||
|
from .tensor_shard_strategy import TensorShardStrategy
|
||||||
|
|
||||||
__all__ = ['BaseShardStrategy', 'TensorShardStrategy']
|
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']
|
||||||
|
41
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
Normal file
41
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||||
|
from torch._utils import _flatten_dense_tensors as flatten
|
||||||
|
|
||||||
|
from .tensor_shard_strategy import TensorShardStrategy
|
||||||
|
|
||||||
|
|
||||||
|
class BucketTensorShardStrategy(TensorShardStrategy):
|
||||||
|
|
||||||
|
def gather(self, tensor_list: List[ShardedTensor]):
|
||||||
|
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
|
||||||
|
if len(tensor_list) == 0:
|
||||||
|
return
|
||||||
|
target_device = tensor_list[0].device
|
||||||
|
dtype = tensor_list[0].dtype
|
||||||
|
buffer_list: List[torch.Tensor] = []
|
||||||
|
tensor_numels = [t.payload.numel() for t in tensor_list]
|
||||||
|
buffer_size = sum(tensor_numels)
|
||||||
|
for i in range(self.world_size):
|
||||||
|
if i == self.local_rank:
|
||||||
|
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
|
||||||
|
# Release payload here, to decrease peak memory usage
|
||||||
|
for t in tensor_list:
|
||||||
|
t.reset_payload(None)
|
||||||
|
else:
|
||||||
|
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
|
||||||
|
dist.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group)
|
||||||
|
# Move to target device before splitting buffer
|
||||||
|
# Ensure we utilize maximum PCIE bandwidth
|
||||||
|
buffer_list = [buffer.to(target_device) for buffer in buffer_list]
|
||||||
|
offset = 0
|
||||||
|
for i, t in enumerate(tensor_list):
|
||||||
|
gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list]
|
||||||
|
gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape)
|
||||||
|
t.reset_payload(gathered_payload)
|
||||||
|
t.is_sharded = False
|
||||||
|
offset += tensor_numels[i]
|
@ -17,7 +17,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
|||||||
from colossalai.zero.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param import ShardedParamV2
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||||
|
from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
|
||||||
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
|
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
|
||||||
get_gradient_predivide_factor)
|
get_gradient_predivide_factor)
|
||||||
|
|
||||||
@ -33,7 +34,8 @@ class ShardedModelV2(nn.Module):
|
|||||||
fp32_reduce_scatter: bool = False,
|
fp32_reduce_scatter: bool = False,
|
||||||
offload_config: Optional[dict] = None,
|
offload_config: Optional[dict] = None,
|
||||||
gradient_predivide_factor: Optional[float] = 1.0,
|
gradient_predivide_factor: Optional[float] = 1.0,
|
||||||
shard_param: bool = True):
|
shard_param: bool = True,
|
||||||
|
use_memory_tracer: bool = False):
|
||||||
r"""
|
r"""
|
||||||
A demo to reconfigure zero1 shared_model.
|
A demo to reconfigure zero1 shared_model.
|
||||||
Currently do not consider the Optimizer States.
|
Currently do not consider the Optimizer States.
|
||||||
@ -59,8 +61,16 @@ class ShardedModelV2(nn.Module):
|
|||||||
if self.shard_param:
|
if self.shard_param:
|
||||||
self.shard_strategy.shard([param.col_attr.data])
|
self.shard_strategy.shard([param.col_attr.data])
|
||||||
|
|
||||||
|
# Init Memory Statistics Collector
|
||||||
|
self._use_memory_tracer = use_memory_tracer
|
||||||
|
if self._use_memory_tracer:
|
||||||
|
self._memstats_collector = MemStatsCollector()
|
||||||
|
else:
|
||||||
|
self._memstats_collector = None
|
||||||
|
self._iter_cnter = 0
|
||||||
|
|
||||||
# Register hooks
|
# Register hooks
|
||||||
register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy)])
|
register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy, self._memstats_collector)])
|
||||||
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
|
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
|
||||||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||||
|
|
||||||
@ -84,6 +94,9 @@ class ShardedModelV2(nn.Module):
|
|||||||
return self._cpu_offload
|
return self._cpu_offload
|
||||||
|
|
||||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||||
|
if self._iter_cnter == 0 and self._memstats_collector:
|
||||||
|
# the opeartion will affect the flag in ZeroHook
|
||||||
|
self._memstats_collector.start_collection()
|
||||||
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
||||||
outputs = self.module(*args, **kwargs)
|
outputs = self.module(*args, **kwargs)
|
||||||
return outputs
|
return outputs
|
||||||
@ -98,6 +111,12 @@ class ShardedModelV2(nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _final_backward_hook(self) -> None:
|
def _final_backward_hook(self) -> None:
|
||||||
|
if self._iter_cnter == 0 and self._memstats_collector:
|
||||||
|
self._memstats_collector.finish_collection()
|
||||||
|
if self._memstats_collector:
|
||||||
|
self._memstats_collector.reset_sampling_cnter()
|
||||||
|
self._iter_cnter += 1
|
||||||
|
|
||||||
if self._require_backward_grad_sync:
|
if self._require_backward_grad_sync:
|
||||||
# Flush any unreduced buckets in the post_backward stream.
|
# Flush any unreduced buckets in the post_backward stream.
|
||||||
with torch.cuda.stream(self.comm_stream):
|
with torch.cuda.stream(self.comm_stream):
|
||||||
@ -185,8 +204,10 @@ class ShardedModelV2(nn.Module):
|
|||||||
reduced_grad.data = cast_tensor_to_fp32(reduced_grad.data)
|
reduced_grad.data = cast_tensor_to_fp32(reduced_grad.data)
|
||||||
|
|
||||||
# Maybe offload
|
# Maybe offload
|
||||||
|
# TODO() optimize GPU->CPU bandwidth utilization
|
||||||
if self._cpu_offload:
|
if self._cpu_offload:
|
||||||
reduced_grad.data = reduced_grad.data.cpu()
|
col_move_to_cpu(reduced_grad)
|
||||||
|
# reduced_grad.data = reduced_grad.data.cpu()
|
||||||
|
|
||||||
if param.col_attr.grad is None:
|
if param.col_attr.grad is None:
|
||||||
param.col_attr.grad = reduced_grad.data
|
param.col_attr.grad = reduced_grad.data
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional
|
from typing import Callable, Dict, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -15,7 +15,7 @@ from torch import Tensor
|
|||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
from typing import Type, Any
|
||||||
from ._utils import has_inf_or_nan
|
from ._utils import has_inf_or_nan
|
||||||
|
|
||||||
|
|
||||||
@ -27,8 +27,8 @@ class OptimState(Enum):
|
|||||||
class ShardedOptimizerV2(ColossalaiOptimizer):
|
class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
optimizer: Optimizer,
|
|
||||||
sharded_model: ShardedModelV2,
|
sharded_model: ShardedModelV2,
|
||||||
|
optimizer_class: Type[Optimizer],
|
||||||
shard_strategy: BaseShardStrategy,
|
shard_strategy: BaseShardStrategy,
|
||||||
cpu_offload: bool = False,
|
cpu_offload: bool = False,
|
||||||
initial_scale: float = 2**32,
|
initial_scale: float = 2**32,
|
||||||
@ -39,9 +39,34 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||||||
hysteresis: float = 2,
|
hysteresis: float = 2,
|
||||||
max_scale: int = 2**32,
|
max_scale: int = 2**32,
|
||||||
dp_process_group: Optional[ProcessGroup] = None,
|
dp_process_group: Optional[ProcessGroup] = None,
|
||||||
mp_process_group: Optional[ProcessGroup] = None) -> None:
|
mp_process_group: Optional[ProcessGroup] = None,
|
||||||
|
**defaults: Any) -> None:
|
||||||
|
"""
|
||||||
|
:param sharded_model: A sharded model initialized by class ShardedModelV2
|
||||||
|
:type sharded_model: sharded_model
|
||||||
|
|
||||||
|
:param optimizer_class: A type of Optimizer
|
||||||
|
:type optimizer_class: Type[Optimizer]
|
||||||
|
|
||||||
|
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
|
||||||
|
:type shard_strategy: BaseShardStrategy
|
||||||
|
|
||||||
|
:param cpu_offload: is offloading the optimizer states to CPU.
|
||||||
|
:type cpu_offload: bool
|
||||||
|
|
||||||
|
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
|
||||||
|
:type shard_strategy: BaseShardStrategy
|
||||||
|
:**defaults: any trailing arguments, which are forwarded to the local optimizer.
|
||||||
|
:type defaults: dict()
|
||||||
|
"""
|
||||||
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
|
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
|
||||||
super().__init__(optimizer)
|
|
||||||
|
self._optim_defaults = defaults
|
||||||
|
# initialize the M, V as zeros tensors and initialize param fp32 from sharded_model.parameters()
|
||||||
|
|
||||||
|
self.optimizer = optimizer_class(sharded_model.parameters(), **self._optim_defaults)
|
||||||
|
|
||||||
|
super().__init__(self.optimizer)
|
||||||
self.shard_strategy = shard_strategy
|
self.shard_strategy = shard_strategy
|
||||||
self.model: ShardedModelV2 = sharded_model
|
self.model: ShardedModelV2 = sharded_model
|
||||||
if cpu_offload and not sharded_model.cpu_offload:
|
if cpu_offload and not sharded_model.cpu_offload:
|
||||||
@ -65,7 +90,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||||||
# Store fp32 param shards
|
# Store fp32 param shards
|
||||||
self.master_params: Dict[Parameter, Tensor] = {}
|
self.master_params: Dict[Parameter, Tensor] = {}
|
||||||
|
|
||||||
for group in optimizer.param_groups:
|
for group in self.optimizer.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam'
|
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam'
|
||||||
is_param_sharded = p.col_attr.data.is_sharded
|
is_param_sharded = p.col_attr.data.is_sharded
|
||||||
@ -118,7 +143,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||||||
# We have to use `copy_payload` instead of `reset_payload`
|
# We have to use `copy_payload` instead of `reset_payload`
|
||||||
# Since p.data is fp32 and p.col_attr.data is fp16
|
# Since p.data is fp32 and p.col_attr.data is fp16
|
||||||
|
|
||||||
# TODO() optimize this line
|
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||||
p.col_attr.data.copy_payload(p.data)
|
p.col_attr.data.copy_payload(p.data)
|
||||||
|
|
||||||
if not is_param_sharded:
|
if not is_param_sharded:
|
||||||
|
2
examples
2
examples
@ -1 +1 @@
|
|||||||
Subproject commit d50ef2db51e7d02ed3f7e9de13f9af86b04eaae9
|
Subproject commit 5345187ad55e8c80c111e0c5f7ad9b9241e8f913
|
@ -74,8 +74,5 @@ def get_training_components():
|
|||||||
sequence_length=sequence_length,
|
sequence_length=sequence_length,
|
||||||
is_distrbuted=True)
|
is_distrbuted=True)
|
||||||
|
|
||||||
def get_optim(model):
|
|
||||||
return torch.optim.Adam(model.parameters(), lr=0.001)
|
|
||||||
|
|
||||||
criterion = None
|
criterion = None
|
||||||
return bert_model_builder, trainloader, testloader, get_optim, criterion
|
return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion
|
||||||
|
@ -49,8 +49,5 @@ def get_training_components():
|
|||||||
trainloader = DummyDataLoader()
|
trainloader = DummyDataLoader()
|
||||||
testloader = DummyDataLoader()
|
testloader = DummyDataLoader()
|
||||||
|
|
||||||
def optim_builder(model):
|
|
||||||
return torch.optim.Adam(model.parameters(), lr=0.001)
|
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
return model_builder, trainloader, testloader, optim_builder, criterion
|
return model_builder, trainloader, testloader, torch.optim.Adam, criterion
|
||||||
|
@ -43,8 +43,5 @@ def get_training_components():
|
|||||||
trainloader = DummyDataLoader()
|
trainloader = DummyDataLoader()
|
||||||
testloader = DummyDataLoader()
|
testloader = DummyDataLoader()
|
||||||
|
|
||||||
def optim_builder(model):
|
|
||||||
return torch.optim.Adam(model.parameters(), lr=0.001)
|
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
return model_builder, trainloader, testloader, optim_builder, criterion
|
return model_builder, trainloader, testloader, torch.optim.Adam, criterion
|
||||||
|
@ -29,8 +29,5 @@ def get_resnet_training_components():
|
|||||||
trainloader = get_cifar10_dataloader(train=True)
|
trainloader = get_cifar10_dataloader(train=True)
|
||||||
testloader = get_cifar10_dataloader(train=False)
|
testloader = get_cifar10_dataloader(train=False)
|
||||||
|
|
||||||
def optim_builder(model):
|
|
||||||
return torch.optim.Adam(model.parameters(), lr=0.001)
|
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
return model_builder, trainloader, testloader, optim_builder, criterion
|
return model_builder, trainloader, testloader, torch.optim.Adam, criterion
|
||||||
|
@ -19,11 +19,11 @@ def run_train():
|
|||||||
# FIXME: test bert
|
# FIXME: test bert
|
||||||
for model_name in test_models:
|
for model_name in test_models:
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func()
|
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||||
|
|
||||||
model = model_builder(checkpoint=False)
|
model = model_builder(checkpoint=False)
|
||||||
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
||||||
optimizer=optimizer_builder(model),
|
optimizer=optimizer_class(model.parameters(), lr=1e-3),
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
train_dataloader=train_dataloader)
|
train_dataloader=train_dataloader)
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ def run_engine(rank, world_size, port):
|
|||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_engine():
|
def test_engine():
|
||||||
world_size = 4
|
world_size = 2
|
||||||
run_func = partial(run_engine, world_size=world_size, port=free_port())
|
run_func = partial(run_engine, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
@ -25,9 +25,9 @@ def run_trainer_no_pipeline(rank, world_size, port):
|
|||||||
test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
|
test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
|
||||||
for name in test_models:
|
for name in test_models:
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(name)
|
get_components_func = non_distributed_component_funcs.get_callable(name)
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_builder, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
model = model_builder()
|
model = model_builder()
|
||||||
optimizer = optimizer_builder(model)
|
optimizer = optimizer_class(model.parameters(), lr=1e-3)
|
||||||
engine, train_dataloader, *_ = colossalai.initialize(model=model,
|
engine, train_dataloader, *_ = colossalai.initialize(model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
|
@ -56,6 +56,7 @@ def test_activation_checkpointing(cpu_offload):
|
|||||||
|
|
||||||
assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match'
|
assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match'
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# as seed manager is singleton
|
# as seed manager is singleton
|
||||||
# if we don't reset seeds here,
|
# if we don't reset seeds here,
|
||||||
# other tests will fail if running together with this test
|
# other tests will fail if running together with this test
|
||||||
|
@ -4,21 +4,20 @@
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.utils.cuda import get_current_device
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
from colossalai.zero.shard_utils.tensor_shard_strategy import \
|
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||||
TensorShardStrategy
|
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
from common import CONFIG
|
from common import CONFIG
|
||||||
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, init_device):
|
def run_dist(rank, world_size, port, init_device, shard_strategy):
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
for get_components_func in non_distributed_component_funcs:
|
for get_components_func in non_distributed_component_funcs:
|
||||||
@ -26,7 +25,7 @@ def run_dist(rank, world_size, port, init_device):
|
|||||||
model_numel_tensor = torch.zeros(1, dtype=torch.int)
|
model_numel_tensor = torch.zeros(1, dtype=torch.int)
|
||||||
with ZeroInitContext(convert_fp16=True,
|
with ZeroInitContext(convert_fp16=True,
|
||||||
target_device=init_device,
|
target_device=init_device,
|
||||||
shard_strategy=TensorShardStrategy(),
|
shard_strategy=shard_strategy(),
|
||||||
shard_param=True,
|
shard_param=True,
|
||||||
model_numel_tensor=model_numel_tensor):
|
model_numel_tensor=model_numel_tensor):
|
||||||
model = model_builder(checkpoint=True)
|
model = model_builder(checkpoint=True)
|
||||||
@ -38,23 +37,25 @@ def run_dist(rank, world_size, port, init_device):
|
|||||||
assert param.col_attr.data.payload.device.type == init_device.type, \
|
assert param.col_attr.data.payload.device.type == init_device.type, \
|
||||||
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
|
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
|
||||||
|
|
||||||
print(f'cpu usgae {GLOBAL_MODEL_DATA_TRACER.cpu_usage}')
|
print(f'cuda usgae {ModelDataTracer().cuda_usage}')
|
||||||
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
|
|
||||||
print(f'numel {model_numel_tensor}')
|
print(f'numel {model_numel_tensor}')
|
||||||
if init_device.type == 'cuda':
|
if init_device.type == 'cuda':
|
||||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
assert (ModelDataTracer().cuda_usage > 0)
|
||||||
elif init_device.type == 'cpu':
|
|
||||||
assert (GLOBAL_MODEL_DATA_TRACER.cpu_usage > 0)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [1, 4])
|
@pytest.mark.parametrize("world_size", [1, 4])
|
||||||
@pytest.mark.parametrize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
|
@pytest.mark.parametrize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
|
||||||
def test_zero_init_context(world_size, init_device):
|
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), init_device=init_device)
|
def test_zero_init_context(world_size, init_device, shard_strategy):
|
||||||
|
run_func = partial(run_dist,
|
||||||
|
world_size=world_size,
|
||||||
|
port=free_port(),
|
||||||
|
init_device=init_device,
|
||||||
|
shard_strategy=shard_strategy)
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_zero_init_context(2, torch.device('cpu'))
|
# test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
|
||||||
test_zero_init_context(2, torch.device(f'cuda:{get_current_device()}'))
|
test_zero_init_context(4, torch.device('cpu'), BucketTensorShardStrategy)
|
||||||
|
@ -3,30 +3,29 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import pytest
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.zero.shard_utils.tensor_shard_strategy import \
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
TensorShardStrategy
|
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
|
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
|
||||||
|
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from common import CONFIG, check_grads_padding, run_fwd_bwd
|
from common import CONFIG, check_grads_padding, run_fwd_bwd
|
||||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
|
def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy):
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||||
shard_strategy = TensorShardStrategy()
|
shard_strategy = shard_strategy()
|
||||||
for model_name in test_models:
|
for model_name in test_models:
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, _, _, criterion = get_components_func()
|
model_builder, train_dataloader, _, _, criterion = get_components_func()
|
||||||
@ -35,12 +34,12 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
|
|||||||
|
|
||||||
if use_zero_init_ctx:
|
if use_zero_init_ctx:
|
||||||
with ZeroInitContext(convert_fp16=True,
|
with ZeroInitContext(convert_fp16=True,
|
||||||
target_device=torch.device('cpu'),
|
target_device=torch.device(f'cpu:0'),
|
||||||
shard_strategy=shard_strategy,
|
shard_strategy=shard_strategy,
|
||||||
shard_param=True,
|
shard_param=True,
|
||||||
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly):
|
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly):
|
||||||
zero_model = model_builder(checkpoint=True)
|
zero_model = model_builder(checkpoint=True)
|
||||||
zero_model = ShardedModelV2(zero_model, shard_strategy)
|
zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
|
||||||
|
|
||||||
model = model_builder(checkpoint=True).half()
|
model = model_builder(checkpoint=True).half()
|
||||||
col_model_deepcopy(zero_model, model)
|
col_model_deepcopy(zero_model, model)
|
||||||
@ -61,19 +60,24 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
|
|||||||
|
|
||||||
check_grads_padding(model, zero_model, loose=True)
|
check_grads_padding(model, zero_model, loose=True)
|
||||||
|
|
||||||
|
print('overall cuda ', zero_model._memstats_collector._overall_cuda)
|
||||||
|
print('model cuda ', zero_model._memstats_collector._model_data_cuda)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [1, 2])
|
@pytest.mark.parametrize("world_size", [1, 2])
|
||||||
@pytest.mark.parametrize("enable_autocast", [True])
|
@pytest.mark.parametrize("enable_autocast", [True])
|
||||||
@pytest.mark.parametrize("use_zero_init_ctx", [True])
|
@pytest.mark.parametrize("use_zero_init_ctx", [True])
|
||||||
def test_shard_model_v2(world_size, use_zero_init_ctx, enable_autocast):
|
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
|
def test_shard_model_v2(world_size, use_zero_init_ctx, enable_autocast, shard_strategy):
|
||||||
run_func = partial(run_dist,
|
run_func = partial(run_dist,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
port=free_port(),
|
port=free_port(),
|
||||||
use_zero_init_ctx=use_zero_init_ctx,
|
use_zero_init_ctx=use_zero_init_ctx,
|
||||||
enable_autocast=enable_autocast)
|
enable_autocast=enable_autocast,
|
||||||
|
shard_strategy=shard_strategy)
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_shard_model_v2(world_size=2, use_zero_init_ctx=True, enable_autocast=True)
|
test_shard_model_v2(world_size=2, use_zero_init_ctx=True, enable_autocast=True, shard_strategy=TensorShardStrategy)
|
||||||
|
@ -10,20 +10,20 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||||
from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
|
from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
|
||||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||||
from tests.test_zero_data_parallel.common import CONFIG, allclose
|
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
from tests.test_zero_data_parallel.common import CONFIG, allclose
|
||||||
|
|
||||||
|
|
||||||
def _run_shard_tensor(rank, world_size, port):
|
def _run_shard_tensor(rank, world_size, port, shard_strategy):
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
|
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
|
||||||
assert list(t.origin_shape) == [world_size * 2, 3]
|
assert list(t.origin_shape) == [world_size * 2, 3]
|
||||||
assert list(t.shape) == [world_size * 2, 3]
|
assert list(t.shape) == [world_size * 2, 3]
|
||||||
|
|
||||||
shard_strategy = TensorShardStrategy(process_group=None)
|
shard_strategy = shard_strategy(process_group=None)
|
||||||
|
|
||||||
# test shard strategy
|
# test shard strategy
|
||||||
shard_strategy.shard([t])
|
shard_strategy.shard([t])
|
||||||
@ -34,8 +34,9 @@ def _run_shard_tensor(rank, world_size, port):
|
|||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [1, 2])
|
@pytest.mark.parametrize("world_size", [1, 2])
|
||||||
def test_shard_tensor(world_size):
|
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port())
|
def test_shard_tensor(world_size, shard_strategy):
|
||||||
|
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
@ -121,7 +122,7 @@ def test_init_shard_param(world_size):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_shard_tensor(2)
|
test_shard_tensor(2, TensorShardStrategy)
|
||||||
test_shard_param(2)
|
test_shard_param(2)
|
||||||
test_shard_param_v2(2)
|
test_shard_param_v2(2)
|
||||||
test_init_shard_param(4)
|
test_init_shard_param(4)
|
||||||
|
@ -1,6 +1,3 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
@ -10,7 +7,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
@ -38,25 +35,27 @@ def run_step(model, optimizer, data, label, criterion, enable_autocast=False):
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, cpu_offload):
|
def run_dist(rank, world_size, port, cpu_offload, shard_strategy):
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||||
|
shard_strategy = shard_strategy()
|
||||||
for model_name in test_models:
|
for model_name in test_models:
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
shard_strategy = TensorShardStrategy()
|
model, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
|
|
||||||
model = model(checkpoint=True).cuda()
|
model = model(checkpoint=True).cuda()
|
||||||
zero_model = ShardedModelV2(copy.deepcopy(model),
|
zero_model = ShardedModelV2(copy.deepcopy(model),
|
||||||
shard_strategy,
|
shard_strategy,
|
||||||
offload_config=dict(device='cpu') if cpu_offload else None)
|
offload_config=dict(device='cpu') if cpu_offload else None)
|
||||||
if dist.get_world_size() > 1:
|
if dist.get_world_size() > 1:
|
||||||
model = DDP(model)
|
model = DDP(model)
|
||||||
optim = Adam(model.parameters(), lr=1e-3)
|
lr = 1e-3
|
||||||
sharded_optim = ShardedOptimizerV2(Adam(zero_model.parameters(), lr=1e-3),
|
optim = optimizer_class(model.parameters(), lr=lr)
|
||||||
zero_model,
|
sharded_optim = ShardedOptimizerV2(zero_model,
|
||||||
|
optimizer_class,
|
||||||
shard_strategy,
|
shard_strategy,
|
||||||
cpu_offload=cpu_offload,
|
cpu_offload=cpu_offload,
|
||||||
initial_scale=2**5)
|
initial_scale=2**5,
|
||||||
|
lr=lr)
|
||||||
for i, (data, label) in enumerate(train_dataloader):
|
for i, (data, label) in enumerate(train_dataloader):
|
||||||
if i > 2:
|
if i > 2:
|
||||||
break
|
break
|
||||||
@ -69,10 +68,15 @@ def run_dist(rank, world_size, port, cpu_offload):
|
|||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [1, 2])
|
@pytest.mark.parametrize("world_size", [1, 2])
|
||||||
@pytest.mark.parametrize("cpu_offload", [True, False])
|
@pytest.mark.parametrize("cpu_offload", [True, False])
|
||||||
def test_sharded_optim_v2(world_size, cpu_offload):
|
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), cpu_offload=cpu_offload)
|
def test_sharded_optim_v2(world_size, cpu_offload, shard_strategy):
|
||||||
|
run_func = partial(run_dist,
|
||||||
|
world_size=world_size,
|
||||||
|
port=free_port(),
|
||||||
|
cpu_offload=cpu_offload,
|
||||||
|
shard_strategy=shard_strategy)
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_sharded_optim_v2(world_size=2, cpu_offload=True)
|
test_sharded_optim_v2(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy)
|
@ -11,7 +11,7 @@ import torch.distributed as dist
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.nn.optimizer import CPUAdam
|
from colossalai.nn.optimizer import CPUAdam
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
@ -47,23 +47,24 @@ def run_step_no_criterion(model, optimizer, data, label, enable_autocast=False):
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port, shard_strategy):
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||||
|
shard_strategy = shard_strategy()
|
||||||
for model_name in test_models:
|
for model_name in test_models:
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
shard_strategy = TensorShardStrategy()
|
|
||||||
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
|
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
|
||||||
model = model(checkpoint=True).cuda()
|
model = model(checkpoint=True).cuda()
|
||||||
zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy, offload_config={'device': 'cpu'})
|
zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy, offload_config={'device': 'cpu'})
|
||||||
if dist.get_world_size() > 1:
|
if dist.get_world_size() > 1:
|
||||||
model = DDP(model)
|
model = DDP(model)
|
||||||
optim = Adam(model.parameters(), lr=1e-3)
|
optim = Adam(model.parameters(), lr=1e-3)
|
||||||
sharded_optim = ShardedOptimizerV2(CPUAdam(zero_model.parameters(), lr=1e-3),
|
sharded_optim = ShardedOptimizerV2(zero_model,
|
||||||
zero_model,
|
CPUAdam,
|
||||||
shard_strategy,
|
shard_strategy,
|
||||||
initial_scale=2**5,
|
initial_scale=2**5,
|
||||||
cpu_offload=True)
|
cpu_offload=True,
|
||||||
|
lr=1e-3)
|
||||||
for i, (data, label) in enumerate(train_dataloader):
|
for i, (data, label) in enumerate(train_dataloader):
|
||||||
if i > 2:
|
if i > 2:
|
||||||
break
|
break
|
||||||
@ -79,10 +80,11 @@ def run_dist(rank, world_size, port):
|
|||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [1, 2])
|
@pytest.mark.parametrize("world_size", [1, 2])
|
||||||
def test_sharded_optim_v2(world_size):
|
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
def test_sharded_optim_v2(world_size, shard_strategy):
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_sharded_optim_v2(world_size=2)
|
test_sharded_optim_v2(world_size=2, shard_strategy=TensorShardStrategy)
|
||||||
|
@ -9,22 +9,21 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.zero.shard_utils.tensor_shard_strategy import \
|
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||||
TensorShardStrategy
|
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
from common import CONFIG
|
from common import CONFIG
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port, shard_strategy):
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
test_models = ['repeated_computed_layers', 'resnet18']
|
test_models = ['repeated_computed_layers', 'resnet18']
|
||||||
|
shard_strategy = shard_strategy()
|
||||||
for model_name in test_models:
|
for model_name in test_models:
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
|
||||||
model = model_builder()
|
model = model_builder()
|
||||||
shard_strategy = TensorShardStrategy()
|
|
||||||
model = model.half().cuda()
|
model = model.half().cuda()
|
||||||
zero_model = ShardedModelV2(deepcopy(model), shard_strategy)
|
zero_model = ShardedModelV2(deepcopy(model), shard_strategy)
|
||||||
zero_state_dict = zero_model.state_dict()
|
zero_state_dict = zero_model.state_dict()
|
||||||
@ -33,11 +32,12 @@ def run_dist(rank, world_size, port):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_zero_state_dict():
|
@pytest.mark.parametrize("world_size", [1, 2])
|
||||||
world_size = 2
|
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
def test_zero_state_dict(world_size, shard_strategy):
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_zero_state_dict()
|
test_zero_state_dict(2, TensorShardStrategy)
|
||||||
|
Loading…
Reference in New Issue
Block a user