mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)
* [autoparallel] first move. * [autoparallel] add solver rotor. * [autoparallel] add ckpt solvers. * [autoparallel] modify codegen. * [fx] fix annotation in test. * [fx] remove check. * [autoparallel] polish docstring. * [fx] refactor MetaTensor.
This commit is contained in:
@@ -1,26 +1,28 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
tracer.py:
|
||||
tracer.py:
|
||||
Implemented a tracer which supports control flow and user-defined meta arguments.
|
||||
The implementation is partly inspired HuggingFace's fx tracer
|
||||
"""
|
||||
import enum
|
||||
import inspect
|
||||
import functools
|
||||
import inspect
|
||||
import operator
|
||||
from contextlib import contextmanager
|
||||
from colossalai.fx.tracer.meta_patch import meta_patched_module
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.fx import Tracer, Node
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.proxy import Proxy, ParameterProxy
|
||||
from torch.fx import Node, Tracer
|
||||
from torch.fx.graph import Graph, magic_methods, reflectable_magic_methods
|
||||
from torch.fx.proxy import ParameterProxy, Proxy
|
||||
|
||||
from colossalai.fx.tracer.meta_patch import meta_patched_module
|
||||
|
||||
from ..proxy import ColoProxy
|
||||
from typing import Optional, Dict, Any
|
||||
from ._tracer_utils import is_element_in_list, extract_meta, compute_meta_data_for_functions_proxy
|
||||
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
|
||||
from .meta_patch import meta_patched_function, meta_patched_module
|
||||
from torch.fx.graph import magic_methods, reflectable_magic_methods
|
||||
|
||||
__all__ = ['ColoTracer']
|
||||
|
||||
@@ -231,7 +233,7 @@ class ColoTracer(Tracer):
|
||||
|
||||
Args:
|
||||
root (nn.Module): a `nn.Module` object to trace the computation graph
|
||||
meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph.
|
||||
meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph.
|
||||
These arguments are the sample data fed to the model during actual computation, but just converted to meta tensors.
|
||||
concrete_args (Optional[Dict[str, Tensor]]): the concrete arguments that should not be treated as Proxies.
|
||||
"""
|
||||
@@ -383,7 +385,7 @@ class ColoTracer(Tracer):
|
||||
|
||||
if self.inside_torch_checkpoint_func:
|
||||
# annotate the activation checkpoint module
|
||||
setattr(node, 'activation_checkpoint', self.act_ckpt_region_count)
|
||||
node.meta['activation_checkpoint'] = self.act_ckpt_region_count
|
||||
return node
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user