mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-12-11 22:27:12 +00:00
[example] opt does not depend on Titans (#1811)
This commit is contained in:
32
examples/language/opt/context.py
Normal file
32
examples/language/opt/context.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
class barrier_context():
|
||||
"""
|
||||
This context manager is used to allow one process to execute while blocking all
|
||||
other processes in the same process group. This is often useful when downloading is required
|
||||
as we only want to download in one process to prevent file corruption.
|
||||
Args:
|
||||
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
|
||||
parallel_mode (ParallelMode): the parallel mode corresponding to a process group
|
||||
Usage:
|
||||
with barrier_context():
|
||||
dataset = CIFAR10(root='./data', download=True)
|
||||
"""
|
||||
|
||||
def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL):
|
||||
# the class name is lowercase by convention
|
||||
current_rank = gpc.get_local_rank(parallel_mode=parallel_mode)
|
||||
self.should_block = current_rank != executor_rank
|
||||
self.group = gpc.get_group(parallel_mode=parallel_mode)
|
||||
|
||||
def __enter__(self):
|
||||
if self.should_block:
|
||||
dist.barrier(group=self.group)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
if not self.should_block:
|
||||
dist.barrier(group=self.group)
|
||||
Reference in New Issue
Block a user