mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 16:00:49 +00:00
[zero] add tensor placement policies (#743)
* add tensor placement policies * polish comments * polish comments * update moe unit tests
This commit is contained in:
@@ -13,14 +13,12 @@ MP_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), parallel=dict(pipeline=dict(siz
|
||||
|
||||
_ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
|
||||
fp32_reduce_scatter=False,
|
||||
offload_config=None,
|
||||
tensor_placement_policy='cuda',
|
||||
gradient_predivide_factor=1.0,
|
||||
use_memory_tracer=False,
|
||||
shard_strategy=TensorShardStrategy(),
|
||||
reuse_fp16_shard=False)
|
||||
|
||||
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
|
||||
initial_scale=2**5,
|
||||
_ZERO_OPTIMIZER_CONFIG = dict(initial_scale=2**5,
|
||||
min_scale=1,
|
||||
growth_factor=2,
|
||||
backoff_factor=0.5,
|
||||
|
@@ -37,16 +37,12 @@ def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio)
|
||||
zero_model = ShardedModelV2(
|
||||
zero_model,
|
||||
shard_strategy,
|
||||
offload_config=dict(device='cpu') if cpu_offload else None,
|
||||
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
|
||||
tensor_placement_policy='cpu' if cpu_offload else 'cuda',
|
||||
reuse_fp16_shard=True,
|
||||
)
|
||||
|
||||
sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3)
|
||||
sharded_optim = ShardedOptimizerV2(zero_model,
|
||||
sharded_optim,
|
||||
cpu_offload=cpu_offload,
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
||||
sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 1:
|
||||
|
@@ -33,7 +33,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy)
|
||||
|
||||
model = model_builder(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
|
@@ -64,8 +64,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
||||
zero_model = ShardedModelV2(
|
||||
zero_model,
|
||||
shard_strategy,
|
||||
offload_config=dict(device='cpu') if cpu_offload else None,
|
||||
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
|
||||
tensor_placement_policy='cpu' if cpu_offload else 'cuda',
|
||||
reuse_fp16_shard=use_cpuadam,
|
||||
)
|
||||
|
||||
@@ -79,7 +78,6 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
||||
sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
|
||||
sharded_optim = ShardedOptimizerV2(zero_model,
|
||||
sharded_optim,
|
||||
cpu_offload=cpu_offload,
|
||||
initial_scale=2**5,
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
||||
|
||||
|
@@ -14,6 +14,7 @@ from colossalai.testing import rerun_on_exception
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import List
|
||||
from functools import partial
|
||||
from colossalai.zero.utils.tensor_placement_policy import AutoTensorPlacementPolicy
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
@@ -37,7 +38,8 @@ def run_stm():
|
||||
p.colo_attr = ShardedParamV2(p, set_data_none=True)
|
||||
GLOBAL_MODEL_DATA_TRACER.register_model(model)
|
||||
mem_collector = MemStatsCollector()
|
||||
stateful_tensor_mgr = StatefulTensorMgr(mem_collector)
|
||||
tensor_placement_policy = AutoTensorPlacementPolicy(mem_stats_collector=mem_collector)
|
||||
stateful_tensor_mgr = StatefulTensorMgr(tensor_placement_policy)
|
||||
for p in model.parameters():
|
||||
stateful_tensor_mgr.register_stateful_param(p.colo_attr)
|
||||
|
||||
|
Reference in New Issue
Block a user