[autochunk] support transformer (#2526)

This commit is contained in:
oahzxl
2023-01-31 16:00:06 +08:00
committed by GitHub
parent 6e0faa70e0
commit 63199c6687
20 changed files with 1214 additions and 1084 deletions

View File

@@ -3,9 +3,12 @@ from typing import Any, Dict, Iterable, List, Tuple
import torch
import colossalai
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
if CODEGEN_AVAILABLE:
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
if AUTOCHUNK_AVAILABLE:
from torch.fx.graph import (
CodeGen,
PythonCode,
@@ -272,7 +275,7 @@ def emit_code_with_chunk(
node_idx += 1
if CODEGEN_AVAILABLE:
if AUTOCHUNK_AVAILABLE:
class AutoChunkCodeGen(CodeGen):