mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[shardformer] refactored the shardformer layer structure (#4053)
This commit is contained in:
27
colossalai/shardformer/layer/parallel_module.py
Normal file
27
colossalai/shardformer/layer/parallel_module.py
Normal file
@@ -0,0 +1,27 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
__all__ = ['ParallelModule']
|
||||
|
||||
|
||||
class ParallelModule(nn.Module, ABC):
|
||||
|
||||
@abstractmethod
|
||||
def from_native_module(module: nn.Module,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule":
|
||||
"""
|
||||
Convert a native PyTorch module to a parallelized module.
|
||||
|
||||
Args:
|
||||
module (nn.Module): the module to be converted.
|
||||
process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication.
|
||||
If this is a list, the process group at the ith index of the list will correspond to the process group
|
||||
in the ith axis of the device mesh. Defaults to None, which means the global process group.
|
||||
"""
|
||||
pass
|
Reference in New Issue
Block a user