Merge pull request #412 from hpcaitech/develop

merge develop to main
This commit is contained in:
Frank Lee 2022-03-14 22:48:56 +08:00 committed by GitHub
commit f8a0e7fb01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 826 additions and 226 deletions

View File

@ -6,7 +6,7 @@ All notable changes to this project will be documented in this file.
### Added ### Added
- Unifed distributed layers - Unified distributed layers
- MoE support - MoE support
- DevOps tools such as github action, code review automation, etc. - DevOps tools such as github action, code review automation, etc.
- New project official website - New project official website
@ -33,4 +33,4 @@ The first beta version of Colossal-AI. Thanks to all contributors for the effort
### Added ### Added
- Initial architecture of the system - Initial architecture of the system
- Features such as tensor parallelism, gradient clipping, gradient accumulation - Features such as tensor parallelism, gradient clipping, gradient accumulation

View File

@ -14,6 +14,7 @@
[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml) [![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
[![Documentation](https://readthedocs.org/projects/colossalai/badge/?version=latest)](https://colossalai.readthedocs.io/en/latest/?badge=latest) [![Documentation](https://readthedocs.org/projects/colossalai/badge/?version=latest)](https://colossalai.readthedocs.io/en/latest/?badge=latest)
[![codebeat badge](https://codebeat.co/badges/bfe8f98b-5d61-4256-8ad2-ccd34d9cc156)](https://codebeat.co/projects/github-com-hpcaitech-colossalai-main) [![codebeat badge](https://codebeat.co/badges/bfe8f98b-5d61-4256-8ad2-ccd34d9cc156)](https://codebeat.co/projects/github-com-hpcaitech-colossalai-main)
[![HuggingFace badge](https://img.shields.io/badge/%F0%9F%A4%97HuggingFace-Join-yellow)](https://huggingface.co/hpcai-tech)
[![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&amp)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&amp)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w)
[![WeChat badge](https://img.shields.io/badge/微信-加入-green?logo=wechat&amp)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png) [![WeChat badge](https://img.shields.io/badge/微信-加入-green?logo=wechat&amp)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png)

View File

@ -14,8 +14,10 @@
[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml) [![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
[![Documentation](https://readthedocs.org/projects/colossalai/badge/?version=latest)](https://colossalai.readthedocs.io/en/latest/?badge=latest) [![Documentation](https://readthedocs.org/projects/colossalai/badge/?version=latest)](https://colossalai.readthedocs.io/en/latest/?badge=latest)
[![codebeat badge](https://codebeat.co/badges/bfe8f98b-5d61-4256-8ad2-ccd34d9cc156)](https://codebeat.co/projects/github-com-hpcaitech-colossalai-main) [![codebeat badge](https://codebeat.co/badges/bfe8f98b-5d61-4256-8ad2-ccd34d9cc156)](https://codebeat.co/projects/github-com-hpcaitech-colossalai-main)
[![HuggingFace badge](https://img.shields.io/badge/%F0%9F%A4%97HuggingFace-Join-yellow)](https://huggingface.co/hpcai-tech)
[![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&amp)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&amp)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w)
[![WeChat badge](https://img.shields.io/badge/微信-加入-green?logo=wechat&amp)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png) [![WeChat badge](https://img.shields.io/badge/微信-加入-green?logo=wechat&amp)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png)
| [English](README.md) | [中文](README-zh-Hans.md) | | [English](README.md) | [中文](README-zh-Hans.md) |

@ -1 +1 @@
Subproject commit 62904e4ff2f3261c5469c773faa3d9307b6f16f4 Subproject commit 9757a137495a8fc8b12133087cffe3e4a97ed2cb

View File

@ -1,8 +1,12 @@
import torch import torch
from colossalai.registry import OPHOOKS from colossalai.registry import OPHOOKS
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy
from ._base_ophook import BaseOpHook from ._base_ophook import BaseOpHook
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
from typing import Optional
@OPHOOKS.register_module @OPHOOKS.register_module
@ -11,32 +15,50 @@ class ZeroHook(BaseOpHook):
A hook to process sharded param for ZeRO method. A hook to process sharded param for ZeRO method.
""" """
def __init__(self, shard_strategy: BaseShardStrategy): def __init__(self, shard_strategy: BaseShardStrategy, memstarts_collector: Optional[MemStatsCollector]):
super().__init__() super().__init__()
self.shard_strategy = shard_strategy self.shard_strategy = shard_strategy
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU # NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self.computing_device = torch.device(f'cuda:{get_current_device()}') self.computing_device = torch.device(f'cuda:{get_current_device()}')
self._memstarts_collector = memstarts_collector
def pre_fwd_exec(self, module: torch.nn.Module, *args): def pre_fwd_exec(self, module: torch.nn.Module, *args):
tensor_list = []
global_model_data_tracer = ModelDataTracer()
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
self.shard_strategy.gather([param.col_attr.data]) tensor_list.append(param.col_attr.data)
self.shard_strategy.gather(tensor_list)
for param in module.parameters():
if param.col_attr.data.device != self.computing_device: if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device) param.col_attr.data.to(self.computing_device)
global_model_data_tracer.add_tensor(param.col_attr.data.payload)
param.data = param.col_attr.data.payload param.data = param.col_attr.data.payload
if self._memstarts_collector:
self._memstarts_collector.sample_memstats()
def post_fwd_exec(self, module: torch.nn.Module, *args): def post_fwd_exec(self, module: torch.nn.Module, *args):
tensor_list = []
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
self.shard_strategy.shard([param.col_attr.data]) tensor_list.append(param.col_attr.data)
param.data = torch.empty([], dtype=param.col_attr.data.dtype, device=param.col_attr.data.payload.device) self.shard_strategy.shard(tensor_list)
for param in module.parameters():
param.col_attr.remove_torch_payload()
def pre_bwd_exec(self, module: torch.nn.Module, input, output): def pre_bwd_exec(self, module: torch.nn.Module, input, output):
tensor_list = []
global_model_data_tracer = ModelDataTracer()
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
self.shard_strategy.gather([param.col_attr.data]) tensor_list.append(param.col_attr.data)
self.shard_strategy.gather(tensor_list)
for param in module.parameters():
if param.col_attr.data.device != self.computing_device: if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device) param.col_attr.data.to(self.computing_device)
global_model_data_tracer.add_tensor(param.col_attr.data.payload)
param.data = param.col_attr.data.payload param.data = param.col_attr.data.payload
# Store local accumulated grad shard # Store local accumulated grad shard
if param.grad is not None: if param.grad is not None:
@ -50,12 +72,17 @@ class ZeroHook(BaseOpHook):
# The grad here must be locally computed full grad in this backward pass # The grad here must be locally computed full grad in this backward pass
assert param.grad.shape == param.col_attr.data.origin_shape assert param.grad.shape == param.col_attr.data.origin_shape
param.col_attr.bwd_count += 1 param.col_attr.bwd_count += 1
if self._memstarts_collector:
self._memstarts_collector.sample_memstats()
def post_bwd_exec(self, module: torch.nn.Module, input): def post_bwd_exec(self, module: torch.nn.Module, input):
tensor_list = []
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
self.shard_strategy.shard([param.col_attr.data]) tensor_list.append(param.col_attr.data)
param.data = torch.empty([], dtype=param.col_attr.data.dtype, device=param.col_attr.data.payload.device) self.shard_strategy.shard(tensor_list)
for param in module.parameters():
param.col_attr.remove_torch_payload()
def pre_iter(self): def pre_iter(self):
pass pass

View File

@ -10,7 +10,7 @@ from .data_sampler import DataParallelSampler, get_dataloader
from .gradient_accumulation import accumulate_gradient from .gradient_accumulation import accumulate_gradient
from .memory import report_memory_usage from .memory import report_memory_usage
from .timer import MultiTimer, Timer from .timer import MultiTimer, Timer
#from .tensor_detector import TensorDetector from .tensor_detector import TensorDetector
__all__ = [ __all__ = [
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0', 'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',

View File

@ -1,60 +1,19 @@
import torch import torch
from colossalai.utils.commons.singleton_meta import SingletonMeta from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
from colossalai.zero.sharded_param import ShardedTensor
from typing import Union
def col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int: def col_move_to_cpu(t: torch.Tensor):
if isinstance(t, ShardedTensor): assert isinstance(t, torch.Tensor)
target = t.payload if t.device.type == 'cpu':
else: return
target = t
return target.numel() * target.element_size() ModelDataTracer().delete_tensor(t)
t.data = t.data.cpu()
class ModelDataTracer(metaclass=SingletonMeta): def col_modeldata_allocate(device: torch.device) -> torch.Tensor:
"""
A singleton to trace model data usage during runtime.
"""
def __init__(self) -> None:
self._cpu_usage = 0
self._cuda_usage = 0
def trace_tensor(self, t: torch.Tensor):
mem_use = col_tensor_mem_usage(t)
if t.device.type == 'cpu':
self._cpu_usage += mem_use
elif t.device.type == 'cuda':
self._cuda_usage += mem_use
else:
raise RuntimeError
def detach_tensor(self, t: torch.Tensor):
mem_use = col_tensor_mem_usage(t)
if t.device.type == 'cpu':
self._cpu_usage -= mem_use
elif t.device.type == 'cuda':
self._cuda_usage -= mem_use
else:
raise RuntimeError
@property
def cpu_usage(self):
return self._cpu_usage
@property
def cuda_usage(self):
return self._cuda_usage
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
def col_allocate_payload(device: torch.device) -> torch.Tensor:
pass pass
def col_release_payload(t: torch.Tensor): def col_modeldata_release(t: torch.Tensor):
pass pass

View File

@ -6,7 +6,7 @@ from colossalai.utils import get_current_device
import torch import torch
def _get_cuda_memory_used(device: torch.device) -> int: def get_cuda_memory_used(device: torch.device) -> int:
""" """
Get the free memory info of device. Get the free memory info of device.
:param device: device id :param device: device id
@ -87,7 +87,7 @@ class AsyncMemoryMonitor:
while self.keep_measuring: while self.keep_measuring:
max_usage = max( max_usage = max(
max_usage, max_usage,
_get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')), get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')),
) )
sleep(self.interval) sleep(self.interval)
return max_usage return max_usage

View File

@ -0,0 +1,11 @@
from colossalai.zero.sharded_param import ShardedTensor
from typing import Union
import torch
def col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
if isinstance(t, ShardedTensor):
target = t.payload
else:
target = t
return target.numel() * target.element_size()

View File

@ -0,0 +1,81 @@
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
from .async_memtracer import get_cuda_memory_used
from colossalai.utils import get_current_device
import torch
class SamplingCounter:
def __init__(self) -> None:
self._samplint_cnt = 0
def advance(self):
self._samplint_cnt += 1
@property
def sampling_cnt(self):
return self._samplint_cnt
def reset(self):
self._samplint_cnt = 0
class MemStatsCollector:
def __init__(self) -> None:
"""
Collecting Memory Statistics.
It has two phases.
1. Collection Phase: collect memory usage statistics
2. Runtime Phase: do not collect statistics.
"""
self._sampling_cnter = SamplingCounter()
self._model_data_cuda = []
self._overall_cuda = []
# TODO(jiaruifang) Now no cpu mem stats collecting
self._model_data_cpu = []
self._overall_cpu = []
self._start_flag = False
def start_collection(self):
self._start_flag = True
def finish_collection(self):
self._start_flag = False
def sample_memstats(self) -> None:
"""
Sampling memory statistics.
Record the current model data CUDA memory usage as well as system CUDA memory usage.
"""
if self._start_flag:
sampling_cnt = self._sampling_cnter.sampling_cnt
assert sampling_cnt == len(self._overall_cuda)
self._model_data_cuda.append(ModelDataTracer().cuda_usage)
self._overall_cuda.append(get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
self._sampling_cnter.advance()
def fetch_memstats(self) -> (int, int):
"""
returns cuda usage of model data and overall cuda usage.
"""
sampling_cnt = self._sampling_cnter.sampling_cnt
if len(self._model_data_cuda) < sampling_cnt:
raise RuntimeError
return (self._model_data_cuda[sampling_cnt], self._overall_cuda[sampling_cnt])
def reset_sampling_cnter(self) -> None:
self._sampling_cnter.reset()
def clear(self) -> None:
self._model_data_cuda = []
self._overall_cuda = []
self._model_data_cpu = []
self._overall_cpu = []
self._start_flag = False
self._sampling_cnter.reset()

View File

@ -0,0 +1,34 @@
from colossalai.utils.commons.singleton_meta import SingletonMeta
from colossalai.utils.memory_tracer.commons import col_tensor_mem_usage
import torch
class ModelDataTracer(metaclass=SingletonMeta):
"""
A singleton to trace model data usage during runtime.
We have to trigger our API (trace_tensor, detach_tensor) when do model-data memory operation,
including allocation, releasing and moving.
NOTE() now the class only trace cuda memory usage
"""
def __init__(self) -> None:
self._cuda_usage = 0
def add_tensor(self, t: torch.Tensor):
assert isinstance(t, torch.Tensor), f"ModelDataTracer add_tensor() should accept a torch.Tensor"
mem_use = col_tensor_mem_usage(t)
self._cuda_usage += mem_use
def delete_tensor(self, t: torch.Tensor):
assert isinstance(t, torch.Tensor), f"ModelDataTracer delete_tensor() should accept a torch.Tensor"
mem_use = col_tensor_mem_usage(t)
self._cuda_usage -= mem_use
@property
def cpu_usage(self):
return self._cpu_usage
@property
def cuda_usage(self):
return self._cuda_usage

View File

@ -0,0 +1,43 @@
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
import torch
def test_mem_collector():
collector = MemStatsCollector()
collector.start_collection()
a = torch.randn(10).cuda()
# sampling at time 0
collector.sample_memstats()
m_a = torch.randn(10).cuda()
ModelDataTracer().add_tensor(m_a)
b = torch.randn(10).cuda()
# sampling at time 1
collector.sample_memstats()
a = b
# sampling at time 2
collector.sample_memstats()
collector.finish_collection()
collector.reset()
# do nothing after collection, just advance sampling cnter
collector.sample_memstats()
collector.sample_memstats()
cuda_use, overall_use = collector.fetch_memstats()
print(cuda_use, overall_use)
print(collector._model_data_cuda)
print(collector._overall_cuda)
if __name__ == '__main__':
test_mem_collector()

View File

@ -6,20 +6,25 @@ from torch.autograd.profiler import profile
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ReduceOp from torch.distributed import ReduceOp
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwith from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwidth
from typing import List, Optional from typing import List, Optional
def _get_code_location(depth: int): def _get_code_location(depth: int):
ret = "" ret = []
length = len(inspect.stack()) length = min(len(inspect.stack()), depth + 1)
for i in range(3, min(length, depth + 1)): for i in range(3, length):
upper_frame = inspect.stack()[i] upper_frame = inspect.stack()[i]
function_name = inspect.stack()[i - 1].function function_name = inspect.stack()[i - 1].function
info = upper_frame.filename + "(" + str(upper_frame.lineno) + "): " + function_name + "\n" ret.append(upper_frame.filename)
ret += info ret.append('(')
ret.append(str(upper_frame.lineno))
ret.append('): ')
ret.append(function_name)
if i != length - 1:
ret.append('\n')
return ret return ''.join(ret)
torch_all_reduce = dist.all_reduce torch_all_reduce = dist.all_reduce
@ -88,41 +93,52 @@ class CommProfiler(BaseProfiler):
dist.reduce = torch_reduce dist.reduce = torch_reduce
def to_tensorboard(self, writer): def to_tensorboard(self, writer):
writer.add_text(tag="Collective Communication", text_string=self.result_list("\n\n")) writer.add_text(tag="Collective Communication", text_string=self.result_str("\n\n"))
def to_file(self, filename: Path): def to_file(self, filename: Path):
with open(filename, "w") as f: with open(filename, "w") as f:
f.write(self.result_list()) f.write(self.result_str())
def show(self): def show(self):
print(self.result_list()) print(self.result_str())
def result_list(self, sep: str = "\n"): def result_str(self, sep: str = "\n"):
res = [] res = []
def append(s: str): def append(s: str = None):
res.append(s) if s is not None:
res.append(s)
res.append(sep) res.append(sep)
if self.warn_flag: if self.warn_flag:
append("Warnning: there exists multiple communication operations in the same time. As a result, " append("Warnning: there exists multiple communication operations in the same time. As a result, "
"the profiling result is not accurate.") "the profiling result is not accurate.")
if self.total_cuda_time == 0:
return "No collective communication has been called yet!"
append("Collective communication profiling result:") append("Collective communication profiling result:")
append("total cuda time: {}".format(_format_time(self.total_cuda_time))) append("total cuda time: {}".format(_format_time(self.total_cuda_time)))
append("average bandwith: {}".format(_format_bandwith(self.total_comm_vol, self.total_cuda_time))) append("average bandwidth: {}".format(_format_bandwidth(self.total_comm_vol, self.total_cuda_time)))
append("total number of calls: {}".format(self.total_count)) append("total number of calls: {}".format(self.total_count))
append("All events:\n----------------------------------------") append("All events:")
seperation = '-' * 74
row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2
append(seperation)
append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls'))
append(seperation)
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time) show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
for location, event in show_list: for location, event in show_list:
append(location) append(location)
append("self cuda time: {}".format(_format_time(event.self_cuda_time))) append(
append("{:.1f}% of total communication time".format(event.self_cuda_time / self.total_cuda_time * 100.0)) row_format.format('', _format_time(event.self_cuda_time),
append("self communication volme: {}".format(_format_memory(event.self_comm_vol))) '{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0),
append("average bandwith: {}".format(_format_bandwith(event.self_comm_vol, event.self_cuda_time))) _format_memory(event.self_comm_vol),
append("number of calls: {}".format(event.self_count)) _format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count))
append("----------------------------------------") append()
return ''.join(res) return ''.join(res)

View File

@ -1,6 +1,6 @@
from pathlib import Path from pathlib import Path
from torch.autograd.profiler import profile from torch.autograd.profiler import profile
from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwith from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwidth
from typing import List from typing import List
@ -24,6 +24,7 @@ def _reduce_location(locations: List[str]) -> str:
for lo in locations: for lo in locations:
ret.append(lo) ret.append(lo)
ret.append("\n") ret.append("\n")
ret = ret[:-1]
return ''.join(ret) return ''.join(ret)
@ -48,18 +49,23 @@ class PcieProfiler(BaseProfiler):
TODO: Merge pcie profiler into communication profiler TODO: Merge pcie profiler into communication profiler
""" """
def __init__(self, def __init__(self, dtype: str = "fp32", depth: int = 1):
dtype: str = "fp32",
depth: int = 1,
total_count: int = 0,
total_pcie_vol: int = 0,
total_cuda_time: int = 0):
super().__init__(profiler_name="Pcie", priority=10) super().__init__(profiler_name="Pcie", priority=10)
self.depth = depth self.depth = depth
self.data_size = _get_size(dtype) self.data_size = _get_size(dtype)
self.total_count = total_count self.h2d_count = 0
self.total_pcie_vol = total_pcie_vol self.h2d_time = 0
self.total_cuda_time = total_cuda_time self.d2h_count = 0
self.d2h_time = 0
self.ops_record = dict()
self.profiler = None
def reset(self):
self.h2d_count = 0
self.h2d_time = 0
self.d2h_count = 0
self.d2h_time = 0
self.ops_record = dict() self.ops_record = dict()
self.profiler = None self.profiler = None
@ -81,51 +87,62 @@ class PcieProfiler(BaseProfiler):
for event in events: for event in events:
if event.name == "aten::copy_": if event.name == "aten::copy_":
t_shape = event.input_shapes[0] t_shape = event.input_shapes[0]
if len(t_shape) == 0 or event.cuda_time_total == 0: if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0:
continue continue
current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total) current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total)
self.total_count += current_comm_event.count
self.total_pcie_vol += current_comm_event.pcie_vol
self.total_cuda_time += current_comm_event.cuda_time
code_location = _reduce_location(event.stack[:self.depth]) code_location = _reduce_location(event.stack[:self.depth])
if code_location in self.ops_record: if code_location in self.ops_record:
self.ops_record[code_location].add(current_comm_event) self.ops_record[code_location].add(current_comm_event)
else: else:
self.ops_record[code_location] = current_comm_event self.ops_record[code_location] = current_comm_event
elif 'Memcpy HtoD' in event.name:
self.h2d_count += 1
self.h2d_time += event.cuda_time_total
elif 'Memcpy DtoH' in event.name:
self.d2h_count += 1
self.d2h_time += event.cuda_time_total
self.profiler = None self.profiler = None
def to_tensorboard(self, writer): def to_tensorboard(self, writer):
writer.add_text(tag="Data Transmission", text_string=self.result_list("\n\n")) writer.add_text(tag="Data Transmission", text_string=self.result_str("\n\n"))
def to_file(self, filename: Path): def to_file(self, filename: Path):
with open(filename, "w") as f: with open(filename, "w") as f:
f.write(self.result_list()) f.write(self.result_str())
def show(self): def show(self):
print(self.result_list()) print(self.result_str())
def result_list(self, sep: str = "\n"): def result_str(self, sep: str = "\n"):
res = [] res = []
def append(s: str): def append(s: str = None):
res.append(s) if s is not None:
res.append(s)
res.append(sep) res.append(sep)
append("Pcie profiling result:") append("Pcie profiling result:")
append("total cuda time: {}".format(_format_time(self.total_cuda_time))) append("time of data transmission (CPU -> GPU): {}".format(_format_time(self.h2d_time)))
append("average bandwith: {}".format(_format_bandwith(self.total_pcie_vol, self.total_cuda_time))) append("number of transmission (CPU -> GPU): {}".format(self.h2d_count))
append("total number of calls: {}".format(self.total_count)) append("time of data transmission (GPU -> CPU): {}".format(_format_time(self.d2h_time)))
append("All events:\n----------------------------------------") append("number of transmission (GPU -> CPU): {}".format(self.d2h_count))
append("Possible data transmission events in PCIE:")
seperation = '-' * 62
row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2
append(seperation)
append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls'))
append(seperation)
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time) show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time)
for location, event in show_list: for location, event in show_list:
append(location) append(location)
append("cuda time: {}".format(_format_time(event.cuda_time))) append(
append("{:.1f}% of total pcie time".format(event.cuda_time / self.total_cuda_time * 100.0)) row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol),
append("pcie volme: {}".format(_format_memory(event.pcie_vol))) _format_bandwidth(event.pcie_vol, event.cuda_time), event.count))
append("average bandwith: {}".format(_format_bandwith(event.pcie_vol, event.cuda_time))) append()
append("number of calls: {}".format(event.count))
append("----------------------------------------")
return ''.join(res) return ''.join(res)

View File

@ -32,7 +32,7 @@ def _format_memory(nbytes):
return str(nbytes) + ' B' return str(nbytes) + ' B'
def _format_bandwith(volme: float or int, time_us: int): def _format_bandwidth(volme: float or int, time_us: int):
sec_div_mb = (1000.0 / 1024.0)**2 sec_div_mb = (1000.0 / 1024.0)**2
mb_per_sec = volme / time_us * sec_div_mb mb_per_sec = volme / time_us * sec_div_mb

View File

@ -0,0 +1 @@
from .tensor_detector import TensorDetector

View File

@ -0,0 +1,128 @@
# Tensor Detector
This tool supports you to detect tensors on both CPU and GPU. However, there will always be some strange tensors on CPU, including the rng state of PyTorch.
## Example
An example is worth than a thousand words.
The code below defines a simple MLP module, with which we will show you how to use the tool.
```python
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.mlp = nn.Sequential(nn.Linear(64, 8),
nn.ReLU(),
nn.Linear(8, 32))
def forward(self, x):
return self.mlp(x)
```
And here is how to use the tool.
```python
from colossalai.utils import TensorDetector
# create random data
data = torch.rand(64, requires_grad=True).cuda()
data.retain_grad()
# create the module
model = MLP().cuda()
# create the detector
# by passing the model to the detector, it can distinguish module parameters from common tensors
detector = TensorDetector(include_cpu=False, module=model)
detector.detect()
out = model(data)
detector.detect()
loss = out.sum()
loss.backward()
detector.detect()
```
I have made some comments on the right of the output for your understanding.
Note that the total `Mem` of all the tensors and parameters is not equal to `Total GPU Memery Allocated`. PyTorch's memory management is really complicated, and for models of a large scale, it's impossible to figure out clearly.
**The order of print is not equal to the order the tensor creates, but they are really close.**
```bash
------------------------------------------------------------------------------------------------------------
Tensor device shape grad dtype Mem
------------------------------------------------------------------------------------------------------------
+ Tensor cuda:0 (64,) True torch.float32 256 B # data
+ mlp.0.weight cuda:0 (8, 64) True torch.float32 2.0 KB
+ mlp.0.bias cuda:0 (8,) True torch.float32 32 B
+ mlp.2.weight cuda:0 (32, 8) True torch.float32 1.0 KB
+ mlp.2.bias cuda:0 (32,) True torch.float32 128 B
------------------------------------------------------------------------------------------------------------
Detect Location: "test_tensor_detector.py" line 27
Totle GPU Memery Allocated on cuda:0 is 4.5 KB
------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------
Tensor device shape grad dtype Mem
------------------------------------------------------------------------------------------------------------
+ Tensor cuda:0 (8,) True torch.float32 32 B # activation
+ Tensor cuda:0 (32,) True torch.float32 128 B # output
------------------------------------------------------------------------------------------------------------
Detect Location: "test_tensor_detector.py" line 30
Totle GPU Memery Allocated on cuda:0 is 5.5 KB
------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------
Tensor device shape grad dtype Mem
------------------------------------------------------------------------------------------------------------
+ Tensor cuda:0 () True torch.float32 4 B # loss
------------------------------------------------------------------------------------------------------------
Detect Location: "test_tensor_detector.py" line 32
Totle GPU Memery Allocated on cuda:0 is 6.0 KB
------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------
Tensor device shape grad dtype Mem
------------------------------------------------------------------------------------------------------------
+ Tensor (with grad) cuda:0 (64,) True torch.float32 512 B # data with grad
+ mlp.0.weight (with grad) cuda:0 (8, 64) True torch.float32 4.0 KB # for use data.retain_grad()
+ mlp.0.bias (with grad) cuda:0 (8,) True torch.float32 64 B
+ mlp.2.weight (with grad) cuda:0 (32, 8) True torch.float32 2.0 KB
+ mlp.2.bias (with grad) cuda:0 (32,) True torch.float32 256 B
- mlp.0.weight cuda:0 (8, 64) True torch.float32 2.0 KB
- mlp.0.bias cuda:0 (8,) True torch.float32 32 B
- mlp.2.weight cuda:0 (32, 8) True torch.float32 1.0 KB
- mlp.2.bias cuda:0 (32,) True torch.float32 128 B
- Tensor cuda:0 (64,) True torch.float32 256 B
- Tensor cuda:0 (8,) True torch.float32 32 B # deleted activation
------------------------------------------------------------------------------------------------------------
Detect Location: "test_tensor_detector.py" line 34
Totle GPU Memery Allocated on cuda:0 is 10.0 KB
------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------
Tensor device shape grad dtype Mem
------------------------------------------------------------------------------------------------------------
+ Tensor cuda:0 (64,) False torch.float32 256 B
+ Tensor cuda:0 (8, 64) False torch.float32 2.0 KB
+ Tensor cuda:0 (8,) False torch.float32 32 B
+ Tensor cuda:0 (32, 8) False torch.float32 1.0 KB
+ Tensor cuda:0 (32,) False torch.float32 128 B
------------------------------------------------------------------------------------------------------------
Detect Location: "test_tensor_detector.py" line 36
Totle GPU Memery Allocated on cuda:0 is 14.0 KB
------------------------------------------------------------------------------------------------------------
```
## Reference
This tool was inspired by https://github.com/Stonesjtu/pytorch_memlab/blob/master/pytorch_memlab/mem_reporter.py
and https://github.com/Oldpan/Pytorch-Memory-Utils

View File

@ -0,0 +1,190 @@
import gc
import inspect
import torch
import torch.nn as nn
from typing import Optional
from collections import defaultdict
LINE_WIDTH = 108
LINE = '-' * LINE_WIDTH + '\n'
class TensorDetector():
def __init__(self,
show_info: bool = True,
log: str = None,
include_cpu: bool = False,
module: Optional[nn.Module] = None
):
"""This class is an detector to detect tensor on different devices.
:param show_info: whether to print the info on screen, default True
:type show_info: bool
:param log: the file name to save the log
:type log: str
:param include_cpu: whether to detect tensor on cpu, default False
:type include_cpu: bool
:param module: when sending an `nn.Module` it, the detector can name the tensors detected better
:type module: Optional[nn.Module]
"""
self.show_info = show_info
self.log = log
self.include_cpu = include_cpu
self.tensor_info = defaultdict(list)
self.saved_tensor_info = defaultdict(list)
self.order = []
self.detected = []
self.devices = []
self.info = ""
self.module = module
if isinstance(module, nn.Module):
# if module is an instance of nn.Module, we can name the parameter with its real name
for name, param in module.named_parameters():
self.tensor_info[id(param)].append(name)
self.tensor_info[id(param)].append(param.device)
self.tensor_info[id(param)].append(param.shape)
self.tensor_info[id(param)].append(param.requires_grad)
self.tensor_info[id(param)].append(param.dtype)
self.tensor_info[id(param)].append(self.get_tensor_mem(param))
def get_tensor_mem(self, tensor):
# calculate the memory occupied by a tensor
memory_size = tensor.element_size() * tensor.storage().size()
if (tensor.is_leaf or tensor.retains_grad) and tensor.grad is not None:
grad_memory_size = tensor.grad.element_size() * tensor.grad.storage().size()
memory_size += grad_memory_size
return self.mem_format(memory_size)
def mem_format(self, real_memory_size):
# format the tensor memory into a reasonal magnitude
if real_memory_size >= 2 ** 30:
return str(real_memory_size / (2 ** 30)) + ' GB'
if real_memory_size >= 2 ** 20:
return str(real_memory_size / (2 ** 20)) + ' MB'
if real_memory_size >= 2 ** 10:
return str(real_memory_size / (2 ** 10)) + ' KB'
return str(real_memory_size) + ' B'
def collect_tensors_state(self):
for obj in gc.get_objects():
if torch.is_tensor(obj):
# skip cpu tensor when include_cpu is false and the tensor we have collected before
if (not self.include_cpu) and obj.device == torch.device('cpu'):
continue
self.detected.append(id(obj))
# skip paramters we had added in __init__ when module is an instance of nn.Module for the first epoch
if id(obj) not in self.tensor_info:
name = type(obj).__name__
# after backward, we want to update the records, to show you the change
if isinstance(self.module, nn.Module) and name == 'Parameter':
if obj.grad is not None:
# with grad attached
for par_name, param in self.module.named_parameters():
if param.requires_grad and param.grad.equal(obj.grad):
name = par_name + ' (with grad)'
else:
# with no grad attached
# there will be no new paramters created during running
# so it must be in saved_tensor_info
continue
# we can also marked common tensors as tensor(with grad)
if name == 'Tensor' and (obj.is_leaf or obj.retains_grad):
if obj.grad is not None:
name = name + ' (with grad)'
# in fact, common tensor have no grad
# unless you set retain_grad()
if id(obj) in self.saved_tensor_info.keys() and name == self.saved_tensor_info[id(obj)][0]:
continue
self.tensor_info[id(obj)].append(name)
self.tensor_info[id(obj)].append(obj.device)
self.tensor_info[id(obj)].append(obj.shape)
self.tensor_info[id(obj)].append(obj.requires_grad)
self.tensor_info[id(obj)].append(obj.dtype)
self.tensor_info[id(obj)].append(self.get_tensor_mem(obj))
# recorded the order we got the tensor
# by this we can guess the tensor easily
# it will record every tensor updated this turn
self.order.append(id(obj))
# recorded all different devices
if obj.device not in self.devices:
self.devices.append(obj.device)
def print_tensors_state(self):
template_format = '{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}'
self.info += LINE
self.info += template_format.format(' ', 'Tensor', 'device', 'shape', 'grad', 'dtype', 'Mem')
self.info += '\n'
self.info += LINE
# if a tensor updates this turn, and was recorded before
# it should be updated in the saved_tensor_info as well
outdated = [x for x in self.saved_tensor_info.keys() if x in self.order]
minus = [x for x in self.saved_tensor_info.keys() if x not in self.detected]
minus = outdated + minus
if len(self.order) > 0:
for tensor_id in self.order:
self.info += template_format.format('+',
str(self.tensor_info[tensor_id][0]),
str(self.tensor_info[tensor_id][1]),
str(tuple(self.tensor_info[tensor_id][2])),
str(self.tensor_info[tensor_id][3]),
str(self.tensor_info[tensor_id][4]),
str(self.tensor_info[tensor_id][5]))
self.info += '\n'
if len(self.order) > 0 and len(minus) > 0:
self.info += '\n'
if len(minus) > 0:
for tensor_id in minus:
self.info += template_format.format('-',
str(self.saved_tensor_info[tensor_id][0]),
str(self.saved_tensor_info[tensor_id][1]),
str(tuple(self.saved_tensor_info[tensor_id][2])),
str(self.saved_tensor_info[tensor_id][3]),
str(self.saved_tensor_info[tensor_id][4]),
str(self.saved_tensor_info[tensor_id][5]))
self.info += '\n'
# deleted the updated tensor
self.saved_tensor_info.pop(tensor_id)
# trace where is the detect()
locate_info = inspect.stack()[2]
locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno)
self.info += LINE
self.info += f"Detect Location: {locate_msg}\n"
for device in self.devices:
if device == torch.device('cpu'):
continue
gpu_mem_alloc = self.mem_format(torch.cuda.memory_allocated(device))
self.info += f"Totle GPU Memery Allocated on {device} is {gpu_mem_alloc}\n"
self.info += LINE
self.info += '\n\n'
if self.show_info:
print(self.info)
if self.log is not None:
with open(self.log + '.log', 'a') as f:
f.write(self.info)
def detect(self, include_cpu = False):
self.include_cpu = include_cpu
self.collect_tensors_state()
self.print_tensors_state()
self.saved_tensor_info.update(self.tensor_info)
self.tensor_info.clear()
self.order = []
self.detected = []
self.info = ""
def close(self):
self.saved_tensor_info.clear()
self.module = None

View File

@ -3,10 +3,11 @@ import functools
import torch import torch
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
# Inserts _post_init_method at the end of init method # Inserts _post_init_method at the end of init method
# for all sub classes of torch.nn.Module # for all sub classes of torch.nn.Module
class InsertPostInitMethodToModuleSubClasses(object): class InsertPostInitMethodToModuleSubClasses(object):
@ -152,7 +153,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if self.shard_param: if self.shard_param:
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor]) self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._data_sharded_tensor.payload) ModelDataTracer().add_tensor(param.col_attr._data_sharded_tensor.payload)
if param.col_attr.grad and self.shard_grad: if param.col_attr.grad and self.shard_grad:
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor]) self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._grad_sharded_tensor.payload) ModelDataTracer().add_tensor(param.col_attr._grad_sharded_tensor.payload)

View File

@ -1,4 +1,5 @@
from colossalai.zero.shard_utils.base_shard_strategy import BaseShardStrategy from .base_shard_strategy import BaseShardStrategy
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
from .tensor_shard_strategy import TensorShardStrategy
__all__ = ['BaseShardStrategy', 'TensorShardStrategy'] __all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']

View File

@ -0,0 +1,41 @@
from typing import List
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from torch._utils import _flatten_dense_tensors as flatten
from .tensor_shard_strategy import TensorShardStrategy
class BucketTensorShardStrategy(TensorShardStrategy):
def gather(self, tensor_list: List[ShardedTensor]):
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
if len(tensor_list) == 0:
return
target_device = tensor_list[0].device
dtype = tensor_list[0].dtype
buffer_list: List[torch.Tensor] = []
tensor_numels = [t.payload.numel() for t in tensor_list]
buffer_size = sum(tensor_numels)
for i in range(self.world_size):
if i == self.local_rank:
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
# Release payload here, to decrease peak memory usage
for t in tensor_list:
t.reset_payload(None)
else:
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
dist.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group)
# Move to target device before splitting buffer
# Ensure we utilize maximum PCIE bandwidth
buffer_list = [buffer.to(target_device) for buffer in buffer_list]
offset = 0
for i, t in enumerate(tensor_list):
gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list]
gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape)
t.reset_payload(gathered_payload)
t.is_sharded = False
offset += tensor_numels[i]

