mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[kernel] fixed repeated loading of kernels (#2549)
* [kernel] fixed repeated loading of kernels * polish code * polish code
This commit is contained in:
@@ -6,6 +6,23 @@ from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
|
||||
def print_rank_0(message):
|
||||
"""
|
||||
Print on only one process to avoid spamming.
|
||||
"""
|
||||
try:
|
||||
import torch.distributed as dist
|
||||
if not dist.is_initialized():
|
||||
is_main_rank = True
|
||||
else:
|
||||
is_main_rank = dist.get_rank() == 0
|
||||
except ImportError:
|
||||
is_main_rank = True
|
||||
|
||||
if is_main_rank:
|
||||
print(message)
|
||||
|
||||
|
||||
class Builder(ABC):
|
||||
"""
|
||||
Builder is the base class to build extensions for PyTorch.
|
||||
@@ -117,7 +134,7 @@ class Builder(ABC):
|
||||
try:
|
||||
op_module = self.import_op()
|
||||
if verbose:
|
||||
print(f"OP {self.prebuilt_import_path} already exists, skip building.")
|
||||
print_rank_0(f"OP {self.prebuilt_import_path} already exists, skip building.")
|
||||
except ImportError:
|
||||
# construct the build directory
|
||||
import torch
|
||||
@@ -130,9 +147,11 @@ class Builder(ABC):
|
||||
Path(build_directory).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if verbose:
|
||||
print("=========================================================================================")
|
||||
print(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now")
|
||||
print("=========================================================================================")
|
||||
print_rank_0(
|
||||
"=========================================================================================")
|
||||
print_rank_0(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now")
|
||||
print_rank_0(
|
||||
"=========================================================================================")
|
||||
|
||||
# load the kernel
|
||||
op_module = load(name=self.name,
|
||||
@@ -146,7 +165,7 @@ class Builder(ABC):
|
||||
|
||||
build_duration = time.time() - start_build
|
||||
if verbose:
|
||||
print(f"Time to load {self.name} op: {build_duration} seconds")
|
||||
print_rank_0(f"Time to load {self.name} op: {build_duration} seconds")
|
||||
|
||||
return op_module
|
||||
|
||||
|
Reference in New Issue
Block a user