[legacy] clean up legacy code (#4743)

* [legacy] remove outdated codes of pipeline (#4692)

* [legacy] remove cli of benchmark and update optim (#4690)

* [legacy] remove cli of benchmark and update optim

* [doc] fix cli doc test

* [legacy] fix engine clip grad norm

* [legacy] remove outdated colo tensor (#4694)

* [legacy] remove outdated colo tensor

* [test] fix test import

* [legacy] move outdated zero to legacy (#4696)

* [legacy] clean up utils (#4700)

* [legacy] clean up utils

* [example] update examples

* [legacy] clean up amp

* [legacy] fix amp module

* [legacy] clean up gpc (#4742)

* [legacy] clean up context

* [legacy] clean core, constants and global vars

* [legacy] refactor initialize

* [example] fix examples ci

* [example] fix examples ci

* [legacy] fix tests

* [example] fix gpt example

* [example] fix examples ci

* [devops] fix ci installation

* [example] fix examples ci
This commit is contained in:
Hongxin Liu
2023-09-18 16:31:06 +08:00
committed by GitHub
parent 32e7f99416
commit b5f9e37c70
342 changed files with 2919 additions and 4182 deletions

View File

@@ -1,9 +1,11 @@
import operator
import torch
import torch.nn as nn
import operator
from colossalai.tensor import ProcessGroup
from colossalai.tensor.distspec import ShardSpec
from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec
from colossalai.legacy.tensor import ProcessGroup
from colossalai.legacy.tensor.compute_spec import ComputePattern, ComputeSpec
from colossalai.legacy.tensor.distspec import ShardSpec
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
@@ -13,7 +15,7 @@ ELEMENTWISE_FUNC_OP = [
def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter:
"""weight_split
"""weight_split
split a nn.Parameter
Args:
@@ -60,9 +62,9 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule):
def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup):
"""
This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers.
This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers.
"""
#TODO: Needs to handle special cases, like x = linear(x) + linear(x)
# TODO: Needs to handle special cases, like x = linear(x) + linear(x)
graph = graph_module.graph
world_size = process_group.world_size()