mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-21 09:23:49 +00:00
[feat] update test; rm comments;
This commit is contained in:
parent
a7b767b071
commit
6d18d38d5c
@ -28,7 +28,8 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
|||||||
from colossalai.interface.optimizer import DistributedOptim
|
from colossalai.interface.optimizer import DistributedOptim
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
|
||||||
|
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||||
@ -1092,8 +1093,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
self.custom_policy = custom_policy
|
self.custom_policy = custom_policy
|
||||||
assert zero_stage in (0, 1, 2)
|
assert zero_stage in (0, 1, 2)
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
|
||||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
assert (
|
||||||
|
pp_style == "interleaved" or pp_style == "zbv"
|
||||||
|
) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
||||||
assert (
|
assert (
|
||||||
num_microbatches is not None or microbatch_size is not None
|
num_microbatches is not None or microbatch_size is not None
|
||||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||||
@ -1103,7 +1106,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
self.stage_manager = PipelineStageManager(
|
self.stage_manager = PipelineStageManager(
|
||||||
self.pg_mesh,
|
self.pg_mesh,
|
||||||
pipeline_axis=self.pp_axis,
|
pipeline_axis=self.pp_axis,
|
||||||
enable_interleave=pp_style == "interleaved",
|
enable_interleave=(pp_style == "interleaved") or (pp_style == "zbv"),
|
||||||
num_model_chunks=num_model_chunks,
|
num_model_chunks=num_model_chunks,
|
||||||
num_layers_per_stage=num_layers_per_stage,
|
num_layers_per_stage=num_layers_per_stage,
|
||||||
)
|
)
|
||||||
@ -1125,6 +1128,31 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
microbatch_size=microbatch_size,
|
microbatch_size=microbatch_size,
|
||||||
enable_metadata_cache=enable_metadata_cache,
|
enable_metadata_cache=enable_metadata_cache,
|
||||||
)
|
)
|
||||||
|
elif pp_style == "zbv":
|
||||||
|
h, a, s = 4096, 32, 1024
|
||||||
|
mem_f = 34 * h + 5 * a * s
|
||||||
|
mem_w = -32 * h
|
||||||
|
mem_b = -mem_w - mem_f
|
||||||
|
zbv_schedule = PipelineGraph(
|
||||||
|
n_stage=self.pp_size,
|
||||||
|
n_micro=num_microbatches,
|
||||||
|
f_cost=1,
|
||||||
|
b_cost=1,
|
||||||
|
w_cost=1,
|
||||||
|
c_cost=1,
|
||||||
|
f_mem=mem_f,
|
||||||
|
b_mem=mem_b,
|
||||||
|
w_mem=mem_w,
|
||||||
|
).get_v_schedule()
|
||||||
|
self.schedule = ZeroBubbleVPipeScheduler(
|
||||||
|
schedule=zbv_schedule,
|
||||||
|
stage_manager=self.stage_manager,
|
||||||
|
num_model_chunks=num_model_chunks,
|
||||||
|
num_microbatch=num_microbatches,
|
||||||
|
microbatch_size=microbatch_size,
|
||||||
|
enable_metadata_cache=enable_metadata_cache,
|
||||||
|
overlap_p2p=overlap_p2p,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if sequence_parallelism_mode == "ring_attn":
|
if sequence_parallelism_mode == "ring_attn":
|
||||||
|
@ -353,7 +353,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
|
|
||||||
# bwd chunk1 is left V;
|
# bwd chunk1 is left V;
|
||||||
else:
|
else:
|
||||||
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} self.send_backward_buffer {self.send_backward_buffer}")
|
|
||||||
################
|
################
|
||||||
# chunk = 1 && is_last_stage
|
# chunk = 1 && is_last_stage
|
||||||
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
|
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
|
||||||
@ -409,7 +408,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
accum_loss.add_(loss.detach())
|
accum_loss.add_(loss.detach())
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
outputs.append(tree_map(detach, output_obj))
|
outputs.append(tree_map(detach, output_obj))
|
||||||
# print(f"accum_loss {accum_loss}; outputs {len(outputs)}; model_chunk_id {model_chunk_id}")
|
|
||||||
return loss
|
return loss
|
||||||
else:
|
else:
|
||||||
return output_obj
|
return output_obj
|
||||||
@ -537,11 +535,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
Returns:
|
Returns:
|
||||||
Nothing.
|
Nothing.
|
||||||
"""
|
"""
|
||||||
|
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||||
# Step1: recv fwd
|
# Step1: recv fwd
|
||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
# is first stage; get input from func param
|
# is first stage; get input from func param
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
input_obj = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
input_obj = micro_batch
|
||||||
else:
|
else:
|
||||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||||
else:
|
else:
|
||||||
@ -619,8 +618,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
else:
|
else:
|
||||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||||
|
|
||||||
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}\n")
|
|
||||||
|
|
||||||
# get input and output object from buffer;
|
# get input and output object from buffer;
|
||||||
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
||||||
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
||||||
@ -643,7 +640,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
output_obj=output_obj,
|
output_obj=output_obj,
|
||||||
output_obj_grad=output_tensor_grad,
|
output_obj_grad=output_tensor_grad,
|
||||||
)
|
)
|
||||||
# print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; input_object_grad {input_object_grad}")
|
|
||||||
|
|
||||||
# Step3: send bwd
|
# Step3: send bwd
|
||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
@ -748,9 +744,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
"""
|
"""
|
||||||
# prepare batch
|
# prepare batch
|
||||||
self.load_batch(data_iter)
|
self.load_batch(data_iter)
|
||||||
print(
|
|
||||||
f"self.batch_size {self.batch_size}; self.batch shape {self.batch.shape}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# prepare accum loss & output
|
# prepare accum loss & output
|
||||||
accum_loss = None
|
accum_loss = None
|
||||||
@ -762,12 +755,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
||||||
|
|
||||||
# while we still have schedules_node in self.schedules
|
# while we still have schedules_node in self.schedules
|
||||||
for it in range(len(self.schedules)):
|
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
||||||
scheduled_node = self.schedules[it]
|
for it in range(len(schedule)):
|
||||||
|
scheduled_node = schedule[it]
|
||||||
print(
|
|
||||||
f"it {it}; manger_stage {self.stage_manager.stage}; node_stage {scheduled_node.stage} chunk {scheduled_node.chunk} {scheduled_node.type};"
|
|
||||||
)
|
|
||||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||||
# communication
|
# communication
|
||||||
communication_func = self.communication_map[scheduled_node.type]
|
communication_func = self.communication_map[scheduled_node.type]
|
||||||
|
@ -2,7 +2,8 @@ from .albert import *
|
|||||||
from .bert import *
|
from .bert import *
|
||||||
from .blip2 import *
|
from .blip2 import *
|
||||||
from .bloom import *
|
from .bloom import *
|
||||||
from .chatglm2 import *
|
|
||||||
|
# from .chatglm2 import *
|
||||||
from .command import *
|
from .command import *
|
||||||
from .deepseek import *
|
from .deepseek import *
|
||||||
from .falcon import *
|
from .falcon import *
|
||||||
|
@ -10,10 +10,11 @@ from torch.testing import assert_close
|
|||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
|
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
|
||||||
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
class MlpModel(nn.Module):
|
class MlpModel(nn.Module):
|
||||||
@ -38,19 +39,31 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
|||||||
|
|
||||||
|
|
||||||
# 1) Test manual v_schedule with multiple microbatch
|
# 1) Test manual v_schedule with multiple microbatch
|
||||||
def run_fwd_bwd_iter_input(
|
@parameterize(
|
||||||
rank: int,
|
"test_config",
|
||||||
world_size: int,
|
[
|
||||||
port: int,
|
{
|
||||||
):
|
"batch_size": 4,
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 4,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"precision": "bf16",
|
||||||
|
"num_model_chunk": 4,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_fwd_bwd_iter_input(test_config):
|
||||||
# init dist
|
# init dist
|
||||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
pp_size = world_size
|
pp_size = test_config["pp_size"]
|
||||||
pg_mesh = ProcessGroupMesh(pp_size)
|
pg_mesh = ProcessGroupMesh(pp_size)
|
||||||
num_microbatch = 4
|
num_microbatch = test_config["num_microbatches"]
|
||||||
|
num_model_chunk = test_config["num_model_chunk"]
|
||||||
# stage_manager
|
# stage_manager
|
||||||
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size)
|
stage_manager = PipelineStageManager(
|
||||||
|
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
|
||||||
|
)
|
||||||
|
|
||||||
# schedule list
|
# schedule list
|
||||||
zbv_schedule = [
|
zbv_schedule = [
|
||||||
@ -373,7 +386,7 @@ def run_fwd_bwd_iter_input(
|
|||||||
]
|
]
|
||||||
|
|
||||||
scheduler = ZeroBubbleVPipeScheduler(
|
scheduler = ZeroBubbleVPipeScheduler(
|
||||||
schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ?
|
schedule=zbv_schedule, # hint: send whole schedule or local schedule only ?
|
||||||
stage_manager=stage_manager,
|
stage_manager=stage_manager,
|
||||||
num_model_chunks=pp_size,
|
num_model_chunks=pp_size,
|
||||||
num_microbatch=num_microbatch,
|
num_microbatch=num_microbatch,
|
||||||
@ -419,20 +432,26 @@ def run_fwd_bwd_iter_input(
|
|||||||
for idx, sub_model in enumerate(model.layers):
|
for idx, sub_model in enumerate(model.layers):
|
||||||
if idx == 3 or idx == 4:
|
if idx == 3 or idx == 4:
|
||||||
local_chunk.append(sub_model)
|
local_chunk.append(sub_model)
|
||||||
|
# init optimizer
|
||||||
|
optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5)
|
||||||
|
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5))
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
scheduler.run_forward_backward(
|
result = scheduler.forward_backward_step(
|
||||||
model_chunk=local_chunk,
|
model_chunk=local_chunk,
|
||||||
data_iter=iter(data_iter),
|
data_iter=iter(data_iter),
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
optimizer=None,
|
optimizer=optimizer_pp,
|
||||||
return_loss=None,
|
return_loss=True,
|
||||||
return_outputs=None,
|
return_outputs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
optimizer_pp.step()
|
||||||
|
|
||||||
##########################
|
##########################
|
||||||
# Fwd bwd for base
|
# Fwd bwd for base
|
||||||
##########################
|
##########################
|
||||||
@ -440,6 +459,7 @@ def run_fwd_bwd_iter_input(
|
|||||||
output_base = model_base(input_base[0])
|
output_base = model_base(input_base[0])
|
||||||
loss_base = criterion(output_base)
|
loss_base = criterion(output_base)
|
||||||
loss_base.backward()
|
loss_base.backward()
|
||||||
|
optimizer_base.step()
|
||||||
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||||
|
|
||||||
##########################
|
##########################
|
||||||
@ -475,21 +495,28 @@ def run_fwd_bwd_iter_input(
|
|||||||
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
||||||
|
|
||||||
|
|
||||||
# 2) Test v_schedule generated by graph with multiple microbatch
|
# 2) add optimizer base 1)
|
||||||
def run_fwd_bwd_with_vschedule(
|
@parameterize(
|
||||||
rank: int,
|
"test_config",
|
||||||
world_size: int,
|
[
|
||||||
port: int,
|
{
|
||||||
num_microbatch: int,
|
"batch_size": 4,
|
||||||
batch_size: int,
|
"tp_size": 1,
|
||||||
num_model_chunk: int,
|
"pp_size": 4,
|
||||||
):
|
"num_microbatches": 4,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"precision": "bf16",
|
||||||
|
"num_model_chunk": 4,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||||
# init dist
|
# init dist
|
||||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
pp_size = world_size
|
pp_size = test_config["pp_size"]
|
||||||
pg_mesh = ProcessGroupMesh(pp_size)
|
pg_mesh = ProcessGroupMesh(pp_size)
|
||||||
num_microbatch = num_microbatch
|
num_microbatch = test_config["num_microbatches"]
|
||||||
|
num_model_chunk = test_config["num_model_chunk"]
|
||||||
# stage_manager
|
# stage_manager
|
||||||
stage_manager = PipelineStageManager(
|
stage_manager = PipelineStageManager(
|
||||||
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
|
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
|
||||||
@ -500,149 +527,7 @@ def run_fwd_bwd_with_vschedule(
|
|||||||
mem_w = -32 * h
|
mem_w = -32 * h
|
||||||
mem_b = -mem_w - mem_f
|
mem_b = -mem_w - mem_f
|
||||||
graph = PipelineGraph(
|
graph = PipelineGraph(
|
||||||
n_stage=world_size,
|
n_stage=pp_size,
|
||||||
n_micro=num_microbatch,
|
|
||||||
f_cost=6,
|
|
||||||
b_cost=6,
|
|
||||||
w_cost=6,
|
|
||||||
c_cost=6,
|
|
||||||
f_mem=mem_f,
|
|
||||||
b_mem=mem_b,
|
|
||||||
w_mem=mem_w,
|
|
||||||
# max_mem=mem_f * (p * 2 + m_offset),
|
|
||||||
)
|
|
||||||
|
|
||||||
zbv_schedule = graph.get_v_schedule()
|
|
||||||
|
|
||||||
scheduler = ZeroBubbleVPipeScheduler(
|
|
||||||
schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ?
|
|
||||||
stage_manager=stage_manager,
|
|
||||||
num_model_chunks=num_model_chunk,
|
|
||||||
num_microbatch=num_microbatch,
|
|
||||||
overlap_p2p=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def criterion(x, *args, **kwargs):
|
|
||||||
return (x * x).mean()
|
|
||||||
|
|
||||||
# init model and input
|
|
||||||
batch_size = batch_size
|
|
||||||
num_layers = 8
|
|
||||||
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
|
|
||||||
in_dim = out_dim = 8
|
|
||||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
|
||||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
|
||||||
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
|
||||||
|
|
||||||
input_base = [t.clone() for t in data_iter]
|
|
||||||
model_base = deepcopy(model)
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
# layer 0 & 7 to chunk 0 on rank0
|
|
||||||
local_chunk = torch.nn.ModuleList().to(rank)
|
|
||||||
for idx, sub_model in enumerate(model.layers):
|
|
||||||
if idx == 0 or idx == 7:
|
|
||||||
local_chunk.append(sub_model)
|
|
||||||
elif rank == 1:
|
|
||||||
# layer 1 & 6 to chunk 1 on rank1
|
|
||||||
local_chunk = torch.nn.ModuleList().to(rank)
|
|
||||||
for idx, sub_model in enumerate(model.layers):
|
|
||||||
if idx == 1 or idx == 6:
|
|
||||||
local_chunk.append(sub_model)
|
|
||||||
elif rank == 2:
|
|
||||||
# layer 2 & 5 to chunk 2 on rank2
|
|
||||||
local_chunk = torch.nn.ModuleList().to(rank)
|
|
||||||
for idx, sub_model in enumerate(model.layers):
|
|
||||||
if idx == 2 or idx == 5:
|
|
||||||
local_chunk.append(sub_model)
|
|
||||||
else:
|
|
||||||
# layer 3 & 4 to chunk 3 on rank3
|
|
||||||
local_chunk = torch.nn.Sequential().to(rank)
|
|
||||||
for idx, sub_model in enumerate(model.layers):
|
|
||||||
if idx == 3 or idx == 4:
|
|
||||||
local_chunk.append(sub_model)
|
|
||||||
print(
|
|
||||||
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
scheduler.run_forward_backward(
|
|
||||||
model_chunk=local_chunk,
|
|
||||||
data_iter=iter(data_iter),
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=None,
|
|
||||||
return_loss=None,
|
|
||||||
return_outputs=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
##########################
|
|
||||||
# Fwd bwd for base
|
|
||||||
##########################
|
|
||||||
# fwd & bwd
|
|
||||||
output_base = model_base(input_base[0])
|
|
||||||
loss_base = criterion(output_base)
|
|
||||||
loss_base.backward()
|
|
||||||
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
|
||||||
|
|
||||||
##########################
|
|
||||||
# assert weight
|
|
||||||
##########################
|
|
||||||
if rank == 0:
|
|
||||||
# layer 0
|
|
||||||
assert_close(local_chunk[0].weight, model_base.layers[0].weight)
|
|
||||||
assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad)
|
|
||||||
# layer 7
|
|
||||||
assert_close(local_chunk[1].weight, model_base.layers[7].weight)
|
|
||||||
assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad)
|
|
||||||
if rank == 1:
|
|
||||||
# layer 1
|
|
||||||
assert_close(local_chunk[0].weight, model_base.layers[1].weight)
|
|
||||||
assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad)
|
|
||||||
# layer 6
|
|
||||||
assert_close(local_chunk[1].weight, model_base.layers[6].weight)
|
|
||||||
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
|
|
||||||
if rank == 2:
|
|
||||||
# layer 2
|
|
||||||
assert_close(local_chunk[0].weight, model_base.layers[2].weight)
|
|
||||||
assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad)
|
|
||||||
# layer 5
|
|
||||||
assert_close(local_chunk[1].weight, model_base.layers[5].weight)
|
|
||||||
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
|
|
||||||
if rank == 3:
|
|
||||||
# layer 3
|
|
||||||
assert_close(local_chunk[0].weight, model_base.layers[3].weight)
|
|
||||||
assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad)
|
|
||||||
# layer 4
|
|
||||||
assert_close(local_chunk[1].weight, model_base.layers[4].weight)
|
|
||||||
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
|
||||||
|
|
||||||
|
|
||||||
# 3) add optimizer base 2)
|
|
||||||
def run_fwd_bwd_vschedule_with_optim(
|
|
||||||
rank: int,
|
|
||||||
world_size: int,
|
|
||||||
port: int,
|
|
||||||
num_microbatch: int,
|
|
||||||
batch_size: int,
|
|
||||||
num_model_chunk: int,
|
|
||||||
):
|
|
||||||
# init dist
|
|
||||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
|
||||||
rank = dist.get_rank()
|
|
||||||
pp_size = world_size
|
|
||||||
pg_mesh = ProcessGroupMesh(pp_size)
|
|
||||||
num_microbatch = num_microbatch
|
|
||||||
# stage_manager
|
|
||||||
stage_manager = PipelineStageManager(
|
|
||||||
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
|
|
||||||
)
|
|
||||||
|
|
||||||
h, a, s = 4096, 32, 1024
|
|
||||||
mem_f = 34 * h + 5 * a * s
|
|
||||||
mem_w = -32 * h
|
|
||||||
mem_b = -mem_w - mem_f
|
|
||||||
graph = PipelineGraph(
|
|
||||||
n_stage=world_size,
|
|
||||||
n_micro=num_microbatch,
|
n_micro=num_microbatch,
|
||||||
f_cost=1,
|
f_cost=1,
|
||||||
b_cost=1,
|
b_cost=1,
|
||||||
@ -657,7 +542,7 @@ def run_fwd_bwd_vschedule_with_optim(
|
|||||||
zbv_schedule = graph.get_v_schedule()
|
zbv_schedule = graph.get_v_schedule()
|
||||||
|
|
||||||
scheduler = ZeroBubbleVPipeScheduler(
|
scheduler = ZeroBubbleVPipeScheduler(
|
||||||
schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ?
|
schedule=zbv_schedule, # hint: send whole schedule or local schedule only ?
|
||||||
stage_manager=stage_manager,
|
stage_manager=stage_manager,
|
||||||
num_model_chunks=num_model_chunk,
|
num_model_chunks=num_model_chunk,
|
||||||
num_microbatch=num_microbatch,
|
num_microbatch=num_microbatch,
|
||||||
@ -669,7 +554,7 @@ def run_fwd_bwd_vschedule_with_optim(
|
|||||||
return (x * x).mean()
|
return (x * x).mean()
|
||||||
|
|
||||||
# init model and input
|
# init model and input
|
||||||
batch_size = batch_size
|
batch_size = test_config["batch_size"]
|
||||||
num_layers = 8
|
num_layers = 8
|
||||||
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
|
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
|
||||||
in_dim = out_dim = 16
|
in_dim = out_dim = 16
|
||||||
@ -793,8 +678,27 @@ def run_fwd_bwd_vschedule_with_optim(
|
|||||||
assert val_base[:2] == val_pp
|
assert val_base[:2] == val_pp
|
||||||
|
|
||||||
|
|
||||||
# 4) support Hybrid base 3)
|
# TODO:4) support Hybrid base 3)
|
||||||
def run_with_hybrid(
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"batch_size": 4,
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 4,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"precision": "bf16",
|
||||||
|
"num_model_chunk": 4,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_with_hybridplugin(test_config):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# TODO:5) support MoEHybrid base 3)
|
||||||
|
def run_with_moehybridplugin(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
port: int,
|
port: int,
|
||||||
@ -805,35 +709,26 @@ def run_with_hybrid(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# 5) support MoE base 3)
|
# TODO:6) support booster & Hybrid base 4)
|
||||||
|
|
||||||
# 6) support booster & Hybrid base 4)
|
# TODO:7) support booster & MoEHybrid base 4)
|
||||||
|
|
||||||
# 6) support booster & MoE base 4)
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_fwd_bwd_iter_input()
|
||||||
|
run_fwd_bwd_vschedule_with_optim()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("num_microbatch", [4])
|
|
||||||
@pytest.mark.parametrize("batch_size", [4])
|
|
||||||
@pytest.mark.parametrize("num_model_chunk", [4])
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
|
def test_pp():
|
||||||
# spawn(
|
|
||||||
# run_fwd_bwd_with_vschedule,
|
|
||||||
# nprocs=4,
|
|
||||||
# num_microbatch=num_microbatch,
|
|
||||||
# batch_size=batch_size,
|
|
||||||
# num_model_chunk=num_model_chunk,
|
|
||||||
# )
|
|
||||||
|
|
||||||
spawn(
|
spawn(
|
||||||
run_fwd_bwd_vschedule_with_optim,
|
run_dist,
|
||||||
nprocs=4,
|
nprocs=4,
|
||||||
num_microbatch=num_microbatch,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_model_chunk=num_model_chunk,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_pp(num_microbatch=4, batch_size=4, num_model_chunk=4)
|
test_pp()
|
||||||
|
Loading…
Reference in New Issue
Block a user