mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 02:20:49 +00:00
[fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling;
This commit is contained in:
@@ -21,6 +21,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
|
||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
@@ -91,7 +92,7 @@ def main():
|
||||
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
||||
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||
|
||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
|
||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
|
||||
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
||||
parser.add_argument(
|
||||
@@ -137,6 +138,11 @@ def main():
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.config in MODEL_CONFIGS:
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||
|
||||
use_empty_init = True
|
||||
if args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
@@ -210,6 +216,23 @@ def main():
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
if args.pp_style == "zbv":
|
||||
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
|
||||
mem_w = -32 * config.hidden_size
|
||||
mem_b = -mem_w - mem_f
|
||||
scheduler_nodes = PipelineGraph(
|
||||
n_stage=args.pp,
|
||||
n_micro=args.batch_size // args.mbs,
|
||||
f_cost=1000,
|
||||
b_cost=1000,
|
||||
w_cost=1000,
|
||||
c_cost=1,
|
||||
f_mem=mem_f,
|
||||
b_mem=mem_b,
|
||||
w_mem=mem_w,
|
||||
).get_v_schedule()
|
||||
else:
|
||||
scheduler_nodes = None
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
@@ -227,6 +250,7 @@ def main():
|
||||
overlap_allgather=args.overlap_allgather,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
scheduler_nodes=scheduler_nodes,
|
||||
**hybrid_kwargs,
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
@@ -256,10 +280,6 @@ def main():
|
||||
# ==============================
|
||||
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
|
||||
|
||||
if args.config in MODEL_CONFIGS:
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||
torch.cuda.manual_seed(42)
|
||||
dataset = RandomDataset(
|
||||
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
||||
@@ -334,8 +354,12 @@ def main():
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
if args.pp_style == "zbv":
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
else:
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
Reference in New Issue
Block a user