mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-11 18:01:05 +00:00
[example] update gpt readme with performance (#2206)
This commit is contained in:
@@ -179,13 +179,17 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
# Gemini + ZeRO DDP
|
||||
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
|
||||
cai_version = colossalai.__version__
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
if version.parse(cai_version) > version.parse("0.1.10"):
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
model = GeminiDDP(model,
|
||||
device=get_current_device(),
|
||||
placement_policy=placememt_policy,
|
||||
pin_memory=True,
|
||||
search_range_mb=32)
|
||||
hidden_dim=4096,
|
||||
search_range_mb=64)
|
||||
if placememt_policy == 'const':
|
||||
model.gemini_manager._placement_policy.set_const_memory_boundary(10 * 1024)
|
||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
||||
@@ -206,9 +210,10 @@ def main():
|
||||
if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
|
||||
raise TypeError(f"{args.distplan} is error")
|
||||
|
||||
BATCH_SIZE = 8
|
||||
BATCH_SIZE = 64
|
||||
SEQ_LEN = 1024
|
||||
VOCAB_SIZE = 50257
|
||||
|
||||
NUM_STEPS = 10
|
||||
|
||||
disable_existing_loggers()
|
||||
@@ -227,22 +232,21 @@ def main():
|
||||
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
|
||||
|
||||
# build GPT model
|
||||
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
|
||||
model = gpt2_medium(checkpoint=True)
|
||||
with ColoInitContext(device=get_current_device(), default_dist_spec=default_dist_spec, default_pg=default_pg):
|
||||
model = gpt2_10b(checkpoint=True)
|
||||
|
||||
pg = default_pg
|
||||
# Tensor Parallelism (TP)
|
||||
tensor_parallelize(model, pg)
|
||||
|
||||
# Gemini + ZeRO DP, Note it must be used after TP
|
||||
model = gemini_zero_dpp(model, pg, args.placement)
|
||||
|
||||
# build optimizer
|
||||
# build highly optimized cpu optimizer
|
||||
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
|
||||
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
# optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
|
||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||
else:
|
||||
model = gpt2_medium(checkpoint=True).cuda()
|
||||
model = gpt2_10b(checkpoint=True).cuda()
|
||||
|
||||
if args.distplan.startswith("torch"):
|
||||
model = DDP(model)
|
||||
|
Reference in New Issue
Block a user