[kernel] fixed repeated loading of kernels (#2549)

* [kernel] fixed repeated loading of kernels

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-02-03 09:47:13 +08:00
committed by GitHub
parent 8438c35a5f
commit dd14783f75
4 changed files with 59 additions and 46 deletions

View File

@@ -6,6 +6,23 @@ from pathlib import Path
from typing import List
def print_rank_0(message):
"""
Print on only one process to avoid spamming.
"""
try:
import torch.distributed as dist
if not dist.is_initialized():
is_main_rank = True
else:
is_main_rank = dist.get_rank() == 0
except ImportError:
is_main_rank = True
if is_main_rank:
print(message)
class Builder(ABC):
"""
Builder is the base class to build extensions for PyTorch.
@@ -117,7 +134,7 @@ class Builder(ABC):
try:
op_module = self.import_op()
if verbose:
print(f"OP {self.prebuilt_import_path} already exists, skip building.")
print_rank_0(f"OP {self.prebuilt_import_path} already exists, skip building.")
except ImportError:
# construct the build directory
import torch
@@ -130,9 +147,11 @@ class Builder(ABC):
Path(build_directory).mkdir(parents=True, exist_ok=True)
if verbose:
print("=========================================================================================")
print(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now")
print("=========================================================================================")
print_rank_0(
"=========================================================================================")
print_rank_0(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now")
print_rank_0(
"=========================================================================================")
# load the kernel
op_module = load(name=self.name,
@@ -146,7 +165,7 @@ class Builder(ABC):
build_duration = time.time() - start_build
if verbose:
print(f"Time to load {self.name} op: {build_duration} seconds")
print_rank_0(f"Time to load {self.name} op: {build_duration} seconds")
return op_module