[bug] Fix the version check bug in colossalai run when generating the cmd. (#4713)

* Fix the version check bug in colossalai run when generating the cmd.

* polish code
This commit is contained in:
littsk 2023-09-22 10:50:47 +08:00 committed by GitHub
parent 3e05c07bb8
commit 1e0e080837
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -156,7 +156,8 @@ def get_launch_command(
torch_version = version.parse(torch.__version__)
assert torch_version.major >= 1
if torch_version.minor < 9:
if torch_version.major == 1 and torch_version.minor < 9:
# torch distributed launch cmd with torch < 1.9
cmd = [
sys.executable,
"-m",
@ -177,7 +178,8 @@ def get_launch_command(
value = extra_launch_args.pop(key)
default_torchrun_rdzv_args[key] = value
if torch_version.minor < 10:
if torch_version.major == 1 and torch_version.minor == 9:
# torch distributed launch cmd with torch == 1.9
cmd = [
sys.executable,
"-m",
@ -187,6 +189,7 @@ def get_launch_command(
f"--node_rank={node_rank}",
]
else:
# torch distributed launch cmd with torch > 1.9
cmd = [
"torchrun",
f"--nproc_per_node={nproc_per_node}",