mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[autoparallel] user-friendly API for CheckpointSolver. (#1879)
Merge for SC tutorial
This commit is contained in:
@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
||||
@@ -17,13 +18,17 @@ def _copy_output(src: Graph, dst: Graph):
|
||||
n_dst.meta = n_src.meta
|
||||
|
||||
|
||||
def _get_param_size(module: torch.nn.Module):
|
||||
"""Get the size of the parameters in the module"""
|
||||
return sum([p.numel() * torch.tensor([], dtype=p.dtype).element_size() for p in module.parameters()])
|
||||
|
||||
|
||||
class CheckpointSolverBase(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
memory_budget: float = -1.0,
|
||||
parameter_size: float = 0,
|
||||
free_memory: float = -1.0,
|
||||
requires_linearize: bool = False,
|
||||
cnode: List[str] = None,
|
||||
):
|
||||
@@ -37,8 +42,7 @@ class CheckpointSolverBase(ABC):
|
||||
|
||||
Args:
|
||||
graph (Graph): The computing graph to be optimized.
|
||||
memory_budget (float): Memory constraint for the solution.
|
||||
parameter_size (float): The size of parameter of this model. Use `parameter_size(model)` to estimate.
|
||||
free_memory (float): Memory constraint for the solution.
|
||||
requires_linearize (bool): Whether the graph needs to be linearized.
|
||||
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
|
||||
|
||||
@@ -58,8 +62,8 @@ class CheckpointSolverBase(ABC):
|
||||
raise RuntimeError(
|
||||
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!")
|
||||
|
||||
self.memory_budget = memory_budget
|
||||
self.parameter_size = parameter_size
|
||||
self.free_memory = free_memory
|
||||
self.parameter_size = _get_param_size(self.graph.owning_module)
|
||||
self.cnode = cnode
|
||||
self.requires_linearize = requires_linearize
|
||||
if self.requires_linearize:
|
||||
|
Reference in New Issue
Block a user