mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[shardformer]: support gpt-j, falcon, Mistral and add interleaved pipeline for bert (#5088)
* [shardformer] implement policy for all GPT-J models and test * [shardformer] support interleaved pipeline parallel for bert finetune * [shardformer] shardformer support falcon (#4883) * [shardformer]: fix interleaved pipeline for bert model (#5048) * [hotfix]: disable seq parallel for gptj and falcon, and polish code (#5093) * Add Mistral support for Shardformer (#5103) * [shardformer] add tests to mistral (#5105) --------- Co-authored-by: Pengtai Xu <henryxu880@gmail.com> Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: eric8607242 <e0928021388@gmail.com>
This commit is contained in:
@@ -88,20 +88,24 @@ class GLUEDataBuilder:
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
# TODO: drop_last is set to True for now to avoid error when using PP
|
||||
# as the last batch may not be divisible by the number of microbatches
|
||||
if len(self.eval_splits) == 1:
|
||||
return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
|
||||
return self.plugin.prepare_dataloader(
|
||||
self.dataset["validation"], batch_size=self.eval_batch_size, drop_last=True
|
||||
)
|
||||
elif len(self.eval_splits) > 1:
|
||||
return [
|
||||
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
|
||||
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True)
|
||||
for x in self.eval_splits
|
||||
]
|
||||
|
||||
def test_dataloader(self):
|
||||
if len(self.eval_splits) == 1:
|
||||
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
|
||||
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size, drop_last=True)
|
||||
elif len(self.eval_splits) > 1:
|
||||
return [
|
||||
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
|
||||
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True)
|
||||
for x in self.eval_splits
|
||||
]
|
||||
|
||||
|
@@ -57,7 +57,9 @@ def evaluate_model(
|
||||
|
||||
def evaluate_subset(dataloader: DataLoader):
|
||||
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(
|
||||
None if not booster.plugin.stage_manager.is_interleave else -1
|
||||
)
|
||||
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
for batch in dataloader:
|
||||
@@ -69,9 +71,10 @@ def evaluate_model(
|
||||
current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
|
||||
current_rank = dist.get_rank()
|
||||
batch = iter([batch])
|
||||
|
||||
outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True)
|
||||
|
||||
if is_pp_last_stage:
|
||||
if is_pp_last_device:
|
||||
logits = outputs["outputs"]["logits"]
|
||||
val_loss = outputs["loss"]
|
||||
accum_loss.add_(val_loss)
|
||||
@@ -133,8 +136,10 @@ def train_epoch(
|
||||
coordinator: DistCoordinator,
|
||||
):
|
||||
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
|
||||
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(
|
||||
None if not booster.plugin.stage_manager.is_interleave else -1
|
||||
)
|
||||
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device)
|
||||
total_step = len(train_dataloader)
|
||||
|
||||
model.train()
|
||||
@@ -148,7 +153,7 @@ def train_epoch(
|
||||
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
|
||||
)
|
||||
# Backward and optimize
|
||||
if is_pp_last_stage:
|
||||
if is_pp_last_device:
|
||||
loss = outputs["loss"]
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
else:
|
||||
@@ -222,7 +227,9 @@ def main():
|
||||
tp_size=1,
|
||||
pp_size=2,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1,
|
||||
pp_style="interleaved",
|
||||
num_model_chunks=2,
|
||||
microbatch_size=16,
|
||||
enable_all_optimization=True,
|
||||
zero_stage=1,
|
||||
precision="fp16",
|
||||
|
Reference in New Issue
Block a user