[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)

* [autoparallel] first move.

* [autoparallel] add solver rotor.

* [autoparallel] add ckpt solvers.

* [autoparallel] modify codegen.

* [fx] fix annotation in test.

* [fx] remove check.

* [autoparallel] polish docstring.

* [fx] refactor MetaTensor.
This commit is contained in:
Super Daniel
2022-11-01 10:43:15 +08:00
committed by GitHub
parent 2b859502d5
commit 1e88811c7a
16 changed files with 1025 additions and 119 deletions

View File

@@ -1,26 +1,28 @@
#!/usr/bin/env python
"""
tracer.py:
tracer.py:
Implemented a tracer which supports control flow and user-defined meta arguments.
The implementation is partly inspired HuggingFace's fx tracer
"""
import enum
import inspect
import functools
import inspect
import operator
from contextlib import contextmanager
from colossalai.fx.tracer.meta_patch import meta_patched_module
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch.fx import Tracer, Node
from torch.fx.graph import Graph
from torch.fx.proxy import Proxy, ParameterProxy
from torch.fx import Node, Tracer
from torch.fx.graph import Graph, magic_methods, reflectable_magic_methods
from torch.fx.proxy import ParameterProxy, Proxy
from colossalai.fx.tracer.meta_patch import meta_patched_module
from ..proxy import ColoProxy
from typing import Optional, Dict, Any
from ._tracer_utils import is_element_in_list, extract_meta, compute_meta_data_for_functions_proxy
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
from .meta_patch import meta_patched_function, meta_patched_module
from torch.fx.graph import magic_methods, reflectable_magic_methods
__all__ = ['ColoTracer']
@@ -231,7 +233,7 @@ class ColoTracer(Tracer):
Args:
root (nn.Module): a `nn.Module` object to trace the computation graph
meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph.
meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph.
These arguments are the sample data fed to the model during actual computation, but just converted to meta tensors.
concrete_args (Optional[Dict[str, Tensor]]): the concrete arguments that should not be treated as Proxies.
"""
@@ -383,7 +385,7 @@ class ColoTracer(Tracer):
if self.inside_torch_checkpoint_func:
# annotate the activation checkpoint module
setattr(node, 'activation_checkpoint', self.act_ckpt_region_count)
node.meta['activation_checkpoint'] = self.act_ckpt_region_count
return node