[shardformer] adapted llama to the new API (#4036)

This commit is contained in:
Frank Lee
2023-06-19 13:53:17 +08:00
parent 74d176c8d8
commit c1d5453e9f
9 changed files with 238 additions and 201 deletions

View File

@@ -1,8 +1,6 @@
from typing import Any, Callable, Dict, List
import torch
import torch.nn as nn
from transformers.pytorch_utils import Conv1D
from colossalai.cluster.process_group_manager import ProcessGroupManager
@@ -41,10 +39,10 @@ class ModelSharder(object):
"""
self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config)
self.preprocess()
self.replace_model_class()
self.replace_module()
self.postprocess()
self._preprocess()
self._replace_model_class()
self._replace_module()
self._postprocess()
def reshape_embedding(self) -> None:
r"""
@@ -57,13 +55,13 @@ class ModelSharder(object):
self.model.resize_token_embeddings(new_vocab_size)
self.model_config = self.model.config
def preprocess(self) -> None:
def _preprocess(self) -> None:
self.model = self.policy.preprocess()
def postprocess(self) -> None:
def _postprocess(self) -> None:
self.model = self.policy.postprocess()
def replace_model_class(self) -> None:
def _replace_model_class(self,) -> None:
r"""
Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model
@@ -84,7 +82,7 @@ class ModelSharder(object):
getattr(new_model_class, key),
)
def replace_module(self) -> None:
def _replace_module(self,) -> None:
r"""
Replace the module according to the policy, and replace the module one by one