[legacy] clean up legacy code (#4743)

* [legacy] remove outdated codes of pipeline (#4692)

* [legacy] remove cli of benchmark and update optim (#4690)

* [legacy] remove cli of benchmark and update optim

* [doc] fix cli doc test

* [legacy] fix engine clip grad norm

* [legacy] remove outdated colo tensor (#4694)

* [legacy] remove outdated colo tensor

* [test] fix test import

* [legacy] move outdated zero to legacy (#4696)

* [legacy] clean up utils (#4700)

* [legacy] clean up utils

* [example] update examples

* [legacy] clean up amp

* [legacy] fix amp module

* [legacy] clean up gpc (#4742)

* [legacy] clean up context

* [legacy] clean core, constants and global vars

* [legacy] refactor initialize

* [example] fix examples ci

* [example] fix examples ci

* [legacy] fix tests

* [example] fix gpt example

* [example] fix examples ci

* [devops] fix ci installation

* [example] fix examples ci
This commit is contained in:
Hongxin Liu
2023-09-18 16:31:06 +08:00
committed by GitHub
parent 32e7f99416
commit b5f9e37c70
342 changed files with 2919 additions and 4182 deletions

View File

@@ -1,18 +1,22 @@
import random
import numpy as np
import torch
from torch.fx import GraphModule
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from colossalai.fx import ColoTracer
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from colossalai.pipeline.middleware.adaptor import get_fx_topology
import random
import numpy as np
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
class MLP(torch.nn.Module):
def __init__(self, config={}):
super().__init__()
dim = config['dim']
@@ -27,6 +31,7 @@ class MLP(torch.nn.Module):
x = layer(x)
return x
def split_model_and_get_DAG(model, data_gen):
model.eval()
@@ -46,7 +51,7 @@ def split_model_and_get_DAG(model, data_gen):
# apply transform passes
annotated_model = balanced_split_pass(gm, 2)
top_module, split_submodules = split_with_split_nodes_pass(annotated_model)
topo = get_fx_topology(top_module)
for submodule in split_submodules:
if isinstance(submodule, torch.fx.GraphModule):
@@ -54,6 +59,7 @@ def split_model_and_get_DAG(model, data_gen):
return top_module, split_submodules[0]._topo
def check_input(top_module, input_partition: Partition):
partition_output = input_partition.get_output_vals()
arg_pos = 0
@@ -63,13 +69,14 @@ def check_input(top_module, input_partition: Partition):
to_partition_and_offset = cur_checkee.get()
assert len(to_partition_and_offset) == len(node.users.keys())
arg_pos += 1
assert arg_pos == len(partition_output)
def check_submod(top_module, part_id, mid_partition: Partition):
partition_input = mid_partition.get_input_vals()
partition_output = mid_partition.get_output_vals()
cnt = 1
cur_node = None
for node in top_module.graph.nodes:
@@ -78,15 +85,15 @@ def check_submod(top_module, part_id, mid_partition: Partition):
if cnt == part_id:
cur_node = node
break
assert len(partition_input) == len(cur_node.args)
assert len(partition_output) == len(cur_node.users)
def check_topo(top_module, topo: Topo):
def check_topo(top_module, topo: Topo):
input_partition = topo.get_input_partition()
mid_partitions = topo.get_mid_partitions()
check_input(top_module, input_partition)
for part_id, submod in mid_partitions.items():
check_submod(top_module, part_id, submod)