[fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling;

This commit is contained in:
duanjunwen
2024-10-11 07:32:43 +00:00
parent e234dfa236
commit 0ca16d5cbe
5 changed files with 134 additions and 430 deletions

View File

@@ -21,6 +21,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
from colossalai.shardformer import PipelineGradientCheckpointConfig
warnings.filterwarnings("ignore")
@@ -91,7 +92,7 @@ def main():
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code")
parser.add_argument(
@@ -137,6 +138,11 @@ def main():
# ==============================
# Initialize Booster
# ==============================
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
use_empty_init = True
if args.plugin == "gemini":
plugin = GeminiPlugin(
@@ -210,6 +216,23 @@ def main():
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "3d":
if args.pp_style == "zbv":
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
mem_w = -32 * config.hidden_size
mem_b = -mem_w - mem_f
scheduler_nodes = PipelineGraph(
n_stage=args.pp,
n_micro=args.batch_size // args.mbs,
f_cost=1000,
b_cost=1000,
w_cost=1000,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
).get_v_schedule()
else:
scheduler_nodes = None
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
@@ -227,6 +250,7 @@ def main():
overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
scheduler_nodes=scheduler_nodes,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
@@ -256,10 +280,6 @@ def main():
# ==============================
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
torch.cuda.manual_seed(42)
dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
@@ -334,8 +354,12 @@ def main():
return_loss=True,
)
loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
if args.pp_style == "zbv":
if dist.get_rank() == 0:
print(f"Step {step} loss: {loss}")
else:
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
optimizer.step()
optimizer.zero_grad()