[pipeline/pytree] add pytree to process args and kwargs | provide data_process_func to process args and kwargs after forward (#1642)

* [pipeline/tuning] improve dispatch performance both time and space cost

* [pipeline/converge] add interface for testing convergence

* [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style

* Update PipelineBase.py

* [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera

* [pipeline/chimera] test chimera | fix bug of initializing

* [pipeline/pytree] add pytree to process args and kwargs | provide  to process args and kwargs after forward
This commit is contained in:
Kirigaya Kazuto
2022-09-29 10:58:58 +08:00
committed by GitHub
parent c27e701cb2
commit 9708638ded
5 changed files with 247 additions and 126 deletions

View File

@@ -22,10 +22,9 @@ def run_master(args):
epoch = args.epoch
device = args.device
stage_num = 4
stage_num = args.world_size
chunk = 1
num_microbatches = 4
actual_stage_num = 4
num_microbatches = args.num_microbatches
use_checkpoint = False
sample_num = 1024
@@ -78,6 +77,4 @@ def run_master(args):
if __name__ == "__main__":
args = parse_args()
args.world_size = 4
args.num_microbatches = 4
rpc_run(args, run_master)