mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
This commit is contained in:
@@ -1,18 +1,22 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||
from colossalai.pipeline.middleware.adaptor import get_fx_topology
|
||||
import random
|
||||
import numpy as np
|
||||
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
|
||||
from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||
from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology
|
||||
|
||||
MANUAL_SEED = 0
|
||||
random.seed(MANUAL_SEED)
|
||||
np.random.seed(MANUAL_SEED)
|
||||
torch.manual_seed(MANUAL_SEED)
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
def __init__(self, config={}):
|
||||
super().__init__()
|
||||
dim = config['dim']
|
||||
@@ -27,6 +31,7 @@ class MLP(torch.nn.Module):
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def split_model_and_get_DAG(model, data_gen):
|
||||
model.eval()
|
||||
|
||||
@@ -46,7 +51,7 @@ def split_model_and_get_DAG(model, data_gen):
|
||||
# apply transform passes
|
||||
annotated_model = balanced_split_pass(gm, 2)
|
||||
top_module, split_submodules = split_with_split_nodes_pass(annotated_model)
|
||||
|
||||
|
||||
topo = get_fx_topology(top_module)
|
||||
for submodule in split_submodules:
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
@@ -54,6 +59,7 @@ def split_model_and_get_DAG(model, data_gen):
|
||||
|
||||
return top_module, split_submodules[0]._topo
|
||||
|
||||
|
||||
def check_input(top_module, input_partition: Partition):
|
||||
partition_output = input_partition.get_output_vals()
|
||||
arg_pos = 0
|
||||
@@ -63,13 +69,14 @@ def check_input(top_module, input_partition: Partition):
|
||||
to_partition_and_offset = cur_checkee.get()
|
||||
assert len(to_partition_and_offset) == len(node.users.keys())
|
||||
arg_pos += 1
|
||||
|
||||
|
||||
assert arg_pos == len(partition_output)
|
||||
|
||||
|
||||
|
||||
def check_submod(top_module, part_id, mid_partition: Partition):
|
||||
partition_input = mid_partition.get_input_vals()
|
||||
partition_output = mid_partition.get_output_vals()
|
||||
|
||||
|
||||
cnt = 1
|
||||
cur_node = None
|
||||
for node in top_module.graph.nodes:
|
||||
@@ -78,15 +85,15 @@ def check_submod(top_module, part_id, mid_partition: Partition):
|
||||
if cnt == part_id:
|
||||
cur_node = node
|
||||
break
|
||||
|
||||
|
||||
assert len(partition_input) == len(cur_node.args)
|
||||
assert len(partition_output) == len(cur_node.users)
|
||||
|
||||
def check_topo(top_module, topo: Topo):
|
||||
|
||||
def check_topo(top_module, topo: Topo):
|
||||
input_partition = topo.get_input_partition()
|
||||
mid_partitions = topo.get_mid_partitions()
|
||||
|
||||
|
||||
check_input(top_module, input_partition)
|
||||
for part_id, submod in mid_partitions.items():
|
||||
check_submod(top_module, part_id, submod)
|
||||
|
Reference in New Issue
Block a user