Files
ColossalAI/colossalai/shardformer/layer/parallelmodule.py
FoolPlayer 4021b9a8a2 [shardformer] add gpt2 test and layer class refactor (#4041)
* add gpt2 test and layer class refactor

* add dropout in gpt2 policy
2023-07-04 16:05:01 +08:00

36 lines
1.1 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
from typing import List, Union
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import init as init
Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass
class ParallelModule(nn.Module, ABC):
@abstractmethod
def from_native_module(module: nn.Module,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule":
"""
Convert a native PyTorch module to a parallelized module.
Args:
module (nn.Module): the module to be converted.
process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication.
If this is a list, the process group at the ith index of the list will correspond to the process group
in the ith axis of the device mesh. Defaults to None, which means the global process group.
"""
pass