mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 04:33:04 +00:00
[shardformer] adapted llama to the new API (#4036)
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from colossalai.cluster.process_group_manager import ProcessGroupManager
|
||||
|
||||
@@ -41,10 +39,10 @@ class ModelSharder(object):
|
||||
"""
|
||||
self.policy.set_model(self.model)
|
||||
self.policy.set_shard_config(self.shard_config)
|
||||
self.preprocess()
|
||||
self.replace_model_class()
|
||||
self.replace_module()
|
||||
self.postprocess()
|
||||
self._preprocess()
|
||||
self._replace_model_class()
|
||||
self._replace_module()
|
||||
self._postprocess()
|
||||
|
||||
def reshape_embedding(self) -> None:
|
||||
r"""
|
||||
@@ -57,13 +55,13 @@ class ModelSharder(object):
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
self.model_config = self.model.config
|
||||
|
||||
def preprocess(self) -> None:
|
||||
def _preprocess(self) -> None:
|
||||
self.model = self.policy.preprocess()
|
||||
|
||||
def postprocess(self) -> None:
|
||||
def _postprocess(self) -> None:
|
||||
self.model = self.policy.postprocess()
|
||||
|
||||
def replace_model_class(self) -> None:
|
||||
def _replace_model_class(self,) -> None:
|
||||
r"""
|
||||
Replace the model to policy defined model
|
||||
Mainly modify the forward and backward to fit distributed model
|
||||
@@ -84,7 +82,7 @@ class ModelSharder(object):
|
||||
getattr(new_model_class, key),
|
||||
)
|
||||
|
||||
def replace_module(self) -> None:
|
||||
def _replace_module(self,) -> None:
|
||||
r"""
|
||||
Replace the module according to the policy, and replace the module one by one
|
||||
|
||||
|
Reference in New Issue
Block a user