[example] update gpt readme with performance (#2206)

This commit is contained in:
Jiarui Fang
2022-12-27 17:39:53 +08:00
committed by GitHub
parent 1cb532ffec
commit 29868a9ec1
3 changed files with 47 additions and 10 deletions

View File

@@ -179,13 +179,17 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
cai_version = colossalai.__version__
from colossalai.gemini import ChunkManager, GeminiManager
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placememt_policy,
pin_memory=True,
search_range_mb=32)
hidden_dim=4096,
search_range_mb=64)
if placememt_policy == 'const':
model.gemini_manager._placement_policy.set_const_memory_boundary(10 * 1024)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
@@ -206,9 +210,10 @@ def main():
if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
raise TypeError(f"{args.distplan} is error")
BATCH_SIZE = 8
BATCH_SIZE = 64
SEQ_LEN = 1024
VOCAB_SIZE = 50257
NUM_STEPS = 10
disable_existing_loggers()
@@ -227,22 +232,21 @@ def main():
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
# build GPT model
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
model = gpt2_medium(checkpoint=True)
with ColoInitContext(device=get_current_device(), default_dist_spec=default_dist_spec, default_pg=default_pg):
model = gpt2_10b(checkpoint=True)
pg = default_pg
# Tensor Parallelism (TP)
tensor_parallelize(model, pg)
# Gemini + ZeRO DP, Note it must be used after TP
model = gemini_zero_dpp(model, pg, args.placement)
# build optimizer
# build highly optimized cpu optimizer
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
# optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
else:
model = gpt2_medium(checkpoint=True).cuda()
model = gpt2_10b(checkpoint=True).cuda()
if args.distplan.startswith("torch"):
model = DDP(model)