From 1e0e080837478e95bc2d835c58ccd025a0013c00 Mon Sep 17 00:00:00 2001 From: littsk <1214689160@qq.com> Date: Fri, 22 Sep 2023 10:50:47 +0800 Subject: [PATCH] [bug] Fix the version check bug in colossalai run when generating the cmd. (#4713) * Fix the version check bug in colossalai run when generating the cmd. * polish code --- colossalai/cli/launcher/run.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py index 7ca8ee903..88f70f02e 100644 --- a/colossalai/cli/launcher/run.py +++ b/colossalai/cli/launcher/run.py @@ -156,7 +156,8 @@ def get_launch_command( torch_version = version.parse(torch.__version__) assert torch_version.major >= 1 - if torch_version.minor < 9: + if torch_version.major == 1 and torch_version.minor < 9: + # torch distributed launch cmd with torch < 1.9 cmd = [ sys.executable, "-m", @@ -177,7 +178,8 @@ def get_launch_command( value = extra_launch_args.pop(key) default_torchrun_rdzv_args[key] = value - if torch_version.minor < 10: + if torch_version.major == 1 and torch_version.minor == 9: + # torch distributed launch cmd with torch == 1.9 cmd = [ sys.executable, "-m", @@ -187,6 +189,7 @@ def get_launch_command( f"--node_rank={node_rank}", ] else: + # torch distributed launch cmd with torch > 1.9 cmd = [ "torchrun", f"--nproc_per_node={nproc_per_node}",