[bug] Fix the version check bug in colossalai run when generating the cmd. (#4713)

* Fix the version check bug in colossalai run when generating the cmd.

* polish code
This commit is contained in:
littsk
2023-09-22 10:50:47 +08:00
committed by GitHub
parent 3e05c07bb8
commit 1e0e080837

View File

@@ -156,7 +156,8 @@ def get_launch_command(
torch_version = version.parse(torch.__version__)
assert torch_version.major >= 1
if torch_version.minor < 9:
if torch_version.major == 1 and torch_version.minor < 9:
# torch distributed launch cmd with torch < 1.9
cmd = [
sys.executable,
"-m",
@@ -177,7 +178,8 @@ def get_launch_command(
value = extra_launch_args.pop(key)
default_torchrun_rdzv_args[key] = value
if torch_version.minor < 10:
if torch_version.major == 1 and torch_version.minor == 9:
# torch distributed launch cmd with torch == 1.9
cmd = [
sys.executable,
"-m",
@@ -187,6 +189,7 @@ def get_launch_command(
f"--node_rank={node_rank}",
]
else:
# torch distributed launch cmd with torch > 1.9
cmd = [
"torchrun",
f"--nproc_per_node={nproc_per_node}",