View File

@ -17,7 +17,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
get_gradient_predivide_factor) get_gradient_predivide_factor)
@ -33,7 +34,8 @@ class ShardedModelV2(nn.Module):
fp32_reduce_scatter: bool = False, fp32_reduce_scatter: bool = False,
offload_config: Optional[dict] = None, offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0, gradient_predivide_factor: Optional[float] = 1.0,
shard_param: bool = True): shard_param: bool = True,
use_memory_tracer: bool = False):
r""" r"""
A demo to reconfigure zero1 shared_model. A demo to reconfigure zero1 shared_model.
Currently do not consider the Optimizer States. Currently do not consider the Optimizer States.
@ -59,8 +61,16 @@ class ShardedModelV2(nn.Module):
if self.shard_param: if self.shard_param:
self.shard_strategy.shard([param.col_attr.data]) self.shard_strategy.shard([param.col_attr.data])
# Init Memory Statistics Collector
self._use_memory_tracer = use_memory_tracer
if self._use_memory_tracer:
self._memstats_collector = MemStatsCollector()
else:
self._memstats_collector = None
self._iter_cnter = 0
# Register hooks # Register hooks
register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy)]) register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy, self._memstats_collector)])
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters())) self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
@ -84,6 +94,9 @@ class ShardedModelV2(nn.Module):
return self._cpu_offload return self._cpu_offload
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
if self._iter_cnter == 0 and self._memstats_collector:
# the opeartion will affect the flag in ZeroHook
self._memstats_collector.start_collection()
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
outputs = self.module(*args, **kwargs) outputs = self.module(*args, **kwargs)
return outputs return outputs
@ -98,6 +111,12 @@ class ShardedModelV2(nn.Module):
@torch.no_grad() @torch.no_grad()
def _final_backward_hook(self) -> None: def _final_backward_hook(self) -> None:
if self._iter_cnter == 0 and self._memstats_collector:
self._memstats_collector.finish_collection()
if self._memstats_collector:
self._memstats_collector.reset_sampling_cnter()
self._iter_cnter += 1
if self._require_backward_grad_sync: if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream. # Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self.comm_stream): with torch.cuda.stream(self.comm_stream):
@ -185,8 +204,10 @@ class ShardedModelV2(nn.Module):
reduced_grad.data = cast_tensor_to_fp32(reduced_grad.data) reduced_grad.data = cast_tensor_to_fp32(reduced_grad.data)
# Maybe offload # Maybe offload
# TODO() optimize GPU->CPU bandwidth utilization
if self._cpu_offload: if self._cpu_offload:
reduced_grad.data = reduced_grad.data.cpu() col_move_to_cpu(reduced_grad)
# reduced_grad.data = reduced_grad.data.cpu()
if param.col_attr.grad is None: if param.col_attr.grad is None:
param.col_attr.grad = reduced_grad.data param.col_attr.grad = reduced_grad.data

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Dict, Optional from typing import Callable, Dict, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -15,7 +15,7 @@ from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from typing import Type, Any
from ._utils import has_inf_or_nan from ._utils import has_inf_or_nan
@ -27,8 +27,8 @@ class OptimState(Enum):
class ShardedOptimizerV2(ColossalaiOptimizer): class ShardedOptimizerV2(ColossalaiOptimizer):
def __init__(self, def __init__(self,
optimizer: Optimizer,
sharded_model: ShardedModelV2, sharded_model: ShardedModelV2,
optimizer_class: Type[Optimizer],
shard_strategy: BaseShardStrategy, shard_strategy: BaseShardStrategy,
cpu_offload: bool = False, cpu_offload: bool = False,
initial_scale: float = 2**32, initial_scale: float = 2**32,
@ -39,9 +39,34 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
hysteresis: float = 2, hysteresis: float = 2,
max_scale: int = 2**32, max_scale: int = 2**32,
dp_process_group: Optional[ProcessGroup] = None, dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None: mp_process_group: Optional[ProcessGroup] = None,
**defaults: Any) -> None:
"""
:param sharded_model: A sharded model initialized by class ShardedModelV2
:type sharded_model: sharded_model
:param optimizer_class: A type of Optimizer
:type optimizer_class: Type[Optimizer]
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:param cpu_offload: is offloading the optimizer states to CPU.
:type cpu_offload: bool
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:**defaults: any trailing arguments, which are forwarded to the local optimizer.
:type defaults: dict()
"""
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
super().__init__(optimizer)
self._optim_defaults = defaults
# initialize the M, V as zeros tensors and initialize param fp32 from sharded_model.parameters()
self.optimizer = optimizer_class(sharded_model.parameters(), **self._optim_defaults)
super().__init__(self.optimizer)
self.shard_strategy = shard_strategy self.shard_strategy = shard_strategy
self.model: ShardedModelV2 = sharded_model self.model: ShardedModelV2 = sharded_model
if cpu_offload and not sharded_model.cpu_offload: if cpu_offload and not sharded_model.cpu_offload:
@ -65,7 +90,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Store fp32 param shards # Store fp32 param shards
self.master_params: Dict[Parameter, Tensor] = {} self.master_params: Dict[Parameter, Tensor] = {}
for group in optimizer.param_groups: for group in self.optimizer.param_groups:
for p in group['params']: for p in group['params']:
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam' assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam'
is_param_sharded = p.col_attr.data.is_sharded is_param_sharded = p.col_attr.data.is_sharded
@ -118,7 +143,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# We have to use `copy_payload` instead of `reset_payload` # We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.col_attr.data is fp16 # Since p.data is fp32 and p.col_attr.data is fp16
# TODO() optimize this line # TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.col_attr.data.copy_payload(p.data) p.col_attr.data.copy_payload(p.data)
if not is_param_sharded: if not is_param_sharded:

@ -1 +1 @@
Subproject commit d50ef2db51e7d02ed3f7e9de13f9af86b04eaae9 Subproject commit 5345187ad55e8c80c111e0c5f7ad9b9241e8f913

View File

@ -74,8 +74,5 @@ def get_training_components():
sequence_length=sequence_length, sequence_length=sequence_length,
is_distrbuted=True) is_distrbuted=True)
def get_optim(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = None criterion = None
return bert_model_builder, trainloader, testloader, get_optim, criterion return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -49,8 +49,5 @@ def get_training_components():
trainloader = DummyDataLoader() trainloader = DummyDataLoader()
testloader = DummyDataLoader() testloader = DummyDataLoader()
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, optim_builder, criterion return model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -43,8 +43,5 @@ def get_training_components():
trainloader = DummyDataLoader() trainloader = DummyDataLoader()
testloader = DummyDataLoader() testloader = DummyDataLoader()
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, optim_builder, criterion return model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -29,8 +29,5 @@ def get_resnet_training_components():
trainloader = get_cifar10_dataloader(train=True) trainloader = get_cifar10_dataloader(train=True)
testloader = get_cifar10_dataloader(train=False) testloader = get_cifar10_dataloader(train=False)
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, optim_builder, criterion return model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -19,11 +19,11 @@ def run_train():
# FIXME: test bert # FIXME: test bert
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func() model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
model = model_builder(checkpoint=False) model = model_builder(checkpoint=False)
engine, train_dataloader, *args = colossalai.initialize(model=model, engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer_builder(model), optimizer=optimizer_class(model.parameters(), lr=1e-3),
criterion=criterion, criterion=criterion,
train_dataloader=train_dataloader) train_dataloader=train_dataloader)
@ -84,7 +84,7 @@ def run_engine(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
def test_engine(): def test_engine():
world_size = 4 world_size = 2
run_func = partial(run_engine, world_size=world_size, port=free_port()) run_func = partial(run_engine, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -25,9 +25,9 @@ def run_trainer_no_pipeline(rank, world_size, port):
test_models = ['repeated_computed_layers', 'resnet18', 'nested_model'] test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
for name in test_models: for name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(name) get_components_func = non_distributed_component_funcs.get_callable(name)
model_builder, train_dataloader, test_dataloader, optimizer_builder, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder() model = model_builder()
optimizer = optimizer_builder(model) optimizer = optimizer_class(model.parameters(), lr=1e-3)
engine, train_dataloader, *_ = colossalai.initialize(model=model, engine, train_dataloader, *_ = colossalai.initialize(model=model,
optimizer=optimizer, optimizer=optimizer,
criterion=criterion, criterion=criterion,

View File

@ -56,6 +56,7 @@ def test_activation_checkpointing(cpu_offload):
assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match' assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match'
torch.cuda.empty_cache() torch.cuda.empty_cache()
# as seed manager is singleton # as seed manager is singleton
# if we don't reset seeds here, # if we don't reset seeds here,
# other tests will fail if running together with this test # other tests will fail if running together with this test

View File

@ -4,21 +4,20 @@
from functools import partial from functools import partial
import colossalai import colossalai
from colossalai.utils.cuda import get_current_device
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils.tensor_shard_strategy import \ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
TensorShardStrategy
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG from common import CONFIG
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
def run_dist(rank, world_size, port, init_device): def run_dist(rank, world_size, port, init_device, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for get_components_func in non_distributed_component_funcs: for get_components_func in non_distributed_component_funcs:
@ -26,7 +25,7 @@ def run_dist(rank, world_size, port, init_device):
model_numel_tensor = torch.zeros(1, dtype=torch.int) model_numel_tensor = torch.zeros(1, dtype=torch.int)
with ZeroInitContext(convert_fp16=True, with ZeroInitContext(convert_fp16=True,
target_device=init_device, target_device=init_device,
shard_strategy=TensorShardStrategy(), shard_strategy=shard_strategy(),
shard_param=True, shard_param=True,
model_numel_tensor=model_numel_tensor): model_numel_tensor=model_numel_tensor):
model = model_builder(checkpoint=True) model = model_builder(checkpoint=True)
@ -38,23 +37,25 @@ def run_dist(rank, world_size, port, init_device):
assert param.col_attr.data.payload.device.type == init_device.type, \ assert param.col_attr.data.payload.device.type == init_device.type, \
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}' f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
print(f'cpu usgae {GLOBAL_MODEL_DATA_TRACER.cpu_usage}') print(f'cuda usgae {ModelDataTracer().cuda_usage}')
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
print(f'numel {model_numel_tensor}') print(f'numel {model_numel_tensor}')
if init_device.type == 'cuda': if init_device.type == 'cuda':
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0) assert (ModelDataTracer().cuda_usage > 0)
elif init_device.type == 'cpu':
assert (GLOBAL_MODEL_DATA_TRACER.cpu_usage > 0)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 4]) @pytest.mark.parametrize("world_size", [1, 4])
@pytest.mark.parametrize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')]) @pytest.mark.parametrize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
def test_zero_init_context(world_size, init_device): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(run_dist, world_size=world_size, port=free_port(), init_device=init_device) def test_zero_init_context(world_size, init_device, shard_strategy):
run_func = partial(run_dist,
world_size=world_size,
port=free_port(),
init_device=init_device,
shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_zero_init_context(2, torch.device('cpu')) # test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
test_zero_init_context(2, torch.device(f'cuda:{get_current_device()}')) test_zero_init_context(4, torch.device('cpu'), BucketTensorShardStrategy)

View File

@ -3,30 +3,29 @@
import copy import copy
from functools import partial from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.zero.init_ctx import ZeroInitContext import pytest
import torch
import torch.multiprocessing as mp
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils.tensor_shard_strategy import \ from colossalai.zero.init_ctx import ZeroInitContext
TensorShardStrategy from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads_padding, run_fwd_bwd from common import CONFIG, check_grads_padding, run_fwd_bwd
from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_model.utils import col_model_deepcopy
def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast): def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = TensorShardStrategy() shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func() model_builder, train_dataloader, _, _, criterion = get_components_func()
@ -35,12 +34,12 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
if use_zero_init_ctx: if use_zero_init_ctx:
with ZeroInitContext(convert_fp16=True, with ZeroInitContext(convert_fp16=True,
target_device=torch.device('cpu'), target_device=torch.device(f'cpu:0'),
shard_strategy=shard_strategy, shard_strategy=shard_strategy,
shard_param=True, shard_param=True,
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly): rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly):
zero_model = model_builder(checkpoint=True) zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model, shard_strategy) zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
model = model_builder(checkpoint=True).half() model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model) col_model_deepcopy(zero_model, model)
@ -61,19 +60,24 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
check_grads_padding(model, zero_model, loose=True) check_grads_padding(model, zero_model, loose=True)
print('overall cuda ', zero_model._memstats_collector._overall_cuda)
print('model cuda ', zero_model._memstats_collector._model_data_cuda)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("enable_autocast", [True]) @pytest.mark.parametrize("enable_autocast", [True])
@pytest.mark.parametrize("use_zero_init_ctx", [True]) @pytest.mark.parametrize("use_zero_init_ctx", [True])
def test_shard_model_v2(world_size, use_zero_init_ctx, enable_autocast): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def test_shard_model_v2(world_size, use_zero_init_ctx, enable_autocast, shard_strategy):
run_func = partial(run_dist, run_func = partial(run_dist,
world_size=world_size, world_size=world_size,
port=free_port(), port=free_port(),
use_zero_init_ctx=use_zero_init_ctx, use_zero_init_ctx=use_zero_init_ctx,
enable_autocast=enable_autocast) enable_autocast=enable_autocast,
shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_shard_model_v2(world_size=2, use_zero_init_ctx=True, enable_autocast=True) test_shard_model_v2(world_size=2, use_zero_init_ctx=True, enable_autocast=True, shard_strategy=TensorShardStrategy)

View File

@ -10,20 +10,20 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_param import ShardedParam, ShardedTensor from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from tests.test_zero_data_parallel.common import CONFIG, allclose
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_zero_data_parallel.common import CONFIG, allclose
def _run_shard_tensor(rank, world_size, port): def _run_shard_tensor(rank, world_size, port, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3)) t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
assert list(t.origin_shape) == [world_size * 2, 3] assert list(t.origin_shape) == [world_size * 2, 3]
assert list(t.shape) == [world_size * 2, 3] assert list(t.shape) == [world_size * 2, 3]
shard_strategy = TensorShardStrategy(process_group=None) shard_strategy = shard_strategy(process_group=None)
# test shard strategy # test shard strategy
shard_strategy.shard([t]) shard_strategy.shard([t])
@ -34,8 +34,9 @@ def _run_shard_tensor(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
def test_shard_tensor(world_size): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port()) def test_shard_tensor(world_size, shard_strategy):
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
@ -121,7 +122,7 @@ def test_init_shard_param(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_shard_tensor(2) test_shard_tensor(2, TensorShardStrategy)
test_shard_param(2) test_shard_param(2)
test_shard_param_v2(2) test_shard_param_v2(2)
test_init_shard_param(4) test_init_shard_param(4)

View File

@ -1,6 +1,3 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy import copy
from functools import partial from functools import partial
@ -10,7 +7,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
@ -38,25 +35,27 @@ def run_step(model, optimizer, data, label, criterion, enable_autocast=False):
optimizer.step() optimizer.step()
def run_dist(rank, world_size, port, cpu_offload): def run_dist(rank, world_size, port, cpu_offload, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
shard_strategy = TensorShardStrategy() model, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model(checkpoint=True).cuda() model = model(checkpoint=True).cuda()
zero_model = ShardedModelV2(copy.deepcopy(model), zero_model = ShardedModelV2(copy.deepcopy(model),
shard_strategy, shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None) offload_config=dict(device='cpu') if cpu_offload else None)
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
model = DDP(model) model = DDP(model)
optim = Adam(model.parameters(), lr=1e-3) lr = 1e-3
sharded_optim = ShardedOptimizerV2(Adam(zero_model.parameters(), lr=1e-3), optim = optimizer_class(model.parameters(), lr=lr)
zero_model, sharded_optim = ShardedOptimizerV2(zero_model,
optimizer_class,
shard_strategy, shard_strategy,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
initial_scale=2**5) initial_scale=2**5,
lr=lr)
for i, (data, label) in enumerate(train_dataloader): for i, (data, label) in enumerate(train_dataloader):
if i > 2: if i > 2:
break break
@ -69,10 +68,15 @@ def run_dist(rank, world_size, port, cpu_offload):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("cpu_offload", [True, False]) @pytest.mark.parametrize("cpu_offload", [True, False])
def test_sharded_optim_v2(world_size, cpu_offload): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(run_dist, world_size=world_size, port=free_port(), cpu_offload=cpu_offload) def test_sharded_optim_v2(world_size, cpu_offload, shard_strategy):
run_func = partial(run_dist,
world_size=world_size,
port=free_port(),
cpu_offload=cpu_offload,
shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_sharded_optim_v2(world_size=2, cpu_offload=True) test_sharded_optim_v2(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy)

View File

@ -11,7 +11,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.nn.optimizer import CPUAdam from colossalai.nn.optimizer import CPUAdam
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
@ -47,23 +47,24 @@ def run_step_no_criterion(model, optimizer, data, label, enable_autocast=False):
optimizer.step() optimizer.step()
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
shard_strategy = TensorShardStrategy()
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model(checkpoint=True).cuda() model = model(checkpoint=True).cuda()
zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy, offload_config={'device': 'cpu'}) zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy, offload_config={'device': 'cpu'})
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
model = DDP(model) model = DDP(model)
optim = Adam(model.parameters(), lr=1e-3) optim = Adam(model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(CPUAdam(zero_model.parameters(), lr=1e-3), sharded_optim = ShardedOptimizerV2(zero_model,
zero_model, CPUAdam,
shard_strategy, shard_strategy,
initial_scale=2**5, initial_scale=2**5,
cpu_offload=True) cpu_offload=True,
lr=1e-3)
for i, (data, label) in enumerate(train_dataloader): for i, (data, label) in enumerate(train_dataloader):
if i > 2: if i > 2:
break break
@ -79,10 +80,11 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
def test_sharded_optim_v2(world_size): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(run_dist, world_size=world_size, port=free_port()) def test_sharded_optim_v2(world_size, shard_strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_sharded_optim_v2(world_size=2) test_sharded_optim_v2(world_size=2, shard_strategy=TensorShardStrategy)

View File

@ -9,22 +9,21 @@ import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils.tensor_shard_strategy import \ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG from common import CONFIG
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18'] test_models = ['repeated_computed_layers', 'resnet18']
shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model_builder() model = model_builder()
shard_strategy = TensorShardStrategy()
model = model.half().cuda() model = model.half().cuda()
zero_model = ShardedModelV2(deepcopy(model), shard_strategy) zero_model = ShardedModelV2(deepcopy(model), shard_strategy)
zero_state_dict = zero_model.state_dict() zero_state_dict = zero_model.state_dict()
@ -33,11 +32,12 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
def test_zero_state_dict(): @pytest.mark.parametrize("world_size", [1, 2])
world_size = 2 @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(run_dist, world_size=world_size, port=free_port()) def test_zero_state_dict(world_size, shard_strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_zero_state_dict() test_zero_state_dict(2, TensorShardStrategy)