From 773d9f964a34a4aa905286a4a0a0a6ddb9de281d Mon Sep 17 00:00:00 2001
From: flybird11111 <1829166702@qq.com>
Date: Fri, 28 Jun 2024 11:20:04 +0800
Subject: [PATCH 01/15] [shardformer]delete xformers (#5859)
* delete xformers
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
colossalai/shardformer/modeling/bert.py | 110 -------------
colossalai/shardformer/modeling/bloom.py | 87 -----------
colossalai/shardformer/modeling/sam.py | 165 --------------------
colossalai/shardformer/policies/bert.py | 12 --
colossalai/shardformer/policies/bloom.py | 13 +-
colossalai/shardformer/policies/sam.py | 20 ---
docs/source/zh-Hans/features/shardformer.md | 12 +-
7 files changed, 7 insertions(+), 412 deletions(-)
diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py
index e7679f0ec..7710b56e7 100644
--- a/colossalai/shardformer/modeling/bert.py
+++ b/colossalai/shardformer/modeling/bert.py
@@ -1,4 +1,3 @@
-import math
import warnings
from typing import List, Optional, Tuple, Union
@@ -1005,115 +1004,6 @@ class BertPipelineForwards:
return {"hidden_states": hidden_states}
-def get_bert_flash_attention_forward():
- try:
- from xformers.ops import memory_efficient_attention as me_attention
- except:
- raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
- from transformers.models.bert.modeling_bert import BertAttention
-
- def forward(
- self: BertAttention,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.FloatTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
- output_attentions: Optional[bool] = False,
- ) -> Tuple[torch.Tensor]:
- mixed_query_layer = self.query(hidden_states)
-
- # If this is instantiated as a cross-attention module, the keys
- # and values come from an encoder; the attention mask needs to be
- # such that the encoder's padding tokens are not attended to.
- is_cross_attention = encoder_hidden_states is not None
-
- if is_cross_attention and past_key_value is not None:
- # reuse k,v, cross_attentions
- key_layer = past_key_value[0]
- value_layer = past_key_value[1]
- attention_mask = encoder_attention_mask
- elif is_cross_attention:
- key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
- value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
- attention_mask = encoder_attention_mask
- elif past_key_value is not None:
- key_layer = self.transpose_for_scores(self.key(hidden_states))
- value_layer = self.transpose_for_scores(self.value(hidden_states))
- key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
- value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
- else:
- key_layer = self.transpose_for_scores(self.key(hidden_states))
- value_layer = self.transpose_for_scores(self.value(hidden_states))
-
- query_layer = self.transpose_for_scores(mixed_query_layer)
-
- use_cache = past_key_value is not None
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_layer, value_layer)
-
- final_attention_mask = None
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
- query_length, key_length = query_layer.shape[2], key_layer.shape[2]
- if use_cache:
- position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1)
- else:
- position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
- position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
- distance = position_ids_l - position_ids_r
-
- positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
- positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
-
- if self.position_embedding_type == "relative_key":
- relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
- final_attention_mask = relative_position_scores
- elif self.position_embedding_type == "relative_key_query":
- relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
- relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
- final_attention_mask = relative_position_scores_query + relative_position_scores_key
-
- scale = 1 / math.sqrt(self.attention_head_size)
- if attention_mask is not None:
- if final_attention_mask != None:
- final_attention_mask = final_attention_mask * scale + attention_mask
- else:
- final_attention_mask = attention_mask
-
- if final_attention_mask is not None:
- batch_size, src_len = query_layer.size()[0], query_layer.size()[2]
- tgt_len = key_layer.size()[2]
- final_attention_mask = final_attention_mask.expand(
- batch_size, self.num_attention_heads, src_len, tgt_len
- ).contiguous()
-
- query_layer = query_layer.permute(0, 2, 1, 3).contiguous()
- key_layer = key_layer.permute(0, 2, 1, 3).contiguous()
- value_layer = value_layer.permute(0, 2, 1, 3).contiguous()
-
- context_layer = me_attention(
- query_layer, key_layer, value_layer, attn_bias=final_attention_mask, p=self.dropout.p, scale=scale
- )
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(new_context_layer_shape)
-
- outputs = (context_layer, None)
-
- if self.is_decoder:
- outputs = outputs + (past_key_value,)
- return outputs
-
- return forward
-
-
def get_jit_fused_bert_self_output_forward():
from transformers.models.bert.modeling_bert import BertSelfOutput
diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py
index 1f34215c5..154143626 100644
--- a/colossalai/shardformer/modeling/bloom.py
+++ b/colossalai/shardformer/modeling/bloom.py
@@ -714,93 +714,6 @@ class BloomPipelineForwards:
return {"hidden_states": hidden_states}
-def get_bloom_flash_attention_forward(enable_jit_fused=False):
- try:
- from xformers.ops import memory_efficient_attention as me_attention
- except:
- raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
- from transformers.models.bloom.modeling_bloom import BloomAttention
-
- def forward(
- self: BloomAttention,
- hidden_states: torch.Tensor,
- residual: torch.Tensor,
- alibi: torch.Tensor,
- attention_mask: torch.Tensor,
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- head_mask: Optional[torch.Tensor] = None,
- use_cache: bool = False,
- output_attentions: bool = False,
- ):
- fused_qkv = self.query_key_value(hidden_states)
- (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
- batch_size, tgt_len, _, _ = query_layer.size()
-
- _, kv_length, _, _ = key_layer.size()
-
- proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim)
- query_layer = query_layer.contiguous().view(*proj_shape)
- key_layer = key_layer.contiguous().view(*proj_shape)
- value_layer = value_layer.contiguous().view(*proj_shape)
-
- if layer_past is not None:
- past_key, past_value = layer_past
- # concatenate along seq_length dimension:
- # - key: [batch_size * self.num_heads, head_dim, kv_length]
- # - value: [batch_size * self.num_heads, kv_length, head_dim]
- key_layer = torch.cat((past_key, key_layer), dim=1)
- value_layer = torch.cat((past_value, value_layer), dim=1)
-
- if use_cache is True:
- present = (key_layer, value_layer)
- else:
- present = None
-
- tgt_len = key_layer.size()[1]
-
- attention_numerical_mask = torch.zeros(
- (batch_size, self.num_heads, tgt_len, kv_length),
- dtype=torch.float32,
- device=query_layer.device,
- requires_grad=True,
- )
- attention_numerical_mask = (
- attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
- )
- attention_numerical_mask = torch.masked_fill(
- attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min
- )
- attention_numerical_mask = attention_numerical_mask.to(query_layer.dtype)
-
- context_layer = me_attention(
- query_layer,
- key_layer,
- value_layer,
- attn_bias=attention_numerical_mask,
- scale=self.inv_norm_factor,
- p=self.attention_dropout.p,
- )
- context_layer = context_layer.reshape(-1, kv_length, self.hidden_size)
- if self.pretraining_tp > 1 and self.slow_but_exact:
- slices = self.hidden_size / self.pretraining_tp
- output_tensor = torch.zeros_like(context_layer)
- for i in range(self.pretraining_tp):
- output_tensor = output_tensor + F.linear(
- context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
- self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
- )
- else:
- output_tensor = self.dense(context_layer)
-
- # TODO to replace with the bias_dropout_add function in jit
- output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
- outputs = (output_tensor, present, None)
-
- return outputs
-
- return forward
-
-
def get_jit_fused_bloom_attention_forward():
from transformers.models.bloom.modeling_bloom import BloomAttention
diff --git a/colossalai/shardformer/modeling/sam.py b/colossalai/shardformer/modeling/sam.py
index 26e0b224d..49fce0556 100644
--- a/colossalai/shardformer/modeling/sam.py
+++ b/colossalai/shardformer/modeling/sam.py
@@ -1,9 +1,4 @@
-import math
-from typing import Tuple
-
import torch
-import torch.nn.functional as F
-from torch import Tensor
def forward_fn():
@@ -45,163 +40,3 @@ def forward_fn():
return outputs
return forward
-
-
-def get_sam_flash_attention_forward():
- from transformers.models.sam.modeling_sam import SamAttention
-
- try:
- from xformers.ops import memory_efficient_attention as me_attention
- except:
- raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
-
- def _separate_heads(hidden_states: Tensor, num_attention_heads: int) -> Tensor:
- batch, point_batch_size, n_tokens, channel = hidden_states.shape
- c_per_head = channel // num_attention_heads
- hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
- return hidden_states
-
- def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor:
- batch, n_tokens, n_heads, c_per_head = hidden_states.shape
- return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
-
- def forward(
- self: SamAttention, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None
- ) -> Tensor:
- # Input projections
- query = self.q_proj(query)
- key = self.k_proj(key)
- value = self.v_proj(value)
-
- point_batch_size = query.shape[1]
- # Separate into heads
- query = _separate_heads(query, self.num_attention_heads)
- key = _separate_heads(key, self.num_attention_heads)
- value = _separate_heads(value, self.num_attention_heads)
-
- # SamAttention
- _, _, _, c_per_head = query.shape
- bias = None
- if attention_similarity is not None:
- bias = attention_similarity
-
- scale = 1.0 / math.sqrt(c_per_head)
- out = me_attention(query, key, value, attn_bias=bias, scale=scale)
-
- out = _recombine_heads(out, point_batch_size)
- out = self.out_proj(out)
-
- return out
-
- return forward
-
-
-def get_sam_vision_flash_attention_forward():
- from transformers.models.sam.modeling_sam import SamVisionAttention
-
- try:
- from xformers.ops import memory_efficient_attention as me_attention
- except:
- raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
-
- def add_decomposed_rel_pos(
- query: torch.Tensor,
- rel_pos_h: torch.Tensor,
- rel_pos_w: torch.Tensor,
- q_size: Tuple[int, int],
- k_size: Tuple[int, int],
- ) -> torch.Tensor:
- """
- Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
- https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
-
- Args:
- attn (`torch.Tensor`):
- attention map.
- query (`torch.Tensor`):
- query q in the attention layer with shape (batch_size, query_height * query_width, channel).
- rel_pos_h (`torch.Tensor`):
- relative position embeddings (Lh, channel) for height axis.
- rel_pos_w (`torch.Tensor`):
- relative position embeddings (Lw, channel) for width axis.
- q_size (tuple):
- spatial sequence size of query q with (query_height, query_width).
- k_size (tuple):
- spatial sequence size of key k with (key_height, key_width).
-
- Returns:
- attn (`torch.Tensor`):
- attention map with added relative positional embeddings.
- """
-
- query_height, query_width = q_size
- key_height, key_width = k_size
- relative_position_height = get_rel_pos(query_height, key_height, rel_pos_h)
- relative_position_width = get_rel_pos(query_width, key_width, rel_pos_w)
-
- batch_size, _, nHead, dim = query.shape
- reshaped_query = query.transpose(1, 2).reshape(batch_size * nHead, query_height, query_width, dim)
- rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
- rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
- rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
- rel_pos = rel_pos.reshape(batch_size, nHead, query_height * query_width, key_height * key_width)
- return rel_pos
-
- def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
- """
- Get relative positional embeddings according to the relative positions of
- query and key sizes.
-
- Args:
- q_size (int):
- size of the query.
- k_size (int):
- size of key k.
- rel_pos (`torch.Tensor`):
- relative position embeddings (L, channel).
-
- Returns:
- Extracted positional embeddings according to relative positions.
- """
- max_rel_dist = int(2 * max(q_size, k_size) - 1)
- # Interpolate rel pos.
- rel_pos_resized = F.interpolate(
- rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
- size=max_rel_dist,
- mode="linear",
- )
- rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
-
- # Scale the coords with short length if shapes for q and k are different.
- q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
- k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
- relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
-
- return rel_pos_resized[relative_coords.long()]
-
- def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
- batch_size, height, width, _ = hidden_states.shape
- # qkv with shape (3, batch_size, nHead, height * width, channel)
- qkv = (
- self.qkv(hidden_states)
- .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
- .permute(2, 0, 1, 3, 4)
- )
-
- query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0)
-
- rel_pos = None
- if self.use_rel_pos:
- rel_pos = add_decomposed_rel_pos(query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width))
-
- attn_output = me_attention(query, key, value, attn_bias=rel_pos, p=self.dropout, scale=self.scale)
-
- attn_output = attn_output.reshape(batch_size, height, width, -1)
-
- attn_output = self.proj(attn_output)
-
- outputs = (attn_output, None)
-
- return outputs
-
- return forward
diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py
index c11ed99ac..b84a372a5 100644
--- a/colossalai/shardformer/policies/bert.py
+++ b/colossalai/shardformer/policies/bert.py
@@ -11,7 +11,6 @@ import colossalai.shardformer.layer as col_nn
from ..modeling.bert import (
BertPipelineForwards,
bert_sequence_parallel_forward_fn,
- get_bert_flash_attention_forward,
get_jit_fused_bert_intermediate_forward,
get_jit_fused_bert_output_forward,
get_jit_fused_bert_self_output_forward,
@@ -49,7 +48,6 @@ class BertPolicy(Policy):
BertLayer,
BertModel,
BertOutput,
- BertSelfAttention,
BertSelfOutput,
)
@@ -218,16 +216,6 @@ class BertPolicy(Policy):
target_key=BertEmbeddings,
)
- # use flash attention
- if self.shard_config.enable_flash_attention:
- self.append_or_create_method_replacement(
- description={
- "forward": get_bert_flash_attention_forward(),
- },
- policy=policy,
- target_key=BertSelfAttention,
- )
-
# use jit operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(
diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py
index 20a75cf90..d80adb84a 100644
--- a/colossalai/shardformer/policies/bloom.py
+++ b/colossalai/shardformer/policies/bloom.py
@@ -11,14 +11,13 @@ import colossalai.shardformer.layer as col_nn
from ..modeling.bloom import (
BloomPipelineForwards,
build_bloom_alibi_tensor_fn,
- get_bloom_flash_attention_forward,
get_bloom_sequence_parallel_forward_fn,
get_jit_fused_bloom_attention_forward,
get_jit_fused_bloom_gelu_forward,
get_jit_fused_bloom_mlp_forward,
get_lm_forward_with_dist_cross_entropy,
)
-from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func
+from ..modeling.jit import get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -165,16 +164,6 @@ class BloomPolicy(Policy):
target_key=BloomModel,
)
- if self.shard_config.enable_flash_attention:
- self.append_or_create_method_replacement(
- description={
- "forward": get_bloom_flash_attention_forward(),
- "dropout_add": get_dropout_add_func(),
- },
- policy=policy,
- target_key=BloomAttention,
- )
-
# enable jit fused operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(
diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py
index c224d7769..53faf8997 100644
--- a/colossalai/shardformer/policies/sam.py
+++ b/colossalai/shardformer/policies/sam.py
@@ -1,5 +1,3 @@
-import warnings
-
import colossalai.shardformer.layer as col_nn
from ..modeling.sam import forward_fn
@@ -212,24 +210,6 @@ class SamPolicy(Policy):
target_key=SamTwoWayTransformer,
)
- # use flash attention
- if self.shard_config.enable_flash_attention:
- warnings.warn("Flash attention is not supported in SAM model. Fallback to normal attention.")
- # self.append_or_create_method_replacement(
- # description={
- # "forward": get_sam_flash_attention_forward(),
- # },
- # policy=policy,
- # target_key=SamAttention,
- # )
- # self.append_or_create_method_replacement(
- # description={
- # "forward": get_sam_vision_flash_attention_forward(),
- # },
- # policy=policy,
- # target_key=SamVisionAttention,
- # )
-
return policy
def postprocess(self):
diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md
index a42c7cc2e..00e1a13d6 100644
--- a/docs/source/zh-Hans/features/shardformer.md
+++ b/docs/source/zh-Hans/features/shardformer.md
@@ -71,8 +71,8 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
✔️ |
✔️ |
✔️ |
- ✔️ |
- ✔️ |
+ ❌ |
+ ❌ |
✔️ |
✔️ |
✔️ |
@@ -95,8 +95,8 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
✔️ |
✔️ |
✔️ |
- ✔️ |
- ✔️ |
+ ❌ |
+ ❌ |
✔️ |
✔️ |
✔️ |
@@ -155,8 +155,8 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
✔️ |
❌ |
❌ |
- ✔️ |
- ✔️ |
+ ❌ |
+ ❌ |
✔️ |
✔️ |
❌ |
From 416580b3142457f1b210147e8611756eef1687ad Mon Sep 17 00:00:00 2001
From: Haze188
Date: Fri, 28 Jun 2024 14:00:08 +0800
Subject: [PATCH 02/15] [MoE/ZeRO] Moe refactor with zero refactor (#5821)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* [moe] removed openmoe-coupled code and rectify mixstral code (#5471)
* [Feauture] MoE refractor; Intergration with Mixtral (#5682)
* cherry pick from refractor-moe branch
* tests passed
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support ep + zero
---------
Co-authored-by: Edenzzzz
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* add mixtral auto policy & move pipeline forward code to modeling folder
* [moe refactor] modify kernel test without Route Class
* [moe refactor] add moe tensor test path environment variable to github workflow
* fix typos
* fix moe test bug due to the code rebase
* [moe refactor] fix moe zero test, and little bug in low level zero
* fix typo
* add moe tensor path to github workflow
* remove some useless code
* fix typo & unify global variable XX_AXIS logic without using -1
* fix typo & prettifier the code
* remove print code & support zero 2 test
* remove useless code
* reanme function
* fix typo
* fix typo
* Further improve the test code
* remove print code
* [moe refactor] change test model from fake moe model to mixtral moe layer and remove useless test
* [moe refactor] skip some unit test which will be refactored later
* [moe refactor] fix unit import error
* [moe refactor] fix circular import issues
* [moe refactor] remove debug code
* [moe refactor] update github workflow
* [moe/zero] refactor low level optimizer (#5767)
* [zero] refactor low level optimizer
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] MoE refactor with newest version of ZeRO (#5801)
* [zero] remove redundant members in BucketStore (#5802)
* [zero] align api with previous version
* [Moe/Zero] Update MoeHybridParallelPlugin with refactored ZeRO and Fix Zero bug (#5819)
* [moe refactor] update unit test with the refactored ZeRO and remove useless test
* move moe checkpoint to checkpoint folder and exchange global axis to class member
* update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug
* fix zero unit test
* Add an assertion to prevent users from using it incorrectly
* [hotfix]Solve the compatibility issue of zero refactor (#5823)
* [moe refactor] update unit test with the refactored ZeRO and remove useless test
* move moe checkpoint to checkpoint folder and exchange global axis to class member
* update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug
* fix zero unit test
* Add an assertion to prevent users from using it incorrectly
* Modify function parameter names to resolve compatibility issues
* [zero] fix missing hook removal (#5824)
* [MoE] Resolve .github conflict (#5829)
* [Fix/Example] Fix Llama Inference Loading Data Type (#5763)
* [fix/example] fix llama inference loading dtype
* revise loading dtype of benchmark llama3
* [release] update version (#5752)
* [release] update version
* [devops] update compatibility test
* [devops] update compatibility test
* [devops] update compatibility test
* [devops] update compatibility test
* [test] fix ddp plugin test
* [test] fix gptj and rpc test
* [devops] fix cuda ext compatibility
* [inference] fix flash decoding test
* [inference] fix flash decoding test
* fix (#5765)
* [test] Fix/fix testcase (#5770)
* [fix] branch for fix testcase;
* [fix] fix test_analyzer & test_auto_parallel;
* [fix] remove local change about moe;
* [fix] rm local change moe;
* [Hotfix] Add missing init file in inference.executor (#5774)
* [CI/tests] simplify some test case to reduce testing time (#5755)
* [ci/tests] simplify some test case to reduce testing time
* [ci/tests] continue to remove test case to reduce ci time cost
* restore some test config
* [ci/tests] continue to reduce ci time cost
* [misc] update dockerfile (#5776)
* [misc] update dockerfile
* [misc] update dockerfile
* [devops] fix docker ci (#5780)
* [Inference]Add Streaming LLM (#5745)
* Add Streaming LLM
* add some parameters to llama_generation.py
* verify streamingllm config
* add test_streamingllm.py
* modified according to the opinions of review
* add Citation
* change _block_tables tolist
* [hotfix] fix llama flash attention forward (#5777)
* [misc] Accelerate CI for zero and dist optim (#5758)
* remove fp16 from lamb
* remove d2h copy in checking states
---------
Co-authored-by: Edenzzzz
* [Test/CI] remove test cases to reduce CI duration (#5753)
* [test] smaller gpt2 test case
* [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py
* [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py
* [test] reduce test cases tests/test_zero/test_gemini/test_optim.py
* Revert "[test] smaller gpt2 test case"
Some tests might depend on the size of model (num of chunks)
This reverts commit df705a5210b8901645992adf276e320e48766ebf.
* [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py
* [CI] smaller test model for two mwo the two modifid cases
* [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there
* [hotfix] fix testcase in test_fx/test_tracer (#5779)
* [fix] branch for fix testcase;
* [fix] fix test_analyzer & test_auto_parallel;
* [fix] remove local change about moe;
* [fix] rm local change moe;
* [fix] fix test_deepfm_model & test_dlrf_model;
* [fix] fix test_hf_albert & test_hf_gpt;
* [gemini] optimize reduce scatter d2h copy (#5760)
* [gemini] optimize reduce scatter d2h copy
* [fix] fix missing reduce variable
* [refactor] remove legacy async reduce scatter code
* [gemini] missing sync
* Revert "[refactor] remove legacy async reduce scatter code"
This reverts commit 58ad76d4665032bbe548d066116d1c572ce98979.
* [gemini] further optimize with async all reduce
* [fix] pass flag from manager to chunk
* Allow building cuda extension without a device. (#5535)
Added FORCE_CUDA environment variable support, to enable building extensions where a GPU device is not present but cuda libraries are.
* [misc] fix dist logger (#5782)
* [install]fix setup (#5786)
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] update requirements (#5787)
* [shardformer] fix import (#5788)
* upgrade colossal-chat support tp_group>1, add sp for sft
* upgrade ppo dpo rm script
* run pre-commit
* moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy
* fix training script
* fix ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix transformers version
* remove duplicated test
* fix datasets version
* remove models that require huggingface auth from ci
* remove local data path
* update ci
* remove baichuan from template test due to transformer version conflict
* merge
* Refactor modeling by adding attention backend
Signed-off-by: char-1ee
* Fix tests and naming
Signed-off-by: char-1ee
* Pass inference model shard configs for module init
Signed-off-by: char-1ee
* Clean up
Signed-off-by: char-1ee
* replace the customized dataloader setup with the build-in one
* replace the customized dataloader setup with the build-in one
* Remove flash attention backend
Signed-off-by: char-1ee
* fix readme
* Fix test import
Signed-off-by: char-1ee
* update sft trainning script
* [Inference]refactor baichuan (#5791)
* refactor baichuan
* remove unused code and add TODO for lazyinit
* [test] fix chatglm test kit (#5793)
* [shardformer] fix modeling of bloom and falcon (#5796)
* [test] fix qwen2 pytest distLarge (#5797)
* [Inference] Fix flash-attn import and add model test (#5794)
* Fix torch int32 dtype
Signed-off-by: char-1ee
* Fix flash-attn import
Signed-off-by: char-1ee
* Add generalized model test
Signed-off-by: char-1ee
* Remove exposed path to model
Signed-off-by: char-1ee
* Add default value for use_flash_attn
Signed-off-by: char-1ee
* Rename model test
Signed-off-by: char-1ee
---------
Signed-off-by: char-1ee
* [Gemini] Use async stream to prefetch and h2d data moving (#5781)
* use async stream to prefetch and h2d data moving
* Remove redundant code
* [gemini] quick fix on possible async operation (#5803)
* [gemini] quick fix on possible async operation
* [gemini] quick fix on possible async operation
* [shardformer] upgrade transformers to 4.39.3 (#5815)
* [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807)
* [shardformer] fix modeling of gpt2 and gptj
* [shardformer] fix whisper modeling
* [misc] update requirements
---------
Co-authored-by: ver217
* [shardformer]upgrade transformers for mistral (#5808)
* upgrade transformers for mistral
* fix
* fix
* [shardformer]upgrade transformers for llama (#5809)
* update transformers
fix
* fix
* fix
* [inference] upgrade transformers (#5810)
* update transformers
fix
* fix
* fix
* fix
* fix
* [gemini] update transformers for gemini (#5814)
---------
Co-authored-by: ver217
* Support 4d parallel + flash attention (#5789)
* support tp + sp + pp
* remove comments
---------
Co-authored-by: Edenzzzz
---------
Signed-off-by: char-1ee
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Hongxin Liu
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Edenzzzz
Co-authored-by: Edenzzzz
Co-authored-by: botbw
Co-authored-by: Charles Coulombe
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YeAnbang
Co-authored-by: char-1ee
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Guangyao Zhang
* [zero] fix hook bug
* [zero] add low level optimizer back (#5839)
* [zero] fix param & refactor
* [zero] add back original low level opt
* [zero] remove moe related
* [zero] pass zero tests
* [zero] refactor
* [chore] add del func back
* [zero] comments and naming (#5840)
* [zero] modify api (#5843)
* [zero] modify api
* [test] remove _grad_store access in tests
* [test] fix (#5857)
* [CI] skip openmoe CI check
* [CI] fox pre-commit
* [zero] remove redundant memebr init (#5862)
* [misc] remove useless code, modify the pg mesh implementation
* [misc] remove useless code, modify the pg mesh implementation
* [misc] use tempfile
* resolve conflict with main branch
* [misc] use tempfile in test_moe_checkpoint.py
* [misc] remove useless code, add assertion about sequence parallel, move logger into function
* [misc] remove useless code
---------
Signed-off-by: char-1ee
Co-authored-by: Frank Lee
Co-authored-by: Edenzzzz
Co-authored-by: Edenzzzz
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: botbw
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Hongxin Liu
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Charles Coulombe
Co-authored-by: YeAnbang
Co-authored-by: char-1ee
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Guangyao Zhang
---
.github/workflows/build_on_pr.yml | 3 +-
.github/workflows/build_on_schedule.yml | 3 +-
.../compatiblity_test_on_dispatch.yml | 3 +-
.github/workflows/compatiblity_test_on_pr.yml | 3 +-
.../compatiblity_test_on_schedule.yml | 3 +-
.../ColossalMoE/colossal_moe/__init__.py | 0
.../colossal_moe/models/__init__.py | 0
.../colossal_moe/models/mixtral_layer.py | 92 --
applications/ColossalMoE/infer.py | 4 -
applications/ColossalMoE/infer.sh | 3 +-
.../ColossalMoE/tests/test_moe_checkpoint.py | 146 ----
applications/ColossalMoE/train.py | 6 +-
.../ColossalMoE/{colossal_moe => }/utils.py | 0
.../colossalqa/local/colossalcloud_llm.py | 1 +
.../booster/plugin/hybrid_parallel_plugin.py | 23 +-
.../plugin/moe_hybrid_parallel_plugin.py | 147 ++--
colossalai/checkpoint_io/__init__.py | 9 +-
.../hybrid_parallel_checkpoint_io.py | 14 +-
.../checkpoint_io/moe_checkpoint.py | 319 ++++++-
colossalai/checkpoint_io/utils.py | 1 +
colossalai/cluster/process_group_mesh.py | 12 +-
colossalai/moe/__init__.py | 15 -
colossalai/moe/checkpoint.py | 792 ------------------
colossalai/moe/load_balance.py | 6 +-
colossalai/moe/loss.py | 78 --
colossalai/moe/routers.py | 466 -----------
colossalai/moe/utils.py | 9 +-
colossalai/shardformer/layer/moe/__init__.py | 3 +
.../{ => shardformer/layer}/moe/experts.py | 4 +-
.../{ => shardformer/layer}/moe/layers.py | 23 +-
colossalai/shardformer/layer/moe/routers.py | 161 ++++
.../shardformer/modeling/mixtral.py | 290 ++-----
.../shardformer/policies/auto_policy.py | 10 +-
colossalai/shardformer/policies/mixtral.py | 210 +++++
colossalai/shardformer/shard/shard_config.py | 1 +
colossalai/tensor/moe_tensor/api.py | 11 +-
.../zero/low_level/bookkeeping/__init__.py | 3 +-
.../low_level/bookkeeping/bucket_store.py | 25 +-
.../low_level/bookkeeping/gradient_store.py | 13 +-
.../low_level/bookkeeping/parameter_store.py | 60 --
colossalai/zero/low_level/low_level_optim.py | 735 +++++++---------
.../openmoe/benchmark/benchmark_cai.py | 2 +-
.../openmoe/model/modeling_openmoe.py | 10 +-
.../language/openmoe/model/openmoe_policy.py | 1 +
examples/language/openmoe/test_ci.sh | 60 +-
examples/language/openmoe/train.py | 46 +-
.../test_low_level_zero_checkpoint_io.py | 12 +-
tests/test_moe/moe_utils.py | 38 +-
tests/test_moe/test_grad_handler.py | 4 +-
tests/test_moe/test_kernel.py | 136 ++-
.../test_moe}/test_mixtral_layer.py | 13 +-
tests/test_moe/test_moe_checkpoint.py | 313 ++++---
tests/test_moe/test_moe_ep_tp.py | 10 +-
tests/test_moe/test_moe_group.py | 4 +-
tests/test_moe/test_moe_hybrid_zero.py | 1 +
tests/test_moe/test_moe_load_balance.py | 4 +-
tests/test_moe/test_moe_router.py | 47 --
tests/test_moe/test_moe_zero_fwd_bwd.py | 78 --
tests/test_moe/test_moe_zero_fwd_bwd_optim.py | 132 +++
tests/test_moe/test_moe_zero_optim.py | 83 --
tests/test_optimizer/_utils.py | 2 +-
tests/test_optimizer/test_dist_adafactor.py | 2 +-
tests/test_optimizer/test_dist_came.py | 2 +-
tests/test_optimizer/test_dist_lamb.py | 2 +-
.../test_zero_optimizer.py | 5 +-
.../test_model/test_shard_command.py | 6 +-
.../test_model/test_shard_llama.py | 8 +-
.../test_zero/test_low_level/test_mem_leak.py | 61 ++
.../test_zero/test_low_level/test_zero1_2.py | 67 +-
69 files changed, 1780 insertions(+), 3076 deletions(-)
delete mode 100644 applications/ColossalMoE/colossal_moe/__init__.py
delete mode 100644 applications/ColossalMoE/colossal_moe/models/__init__.py
delete mode 100644 applications/ColossalMoE/colossal_moe/models/mixtral_layer.py
delete mode 100644 applications/ColossalMoE/tests/test_moe_checkpoint.py
rename applications/ColossalMoE/{colossal_moe => }/utils.py (100%)
rename applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py => colossalai/checkpoint_io/moe_checkpoint.py (66%)
delete mode 100644 colossalai/moe/checkpoint.py
delete mode 100644 colossalai/moe/loss.py
delete mode 100644 colossalai/moe/routers.py
create mode 100644 colossalai/shardformer/layer/moe/__init__.py
rename colossalai/{ => shardformer/layer}/moe/experts.py (98%)
rename colossalai/{ => shardformer/layer}/moe/layers.py (96%)
create mode 100644 colossalai/shardformer/layer/moe/routers.py
rename applications/ColossalMoE/colossal_moe/models/mixtral_policy.py => colossalai/shardformer/modeling/mixtral.py (65%)
create mode 100644 colossalai/shardformer/policies/mixtral.py
delete mode 100644 colossalai/zero/low_level/bookkeeping/parameter_store.py
rename {applications/ColossalMoE/tests => tests/test_moe}/test_mixtral_layer.py (81%)
delete mode 100644 tests/test_moe/test_moe_router.py
delete mode 100644 tests/test_moe/test_moe_zero_fwd_bwd.py
create mode 100644 tests/test_moe/test_moe_zero_fwd_bwd_optim.py
delete mode 100644 tests/test_moe/test_moe_zero_optim.py
create mode 100644 tests/test_zero/test_low_level/test_mem_leak.py
diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml
index adf4501bb..151454239 100644
--- a/.github/workflows/build_on_pr.yml
+++ b/.github/workflows/build_on_pr.yml
@@ -90,7 +90,7 @@ jobs:
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
- options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
+ options: --gpus all --rm -v /dev/shm -v /data/scratch:/data/scratch
timeout-minutes: 90
defaults:
run:
@@ -165,6 +165,7 @@ jobs:
env:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
+ MOE_TENSOR_PATH: /data/scratch/moe_tensors
- name: Collate artifact
env:
diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml
index e560d0c00..fc6424503 100644
--- a/.github/workflows/build_on_schedule.yml
+++ b/.github/workflows/build_on_schedule.yml
@@ -13,7 +13,7 @@ jobs:
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
- options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
+ options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 90
steps:
- name: Check GPU Availability # ensure all GPUs have enough memory
@@ -69,6 +69,7 @@ jobs:
env:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
+ MOE_TENSOR_PATH: /data/scratch/moe_tensors
- name: Notify Lark
id: message-preparation
diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml
index 9867ef7c6..3eee564c2 100644
--- a/.github/workflows/compatiblity_test_on_dispatch.yml
+++ b/.github/workflows/compatiblity_test_on_dispatch.yml
@@ -50,7 +50,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
- options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
+ options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 200
steps:
- name: Install dependencies
@@ -92,3 +92,4 @@ jobs:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
+ MOE_TENSOR_PATH: /data/scratch/moe_tensors
diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml
index 885d352d5..b418c843e 100644
--- a/.github/workflows/compatiblity_test_on_pr.yml
+++ b/.github/workflows/compatiblity_test_on_pr.yml
@@ -41,7 +41,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
- options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
+ options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 200
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }}
@@ -87,3 +87,4 @@ jobs:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
+ MOE_TENSOR_PATH: /data/scratch/moe_tensors
diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml
index 39e1f479c..8d98e775c 100644
--- a/.github/workflows/compatiblity_test_on_schedule.yml
+++ b/.github/workflows/compatiblity_test_on_schedule.yml
@@ -38,7 +38,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
- options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
+ options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 200
steps:
- name: Install dependencies
@@ -85,6 +85,7 @@ jobs:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
+ MOE_TENSOR_PATH: /data/scratch/moe_tensors
- name: Notify Lark
id: message-preparation
diff --git a/applications/ColossalMoE/colossal_moe/__init__.py b/applications/ColossalMoE/colossal_moe/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/applications/ColossalMoE/colossal_moe/models/__init__.py b/applications/ColossalMoE/colossal_moe/models/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py
deleted file mode 100644
index a2b78a2bd..000000000
--- a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py
+++ /dev/null
@@ -1,92 +0,0 @@
-import torch
-import torch.distributed as dist
-import torch.nn.functional as F
-from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
-
-from colossalai.lazy import LazyInitContext
-from colossalai.moe import MOE_MANAGER
-from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
-from colossalai.shardformer.shard.utils import set_tensors_to_none
-from colossalai.tensor.moe_tensor.api import set_moe_tensor_info
-
-
-class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
- def __init__(self, config):
- super().__init__(config)
- self.setup_ep()
-
- def setup_ep(self):
- _, moe_info = MOE_MANAGER.get_info(self.num_experts)
- ep_group = moe_info.ep_group
- self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
- self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
- assert self.num_experts % self.ep_size == 0
- self.ep_group = ep_group
- self.num_experts_per_ep = self.num_experts // self.ep_size
- self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
- held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
- set_tensors_to_none(self.experts, exclude=set(held_experts))
- for p in self.experts.parameters():
- set_moe_tensor_info(p, moe_info)
-
- @staticmethod
- def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
- LazyInitContext.materialize(module)
- module.__class__ = EPMixtralSparseMoeBlock
- module.setup_ep()
- return module
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- batch_size, sequence_length, hidden_dim = hidden_states.shape
- hidden_states = hidden_states.view(-1, hidden_dim)
- # router_logits: (batch * sequence_length, n_experts)
- router_logits = self.gate(hidden_states)
-
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
- # we cast back to the input dtype
- routing_weights = routing_weights.to(hidden_states.dtype)
-
- selected_experts = selected_experts.t().reshape(-1)
- selected_experts_idx = selected_experts.argsort()
- dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
- input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
- output_split_sizes = torch.zeros_like(input_split_sizes)
- dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
-
- input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
- output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
- output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
- # compute expert output
- output_states = MoeInGradScaler.apply(output_states, self.ep_size)
- if output_states.size(0) > 0:
- if self.num_experts_per_ep == 1:
- # no need to split
- expert = self.experts[self.expert_start_idx]
- output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
- output_states = expert.w2(output_states)
- else:
- output_states_splits = output_states.split(output_split_sizes.tolist())
- output_states_list = []
- for i, split_states in enumerate(output_states_splits):
- if split_states.size(0) == 0:
- continue
- expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
- split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
- split_states = expert.w2(split_states)
- output_states_list.append(split_states)
- output_states = torch.cat(output_states_list)
- output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
- dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
- recover_experts_idx = torch.empty_like(selected_experts_idx)
- recover_experts_idx[selected_experts_idx] = torch.arange(
- selected_experts_idx.size(0), device=selected_experts_idx.device
- )
- dispatch_states = dispatch_states[recover_experts_idx]
- k_hidden_states = dispatch_states.chunk(self.top_k)
- output_states = k_hidden_states[0] * routing_weights[:, 0, None]
- for i in range(1, self.top_k):
- output_states += k_hidden_states[i] * routing_weights[:, i, None]
- output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
- return output_states, router_logits
diff --git a/applications/ColossalMoE/infer.py b/applications/ColossalMoE/infer.py
index 543c434d2..6023e304d 100644
--- a/applications/ColossalMoE/infer.py
+++ b/applications/ColossalMoE/infer.py
@@ -2,8 +2,6 @@ import argparse
import torch
import torch.distributed as dist
-from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
-from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
@@ -70,8 +68,6 @@ def main():
ep_size=ep_size,
zero_stage=1,
precision=args.precision,
- custom_policy=MixtralForCausalLMPolicy(),
- checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
)
diff --git a/applications/ColossalMoE/infer.sh b/applications/ColossalMoE/infer.sh
index 0487fe9c1..ba4362d74 100644
--- a/applications/ColossalMoE/infer.sh
+++ b/applications/ColossalMoE/infer.sh
@@ -1,5 +1,6 @@
NUM_GPU=2
-MODEL="mistralai/Mixtral-8x7B-v0.1"
+# MODEL="mistralai/Mixtral-8x7B-v0.1"
+MODEL="mistralai/Mixtral-8x7B-Instruct-v0.1"
# ep
torchrun --standalone --nproc_per_node $NUM_GPU infer.py \
diff --git a/applications/ColossalMoE/tests/test_moe_checkpoint.py b/applications/ColossalMoE/tests/test_moe_checkpoint.py
deleted file mode 100644
index 074dbf835..000000000
--- a/applications/ColossalMoE/tests/test_moe_checkpoint.py
+++ /dev/null
@@ -1,146 +0,0 @@
-from copy import deepcopy
-
-import pytest
-import torch
-import torch.distributed as dist
-from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
-from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
-from torch.optim import Adam
-from transformers.models.mixtral.configuration_mixtral import MixtralConfig
-from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
-
-import colossalai
-from colossalai.booster import Booster
-from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
-from colossalai.testing.utils import spawn
-
-tokens, n_experts = 7, 4
-hidden_size = 8
-top_k = 2
-
-
-def check_model_equal(model1, model2):
- assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
- for p1, p2 in zip(model1.parameters(), model2.parameters()):
- assert torch.equal(p1.half(), p2.half())
-
-
-def get_optimizer_snapshot(optim):
- state = {id(k): deepcopy(v) for k, v in optim.state.items()}
- param_groups = []
- for group in optim.param_groups:
- params = [id(p) for p in group["params"]]
- new_group = {"params": params}
- for k, v in group.items():
- if k != "params":
- new_group[k] = v
- param_groups.append(new_group)
- return {
- "state": state,
- "param_groups": param_groups,
- }
-
-
-def check_optimizer_snapshot_equal(snapshot1, snapshot2):
- # check param_groups
- assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
- for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
- assert set(group1.keys()) == set(group2.keys())
- for k in group1.keys():
- assert group1[k] == group2[k]
- # check state
- assert set(snapshot1["state"].keys()) == set(
- snapshot2["state"].keys()
- ), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
- for pid in snapshot1["state"].keys():
- state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
- assert set(state1.keys()) == set(state2.keys())
- for k in state1.keys():
- if isinstance(state1[k], torch.Tensor):
- assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}"
- else:
- assert state1[k] == state2[k]
-
-
-def check_mixtral_moe_layer():
- torch.cuda.set_device(dist.get_rank())
- config = MixtralConfig(
- hidden_size=hidden_size,
- intermediate_size=hidden_size * 2,
- num_local_experts=n_experts,
- num_experts_per_tok=top_k,
- num_attention_heads=2,
- num_key_value_heads=2,
- )
- torch.manual_seed(0)
- input_ids = torch.randint(0, 100, (2, tokens)).cuda()
- orig_model = MixtralForCausalLM(config).cuda()
- model = deepcopy(orig_model)
- optimizer = Adam(model.parameters(), lr=1e-3)
- plugin = MoeHybridParallelPlugin(
- tp_size=1,
- pp_size=2,
- ep_size=2,
- custom_policy=MixtralForCausalLMPolicy(),
- checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
- microbatch_size=1,
- zero_stage=1,
- )
- booster = Booster(plugin=plugin)
- model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
- # initialize grads
- data_iter = iter(
- [{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
- )
- booster.execute_pipeline(
- data_iter,
- model,
- lambda outputs, inputs: outputs.loss,
- optimizer,
- )
-
- # check save model
- booster.save_model(model, "mixtral_model", shard=True)
- dist.barrier()
- if dist.get_rank() == 0:
- saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
- check_model_equal(orig_model, saved_model)
- saved_model.save_pretrained("mixtral_hf_model")
- dist.barrier()
-
- # check load model
- new_model = MixtralForCausalLM(config).cuda()
- new_optimizer = Adam(new_model.parameters(), lr=1e-3)
- new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
- booster.load_model(new_model, "mixtral_hf_model")
- check_model_equal(model, new_model)
-
- # check save optimizer
- optimizer.step()
- for group in optimizer.param_groups:
- group["lr"] = 0.1
- snapshot = get_optimizer_snapshot(optimizer.unwrap())
- booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
- dist.barrier()
- # reset optimizer state
- for state in optimizer.unwrap().state.values():
- for v in state.values():
- if isinstance(v, torch.Tensor):
- v.zero_()
- booster.load_optimizer(optimizer, "mixtral_optim")
- loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
- check_optimizer_snapshot_equal(snapshot, loaded_snapshot)
-
-
-def run_dist(rank: int, world_size: int, port: int):
- colossalai.launch(rank, world_size, "localhost", port)
- check_mixtral_moe_layer()
-
-
-@pytest.mark.parametrize("world_size", [4])
-def test_mixtral_moe_layer(world_size: int):
- spawn(run_dist, world_size)
-
-
-if __name__ == "__main__":
- test_mixtral_moe_layer(4)
diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py
index d2789d644..9cd810e5a 100644
--- a/applications/ColossalMoE/train.py
+++ b/applications/ColossalMoE/train.py
@@ -2,13 +2,11 @@ import argparse
import torch
import torch.distributed as dist
-from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
-from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
-from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralForCausalLM
+from utils import load_checkpoint, move_to_cuda, save_checkpoint
import colossalai
from colossalai.booster import Booster
@@ -155,12 +153,10 @@ def main():
pp_size=args.pp_size,
ep_size=args.ep_size,
microbatch_size=args.microbatch_size,
- custom_policy=MixtralForCausalLMPolicy(),
enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
precision=args.precision,
zero_stage=args.zero_stage,
- checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
)
else:
diff --git a/applications/ColossalMoE/colossal_moe/utils.py b/applications/ColossalMoE/utils.py
similarity index 100%
rename from applications/ColossalMoE/colossal_moe/utils.py
rename to applications/ColossalMoE/utils.py
diff --git a/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py b/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py
index 362977869..ca8d64f22 100644
--- a/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py
+++ b/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py
@@ -20,6 +20,7 @@ resp = llm(prompt='What do you call a three-ton kangaroo?', **gen_config)
print(resp) # super-heavyweight awesome-natured yawning Australian creature!
"""
+
import json
from typing import Any, Mapping
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index 3bd43f172..a3d6f1e74 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -655,7 +655,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
- self.dp_pg = dp_process_group
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
if use_pipeline:
@@ -718,7 +717,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
"""Retrieve all working gradients from different parameter groups."""
all_working_grads = []
for group_id in range(self.num_param_groups):
- working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
+ working_grads = self.get_working_grads_by_group_id(group_id)
all_working_grads.extend(working_grads)
return all_working_grads
@@ -726,7 +725,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
"""Identify gradients to be synchronized in the sequence parallelism."""
grads_to_sync = []
for grad in all_working_grads:
- param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
+ param_id_for_grad = self.get_param_id_for_grad(grad)
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad):
grads_to_sync.append(grad)
@@ -739,7 +738,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# Get all working gradients and gradients to be synchronized.
all_working_grads = _get_all_working_grads()
grads_to_sync = _get_grads_to_sync(all_working_grads)
- if self._grad_store.require_grad_sync and grads_to_sync is not None:
+ if self.require_grad_sync and grads_to_sync is not None:
# Synchronize sequence parallelism gradients if required.
SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
else:
@@ -763,7 +762,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# Call the superclass backward method to compute gradients.
super().backward(loss, retain_graph)
- if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
+ if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
@@ -788,14 +787,14 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# Call the superclass backward_by_grad method to compute gradients.
super().backward_by_grad(tensor, grad)
- if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
+ if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
- def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
+ def _compute_grad_norm(self, dp_pg, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Compute and return the gradient norm for gradient clipping.
@@ -811,7 +810,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
if len(gradients) == 0:
return 0.0
- dp_size = get_world_size(self.dp_pg) if self.dp_pg is not None else 1
+ dp_size = get_world_size(dp_pg) if dp_pg is not None else 1
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type)
@@ -842,7 +841,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if tp_size > 1:
- param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
+ param_id_for_grad = self.get_param_id_for_grad(grad)
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
if not is_distributed_tensor(param_for_grad):
@@ -856,7 +855,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_shared_param = shared_param[self.stage_manager.stage]
- working_grad = self._grad_store.get_working_grad_by_param_id(id(stage_shared_param))
+ working_grad = self.get_working_grad_by_param_id(id(stage_shared_param))
if grad is working_grad:
grad_norm_exponentiated /= len(shared_param)
@@ -867,7 +866,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
)
if dp_size > 1:
# compute norm in dp process group
- dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
+ dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=dp_pg)
if tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
@@ -1309,7 +1308,7 @@ class HybridParallelPlugin(PipelinePluginBase):
# run with gradients accumulation
if model.require_grad_sync == False or (
- isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer._grad_store.require_grad_sync == False
+ isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
):
return outputs
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 83888e506..2cfdd000a 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -1,4 +1,5 @@
import random
+import warnings
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple
@@ -20,19 +21,19 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
get_param_info,
init_pipeline_optimizer,
)
+from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
-from colossalai.moe import MOE_MANAGER, MoECheckpointIO
+from colossalai.logging import get_dist_logger
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.policies.base_policy import Policy
+from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer
-PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
-
-class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
+class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__(
self,
optimizer: Optimizer,
@@ -67,8 +68,20 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
self.pp_pg = pp_process_group
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
+
+ pg_param_list = {
+ dp_process_group: [],
+ moe_extra_dp_process_group: [],
+ }
+ for param in model.parameters():
+ if is_moe_tensor(param):
+ pg_param_list[moe_extra_dp_process_group].append(param)
+ else:
+ pg_param_list[dp_process_group].append(param)
+
super().__init__(
optimizer=optimizer,
+ pg_to_param_list=pg_param_list,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
@@ -83,9 +96,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
- dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
- moe_extra_dp_process_group=moe_extra_dp_process_group,
)
@@ -107,8 +118,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
Args:
- tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
+ tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
precision (str, optional): Specifies the precision of parameters during training.
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
Defaults to 'fp16'.
@@ -144,14 +155,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
+ use_ep_inside (bool, Optional): Whether to use ep inside dp (intra-node) for moe params.
"""
def __init__(
self,
- tp_size: int,
pp_size: int,
ep_size: int,
- extra_dp_size: int = 1,
+ tp_size: int = 1,
+ sp_size: int = 1,
precision: str = "fp16",
zero_stage: int = 0,
enable_all_optimization: bool = False,
@@ -184,32 +196,22 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
custom_policy: Policy = None,
checkpoint_io: Optional[MoECheckpointIO] = None,
) -> None:
- assert (
- dist.get_world_size() % (tp_size * pp_size) == 0
- ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
+ world_size = dist.get_world_size()
+ assert tp_size == 1, "Tensor parallel is not supported in MoE yet"
+ assert sp_size == 1 and enable_sequence_parallelism is False, "Sequence parallelism it not supported in MoE yet"
- if enable_sequence_parallelism:
- assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
assert (
- dist.get_world_size() % (tp_size * pp_size) == 0
- ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
+ world_size % (tp_size * pp_size) == 0
+ ), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
assert (
- dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
- ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
- self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
- MOE_MANAGER.setup(
- parallel="EP",
- mode="fixed",
- fixed_dp_size=self.real_dp_size,
- fixed_ep_size=ep_size,
- fixed_pp_size=pp_size,
- use_ep_inside=use_ep_inside,
- )
+ world_size % (tp_size * pp_size * ep_size) == 0
+ ), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
+
+ self.dp_size = world_size // (tp_size * pp_size)
self.tp_size = tp_size
self.pp_size = pp_size
- self.dp_size = dist.get_world_size() // (tp_size * pp_size)
self.ep_size = ep_size
- self.moe_info = MOE_MANAGER.get_info(0)[1]
+ self.sp_size = sp_size
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
@@ -219,43 +221,57 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.checkpoint_io = checkpoint_io
+
+ logger = get_dist_logger()
+
+ # NOTE: Two process meshes: global dp for non-moe param; dp + ep for moe param
+ # See https://hpc-ai.com/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient
# we change pg mesh to (pp, dp, tp) for better moe performance
- self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
+ assert (
+ self.ep_size <= self.dp_size
+ ), f"Not enough devices({self.dp_size}) for expert parallelism size({self.ep_size})."
- # sync moe in outer dp group, and sync other param in global dp group
- if extra_dp_size > 1:
- ep_size = self.dp_size // extra_dp_size
- if use_ep_inside:
- self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, extra_dp_size, ep_size)
- self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(1)
- if dist.get_rank() == 0:
- print(f"Zero Parallel: pp {self.pp_size}, outer_dp {extra_dp_size}, inner_dp {ep_size}")
- else:
- self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, ep_size, extra_dp_size)
- self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(2)
- if dist.get_rank() == 0:
- print(f"Zero Parallel: pp {self.pp_size}, outer_dp {ep_size}, inner_dp {extra_dp_size}")
+ self.moe_dp_size = self.dp_size // self.ep_size
+ self.use_ep_inside = use_ep_inside
+ if self.use_ep_inside:
+ logger.info(f"MoE Parallel use ep inside dp.", ranks=[0])
+ self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 1, 2, 3
+ self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, ep_size, tp_size)
else:
- self.moe_extra_dp_group = None
+ logger.info(f"MoE Parallel use ep outside dp.", ranks=[0])
+ warnings.warn("Using ep outside dp (cross-node) is strongly discouraged due to communication costs.")
+ self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 2, 1, 3
+ self.pg_mesh = ProcessGroupMesh(self.pp_size, ep_size, self.moe_dp_size, tp_size)
+ self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
+ self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis)
+ logger.info(f"Non-MoE Parameter Parallel: pp {self.pp_size}, dp {self.dp_size}, tp {tp_size}", ranks=[0])
+ logger.info(
+ f"MoE Parallel: pp {self.pp_size}, ep {ep_size}, moe dp {self.moe_dp_size}, tp {tp_size}", ranks=[0]
+ )
+
+ self.tp_group = self.pg_mesh.get_group_along_axis(
+ self.tp_axis
+ ) # TODO: support custom tp size for mixtral lm head
+ self.global_dp_group = self.pg_mesh.get_group_along_axis((self.dp_axis, self.ep_axis))
+ self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
+ # TODO: Currently moe only support partially sequence parallel
+ self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
+
+ self.custom_policy = custom_policy
self.stage_manager = None
self.schedule = None
- self.custom_policy = custom_policy
+
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
- self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
+ self.stage_manager = PipelineStageManager(self.pg_mesh, self.pp_axis)
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
)
- self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
- self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
- self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
- # TODO: Currently moe only support partially sequence parallel
- self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
@@ -267,6 +283,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
+ ep_group=self.ep_group,
)
self.amp_config = dict(
initial_scale=initial_scale,
@@ -323,7 +340,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(
- dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
+ dataset,
+ num_replicas=self.dp_size,
+ rank=dist.get_rank(self.global_dp_group),
+ shuffle=shuffle,
)
# Deterministic dataloader
@@ -346,9 +366,20 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
def get_checkpoint_io(self) -> MoECheckpointIO:
if self.checkpoint_io is None:
- self.checkpoint_io = MoECheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
+ self.checkpoint_io = MoECheckpointIO(
+ self.global_dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
+ )
else:
- self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
+ self.checkpoint_io = self.checkpoint_io(
+ self.global_dp_group,
+ self.pp_group,
+ self.tp_group,
+ ep_group=self.ep_group,
+ moe_dp_group=self.moe_dp_group,
+ zero_stage=self.zero_stage,
+ )
+ if hasattr(self.checkpoint_io, "moe_info"):
+ self.checkpoint_io.moe_info = self.moe_info
return self.checkpoint_io
def configure(
@@ -366,7 +397,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
module=model,
precision=self.precision,
shard_config=self.shard_config,
- dp_group=self.dp_group,
+ dp_group=self.global_dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp,
@@ -392,15 +423,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
- optimizer = HybridParallelZeroOptimizer(
+ optimizer = MoeHybridParallelZeroOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
- dp_process_group=self.dp_group,
+ dp_process_group=self.global_dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
- moe_extra_dp_process_group=self.moe_extra_dp_group,
+ moe_extra_dp_process_group=self.moe_dp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,
diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py
index 19b61730b..ef37534fe 100644
--- a/colossalai/checkpoint_io/__init__.py
+++ b/colossalai/checkpoint_io/__init__.py
@@ -2,5 +2,12 @@ from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from .index_file import CheckpointIndexFile
+from .moe_checkpoint import MoECheckpointIO
-__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"]
+__all__ = [
+ "CheckpointIO",
+ "CheckpointIndexFile",
+ "GeneralCheckpointIO",
+ "HybridParallelCheckpointIO",
+ "MoECheckpointIO",
+]
diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
index 7946d9b9c..61c9d1438 100644
--- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
+++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
@@ -70,13 +70,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
verbose: bool = True,
) -> None:
super().__init__()
- self.dp_group = dp_group
+ self.global_dp_group = dp_group
self.pp_group = pp_group
self.tp_group = tp_group
- self.dp_rank = dist.get_rank(self.dp_group)
+ self.dp_rank = dist.get_rank(self.global_dp_group)
self.tp_rank = dist.get_rank(self.tp_group)
self.pp_rank = dist.get_rank(self.pp_group)
- self.dp_size = dist.get_world_size(dp_group)
+ self.global_dp_size = dist.get_world_size(dp_group)
self.pp_size = dist.get_world_size(pp_group)
self.tp_size = dist.get_world_size(tp_group)
self.use_zero = zero_stage > 0
@@ -433,7 +433,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
optimizer,
use_zero=self.use_zero,
- dp_group=self.dp_group,
+ dp_group=self.global_dp_group,
tp_group=self.tp_group,
size_per_shard=size_per_shard,
)
@@ -727,7 +727,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state,
working_param,
original_shape=original_shape,
- dp_group=self.dp_group,
+ dp_group=self.global_dp_group,
tp_group=self.tp_group,
use_zero=self.use_zero,
inplace=False,
@@ -932,12 +932,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Shard state along data parallel group when using Zero.
if self.use_zero:
- padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
+ padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
- slice_size = v.numel() // self.dp_size
+ slice_size = v.numel() // self.global_dp_size
v = v.split(slice_size, dim=0)[self.dp_rank]
state_[k] = v.detach().clone().to(device)
diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py
similarity index 66%
rename from applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
rename to colossalai/checkpoint_io/moe_checkpoint.py
index d08dfd5f8..a0b625008 100644
--- a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
+++ b/colossalai/checkpoint_io/moe_checkpoint.py
@@ -9,6 +9,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
+from torch.distributed.distributed_c10d import get_global_rank
from colossalai.checkpoint_io import CheckpointIndexFile
from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
@@ -19,15 +20,16 @@ from colossalai.checkpoint_io.utils import (
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
+ load_state_dict,
load_states_into_optimizer,
save_config_file,
save_param_groups,
+ save_state_dict,
save_state_dict_shards,
search_tp_partition_dim,
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import ModelWrapper, OptimizerWrapper
-from colossalai.moe import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
try:
@@ -36,21 +38,30 @@ except ImportError:
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
-class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
+class MoECheckpointIO(HybridParallelCheckpointIO):
def __init__(
self,
- dp_group: ProcessGroup,
+ global_dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
+ ep_group: ProcessGroup,
+ moe_dp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
- super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose)
- moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size]
- self.ep_group = moe_info.ep_group
- self.ep_size = moe_info.ep_size
- self.ep_rank = moe_info.ep_rank
- self.real_dp_rank = moe_info.dp_rank
+ super().__init__(global_dp_group, pp_group, tp_group, zero_stage, verbose)
+ self.global_dp_group = global_dp_group
+ self.global_dp_rank = dist.get_rank(global_dp_group)
+ self.global_dp_size = dist.get_world_size(global_dp_group)
+ self.pp_group = pp_group
+ self.tp_group = tp_group
+
+ self.moe_dp_group = moe_dp_group
+ self.moe_dp_size = dist.get_world_size(moe_dp_group)
+ self.moe_dp_rank = dist.get_rank(moe_dp_group)
+ self.ep_group = ep_group
+ self.ep_size = dist.get_world_size(ep_group)
+ self.ep_rank = dist.get_rank(ep_group)
@staticmethod
def _model_sharder(
@@ -134,7 +145,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
Path(checkpoint).mkdir(parents=True, exist_ok=True)
- if self.real_dp_rank != 0:
+ if self.moe_dp_rank != 0:
dist.barrier()
return
@@ -144,7 +155,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
- state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder(
+ state_dict_shard = MoECheckpointIO._model_sharder(
model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
@@ -234,11 +245,12 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
state: OrderedDict,
param: torch.Tensor,
original_shape: torch.Size,
- dp_group: ProcessGroup,
+ global_dp_group: ProcessGroup,
tp_group: ProcessGroup,
use_zero: bool,
inplace: bool,
is_moe_param: bool,
+ moe_dp_group: ProcessGroup = None,
device: torch.device = torch.device("cpu"),
) -> OrderedDict:
"""
@@ -248,7 +260,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
original_shape (torch.Size): The size of parameter before sharding.
- dp_group (ProcessGroup): The process group of data parallel.
+ global_dp_group (ProcessGroup): The process group of data parallel.
tp_group (ProcessGroup): The process group of tensor parallel.
use_zero (bool): Whether Zero is used.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
@@ -257,27 +269,47 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
Returns:
OrderedDict: The complete optimizer state of given parameter.
"""
- dp_size = dist.get_world_size(dp_group)
+ global_dp_size = dist.get_world_size(global_dp_group)
tp_size = dist.get_world_size(tp_group)
+ moe_dp_size = dist.get_world_size(moe_dp_group) if moe_dp_group is not None else 1
current_shape = param.shape
state_ = state if inplace else copy.deepcopy(state)
-
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != "step":
+ v = v.cuda()
+
# First gather Zero shards.
- if use_zero and not is_moe_param:
- v = v.cuda()
- gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
- dist.all_gather(gather_tensor, v, group=dp_group)
- v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
+ if use_zero and is_moe_param and moe_dp_size > 1:
+ moe_dp_rank = dist.get_rank(moe_dp_group)
+ dst = get_global_rank(moe_dp_group, 0)
+ if moe_dp_rank == 0:
+ gather_tensor = [torch.zeros_like(v) for _ in range(moe_dp_size)]
+ dist.gather(v, gather_tensor, group=moe_dp_group, dst=dst)
+ v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
+ else:
+ dist.gather(v, group=moe_dp_group, dst=dst)
+
+ elif use_zero and not is_moe_param and global_dp_size > 1:
+ dp_rank = dist.get_rank(global_dp_group)
+ dst = get_global_rank(global_dp_group, 0)
+ if dp_rank == 0:
+ gather_tensor = [torch.zeros_like(v) for _ in range(global_dp_size)]
+ dist.gather(v, gather_tensor, group=global_dp_group, dst=dst)
+ v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
+ else:
+ dist.gather(v, group=global_dp_group, dst=dst)
# Then gather TP shards.
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
if partition_dim is not None:
- gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
- dist.all_gather(gather_tensor, v, group=tp_group)
- v = torch.cat(gather_tensor, dim=partition_dim)
-
+ tp_rank = dist.get_rank(tp_group)
+ dst = get_global_rank(tp_group, 0)
+ if tp_rank == 0:
+ gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
+ dist.gather(v, gather_tensor, group=tp_group, dst=dst)
+ v = torch.cat(gather_tensor, dim=partition_dim)
+ else:
+ dist.gather(v, group=tp_group, dst=dst)
state_[k] = v.detach().clone().to(device)
return state_
@@ -286,8 +318,9 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
def _optimizer_sharder(
optimizer: OptimizerWrapper,
use_zero: bool,
- dp_group: ProcessGroup,
+ global_dp_group: ProcessGroup,
tp_group: ProcessGroup,
+ moe_dp_group: ProcessGroup,
size_per_shard: int = 1024,
only_moe_param: bool = False,
):
@@ -296,7 +329,6 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
state_dict_sharder = StateDictSharder(size_per_shard)
param_info = optimizer.param_info
master_to_working_map = optimizer.get_master_to_working_map()
-
for param, state in optimizer.optim.state.items():
if param is None:
continue
@@ -305,22 +337,23 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
working_param = master_to_working_map[id(param)]
else:
working_param = param
-
param_id = param_info["param2id"][id(working_param)]
original_shape = param_info["param2shape"][id(working_param)]
- state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
+ state_ = MoECheckpointIO.gather_from_sharded_optimizer_state(
state,
working_param,
original_shape=original_shape,
- dp_group=dp_group,
+ global_dp_group=global_dp_group,
+ moe_dp_group=moe_dp_group,
tp_group=tp_group,
use_zero=use_zero,
inplace=False,
- is_moe_param=is_moe_tensor(working_param),
+ is_moe_param=is_moe_tensor(working_param), # TODO: Check correctness here
)
if only_moe_param and not is_moe_tensor(working_param):
continue
+
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
if block is not None:
yield block, block_size
@@ -359,25 +392,28 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
Path(checkpoint).mkdir(parents=True, exist_ok=True)
- # Devices along the same dp_group share the same copies of states when zero is not used.
- # In this case only let the device with dp_rank == 0 save the model.
- if not self.use_zero and self.real_dp_rank != 0:
+ # If optim states are not sharded, other ranks don't need to participate in gather.
+ if not self.use_zero and self.moe_dp_rank != 0:
dist.barrier()
return
# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
- state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder(
+ state_dict_shard = MoECheckpointIO._optimizer_sharder(
optimizer,
use_zero=self.use_zero,
- dp_group=self.dp_group,
+ global_dp_group=self.global_dp_group,
tp_group=self.tp_group,
+ moe_dp_group=self.moe_dp_group,
size_per_shard=size_per_shard,
only_moe_param=self.ep_rank != 0,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
- control_saving = self.real_dp_rank == 0 and self.tp_rank == 0
+ # e.g. dp_size = 4, moe_dp_size = 2, ep_size = 2 and use gather
+ # rank 0 saves moe & non-moe params; rank 1 only saves moe params
+ # rank 3 & 4 save nothing
+ control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0
if self.pp_size == 1 and self.ep_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
@@ -596,7 +632,6 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
OrderedDict: The sharded optimizer state of the given parameter.
"""
state_ = state if inplace else copy.deepcopy(state)
-
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != "step":
# Shard state along tensor parallel group.
@@ -606,24 +641,218 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
# Shard state along data parallel group when using Zero.
- if self.use_zero and not is_moe_param:
- padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
+ if self.use_zero and not is_moe_param and self.global_dp_size > 1:
+ padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
- slice_size = v.numel() // self.dp_size
- v = v.split(slice_size, dim=0)[self.dp_rank]
+ slice_size = v.numel() // self.global_dp_size
+ v = v.split(slice_size, dim=0)[self.global_dp_rank]
+
+ elif self.use_zero and is_moe_param and self.moe_dp_size > 1:
+ # LowLevelZeRO pads by global dp size for now.
+ # TODO: update both to use moe dp size
+ padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size
+ with torch.no_grad():
+ v = v.flatten()
+ if padding_size > 0:
+ v = torch.nn.functional.pad(v, [0, padding_size])
+ slice_size = v.numel() // self.moe_dp_size
+ v = v.split(slice_size, dim=0)[self.moe_dp_rank]
state_[k] = v.detach().clone().to(device)
return state_
- def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
- raise NotImplementedError
+ """Migration from MoEHybridParallelCheckpointIO. These functions mostly deals with unsharded saving,
+ and can be savely deleted since large MoE models are often saved in shards.
+ """
+ # Copied from colossalai.moe
+ def pre_save_model(self, model: nn.Module) -> dict:
+ state_dict = model.state_dict()
+ for name, param in model.named_parameters():
+ if ".experts." in name and is_moe_tensor(param):
+ ep_group = param.ep_group
+ ep_rank = dist.get_rank(ep_group)
+ ep_size = dist.get_world_size(ep_group)
+ # TODO: check correctness here
+ # dp_rank = get_dp_rank(param)
+ dp_rank = dist.get_rank(self.global_dp_group)
+ if dp_rank == 0:
+ param = param.data.cuda()
+ if ep_rank == 0:
+ all_param = [torch.zeros_like(param) for _ in range(ep_size)]
+ else:
+ all_param = None
+ # gather param from every ep rank
+ # dist.all_gather(all_param, param, group=ep_group)
+ dist.gather(param, all_param, group=ep_group)
+ if ep_rank == 0:
+ all_param = torch.cat(all_param, dim=0)
+ state_dict[name] = all_param.cpu()
+
+ if self.pp_size > 1:
+ if self.dp_rank == 0:
+ out = [None for _ in range(self.pp_size)]
+ dist.gather_object(state_dict, out, group=self.pp_group)
+ if self.pp_rank == 0:
+ new_state_dict = {}
+ for o in out:
+ new_state_dict.update(o)
+ state_dict = new_state_dict
+ dist.barrier()
+ return state_dict
+
+ def save_unsharded_model(
+ self,
+ model: nn.Module,
+ checkpoint: str,
+ gather_dtensor: bool,
+ use_safetensors: bool,
+ ):
+ state_dict = self.pre_save_model(model)
+ if dist.get_rank() == 0:
+ torch.save(state_dict, checkpoint)
+ dist.barrier()
+
+ # Copied from colossalai.moe
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
- raise NotImplementedError
+ """
+ Save optimizer state dict to a file with given path.
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
+ checkpoint (str): Path to save optimizer state_dict.
+ gather_dtensor (bool): Whether to gather_dtensor, not used.
+ """
+ if self.coordinator.is_master():
+ logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
+
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
+
+ # optimizer states of parameters kept by local device('s pipeline stage)
+ local_states = dict()
+
+ for param, state in optimizer.optim.state.items():
+ if param is None:
+ continue
+
+ # working param is needed for obtaining correct param_id
+ master_to_working_map = optimizer.get_master_to_working_map()
+ if master_to_working_map is not None and id(param) in master_to_working_map:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+
+ # gather complete state from tp shards & dp shards
+ param_id = optimizer.param_info["param2id"][id(working_param)]
+ local_states[param_id] = self.pre_save_optim(
+ state,
+ working_param,
+ inplace=False,
+ device=torch.device("cuda"),
+ )
+
+ if self.pp_size == 1:
+ # When pipeline is not used, let master rank directly save the collected state_dict.
+ state_dict = {"param_groups": optimizer.optim.param_groups, "state": local_states}
+ if self.coordinator.is_master():
+ save_state_dict(state_dict, checkpoint, use_safetensors=False)
+ else:
+ # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
+ states_list = [None for _ in range(self.pp_size)]
+ dist.barrier(self.pp_group)
+ # dist.all_gather_object(states_list, local_states, self.pp_group)
+ dist.gather_object(local_states, states_list, self.pp_group)
+
+ # Only the master rank do the saving.
+ if self.coordinator.is_master():
+ state_dict = {"param_groups": optimizer.optim.param_groups, "state": dict()}
+ for _states in states_list:
+ state_dict["state"].update(_states)
+ save_state_dict(state_dict, checkpoint, use_safetensors=False)
+ dist.barrier()
+
+ # Copied from colossalai.moe
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
- raise NotImplementedError
+ """
+ Load optimizer from a file with given path.
+
+ Args:
+ optimizer (OptimizerWrapper): The optimizer to be loaded.
+ checkpoint_index_file (str): Path to the checkpoint file.
+ """
+
+ def _get_param_id_from_optimizer_param(
+ param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
+ ):
+ if master_to_working_map is not None and id(param) in master_to_working_map:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+ if id(working_param) in optimizer.param_info["param2id"]:
+ return optimizer.param_info["param2id"][id(working_param)]
+ else:
+ None
+
+ if self.coordinator.is_master():
+ logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
+
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
+
+ # Complete optimizer state_dict loaded from checkpoint, need to be processed later.
+ state_dict = load_state_dict(checkpoint)
+
+ # Load param_groups.
+ updated_groups = []
+ saved_groups = state_dict["param_groups"]
+ for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
+ new_pg = copy.deepcopy(saved_pg)
+ new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
+ updated_groups.append(new_pg)
+
+ # ep extra group
+ # if MOE_MANAGER.parallel == "EP":
+ if self.ep_size > 1:
+ new_pg = copy.deepcopy(saved_pg)
+ new_pg["params"] = optimizer.optim.param_groups[-1][
+ "params"
+ ] # Only keep the parameters kept by current pipeline stage.
+ for param in new_pg["params"]:
+ param.data = param.data.to(torch.float32)
+ updated_groups.append(new_pg)
+ optimizer.optim.__dict__.update({"param_groups": updated_groups})
+
+ # Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
+ master_to_working_map = optimizer.get_master_to_working_map()
+ id_map = {}
+ for pg in optimizer.optim.param_groups:
+ for param in pg["params"]:
+ param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+ if param_id is not None:
+ id_map[param_id] = param
+ load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
+
+ # Then shard the loaded optimizer states if using tp/zero.
+ for param, state in optimizer.optim.state.items():
+ if param is None:
+ continue
+ device = param.device
+ if master_to_working_map is not None and id(param) in master_to_working_map:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+ original_shape = optimizer.param_info["param2shape"][id(working_param)]
+ sharded_state = self.pre_load_optim(
+ state,
+ param,
+ current_shape=working_param.shape,
+ original_shape=original_shape,
+ device=device,
+ inplace=True,
+ )
+ optimizer.optim.state[param] = sharded_state
+ sharded_optimizer_loading_epilogue(optimizer.optim)
+ dist.barrier()
diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py
index 20870a3c2..36138f33e 100644
--- a/colossalai/checkpoint_io/utils.py
+++ b/colossalai/checkpoint_io/utils.py
@@ -242,6 +242,7 @@ def save_state_dict_shards(
shard_filenames = []
for idx, shard_pair in enumerate(sharded_state_dict):
shard, current_size = shard_pair
+ # Just loop over the sharder and gather to other ranks if not master
if not is_master:
del shard
continue
diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py
index f0cb78c5f..1319a4529 100644
--- a/colossalai/cluster/process_group_mesh.py
+++ b/colossalai/cluster/process_group_mesh.py
@@ -244,19 +244,25 @@ class ProcessGroupMesh:
return target_group
def get_group_along_axis(
- self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
+ self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
) -> ProcessGroup:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Args:
- axis (int): Axis along which the process groups are created.
+ axis (int or list of int): Axes along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
"""
- indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
+ indices_at_axis = indices_at_axis
+ if indices_at_axis is None:
+ if isinstance(axis, (list, tuple)):
+ indices_at_axis = list(list(range(self._shape[ax])) for ax in axis)
+ else:
+ indices_at_axis = list(range(self._shape[axis]))
+
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
if ranks_in_group not in self._ranks_to_group:
diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py
index cc33c77f3..0623d19ef 100644
--- a/colossalai/moe/__init__.py
+++ b/colossalai/moe/__init__.py
@@ -1,20 +1,5 @@
-from .checkpoint import MoECheckpointIO
-from .experts import MLPExperts
-from .layers import SparseMLP, apply_load_balance
from .manager import MOE_MANAGER
-from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
-from .utils import NormalNoiseGenerator, UniformNoiseGenerator
__all__ = [
- "MLPExperts",
- "MoeRouter",
- "Top1Router",
- "Top2Router",
- "TopKRouter",
- "NormalNoiseGenerator",
- "UniformNoiseGenerator",
- "SparseMLP",
- "MoECheckpointIO",
"MOE_MANAGER",
- "apply_load_balance",
]
diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py
deleted file mode 100644
index 59a0ec3f0..000000000
--- a/colossalai/moe/checkpoint.py
+++ /dev/null
@@ -1,792 +0,0 @@
-import copy
-import logging
-import os
-from pathlib import Path
-from shutil import rmtree
-from typing import Dict, Iterator, Optional, OrderedDict, Tuple
-
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-from torch.distributed import ProcessGroup
-
-from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO
-from colossalai.checkpoint_io.utils import (
- StateDictSharder,
- gather_distributed_param,
- get_model_base_filenames,
- get_optimizer_base_filenames,
- is_safetensors_available,
- load_shard_state_dict,
- load_state_dict,
- load_state_dict_into_model,
- load_states_into_optimizer,
- save_config_file,
- save_param_groups,
- save_state_dict,
- save_state_dict_shards,
- sharded_optimizer_loading_epilogue,
-)
-from colossalai.interface import OptimizerWrapper
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.tensor.moe_tensor.api import (
- get_dp_group,
- get_dp_rank,
- get_dp_size,
- get_ep_group,
- get_ep_rank,
- get_ep_size,
- is_moe_tensor,
-)
-
-
-class MoECheckpointIO(HybridParallelCheckpointIO):
- def __init__(
- self,
- dp_group: ProcessGroup,
- pp_group: ProcessGroup,
- tp_group: ProcessGroup,
- zero_stage: int,
- ) -> None:
- assert zero_stage in [
- 0,
- 1,
- 2,
- ], f"zero_stage should be 0 or 1 or 2, got {zero_stage}"
- super().__init__(dp_group, pp_group, tp_group, zero_stage)
- self.parallel = MOE_MANAGER.parallel
-
- def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
- """
- Preprocess state_dict before loading and slice the state_dict of MOE tensors.
- """
- for name, param in state_dict.items():
- if ".experts." in name:
- if name in dict(model.named_parameters()):
- model_param = dict(model.named_parameters())[name]
- if is_moe_tensor(model_param):
- ep_rank = get_ep_rank(model_param)
- ep_size = get_ep_size(model_param)
- expert_num = param.shape[0] // ep_size
- assert param.shape[0] % ep_size == 0
- param = param[ep_rank * expert_num : (ep_rank + 1) * expert_num]
- state_dict[name] = param
- dist.barrier()
- return state_dict
-
- def _model_sharder(
- self,
- state_dict: nn.Module,
- prefix: str = "",
- keep_vars: bool = False,
- size_per_shard: int = 1024,
- ) -> Iterator[Tuple[OrderedDict, int]]:
- # An internel method that breaks state_dict of model into shards within limited size.
- state_dict_sharder = StateDictSharder(size_per_shard)
-
- for name, param in state_dict.items():
- if param is None:
- continue
- # Gather tensor pieces when using tensor parallel.
- param_ = gather_distributed_param(param, keep_vars=False)
- block, block_size = state_dict_sharder.append_param(prefix + name, param_)
- if block is not None:
- yield block, block_size
-
- # Return the last block in sharder.
- yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
-
- def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool) -> None:
- state_dict = torch.load(checkpoint)
- state_dict = self.pre_load_model(model, state_dict)
- model.load_state_dict(state_dict, strict=strict if self.pp_size == 1 else False)
-
- def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
- """
- Load sharded model with the given path to index file of checkpoint folder.
-
- Args:
- model (nn.Module): The model to be loaded.
- checkpoint_index_file (str): Path to the index file of checkpointing folder.
- strict (bool, optional): For name matching during loading state_dict. Defaults to False.
- This argument should be manually set to False since params on same device might be stored in different files.
- """
-
- # Check whether the checkpoint uses safetensors.
- use_safetensors = False
- if "safetensors" in checkpoint_index_file.name:
- use_safetensors = True
-
- if use_safetensors and not is_safetensors_available():
- raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
-
- # Read checkpoint index file.
- ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
- ckpt_root_path = ckpt_index_file.root_path
- weight_map = ckpt_index_file.weight_map
- strict = False
-
- # Load params & buffers to model.
- # Keep a record of loaded files so that file will not be repeatedly loaded.
- loaded_file = set()
-
- def _load(name: str):
- if name not in weight_map:
- raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
- filename = weight_map[name]
-
- # If this param/buffer has been loaded before, directly return.
- if filename in loaded_file:
- return
-
- file_path = os.path.join(ckpt_root_path, filename)
- state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
- state_dict = self.pre_load_model(model, state_dict)
- missing_keys = []
-
- load_state_dict_into_model(
- model,
- state_dict,
- missing_keys=missing_keys,
- strict=strict,
- load_sub_module=True,
- )
- loaded_file.add(filename)
-
- # Load parameters.
- for name, _ in model.named_parameters():
- _load(name)
-
- if self.verbose:
- logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
-
- def pre_save_model(self, model: nn.Module) -> dict:
- state_dict = model.state_dict()
- for name, param in model.named_parameters():
- if ".experts." in name and is_moe_tensor(param):
- ep_group = get_ep_group(param)
- ep_rank = get_ep_rank(param)
- ep_size = get_ep_size(param)
- dp_rank = get_dp_rank(param)
- if dp_rank == 0:
- param = param.data.cuda()
- all_param = [torch.zeros_like(param) for _ in range(ep_size)]
- # gather param from every ep rank
- dist.all_gather(all_param, param, group=ep_group)
- if ep_rank == 0:
- all_param = torch.cat(all_param, dim=0)
- state_dict[name] = all_param.cpu()
- if self.pp_size > 1:
- if self.dp_rank == 0:
- out = [None for _ in range(self.pp_size)]
- dist.all_gather_object(out, state_dict, group=self.pp_group)
- if self.pp_rank == 0:
- new_state_dict = {}
- for o in out:
- new_state_dict.update(o)
- state_dict = new_state_dict
- dist.barrier()
- return state_dict
-
- def save_unsharded_model(
- self,
- model: nn.Module,
- checkpoint: str,
- gather_dtensor: bool,
- use_safetensors: bool,
- ):
- state_dict = self.pre_save_model(model)
- if dist.get_rank() == 0:
- torch.save(state_dict, checkpoint)
- dist.barrier()
-
- def save_sharded_model(
- self,
- model: nn.Module,
- checkpoint: str,
- gather_dtensor: bool = True,
- prefix: Optional[str] = None,
- size_per_shard: int = 1024,
- use_safetensors: bool = False,
- ) -> None:
- """
- Save sharded model checkpoint under the given checkpointing path.
- The following files will be created under the path:
- - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
- - Multiple files that store state tensors of models.
- The filenames are in the form of "pytorch_model.-000XX.bin"
-
- Args:
- model (nn.Module): Model on local device to be saved.
- checkpoint (str): Checkpointing path which should be a directory path.
- gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
- prefix (str, optional): Perfix of file to save. Defaults to None.
- size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
- use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
- """
- torch.cuda.empty_cache()
- if os.path.isfile(checkpoint):
- logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
- return
-
- Path(checkpoint).mkdir(parents=True, exist_ok=True)
-
- # Then collect the sharded parameters & buffers along tp_group.
- # Only devices with tp_rank == 0 are responsible for model saving.
- state_dict = self.pre_save_model(model)
-
- if dist.get_rank() == 0:
- state_dict_shard = self._model_sharder(state_dict, size_per_shard=size_per_shard)
-
- # Devices along the same dp_group share the same copies of model.
- # So only let the device with dp_rank == 0 save the model.
- if self.dp_rank != 0:
- return
-
- weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
- index_file = CheckpointIndexFile(checkpoint)
- control_saving = self.tp_rank == 0
-
- total_size = save_state_dict_shards(
- sharded_state_dict=state_dict_shard,
- checkpoint=checkpoint,
- index_file=index_file,
- base_filename=weights_name,
- is_master=control_saving,
- use_safetensors=use_safetensors,
- )
- if control_saving:
- index_file.append_meta_data("total_size", total_size)
- index_file.write_index_file(save_index_file)
- save_config_file(model, checkpoint)
- if self.verbose:
- logging.info(
- f"The model is split into checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}."
- )
- dist.barrier()
- torch.cuda.empty_cache()
-
- # ========================================================
- # Abstract methods for optimizer loading/saving implementation
- # ========================================================
-
- def pre_load_optim(
- self,
- state: OrderedDict,
- working_param,
- current_shape: torch.Size,
- original_shape: torch.Size,
- device: torch.device,
- inplace: bool,
- ) -> OrderedDict:
- """
- With complete optimizer states of a specific parameter loaded from checkpoint,
- slice out the sharded optimizer states kept by current device.
-
- Args:
- state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
- current_shape (torch.Size): The size of parameter after sharding.
- original_shape (torch.Size): The size of parameter before sharding.
- device (torch.device): The destination device of loaded optimizer states.
- inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
-
- Returns:
- OrderedDict: The sharded optimizer state of the given parameter.
- """
- state_ = state if inplace else copy.deepcopy(state)
- is_moe_tensor_flag = is_moe_tensor(working_param)
- if is_moe_tensor_flag:
- ep_rank = get_ep_rank(working_param)
- ep_size = get_ep_size(working_param)
-
- for k, v in state_.items():
- if isinstance(v, torch.Tensor) and k != "step":
- if is_moe_tensor_flag:
- with torch.no_grad():
- expert_num = v.shape[0] // ep_size
- assert v.shape[0] % ep_size == 0
- v = v[ep_rank * expert_num : (ep_rank + 1) * expert_num]
- else:
- # Shard state along data parallel group when using Zero.
- padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
- with torch.no_grad():
- v = v.flatten()
- if padding_size > 0:
- v = torch.nn.functional.pad(v, [0, padding_size])
- slice_size = v.numel() // self.dp_size
- v = v.split(slice_size, dim=0)[self.dp_rank]
-
- state_[k] = v.detach().clone().to(device)
-
- return state_
-
- def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
- """
- Load sharded optimizer with the given path to index file of checkpoint folder.
-
- Args:
- optimizer (OptimizerWrapper): The optimizer to be loaded.
- checkpoint_index_file (str): Path to the index file of checkpointing folder.
- prefix (str): Not used.
- """
- assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
-
- def _get_param_id_from_optimizer_param(
- param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None
- ):
- if master_to_working_map is not None and id(param) in master_to_working_map:
- working_param = master_to_working_map[id(param)]
- elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
- working_param = optimizer.moe_master_to_working_map[id(param)]
- else:
- working_param = param
- return optimizer.param_info["param2id"][id(working_param)]
-
- # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
- # When Zero is used, the mapped parameter objects should be fp32 master parameters.
- # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
- id_map = {}
- master_to_working_map = optimizer.get_master_to_working_map()
- for pg in optimizer.optim.param_groups:
- for param in pg["params"]:
- param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
- id_map[param_id] = param
-
- # Read checkpoint index file.
- ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
- ckpt_root_path = ckpt_index_file.root_path
- weight_map = ckpt_index_file.weight_map
- weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
-
- # Load param_groups
- param_group_path = ckpt_index_file.get_param_group_filename()
- if param_group_path is None:
- raise RuntimeError(
- f"Invalid index file path {checkpoint_index_file} for an optimizer. \
- Lacking param group file under current directory."
- )
- saved_groups = torch.load(param_group_path)
-
- updated_groups = []
- for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
- # obtain updated param group
- new_pg = copy.deepcopy(saved_pg)
- new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change.
- updated_groups.append(new_pg)
- # ep param group
- if len(optimizer.optim.param_groups) > len(saved_groups):
- new_pg = copy.deepcopy(saved_pg)
- new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
- updated_groups.append(new_pg)
- optimizer.optim.__dict__.update({"param_groups": updated_groups})
-
- # Load saved states to optimizer.
- # Keep a record of loaded files so that file will not be repeatedly loaded.
- loaded_file = set()
- for pg in optimizer.optim.param_groups:
- for param in pg["params"]:
- if param is None:
- continue
- param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
- if param_id not in weight_map:
- continue
- filename = weight_map[param_id]
-
- # If this param's states has been loaded before, directly return.
- if filename in loaded_file:
- continue
-
- file_path = os.path.join(ckpt_root_path, filename)
- state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
-
- # Then shard the loaded optimizer states if using tp/zero.
- for pid, state in list(state_dict.items()):
- if pid in id_map:
- param = id_map[pid]
- if master_to_working_map is not None and id(param) in master_to_working_map:
- working_param = master_to_working_map[id(param)]
- elif (
- hasattr(optimizer, "moe_master_to_working_map")
- and id(param) in optimizer.moe_master_to_working_map
- ):
- working_param = optimizer.moe_master_to_working_map[id(param)]
- else:
- working_param = param
- original_shape = optimizer.param_info["param2shape"][id(working_param)]
- sharded_state = self.pre_load_optim(
- state,
- working_param,
- current_shape=working_param.shape,
- original_shape=original_shape,
- device="cpu",
- inplace=True,
- )
- state_dict[pid] = sharded_state
-
- load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
- loaded_file.add(filename)
-
- sharded_optimizer_loading_epilogue(optimizer.optim)
- if self.verbose and self.coordinator.is_master():
- logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
- dist.barrier()
-
- def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
- """
- Load optimizer from a file with given path.
-
- Args:
- optimizer (OptimizerWrapper): The optimizer to be loaded.
- checkpoint_index_file (str): Path to the checkpoint file.
- """
-
- def _get_param_id_from_optimizer_param(
- param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
- ):
- if master_to_working_map is not None and id(param) in master_to_working_map:
- working_param = master_to_working_map[id(param)]
- else:
- working_param = param
- if id(working_param) in optimizer.param_info["param2id"]:
- return optimizer.param_info["param2id"][id(working_param)]
- else:
- None
-
- if self.coordinator.is_master():
- logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
-
- assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
-
- # Complete optimizer state_dict loaded from checkpoint, need to be processed later.
- state_dict = load_state_dict(checkpoint)
-
- # Load param_groups.
- updated_groups = []
- saved_groups = state_dict["param_groups"]
- for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
- new_pg = copy.deepcopy(saved_pg)
- new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
- updated_groups.append(new_pg)
- # ep extra group
- if MOE_MANAGER.parallel == "EP":
- new_pg = copy.deepcopy(saved_pg)
- new_pg["params"] = optimizer.optim.param_groups[-1][
- "params"
- ] # Only keep the parameters kept by current pipeline stage.
- for param in new_pg["params"]:
- param.data = param.data.to(torch.float32)
- updated_groups.append(new_pg)
- optimizer.optim.__dict__.update({"param_groups": updated_groups})
-
- # Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
- master_to_working_map = optimizer.get_master_to_working_map()
- id_map = {}
- for pg in optimizer.optim.param_groups:
- for param in pg["params"]:
- param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
- if param_id is not None:
- id_map[param_id] = param
- load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
-
- # Then shard the loaded optimizer states if using tp/zero.
- for param, state in optimizer.optim.state.items():
- if param is None:
- continue
- device = param.device
- if master_to_working_map is not None and id(param) in master_to_working_map:
- working_param = master_to_working_map[id(param)]
- else:
- working_param = param
- original_shape = optimizer.param_info["param2shape"][id(working_param)]
- sharded_state = self.pre_load_optim(
- state,
- param,
- current_shape=working_param.shape,
- original_shape=original_shape,
- device=device,
- inplace=True,
- )
- optimizer.optim.state[param] = sharded_state
- sharded_optimizer_loading_epilogue(optimizer.optim)
- dist.barrier()
-
- def pre_save_optim(
- self,
- state: OrderedDict,
- param: torch.Tensor,
- inplace: bool,
- device: torch.device = torch.device("cpu"),
- ) -> OrderedDict:
- """
- With given parameter and its optimizer states, gather the complete optimizer state for saving.
-
- Args:
- state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
- param (torch.Tensor): The given parameter. It should be working_param when using Zero.
- original_shape (torch.Size): The size of parameter before sharding.
- dp_group (ProcessGroup): The process group of data parallel.
- tp_group (ProcessGroup): The process group of tensor parallel.
- use_zero (bool): Whether Zero is used.
- inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
- device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
-
- Returns:
- OrderedDict: The complete optimizer state of given parameter.
- """
- if is_moe_tensor(param):
- moe_dp_group = get_dp_group(param)
- moe_dp_size = get_dp_size(param)
- moe_ep_group = get_ep_group(param)
- moe_ep_size = get_ep_size(param)
- state_ = state if inplace else copy.deepcopy(state)
-
- for k, v in state_.items():
- if isinstance(v, torch.Tensor) and k != "step":
- # moe param
- if is_moe_tensor(param):
- # dp gather
- v = v.cuda()
- gather_tensor = [torch.zeros_like(v) for _ in range(moe_dp_size)]
- dist.all_gather(gather_tensor, v, group=moe_dp_group)
- v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
- # ep gather
- gather_tensor = [torch.zeros_like(v) for _ in range(moe_ep_size)]
- dist.all_gather(gather_tensor, v, group=moe_ep_group)
- v = torch.cat(gather_tensor, dim=0)
- else:
- # global dp
- v = v.cuda()
- gather_tensor = [torch.zeros_like(v) for _ in range(dist.get_world_size(self.dp_group))]
- dist.all_gather(gather_tensor, v, group=self.dp_group)
- v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
-
- state_[k] = v.detach().clone().to(device)
-
- return state_
-
- def _optimizer_sharder(
- self,
- optimizer: OptimizerWrapper,
- size_per_shard: int = 1024,
- ):
- # An internel method that breaks state_dict of optimizer into shards within limited size.
-
- state_dict_sharder = StateDictSharder(size_per_shard)
- param_info = optimizer.param_info
- master_to_working_map = optimizer.get_master_to_working_map()
-
- for param, state in optimizer.optim.state.items():
- if param is None:
- continue
-
- if master_to_working_map is not None and id(param) in master_to_working_map:
- working_param = master_to_working_map[id(param)]
- elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
- working_param = optimizer.moe_master_to_working_map[id(param)]
- else:
- working_param = param
-
- param_id = param_info["param2id"][id(working_param)]
- state_ = self.pre_save_optim(
- state,
- working_param,
- inplace=False,
- device=torch.device("cuda"),
- )
-
- block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
- if block is not None:
- yield block, block_size
-
- # Return the last block in sharder.
- yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
-
- def save_sharded_optimizer(
- self,
- optimizer: OptimizerWrapper,
- checkpoint: str,
- gather_dtensor: bool = True,
- prefix: Optional[str] = None,
- size_per_shard: int = 1024,
- ):
- """
- Save sharded optimizer checkpoint under the given checkpointing path.
- The following files will be created under the path:
- - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- - A group file (pytorch_optim_group.bin) recording information of param_groups
- - Multiple files that store state tensors of optimizers.
- If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin".
- If pipeline parallelism is not used, "pytorch_optim.-000XX.bin"
-
- Args:
- optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
- checkpoint (str): Path to save optimizer state_dict
- gather_dtensor (bool): Whether to gather_dtensor, not used
- prefix (str): Perfix of file to save
- size_per_shard (int): Max file size of each file shard that store state tensors
- """
- torch.cuda.empty_cache()
- assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
- if os.path.isfile(checkpoint):
- logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
- return
-
- Path(checkpoint).mkdir(parents=True, exist_ok=True)
-
- # Devices along the same dp_group share the same copies of states when zero is not used.
- # In this case only let the device with dp_rank == 0 save the model.
- if not self.use_zero and self.dp_rank != 0:
- return
-
- # Then collect the sharded states along dp_group(if using zero)/tp_group.
- # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
- state_dict_shard = self._optimizer_sharder(
- optimizer,
- size_per_shard=size_per_shard,
- )
- states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
- index_file = CheckpointIndexFile(checkpoint)
- control_saving = self.dp_rank == 0 and self.tp_rank == 0
- if self.pp_size == 1:
- # When pipeline is not used, save the optimizer shards as in general checkpointIO
- total_size = save_state_dict_shards(
- sharded_state_dict=state_dict_shard,
- checkpoint=checkpoint,
- index_file=index_file,
- base_filename=states_name,
- is_master=control_saving,
- )
-
- if control_saving:
- # Store param groups.
- index_file.append_meta_data("param_groups", param_group_file)
- group_file_path = os.path.join(checkpoint, param_group_file)
- save_param_groups(optimizer.param_info, group_file_path)
- # Store index file.
- index_file.append_meta_data("total_size", total_size)
- index_file.write_index_file(save_index_file)
- if self.verbose and self.coordinator.is_master():
- logging.info(
- f"The optimizer is going to be split to checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}."
- )
-
- else:
- # When pipeline is used, each stage produces its own shard files and index files.
- # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
- # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
-
- final_index_file_path = copy.deepcopy(save_index_file)
- tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
- Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
-
- # Manage filenames of sharded weights and index file for each pipeline stage.
- states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
- save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
- save_index_file = os.path.join("tmp_index_files", save_index_file)
-
- total_size = save_state_dict_shards(
- sharded_state_dict=state_dict_shard,
- checkpoint=checkpoint,
- index_file=index_file,
- base_filename=states_name,
- is_master=control_saving,
- use_pp_format=True,
- )
-
- if control_saving:
- assert (
- self.dp_rank == 0 and self.tp_rank == 0
- ), "The saving process should have both dp_rank and tp_rank as 0."
- index_file.append_meta_data("total_size", total_size)
- index_file.write_index_file(save_index_file)
- else:
- return
-
- dist.barrier(self.pp_group)
-
- # The global master rank integrates the index files and clean the folder.
- if self.pp_rank == 0:
- final_index_file = CheckpointIndexFile(checkpoint)
- final_index_file.append_meta_data("total_size", 0)
-
- for filename in os.listdir(tmp_index_file_folder):
- stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
- final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
- for param_id, state_filename in stage_index_file.weight_map.items():
- final_index_file.append_weight_map(param_id, state_filename)
-
- # Store param groups.
- final_index_file.append_meta_data("param_groups", param_group_file)
- group_file_path = os.path.join(checkpoint, param_group_file)
- save_param_groups(optimizer.param_info, group_file_path)
-
- final_index_file.write_index_file(final_index_file_path)
- rmtree(tmp_index_file_folder)
-
- if self.verbose and self.coordinator.is_master():
- logging.info(
- f"The model is split into checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {final_index_file_path}."
- )
- torch.cuda.empty_cache()
-
- def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
- """
- Save optimizer state dict to a file with given path.
-
- Args:
- optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
- checkpoint (str): Path to save optimizer state_dict.
- gather_dtensor (bool): Whether to gather_dtensor, not used.
- """
- if self.coordinator.is_master():
- logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
-
- assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
-
- # optimizer states of parameters kept by local device('s pipeline stage)
- local_states = dict()
-
- for param, state in optimizer.optim.state.items():
- if param is None:
- continue
-
- # working param is needed for obtaining correct param_id
- master_to_working_map = optimizer.get_master_to_working_map()
- if master_to_working_map is not None and id(param) in master_to_working_map:
- working_param = master_to_working_map[id(param)]
- else:
- working_param = param
-
- # gather complete state from tp shards & dp shards
- param_id = optimizer.param_info["param2id"][id(working_param)]
- local_states[param_id] = self.pre_save_optim(
- state,
- working_param,
- inplace=False,
- device=torch.device("cuda"),
- )
-
- if self.pp_size == 1:
- # When pipeline is not used, let master rank directly save the collected state_dict.
- state_dict = {"param_groups": optimizer.optim.param_groups, "state": local_states}
- if self.coordinator.is_master():
- save_state_dict(state_dict, checkpoint, use_safetensors=False)
- else:
- # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
- states_list = [None for _ in range(self.pp_size)]
- dist.barrier(self.pp_group)
- dist.all_gather_object(states_list, local_states, self.pp_group)
-
- # Only the master rank do the saving.
- if self.coordinator.is_master():
- state_dict = {"param_groups": optimizer.optim.param_groups, "state": dict()}
- for _states in states_list:
- state_dict["state"].update(_states)
- save_state_dict(state_dict, checkpoint, use_safetensors=False)
- dist.barrier()
diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py
index 85c12d73f..3dc6c02c7 100644
--- a/colossalai/moe/load_balance.py
+++ b/colossalai/moe/load_balance.py
@@ -7,8 +7,8 @@ from torch import Tensor, nn
from torch.distributed import ProcessGroup
from colossalai.cluster import ProcessGroupMesh
-from colossalai.moe.experts import MLPExperts
from colossalai.moe.manager import MOE_MANAGER
+from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.zero.low_level import LowLevelZeroOptimizer
@@ -292,7 +292,7 @@ class LoadBalancer:
exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"]
exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"]
else:
- master_weight_ptr = optim._param_store.working_to_master_param[id(weight)]
+ master_weight_ptr = optim.working_to_master_param[id(weight)]
working_weight_ptr = weight
exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"]
exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"]
@@ -344,7 +344,7 @@ class LoadBalancer:
# gate optim should be obtained first
gate_shape = self.gate.shape
# get master weight and optim
- master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)]
+ master_gate_weight = optim.working_to_master_param[id(self.gate)]
gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"]
gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"]
# gather
diff --git a/colossalai/moe/loss.py b/colossalai/moe/loss.py
deleted file mode 100644
index 75624510b..000000000
--- a/colossalai/moe/loss.py
+++ /dev/null
@@ -1,78 +0,0 @@
-import torch.nn as nn
-from torch.nn.modules.loss import _Loss
-
-from colossalai.moe.manager import MOE_MANAGER
-
-
-class MoeCrossEntropyLoss(_Loss):
- r"""torch.nn.CrossEntropyLoss added with auxiliary loss.
-
- Args:
- input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
- target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
- aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01.
-
- The ``args`` and ``kwargs`` should include parameters below:
- ::
-
- weight (Tensor, optional)
- size_average (bool, optional)
- ignore_index (int, optional)
- reduce (bool, optional)
- reduction (str, optional)
- label_smoothing (float, optional)
-
- More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
- `Cross_entropy `_.
- """
-
- def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
- super().__init__()
- self.loss = nn.CrossEntropyLoss(*args, **kwargs)
- self.aux_weight = aux_weight
-
- def forward(self, *args):
- """
- The ``args`` should at least include parameters below:
- ::
-
- input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
- target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
-
- More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
- `Cross_entropy `_.
- """
- main_loss = self.loss(*args)
- aux_loss = MOE_MANAGER.get_loss()
- return main_loss + self.aux_weight * aux_loss
-
-
-class MoeLoss(_Loss):
- """A wrapper class for any loss module to add with auxiliary loss.
-
- Args:
- aux_weight (float): Weight of auxiliary loss in total loss.
- loss_fn (``Callable``): Loss function.
- args (list): Args in loss function.
- kwargs (dict): Kwargs in loss function
- """
-
- def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
- super().__init__()
- self.loss_fn = loss_fn(*args, **kwargs)
- self.aux_weight = aux_weight
-
- def forward(self, *args, **kwargs):
- """
- The ``args`` and ``kwargs`` should at least include parameters below:
- ::
-
- input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
- target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
-
- Note:
- The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
- """
- main_loss = self.loss_fn(*args, **kwargs)
- aux_loss = MOE_MANAGER.get_loss()
- return main_loss + self.aux_weight * aux_loss
diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py
deleted file mode 100644
index e40674c9b..000000000
--- a/colossalai/moe/routers.py
+++ /dev/null
@@ -1,466 +0,0 @@
-import math
-from abc import ABC
-from typing import Callable, Optional, Tuple
-
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.distributed import ProcessGroup
-
-from colossalai.accelerator import get_accelerator
-from colossalai.moe._operation import moe_cumsum
-from colossalai.moe.manager import MOE_MANAGER
-
-
-class MoeRouter(nn.Module, ABC):
- """Base class for all MoE routers.
- Args:
- k_value (int): The value of top_k.
- capacity_factor_train (float): Capacity factor in routing of training.
- capacity_factor_eval (float): Capacity factor in routing of evaluation.
- min_capacity (int): The minimum number of the capacity of each expert.
- noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
- drop_tks (bool, optional): Whether drops tokens in evaluation
- """
-
- def __init__(
- self,
- k_value: int,
- capacity_factor_train: float,
- capacity_factor_eval: float,
- min_capacity: int,
- noisy_func: Optional[Callable] = None,
- drop_tks: bool = True,
- use_kernel: bool = False,
- ):
- super().__init__()
- self.k_value = k_value
- self.capacity_factor_train = capacity_factor_train
- self.capacity_factor_eval = capacity_factor_eval
- self.min_capacity = min_capacity
- self.noisy_func = noisy_func
- self.drop_tks = drop_tks
- self._aux_loss = None
- self._z_loss = None
- self.use_kernel = use_kernel
-
- def get_capacity(self, num_tokens, num_experts, ep_group=None):
- if ep_group is not None:
- num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device())
- dist.all_reduce(num_tokens_tensor, group=ep_group)
- num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group)
- capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
- capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts)
- capacity += capacity % 2
- capacity = max(capacity, self.min_capacity)
- assert capacity > 0
- return int(capacity)
-
- def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None:
- """Computes auxiliary load balancing loss as in Switch Transformer.
-
- See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
- implements the loss function presented in equations (4) - (6). It aims to
- penalize those cases where the routing between experts is unbalanced.
-
- Args:
- router_probs: Probability assigned to each expert per token. Shape:
- [num_groups, tokens_per_group, num_experts].
- expert_indices: [num_groups, tokens_per_group, num_selected_experts]
- indices identifying the top num_selected_experts for a given token.
- """
- assert self._aux_loss is None
- if router_probs.dim() == expert_indices.dim() == 2:
- router_probs = router_probs.unsqueeze(0)
- expert_indices = expert_indices.unsqueeze(0)
- assert (
- router_probs.dim() == expert_indices.dim() == 3
- ), "router_probs must be 3D tensor and expert_indices must be 4D tensor"
-
- # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
- expert_mask = F.one_hot(expert_indices, num_experts)
- # For a given token, determine if it was routed to a given expert.
- # Shape: [num_groups, tokens_per_group, num_experts]
- expert_mask = expert_mask.max(dim=-2)[0]
-
- tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2)
- router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2)
- aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert)
- self._aux_loss = aux_loss
-
- def set_z_loss(self, router_logits: torch.Tensor):
- """Compute router z-loss.
-
- The router z-loss was introduced in Designing Effective Sparse Expert Models
- (https://arxiv.org/abs/2202.08906). It encourages router logits to remain
- small in an effort to improve stability.
-
- Args:
- router_logits: [num_groups, tokens_per_group, num_experts] router logits.
- """
- assert self._z_loss is None
- if router_logits.dim() == 2:
- router_logits = router_logits.unsqueeze(0)
- assert router_logits.dim() == 3, "router_logits must be 3D tensor"
- num_groups, tokens_per_group, _ = router_logits.shape
- log_z = torch.logsumexp(router_logits, dim=-1)
- z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group)
- self._z_loss = z_loss
-
- def pop_router_loss(self) -> torch.Tensor:
- assert self._aux_loss is not None
- MOE_MANAGER.add_loss(self._aux_loss, self._z_loss)
- self._aux_loss = None
- self._z_loss = None
-
-
-class Top1Router(MoeRouter):
- """Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
- and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
- function can be found in the paper about Switch Transformer of Google.
-
- Args:
- capacity_factor_train (float, optional): Capacity factor in routing of training.
- capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
- min_capacity (int, optional): The minimum number of the capacity of each expert.
- select_policy (str, optional): The policy about tokens selection.
- noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
- drop_tks (bool, optional): Whether drops tokens in evaluation
- """
-
- def __init__(
- self,
- capacity_factor_train: float = 1.25,
- capacity_factor_eval: float = 2.0,
- min_capacity: int = 4,
- select_policy: str = "first",
- noisy_func: Optional[Callable] = None,
- drop_tks: bool = True,
- ):
- super().__init__(
- k_value=1,
- capacity_factor_train=capacity_factor_train,
- capacity_factor_eval=capacity_factor_eval,
- min_capacity=min_capacity,
- noisy_func=noisy_func,
- drop_tks=drop_tks,
- )
- self.select_policy = select_policy
- assert select_policy in {"first", "random"}
- if select_policy == "random":
- self.uniform = torch.distributions.uniform.Uniform(
- low=torch.tensor(0.0, device=get_accelerator().get_current_device()),
- high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
- ).rsample
-
- def forward(
- self,
- inputs: torch.Tensor,
- use_kernel: bool = False,
- ep_group: Optional[ProcessGroup] = None,
- use_loss: bool = False,
- use_norm: bool = False,
- ) -> Tuple:
- """
- Args:
- inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
-
- Returns:
- 1. use_kernel is False:
- The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
- The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
- 2. use_kernel is True:
- ...
- """
- if self.noisy_func is not None and self.training:
- inputs = self.noisy_func(inputs)
-
- assert inputs.dtype == torch.float
- probs = F.softmax(inputs, dim=-1)
- num_experts = probs.size(-1)
- num_tokens = inputs.size(0)
- capacity = self.get_capacity(num_tokens, num_experts, ep_group)
-
- top1_idx = torch.argmax(inputs, dim=-1)
- mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
-
- # calculate router loss
- self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts)
- self.set_z_loss(inputs)
- self.pop_router_loss()
-
- if not self.training and not self.drop_tks and ep_group is not None:
- max_num = torch.max(torch.sum(mask, dim=0))
- dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
- capacity = max_num.item()
-
- if self.select_policy == "random":
- rand_mask = mask * self.uniform(mask.shape)
- _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
- mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
- ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
- elif self.select_policy == "first":
- ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
- mask = mask * torch.lt(ranks, capacity)
- else:
- raise NotImplementedError("Not support such select policy yet.")
-
- ranks = torch.sum(mask * ranks, dim=-1)
- used_capacity = mask.sum(dim=0)
-
- if use_kernel:
- mask = torch.sum(mask, dim=-1)
- mask = torch.stack([mask], dim=0).to(torch.int32)
- dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
- return used_capacity, probs, mask, dest_idx, num_experts * capacity
- else:
- ranks = F.one_hot(ranks, num_classes=capacity)
- weight = mask * probs.type_as(inputs)
- combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
- sec_mask = combine_weights.bool()
- return used_capacity, combine_weights, sec_mask, probs
-
-
-class Top2Router(MoeRouter):
- """Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
- and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
- function can be found in the paper about ViT-MoE.
-
- Args:
- capacity_factor_train (float, optional): Capacity factor in routing of training.
- capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
- min_capacity (int, optional): The minimum number of the capacity of each expert
- noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
- drop_tks (bool, optional): Whether drops tokens in evaluation.
- """
-
- def __init__(
- self,
- capacity_factor_train: float = 1.25,
- capacity_factor_eval: float = 2.0,
- min_capacity: int = 4,
- noisy_func: Optional[Callable] = None,
- drop_tks: bool = True,
- ):
- super().__init__(
- k_value=2,
- capacity_factor_train=capacity_factor_train,
- capacity_factor_eval=capacity_factor_eval,
- min_capacity=min_capacity,
- noisy_func=noisy_func,
- drop_tks=drop_tks,
- )
-
- def forward(
- self,
- inputs: torch.Tensor,
- use_kernel: bool = False,
- ep_group: Optional[ProcessGroup] = None,
- use_norm: bool = False,
- use_loss: bool = True,
- ) -> Tuple:
- """
- Args:
- inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
-
- Returns:
- 1. use_kernel is False:
- The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
- The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
- 2. use_kernel is True:
- ...
- """
- if self.noisy_func is not None and self.training:
- inputs = self.noisy_func(inputs)
-
- assert inputs.dtype == torch.float
- probs = F.softmax(inputs, dim=-1)
- if use_norm:
- routing_weights, _ = torch.topk(probs, 2, dim=-1)
- probs = probs / routing_weights.sum(dim=-1, keepdim=True)
-
- num_experts = probs.size(-1)
- num_tokens = inputs.size(0)
- capacity = self.get_capacity(num_tokens, num_experts, ep_group)
-
- top1_idx = torch.argmax(probs, dim=-1)
- mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
- logits_except1 = probs.masked_fill(mask1.bool(), float("-inf"))
- top2_idx = torch.argmax(logits_except1, dim=-1)
- mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
-
- cmask = mask1 + mask2 # loss: [s, e]
- cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
-
- # calculate loss
- if use_loss:
- expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
- self.set_aux_loss(probs, expert_indices, num_experts)
- self.set_z_loss(inputs)
- self.pop_router_loss()
-
- if not self.training and not self.drop_tks and ep_group is not None:
- max_num = torch.max(torch.sum(cmask, dim=0))
- dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
- capacity = max_num.item()
-
- rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
- rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
- rank2 += torch.sum(mask1, dim=-2, keepdim=True)
-
- mask1 *= torch.lt(rank1, capacity)
- mask2 *= torch.lt(rank2, capacity)
- used_capacity = mask1.sum(dim=0) + mask2.sum(dim=0)
-
- rank1 = torch.sum(mask1 * rank1, dim=-1)
- rank2 = torch.sum(mask2 * rank2, dim=-1)
-
- if use_kernel:
- mask1 = torch.sum(mask1, dim=-1)
- mask2 = torch.sum(mask2, dim=-1)
-
- mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
- dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
-
- return used_capacity, probs, mask, dest_idx, num_experts * capacity
- else:
- """
- The following code is equivalent to:
-
- ```
- weight1 = mask1 * probs.type_as(inputs)
- weight2 = mask2 * probs.type_as(inputs)
- rank1_sc = F.one_hot(rank1, num_classes=capacity)
- rank2_sc = F.one_hot(rank2, num_classes=capacity)
-
- cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
- cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
- cb_weight = cb_weight1 + cb_weight2
- sec_mask = cb_weight.bool()
- ```
- """
-
- weight1 = mask1 * probs.type_as(inputs)
- weight2 = mask2 * probs.type_as(inputs)
-
- cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device)
- sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool)
- indices = torch.arange(0, inputs.shape[0], device=inputs.device)
- cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]]
- cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]]
- sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]]
- sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]]
-
- return used_capacity, cb_weight, sec_mask
-
-
-class TopKRouter(MoeRouter):
- """Masked matmul router using tokens choose top-k experts assignment.
-
- NOTE: this is modified from flaxformer.
- This router uses the same mechanism as in Switch Transformer
- (https://arxiv.org/abs/2101.03961) and V-MoE
- (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are
- sorted by router_probs and then routed to their choice of expert until the
- expert's expert_capacity is reached. There is no guarantee that each token is
- processed by an expert, or that each expert receives at least one token.
-
- Attributes:
- num_selected_experts: Maximum number of experts to which each token is
- routed. Tokens may be routed to fewer experts if particular experts are
- oversubscribed / reach capacity.
- """
-
- def __init__(
- self,
- num_selected_experts: int,
- capacity_factor_train: float = 1.25,
- capacity_factor_eval: float = 2.0,
- min_capacity: int = 4,
- noisy_func: Optional[Callable] = None,
- drop_tks: bool = True,
- ):
- super().__init__(
- num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks
- )
-
- def forward(
- self,
- router_probs: torch.Tensor,
- expert_capacity: int,
- ) -> Tuple:
- """Computes masks for the top-k experts per token.
-
- Args:
- router_probs: [num_groups, tokens_per_group, num_experts]
- probabilities used to determine the routing of tokens to the experts.
-
- Returns:
- Dispatch and combine arrays for routing with masked matmuls.
- """
- # TODO: FIXME: add parallel group
- num_groups, _, num_experts = router_probs.shape
-
- # Top-k router probability and corresponding expert indices for each token.
- # Shape: [num_groups, tokens_per_group, num_selected_experts].
- expert_gate, expert_index = torch.topk(router_probs, self.k_value)
-
- self.set_aux_loss(router_probs, expert_index, num_experts)
- self.pop_router_loss()
-
- # Make num_selected_experts the leading axis to ensure that top-1 choices
- # have priority over top-2 choices, which have priority over top-3 choices,
- # etc.
- expert_index = torch.transpose(expert_index, 1, 2)
- # Shape: [num_groups, num_selected_experts * tokens_per_group]
- expert_index = expert_index.reshape(num_groups, -1)
-
- # Create mask out of indices.
- # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
- expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
-
- # Experts have a fixed capacity that we cannot exceed. A token's priority
- # within the expert's buffer is given by the masked, cumulative capacity of
- # its target expert.
- # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
- token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1
- # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts].
- token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts))
- # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
- token_priority = torch.transpose(token_priority, 1, 2)
- # For each token, across all selected experts, select the only non-negative
- # (unmasked) priority. Now, for group G routing to expert E, token T has
- # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E
- # is its targeted expert.
- # Shape: [num_groups, tokens_per_group, num_experts].
- token_priority = torch.max(token_priority, dim=2)[0]
-
- # Token T can only be routed to expert E if its priority is positive and
- # less than the expert capacity. One-hot matrix will ignore indices outside
- # the range [0, expert_capacity).
- # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity].
- valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
- token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
- dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
- valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity)
- dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
-
- # The combine array will be used for combining expert outputs, scaled by the
- # router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
- # expert_capacity].
- combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
-
- return combine_array, dispatch_mask
-
-
-def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter:
- if not grouped:
- if top_k == 1:
- return Top1Router
- elif top_k == 2:
- return Top2Router
- else:
- raise NotImplementedError("top_k > 2 is not supported yet")
- else:
- return TopKRouter
diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py
index c642f1a44..3d08ab7dd 100644
--- a/colossalai/moe/utils.py
+++ b/colossalai/moe/utils.py
@@ -6,10 +6,11 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
+from torch.distributed.distributed_c10d import get_process_group_ranks
from colossalai.accelerator import get_accelerator
from colossalai.moe.manager import MOE_MANAGER
-from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor
+from colossalai.tensor.moe_tensor.api import is_moe_tensor
class ForceFP32Parameter(torch.nn.Parameter):
@@ -145,7 +146,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]
if not is_moe_tensor(param):
ep_size = 1 # set ep_size to 1 for dp parameters
else:
- ep_size = get_ep_size(param)
+ ep_size = dist.get_world_size(param.ep_group)
if ep_size not in epsize_param_dict:
epsize_param_dict[ep_size] = []
epsize_param_dict[ep_size].append(param)
@@ -170,8 +171,8 @@ def sync_moe_model_param(model: nn.Module):
# When ep_size = world_size, communication is not needed
if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
for param in param_dict[ep_size]:
- src_rank = get_dp_group_ranks(param)[0]
- dist.broadcast(param, src=src_rank, group=get_dp_group(param))
+ src_rank = get_process_group_ranks(param.dp_group)[0]
+ dist.broadcast(param, src=src_rank, group=param.dp_group)
def set_moe_args(config: Any, args: dict):
diff --git a/colossalai/shardformer/layer/moe/__init__.py b/colossalai/shardformer/layer/moe/__init__.py
new file mode 100644
index 000000000..6fa015a94
--- /dev/null
+++ b/colossalai/shardformer/layer/moe/__init__.py
@@ -0,0 +1,3 @@
+from .experts import *
+from .layers import *
+from .routers import *
diff --git a/colossalai/moe/experts.py b/colossalai/shardformer/layer/moe/experts.py
similarity index 98%
rename from colossalai/moe/experts.py
rename to colossalai/shardformer/layer/moe/experts.py
index 8e6ea3884..1be7a2754 100644
--- a/colossalai/moe/experts.py
+++ b/colossalai/shardformer/layer/moe/experts.py
@@ -9,7 +9,7 @@ from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation
from colossalai.shardformer.layer.utils import Randomizer
-from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, set_moe_tensor_info
+from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
if HAS_TRITON:
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
@@ -35,7 +35,7 @@ class MLPExperts(nn.Module):
num_experts: int,
hidden_size: int,
intermediate_size: int,
- expert_parallel: Optional[str] = None,
+ expert_parallel: Optional[str] = "EP",
activation: Optional[Callable] = None,
drop_rate: Optional[float] = 0,
gated: Optional[bool] = False,
diff --git a/colossalai/moe/layers.py b/colossalai/shardformer/layer/moe/layers.py
similarity index 96%
rename from colossalai/moe/layers.py
rename to colossalai/shardformer/layer/moe/layers.py
index 2ac5b186d..e5b0ef97f 100644
--- a/colossalai/moe/layers.py
+++ b/colossalai/shardformer/layer/moe/layers.py
@@ -8,11 +8,9 @@ import torch.nn as nn
import torch.nn.functional as F
from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
-from colossalai.moe.experts import MLPExperts
from colossalai.moe.load_balance import LoadBalancer
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.moe.routers import MoeRouter, get_router_cls
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
+from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size
@@ -23,6 +21,7 @@ class SparseMLP(nn.Module):
dim_model (int): Hidden dimension of training model
num_experts (int): The number experts
top_k (int, optional): The number of experts for dispatchment of each token
+ parallel (str): parallel mode. Should be "EP", "TP" or None
capacity_factor_train (float, optional): Capacity factor in routing during training
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
min_capacity (int, optional): The minimum number of the capacity of each expert
@@ -51,6 +50,7 @@ class SparseMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
router_top_k: int = 1,
+ parallel: str = "EP",
router_loss: bool = True,
router_norm: bool = False,
router_capacity_factor_train: float = 1.25,
@@ -66,7 +66,7 @@ class SparseMLP(nn.Module):
load_balance_group_swap_factor: float = 0.4,
enable_kernel: bool = False,
enable_comm_overlap: bool = False,
- enable_hierarchical_comm: bool = False,
+ enable_hierarchical_comm: bool = True,
return_gate_logits: bool = False,
):
super().__init__()
@@ -77,7 +77,9 @@ class SparseMLP(nn.Module):
self.return_gate_logits = return_gate_logits
self.enable_kernel = enable_kernel
self.enable_comm_overlap = enable_comm_overlap
- self.expert_parallel = MOE_MANAGER.get_parallel()
+ # self.expert_parallel = MOE_MANAGER.get_parallel()
+ assert parallel in ["EP", "TP", None], "parallel mode must be EP, TP or None"
+ self.parallel = parallel
self.router_loss = router_loss
self.router_norm = router_norm
@@ -99,7 +101,7 @@ class SparseMLP(nn.Module):
# moe experts
self.experts = MLPExperts(
num_experts=self.num_experts,
- expert_parallel=self.expert_parallel,
+ expert_parallel=self.parallel,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
activation=mlp_activation,
@@ -108,11 +110,12 @@ class SparseMLP(nn.Module):
)
# get parallel settings
- if self.expert_parallel is not None:
+ if self.parallel is not None:
self.ep_group = get_ep_group(self.experts)
self.ep_size = get_ep_size(self.experts)
self.ep_hierarchical_group = None
if enable_hierarchical_comm:
+ # TODO: move to plugin
self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group(
get_ep_group_ranks(self.experts)
)
@@ -186,11 +189,11 @@ class SparseMLP(nn.Module):
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
# expert_output: (num_groups, num_experts, capacity, hidden_size)
- if self.expert_parallel == "EP":
+ if self.parallel == "EP":
expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
- elif self.expert_parallel == "TP":
+ elif self.parallel == "TP":
expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
- elif self.expert_parallel is None:
+ elif self.parallel is None:
expert_output = self._local_process(dispatch_data)
else:
raise NotImplementedError(
diff --git a/colossalai/shardformer/layer/moe/routers.py b/colossalai/shardformer/layer/moe/routers.py
new file mode 100644
index 000000000..1be7a2754
--- /dev/null
+++ b/colossalai/shardformer/layer/moe/routers.py
@@ -0,0 +1,161 @@
+import math
+from typing import Callable, Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
+from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.moe.utils import get_activation
+from colossalai.shardformer.layer.utils import Randomizer
+from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
+
+if HAS_TRITON:
+ from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
+
+
+class MLPExperts(nn.Module):
+ """
+ SparseMLP is a multi-layer perceptron with sparse expert parallel layers.
+
+ Args:
+ num_experts (int): The number of experts
+ hidden_size (int): The hidden size of MLP
+ intermediate_size (int): The intermediate size of MLP
+ expert_parallel (str, optional): The parallelism of experts. Now we have None, EP and TP.
+ activation (optional): The activation function of MLP
+ drop_rate (float, optional): The drop rate of MLP
+ gated (bool, optional): Whether to use gated MLP
+ use_kernel (bool, optional): Whether to use kernel optimization
+ """
+
+ def __init__(
+ self,
+ num_experts: int,
+ hidden_size: int,
+ intermediate_size: int,
+ expert_parallel: Optional[str] = "EP",
+ activation: Optional[Callable] = None,
+ drop_rate: Optional[float] = 0,
+ gated: Optional[bool] = False,
+ use_kernel: Optional[bool] = False,
+ ):
+ super().__init__()
+ assert expert_parallel in ["EP", "TP", None]
+ self.expert_parallel = expert_parallel
+ self.num_total_experts = num_experts
+ self.gated = gated
+ self.use_kernel = use_kernel
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+
+ # get expert parallel info
+ if expert_parallel is not None:
+ self.num_local_experts, self.moe_info = MOE_MANAGER.get_info(
+ num_experts, use_tp=True if expert_parallel == "TP" else False
+ )
+ # get settings for different parallel
+ self.ep_size = get_ep_size(self)
+ if expert_parallel == "TP":
+ intermediate_size = intermediate_size // self.ep_size
+ num_experts = self.num_total_experts
+ else:
+ num_experts = self.num_local_experts
+ else:
+ self.num_local_experts = self.num_total_experts
+ self.ep_size = 1
+
+ if gated:
+ self.wi_gate = nn.Parameter(
+ torch.empty(
+ num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
+ )
+ )
+ self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
+ else:
+ self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
+ self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size))
+
+ self.act_name = activation
+ self.act = get_activation(activation)
+ self.drop = nn.Dropout(p=drop_rate)
+
+ if expert_parallel is not None:
+ for param in self.parameters():
+ set_moe_tensor_info(param, self.moe_info)
+
+ # init param
+ self.reset_parameters()
+
+ @torch.no_grad()
+ def reset_parameters(self):
+ # expert param should be different
+ if self.expert_parallel is not None:
+ seed_ctx = Randomizer(get_ep_rank(self)).fork_rng(enable_cpu=True)
+ else:
+ seed_ctx = Randomizer(42).fork_rng(enable_cpu=True)
+ with seed_ctx:
+ if self.gated:
+ torch.nn.init.normal_(self.wi_gate, std=math.sqrt(0.1 / self.hidden_size))
+ torch.nn.init.normal_(self.wi_up, std=math.sqrt(0.1 / self.hidden_size))
+ else:
+ torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size))
+ torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size))
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ param_slice: Tuple[slice] = (slice(None),),
+ use_sparse: bool = True,
+ ) -> torch.Tensor:
+ """
+ forward: hidden_size --> intermediate_size --> hidden_size
+
+ Args:
+ x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size)
+
+ Returns:
+ torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
+ """
+ x = MoeInGradScaler.apply(x, self.ep_size)
+
+ e = x.size(1)
+ h = x.size(-1)
+
+ x = x.transpose(0, 1)
+ inshape = x.shape
+ x = x.reshape(e, -1, h)
+
+ if self.use_kernel and use_sparse:
+ seq_len = x.shape[1]
+ with torch.no_grad():
+ mask = x[:, :, 0] != 0.0
+ mask = torch.sum(mask, dim=-1)
+ x_list = []
+ for i in range(e):
+ x_list.append(x[i, : mask[i]])
+ x = x_list
+
+ if self.gated:
+ x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)]
+ x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)]
+ if self.use_kernel and HAS_TRITON and self.act_name == "swiglu":
+ x = [LlamaActCombine.apply(x_gate[i], x_up[i]) for i in range(e)]
+ else:
+ x = [self.act(x_gate[i]) * x_up[i] for i in range(e)]
+ else:
+ x = [torch.mm(x[i], self.wi[param_slice][i]) for i in range(e)]
+ x = [self.act(x[i]) for i in range(e)]
+ x = [self.drop(x[i]) for i in range(e)]
+ x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)]
+
+ if self.use_kernel and use_sparse:
+ for i in range(e):
+ x[i] = torch.nn.functional.pad(x[i], (0, 0, 0, seq_len - x[i].shape[0]), mode="constant", value=0)
+
+ x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
+ x = x.reshape(inshape)
+ x = x.transpose(0, 1).contiguous()
+ x = MoeOutGradScaler.apply(x, self.ep_size)
+ return x
diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/colossalai/shardformer/modeling/mixtral.py
similarity index 65%
rename from applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
rename to colossalai/shardformer/modeling/mixtral.py
index c01e02c49..2fbc34302 100644
--- a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -1,222 +1,108 @@
-from functools import partial
-from typing import Callable, Dict, List, Optional, Union
+from typing import List, Optional
import torch
-import torch.nn as nn
-from torch import Tensor
-from torch.nn import CrossEntropyLoss, Module
+import torch.distributed as dist
+import torch.nn.functional as F
+from torch.distributed import ProcessGroup
+
+# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
+from torch.nn import CrossEntropyLoss
+from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.mixtral.modeling_mixtral import (
- MixtralDecoderLayer,
- MixtralForCausalLM,
- MixtralModel,
+ MixtralSparseMoeBlock,
MoeCausalLMOutputWithPast,
- _prepare_4d_causal_attention_mask,
load_balancing_loss_func,
)
-from transformers.utils import logging
+from transformers.utils import is_flash_attn_2_available, logging
+from colossalai.lazy import LazyInitContext
+from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
-from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from colossalai.shardformer.shard import ShardConfig
-
-from .mixtral_layer import EPMixtralSparseMoeBlock
-
-__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
+from colossalai.shardformer.shard.utils import set_tensors_to_none
-class MixtralPolicy(Policy):
- def config_sanity_check(self):
- pass
+class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
+ def __init__(self, config):
+ self.moe_info = None
+ super().__init__(config)
- def preprocess(self):
- if self.shard_config.enable_tensor_parallelism:
- # Resize embedding
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
+ def setup_ep(self, ep_group: ProcessGroup):
+ ep_group = ep_group
+ self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
+ self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
+ assert self.num_experts % self.ep_size == 0
+ self.ep_group = ep_group
+ self.num_experts_per_ep = self.num_experts // self.ep_size
+ self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
+ held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
+ set_tensors_to_none(self.experts, exclude=set(held_experts))
+ for p in self.experts.parameters():
+ p.ep_group = ep_group
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ @staticmethod
+ def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
+ LazyInitContext.materialize(module)
+ module.__class__ = EPMixtralSparseMoeBlock
+ # if "ep_group" in kwargs:
+ assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
+ module.setup_ep(kwargs["ep_group"])
+ return module
- return self.model
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+ # router_logits: (batch * sequence_length, n_experts)
+ router_logits = self.gate(hidden_states)
- def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
- policy = {}
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ # we cast back to the input dtype
+ routing_weights = routing_weights.to(hidden_states.dtype)
- if self.shard_config.enable_sequence_parallelism:
- self.shard_config.enable_sequence_parallelism = False
- raise NotImplementedError(
- "Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
- )
+ selected_experts = selected_experts.t().reshape(-1)
+ selected_experts_idx = selected_experts.argsort()
+ dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
+ input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
+ output_split_sizes = torch.zeros_like(input_split_sizes)
+ dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
- if self.shard_config.enable_tensor_parallelism:
- raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
-
- # expert parallel
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="block_sparse_moe",
- target_module=EPMixtralSparseMoeBlock,
- )
- ],
- policy=policy,
- target_key=MixtralDecoderLayer,
- )
-
- # optimization configuration
- if self.shard_config.enable_fused_normalization:
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="input_layernorm",
- target_module=FusedRMSNorm,
- ),
- SubModuleReplacementDescription(
- suffix="post_attention_layernorm",
- target_module=FusedRMSNorm,
- ),
- ],
- policy=policy,
- target_key=MixtralDecoderLayer,
- )
-
- self.append_or_create_submodule_replacement(
- description=SubModuleReplacementDescription(
- suffix="norm",
- target_module=FusedRMSNorm,
- ),
- policy=policy,
- target_key=MixtralModel,
- )
-
- if self.shard_config.enable_flash_attention:
- raise NotImplementedError("Flash attention has already been replaced in mixtral.")
-
- return policy
-
- def postprocess(self):
- return self.model
-
- def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
- """If under pipeline parallel setting, replacing the original forward method of huggingface
- to customized forward method, and add this changing to policy."""
- if self.pipeline_stage_manager:
- stage_manager = self.pipeline_stage_manager
- if self.model.__class__.__name__ == "MixtralModel":
- module = self.model
+ input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
+ output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
+ output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
+ # compute expert output
+ output_states = MoeInGradScaler.apply(output_states, self.ep_size)
+ if output_states.size(0) > 0:
+ if self.num_experts_per_ep == 1:
+ # no need to split
+ expert = self.experts[self.expert_start_idx]
+ output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
+ output_states = expert.w2(output_states)
else:
- module = self.model.model
-
- layers_per_stage = stage_manager.distribute_layers(len(module.layers))
- stage_index = stage_manager.get_stage_index(layers_per_stage)
- method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
- self.append_or_create_method_replacement(
- description=method_replacement, policy=policy, target_key=model_cls
- )
-
- return
-
- def get_held_layers(self) -> List[Module]:
- """Get pipeline layers for current stage."""
- assert self.pipeline_stage_manager is not None
-
- if self.model.__class__.__name__ == "MixtralModel":
- module = self.model
- else:
- module = self.model.model
- stage_manager = self.pipeline_stage_manager
-
- held_layers = []
- layers_per_stage = stage_manager.distribute_layers(len(module.layers))
- if stage_manager.is_first_stage():
- held_layers.append(module.embed_tokens)
- start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
- held_layers.extend(module.layers[start_idx:end_idx])
- if stage_manager.is_last_stage():
- held_layers.append(module.norm)
-
- return held_layers
-
-
-class MixtralModelPolicy(MixtralPolicy):
- def __init__(self) -> None:
- super().__init__()
-
- def module_policy(self):
- policy = super().module_policy()
- if self.pipeline_stage_manager:
- # set None as default
- self.set_pipeline_forward(
- model_cls=MixtralModel,
- new_forward=MixtralPipelineForwards.mixtral_model_forward,
- policy=policy,
- )
- return policy
-
- def get_held_layers(self) -> List[Module]:
- """Get pipeline layers for current stage."""
- held_layers = super().get_held_layers()
- return held_layers
-
- def get_shared_params(self) -> List[Dict[int, Tensor]]:
- """No shared params in llama model"""
- return []
-
-
-class MixtralForCausalLMPolicy(MixtralPolicy):
- def module_policy(self):
- policy = super().module_policy()
-
- if self.shard_config.enable_tensor_parallelism:
- # add a new item for casual lm
- new_item = {
- MixtralForCausalLM: ModulePolicyDescription(
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="lm_head",
- target_module=Linear1D_Col,
- kwargs=dict(gather_output=True),
- )
- ]
- )
- }
- policy.update(new_item)
-
- if self.pipeline_stage_manager:
- # set None as default
- self.set_pipeline_forward(
- model_cls=MixtralForCausalLM,
- new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward,
- policy=policy,
- )
-
- return policy
-
- def get_held_layers(self) -> List[Module]:
- """Get pipeline layers for current stage."""
- stage_manager = self.pipeline_stage_manager
- held_layers = super().get_held_layers()
- if stage_manager.is_last_stage():
- held_layers.append(self.model.lm_head)
- return held_layers
-
- def get_shared_params(self) -> List[Dict[int, Tensor]]:
- llama_model = self.model.model
- if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
- if (
- id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
- and self.pipeline_stage_manager.num_stages > 1
- ):
- # tie weights
- return [
- {
- 0: llama_model.embed_tokens.weight,
- self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
- }
- ]
- return []
+ output_states_splits = output_states.split(output_split_sizes.tolist())
+ output_states_list = []
+ for i, split_states in enumerate(output_states_splits):
+ if split_states.size(0) == 0:
+ continue
+ expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
+ split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
+ split_states = expert.w2(split_states)
+ output_states_list.append(split_states)
+ output_states = torch.cat(output_states_list)
+ output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
+ dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
+ recover_experts_idx = torch.empty_like(selected_experts_idx)
+ recover_experts_idx[selected_experts_idx] = torch.arange(
+ selected_experts_idx.size(0), device=selected_experts_idx.device
+ )
+ dispatch_states = dispatch_states[recover_experts_idx]
+ k_hidden_states = dispatch_states.chunk(self.top_k)
+ output_states = k_hidden_states[0] * routing_weights[:, 0, None]
+ for i in range(1, self.top_k):
+ output_states += k_hidden_states[i] * routing_weights[:, i, None]
+ output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
+ return output_states, router_logits
class MixtralPipelineForwards:
@@ -332,7 +218,7 @@ class MixtralPipelineForwards:
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
- if self._use_flash_attention_2:
+ if is_flash_attn_2_available():
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py
index 99b68aee2..bf139c840 100644
--- a/colossalai/shardformer/policies/auto_policy.py
+++ b/colossalai/shardformer/policies/auto_policy.py
@@ -176,6 +176,7 @@ _POLICY_LIST = {
"transformers.models.falcon.modeling_falcon.FalconForQuestionAnswering": PolicyLocation(
file_name="falcon", class_name="FalconForQuestionAnsweringPolicy"
),
+ # mistral
"transformers.models.mistral.modeling_mistral.MistralModel": PolicyLocation(
file_name="mistral", class_name="MistralModelPolicy"
),
@@ -185,6 +186,13 @@ _POLICY_LIST = {
"transformers.models.mistral.modeling_mistral.MistralForSequenceClassification": PolicyLocation(
file_name="mistral", class_name="MistralForSequenceClassificationPolicy"
),
+ # mixtral
+ "transformers.models.mixtral.modeling_mixtral.MixtralModel": PolicyLocation(
+ file_name="mixtral", class_name="MixtralModelPolicy"
+ ),
+ "transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM": PolicyLocation(
+ file_name="mixtral", class_name="MixtralForCausalLMPolicy"
+ ),
# Qwen2
"transformers.models.qwen2.modeling_qwen2.Qwen2Model": PolicyLocation(
file_name="qwen2", class_name="Qwen2ModelPolicy"
@@ -195,7 +203,7 @@ _POLICY_LIST = {
"transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation(
file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy"
),
- # Command-R
+ # command
"transformers.models.cohere.modeling_cohere.CohereModel": PolicyLocation(
file_name="command", class_name="CommandModelPolicy"
),
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
new file mode 100644
index 000000000..f9721c79e
--- /dev/null
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -0,0 +1,210 @@
+from functools import partial
+from typing import Callable, Dict, List, Union
+
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import Module
+from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
+
+from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
+from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+
+__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
+
+
+class MixtralPolicy(Policy):
+ def config_sanity_check(self):
+ pass
+
+ def preprocess(self):
+ if self.shard_config.enable_tensor_parallelism:
+ # Resize embedding
+ vocab_size = self.model.config.vocab_size
+ world_size = self.shard_config.tensor_parallel_size
+
+ if vocab_size % world_size != 0:
+ new_vocab_size = vocab_size + world_size - vocab_size % world_size
+ self.model.resize_token_embeddings(new_vocab_size)
+
+ return self.model
+
+ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
+ policy = {}
+
+ if self.shard_config.enable_sequence_parallelism:
+ self.shard_config.enable_sequence_parallelism = False
+ raise NotImplementedError(
+ "Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
+ )
+
+ if self.shard_config.enable_tensor_parallelism:
+ raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
+ if getattr(self.shard_config, "ep_group", None) is None:
+ raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
+
+ # expert parallel
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="block_sparse_moe",
+ target_module=EPMixtralSparseMoeBlock,
+ kwargs={"ep_group": self.shard_config.ep_group},
+ )
+ ],
+ policy=policy,
+ target_key=MixtralDecoderLayer,
+ )
+
+ # optimization configuration
+ if self.shard_config.enable_fused_normalization:
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="input_layernorm",
+ target_module=FusedRMSNorm,
+ ),
+ SubModuleReplacementDescription(
+ suffix="post_attention_layernorm",
+ target_module=FusedRMSNorm,
+ ),
+ ],
+ policy=policy,
+ target_key=MixtralDecoderLayer,
+ )
+
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="norm",
+ target_module=FusedRMSNorm,
+ ),
+ policy=policy,
+ target_key=MixtralModel,
+ )
+
+ if self.shard_config.enable_flash_attention:
+ raise NotImplementedError("Flash attention has already been replaced in mixtral.")
+
+ return policy
+
+ def postprocess(self):
+ return self.model
+
+ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
+ """If under pipeline parallel setting, replacing the original forward method of huggingface
+ to customized forward method, and add this changing to policy."""
+ if self.pipeline_stage_manager:
+ stage_manager = self.pipeline_stage_manager
+ if self.model.__class__.__name__ == "MixtralModel":
+ module = self.model
+ else:
+ module = self.model.model
+
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
+ stage_index = stage_manager.get_stage_index(layers_per_stage)
+ method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=model_cls
+ )
+
+ return
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ assert self.pipeline_stage_manager is not None
+
+ if self.model.__class__.__name__ == "MixtralModel":
+ module = self.model
+ else:
+ module = self.model.model
+ stage_manager = self.pipeline_stage_manager
+
+ held_layers = []
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
+ if stage_manager.is_first_stage():
+ held_layers.append(module.embed_tokens)
+ start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
+ held_layers.extend(module.layers[start_idx:end_idx])
+ if stage_manager.is_last_stage():
+ held_layers.append(module.norm)
+
+ return held_layers
+
+
+class MixtralModelPolicy(MixtralPolicy):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+ if self.pipeline_stage_manager:
+ # set None as default
+ self.set_pipeline_forward(
+ model_cls=MixtralModel,
+ new_forward=MixtralPipelineForwards.mixtral_model_forward,
+ policy=policy,
+ )
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ held_layers = super().get_held_layers()
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in llama model"""
+ return []
+
+
+class MixtralForCausalLMPolicy(MixtralPolicy):
+ def module_policy(self):
+ policy = super().module_policy()
+ # TODO: assign pg mesh from plugin to all modules
+ if self.shard_config.enable_tensor_parallelism:
+ # add a new item for casual lm
+ new_item = {
+ MixtralForCausalLM: ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=Linear1D_Col,
+ kwargs=dict(gather_output=True),
+ )
+ ]
+ )
+ }
+ policy.update(new_item)
+
+ if self.pipeline_stage_manager:
+ # set None as default
+ self.set_pipeline_forward(
+ model_cls=MixtralForCausalLM,
+ new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward,
+ policy=policy,
+ )
+
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ stage_manager = self.pipeline_stage_manager
+ held_layers = super().get_held_layers()
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.lm_head)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ llama_model = self.model.model
+ if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
+ if (
+ id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
+ and self.pipeline_stage_manager.num_stages > 1
+ ):
+ # tie weights
+ return [
+ {
+ 0: llama_model.embed_tokens.weight,
+ self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
+ }
+ ]
+ return []
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index 453e8d23e..b64300366 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -46,6 +46,7 @@ class ShardConfig:
make_vocab_size_divisible_by: int = 64
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
+ ep_group: Optional[ProcessGroup] = None
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py
index b6843df7a..f52802d47 100644
--- a/colossalai/tensor/moe_tensor/api.py
+++ b/colossalai/tensor/moe_tensor/api.py
@@ -17,10 +17,10 @@ def is_moe_tensor(tensor: torch.Tensor) -> bool:
Returns:
bool: Whether the given tensor is a moe tensor.
"""
- return hasattr(tensor, "moe_info")
+ return hasattr(tensor, "ep_group")
-def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None:
+def set_moe_tensor_ep_group(tensor: torch.Tensor, ep_group: ProcessGroup) -> None:
"""
Set moe info for the given tensor.
@@ -29,7 +29,7 @@ def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None
moe_info (dict): The moe info to be set.
"""
- tensor.__setattr__("moe_info", moe_info)
+ tensor.__setattr__("ep_group", ep_group)
def get_moe_info(ep_size: int, dp_size: int, pp_size: int, ep_inside: bool) -> MoeParallelInfo:
@@ -58,7 +58,7 @@ def get_ep_group(tensor: torch.Tensor) -> ProcessGroup:
Returns:
torch.distributed.ProcessGroup: The expert parallel group of the given tensor.
"""
- return tensor.moe_info.ep_group
+ return tensor.ep_group
def get_ep_size(tensor: torch.Tensor) -> int:
@@ -71,7 +71,8 @@ def get_ep_size(tensor: torch.Tensor) -> int:
Returns:
int: The expert parallel size of the given tensor.
"""
- return tensor.moe_info.ep_size
+ assert getattr(tensor, "ep_group") is not None, "The tensor does not have expert parallel group."
+ return dist.get_world_size(tensor.ep_group)
def get_dp_size(tensor: torch.Tensor) -> int:
diff --git a/colossalai/zero/low_level/bookkeeping/__init__.py b/colossalai/zero/low_level/bookkeeping/__init__.py
index 427973772..07f6cdb2d 100644
--- a/colossalai/zero/low_level/bookkeeping/__init__.py
+++ b/colossalai/zero/low_level/bookkeeping/__init__.py
@@ -1,6 +1,5 @@
from .bucket_store import BucketStore
from .gradient_store import GradientStore
-from .parameter_store import ParameterStore
from .tensor_bucket import TensorBucket
-__all__ = ["GradientStore", "ParameterStore", "BucketStore", "TensorBucket"]
+__all__ = ["GradientStore", "BucketStore", "TensorBucket"]
diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py
index 1496603fa..19d20de2b 100644
--- a/colossalai/zero/low_level/bookkeeping/bucket_store.py
+++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py
@@ -1,12 +1,11 @@
-from typing import Dict, Optional
+from typing import Dict
import torch
-import torch.distributed as dist
from torch import Tensor
from torch._utils import _flatten_dense_tensors
from torch.distributed import ProcessGroup
-from colossalai.accelerator import get_accelerator
+from colossalai.accelerator.api import get_accelerator
from .base_store import BaseStore
@@ -16,29 +15,11 @@ class BucketStore(BaseStore):
self,
torch_pg: ProcessGroup,
reduce_bucket_size: int,
- overlap_communication: bool,
- communication_dtype: Optional[torch.dtype] = None,
- moe_extra_dp_process_group: ProcessGroup = None,
):
super().__init__(torch_pg)
self.reduce_bucket_size = reduce_bucket_size
- # communication params
- self._overlap_communication = overlap_communication
- self._communication_dtype = communication_dtype
- if self._overlap_communication:
- self.comm_stream = get_accelerator().Stream()
- self.zero_local_rank = dist.get_rank(group=self.torch_pg)
- self.zero_world_size = dist.get_world_size(group=self.torch_pg)
- # extra dp
- # This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size.
- # Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg.
- # Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step.
- # And moe working and master param are split by extra dp pg.
- self.moe_extra_dp_pg = moe_extra_dp_process_group
- if self.moe_extra_dp_pg is not None:
- self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg)
- self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg)
self.reset_all()
+ self.comm_stream = get_accelerator().Stream()
def reset_all(self) -> None:
# init
diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py
index fc28b7795..e24a67f9d 100644
--- a/colossalai/zero/low_level/bookkeeping/gradient_store.py
+++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import List, Optional
from torch import Tensor
@@ -6,7 +6,7 @@ from .base_store import BaseStore
class GradientStore(BaseStore):
- def __init__(self, *args, partition_grad: bool = False, require_grad_sync: bool = True):
+ def __init__(self, *args, partition_grad: bool = False):
super().__init__(*args)
"""
self._grads_of_params mapping the parameter and its gradient slices
@@ -20,8 +20,6 @@ class GradientStore(BaseStore):
self._grads_of_params = dict()
# stage 2
self._partition_grads = partition_grad
- # grad accumulation
- self.require_grad_sync = require_grad_sync
self._working_index = 0 if partition_grad else self._local_rank
# for zero2, it's `param_id: [grad_local_rank]`
self.grad_to_param_mapping = dict()
@@ -107,8 +105,7 @@ class GradientStore(BaseStore):
for group in self._grads_of_params.values():
if param_id in group.keys():
return group[param_id][self._working_index]
-
- raise KeyError(f"Working gradient for param_id {param_id} not found.")
+ return None
def reset_grads_by_group_id(self, group_id: int):
self._grads_of_params[group_id] = dict()
@@ -116,7 +113,7 @@ class GradientStore(BaseStore):
def reset_all_gradients(self):
self._grads_of_params = dict()
- def get_param_id_for_grad(self, grad: Tensor) -> int:
+ def get_param_id_for_grad(self, grad: Tensor) -> Optional[int]:
"""Return the id of a parameter which the gradient slice belongs to
Args:
@@ -126,4 +123,4 @@ class GradientStore(BaseStore):
int: the id of a parameter which the gradient slice belongs to
"""
- return self.grad_to_param_mapping[id(grad)]
+ return self.grad_to_param_mapping.get(id(grad), None)
diff --git a/colossalai/zero/low_level/bookkeeping/parameter_store.py b/colossalai/zero/low_level/bookkeeping/parameter_store.py
deleted file mode 100644
index c03231f5f..000000000
--- a/colossalai/zero/low_level/bookkeeping/parameter_store.py
+++ /dev/null
@@ -1,60 +0,0 @@
-from typing import Dict
-
-from torch import Tensor
-from torch.distributed import ProcessGroup
-
-from .base_store import BaseStore
-
-
-class ParameterStore(BaseStore):
- def __init__(self, torch_pg: ProcessGroup):
- super().__init__(torch_pg)
-
- # record the padding size of each param
- self._padding_map = dict()
-
- # mapping working param and master param
- self.master_to_working_param = dict()
- self.working_to_master_param = dict()
-
- def record_param_padding_size(self, param: Tensor, padding_size: int):
- """Record the padding size of a param
-
- Args:
- param (Tensor): The parameter
- padding_size (int): The padding size of the parameter
- """
-
- self._padding_map[id(param)] = padding_size
-
- def get_param_padding_size(self, param: Tensor) -> int:
- """Return the padding size of the parameter
-
- Args:
- param (Tensor): The parameter
-
- Returns:
- int: the padding size of the parameter
- """
-
- return self._padding_map[id(param)]
-
- def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor):
- """Mapping master parameter and working parameter
-
- Args:
- master_param (Tensor): The parameter copy in optimizer
- working_param (Tensor): The parameter of the model
- """
-
- self.master_to_working_param[id(master_param)] = working_param
- self.working_to_master_param[id(working_param)] = master_param
-
- def get_padding_map(self) -> Dict[int, Tensor]:
- """Return the padding map
-
- Returns:
- Dict[int, Tensor]: The padding map
- """
-
- return self._padding_map
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index d19e0a002..e06cf0581 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -3,12 +3,12 @@ import copy
from contextlib import contextmanager
from functools import partial
from typing import Dict, Iterator, List, Optional, Tuple
+from weakref import proxy
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor, inf
-from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
@@ -20,17 +20,16 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
)
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
-from colossalai.tensor.moe_tensor.api import is_moe_tensor
-from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
-from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
+from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
+from .bookkeeping import BucketStore, GradientStore, TensorBucket
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(
self,
num_working_param_groups: int,
- grad_store: GradientStore,
+ pg_to_grad_store: Dict[ProcessGroup, GradientStore],
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
@@ -49,13 +48,14 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
max_scale,
)
self.num_working_param_groups = num_working_param_groups
- self.grad_store = grad_store
+ self.pg_to_grad_store = pg_to_grad_store
def check_local_overflow(self) -> bool:
- for group_id in range(self.num_working_param_groups):
- for avg_grad in self.grad_store.get_working_grads_by_group_id(group_id):
- if avg_grad is not None and has_inf_or_nan(avg_grad):
- return True
+ for store in self.pg_to_grad_store.values():
+ for group_id in range(self.num_working_param_groups):
+ for avg_grad in store.get_working_grads_by_group_id(group_id):
+ if avg_grad is not None and has_inf_or_nan(avg_grad):
+ return True
return False
@@ -65,6 +65,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def __init__(
self,
optimizer: Optimizer,
+ pg_to_param_list: Dict[ProcessGroup, List[nn.Parameter]] = None,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.0,
@@ -79,9 +80,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
- dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
+ dp_process_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None,
- moe_extra_dp_process_group: Optional[ProcessGroup] = None,
master_weights: bool = True, # master weights
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
@@ -90,12 +90,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._logger = get_dist_logger()
self._verbose = verbose
+ if dp_process_group is not None and pg_to_param_list is not None:
+ raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.")
+
+ if pg_to_param_list is None:
+ unique_dp_group = dist.group.WORLD if dp_process_group is None else dp_process_group
+ pg_to_param_list = {unique_dp_group: []}
+ for group in self.optim.param_groups:
+ pg_to_param_list[unique_dp_group].extend(group["params"])
+
+ self.pg_to_param_list = pg_to_param_list
+ param_to_pg = {}
+ for grp, param_list in pg_to_param_list.items():
+ for p in param_list:
+ assert isinstance(p, nn.Parameter), f"got {type(p)}"
+ param_to_pg[p] = grp
+ self.param_to_pg = param_to_pg
+
+ # stage 2
+ self._partition_grads = partition_grad
+
self._cpu_offload = cpu_offload
+ # grad accumulation
+ self.require_grad_sync = True
+
# working and master params for mixed precision training
self._working_param_groups = dict()
self._master_param_groups_of_current_rank = dict()
+ # communication params
+ self._overlap_communication = overlap_communication
+ self._reduce_bucket_size = reduce_bucket_size
+ self._communication_dtype = communication_dtype
+
# gradient clipping
self._clip_grad_norm = clip_grad_norm
@@ -114,17 +142,27 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
- self._param_store = ParameterStore(dp_process_group)
- self._grad_store = GradientStore(dp_process_group, partition_grad=partition_grad, require_grad_sync=True)
- self._bucket_store = BucketStore(
- dp_process_group, reduce_bucket_size, overlap_communication, communication_dtype, moe_extra_dp_process_group
- )
- # moe param should not be stored in working_groups
- # because they have different parallel strategy
- # so we need to store them separately in param_groups
- # instead of working_groups
- self.working_moe_params = list()
+ # record the padding size of each param
+ self._padding_map = dict()
+
+ # mapping working param and master param
+ self.master_to_working_param = dict()
+ self.working_to_master_param = dict()
+
+ # NOTE need to gurantee the order of process group is the same accross all ranks
+ # process_group <---> xxx_store
+ # process_group <---> [param1 param2 ...]
+ # each process group have its own stores
+ # param belonging to one process_group will use corresponding store
+ self.pg_to_grad_store = {
+ pg: GradientStore(pg, partition_grad=self._partition_grads) for pg in self.pg_to_param_list
+ }
+ # param id to grad store, have to use id(param) as key since it is used in stores
+ self.pid_to_grad_store = {id(param): self.pg_to_grad_store[param_to_pg[param]] for param in param_to_pg}
+ self.pg_to_bucket_store = {pg: BucketStore(pg, reduce_bucket_size) for pg in self.pg_to_param_list}
+ # param id to bucket store, have to use id(param) as key since it is used in stores
+ self.pid_to_bucket_store = {id(param): self.pg_to_bucket_store[param_to_pg[param]] for param in param_to_pg}
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
@@ -133,11 +171,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
group_params = list()
for param in param_group["params"]:
if param.requires_grad:
- if self._bucket_store.moe_extra_dp_pg is None:
- # skip moe param
- if is_moe_tensor(param):
- self.working_moe_params.append(param)
- continue
group_params.append(param)
# add the working params to working_param_groups for bookkeeping
@@ -151,29 +184,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# managed by this data parallel rank
param_group["params"] = master_param_current_rank
- # if there are moe params, store in addtional group in optim
- if len(self.working_moe_params) > 0:
- self._sync_master_param = False
- param_group = dict()
- # create fp32 master param
- for key, value in self.optim.param_groups[0].items():
- if key != "params":
- param_group[key] = value
- self.master_moe_params = []
- for param in self.working_moe_params:
- self.master_moe_params.append(param.clone().to(torch.float32).detach())
- # create mapping from master to working for optimizer io
- self.moe_master_to_working_map = {}
- for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
- self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param
- # add to optim
- param_group["params"] = self.master_moe_params
- self.optim.param_groups.append(param_group)
-
# reduction hook is only used if overlapping communication
# or stage 2 is used
# if it is stage 1 without overlapping, no hook will be attached
- if self._bucket_store._overlap_communication or self._grad_store._partition_grads:
+ self.grad_handles = []
+ if self._overlap_communication or self._partition_grads:
self._attach_reduction_hook()
# initialize mixed precision mixin
@@ -181,7 +196,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self._dtype is torch.float16:
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(
self.num_param_groups,
- self._grad_store,
+ self.pg_to_grad_store,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
@@ -194,7 +209,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
def __del__(self):
- self.remove_hooks()
+ for hook in self.grad_handles:
+ hook.remove()
@property
def dtype(self):
@@ -221,9 +237,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for param in param_list:
padding_size = (
- self._bucket_store.zero_world_size - param.numel() % self._bucket_store.zero_world_size
- ) % self._bucket_store.zero_world_size
- self._param_store.record_param_padding_size(param, padding_size)
+ self.pid_to_bucket_store[id(param)].world_size
+ - param.numel() % self.pid_to_bucket_store[id(param)].world_size
+ ) % self.pid_to_bucket_store[id(param)].world_size
+ self.record_param_padding_size(param, padding_size)
with torch.no_grad():
if padding_size > 0:
@@ -234,14 +251,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
else:
padding_param = param.data.view(-1)
- if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(param):
- splited_params = padding_param.split(
- padding_param.numel() // self._bucket_store.moe_extra_dp_pg_size
- )
- splited_params = splited_params[self._bucket_store.moe_extra_dp_pg_rank]
- else:
- splited_params = padding_param.split(padding_param.numel() // self._bucket_store.zero_world_size)
- splited_params = splited_params[self._bucket_store.zero_local_rank]
+ splited_params = padding_param.split(
+ padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size
+ )
+ splited_params = splited_params[self.pid_to_bucket_store[id(param)].local_rank]
# use fp32 when master_weights is True
if self._master_weights is True:
@@ -249,9 +262,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
else:
splited_param_current_rank = splited_params
- # Send the splited view to the optimizer to match ZeRO 2 grad shape
params_current_rank.append(splited_param_current_rank)
- self._param_store.link_master_and_working_param(splited_param_current_rank, param)
+ self.link_master_and_working_param(splited_param_current_rank, param)
return params_current_rank
@@ -259,93 +271,45 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# Backward Reduction Hook #
###########################
- @staticmethod
- def grad_handler(
- param: nn.Parameter,
- group_id: int,
- bucket_store: BucketStore,
- param_store: ParameterStore,
- grad_store: GradientStore,
- ):
- # if run with no_sync context, would not sync grad when backward
- if grad_store.require_grad_sync:
- LowLevelZeroOptimizer.add_to_bucket(param, group_id, bucket_store, param_store, grad_store)
-
def _attach_reduction_hook(self):
# we iterate over the working params
# on each param, we register a hook to its AccumulateGrad object
+ self_weakref = proxy(self)
+
+ def _grad_handler(param, group_id):
+ # if run with no_sync context, would not sync grad when backward
+ if self_weakref.require_grad_sync:
+ self_weakref._add_to_bucket(param, group_id)
+
for group_id in range(self.num_param_groups):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad:
- param._grad_handle = param.register_post_accumulate_grad_hook(
- partial(
- LowLevelZeroOptimizer.grad_handler,
- group_id=group_id,
- bucket_store=self._bucket_store,
- param_store=self._param_store,
- grad_store=self._grad_store,
- )
+ self.grad_handles.append(
+ param.register_post_accumulate_grad_hook(partial(_grad_handler, group_id=group_id))
)
#######################
# Reduction Functions #
#######################
- @staticmethod
- def run_reduction(bucket_store: BucketStore, grad_store: GradientStore):
- if bucket_store.num_elements_in_bucket() > 0:
+
+ def _run_reduction(self):
+ for bucket_store in self.pg_to_bucket_store.values():
+ if bucket_store.num_elements_in_bucket() <= 0:
+ continue
+
bucket_store.build_grad_in_bucket()
- if bucket_store.moe_extra_dp_pg is None:
- flat_grads = bucket_store.get_flatten_grad()
- flat_grads /= bucket_store.zero_world_size
- else:
- # record moe and non moe param
- moe_list = []
- for param in bucket_store._param_list:
- moe_list.append(is_moe_tensor(param))
- # divide them into different groups
- moe_grad_list = []
- non_moe_grad_list = []
- for grad_list in bucket_store._grad_in_bucket.values():
- non_moe_cur_grad = []
- moe_cur_grad = []
- for i in range(len(grad_list)):
- if moe_list[i] == True:
- moe_cur_grad.append(grad_list[i])
- else:
- non_moe_cur_grad.append(grad_list[i])
- if len(moe_cur_grad) > 0:
- moe_grad_list.append(moe_cur_grad)
- if len(non_moe_cur_grad) > 0:
- non_moe_grad_list.append(non_moe_cur_grad)
-
- if len(non_moe_grad_list) > 0:
- non_moe_flat_grads = []
- for grad_list in non_moe_grad_list:
- non_moe_flat_grads.append(_flatten_dense_tensors(grad_list))
- non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads)
- non_moe_flat_grads /= bucket_store.zero_world_size
-
- if len(moe_grad_list) > 0:
- moe_flat_grads = []
- for grad_list in moe_grad_list:
- moe_flat_grads.append(_flatten_dense_tensors(grad_list))
- moe_flat_grads = _flatten_dense_tensors(moe_flat_grads)
+ flat_grads = bucket_store.get_flatten_grad()
+ flat_grads /= bucket_store.world_size
# ready to add other tensors to bucket
bucket_store.reset_num_elements_in_bucket()
- if bucket_store._overlap_communication:
+ if self._overlap_communication:
stream = bucket_store.comm_stream
# in case of the memory being reused in the default stream
- if bucket_store.moe_extra_dp_pg is None:
- flat_grads.record_stream(stream)
- else:
- if len(non_moe_grad_list) > 0:
- non_moe_flat_grads.record_stream(stream)
- if len(moe_grad_list) > 0:
- moe_flat_grads.record_stream(stream)
+ flat_grads.record_stream(stream)
# waiting for ops in the default stream finishing
stream.wait_stream(get_accelerator().current_stream())
else:
@@ -354,126 +318,43 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
with get_accelerator().stream(stream):
group_id = bucket_store.current_group_id
- if bucket_store.moe_extra_dp_pg is None:
- grad_dtype = flat_grads.dtype
- if bucket_store._communication_dtype is not None:
- flat_grads = flat_grads.to(bucket_store._communication_dtype)
+ grad_dtype = flat_grads.dtype
+ if self._communication_dtype is not None:
+ flat_grads = flat_grads.to(self._communication_dtype)
- if not grad_store._partition_grads:
- if bucket_store.moe_extra_dp_pg is None:
- dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
- if flat_grads.dtype != grad_dtype:
- flat_grads = flat_grads.to(grad_dtype)
-
- flat_grads_per_rank = flat_grads.split(flat_grads.numel() // bucket_store.zero_world_size)
- grad_in_bucket = bucket_store.get_grad()
- LowLevelZeroOptimizer.update_unpartitoned_grad(
- bucket_store, grad_store, grad_in_bucket.values(), flat_grads_per_rank, group_id
- )
-
- # sync extra zero group
- else:
- # sync non moe param in global dp group
- if len(non_moe_grad_list) > 0:
- dist.all_reduce(non_moe_flat_grads, group=bucket_store.torch_pg)
- flat_grads_per_rank = non_moe_flat_grads.split(
- non_moe_flat_grads.numel() // bucket_store.zero_world_size
- )
- LowLevelZeroOptimizer.update_unpartitoned_grad(
- bucket_store, grad_store, non_moe_grad_list, flat_grads_per_rank, group_id
- )
-
- # sync moe param only in zero group
- if len(moe_grad_list) > 0:
- dist.all_reduce(moe_flat_grads, group=bucket_store.moe_extra_dp_pg)
- flat_grads_per_rank = moe_flat_grads.split(
- moe_flat_grads.numel() // bucket_store.zero_world_size
- )
- LowLevelZeroOptimizer.update_unpartitoned_grad(
- bucket_store, grad_store, moe_grad_list, flat_grads_per_rank, group_id
- )
+ if not self._partition_grads:
+ dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
+ if flat_grads.dtype != grad_dtype:
+ flat_grads = flat_grads.to(grad_dtype)
+ flat_grads_per_rank = flat_grads.split(flat_grads.numel() // bucket_store.world_size)
+ grad_in_bucket = bucket_store.get_grad()
+ self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id)
else:
- if bucket_store.moe_extra_dp_pg is None:
- flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.zero_world_size))
- received_grad = torch.zeros_like(flat_grads_list[0])
- dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
+ flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
+ recieved_grad = torch.zeros_like(flat_grads_list[0])
+ dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
- if received_grad.dtype != grad_dtype:
- received_grad = received_grad.to(grad_dtype)
+ if recieved_grad.dtype != grad_dtype:
+ recieved_grad = recieved_grad.to(grad_dtype)
- grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank]
- LowLevelZeroOptimizer.update_partitoned_grad(
- bucket_store, grad_store, grad_in_bucket_current_rank, received_grad, group_id, 1
- )
- else:
- # categorize moe and non moe param
- grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank]
- moe_grad_in_bucket_current_rank = []
- non_moe_grad_in_bucket_current_rank = []
- for idx, grad in enumerate(grad_in_bucket_current_rank):
- if moe_list[idx] == True:
- moe_grad_in_bucket_current_rank.append(grad)
- else:
- non_moe_grad_in_bucket_current_rank.append(grad)
-
- if len(non_moe_grad_list) > 0:
- flat_grads_list = list(
- non_moe_flat_grads.split(len(non_moe_flat_grads) // bucket_store.zero_world_size)
- )
- received_grad = torch.zeros_like(flat_grads_list[0])
- dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
- LowLevelZeroOptimizer.update_partitoned_grad(
- bucket_store,
- grad_store,
- non_moe_grad_in_bucket_current_rank,
- received_grad,
- group_id,
- 1,
- )
-
- if len(moe_grad_list) > 0:
- flat_grads_list = list(
- moe_flat_grads.split(len(moe_flat_grads) // bucket_store.moe_extra_dp_pg_size)
- )
- received_grad = torch.zeros_like(flat_grads_list[0])
- dist.reduce_scatter(
- received_grad,
- flat_grads_list,
- group=bucket_store.moe_extra_dp_pg,
- )
- param_slice = bucket_store.zero_world_size // bucket_store.moe_extra_dp_pg_size
- received_grad = list(received_grad.split(len(received_grad) // param_slice))
- for split_recieved_grad in received_grad:
- split_recieved_grad = _unflatten_dense_tensors(
- split_recieved_grad, moe_grad_in_bucket_current_rank
- )
- for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank):
- param_id = bucket_store.get_param_id_of_grad(grad)
- LowLevelZeroOptimizer.add_grad(
- grad_store, real_grad, param_slice, group_id, param_id
- )
+ grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.local_rank]
+ self._update_partitoned_grad(bucket_store, grad_in_bucket_current_rank, recieved_grad, group_id, 1)
bucket_store.reset()
- @staticmethod
- def update_unpartitoned_grad(
- bucket_store: BucketStore,
- grad_store: GradientStore,
- origin_grad_list: List,
- flat_grad_list: List,
- group_id: int,
+ def _update_unpartitoned_grad(
+ self, bucket_store: BucketStore, origin_grad_list: List, flat_grad_list: List, group_id: int
) -> None:
for rank, grad_list in enumerate(origin_grad_list):
sync_tensor(flat_grad_list[rank], grad_list)
for grad in grad_list:
param_id = bucket_store.get_param_id_of_grad(grad)
- LowLevelZeroOptimizer.add_grad(grad_store, grad, bucket_store.zero_world_size, group_id, param_id, rank)
+ self._add_grad(grad, bucket_store.world_size, group_id, param_id, rank)
- @staticmethod
- def update_partitoned_grad(
+ def _update_partitoned_grad(
+ self,
bucket_store: BucketStore,
- grad_store: GradientStore,
origin_grad_list: List,
flat_grad: torch.Tensor,
group_id: int,
@@ -482,30 +363,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
sync_tensor(flat_grad, origin_grad_list)
for grad in origin_grad_list:
param_id = bucket_store.get_param_id_of_grad(grad)
- LowLevelZeroOptimizer.add_grad(grad_store, grad, partition_num, group_id, param_id)
+ self._add_grad(grad, partition_num, group_id, param_id)
- @staticmethod
- def add_grad(
- grad_store: GradientStore,
+ def _add_grad(
+ self,
grad: torch.Tensor,
partition_num: int,
group_id: int,
param_id: int,
rank: int = 0,
) -> None:
- if len(grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num:
- grad_store.append_gradients_by_param_id(grad, group_id, param_id)
+ if (
+ len(self.pid_to_grad_store[param_id].get_partitioned_gradients_by_param_id(group_id, param_id))
+ < partition_num
+ ):
+ self.pid_to_grad_store[param_id].append_gradients_by_param_id(grad, group_id, param_id)
else:
- grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
+ self.pid_to_grad_store[param_id].add_gradients_by_param_id(grad, rank, group_id, param_id)
- @staticmethod
- def add_to_bucket(
- param: nn.Parameter,
- group_id: int,
- bucket_store: BucketStore,
- param_store: ParameterStore,
- grad_store: GradientStore,
- ):
+ def _add_to_bucket(self, param, group_id):
param_size = param.numel()
# check if the bucket is full
@@ -513,13 +389,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# or got a grad of param from another group
# after reduction, the bucket will be empty
if (
- bucket_store.num_elements_in_bucket() + param_size > bucket_store.reduce_bucket_size
- or group_id != bucket_store.current_group_id
+ self.pid_to_bucket_store[id(param)].num_elements_in_bucket() + param_size > self._reduce_bucket_size
+ or group_id != self.pid_to_bucket_store[id(param)].current_group_id
):
- LowLevelZeroOptimizer.run_reduction(bucket_store, grad_store)
+ self._run_reduction()
- padding_size = param_store.get_param_padding_size(param)
- bucket_store.add_param_grad(group_id, param, padding_size)
+ padding_size = self.get_param_padding_size(param)
+ self.pid_to_bucket_store[id(param)].add_param_grad(group_id, param, padding_size)
################################
# torch.optim.Optimizer methods
@@ -527,7 +403,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def backward(self, loss, retain_graph=False):
assert not (
- self._grad_store._partition_grads and not self._grad_store.require_grad_sync
+ self._partition_grads and not self.require_grad_sync
), "ZeRO2(partition_grads) and no_sync are not compatible"
if self.mixed_precision_mixin is not None:
@@ -535,34 +411,39 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
loss.backward(retain_graph=retain_graph)
- if not self._grad_store.require_grad_sync:
+ if not self.require_grad_sync:
return
- self._reduce_grad(self._grad_store._partition_grads)
+ self._reduce_grad(self._partition_grads)
# clear reduced grads
- if self._bucket_store._overlap_communication:
+ if self._overlap_communication:
get_accelerator().synchronize()
- self.zero_grad()
def backward_by_grad(self, tensor, grad):
assert not (
- self._grad_store._partition_grads and not self._grad_store.require_grad_sync
+ self._partition_grads and not self.require_grad_sync
), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
torch.autograd.backward(tensor, grad)
- if not self._grad_store.require_grad_sync:
+ if not self.require_grad_sync:
return
- self._reduce_grad(self._grad_store._partition_grads)
+ self._reduce_grad(self._partition_grads)
# clear reduced grads
- if self._bucket_store._overlap_communication:
+ if self._overlap_communication:
get_accelerator().synchronize()
- self.zero_grad()
+ def zero_bucket_stores(self):
+ for bucket_store in self.pg_to_bucket_store.values():
+ bucket_store.reset_all()
+
+ def zero_grad_stores(self):
+ for grad_store in self.pg_to_grad_store.values():
+ grad_store.reset_all_gradients()
def zero_grad(self, set_to_none=True):
"""
@@ -582,7 +463,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if param.grad is not None:
param.grad.detach()
param.grad.zero_()
- self._bucket_store.reset_all()
+ self.zero_grad_stores()
+ self.zero_bucket_stores()
####################
# Update Parameter #
@@ -590,11 +472,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def step(self, closure=None):
assert closure is None, "closure is not supported by step()"
- if not self._grad_store.require_grad_sync:
+ if not self.require_grad_sync:
return
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
- self._grad_store.reset_all_gradients()
if self._verbose:
self._logger.info(f"Found overflow. Skip step")
self.zero_grad()
@@ -609,71 +490,41 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# and should not be updated
real_working_params = dict()
real_master_params = dict()
- grad_index = 0 if self._grad_store._partition_grads else self._bucket_store.zero_local_rank
+
for group_id in range(self.num_param_groups):
master_params = self._master_param_groups_of_current_rank[group_id]
+ working_params = self._working_param_groups[group_id]
real_working_params[group_id] = []
real_master_params[group_id] = []
- for splited_param in master_params:
- working_param = self._param_store.master_to_working_param[id(splited_param)]
+ working_grads = []
+ for working_param, master_param in zip(working_params, master_params):
# if a working param requires grad and has no grad
# it is not 'really' working, e.g. the droped layer
# else the splited grad should be attached to the splited param
- grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
+ grad_store = self.pid_to_grad_store[id(working_param)]
+ grads = grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
+ grad_index = 0 if self._partition_grads else grad_store.local_rank
if len(grads) > 0:
- # moe hybrid zero
- if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
- real_working_params[group_id].append(working_param)
- if self._grad_store._partition_grads:
- grad = grads
- else:
- param_slice = self._bucket_store.zero_world_size // self._bucket_store.moe_extra_dp_pg_size
- grad = grads[
- self._bucket_store.moe_extra_dp_pg_rank
- * param_slice : (self._bucket_store.moe_extra_dp_pg_rank + 1)
- * param_slice
- ]
- grad = flatten(grad)
- else:
- real_working_params[group_id].append(working_param)
- grad = grads[grad_index]
+ real_working_params[group_id].append(working_param)
+ grad = grads[grad_index]
# no need to copy fp32 grad if master_weights is False
if self._master_weights:
- grad = grad.to(splited_param.dtype).to(splited_param.device)
- splited_param.grad = grad
+ grad = grad.to(master_param.dtype).to(master_param.device)
+ master_param.grad = grad
grad_partition_groups.append(grad)
- real_master_params[group_id].append(splited_param)
+ real_master_params[group_id].append(master_param)
# compute norm
- working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
- norm_group = self._compute_grad_norm(gradients=working_grads)
- norm_groups.append(norm_group)
+ norm_group = 0
+ for grad_store in self.pg_to_grad_store.values():
+ working_grads = grad_store.get_working_grads_by_group_id(group_id)
+ norm_group += self._compute_grad_norm(dp_pg=grad_store.torch_pg, gradients=working_grads)
- self._grad_store.reset_grads_by_group_id(group_id)
+ norm_groups.append(norm_group)
# update the params in the optimizer
self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
- # update param for moe ep
- # move grad to master param and compute norm
- if len(self.working_moe_params) > 0:
- moe_grads = []
- for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
- if master_moe_param.grad is not None:
- raise RuntimeError("Moe param should not have grad here")
- grad = working_moe_param.grad
- # no need to copy fp32 grad if master_weights is False
- if self._master_weights:
- grad = grad.to(master_moe_param.dtype).to(master_moe_param.device)
- master_moe_param.grad = grad
- working_moe_param.grad = None
- moe_grads.append(grad)
- grad_partition_groups.append(grad)
- norm_group = self._compute_grad_norm(gradients=moe_grads)
- norm_groups.append(norm_group)
- self.optim.param_groups[-1]["params"] = self.master_moe_params
- del moe_grads
-
# unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
@@ -681,48 +532,34 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# update the parameters
self.optim.step()
- # release moe grad
- if len(self.working_moe_params) > 0:
- for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
- master_moe_param.grad = None
- working_moe_param.data = (
- master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach()
- )
-
# release the grad
grad_partition_groups = []
for group_id in range(self.num_param_groups):
release_param_grad(self._master_param_groups_of_current_rank[group_id])
- tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size)
- moe_tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size)
+ self.pg_to_tensor_bucket = {
+ pg: TensorBucket(self.pg_to_bucket_store[pg].reduce_bucket_size) for pg in self.pg_to_param_list
+ }
# update working partition updated by the current rank
device = get_accelerator().get_current_device()
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"]
- for idx, splited_param in enumerate(master_working_param):
+ for idx, master_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx]
- param_to_gather = splited_param.to(device).to(self._dtype)
- if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
- try:
- moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
- except RuntimeError:
- moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg)
- moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
- else:
- try:
- tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
- except RuntimeError:
- tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg)
- tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
+ param_to_gather = master_param.to(device).to(self._dtype)
+ pg = self.param_to_pg[working_param]
+ try:
+ self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
+ except RuntimeError:
+ self.pg_to_tensor_bucket[pg].all_gather(pg)
+ self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
- if not moe_tensor_bucket.is_empty():
- moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg)
- if not tensor_bucket.is_empty():
- tensor_bucket.all_gather(self._bucket_store.torch_pg)
+ for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
+ if not tensor_bucket.is_empty():
+ tensor_bucket.all_gather(pg)
- def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
+ def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Compute and return the gradient norm for gradient clipping.
@@ -745,7 +582,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
device=get_accelerator().get_current_device(),
dtype=torch.float,
)
- dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self._bucket_store.torch_pg)
+ dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg)
total_norm = total_norm_cuda.item()
else:
@@ -763,7 +600,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
torch.distributed.all_reduce(
total_norm_exponentiated_cuda,
op=torch.distributed.ReduceOp.SUM,
- group=self._bucket_store.torch_pg,
+ group=dp_pg,
)
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
@@ -798,33 +635,27 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad and param.grad is not None:
- LowLevelZeroOptimizer.add_to_bucket(
- param,
- group_id,
- self._bucket_store,
- self._param_store,
- self._grad_store,
- )
+ self._add_to_bucket(param, group_id)
- LowLevelZeroOptimizer.run_reduction(self._bucket_store, self._grad_store)
+ self._run_reduction()
def _reduce_grad(self, partition_grad):
# if not overlapping communication (no reduction hook is attached) when zero1
# we need to manually reduce these gradients
- if not partition_grad and not self._bucket_store._overlap_communication:
+ if not partition_grad and not self._overlap_communication:
self._sync_grad()
else:
- LowLevelZeroOptimizer.run_reduction(self._bucket_store, self._grad_store)
+ self._run_reduction()
# this context comes from pytorch DDP
@contextmanager
def no_sync(self):
- old_require_grad_sync = self._grad_store.require_grad_sync
- self._grad_store.require_grad_sync = False
+ old_require_grad_sync = self.require_grad_sync
+ self.require_grad_sync = False
try:
yield
finally:
- self._grad_store.require_grad_sync = old_require_grad_sync
+ self.require_grad_sync = old_require_grad_sync
##############
# State Dict #
@@ -863,19 +694,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
zero_state[param] = copy.deepcopy(state)
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
- working_param = self._param_store.master_to_working_param[id(param)]
- if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v):
- gather_tensor = [
- torch.zeros(v.shape, device=device, dtype=v.dtype)
- for _ in range(self._bucket_store.moe_extra_dp_pg_size)
- ]
- dist.all_gather(gather_tensor, v.to(device), group=self._bucket_store.moe_extra_dp_pg)
- else:
- gather_tensor = [
- torch.zeros(v.shape, device=device, dtype=v.dtype)
- for _ in range(self._bucket_store.zero_world_size)
- ]
- dist.all_gather(gather_tensor, v.to(device), group=self._bucket_store.torch_pg)
+ working_param = self.master_to_working_param[id(param)]
+ pg = self.param_to_pg[working_param]
+ gather_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())]
+ dist.all_gather(gather_tensor, v.to(device), group=pg)
param_state = (
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
@@ -892,26 +714,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
state_dict (dict): A pytorch form state_dict
"""
zero_state_dict = copy.deepcopy(state_dict)
+ idx2master = {}
+ cnt = 0
+ for param_group in self.optim.param_groups:
+ for param in param_group["params"]:
+ idx2master[cnt] = param
+ cnt += 1
for param_idx, state in zero_state_dict["state"].items():
+ pg = self.param_to_pg[self.master_to_working_param[id(idx2master[param_idx])]]
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
- padding_size = (
- self._bucket_store.zero_world_size - v.numel() % self._bucket_store.zero_world_size
- ) % self._bucket_store.zero_world_size
+ padding_size = (pg.size() - v.numel() % pg.size()) % pg.size()
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
- if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v):
- v_list = v.split(v.numel() // self._bucket_store.moe_extra_dp_pg_size)
- zero_state_dict["state"][param_idx][k] = (
- v_list[self._bucket_store.moe_extra_dp_pg_rank].detach().clone()
- )
- else:
- v_list = v.split(v.numel() // self._bucket_store.zero_world_size)
- zero_state_dict["state"][param_idx][k] = (
- v_list[self._bucket_store.zero_local_rank].detach().clone()
- )
+ v_list = v.split(v.numel() // pg.size())
+ zero_state_dict["state"][param_idx][k] = v_list[pg.rank()].detach().clone()
self.optim.load_state_dict(zero_state_dict)
@@ -930,31 +749,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
device = get_accelerator().get_current_device()
local_states = self.optim.state_dict()["state"]
+
+ idx2master = {}
+ cnt = 0
+ for param_group in self.optim.param_groups:
+ for param in param_group["params"]:
+ idx2master[cnt] = param
+ cnt += 1
for param_idx, states in local_states.items():
current_block_size = 0
current_block = copy.deepcopy(states)
- # find the working param of current param_id
- for group_id, pg in self._master_param_groups_of_current_rank.items():
- if (group_id + 1) * len(pg) < param_idx:
- continue
- master_param = pg[param_idx - (group_id) * len(pg)]
- working_param = self._param_store.master_to_working_param[id(master_param)]
+ master_param = idx2master[param_idx]
+ working_param = self.master_to_working_param[id(master_param)]
+ pg = self.param_to_pg[working_param]
for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step":
- if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v):
- state_tensor = [
- torch.zeros(v.shape, device=device, dtype=v.dtype)
- for _ in range(self._bucket_store.moe_extra_dp_pg_size)
- ]
- dist.all_gather(state_tensor, v.to(device), group=self._bucket_store.moe_extra_dp_pg)
- else:
- state_tensor = [
- torch.zeros(v.shape, device=device, dtype=v.dtype)
- for _ in range(self._bucket_store.zero_world_size)
- ]
- dist.all_gather(state_tensor, v.to(device), group=self._bucket_store.torch_pg)
+ state_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())]
+ dist.all_gather(state_tensor, v.to(device), group=pg)
state_tensor = (
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
@@ -979,46 +792,96 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
"""
for p in model.parameters():
p_id = id(p)
- if p_id in self._param_store.working_to_master_param:
- master_param = self._param_store.working_to_master_param[p_id]
- padding_size = self._param_store.get_param_padding_size(p)
+ pg = self.param_to_pg[p]
+ if p_id in self.working_to_master_param:
+ master_param = self.working_to_master_param[p_id]
+ padding_size = self.get_param_padding_size(p)
working_param = p.data.view(-1)
if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
- if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(p):
- master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
- else:
- master_param.copy_(
- working_param.chunk(self._bucket_store.zero_world_size)[self._bucket_store.zero_local_rank]
- )
- if hasattr(self, "master_moe_params"):
- for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
- master_moe_param.copy_(working_moe_param)
-
- def remove_hooks(self) -> None:
- """remove the registered hooks
-
- Args:
- plugin (LowLevelZeroPlugin): the plugin to bound this method.
- """
- for group_id in range(self.num_param_groups):
- param_group = self._working_param_groups[group_id]
- for param in param_group:
- if param.requires_grad:
- assert hasattr(param, "_grad_handle")
- param._grad_handle.remove()
- delattr(param, "_grad_handle")
+ master_param.copy_(working_param.chunk(pg.size())[pg.rank()])
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
- return self._param_store.working_to_master_param
+ return self.working_to_master_param
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
- if hasattr(self, "moe_master_to_working_map"):
- return {
- **self._param_store.master_to_working_param,
- **self.moe_master_to_working_map,
- }
- return self._param_store.master_to_working_param
+ return self.master_to_working_param
def get_param_padding_map(self) -> Dict[int, torch.Tensor]:
- return self._param_store.get_padding_map()
+ return self._padding_map
+
+ def record_param_padding_size(self, param: Tensor, padding_size: int):
+ """Record the padding size of a param
+
+ Args:
+ param (Tensor): The parameter
+ padding_size (int): The padding size of the parameter
+ """
+
+ self._padding_map[id(param)] = padding_size
+
+ def get_param_padding_size(self, param: Tensor) -> int:
+ """Return the padding size of the parameter
+
+ Args:
+ param (Tensor): The parameter
+
+ Returns:
+ int: the padding size of the parameter
+ """
+
+ return self._padding_map[id(param)]
+
+ def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor):
+ """Mapping master parameter and working parameter
+
+ Args:
+ master_param (Tensor): The parameter copy in optimizer
+ working_param (Tensor): The parameter of the model
+ """
+
+ self.master_to_working_param[id(master_param)] = working_param
+ self.working_to_master_param[id(working_param)] = master_param
+
+ def get_padding_map(self) -> Dict[int, Tensor]:
+ """Return the padding map
+
+ Returns:
+ Dict[int, Tensor]: The padding map
+ """
+
+ return self._padding_map
+
+ def get_param_grad(self, working_param: nn.Parameter) -> Tensor:
+ grad_store = self.pid_to_grad_store[id(working_param)]
+ partial_grad = grad_store.get_working_grad_by_param_id(id(working_param))
+ if partial_grad is None:
+ return None
+ tensor_list = [torch.empty_like(partial_grad) for _ in range(grad_store.world_size)]
+ dist.all_gather(tensor_list, partial_grad, group=grad_store.torch_pg)
+ grad_flat = torch.cat(tensor_list, dim=0)
+ return grad_flat[: working_param.numel()].reshape_as(working_param)
+
+ def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
+ working_grads = []
+ for grad_store in self.pg_to_grad_store.values():
+ working_grads.extend(grad_store.get_working_grads_by_group_id(group_id))
+ return working_grads
+
+ def get_param_id_for_grad(self, grad: Tensor) -> int:
+ param_id = None
+ for grad_store in self.pg_to_grad_store.values():
+ id_maybe_none = grad_store.get_param_id_for_grad(grad)
+ if id_maybe_none is not None:
+ if param_id is not None:
+ raise ValueError("The grad mapping is not unique")
+ param_id = id_maybe_none
+ return param_id
+
+ def get_working_grad_by_param_id(self, param_id: int) -> Tensor:
+ grad_store = self.pid_to_grad_store[param_id]
+ return grad_store.get_working_grad_by_param_id(param_id)
+
+ def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
+ grad_store = self.pid_to_grad_store[param_id]
+ return grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)
diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py
index 22e0c790b..b9ef915c3 100644
--- a/examples/language/openmoe/benchmark/benchmark_cai.py
+++ b/examples/language/openmoe/benchmark/benchmark_cai.py
@@ -176,7 +176,7 @@ def main():
use_ep_inside = False
plugin = MoeHybridParallelPlugin(
pp_size=1,
- extra_dp_size=args.extra_dp_size,
+ ep_size=args.ep_size,
use_ep_inside=use_ep_inside,
**hybrid_dict,
)
diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py
index 5a9e30dd4..1febacd7d 100644
--- a/examples/language/openmoe/model/modeling_openmoe.py
+++ b/examples/language/openmoe/model/modeling_openmoe.py
@@ -50,9 +50,9 @@ try:
except:
HAS_FLASH_ATTN = False
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
-from colossalai.moe.layers import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation, set_moe_args
+from colossalai.shardformer.layer.moe import SparseMLP
if HAS_TRITON:
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
@@ -83,7 +83,7 @@ def set_openmoe_args(
load_balance_group_swap_factor: float = 0.4,
enable_kernel: bool = False,
enable_comm_overlap: bool = False,
- enable_hierarchical_alltoall: bool = False,
+ enable_hierarchical_alltoall: bool = True,
) -> None:
"""
MoE related arguments.
@@ -465,7 +465,7 @@ class OpenMoeDecoderLayer(nn.Module):
load_balance_beam_width=config.load_balance_beam_width,
load_balance_group_swap_factor=config.load_balance_group_swap_factor,
enable_kernel=config.enable_kernel,
- enable_comm_overlap=config.enable_comm_overlap,
+ enable_hierarchical_comm=config.enable_hierarchical_alltoall,
)
self.pre_extra_mlp_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.extra_mlp = OpenMoeMLP(config)
@@ -903,7 +903,7 @@ class OpenMoeForCausalLM(OpenMoePreTrainedModel):
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
# reset moe loss
- MOE_MANAGER.reset_loss()
+ MOE_MANAGER.reset_loss() # TODO: remove
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -1027,7 +1027,7 @@ class OpenMoeForCausalLM(OpenMoePreTrainedModel):
def _calculate_router_loss(self, aux_loss: list = None, z_loss: list = None):
if aux_loss is None or z_loss is None:
- aux_loss, z_loss = MOE_MANAGER.get_loss()
+ aux_loss, z_loss = MOE_MANAGER.get_loss() # TODO: remove
assert len(aux_loss) == len(z_loss) == self.config.num_hidden_layers // self.config.moe_layer_interval
aux_loss = self.config.router_aux_loss_factor * sum(aux_loss) / len(aux_loss)
z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss)
diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py
index 8ef07bdb9..f46062128 100644
--- a/examples/language/openmoe/model/openmoe_policy.py
+++ b/examples/language/openmoe/model/openmoe_policy.py
@@ -172,6 +172,7 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy):
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
+ # TODO: recursively assign ep group foe all modules
new_item = {
OpenMoeForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh
index 960c83adb..9ea232478 100644
--- a/examples/language/openmoe/test_ci.sh
+++ b/examples/language/openmoe/test_ci.sh
@@ -1,37 +1,37 @@
-pip install -r requirements.txt
+# pip install -r requirements.txt
# inference
-python infer.py --model "test"
+# python infer.py --model "test"
# train
-torchrun --standalone --nproc_per_node 4 train.py \
- --num_epoch 1 \
- --model_name "test" \
- --plugin "ep" \
- --batch_size 1
+# torchrun --standalone --nproc_per_node 4 train.py \
+# --num_epoch 1 \
+# --model_name "test" \
+# --plugin "ep" \
+# --batch_size 1
-torchrun --standalone --nproc_per_node 4 train.py \
- --num_epoch 1 \
- --model_name "test" \
- --plugin "ep_zero" \
- --batch_size 1 \
- --zero_stage 1 \
- --extra_dp_size 2 \
+# torchrun --standalone --nproc_per_node 4 train.py \
+# --num_epoch 1 \
+# --model_name "test" \
+# --plugin "ep_zero" \
+# --batch_size 1 \
+# --zero_stage 1 \
+# --extra_dp_size 2 \
-torchrun --standalone --nproc_per_node 4 train.py \
- --num_epoch 1 \
- --model_name "test" \
- --plugin "ep_zero" \
- --batch_size 1 \
- --zero_stage 2 \
- --extra_dp_size 2 \
+# torchrun --standalone --nproc_per_node 4 train.py \
+# --num_epoch 1 \
+# --model_name "test" \
+# --plugin "ep_zero" \
+# --batch_size 1 \
+# --zero_stage 2 \
+# --extra_dp_size 2 \
-torchrun --standalone --nproc_per_node 4 train.py \
- --model_name "test" \
- --plugin "hybrid" \
- --num_epoch 1 \
- --pp_size 2 \
- --dp_size 1 \
- --ep_size 2 \
- --zero_stage 1 \
- --batch_size 1
+# torchrun --standalone --nproc_per_node 4 train.py \
+# --model_name "test" \
+# --plugin "hybrid" \
+# --num_epoch 1 \
+# --pp_size 2 \
+# --dp_size 1 \
+# --ep_size 2 \
+# --zero_stage 1 \
+# --batch_size 1
diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py
index 40f072f13..ff0e4bad6 100644
--- a/examples/language/openmoe/train.py
+++ b/examples/language/openmoe/train.py
@@ -19,10 +19,9 @@ from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
-from colossalai.moe.layers import apply_load_balance
-from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.nn.optimizer import HybridAdam
+from colossalai.shardformer.layer.moe import apply_load_balance
def move_to_cuda(batch, device):
@@ -221,48 +220,49 @@ def main():
"precision": args.precision,
"zero_stage": args.zero_stage,
}
- mgr_dict = {}
if args.plugin == "ep":
dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin(
pp_size=1,
+ ep_size=args.ep_size,
**hybrid_dict,
)
- MOE_MANAGER.setup(
- parallel="EP",
- max_ep_size=dp_size,
- **mgr_dict,
- )
+ # MOE_MANAGER.setup(
+ # parallel="EP",
+ # max_ep_size=dp_size,
+ # **mgr_dict,
+ # )
elif args.plugin == "ep_zero":
dp_size = dist.get_world_size()
use_ep_inside = False
plugin = MoeHybridParallelPlugin(
pp_size=1,
- extra_dp_size=args.extra_dp_size,
+ ep_size=dp_size // args.ep_size,
use_ep_inside=use_ep_inside,
**hybrid_dict,
)
- MOE_MANAGER.setup(
- parallel="EP",
- max_ep_size=dp_size // args.extra_dp_size,
- use_ep_inside=use_ep_inside,
- **mgr_dict,
- )
+ # MOE_MANAGER.setup(
+ # parallel="EP",
+ # max_ep_size=dp_size // args.extra_dp_size,
+ # use_ep_inside=use_ep_inside,
+ # **mgr_dict,
+ # )
elif args.plugin == "hybrid":
dp_size = dist.get_world_size() // args.pp_size
plugin = MoeHybridParallelPlugin(
pp_size=args.pp_size,
+ ep_size=args.ep_size,
microbatch_size=args.microbatch_size,
**hybrid_dict,
)
- MOE_MANAGER.setup(
- parallel="EP",
- mode="fixed",
- fixed_dp_size=args.dp_size,
- fixed_ep_size=args.ep_size,
- fixed_pp_size=args.pp_size,
- **mgr_dict,
- )
+ # MOE_MANAGER.setup(
+ # parallel="EP",
+ # mode="fixed",
+ # fixed_dp_size=args.dp_size,
+ # fixed_ep_size=args.ep_size,
+ # fixed_pp_size=args.pp_size,
+ # **mgr_dict,
+ # )
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
index 24dc4a5d2..ab48944d4 100644
--- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
@@ -59,10 +59,10 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
# check master weight
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
working_param_id_set = set(id(p) for p in new_model.parameters())
- for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
+ for p_id, master_param in new_optimizer.working_to_master_param.items():
assert p_id in working_param_id_set
- working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
- padding = new_optimizer._param_store.get_param_padding_size(working_param)
+ working_param = new_optimizer.master_to_working_param[id(master_param)]
+ padding = new_optimizer.get_param_padding_size(working_param)
padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
assert torch.equal(
@@ -115,10 +115,10 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
# check master weight
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
working_param_id_set = set(id(p) for p in new_model.parameters())
- for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
+ for p_id, master_param in new_optimizer.working_to_master_param.items():
assert p_id in working_param_id_set
- working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
- padding = new_optimizer._param_store.get_param_padding_size(working_param)
+ working_param = new_optimizer.master_to_working_param[id(master_param)]
+ padding = new_optimizer.get_param_padding_size(working_param)
padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
assert torch.equal(
diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py
index 17b790e3e..131932dcb 100644
--- a/tests/test_moe/moe_utils.py
+++ b/tests/test_moe/moe_utils.py
@@ -1,48 +1,37 @@
import torch
import torch.distributed as dist
import torch.nn as nn
+from torch.distributed import ProcessGroup
from torch.testing import assert_close
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
from colossalai.legacy.registry import GRADIENT_HANDLER
-from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_moe_epsize_param_dict
-from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size
+
+# from colossalai.shardformer.layer.moe import SparseMLP
+from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group
def delete_moe_info(model):
for _, param in model.named_parameters():
- if hasattr(param, "moe_info"):
- delattr(param, "moe_info")
+ if hasattr(param, "ep_group"):
+ delattr(param, "ep_group")
class MoeModel(nn.Module):
- def __init__(self, enable_load_balance: bool = False):
- class TestSubModule(nn.Module):
- def __init__(self):
- super().__init__()
- self.moe = SparseMLP(
- num_experts=8, hidden_size=16, intermediate_size=32, enable_load_balance=enable_load_balance
- )
- self.proj = nn.Linear(16, 4)
-
- def forward(self, x):
- x = self.moe(x)
- x = self.proj(x)
- return x
-
+ def __init__(self, ep_group: ProcessGroup = None):
super().__init__()
- self.test_embed = nn.Linear(4, 16)
- self.test_transform = TestSubModule()
+ self.test_embed = nn.Linear(4, 16, bias=False)
+ self.w1 = torch.nn.Parameter(torch.randn(16, 8))
+ if ep_group:
+ set_moe_tensor_ep_group(self.w1, ep_group)
def forward(self, x):
- MOE_MANAGER.reset_loss()
-
x = self.test_embed(x)
- x = self.test_transform(x)
+ x = torch.matmul(x, self.w1)
return x
@@ -116,7 +105,7 @@ def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False)
return y
-def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
+def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
@@ -126,7 +115,6 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_
for (local_name, local_param), (ep_name, ep_param) in zip(
local_model.named_parameters(), ep_model.named_parameters()
):
- assert local_name in ep_name, print(f"{local_name} != {ep_name}")
if "experts" not in local_name:
if assert_grad_flag:
assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py
index a88f5f9cc..25e61b091 100644
--- a/tests/test_moe/test_grad_handler.py
+++ b/tests/test_moe/test_grad_handler.py
@@ -5,8 +5,9 @@ import torch.nn as nn
import colossalai
from colossalai.accelerator import get_accelerator
-from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
+
+# from colossalai.shardformer.layer.moe.layers import SparseMLP
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeGradientHandler
@@ -69,6 +70,7 @@ def run_test(rank, world_size, port):
# MoE grad handler test passed
+@pytest.mark.skip(reason="moe need to be refactored")
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_grad_handler():
diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py
index 30122d31a..28e6db441 100644
--- a/tests/test_moe/test_kernel.py
+++ b/tests/test_moe/test_kernel.py
@@ -1,98 +1,96 @@
+import os
+
import pytest
import torch
-import torch.distributed as dist
-import colossalai
from colossalai.accelerator import get_accelerator
-from colossalai.moe import SparseMLP
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-BATCH_SIZE = 4
+# from colossalai.moe import SparseMLP
+from colossalai.moe._operation import MoeCombine, MoeDispatch, moe_cumsum
+
NUM_EXPERTS = 4
+BATCH_SIZE = 4
+SEQ_LEN = 4
+
+MOE_TENSOR_PATH = os.getenv("MOE_TENSOR_PATH")
def check_equal(tensor_a, tensor_b, atol=1e-06):
assert torch.allclose(tensor_a, tensor_b, rtol=0, atol=atol) is True
-def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, topk=1):
- # Here we do not need TF32, since it brings absolute error on results
- torch.backends.cuda.matmul.allow_tf32 = False
+def run_moe_cumsum():
+ test_mask = torch.tensor(
+ [
+ [0, 1, 0, 0],
+ [1, 0, 0, 0],
+ [0, 1, 0, 0],
+ [1, 0, 0, 0],
+ ],
+ dtype=torch.int32,
+ ).to("cuda")
+ out_no_kernel = moe_cumsum(test_mask, use_kernel=False)
+ out_kernel = moe_cumsum(test_mask, use_kernel=True)
+ print(out_no_kernel.dtype, out_kernel.dtype)
+ check_equal(out_no_kernel.to(torch.int32), out_kernel)
- colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- local_rank = dist.get_rank()
- MOE_MANAGER.setup(parallel="EP") # MOE environment initialization
- MOE_MANAGER.reset_loss()
- torch.manual_seed(rs + local_rank) # set each process has different random seed
-
- # get randomized data
+def run_moe_dispatch_combine_fwd_bwd(data_type=torch.float32, hidden_size=128, num_experts=4):
tokens = torch.randn(
BATCH_SIZE, hidden_size, dtype=data_type, device=get_accelerator().get_current_device(), requires_grad=True
)
- layer = SparseMLP(
- hidden_size=hidden_size,
- intermediate_size=hidden_size * 2,
- num_experts=NUM_EXPERTS,
- router_top_k=topk,
- router_capacity_factor_train=1.0,
- )
- layer = layer.to(get_accelerator().get_current_device())
- if data_type == torch.float16:
- layer = layer.half()
+ # use kernel
+ route_result_list_kernel = torch.load(f"{MOE_TENSOR_PATH}/True_4_{data_type}.pt")
+ # dispatch
+ dispatch_data_kernel = MoeDispatch.apply(tokens, *route_result_list_kernel[1:])
+ dispatch_data_kernel = dispatch_data_kernel.reshape(num_experts, -1, hidden_size)
+ # combine
+ expert_output = dispatch_data_kernel.reshape(-1, hidden_size)
+ ans_kernel = MoeCombine.apply(expert_output, *route_result_list_kernel)
- # use matrix multiplication instead of COL_MOE_KERNEL in MOE dispatch and combine
- layer.enable_kernel = False
- old_out = layer(tokens)
- ech = old_out.shape
- grad = torch.randn(ech, device=get_accelerator().get_current_device())
- old_out.backward(grad) # get gradient
+ # no kernel
+ route_result_list_no_kernel = torch.load(f"{MOE_TENSOR_PATH}/False_2_{data_type}.pt")
+ # dispatch
+ sec_mask_f = route_result_list_no_kernel[1].type_as(tokens)
+ dispatch_data_no_kernel = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
+ # combine
+ combine_weights = route_result_list_no_kernel[0].type_as(tokens)
+ combine_weights = combine_weights.view(combine_weights.shape[0], -1)
+ expert_output = expert_output.view(-1, expert_output.shape[-1])
+ ans_no_kernel = torch.matmul(combine_weights, expert_output)
- # save all results
- o_tk_grad = tokens.grad.data.clone()
- o_gt_grad = layer.gate_weight.grad.data.clone()
+ # check fwd
+ if data_type == torch.float32:
+ check_equal(dispatch_data_kernel.reshape(dispatch_data_no_kernel.shape), dispatch_data_no_kernel)
+ else:
+ check_equal(dispatch_data_kernel.reshape(dispatch_data_no_kernel.shape), dispatch_data_no_kernel, 1e-2)
- # reset all gradients
+ if data_type == torch.float32:
+ check_equal(ans_kernel, ans_no_kernel)
+ else:
+ check_equal(ans_kernel, ans_no_kernel, 1e-2)
+
+ # check bwd
+ out_shape = ans_kernel.shape
+ grad = torch.randn(out_shape, device=get_accelerator().get_current_device())
+
+ ans_kernel.backward(grad, retain_graph=True)
+ grad_kernel = tokens.grad.data.clone()
tokens.grad.zero_()
- layer.gate_weight.grad.zero_()
- layer.enable_kernel = True
- new_out = layer(tokens) # get outputs through colossal kernel
+ ans_no_kernel.backward(grad) # get gradient
+ grad_no_kernel = tokens.grad.data.clone()
+ tokens.grad.zero_()
if data_type == torch.float32:
- check_equal(old_out, new_out)
+ check_equal(grad_no_kernel, grad_kernel)
else:
- check_equal(old_out, new_out, 1e-2)
- # forward function passed
-
- new_out.backward(grad) # get new type gradient
- n_tk_grad = tokens.grad.data.clone()
- n_gt_grad = layer.gate_weight.grad.data.clone()
-
- if data_type == torch.float32:
- check_equal(o_tk_grad, n_tk_grad)
- else:
- check_equal(o_tk_grad, o_tk_grad, 1e-2)
- # tokens gradient is correct
-
- if data_type == torch.float32:
- check_equal(o_gt_grad, n_gt_grad, 5e-05)
- else:
- check_equal(o_gt_grad, n_gt_grad, 2e-01)
- # bias gradient is correct
+ check_equal(grad_no_kernel, grad_kernel, 1e-2)
-@pytest.mark.dist
-@pytest.mark.parametrize("rs", [131])
-@pytest.mark.parametrize("hidden_size", [32, 144])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
-@pytest.mark.parametrize("topk", [1, 2])
-@rerun_if_address_is_in_use()
-def test_moe_kernel(rs, hidden_size, data_type, topk):
- spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, topk=topk)
-
-
-if __name__ == "__main__":
- test_moe_kernel(2, 256, torch.float16, 2)
+def test_moe_kernel(data_type):
+ torch.manual_seed(1024)
+ run_moe_cumsum()
+ run_moe_dispatch_combine_fwd_bwd(data_type=data_type)
diff --git a/applications/ColossalMoE/tests/test_mixtral_layer.py b/tests/test_moe/test_mixtral_layer.py
similarity index 81%
rename from applications/ColossalMoE/tests/test_mixtral_layer.py
rename to tests/test_moe/test_mixtral_layer.py
index cbb70f195..b7b0322e0 100644
--- a/applications/ColossalMoE/tests/test_mixtral_layer.py
+++ b/tests/test_moe/test_mixtral_layer.py
@@ -3,13 +3,13 @@ from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
-from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock
from torch.testing import assert_close
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import colossalai
-from colossalai.moe import MOE_MANAGER
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock
from colossalai.testing.utils import spawn
tokens, n_experts = 7, 4
@@ -19,8 +19,11 @@ top_k = 2
def check_mixtral_moe_layer():
torch.cuda.set_device(dist.get_rank())
- MOE_MANAGER.setup(
- parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1
+ plugin = MoeHybridParallelPlugin(
+ precision="bf16",
+ tp_size=1,
+ pp_size=1,
+ ep_size=dist.get_world_size(),
)
config = MixtralConfig(
hidden_size=hidden_size,
@@ -33,7 +36,7 @@ def check_mixtral_moe_layer():
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output, orig_logits = orig_model(x)
model = deepcopy(orig_model)
- model = EPMixtralSparseMoeBlock.from_native_module(model)
+ model = EPMixtralSparseMoeBlock.from_native_module(model, ep_group=plugin.ep_group)
ep_output, ep_logits = model(x)
assert_close(orig_logits, ep_logits)
assert_close(orig_output, ep_output)
diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py
index 10e63592a..249dd4b97 100644
--- a/tests/test_moe/test_moe_checkpoint.py
+++ b/tests/test_moe/test_moe_checkpoint.py
@@ -1,201 +1,176 @@
-import importlib
import os
-import shutil
-import sys
+import tempfile
+from contextlib import nullcontext
+from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
-from transformers.models.llama import LlamaConfig
+from torch.optim import Adam
+from transformers.models.mixtral.configuration_mixtral import MixtralConfig
+from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
import colossalai
-from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
-from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
+from colossalai.checkpoint_io import MoECheckpointIO
+from colossalai.tensor.moe_tensor.api import is_moe_tensor
+from colossalai.testing.utils import spawn
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
- "examples/language/openmoe",
- )
-)
-
-OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM
-set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args
-OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy
+tokens, n_experts = 7, 4
+hidden_size = 8
+top_k = 2
-def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
- input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_accelerator().get_current_device())
- attention_mask = torch.ones_like(input_ids)
+def check_model_equal(model1, model2):
+ assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
+ for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
+ if not torch.equal(p1.half(), p2.half()):
+ print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
+ raise AssertionError(f"Model parameter {name} is not equal")
+
+
+def get_optimizer_snapshot(optim):
+ state = {id(k): deepcopy(v) for k, v in optim.state.items()}
+ param_groups = []
+ for group in optim.param_groups:
+ params = [id(p) for p in group["params"]]
+ new_group = {"params": params}
+ for k, v in group.items():
+ if k != "params":
+ new_group[k] = v
+ param_groups.append(new_group)
return {
- "input_ids": input_ids,
- "attention_mask": attention_mask,
- "labels": input_ids,
+ "state": state,
+ "param_groups": param_groups,
}
-def run_fwd_bwd(
- model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None
-):
- model.train()
- if pipeline:
- train_dataloader_iter = DummyDataloader(data_gen_fn, length=1)
- is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
- y = booster.execute_pipeline(
- train_dataloader_iter,
- model,
- lambda x, y: x.loss,
- optimizer,
- return_loss=True,
- )
- # Backward and optimize
- if is_pp_last_stage:
- loss = y["loss"]
- else:
- if criterion:
- y = model(data).logits
- loss = criterion(y)
+def check_optimizer_snapshot_equal(snapshot1, snapshot2, param2name, moe_dp_group=None):
+ assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
+ for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
+ assert set(group1.keys()) == set(group2.keys())
+ for k in group1.keys():
+ assert group1[k] == group2[k]
+ # check state
+ assert set(snapshot1["state"].keys()) == set(
+ snapshot2["state"].keys()
+ ), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
+
+ passed = True
+ count = 0
+ for pid in snapshot1["state"].keys():
+ state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
+ assert set(state1.keys()) == set(state2.keys())
+ bug = False
+ for k in state1.keys():
+ if isinstance(state1[k], torch.Tensor):
+ if not torch.equal(state1[k], state2[k]):
+ bug = True
+ count += 1
+ else:
+ assert state1[k] == state2[k]
+ if bug:
+ passed = False
+
+ if not passed:
+ raise AssertionError(f"A total of {count} optim states are not equal")
+
+
+def check_mixtral_moe_layer():
+ context = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
+ with context as f:
+ torch.cuda.set_device(dist.get_rank())
+ if dist.get_rank() == 0:
+ broadcast_objects = [f] # any picklable object
else:
- loss = model(data, label)
- loss = loss.float()
+ broadcast_objects = [None]
+ dist.broadcast_object_list(broadcast_objects, src=0)
- if optimizer is not None:
- optimizer.backward(loss)
- else:
- loss.backward()
- return y
-
-
-def get_config():
- config = LlamaConfig(
- vocab_size=300,
- hidden_size=16,
- intermediate_size=32,
- num_hidden_layers=2,
- num_attention_heads=2,
- head_dim=4,
- dropout_rate=0.0,
- hidden_act="swiglu",
- )
- set_openmoe_args(config, num_experts=8, moe_layer_interval=1)
- return config
-
-
-def get_model(parallel):
- config = get_config()
- model = OpenMoeForCausalLM(config)
- optim = torch.optim.Adam(model.parameters())
-
- if parallel == None:
- plugin = MoeHybridParallelPlugin(
- precision="bf16",
- tp_size=1,
- pp_size=1,
- ep_size=1,
- zero_stage=2,
- custom_policy=OpenMoeForCausalLMPolicy(),
+ config = MixtralConfig(
+ hidden_size=hidden_size,
+ intermediate_size=hidden_size * 2,
+ num_local_experts=n_experts,
+ num_experts_per_tok=top_k,
+ num_attention_heads=2,
+ num_key_value_heads=2,
)
- elif parallel == "ep":
+ torch.manual_seed(0)
+ input_ids = torch.randint(0, 100, (2, tokens)).cuda()
+ orig_model = MixtralForCausalLM(config).cuda()
+ model = deepcopy(orig_model)
+ optimizer = Adam(model.parameters(), lr=1e-3)
plugin = MoeHybridParallelPlugin(
- precision="bf16",
- tp_size=1,
- pp_size=1,
- ep_size=dist.get_world_size(),
- zero_stage=2,
- custom_policy=OpenMoeForCausalLMPolicy(),
- )
- elif parallel == "ep_zero":
- plugin = MoeHybridParallelPlugin(
- precision="bf16",
- tp_size=1,
- pp_size=1,
- ep_size=2,
- zero_stage=2,
- extra_dp_size=2,
- custom_policy=OpenMoeForCausalLMPolicy(),
- )
- elif parallel == "hybrid":
- plugin = MoeHybridParallelPlugin(
- precision="bf16",
- tp_size=1,
pp_size=2,
ep_size=2,
- zero_stage=1,
+ tp_size=1,
+ checkpoint_io=MoECheckpointIO,
microbatch_size=1,
- custom_policy=OpenMoeForCausalLMPolicy(),
+ zero_stage=1,
)
- booster = Booster(plugin=plugin)
- model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
- return model, booster, optim
+ booster = Booster(plugin=plugin)
+ model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
+ # initialize grads
+ data_iter = iter(
+ [{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
+ )
+ booster.execute_pipeline(
+ data_iter,
+ model,
+ lambda outputs, inputs: outputs.loss,
+ optimizer,
+ )
+
+ tmpdirname = broadcast_objects[0]
+ model_dir = os.path.join(tmpdirname, "mixtral_model")
+ hf_model_dir = os.path.join(tmpdirname, "mixtral_hf_model")
+ optim_dir = os.path.join(tmpdirname, "mixtral_optim")
+
+ booster.save_model(model, model_dir, shard=True)
+ dist.barrier()
+ if dist.get_rank() == 0:
+ saved_model = MixtralForCausalLM.from_pretrained(model_dir).cuda()
+ check_model_equal(orig_model, saved_model)
+ # check_model_equal(model, saved_model)
+ saved_model.save_pretrained(hf_model_dir)
+ dist.barrier()
+ # check load model
+ new_model = MixtralForCausalLM(config).cuda()
+ new_optimizer = Adam(new_model.parameters(), lr=1e-3)
+ new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
+ booster.load_model(new_model, hf_model_dir)
+ check_model_equal(model, new_model)
+
+ # check save optimizer
+ optimizer.step()
+ for group in optimizer.param_groups:
+ group["lr"] = 0.1
+ snapshot = get_optimizer_snapshot(optimizer.unwrap())
+ booster.save_optimizer(optimizer, optim_dir, shard=True)
+ dist.barrier()
+
+ # reset optimizer state
+ for state in optimizer.unwrap().state.values():
+ for v in state.values():
+ if isinstance(v, torch.Tensor):
+ v.zero_()
+ booster.load_optimizer(optimizer, optim_dir)
+ loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
+ check_optimizer_snapshot_equal(snapshot, loaded_snapshot, None, model)
+ # Ensure rank 0 waits for all other ranks to finish
+ dist.barrier()
-def _test_moe_checkpoint(rank, parallel):
- model1, booster1, optim1 = get_model(parallel)
- model2, booster2, optim2 = get_model(parallel)
- model3, booster3, optim3 = get_model(parallel)
-
- # param ckpt
- # shard
- booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
- booster2.load_model(model2, "./tmp_ckpt1")
- # unshard
- booster1.save_model(model1, "./tmp_ckpt1.pth")
- booster3.load_model(model3, "./tmp_ckpt1.pth")
- # check
- check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
- check_state_dict_equal(model1.state_dict(), model3.state_dict(), False)
-
- # optim ckpt
- criterion = lambda x: x.mean()
- data = torch.randint(0, 4, (2, 4)).cuda()
- label = torch.randint(0, 4, (2,)).cuda()
- if parallel == "hybrid":
- kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}
- else:
- kwargs = {}
- run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)
- optim1.step()
- optim1.zero_grad()
- # shard
- booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
- dist.barrier()
- booster2.load_optimizer(optim2, "./tmp_ckpt2")
- # unshard
- booster1.save_optimizer(optim1, "./tmp_ckpt2.pth")
- booster3.load_optimizer(optim3, "./tmp_ckpt2.pth")
- # check
- check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
- check_state_dict_equal(optim1.optim.state_dict(), optim3.optim.state_dict(), False)
-
- if dist.get_rank() == 0:
- shutil.rmtree("./tmp_ckpt1")
- shutil.rmtree("./tmp_ckpt2")
- os.remove("./tmp_ckpt1.pth")
- os.remove("./tmp_ckpt2.pth")
+def run_dist(rank: int, world_size: int, port: int):
+ colossalai.launch(rank, world_size, "localhost", port)
+ check_mixtral_moe_layer()
-def _run_dist(rank, world_size, port, parallel):
- colossalai.launch(
- config=dict(),
- rank=rank,
- world_size=world_size,
- host="localhost",
- port=port,
- backend="nccl",
- )
- _test_moe_checkpoint(rank, parallel)
-
-
-@pytest.mark.skip(reason="This is tested in ColossalMOE")
-@pytest.mark.dist
+# Test EP + ZeRO + PP
@pytest.mark.parametrize("world_size", [4])
-@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])
-@rerun_if_address_is_in_use()
-def test_moe_checkpoint(world_size, parallel):
- spawn(_run_dist, world_size, parallel=parallel)
+def test_mixtral_moe_layer(world_size: int):
+ spawn(run_dist, world_size)
if __name__ == "__main__":
- test_moe_checkpoint(world_size=4, parallel="hybrid")
+ test_mixtral_moe_layer(4)
diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py
index 660fbd358..9bc11033a 100644
--- a/tests/test_moe/test_moe_ep_tp.py
+++ b/tests/test_moe/test_moe_ep_tp.py
@@ -8,15 +8,16 @@ import torch.distributed as dist
import colossalai
from colossalai.accelerator import get_accelerator
-from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import sync_moe_model_param
+
+# from colossalai.shardformer.layer import SparseMLP
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeGradientHandler
-def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_flag: bool = False) -> None:
+def sync_tp_from_local(tp_model, local_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from local model
Args:
@@ -48,7 +49,7 @@ def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_
tp_param.data.copy_(local_param[tuple(tp_slice)].data)
-def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
+def sync_tp_from_ep(tp_model, ep_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
@@ -90,7 +91,7 @@ def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag:
tp_param.data.copy_(new_tp_param.data)
-def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
+def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
@@ -216,6 +217,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
)
+@pytest.mark.skip(reason="moe need to be refactored")
@pytest.mark.dist
@pytest.mark.parametrize("num_experts", [4, 64])
@pytest.mark.parametrize("batch_size", [16])
diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py
index b7be54d26..89baf1d37 100644
--- a/tests/test_moe/test_moe_group.py
+++ b/tests/test_moe/test_moe_group.py
@@ -4,9 +4,10 @@ import torch.nn as nn
import colossalai
from colossalai.accelerator import get_accelerator
-from colossalai.moe.experts import MLPExperts
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import sync_moe_model_param
+
+# from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
HIDDEN_SIZE = 4
@@ -69,6 +70,7 @@ def _run_test(rank, world_size, port, expert_parallel):
run_moe_init(expert_parallel)
+@pytest.mark.skip(reason="moe need to be refactored")
@pytest.mark.dist
@pytest.mark.parametrize("expert_parallel", ["EP", "TP"])
@rerun_if_address_is_in_use()
diff --git a/tests/test_moe/test_moe_hybrid_zero.py b/tests/test_moe/test_moe_hybrid_zero.py
index 7932fa8a7..513c4ebda 100644
--- a/tests/test_moe/test_moe_hybrid_zero.py
+++ b/tests/test_moe/test_moe_hybrid_zero.py
@@ -86,6 +86,7 @@ def run_dist(rank, world_size, port):
run_zero_optim_test(rank, world_size, stage=2)
+@pytest.mark.skip(reason="moe need to be refactored")
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py
index fae189bac..ddd3ea368 100644
--- a/tests/test_moe/test_moe_load_balance.py
+++ b/tests/test_moe/test_moe_load_balance.py
@@ -6,8 +6,9 @@ import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
-from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
+
+# from colossalai.shardformer.layer.moe import apply_load_balance
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
@@ -176,6 +177,7 @@ def run_dist(rank, world_size, port):
run_hybrid_zero_optim_test(rank, world_size, stage=2)
+@pytest.mark.skip(reason="moe need to be refactored")
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
diff --git a/tests/test_moe/test_moe_router.py b/tests/test_moe/test_moe_router.py
deleted file mode 100644
index 9f6167692..000000000
--- a/tests/test_moe/test_moe_router.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import pytest
-import torch
-
-from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter
-
-
-@pytest.mark.parametrize(
- ["router", "num_groups"],
- [
- (Top1Router(), 1),
- (Top2Router(), 1),
- # (TopKRouter(num_selected_experts=3), 4),
- ],
-)
-@pytest.mark.parametrize(
- ["batch_size", "seq_len", "num_experts"],
- [
- (4, 5, 8),
- (3, 4, 4),
- ],
-)
-def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int):
- x = torch.randn((batch_size * seq_len, num_experts)).cuda()
- if num_groups > 1:
- x = x.expand(num_groups, -1, -1)
-
- router.train()
- if isinstance(router, TopKRouter):
- combine_array, dispatch_mask = router(x, expert_capacity=2)
- else:
- combine_array, dispatch_mask = router(x)[1:3]
- assert combine_array.shape[:-1] == x.shape
- assert dispatch_mask.shape[:-1] == x.shape
- assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
-
- router.eval()
- if isinstance(router, TopKRouter):
- combine_array, dispatch_mask = router(x, expert_capacity=2)
- else:
- combine_array, dispatch_mask = router(x)[1:3]
- assert combine_array.shape[:-1] == x.shape
- assert dispatch_mask.shape[:-1] == x.shape
- assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
-
-
-if __name__ == "__main__":
- test_router_forward(Top2Router(), 4, 4, 4, 1)
diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py
deleted file mode 100644
index 3bb08b49e..000000000
--- a/tests/test_moe/test_moe_zero_fwd_bwd.py
+++ /dev/null
@@ -1,78 +0,0 @@
-import pytest
-import torch
-
-import colossalai
-from colossalai.booster import Booster
-from colossalai.booster.plugin import LowLevelZeroPlugin
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.testing.random import seed_all
-from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep
-
-
-def run_zero_test(local_rank, stage=1):
- criterion = torch.nn.CrossEntropyLoss()
-
- MOE_MANAGER.__init__()
- MOE_MANAGER.setup(parallel="EP")
- moe_model = MoeModel().bfloat16()
- moe_optimizer = torch.optim.Adam(moe_model.parameters())
- moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
- moe_booster = Booster(plugin=moe_plugin)
- moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
-
- MOE_MANAGER.__init__()
- MOE_MANAGER.setup(parallel=None)
- zero_model = MoeModel().bfloat16()
- delete_moe_info(zero_model)
- zero_optimizer = torch.optim.Adam(zero_model.parameters())
- zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
- zero_booster = Booster(plugin=zero_plugin)
- zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
- sync_local_from_ep(zero_model, moe_model)
-
- data = torch.randn(16, 4).bfloat16().cuda()
- label = torch.randint(0, 4, (16,)).cuda()
-
- zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
- moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
- assert torch.allclose(zero_out, moe_out)
-
- for (moe_name, moe_param), (zero_name, zero_param) in zip(
- moe_model.module.named_parameters(), zero_model.module.named_parameters()
- ):
- assert moe_name == zero_name
- moe_grad_list = moe_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(moe_param))
- zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
- if hasattr(moe_param, "moe_info"):
- assert len(moe_grad_list) == 0
- if stage == 1:
- zero_grad = zero_grad_list[local_rank].view(moe_param.grad.shape)
- else:
- zero_grad = zero_grad_list[0].view(moe_param.grad.shape)
- assert torch.allclose(
- moe_param.grad, zero_grad, atol=1e-5
- ), f"zero grad:\n{moe_param.grad}\ntorch grad:\n{zero_grad}\nmax diff: {(moe_param.grad - zero_grad).abs().max()}, mean diff: {(moe_param.grad - zero_grad).abs().mean()}"
- else:
- assert len(moe_grad_list) > 0
- assert len(moe_grad_list) == len(zero_grad_list)
- for moe_grad, zero_grad in zip(moe_grad_list, zero_grad_list):
- assert torch.allclose(moe_grad, zero_grad)
-
-
-def run_dist(rank, world_size, port, stage):
- colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- seed_all(42 + rank)
- run_zero_test(rank, stage=stage)
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [2])
-@pytest.mark.parametrize("stage", [1, 2])
-@rerun_if_address_is_in_use()
-def test_moe_zero_model(world_size, stage):
- spawn(run_dist, world_size, stage=stage)
-
-
-if __name__ == "__main__":
- test_moe_zero_model(world_size=2, stage=1)
diff --git a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py
new file mode 100644
index 000000000..042b3d8ae
--- /dev/null
+++ b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py
@@ -0,0 +1,132 @@
+from copy import deepcopy
+
+import pytest
+import torch
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+from transformers.models.mixtral.configuration_mixtral import MixtralConfig
+from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
+
+import colossalai
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock
+from colossalai.tensor.moe_tensor.api import is_moe_tensor
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.testing.random import seed_all
+from colossalai.zero import LowLevelZeroOptimizer
+from tests.test_moe.moe_utils import loose_close
+
+tokens, n_experts = 7, 4
+hidden_size = 8
+top_k = 2
+
+
+def split_grad(grad, world_size):
+ with torch.no_grad():
+ grad = grad.clone().detach().flatten()
+ padding_size = (world_size - grad.numel() % world_size) % world_size
+ if padding_size > 0:
+ grad = torch.nn.functional.pad(grad, [0, padding_size])
+ splited_grad = grad.split(grad.numel() // world_size)
+ return splited_grad
+
+
+@parameterize("dtype", [torch.float16, torch.bfloat16])
+@parameterize("master_weights", [True, False])
+@parameterize("stage", [1, 2])
+def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.dtype, stage: int):
+ rank = torch.distributed.get_rank()
+ torch.cuda.set_device(dist.get_rank())
+ plugin = MoeHybridParallelPlugin(
+ tp_size=1,
+ pp_size=1,
+ ep_size=dist.get_world_size() // 2,
+ )
+
+ seed_all(10086)
+ config = MixtralConfig(
+ hidden_size=hidden_size,
+ intermediate_size=hidden_size * 2,
+ num_local_experts=n_experts,
+ num_experts_per_tok=top_k,
+ )
+
+ orig_model = MixtralSparseMoeBlock(config).to(dtype).cuda()
+
+ ori_model = DDP(orig_model.cuda(), static_graph=True).cuda()
+
+ zero_model = deepcopy(orig_model).to(dtype)
+ zero_model = EPMixtralSparseMoeBlock.from_native_module(zero_model, ep_group=plugin.ep_group)
+
+ zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
+ pg_param_list = {plugin.global_dp_group: [], plugin.moe_dp_group: []}
+ for p in zero_model.parameters():
+ if is_moe_tensor(p):
+ pg_param_list[plugin.moe_dp_group].append(p)
+ else:
+ pg_param_list[plugin.global_dp_group].append(p)
+
+ zero_optimizer = LowLevelZeroOptimizer(
+ zero_optimizer,
+ pg_to_param_list=pg_param_list,
+ master_weights=master_weights,
+ initial_scale=1,
+ overlap_communication=False,
+ partition_grad=True,
+ )
+
+ ori_optimizer = torch.optim.SGD(ori_model.parameters(), lr=1)
+
+ # create
+ seed_all(1453 + rank)
+
+ for _ in range(2):
+ # zero-dp forward
+ input_data = torch.rand(1, tokens, hidden_size).cuda()
+ zero_output, zero_logits = zero_model(input_data.to(dtype))
+
+ # torch-ddp forward
+ ori_output, ori_logits = ori_model(input_data.to(dtype))
+ loose_close(zero_output, ori_output, dtype=dtype)
+
+ # zero-dp backward
+ zero_optimizer.backward(zero_output.mean().float())
+
+ # torch-ddp backward
+ ori_output.mean().backward()
+
+ # check grad
+ name_to_p = {n: p for n, p in ori_model.module.named_parameters()}
+ for n, p in zero_model.named_parameters():
+ zero_grad = zero_optimizer.get_param_grad(p)
+ if name_to_p[n].grad is None:
+ assert zero_grad is None
+ continue
+
+ loose_close(zero_grad, name_to_p[n].grad, dtype=dtype)
+
+ # zero-dp step
+ zero_optimizer.step()
+
+ # original model step
+ ori_optimizer.step()
+
+ # check updated param
+ for n, p in zero_model.named_parameters():
+ loose_close(p.data, name_to_p[n].data, dtype=dtype)
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_zero_with_original_model(world_size=world_size)
+
+
+@pytest.mark.dist
+@pytest.mark.parametrize("world_size", [2, 4])
+@rerun_if_address_is_in_use()
+def test_moe_zero_model(world_size):
+ spawn(run_dist, world_size)
+
+
+if __name__ == "__main__":
+ test_moe_zero_model(world_size=4)
diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py
deleted file mode 100644
index 224c5c3b9..000000000
--- a/tests/test_moe/test_moe_zero_optim.py
+++ /dev/null
@@ -1,83 +0,0 @@
-import pytest
-import torch
-
-import colossalai
-from colossalai.booster import Booster
-from colossalai.booster.plugin import LowLevelZeroPlugin
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.tensor.moe_tensor.api import is_moe_tensor
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.testing.random import seed_all
-from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, run_fwd_bwd, sync_local_from_ep
-
-
-def run_zero_test(local_rank, stage=1):
- criterion = torch.nn.CrossEntropyLoss()
-
- MOE_MANAGER.__init__()
- MOE_MANAGER.setup(parallel="EP")
- moe_model = MoeModel().bfloat16()
- moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=1.0)
- moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
- moe_booster = Booster(plugin=moe_plugin)
- moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
-
- MOE_MANAGER.__init__()
- MOE_MANAGER.setup(parallel=None)
- zero_model = MoeModel().bfloat16()
- delete_moe_info(zero_model)
- sync_local_from_ep(zero_model, moe_model)
- zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1.0)
- zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
- zero_booster = Booster(plugin=zero_plugin)
- zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
-
- for (moe_name, moe_param), (zero_name, zero_param) in zip(
- moe_model.named_parameters(), zero_model.named_parameters()
- ):
- if ".experts." in moe_name:
- continue
- assert moe_name == zero_name
- assert torch.allclose(
- moe_param.data, zero_param.data
- ), f"{moe_name}\ntorch_param {moe_param.data}\nzero_param {zero_param.data}"
-
- for _ in range(1):
- data = torch.randn(2, 4).bfloat16().cuda()
- label = torch.randint(0, 4, (2,)).cuda()
-
- moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
- zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
- assert torch.allclose(zero_out, moe_out)
- moe_optimizer.step()
- zero_optimizer.step()
-
- for (moe_name, moe_param), (zero_name, zero_param) in zip(
- moe_model.named_parameters(), zero_model.named_parameters()
- ):
- assert moe_name == zero_name
- if is_moe_tensor(moe_param):
- param_size = moe_param.shape[0]
- zero_param = zero_param[local_rank * param_size : (local_rank + 1) * param_size]
- loose_close(moe_param.data, zero_param.data, dtype=moe_param.dtype)
-
- moe_optimizer.zero_grad()
- zero_optimizer.zero_grad()
-
-
-def run_dist(rank, world_size, port, stage):
- colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- seed_all(42 + rank)
- run_zero_test(rank, stage=stage)
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [2])
-@pytest.mark.parametrize("stage", [1, 2])
-@rerun_if_address_is_in_use()
-def test_moe_zero_optim(world_size, stage):
- spawn(run_dist, world_size, stage=stage)
-
-
-if __name__ == "__main__":
- test_moe_zero_optim(world_size=2, stage=1)
diff --git a/tests/test_optimizer/_utils.py b/tests/test_optimizer/_utils.py
index 313624e83..4046e4118 100644
--- a/tests/test_optimizer/_utils.py
+++ b/tests/test_optimizer/_utils.py
@@ -234,7 +234,7 @@ def check_dist_grad(sharded_optimizer, org_model, sharded_model, weight_layer_fo
if org_name in weight_layer_for_check:
org_grad = org_param.grad
group_id = dist.get_rank(sharded_optimizer.optim.dp_group)
- dist_grad = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(group_id, id(sharded_param))
+ dist_grad = sharded_optimizer.get_partitioned_gradients_by_param_id(group_id, id(sharded_param))
# dist_grad concat then reshape to org_grad shape
if dist_grad:
diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py
index 06c254e56..2da679d7d 100644
--- a/tests/test_optimizer/test_dist_adafactor.py
+++ b/tests/test_optimizer/test_dist_adafactor.py
@@ -316,7 +316,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
dp_process_group=dp_group,
verbose=True,
)
- shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened
+ shard_to_param = dist_optim.master_to_working_param # {id(): param tensor} but flattened
dist_optim.optim.setup_distributed(
tp_group=tp_group,
dp_group=dp_group,
diff --git a/tests/test_optimizer/test_dist_came.py b/tests/test_optimizer/test_dist_came.py
index c767e9684..45fe687b7 100644
--- a/tests/test_optimizer/test_dist_came.py
+++ b/tests/test_optimizer/test_dist_came.py
@@ -200,7 +200,7 @@ def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
dp_process_group=dp_group,
verbose=True,
)
- shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened
+ shard_to_param = dist_optim.master_to_working_param # {id(): param tensor} but flattened
dist_optim.optim.setup_distributed(
tp_group=tp_group,
dp_group=dp_group,
diff --git a/tests/test_optimizer/test_dist_lamb.py b/tests/test_optimizer/test_dist_lamb.py
index c1ff78c0c..66e8e49c7 100644
--- a/tests/test_optimizer/test_dist_lamb.py
+++ b/tests/test_optimizer/test_dist_lamb.py
@@ -229,7 +229,7 @@ def run_dist_lamb_fwd_bwd(
dp_process_group=dp_group,
verbose=True,
)
- shard_to_param = optim._param_store.master_to_working_param
+ shard_to_param = optim.master_to_working_param
optim.optim.setup_distributed(tp_group, dp_group, shard_to_param, is_zero=True)
else:
optim.setup_distributed(tp_group)
diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py
index be257e818..e37a050e3 100644
--- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py
+++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py
@@ -32,6 +32,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
+ dp_group = booster.plugin.dp_group
bert = unwrap_model(org_model, "BertModel", "bert")
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
@@ -53,8 +54,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
device = origin_norm.device
norm_groups = []
for group_id in range(sharded_optimizer.num_param_groups):
- working_grads = sharded_optimizer._grad_store.get_working_grads_by_group_id(group_id)
- norm_group = sharded_optimizer._compute_grad_norm(gradients=working_grads)
+ working_grads = sharded_optimizer.get_working_grads_by_group_id(group_id)
+ norm_group = sharded_optimizer._compute_grad_norm(dp_group, gradients=working_grads)
norm_groups.append(norm_group)
total_norm = 0.0
for norm in norm_groups:
diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py
index b73552cec..4d66692a4 100644
--- a/tests/test_shardformer/test_model/test_shard_command.py
+++ b/tests/test_shardformer/test_model/test_shard_command.py
@@ -62,10 +62,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
for p1, p2 in zip(command_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
- working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
- grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
+ working_p = sharded_optimizer.master_to_working_param[id(p2)]
+ grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = (
- 0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank
+ 0 if sharded_optimizer._partition_grads else sharded_optimizer.pid_to_bucket_store[id(p2)].local_rank
)
grad = grads[grad_index]
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py
index 3a8a1357d..8fe18f69b 100644
--- a/tests/test_shardformer/test_model/test_shard_llama.py
+++ b/tests/test_shardformer/test_model/test_shard_llama.py
@@ -62,10 +62,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
- working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
- grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
+ working_p = sharded_optimizer.master_to_working_param[id(p2)]
+ grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = (
- 0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank
+ 0
+ if sharded_optimizer._partition_grads
+ else sharded_optimizer.pid_to_bucket_store[id(working_p)].local_rank
)
grad = grads[grad_index]
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
diff --git a/tests/test_zero/test_low_level/test_mem_leak.py b/tests/test_zero/test_low_level/test_mem_leak.py
new file mode 100644
index 000000000..7fa59ccc5
--- /dev/null
+++ b/tests/test_zero/test_low_level/test_mem_leak.py
@@ -0,0 +1,61 @@
+import pytest
+import torch
+import torch.nn as nn
+
+import colossalai
+from colossalai.testing import rerun_if_address_is_in_use, spawn
+from colossalai.zero import LowLevelZeroOptimizer
+
+
+class MlpModel(nn.Module):
+ def __init__(self):
+ super(MlpModel, self).__init__()
+ self.linear1 = nn.Linear(123, 253)
+
+ def forward(self, x):
+ x = self.linear1(x)
+ return x
+
+
+DEL_CALLED = False
+
+
+class TestLowLevelZeroOptimizer(LowLevelZeroOptimizer):
+ def __del__(self):
+ super().__del__()
+ global DEL_CALLED
+ DEL_CALLED = True
+
+
+def exam_mem_leak(world_size):
+ """
+ In this test, we test whether del will be called after the optimizer
+ is out of scope.
+ """
+ # create models
+ zero_model = MlpModel().cuda()
+
+ # we only test stage 1 here
+ # in `check_sharded_param_consistency.py`, we will test whether
+ # level 1 and 2 will produce exactly the same results
+ zero_optimizer = TestLowLevelZeroOptimizer(torch.optim.SGD(zero_model.parameters(), lr=1))
+
+ del zero_optimizer
+
+ assert DEL_CALLED
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
+
+ exam_mem_leak(world_size=world_size)
+
+
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+def test_zero_1_2():
+ spawn(run_dist, 2)
+
+
+if __name__ == "__main__":
+ test_zero_1_2()
diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py
index 06a29bd1d..8df35bdaa 100644
--- a/tests/test_zero/test_low_level/test_zero1_2.py
+++ b/tests/test_zero/test_low_level/test_zero1_2.py
@@ -91,10 +91,13 @@ def exam_zero_1_2():
zero2_optimizer.backward(zero2_output.mean().float())
# check grad
- z1g_list = zero1_optimizer._grad_store.get_working_grads_by_group_id(0)
- z2g_list = zero2_optimizer._grad_store.get_working_grads_by_group_id(0)
- for z1g, z2g in zip(z1g_list, z2g_list):
- assert torch.equal(z1g, z2g)
+ for p1, p2 in zip(zero1_model.parameters(), zero2_model.parameters()):
+ g1 = zero1_optimizer.get_param_grad(p1)
+ g2 = zero2_optimizer.get_param_grad(p2)
+ if g1 is None or g2 is None:
+ assert g1 is None and g2 is None
+ continue
+ assert torch.allclose(g1, g2)
# step
zero1_optimizer.step()
@@ -102,7 +105,7 @@ def exam_zero_1_2():
# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
- assert torch.equal(z1p.data, z2p.data)
+ assert torch.allclose(z1p, z2p)
@parameterize("dtype", [torch.float16, torch.bfloat16])
@@ -120,7 +123,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
seed_all(1453)
# create models
- torch_model = MlpModel().cuda()
+ torch_model = MlpModel().cuda().to(dtype)
zero_model = copy.deepcopy(torch_model).to(dtype)
torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()
@@ -142,39 +145,41 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
seed_all(1453 + local_rank)
- # create
- input_data = torch.rand(32, 123).cuda()
- # zero-dp forward
- zero_output = zero_model(input_data.to(dtype))
+ for _ in range(2):
+ # create
+ input_data = torch.rand(32, 123).cuda().to(dtype)
- # torch-ddp forward
- torch_output = torch_model(input_data)
- loose_close(zero_output, torch_output, dtype=dtype)
+ # zero-dp forward
+ zero_output = zero_model(input_data)
- # zero-dp backward
- zero_optimizer.backward(zero_output.mean().float())
+ # torch-ddp forward
+ torch_output = torch_model(input_data)
+ loose_close(zero_output, torch_output, dtype=dtype)
- # torch-ddp backward
- torch_output.mean().backward()
+ # zero-dp backward
+ zero_optimizer.backward(zero_output.mean())
- # check grad
- for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
- if p.grad is not None:
- zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(z1p))
- torch_grad_list = split_ddp_grad(p.grad, world_size)
- for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
- loose_close(zero_grad, torch_grad, dtype=dtype)
+ # torch-ddp backward
+ torch_output.mean().backward()
- # zero-dp step
- zero_optimizer.step()
+ # check grad
+ for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
+ zero_grad = zero_optimizer.get_param_grad(z1p)
+ if p.grad is None:
+ assert zero_grad is None
+ continue
+ loose_close(p.grad, zero_grad, dtype=dtype)
- # torch ddp step
- torch_optimizer.step()
+ # zero-dp step
+ zero_optimizer.step()
- # check updated param
- for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
- loose_close(p.data, z1p.data, dtype=dtype)
+ # torch ddp step
+ torch_optimizer.step()
+
+ # check updated param
+ for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
+ loose_close(p, z1p, dtype=dtype)
def run_dist(rank, world_size, port):
From 8ab46b4000d36c76cde93c6bb553411e815980fb Mon Sep 17 00:00:00 2001
From: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Date: Mon, 1 Jul 2024 16:45:09 +0800
Subject: [PATCH 03/15] [Shardformer] change qwen2 modeling into gradient
checkpointing style (#5874)
---
colossalai/shardformer/modeling/qwen2.py | 17 +++++++++++++++--
1 file changed, 15 insertions(+), 2 deletions(-)
diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py
index e0aa5fba4..11c26822f 100644
--- a/colossalai/shardformer/modeling/qwen2.py
+++ b/colossalai/shardformer/modeling/qwen2.py
@@ -168,13 +168,27 @@ class Qwen2PipelineForwards:
next_decoder_cache = None
start_idx, end_idx = stage_index[0], stage_index[1]
+ num_ckpt_layers = 0
+ if self.gradient_checkpointing and self.training:
+ num_ckpt_layers = end_idx - start_idx
+ # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
+ if shard_config.gradient_checkpoint_config is not None:
+ num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
+ stage=stage_manager.stage,
+ num_stages=stage_manager.num_stages,
+ num_layers=end_idx - start_idx,
+ model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
+ num_model_chunks=stage_manager.num_model_chunks,
+ )
+ assert num_ckpt_layers <= end_idx - start_idx
+
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
- if self.gradient_checkpointing and self.training:
+ if idx - start_idx < num_ckpt_layers:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
@@ -198,7 +212,6 @@ class Qwen2PipelineForwards:
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
-
if output_attentions:
all_self_attns += (layer_outputs[1],)
From 936d0b0f7ba7f9b4e0d53c343bcf6afd10c63de1 Mon Sep 17 00:00:00 2001
From: Edenzzzz
Date: Mon, 1 Jul 2024 17:07:22 +0800
Subject: [PATCH 04/15] [doc] Update llama + sp compatibility; fix dist optim
table
Co-authored-by: Edenzzzz
---
.../en/features/distributed_optimizers.md | 52 +++++++++---------
docs/source/en/features/shardformer.md | 2 +-
.../features/distributed_optimizers.md | 53 +++++++++----------
docs/source/zh-Hans/features/shardformer.md | 2 +-
4 files changed, 53 insertions(+), 56 deletions(-)
diff --git a/docs/source/en/features/distributed_optimizers.md b/docs/source/en/features/distributed_optimizers.md
index f95b23304..279bc8f9d 100644
--- a/docs/source/en/features/distributed_optimizers.md
+++ b/docs/source/en/features/distributed_optimizers.md
@@ -87,44 +87,42 @@ optim = DistGaloreAwamW(
## Plugin compatibility
- Model/Feature |
- Lamb |
- GaLore |
- Adafactor |
- CAME |
+ Optimizer/Plugin |
+ Hybrid Parallel Plugin |
+ Low Level Zero Plugin |
+ Torch DDP Plugin |
+ Gemini Plugin |
+ Moe Hybrid Plugin |
- Hybrid Parallel Plugin |
+ Lamb |
✔️ |
✔️ |
✔️ |
- ✔️ |
-
-
- Low Level Zero Plugin |
- ✔️ |
- ❌ |
- ✔️ |
- ✔️ |
-
-
- Torch DDP Plugin |
- ✔️ |
- ✔️ |
- ✔️ |
- ✔️ |
-
-
- Gemini Plugin |
- ❌ |
- ❌ |
❌ |
❌ |
- Moe Hybrid Plugin |
+ GaLore |
+ ✔️ |
+ ✔️ |
+ ✔️ |
❌ |
❌ |
+
+
+ Adafactor |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ❌ |
+
+
+ CAME |
+ ✔️ |
+ ✔️ |
+ ✔️ |
❌ |
❌ |
diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md
index 68d310f5c..40b8954b5 100644
--- a/docs/source/en/features/shardformer.md
+++ b/docs/source/en/features/shardformer.md
@@ -55,7 +55,7 @@ Model/Feature Compatibility Matrix:
✔️ |
✔️ |
✔️ |
- ❌ |
+ ✔️ |
❌ |
diff --git a/docs/source/zh-Hans/features/distributed_optimizers.md b/docs/source/zh-Hans/features/distributed_optimizers.md
index 7a7068077..5761f8c55 100644
--- a/docs/source/zh-Hans/features/distributed_optimizers.md
+++ b/docs/source/zh-Hans/features/distributed_optimizers.md
@@ -84,44 +84,42 @@ optim = DistGaloreAwamW(
## 兼容性
- Model/Feature |
- Lamb |
- GaLore |
- Adafactor |
- CAME |
+ Optimizer/Plugin |
+ Hybrid Parallel Plugin |
+ Low Level Zero Plugin |
+ Torch DDP Plugin |
+ Gemini Plugin |
+ Moe Hybrid Plugin |
- Hybrid Parallel Plugin |
+ Lamb |
✔️ |
✔️ |
✔️ |
- ✔️ |
-
-
- Low Level Zero Plugin |
- ✔️ |
- ❌ |
- ✔️ |
- ✔️ |
-
-
- Torch DDP Plugin |
- ✔️ |
- ✔️ |
- ✔️ |
- ✔️ |
-
-
- Gemini Plugin |
- ❌ |
- ❌ |
❌ |
❌ |
- Moe Hybrid Plugin |
+ GaLore |
+ ✔️ |
+ ✔️ |
+ ✔️ |
❌ |
❌ |
+
+
+ Adafactor |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ❌ |
+
+
+ CAME |
+ ✔️ |
+ ✔️ |
+ ✔️ |
❌ |
❌ |
@@ -130,6 +128,7 @@ optim = DistGaloreAwamW(
+
## API 参考
diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md
index 00e1a13d6..02290f3d6 100644
--- a/docs/source/zh-Hans/features/shardformer.md
+++ b/docs/source/zh-Hans/features/shardformer.md
@@ -51,7 +51,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
✔️ |
✔️ |
✔️ |
- ❌ |
+ ✔️ |
❌ |
From 7c2f79fa98c837ee4f5995d7948371040fa94572 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Mon, 1 Jul 2024 17:16:41 +0800
Subject: [PATCH 05/15] [pre-commit.ci] pre-commit autoupdate (#5572)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* [pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/PyCQA/autoflake: v2.2.1 → v2.3.1](https://github.com/PyCQA/autoflake/compare/v2.2.1...v2.3.1)
- [github.com/pycqa/isort: 5.12.0 → 5.13.2](https://github.com/pycqa/isort/compare/5.12.0...5.13.2)
- [github.com/psf/black-pre-commit-mirror: 23.9.1 → 24.4.2](https://github.com/psf/black-pre-commit-mirror/compare/23.9.1...24.4.2)
- [github.com/pre-commit/mirrors-clang-format: v13.0.1 → v18.1.7](https://github.com/pre-commit/mirrors-clang-format/compare/v13.0.1...v18.1.7)
- [github.com/pre-commit/pre-commit-hooks: v4.3.0 → v4.6.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.3.0...v4.6.0)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
.pre-commit-config.yaml | 10 +++---
.../ColossalChat/coati/dataset/loader.py | 16 +++++----
.../ColossalChat/coati/models/loss.py | 1 +
.../ColossalChat/coati/models/reward_model.py | 1 +
.../ColossalChat/coati/trainer/utils.py | 1 +
.../colossal_eval/dataset/agieval.py | 14 ++++++--
.../colossal_eval/dataset/ceval.py | 6 +++-
.../colossal_eval/dataset/mtbench.py | 8 +++--
.../colossal_eval/models/huggingface.py | 4 ++-
.../colossalqa/chain/retrieval_qa/base.py | 1 +
.../chain/retrieval_qa/load_chain.py | 1 +
.../colossalqa/chain/retrieval_qa/stuff.py | 1 +
.../data_loader/table_dataloader.py | 1 -
.../ColossalQA/colossalqa/local/llm.py | 1 +
.../ColossalQA/colossalqa/local/utils.py | 1 +
applications/ColossalQA/colossalqa/memory.py | 1 +
.../ColossalQA/colossalqa/mylogging.py | 1 +
.../colossalqa/retrieval_conversation_en.py | 1 +
.../retrieval_conversation_universal.py | 1 +
.../colossalqa/retrieval_conversation_zh.py | 1 +
.../ColossalQA/colossalqa/retriever.py | 1 +
.../text_splitter/chinese_text_splitter.py | 1 +
.../examples/retrieval_conversation_en.py | 1 +
...rieval_conversation_en_customer_service.py | 1 +
.../examples/retrieval_conversation_zh.py | 1 +
...tent_classification_zh_customer_service.py | 1 +
.../meta_profiler/meta_registry/conv.py | 20 ++++++-----
colossalai/inference/batch_bucket.py | 12 +++----
colossalai/inference/config.py | 17 +++++----
colossalai/inference/core/engine.py | 1 -
colossalai/inference/core/rpc_engine.py | 1 -
colossalai/inference/executor/rpc_worker.py | 1 -
.../inference/kv_cache/kvcache_manager.py | 8 +++--
colossalai/inference/utils.py | 1 +
.../initializer_2d.py | 4 +--
colossalai/legacy/inference/async_engine.py | 1 -
.../inference/dynamic_batching/io_struct.py | 12 +++----
.../inference/hybridengine/modeling/_utils.py | 1 +
.../tensor_parallel/batch_infer_state.py | 1 +
.../tensor_parallel/kvcache_manager.py | 1 +
.../tensor_parallel/modeling/_utils.py | 1 +
.../modeling/chatglm2_6b/modeling_chatglm.py | 1 +
.../nvidia_bert_dataset_provider.py | 8 +++--
.../diffusion/ldm/models/diffusion/ddpm.py | 20 ++++++-----
.../models/diffusion/dpm_solver/sampler.py | 1 +
.../modules/diffusionmodules/openaimodel.py | 36 ++++++++++---------
.../ldm/modules/midas/midas/midas_net.py | 1 +
.../modules/midas/midas/midas_net_custom.py | 1 +
.../diffusion/ldm/modules/midas/utils.py | 1 +
.../data/datasets/helpers.cpp | 12 +++----
extensions/csrc/kernel/arm/cpu_adam_arm.h | 2 +-
extensions/csrc/kernel/x86/cpu_adam.h | 2 +-
.../kit/model_zoo/torchvision/torchvision.py | 12 +++----
53 files changed, 157 insertions(+), 100 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 9871e1184..f2c408bce 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,34 +1,34 @@
repos:
- repo: https://github.com/PyCQA/autoflake
- rev: v2.2.1
+ rev: v2.3.1
hooks:
- id: autoflake
name: autoflake (python)
args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
- repo: https://github.com/pycqa/isort
- rev: 5.12.0
+ rev: 5.13.2
hooks:
- id: isort
name: sort all imports (python)
- repo: https://github.com/psf/black-pre-commit-mirror
- rev: 23.9.1
+ rev: 24.4.2
hooks:
- id: black
name: black formatter
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
- repo: https://github.com/pre-commit/mirrors-clang-format
- rev: v13.0.1
+ rev: v18.1.7
hooks:
- id: clang-format
name: clang formatter
types_or: [c++, c]
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.3.0
+ rev: v4.6.0
hooks:
- id: check-yaml
- id: check-merge-conflict
diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py
index cea1b2dbb..a0cd17bb4 100755
--- a/applications/ColossalChat/coati/dataset/loader.py
+++ b/applications/ColossalChat/coati/dataset/loader.py
@@ -83,15 +83,19 @@ class DataCollatorForSupervisedDataset(object):
# `List[torch.Tensor]`
batch_input_ids = [
- torch.LongTensor(instance["input_ids"][: self.max_length])
- if len(instance["input_ids"]) > self.max_length
- else torch.LongTensor(instance["input_ids"])
+ (
+ torch.LongTensor(instance["input_ids"][: self.max_length])
+ if len(instance["input_ids"]) > self.max_length
+ else torch.LongTensor(instance["input_ids"])
+ )
for instance in instances
]
batch_labels = [
- torch.LongTensor(instance["labels"][: self.max_length])
- if len(instance["labels"]) > self.max_length
- else torch.LongTensor(instance["labels"])
+ (
+ torch.LongTensor(instance["labels"][: self.max_length])
+ if len(instance["labels"]) > self.max_length
+ else torch.LongTensor(instance["labels"])
+ )
for instance in instances
]
if self.tokenizer.padding_side == "right":
diff --git a/applications/ColossalChat/coati/models/loss.py b/applications/ColossalChat/coati/models/loss.py
index aaef447a4..e411dded1 100755
--- a/applications/ColossalChat/coati/models/loss.py
+++ b/applications/ColossalChat/coati/models/loss.py
@@ -1,6 +1,7 @@
"""
loss functions
"""
+
from typing import Optional, Tuple
import torch
diff --git a/applications/ColossalChat/coati/models/reward_model.py b/applications/ColossalChat/coati/models/reward_model.py
index 394f3ea90..573b9d889 100755
--- a/applications/ColossalChat/coati/models/reward_model.py
+++ b/applications/ColossalChat/coati/models/reward_model.py
@@ -1,6 +1,7 @@
"""
reward model
"""
+
from typing import Optional
import torch
diff --git a/applications/ColossalChat/coati/trainer/utils.py b/applications/ColossalChat/coati/trainer/utils.py
index 5ce1e9ef0..3c836b4b4 100755
--- a/applications/ColossalChat/coati/trainer/utils.py
+++ b/applications/ColossalChat/coati/trainer/utils.py
@@ -1,6 +1,7 @@
"""
Training utilities for Coati.
"""
+
from typing import Any
import torch
diff --git a/applications/ColossalEval/colossal_eval/dataset/agieval.py b/applications/ColossalEval/colossal_eval/dataset/agieval.py
index 32f8544e9..d5f230249 100644
--- a/applications/ColossalEval/colossal_eval/dataset/agieval.py
+++ b/applications/ColossalEval/colossal_eval/dataset/agieval.py
@@ -78,7 +78,9 @@ def get_prompt(line: Dict, dataset_name: str, logger: DistributedLogger) -> Dict
option_string = "ABCDEFG"
count = len(line["options"])
- input = "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:"
+ input = (
+ "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:"
+ )
all_classes = list(option_string[0:count])
@@ -150,7 +152,15 @@ def combine_prompt(prompt_path, dataset_name, load_explanation=True, chat_mode=F
)
elif dataset_name in chinese_qa_datasets:
question_input = (
- "问题:" + passage + " " + question + "\n" + "从以下选项中选择:" + " ".join(options) + "\n" + "答案:{}".format(label)
+ "问题:"
+ + passage
+ + " "
+ + question
+ + "\n"
+ + "从以下选项中选择:"
+ + " ".join(options)
+ + "\n"
+ + "答案:{}".format(label)
)
elif dataset_name in english_cloze_datasets:
question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer)
diff --git a/applications/ColossalEval/colossal_eval/dataset/ceval.py b/applications/ColossalEval/colossal_eval/dataset/ceval.py
index 2cf09ec4d..915f4d9b0 100644
--- a/applications/ColossalEval/colossal_eval/dataset/ceval.py
+++ b/applications/ColossalEval/colossal_eval/dataset/ceval.py
@@ -57,7 +57,11 @@ ceval_subject_mapping = {
"urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"],
"accountant": ["Accountant", "注册会计师", "Other"],
"fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"],
- "environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"],
+ "environmental_impact_assessment_engineer": [
+ "Environmental Impact Assessment Engineer",
+ "环境影响评价工程师",
+ "Other",
+ ],
"tax_accountant": ["Tax Accountant", "税务师", "Other"],
"physician": ["Physician", "医师资格", "Other"],
}
diff --git a/applications/ColossalEval/colossal_eval/dataset/mtbench.py b/applications/ColossalEval/colossal_eval/dataset/mtbench.py
index 9e74a4d82..031415567 100644
--- a/applications/ColossalEval/colossal_eval/dataset/mtbench.py
+++ b/applications/ColossalEval/colossal_eval/dataset/mtbench.py
@@ -56,9 +56,11 @@ class MTBenchDataset(BaseDataset):
"instruction": question["turns"],
"input": "",
"output": [],
- "target": [""] * turn_number
- if question["question_id"] not in reference
- else reference[question["question_id"]],
+ "target": (
+ [""] * turn_number
+ if question["question_id"] not in reference
+ else reference[question["question_id"]]
+ ),
}
if category in dataset["test"]:
diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py
index fff697e21..23c399cce 100644
--- a/applications/ColossalEval/colossal_eval/models/huggingface.py
+++ b/applications/ColossalEval/colossal_eval/models/huggingface.py
@@ -77,7 +77,9 @@ class HuggingFaceModel(BaseModel):
self.indices_for_choices[0].append(
self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1]
)
- self.indices_for_choices[1].append(self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1])
+ self.indices_for_choices[1].append(
+ self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1]
+ )
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict):
"""
diff --git a/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py b/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py
index 80dbf47de..2f9750de3 100644
--- a/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py
+++ b/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py
@@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
+
from __future__ import annotations
import copy
diff --git a/applications/ColossalQA/colossalqa/chain/retrieval_qa/load_chain.py b/applications/ColossalQA/colossalqa/chain/retrieval_qa/load_chain.py
index a2b1f81e3..8cb8ef536 100644
--- a/applications/ColossalQA/colossalqa/chain/retrieval_qa/load_chain.py
+++ b/applications/ColossalQA/colossalqa/chain/retrieval_qa/load_chain.py
@@ -8,6 +8,7 @@ This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
+
import copy
from typing import Any, Mapping, Optional, Protocol
diff --git a/applications/ColossalQA/colossalqa/chain/retrieval_qa/stuff.py b/applications/ColossalQA/colossalqa/chain/retrieval_qa/stuff.py
index bf7ad0ffc..64e476438 100644
--- a/applications/ColossalQA/colossalqa/chain/retrieval_qa/stuff.py
+++ b/applications/ColossalQA/colossalqa/chain/retrieval_qa/stuff.py
@@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
+
import copy
from typing import Any, List
diff --git a/applications/ColossalQA/colossalqa/data_loader/table_dataloader.py b/applications/ColossalQA/colossalqa/data_loader/table_dataloader.py
index 29542466f..0ad66f0ad 100644
--- a/applications/ColossalQA/colossalqa/data_loader/table_dataloader.py
+++ b/applications/ColossalQA/colossalqa/data_loader/table_dataloader.py
@@ -2,7 +2,6 @@
Class for loading table type data. please refer to Pandas-Input/Output for file format details.
"""
-
import glob
import os
diff --git a/applications/ColossalQA/colossalqa/local/llm.py b/applications/ColossalQA/colossalqa/local/llm.py
index 30a456c3d..58a4811d9 100644
--- a/applications/ColossalQA/colossalqa/local/llm.py
+++ b/applications/ColossalQA/colossalqa/local/llm.py
@@ -12,6 +12,7 @@ TEST_PROMPT_CHATGLM="续写文章:惊蛰一过,春寒加剧。先是料料
logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True)
"""
+
from typing import Any, List, Mapping, Optional
import torch
diff --git a/applications/ColossalQA/colossalqa/local/utils.py b/applications/ColossalQA/colossalqa/local/utils.py
index ed90264ca..2cbd474bd 100644
--- a/applications/ColossalQA/colossalqa/local/utils.py
+++ b/applications/ColossalQA/colossalqa/local/utils.py
@@ -1,6 +1,7 @@
"""
Generation utilities
"""
+
import json
from typing import List
diff --git a/applications/ColossalQA/colossalqa/memory.py b/applications/ColossalQA/colossalqa/memory.py
index 7a5512281..d8de544a5 100644
--- a/applications/ColossalQA/colossalqa/memory.py
+++ b/applications/ColossalQA/colossalqa/memory.py
@@ -2,6 +2,7 @@
Implement a memory class for storing conversation history
Support long term and short term memory
"""
+
from typing import Any, Dict, List
from colossalqa.chain.memory.summary import ConversationSummaryMemory
diff --git a/applications/ColossalQA/colossalqa/mylogging.py b/applications/ColossalQA/colossalqa/mylogging.py
index 574c33b41..67e2a83ed 100644
--- a/applications/ColossalQA/colossalqa/mylogging.py
+++ b/applications/ColossalQA/colossalqa/mylogging.py
@@ -1,6 +1,7 @@
"""
Class for logging with extra control for debugging
"""
+
import logging
diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_en.py b/applications/ColossalQA/colossalqa/retrieval_conversation_en.py
index 96bce82b9..cab168075 100644
--- a/applications/ColossalQA/colossalqa/retrieval_conversation_en.py
+++ b/applications/ColossalQA/colossalqa/retrieval_conversation_en.py
@@ -1,6 +1,7 @@
"""
Script for Chinese retrieval based conversation system backed by ChatGLM
"""
+
from typing import Tuple
from colossalqa.chain.retrieval_qa.base import RetrievalQA
diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py b/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py
index 6e77bb2ae..a991b202e 100644
--- a/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py
+++ b/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py
@@ -1,6 +1,7 @@
"""
Multilingual retrieval based conversation system
"""
+
from typing import List
from colossalqa.data_loader.document_loader import DocumentLoader
diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py b/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py
index 4eef41947..6c9b69117 100644
--- a/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py
+++ b/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py
@@ -1,6 +1,7 @@
"""
Script for Chinese retrieval based conversation system backed by ChatGLM
"""
+
from typing import Tuple
from colossalqa.chain.retrieval_qa.base import RetrievalQA
diff --git a/applications/ColossalQA/colossalqa/retriever.py b/applications/ColossalQA/colossalqa/retriever.py
index 6a0c69859..ec4941ddd 100644
--- a/applications/ColossalQA/colossalqa/retriever.py
+++ b/applications/ColossalQA/colossalqa/retriever.py
@@ -1,6 +1,7 @@
"""
Code for custom retriver with incremental update
"""
+
import copy
import hashlib
import os
diff --git a/applications/ColossalQA/colossalqa/text_splitter/chinese_text_splitter.py b/applications/ColossalQA/colossalqa/text_splitter/chinese_text_splitter.py
index 3815f5ed2..697af484b 100644
--- a/applications/ColossalQA/colossalqa/text_splitter/chinese_text_splitter.py
+++ b/applications/ColossalQA/colossalqa/text_splitter/chinese_text_splitter.py
@@ -1,6 +1,7 @@
"""
Code for Chinese text splitter
"""
+
from typing import Any, List, Optional
from colossalqa.text_splitter.utils import get_cleaned_paragraph
diff --git a/applications/ColossalQA/examples/retrieval_conversation_en.py b/applications/ColossalQA/examples/retrieval_conversation_en.py
index fe2b9b4db..b7339de93 100644
--- a/applications/ColossalQA/examples/retrieval_conversation_en.py
+++ b/applications/ColossalQA/examples/retrieval_conversation_en.py
@@ -1,6 +1,7 @@
"""
Script for English retrieval based conversation system backed by LLaMa2
"""
+
import argparse
import os
diff --git a/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py b/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py
index d4ba73b94..a0c90e7c5 100644
--- a/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py
+++ b/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py
@@ -1,6 +1,7 @@
"""
Script for English retrieval based conversation system backed by LLaMa2
"""
+
import argparse
import json
import os
diff --git a/applications/ColossalQA/examples/retrieval_conversation_zh.py b/applications/ColossalQA/examples/retrieval_conversation_zh.py
index b143b9baa..96641edf5 100644
--- a/applications/ColossalQA/examples/retrieval_conversation_zh.py
+++ b/applications/ColossalQA/examples/retrieval_conversation_zh.py
@@ -1,6 +1,7 @@
"""
Script for Chinese retrieval based conversation system backed by ChatGLM
"""
+
import argparse
import os
diff --git a/applications/ColossalQA/examples/retrieval_intent_classification_zh_customer_service.py b/applications/ColossalQA/examples/retrieval_intent_classification_zh_customer_service.py
index adb654494..865ade5bb 100644
--- a/applications/ColossalQA/examples/retrieval_intent_classification_zh_customer_service.py
+++ b/applications/ColossalQA/examples/retrieval_intent_classification_zh_customer_service.py
@@ -1,6 +1,7 @@
"""
Script for English retrieval based conversation system backed by LLaMa2
"""
+
import argparse
import os
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
index 2f630995c..b1e32e885 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
@@ -107,20 +107,22 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
- if has_bias
- else compute_size_in_bytes(weight_tensor),
+ parameter=(
+ compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor)
+ ),
temp=0,
buffer=0,
)
bwd_memory_cost = MemoryCost(
- activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
- if has_bias
- else compute_size_in_bytes([input_tensor, weight_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
- if has_bias
- else compute_size_in_bytes(weight_tensor),
+ activation=(
+ compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
+ if has_bias
+ else compute_size_in_bytes([input_tensor, weight_tensor])
+ ),
+ parameter=(
+ compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor)
+ ),
temp=0,
buffer=0,
)
diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py
index 88bde3a3b..581d114d2 100644
--- a/colossalai/inference/batch_bucket.py
+++ b/colossalai/inference/batch_bucket.py
@@ -247,16 +247,16 @@ class BatchBucket:
self._sequences_dict[seq.request_id] = seq
self._sequences_indexes[seq.request_id] = self._current_batch_size + i
# TODO external (rename): modify Sequence.sentence_len to seq_len
- self._sequence_lengths[
- self._current_batch_size : self._current_batch_size + num_seqs_to_add
- ] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
+ self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = (
+ torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
+ )
# NOTE block tables to be updated by kvcache manager
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add]
if alloc_block_tables is not None:
# copy block ids from provided block tables
- self._block_tables[
- self._current_batch_size : self._current_batch_size + num_seqs_to_add
- ] = alloc_block_tables
+ self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = (
+ alloc_block_tables
+ )
elif alloc_block_tables_fn:
alloc_block_tables_fn(
block_tables,
diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py
index c73ee9df4..e114e8a61 100644
--- a/colossalai/inference/config.py
+++ b/colossalai/inference/config.py
@@ -1,6 +1,7 @@
"""
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
"""
+
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
@@ -82,9 +83,9 @@ class InputMetaData(RPC_PARAM):
dtype: torch.dtype = torch.float32
use_spec_dec: bool = False
num_tokens_to_verify: int = 0
- batch_token_ids: Optional[
- List[List[int]]
- ] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
+ batch_token_ids: Optional[List[List[int]]] = (
+ None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
+ )
def to_rpc_param(self) -> Dict[str, any]:
return {
@@ -202,9 +203,9 @@ class InferenceConfig(RPC_PARAM):
prompt_template: Optional[str] = None
do_sample: bool = False
beam_width: int = 1 # TODO: beam search is not support for now
- prefill_ratio: Optional[
- float
- ] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
+ prefill_ratio: Optional[float] = (
+ 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
+ )
pad_input: bool = False
early_stopping: Optional[bool] = False
top_k: Optional[int] = 50
@@ -234,7 +235,9 @@ class InferenceConfig(RPC_PARAM):
high_precision: Optional[bool] = False
# cuda_graph
- use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
+ use_cuda_graph: bool = (
+ False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
+ )
max_context_len_to_capture: int = 512
# StreamingLLM (sliding window attention with attention sinks)
diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py
index f0918c88c..8f8aef65e 100644
--- a/colossalai/inference/core/engine.py
+++ b/colossalai/inference/core/engine.py
@@ -47,7 +47,6 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
class InferenceEngine:
-
"""
InferenceEngine which manages the inference process..
diff --git a/colossalai/inference/core/rpc_engine.py b/colossalai/inference/core/rpc_engine.py
index 87222a744..749360872 100644
--- a/colossalai/inference/core/rpc_engine.py
+++ b/colossalai/inference/core/rpc_engine.py
@@ -34,7 +34,6 @@ def run_server(host, port, event: mp.Event = None):
class RPCInferenceEngine(InferenceEngine):
-
"""
InferenceEngine which manages the inference process..
diff --git a/colossalai/inference/executor/rpc_worker.py b/colossalai/inference/executor/rpc_worker.py
index a5199cb74..a4fd20a69 100644
--- a/colossalai/inference/executor/rpc_worker.py
+++ b/colossalai/inference/executor/rpc_worker.py
@@ -42,7 +42,6 @@ logger = get_dist_logger(__name__)
class rpcWorkerService(rpyc.Service):
-
"""
Execute the computation tasks and manage its own kv cache
diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py
index 378eb2ff9..dac5037f4 100644
--- a/colossalai/inference/kv_cache/kvcache_manager.py
+++ b/colossalai/inference/kv_cache/kvcache_manager.py
@@ -279,9 +279,11 @@ class KVCacheManager:
block.add_ref()
self._allocate_on_block(
block,
- block.block_size
- if context_lengths[i] % block.block_size == 0
- else context_lengths[i].item() % block.block_size,
+ (
+ block.block_size
+ if context_lengths[i] % block.block_size == 0
+ else context_lengths[i].item() % block.block_size
+ ),
)
for block_id in alloc_block_ids:
if block_id in alloc_block_ids[last_block_locs]:
diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py
index 8c155e6ca..332e84d37 100644
--- a/colossalai/inference/utils.py
+++ b/colossalai/inference/utils.py
@@ -1,6 +1,7 @@
"""
Utils for model inference
"""
+
import math
import os
import re
diff --git a/colossalai/legacy/context/process_group_initializer/initializer_2d.py b/colossalai/legacy/context/process_group_initializer/initializer_2d.py
index 1c08d4d42..fc51844b6 100644
--- a/colossalai/legacy/context/process_group_initializer/initializer_2d.py
+++ b/colossalai/legacy/context/process_group_initializer/initializer_2d.py
@@ -138,9 +138,7 @@ class Initializer_2D(ProcessGroupInitializer):
self.num_group = self.world_size // self.tensor_parallel_size
self.summa_dim = int(math.sqrt(self.tensor_parallel_size))
- assert (
- self.tensor_parallel_size == self.summa_dim**2
- ), "2D summa dim should equal to tensor parallel size ^ 0.5"
+ assert self.tensor_parallel_size == self.summa_dim**2, "2D summa dim should equal to tensor parallel size ^ 0.5"
_check_summa_env_var(self.summa_dim)
self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs)
diff --git a/colossalai/legacy/inference/async_engine.py b/colossalai/legacy/inference/async_engine.py
index d0890ba3e..b4c523669 100644
--- a/colossalai/legacy/inference/async_engine.py
+++ b/colossalai/legacy/inference/async_engine.py
@@ -54,7 +54,6 @@ class RequestTracker:
class Async_Engine:
-
"""
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
Background loop: inference reqs in waiting list (Listen)
diff --git a/colossalai/legacy/inference/dynamic_batching/io_struct.py b/colossalai/legacy/inference/dynamic_batching/io_struct.py
index fc5ecfe57..abc41cc8e 100644
--- a/colossalai/legacy/inference/dynamic_batching/io_struct.py
+++ b/colossalai/legacy/inference/dynamic_batching/io_struct.py
@@ -118,16 +118,16 @@ class Batch:
class BatchTokenIdOut:
def __init__(self):
- self.reqs_infs: List[
- Tuple[str, int, Dict, bool, bool]
- ] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
+ self.reqs_infs: List[Tuple[str, int, Dict, bool, bool]] = (
+ []
+ ) # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
class BatchStrOut:
def __init__(self):
- self.reqs_infs: List[
- Tuple[str, str, Dict, bool, bool]
- ] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state]
+ self.reqs_infs: List[Tuple[str, str, Dict, bool, bool]] = (
+ []
+ ) # [req_id, token_str, gen_metadata, finished_state, abort_state]
class AbortReq:
diff --git a/colossalai/legacy/inference/hybridengine/modeling/_utils.py b/colossalai/legacy/inference/hybridengine/modeling/_utils.py
index 068b64b4f..46d4222c4 100644
--- a/colossalai/legacy/inference/hybridengine/modeling/_utils.py
+++ b/colossalai/legacy/inference/hybridengine/modeling/_utils.py
@@ -1,6 +1,7 @@
"""
Utils for model inference
"""
+
import os
import torch
diff --git a/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py b/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py
index f707a86df..b72610899 100644
--- a/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py
+++ b/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py
@@ -14,6 +14,7 @@ class BatchInferState:
Information to be passed and used for a batch of inputs during
a single model forward
"""
+
batch_size: int
max_len_in_batch: int
diff --git a/colossalai/legacy/inference/tensor_parallel/kvcache_manager.py b/colossalai/legacy/inference/tensor_parallel/kvcache_manager.py
index 91bb96a1f..8c54fda26 100644
--- a/colossalai/legacy/inference/tensor_parallel/kvcache_manager.py
+++ b/colossalai/legacy/inference/tensor_parallel/kvcache_manager.py
@@ -4,6 +4,7 @@ of the ModelTC/lightllm GitHub repository
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
"""
+
import torch
from transformers.utils import logging
diff --git a/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py b/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py
index 068b64b4f..46d4222c4 100644
--- a/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py
+++ b/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py
@@ -1,6 +1,7 @@
"""
Utils for model inference
"""
+
import os
import torch
diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py
index bf581300a..6ae4b06e5 100644
--- a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py
+++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py
@@ -33,6 +33,7 @@ This license shall be governed and construed in accordance with the laws of Peop
Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com.
"""
+
""" PyTorch ChatGLM model. """
import copy
diff --git a/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py b/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py
index 09677a619..4d08d9941 100644
--- a/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py
+++ b/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py
@@ -52,9 +52,11 @@ class pretraining_dataset(Dataset):
def __getitem__(self, index):
[input_ids, input_mask, segment_ids, masked_lm_labels] = [
- torch.from_numpy(input[index].astype(np.int64))
- if indice < 5
- else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
+ (
+ torch.from_numpy(input[index].astype(np.int64))
+ if indice < 5
+ else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
+ )
for indice, input in enumerate(self.inputs)
]
diff --git a/examples/images/diffusion/ldm/models/diffusion/ddpm.py b/examples/images/diffusion/ldm/models/diffusion/ddpm.py
index 20e26256e..3cf6aceb5 100644
--- a/examples/images/diffusion/ldm/models/diffusion/ddpm.py
+++ b/examples/images/diffusion/ldm/models/diffusion/ddpm.py
@@ -229,9 +229,7 @@ class DDPM(pl.LightningModule):
)
if self.parameterization == "eps":
- lvlb_weights = self.betas**2 / (
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)
- )
+ lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
elif self.parameterization == "x0":
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod))
elif self.parameterization == "v":
@@ -1186,9 +1184,11 @@ class LatentDiffusion(DDPM):
if cond is not None:
if isinstance(cond, dict):
cond = {
- key: cond[key][:batch_size]
- if not isinstance(cond[key], list)
- else list(map(lambda x: x[:batch_size], cond[key]))
+ key: (
+ cond[key][:batch_size]
+ if not isinstance(cond[key], list)
+ else list(map(lambda x: x[:batch_size], cond[key]))
+ )
for key in cond
}
else:
@@ -1321,9 +1321,11 @@ class LatentDiffusion(DDPM):
if cond is not None:
if isinstance(cond, dict):
cond = {
- key: cond[key][:batch_size]
- if not isinstance(cond[key], list)
- else list(map(lambda x: x[:batch_size], cond[key]))
+ key: (
+ cond[key][:batch_size]
+ if not isinstance(cond[key], list)
+ else list(map(lambda x: x[:batch_size], cond[key]))
+ )
for key in cond
}
else:
diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py
index 55dac8555..4104fe3b0 100644
--- a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py
+++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py
@@ -1,4 +1,5 @@
"""SAMPLING ONLY."""
+
import torch
from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py
index 6c80f3229..afde5dfd4 100644
--- a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py
+++ b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py
@@ -640,23 +640,25 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- )
- if not use_spatial_transformer
- else SpatialTransformer( # always uses a self-attn
- ch,
- num_heads,
- dim_head,
- depth=transformer_depth,
- context_dim=context_dim,
- disable_self_attn=disable_middle_self_attn,
- use_linear=use_linear_in_transformer,
- use_checkpoint=use_checkpoint,
+ (
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ if not use_spatial_transformer
+ else SpatialTransformer( # always uses a self-attn
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn,
+ use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint,
+ )
),
ResBlock(
ch,
diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py
index 0dd87b596..8c13f39ff 100644
--- a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py
+++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py
@@ -2,6 +2,7 @@
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
+
import torch
import torch.nn as nn
diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py
index 4d30744c4..c79581afc 100644
--- a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py
+++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py
@@ -2,6 +2,7 @@
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
+
import torch
import torch.nn as nn
diff --git a/examples/images/diffusion/ldm/modules/midas/utils.py b/examples/images/diffusion/ldm/modules/midas/utils.py
index 1428d42b2..f7fc7dcc9 100644
--- a/examples/images/diffusion/ldm/modules/midas/utils.py
+++ b/examples/images/diffusion/ldm/modules/midas/utils.py
@@ -1,4 +1,5 @@
"""Utils for monoDepth."""
+
import re
import sys
diff --git a/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp
index 52977e631..fe9968177 100644
--- a/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp
+++ b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp
@@ -369,9 +369,9 @@ py::array build_mapping_impl(const py::array_t& docs_,
}
} // for (auto sent_index=sent_index_first; ...
- } // if (num_remain_sent > 1) {
- } // for (int doc=0; doc < num_docs; ++doc) {
- } // for (int epoch=0; epoch < num_epochs; ++epoch) {
+ } // if (num_remain_sent > 1) {
+ } // for (int doc=0; doc < num_docs; ++doc) {
+ } // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
if (verbose) {
@@ -606,9 +606,9 @@ py::array build_blocks_mapping_impl(
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
- } // if (num_remain_sent > 1) {
- } // for (int doc=0; doc < num_docs; ++doc) {
- } // for (int epoch=0; epoch < num_epochs; ++epoch) {
+ } // if (num_remain_sent > 1) {
+ } // for (int doc=0; doc < num_docs; ++doc) {
+ } // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
if (verbose) {
diff --git a/extensions/csrc/kernel/arm/cpu_adam_arm.h b/extensions/csrc/kernel/arm/cpu_adam_arm.h
index c731850ed..d48968e21 100644
--- a/extensions/csrc/kernel/arm/cpu_adam_arm.h
+++ b/extensions/csrc/kernel/arm/cpu_adam_arm.h
@@ -4,7 +4,7 @@
#include
-#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
+#define ROUND_DOWN(size, step) ((size) & ~((step) - 1))
#define TILE (128 * 1024 * 1024)
#if defined(__aarch64__)
diff --git a/extensions/csrc/kernel/x86/cpu_adam.h b/extensions/csrc/kernel/x86/cpu_adam.h
index db1f26d5f..45e1dde62 100644
--- a/extensions/csrc/kernel/x86/cpu_adam.h
+++ b/extensions/csrc/kernel/x86/cpu_adam.h
@@ -32,7 +32,7 @@ SOFTWARE
#include
#endif
-#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
+#define ROUND_DOWN(size, step) ((size) & ~((step) - 1))
#define TILE (128 * 1024 * 1024)
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
diff --git a/tests/kit/model_zoo/torchvision/torchvision.py b/tests/kit/model_zoo/torchvision/torchvision.py
index 57b633e9d..c0524d089 100644
--- a/tests/kit/model_zoo/torchvision/torchvision.py
+++ b/tests/kit/model_zoo/torchvision/torchvision.py
@@ -34,14 +34,14 @@ def swin_s():
# special output transform fn
-google_net_output_transform_fn = (
- lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x)
+google_net_output_transform_fn = lambda x: (
+ dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x)
)
-swin_s_output_output_transform_fn = (
- lambda x: {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
+swin_s_output_output_transform_fn = lambda x: (
+ {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
)
-inception_v3_output_transform_fn = (
- lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x)
+inception_v3_output_transform_fn = lambda x: (
+ dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x)
)
model_zoo.register(
From ea94c07b959e8895b713d6dd68b168ea37db6b7b Mon Sep 17 00:00:00 2001
From: Haze188
Date: Tue, 2 Jul 2024 12:42:02 +0800
Subject: [PATCH 06/15] [hotfix] fix the bug that large tensor exceed the
maximum capacity of TensorBucket (#5879)
---
colossalai/zero/low_level/low_level_optim.py | 7 +++++++
1 file changed, 7 insertions(+)
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index e06cf0581..bdc91b51f 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -549,6 +549,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param = real_working_params[group_id][idx]
param_to_gather = master_param.to(device).to(self._dtype)
pg = self.param_to_pg[working_param]
+ if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
+ buffer_tensor = torch.empty_like(
+ torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
+ )
+ dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
+ working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
+ continue
try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError:
From eb24fcd914f4c38fb82bc082db84d13d50865572 Mon Sep 17 00:00:00 2001
From: Edenzzzz
Date: Wed, 3 Jul 2024 14:57:57 +0800
Subject: [PATCH 07/15] [Hotfix] Fix OPT gradient checkpointing forward
Co-authored-by: Edenzzzz
---
colossalai/shardformer/modeling/opt.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py
index f10860fef..b250b4976 100644
--- a/colossalai/shardformer/modeling/opt.py
+++ b/colossalai/shardformer/modeling/opt.py
@@ -221,7 +221,7 @@ class OPTPipelineForwards:
past_key_value = past_key_values[idx] if past_key_values is not None else None
if decoder.gradient_checkpointing and decoder.training:
- layer_outputs = self._gradient_checkpointing_func(
+ layer_outputs = self.decoder._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_attention_mask,
From 6cd4c32be4c0ced9a70e228530f383c5f4a580de Mon Sep 17 00:00:00 2001
From: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Date: Wed, 3 Jul 2024 20:02:19 +0800
Subject: [PATCH 08/15] [shardformer] fix the moe (#5883)
---
colossalai/booster/plugin/__init__.py | 10 +++++++-
colossalai/shardformer/policies/mixtral.py | 28 ++++++++++------------
2 files changed, 22 insertions(+), 16 deletions(-)
diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py
index 62f3708fc..7e0e6ffdd 100644
--- a/colossalai/booster/plugin/__init__.py
+++ b/colossalai/booster/plugin/__init__.py
@@ -1,10 +1,18 @@
from .gemini_plugin import GeminiPlugin
from .hybrid_parallel_plugin import HybridParallelPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin
+from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin
-__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"]
+__all__ = [
+ "Plugin",
+ "TorchDDPPlugin",
+ "GeminiPlugin",
+ "LowLevelZeroPlugin",
+ "HybridParallelPlugin",
+ "MoeHybridParallelPlugin",
+]
import torch
from packaging import version
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index f9721c79e..0fb858d78 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -40,21 +40,19 @@ class MixtralPolicy(Policy):
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
- if getattr(self.shard_config, "ep_group", None) is None:
- raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
-
- # expert parallel
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="block_sparse_moe",
- target_module=EPMixtralSparseMoeBlock,
- kwargs={"ep_group": self.shard_config.ep_group},
- )
- ],
- policy=policy,
- target_key=MixtralDecoderLayer,
- )
+ if getattr(self.shard_config, "ep_group", None) is not None:
+ # expert parallel
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="block_sparse_moe",
+ target_module=EPMixtralSparseMoeBlock,
+ kwargs={"ep_group": self.shard_config.ep_group},
+ )
+ ],
+ policy=policy,
+ target_key=MixtralDecoderLayer,
+ )
# optimization configuration
if self.shard_config.enable_fused_normalization:
From 7afbc81d6292f1a44cb5c2f89571c6c1c6d74691 Mon Sep 17 00:00:00 2001
From: Hongxin Liu
Date: Thu, 4 Jul 2024 11:33:23 +0800
Subject: [PATCH 09/15] [quant] fix bitsandbytes version check (#5882)
* [quant] fix bitsandbytes version check
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
colossalai/quantization/bnb.py | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
diff --git a/colossalai/quantization/bnb.py b/colossalai/quantization/bnb.py
index fa214116a..3601ef62b 100644
--- a/colossalai/quantization/bnb.py
+++ b/colossalai/quantization/bnb.py
@@ -1,17 +1,25 @@
# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py
+import importlib.metadata
import logging
import torch
import torch.nn as nn
+from packaging.version import Version
from .bnb_config import BnbQuantizationConfig
try:
import bitsandbytes as bnb
- IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0"
- IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2"
+ try:
+ # in case lower version of bitsandbytes does not have __version__ attribute
+ BNB_VERSION = Version(bnb.__version__)
+ except AttributeError:
+ BNB_VERSION = Version(importlib.metadata.version("bitsandbytes"))
+
+ IS_4BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.39.0")
+ IS_8BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.37.2")
except ImportError:
pass
From 7997683aac44cf99529589af4262fba52b29a74b Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Thu, 4 Jul 2024 13:46:41 +0800
Subject: [PATCH 10/15] [pre-commit.ci] pre-commit autoupdate (#5878)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
updates:
- [github.com/pre-commit/mirrors-clang-format: v18.1.7 → v18.1.8](https://github.com/pre-commit/mirrors-clang-format/compare/v18.1.7...v18.1.8)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
.pre-commit-config.yaml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index f2c408bce..9088d0e1b 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -21,7 +21,7 @@ repos:
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
- repo: https://github.com/pre-commit/mirrors-clang-format
- rev: v18.1.7
+ rev: v18.1.8
hooks:
- id: clang-format
name: clang formatter
From 3420921101186ffa6e6f9428bbb4036302230ccd Mon Sep 17 00:00:00 2001
From: Haze188
Date: Fri, 5 Jul 2024 16:13:58 +0800
Subject: [PATCH 11/15] [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
colossalai/cluster/process_group_mesh.py | 2 +-
colossalai/shardformer/modeling/deepseek.py | 429 ++++++++++++++++++
.../shardformer/policies/auto_policy.py | 8 +-
colossalai/shardformer/policies/deepseek.py | 212 +++++++++
colossalai/shardformer/policies/mixtral.py | 6 +-
tests/test_moe/test_deepseek_layer.py | 72 +++
tests/test_moe/test_moe_checkpoint.py | 38 +-
7 files changed, 748 insertions(+), 19 deletions(-)
create mode 100644 colossalai/shardformer/modeling/deepseek.py
create mode 100644 colossalai/shardformer/policies/deepseek.py
create mode 100644 tests/test_moe/test_deepseek_layer.py
diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py
index 1319a4529..b6aff0d72 100644
--- a/colossalai/cluster/process_group_mesh.py
+++ b/colossalai/cluster/process_group_mesh.py
@@ -147,7 +147,7 @@ class ProcessGroupMesh:
ProcessGroup: The process group with the given ranks.
"""
ranks_in_group = sorted(ranks_in_group)
- if tuple(ranks_in_group) not in self._group_to_ranks:
+ if tuple(ranks_in_group) not in self._ranks_to_group:
group = dist.new_group(ranks_in_group, backend=backend)
self._ranks_to_group[tuple(ranks_in_group)] = group
self._group_to_ranks[group] = tuple(ranks_in_group)
diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py
new file mode 100644
index 000000000..6e79ce144
--- /dev/null
+++ b/colossalai/shardformer/modeling/deepseek.py
@@ -0,0 +1,429 @@
+from typing import List, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed import ProcessGroup
+
+# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
+from torch.nn import CrossEntropyLoss
+from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.utils import is_flash_attn_2_available, logging
+
+from colossalai.lazy import LazyInitContext
+from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.shard import ShardConfig
+from colossalai.shardformer.shard.utils import set_tensors_to_none
+
+
+# copied from modeling_deepseek.py
+class AddAuxiliaryLoss(torch.autograd.Function):
+ """
+ The trick function of adding auxiliary (aux) loss,
+ which includes the gradient of the aux loss during backpropagation.
+ """
+
+ @staticmethod
+ def forward(ctx, x, loss):
+ assert loss.numel() == 1
+ ctx.dtype = loss.dtype
+ ctx.required_aux_loss = loss.requires_grad
+ return x
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_loss = None
+ if ctx.required_aux_loss:
+ grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
+ return grad_output, grad_loss
+
+
+class EPDeepseekMoE(nn.Module):
+ def __init__(self):
+ super(EPDeepseekMoE, self).__init__()
+
+ def setup_ep(self, ep_group: ProcessGroup):
+ ep_group = ep_group
+ self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
+ self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
+ self.num_experts = self.config.n_routed_experts
+ assert self.num_experts % self.ep_size == 0
+ self.ep_group = ep_group
+ self.num_experts_per_ep = self.num_experts // self.ep_size
+ self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
+ held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
+ set_tensors_to_none(self.experts, exclude=set(held_experts))
+ for p in self.experts.parameters():
+ p.ep_group = ep_group
+
+ @staticmethod
+ def from_native_module(module: Union["DeepseekMoE", "DeepseekMLP"], *args, **kwargs) -> "EPDeepseekMoE":
+ LazyInitContext.materialize(module)
+ if module.__class__.__name__ == "DeepseekMLP":
+ return module
+ module.__class__ = EPDeepseekMoE
+ assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
+ module.setup_ep(kwargs["ep_group"])
+ return module
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ identity = hidden_states
+ orig_shape = hidden_states.shape
+
+ topk_experts_idx, topk_experts_weight, aux_loss = self.gate(hidden_states)
+
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) # [t0, t1, t2 ...]
+ hidden_states = hidden_states.repeat_interleave(
+ self.num_experts_per_tok, dim=0
+ ) # after repeat_interleave: [t0 t0 t1 t1 t2 t2 ... ]
+
+ flat_topk_experts_idx = topk_experts_idx.view(-1) # [e0 e1 e2 ...]
+ # The elements of flat_topk_token_idx are token ids, which are arranged in ascending order of expert ids.
+ flat_topk_token_idx = flat_topk_experts_idx.argsort()
+
+ # Now we adjust the order of the hidden states, also in ascending order of expert id
+ dispatch_states = hidden_states[flat_topk_token_idx]
+ input_split_sizes = flat_topk_experts_idx.bincount(minlength=self.num_experts) # [n0, n1, n2, n3]
+ output_split_sizes = torch.zeros_like(input_split_sizes)
+
+ # [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
+ dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
+
+ input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
+ output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
+ output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
+ output_states = MoeInGradScaler.apply(output_states, self.ep_size)
+
+ if output_states.size(0) > 0:
+ if self.num_experts_per_ep == 1:
+ expert = self.experts[self.expert_start_idx]
+ output_states = expert(output_states)
+ else:
+ output_states_splits = output_states.split(output_split_sizes.tolist())
+ output_states_list = []
+ for i, split_states in enumerate(output_states_splits):
+ if split_states.size(0) == 0: # no token routed to this experts
+ continue
+ expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
+ split_states = expert(split_states)
+ output_states_list.append(split_states)
+ output_states = torch.cat(output_states_list)
+ output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
+ dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
+ recover_token_idx = torch.empty_like(flat_topk_token_idx)
+ recover_token_idx[flat_topk_token_idx] = torch.arange(
+ flat_topk_token_idx.size(0), device=flat_topk_token_idx.device
+ )
+
+ output_hidden_states = dispatch_states[recover_token_idx] # t0 t0 t1 t1 t2 t2
+ output_hidden_states = output_hidden_states.view(-1, self.num_experts_per_tok, orig_shape[-1])
+ output_hidden_states = (output_hidden_states * topk_experts_weight[:, :, None]).sum(dim=-2) # (B*S, h)
+ output_hidden_states = output_hidden_states.view(*orig_shape)
+ output_hidden_states = AddAuxiliaryLoss.apply(output_hidden_states, aux_loss)
+ if self.config.n_shared_experts is not None:
+ output_hidden_states = output_hidden_states + self.shared_experts(identity)
+ return output_hidden_states
+
+
+class DeepseekPipelineForwards:
+ """
+ This class serves as a micro library for forward function substitution of Llama models
+ under pipeline setting.
+ """
+
+ @staticmethod
+ def deepseek_model_forward(
+ self: "DeepseekModel",
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM
+
+ >>> model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ logger = logging.get_logger(__name__)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if stage_manager.is_first_stage():
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ hidden_states = inputs_embeds
+ else:
+ input_shape = hidden_states.shape[:-1]
+ batch_size, seq_length = input_shape
+ device = hidden_states.device
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
+ use_cache = False
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ # embed positions, for the first stage, hidden_states is the input embeddings,
+ # for the other stages, hidden_states is the output of the previous stage
+ if is_flash_attn_2_available():
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ hidden_states,
+ past_key_values_length,
+ sliding_window=self.config.sliding_window,
+ )
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+ for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ output_attentions,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_value,
+ output_attentions,
+ use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = (layer_outputs[2 if output_attentions else 1],)
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if stage_manager.is_last_stage():
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ next_cache = next_decoder_cache if use_cache else None
+
+ if stage_manager.is_last_stage():
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ # always return dict for imediate stage
+ return {
+ "hidden_states": hidden_states,
+ }
+
+ @staticmethod
+ def deepseek_for_causal_lm_forward(
+ self: "DeepseekForCausalLM",
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MixtralForCausalLM
+
+ >>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ logger = logging.get_logger(__name__)
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
+ output_hidden_states = False
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = DeepseekPipelineForwards.deepseek_model_forward(
+ self.model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ )
+ past_key_values = None
+
+ if stage_manager.is_last_stage():
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=None,
+ hidden_states=outputs[0],
+ attentions=None,
+ )
+ else:
+ out = {}
+ hidden_states = outputs.get("hidden_states")
+ out["hidden_states"] = hidden_states
+ return out
diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py
index bf139c840..ae9f3603c 100644
--- a/colossalai/shardformer/policies/auto_policy.py
+++ b/colossalai/shardformer/policies/auto_policy.py
@@ -160,6 +160,13 @@ _POLICY_LIST = {
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
),
+ # Deepseek
+ "transformers_modules.modeling_deepseek.DeepSeekModel": PolicyLocation(
+ file_name="deepseek", class_name="DeepseekModelPolicy"
+ ),
+ "transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
+ file_name="deepseek", class_name="DeepseekForCausalLMPolicy"
+ ),
# Falcon
"transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation(
file_name="falcon", class_name="FalconModelPolicy"
@@ -252,7 +259,6 @@ def get_autopolicy(model: nn.Module) -> Policy:
"""
full_name = _fullname(model)
policy_location = _POLICY_LIST.get(full_name, None)
-
if policy_location is None:
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py
new file mode 100644
index 000000000..8ebda357b
--- /dev/null
+++ b/colossalai/shardformer/policies/deepseek.py
@@ -0,0 +1,212 @@
+import warnings
+from functools import partial
+from typing import Callable, Dict, List, Union
+
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import Module
+
+from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
+from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+
+__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
+
+
+class DeepseekPolicy(Policy):
+ def config_sanity_check(self):
+ pass
+
+ def preprocess(self):
+ if self.shard_config.enable_tensor_parallelism:
+ # Resize embedding
+ vocab_size = self.model.config.vocab_size
+ world_size = self.shard_config.tensor_parallel_size
+
+ if vocab_size % world_size != 0:
+ new_vocab_size = vocab_size + world_size - vocab_size % world_size
+ self.model.resize_token_embeddings(new_vocab_size)
+
+ return self.model
+
+ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
+ policy = {}
+
+ if self.shard_config.enable_sequence_parallelism:
+ self.shard_config.enable_sequence_parallelism = False
+ raise NotImplementedError(
+ "Deepseek dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
+ )
+
+ if self.shard_config.enable_tensor_parallelism:
+ raise NotImplementedError("Tensor parallelism is not supported for Deepseek model now.")
+
+ if getattr(self.shard_config, "ep_group", None) is not None:
+ # expert parallel
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="mlp",
+ target_module=EPDeepseekMoE,
+ kwargs={"ep_group": self.shard_config.ep_group},
+ )
+ ],
+ policy=policy,
+ target_key="DeepseekDecoderLayer",
+ )
+
+ # optimization configuration
+ if self.shard_config.enable_fused_normalization:
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="input_layernorm",
+ target_module=FusedRMSNorm,
+ ),
+ SubModuleReplacementDescription(
+ suffix="post_attention_layernorm",
+ target_module=FusedRMSNorm,
+ ),
+ ],
+ policy=policy,
+ target_key="DeepseekDecoderLayer",
+ )
+
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="norm",
+ target_module=FusedRMSNorm,
+ ),
+ policy=policy,
+ target_key="DeepseekModel",
+ )
+
+ if self.shard_config.enable_flash_attention:
+ warnings.warn(
+ "Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False."
+ )
+ self.shard_config.enable_flash_attention = False
+
+ return policy
+
+ def postprocess(self):
+ return self.model
+
+ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
+ """If under pipeline parallel setting, replacing the original forward method of huggingface
+ to customized forward method, and add this changing to policy."""
+ if self.pipeline_stage_manager:
+ stage_manager = self.pipeline_stage_manager
+ if self.model.__class__.__name__ == "DeepseekModel":
+ module = self.model
+ else:
+ module = self.model.model
+
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
+ stage_index = stage_manager.get_stage_index(layers_per_stage)
+ method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=model_cls
+ )
+
+ return
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ assert self.pipeline_stage_manager is not None
+
+ if self.model.__class__.__name__ == "DeepseekModel":
+ module = self.model
+ else:
+ module = self.model.model
+ stage_manager = self.pipeline_stage_manager
+
+ held_layers = []
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
+ if stage_manager.is_first_stage():
+ held_layers.append(module.embed_tokens)
+ start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
+ held_layers.extend(module.layers[start_idx:end_idx])
+ if stage_manager.is_last_stage():
+ held_layers.append(module.norm)
+
+ return held_layers
+
+
+class DeepseekModelPolicy(DeepseekPolicy):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+ if self.pipeline_stage_manager:
+ # set None as default
+ self.set_pipeline_forward(
+ model_cls="DeepseekModel",
+ new_forward=DeepseekPipelineForwards.deepseek_model_forward,
+ policy=policy,
+ )
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ held_layers = super().get_held_layers()
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in llama model"""
+ return []
+
+
+class DeepseekForCausalLMPolicy(DeepseekPolicy):
+ def module_policy(self):
+ policy = super().module_policy()
+ # TODO: assign pg mesh from plugin to all modules
+ if self.shard_config.enable_tensor_parallelism:
+ # add a new item for casual lm
+ new_item = {
+ "DeepseekForCausalLM": ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=Linear1D_Col,
+ kwargs=dict(gather_output=True),
+ )
+ ]
+ )
+ }
+ policy.update(new_item)
+
+ if self.pipeline_stage_manager:
+ # set None as default
+ self.set_pipeline_forward(
+ model_cls="DeepseekForCausalLM",
+ new_forward=DeepseekPipelineForwards.deepseek_for_causal_lm_forward,
+ policy=policy,
+ )
+
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ stage_manager = self.pipeline_stage_manager
+ held_layers = super().get_held_layers()
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.lm_head)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ deepseek_model = self.model.model
+ if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
+ if (
+ id(deepseek_model.embed_tokens.weight) == id(self.model.lm_head.weight)
+ and self.pipeline_stage_manager.num_stages > 1
+ ):
+ # tie weights
+ return [
+ {
+ 0: deepseek_model.embed_tokens.weight,
+ self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
+ }
+ ]
+ return []
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index 0fb858d78..ad93e9469 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -192,16 +192,16 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
- llama_model = self.model.model
+ mixtral_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (
- id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
+ id(mixtral_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [
{
- 0: llama_model.embed_tokens.weight,
+ 0: mixtral_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
diff --git a/tests/test_moe/test_deepseek_layer.py b/tests/test_moe/test_deepseek_layer.py
new file mode 100644
index 000000000..85cc98695
--- /dev/null
+++ b/tests/test_moe/test_deepseek_layer.py
@@ -0,0 +1,72 @@
+from copy import deepcopy
+
+import pytest
+import torch
+import torch.distributed as dist
+from torch.testing import assert_close
+from transformers import AutoConfig, AutoModel
+
+import colossalai
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.shardformer.modeling.deepseek import EPDeepseekMoE
+from colossalai.testing.utils import spawn
+
+tokens, n_experts = 7, 4
+hidden_size = 8
+top_k = 2
+
+
+def check_deepseek_moe_layer():
+ torch.cuda.set_device(dist.get_rank())
+ plugin = MoeHybridParallelPlugin(
+ precision="bf16",
+ tp_size=1,
+ pp_size=1,
+ ep_size=dist.get_world_size(),
+ )
+
+ config = AutoConfig.from_pretrained(
+ "deepseek-ai/deepseek-moe-16b-base",
+ num_hidden_layers=1,
+ n_routed_experts=n_experts,
+ num_experts_per_tok=top_k,
+ hidden_size=hidden_size,
+ intermediate_size=hidden_size * 2,
+ first_k_dense_replace=0,
+ num_attention_heads=2,
+ trust_remote_code=True,
+ )
+ torch.manual_seed(0)
+ # get the moe layer in auto model
+ orig_model = AutoModel.from_config(config, trust_remote_code=True).layers[0].mlp.cuda()
+ x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
+ orig_output = orig_model(x)
+ model = deepcopy(orig_model)
+ model = EPDeepseekMoE.from_native_module(model, ep_group=plugin.ep_group)
+ ep_output = model(x)
+ assert_close(orig_output, ep_output)
+ orig_loss = orig_output.mean()
+ orig_loss.backward()
+ ep_loss = ep_output.mean()
+ ep_loss.backward()
+ assert_close(orig_loss, ep_loss)
+ name_to_p = {n: p for n, p in orig_model.named_parameters()}
+ for n, ep_p in model.named_parameters():
+ p = name_to_p[n]
+ if ep_p.grad is not None:
+ assert_close(p.grad, ep_p.grad)
+
+
+def run_dist(rank: int, world_size: int, port: int):
+ colossalai.launch(rank, world_size, "localhost", port)
+ check_deepseek_moe_layer()
+
+
+# @pytest.mark.parametrize("world_size", [2, 4])
+@pytest.mark.parametrize("world_size", [2])
+def test_deepseek_moe_layer(world_size: int):
+ spawn(run_dist, world_size)
+
+
+if __name__ == "__main__":
+ test_deepseek_moe_layer(2)
diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py
index 249dd4b97..164301695 100644
--- a/tests/test_moe/test_moe_checkpoint.py
+++ b/tests/test_moe/test_moe_checkpoint.py
@@ -15,6 +15,7 @@ from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.tensor.moe_tensor.api import is_moe_tensor
+from colossalai.testing import parameterize, spawn
from colossalai.testing.utils import spawn
tokens, n_experts = 7, 4
@@ -77,7 +78,23 @@ def check_optimizer_snapshot_equal(snapshot1, snapshot2, param2name, moe_dp_grou
raise AssertionError(f"A total of {count} optim states are not equal")
-def check_mixtral_moe_layer():
+@parameterize(
+ "test_config",
+ [
+ [
+ MixtralConfig(
+ hidden_size=hidden_size,
+ intermediate_size=hidden_size * 2,
+ num_local_experts=n_experts,
+ num_experts_per_tok=top_k,
+ num_attention_heads=2,
+ num_key_value_heads=2,
+ ),
+ MixtralForCausalLM,
+ ],
+ ],
+)
+def check_moe_checkpoint(test_config):
context = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
with context as f:
torch.cuda.set_device(dist.get_rank())
@@ -87,17 +104,11 @@ def check_mixtral_moe_layer():
broadcast_objects = [None]
dist.broadcast_object_list(broadcast_objects, src=0)
- config = MixtralConfig(
- hidden_size=hidden_size,
- intermediate_size=hidden_size * 2,
- num_local_experts=n_experts,
- num_experts_per_tok=top_k,
- num_attention_heads=2,
- num_key_value_heads=2,
- )
+ config = test_config[0]
+ model_cls = test_config[1]
torch.manual_seed(0)
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
- orig_model = MixtralForCausalLM(config).cuda()
+ orig_model = model_cls(config).cuda()
model = deepcopy(orig_model)
optimizer = Adam(model.parameters(), lr=1e-3)
plugin = MoeHybridParallelPlugin(
@@ -120,7 +131,6 @@ def check_mixtral_moe_layer():
lambda outputs, inputs: outputs.loss,
optimizer,
)
-
tmpdirname = broadcast_objects[0]
model_dir = os.path.join(tmpdirname, "mixtral_model")
hf_model_dir = os.path.join(tmpdirname, "mixtral_hf_model")
@@ -129,13 +139,13 @@ def check_mixtral_moe_layer():
booster.save_model(model, model_dir, shard=True)
dist.barrier()
if dist.get_rank() == 0:
- saved_model = MixtralForCausalLM.from_pretrained(model_dir).cuda()
+ saved_model = model_cls.from_pretrained(model_dir).cuda()
check_model_equal(orig_model, saved_model)
# check_model_equal(model, saved_model)
saved_model.save_pretrained(hf_model_dir)
dist.barrier()
# check load model
- new_model = MixtralForCausalLM(config).cuda()
+ new_model = model_cls(config).cuda()
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
booster.load_model(new_model, hf_model_dir)
@@ -163,7 +173,7 @@ def check_mixtral_moe_layer():
def run_dist(rank: int, world_size: int, port: int):
colossalai.launch(rank, world_size, "localhost", port)
- check_mixtral_moe_layer()
+ check_moe_checkpoint()
# Test EP + ZeRO + PP
From 8ec24b6a4d0e0dbec7da39e43c3c1b2cfcb0395d Mon Sep 17 00:00:00 2001
From: Edenzzzz
Date: Fri, 5 Jul 2024 20:02:36 +0800
Subject: [PATCH 12/15] [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm
overlap
Co-authored-by: Edenzzzz
---
colossalai/initialize.py | 6 ++++++
colossalai/legacy/nn/layer/parallel_1d/_operation.py | 1 -
colossalai/shardformer/shard/shardformer.py | 4 ----
examples/language/llama/benchmark.py | 2 +-
4 files changed, 7 insertions(+), 6 deletions(-)
diff --git a/colossalai/initialize.py b/colossalai/initialize.py
index 71d42312e..4e2eff7ce 100644
--- a/colossalai/initialize.py
+++ b/colossalai/initialize.py
@@ -3,6 +3,12 @@
import os
+# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when overlapping communication and computation,
+# the order of of kernel launches on GPUs are the same as on the CPU so that comm is launched first.
+# see https://github.com/NVIDIA/Megatron-LM/issues/533
+# https://forums.developer.nvidia.com/t/how-many-streams-maximum-number-of-streams/6571/16
+os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
+
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
diff --git a/colossalai/legacy/nn/layer/parallel_1d/_operation.py b/colossalai/legacy/nn/layer/parallel_1d/_operation.py
index f01da97ba..8b8f04ccf 100644
--- a/colossalai/legacy/nn/layer/parallel_1d/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_1d/_operation.py
@@ -81,7 +81,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
- _ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py
index b54c58273..db03eec41 100644
--- a/colossalai/shardformer/shard/shardformer.py
+++ b/colossalai/shardformer/shard/shardformer.py
@@ -1,4 +1,3 @@
-import os
from typing import Dict, List, Tuple
import torch.distributed as dist
@@ -11,9 +10,6 @@ from ..policies.base_policy import Policy
from .shard_config import ShardConfig
from .sharder import ModelSharder
-# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when communication and computation overlap, the order of core scheduling is correct
-os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
-
class ShardFormer:
"""
diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py
index 8a35db1f7..2b7bd50b8 100644
--- a/examples/language/llama/benchmark.py
+++ b/examples/language/llama/benchmark.py
@@ -292,7 +292,7 @@ def main():
with get_profile_context(
args.profile,
args.ignore_steps,
- len(dataloader) - 1,
+ 1, # avoid creating massive log files
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
) as prof:
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
From cba20525a81565fc86e13b78973ffa8210a05cd3 Mon Sep 17 00:00:00 2001
From: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Date: Mon, 8 Jul 2024 16:02:07 +0800
Subject: [PATCH 13/15] [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3)
Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
---
colossalai/inference/config.py | 48 +-
colossalai/inference/core/base_engine.py | 90 ++
colossalai/inference/core/diffusion_engine.py | 200 +++++
colossalai/inference/core/engine.py | 800 ++----------------
colossalai/inference/core/llm_engine.py | 758 +++++++++++++++++
colossalai/inference/core/request_handler.py | 51 +-
.../inference/modeling/models/diffusion.py | 54 ++
.../inference/modeling/models/pixart_alpha.py | 220 +++++
.../modeling/models/stablediffusion3.py | 178 ++++
.../inference/modeling/policy/__init__.py | 6 +
.../inference/modeling/policy/pixart_alpha.py | 34 +
.../modeling/policy/stablediffusion3.py | 34 +
colossalai/inference/struct.py | 12 +
colossalai/inference/utils.py | 39 +-
.../stable_diffusion/sd3_generation.py | 75 ++
requirements/requirements.txt | 1 +
16 files changed, 1860 insertions(+), 740 deletions(-)
create mode 100644 colossalai/inference/core/base_engine.py
create mode 100644 colossalai/inference/core/diffusion_engine.py
create mode 100644 colossalai/inference/core/llm_engine.py
create mode 100644 colossalai/inference/modeling/models/diffusion.py
create mode 100644 colossalai/inference/modeling/models/pixart_alpha.py
create mode 100644 colossalai/inference/modeling/models/stablediffusion3.py
create mode 100644 colossalai/inference/modeling/policy/pixart_alpha.py
create mode 100644 colossalai/inference/modeling/policy/stablediffusion3.py
create mode 100644 examples/inference/stable_diffusion/sd3_generation.py
diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py
index e114e8a61..1beb86874 100644
--- a/colossalai/inference/config.py
+++ b/colossalai/inference/config.py
@@ -5,7 +5,7 @@ Our config contains various options for inference optimization, it is a unified
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers.generation import GenerationConfig
@@ -396,3 +396,49 @@ class ModelShardInferenceConfig:
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
+
+
+@dataclass
+class DiffusionGenerationConfig:
+ """
+ Param for diffusion model forward
+ """
+
+ prompt_2: Optional[Union[str, List[str]]] = None
+ prompt_3: Optional[Union[str, List[str]]] = None
+ height: Optional[int] = None
+ width: Optional[int] = None
+ num_inference_steps: int = None
+ timesteps: List[int] = None
+ guidance_scale: float = None
+ negative_prompt: Optional[Union[str, List[str]]] = (
+ None # NOTE(@lry89757) in pixart default to "", in sd3 default to None
+ )
+ negative_prompt_2: Optional[Union[str, List[str]]] = None
+ negative_prompt_3: Optional[Union[str, List[str]]] = None
+ num_images_per_prompt: Optional[int] = None
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None
+ latents: Optional[torch.FloatTensor] = None
+ prompt_embeds: Optional[torch.FloatTensor] = None
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None
+ output_type: Optional[str] = None # "pil"
+ return_dict: bool = None
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None
+ clip_skip: Optional[int] = None
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None
+ callback_on_step_end_tensor_inputs: List[str] = None
+
+ def to_dict(self) -> Dict[str, Any]:
+ # NOTE(@lry89757) Only return the dict that not the default value None
+ result = {}
+ for field in fields(self):
+ value = getattr(self, field.name)
+ if value is not None:
+ result[field.name] = value
+ return result
+
+ @classmethod
+ def from_kwargs(cls, **kwargs) -> "DiffusionGenerationConfig":
+ return cls(**kwargs)
diff --git a/colossalai/inference/core/base_engine.py b/colossalai/inference/core/base_engine.py
new file mode 100644
index 000000000..392dd2990
--- /dev/null
+++ b/colossalai/inference/core/base_engine.py
@@ -0,0 +1,90 @@
+from abc import ABC, abstractmethod
+
+import torch
+import torch.nn as nn
+
+from colossalai.cluster import ProcessGroupMesh
+from colossalai.inference.config import ModelShardInferenceConfig
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.shardformer.policies.base_policy import Policy
+
+
+class BaseEngine(ABC):
+ @abstractmethod
+ def __init__(self, model_or_path, inference_config=None, verbose=False, model_policy=None):
+ pass
+
+ @abstractmethod
+ def init_model(self, model_or_path, model_policy=None, model_shard_infer_config=None):
+ """
+ Init Model for Engine
+ """
+
+ @abstractmethod
+ def generate(self, request_ids=None, prompts=None, generation_config=None, **kwargs):
+ """
+ Generate ouptput for coming requests
+ """
+
+ @abstractmethod
+ def add_request(self, prompts, request_ids=None, **kwargs):
+ """
+ Add new request to Engine
+ """
+
+ @abstractmethod
+ def step(self):
+ """
+ Perform one new step forward
+ """
+
+ @abstractmethod
+ def _verify_args(self):
+ """
+ Verify the parameters and members of class
+ """
+
+ @torch.inference_mode()
+ def capture_model(self):
+ """
+ Use cuda graph to capture model
+ """
+ return NotImplementedError("This method should be implemented by subclasses")
+
+ def _shardformer(
+ self,
+ model: nn.Module,
+ model_policy: Policy,
+ model_shard_infer_config: ModelShardInferenceConfig = None,
+ stage_manager: PipelineStageManager = None,
+ tp_group: ProcessGroupMesh = None,
+ **kwargs,
+ ) -> nn.Module:
+ """
+ Initialize ShardConfig and replace the model with shardformer.
+
+ Args:
+ model (nn.Module): Path or nn.Module of this model.
+ model_policy (Policy): The policy to shardformer model which is determined by the model type.
+ stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
+ tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
+
+ Returns:
+ nn.Module: The model optimized by Shardformer.
+ """
+
+ shardconfig = ShardConfig(
+ tensor_parallel_process_group=tp_group,
+ pipeline_stage_manager=stage_manager,
+ enable_tensor_parallelism=(self.inference_config.tp_size > 1),
+ enable_fused_normalization=False,
+ enable_all_optimization=False,
+ enable_flash_attention=False,
+ enable_jit_fused=False,
+ enable_sequence_parallelism=False,
+ extra_kwargs={"model_shard_infer_config": model_shard_infer_config, **kwargs},
+ )
+ shardformer = ShardFormer(shard_config=shardconfig)
+ shard_model, _ = shardformer.optimize(model, model_policy)
+ return shard_model
diff --git a/colossalai/inference/core/diffusion_engine.py b/colossalai/inference/core/diffusion_engine.py
new file mode 100644
index 000000000..75b9889bf
--- /dev/null
+++ b/colossalai/inference/core/diffusion_engine.py
@@ -0,0 +1,200 @@
+from itertools import count
+from typing import List, Tuple, Type, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn as nn
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from torch import distributed as dist
+
+from colossalai.accelerator import get_accelerator
+from colossalai.cluster import ProcessGroupMesh
+from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
+from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.policy import model_policy_map
+from colossalai.inference.struct import DiffusionSequence
+from colossalai.inference.utils import get_model_size, get_model_type
+from colossalai.logging import get_dist_logger
+from colossalai.shardformer.policies.base_policy import Policy
+
+from .base_engine import BaseEngine
+from .request_handler import NaiveRequestHandler
+
+PP_AXIS, TP_AXIS = 0, 1
+
+
+class DiffusionEngine(BaseEngine):
+ def __init__(
+ self,
+ model_or_path: DiffusionPipeline | str,
+ inference_config: InferenceConfig = None,
+ verbose: bool = False,
+ model_policy: Policy | type[Policy] = None,
+ ) -> None:
+ self.inference_config = inference_config
+ self.dtype = inference_config.dtype
+ self.high_precision = inference_config.high_precision
+
+ self.verbose = verbose
+ self.logger = get_dist_logger(__name__)
+ self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
+
+ self.model_type = get_model_type(model_or_path=model_or_path)
+
+ self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
+
+ self.request_handler = NaiveRequestHandler()
+
+ self.counter = count()
+
+ self._verify_args()
+
+ def _verify_args(self) -> None:
+ assert isinstance(self.model, DiffusionPipe), "model must be DiffusionPipe"
+
+ def init_model(
+ self,
+ model_or_path: Union[str, nn.Module, DiffusionPipeline],
+ model_policy: Union[Policy, Type[Policy]] = None,
+ model_shard_infer_config: ModelShardInferenceConfig = None,
+ ):
+ """
+ Shard model or/and Load weight
+
+ Args:
+ model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
+ model_policy (Policy): the policy to replace the model.
+ model_inference_config: the configuration for modeling initialization when inference.
+ model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
+ """
+ if isinstance(model_or_path, str):
+ model = DiffusionPipeline.from_pretrained(model_or_path, torch_dtype=self.dtype)
+ policy_map_key = model.__class__.__name__
+ model = DiffusionPipe(model)
+ elif isinstance(model_or_path, DiffusionPipeline):
+ policy_map_key = model_or_path.__class__.__name__
+ model = DiffusionPipe(model_or_path)
+ else:
+ self.logger.error(f"model_or_path support only str or DiffusionPipeline currently!")
+
+ torch.cuda.empty_cache()
+ init_gpu_memory = torch.cuda.mem_get_info()[0]
+
+ self.device = get_accelerator().get_current_device()
+ if self.verbose:
+ self.logger.info(f"the device is {self.device}")
+
+ if self.verbose:
+ self.logger.info(
+ f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
+ )
+
+ if model_policy is None:
+ model_policy = model_policy_map.get(policy_map_key)
+
+ if not isinstance(model_policy, Policy):
+ try:
+ model_policy = model_policy()
+ except Exception as e:
+ raise ValueError(f"Unable to instantiate model policy: {e}")
+
+ assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
+ pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
+ tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
+
+ self.model = self._shardformer(
+ model,
+ model_policy,
+ model_shard_infer_config,
+ None,
+ tp_group=tp_group,
+ )
+
+ self.model = model.to(self.device)
+
+ if self.verbose:
+ self.logger.info(
+ f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
+ )
+
+ free_gpu_memory, _ = torch.cuda.mem_get_info()
+ peak_memory = init_gpu_memory - free_gpu_memory
+ if self.verbose:
+ self.logger.info(
+ f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
+ )
+
+ def generate(
+ self,
+ request_ids: Union[List[int], int] = None,
+ prompts: Union[List[str], str] = None,
+ generation_config: DiffusionGenerationConfig = None,
+ **kwargs,
+ ) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
+ """ """
+ gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
+ prompts = [prompts] if isinstance(prompts, str) else prompts
+ request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
+
+ with torch.inference_mode():
+ if prompts is not None:
+ self.add_request(
+ request_ids=request_ids,
+ prompts=prompts,
+ **gen_config_dict,
+ **kwargs,
+ )
+
+ output_reqs_list = []
+
+ # intuition: If user provide a generation config, we should replace the existing one.
+ if generation_config is not None:
+ self.generation_config = generation_config
+ self.generation_config_dict = gen_config_dict
+
+ while self.request_handler.check_unfinished_reqs():
+ output_reqs_list += self.step()
+
+ return output_reqs_list
+
+ def add_request(
+ self,
+ prompts: Union[List[str], str],
+ request_ids: Union[List[int], int] = None,
+ **kwargs,
+ ):
+ if request_ids is not None and not isinstance(request_ids, list):
+ request_ids = [request_ids]
+
+ if not isinstance(prompts, list):
+ prompts = [prompts]
+
+ generation_config = DiffusionGenerationConfig.from_kwargs(**kwargs)
+ prompts_num = len(prompts)
+ for i in range(prompts_num):
+ if request_ids:
+ assert isinstance(
+ request_ids[0], int
+ ), f"The request_id type must be int, but got {type(request_ids[0])}"
+ assert len(request_ids) == prompts_num
+ request_id = request_ids[i]
+ else:
+ request_id = next(self.counter)
+
+ seq = DiffusionSequence(request_id=request_id, prompt=prompts[i], generation_config=generation_config)
+
+ self.request_handler.add_sequence(seq)
+
+ def step(self) -> List[PIL.Image.Image]:
+ """
+ In each step, do the follows:
+ 1. Run RequestHandler.schedule() and get the batch used for inference.
+ 2. run forward to get List[Image]
+ Returns:
+ List[PIL.Image.Image]: Image Generated by one step.
+ """
+
+ input = self.request_handler.schedule()
+ ret = self.model(prompt=input.prompt, **input.generation_config.to_dict())
+ return ret
diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py
index 8f8aef65e..5c9bdc321 100644
--- a/colossalai/inference/core/engine.py
+++ b/colossalai/inference/core/engine.py
@@ -1,57 +1,24 @@
-import time
-from itertools import count
-from typing import Dict, List, Optional, Tuple, Type, Union
+from typing import List, Tuple, Type, Union
import numpy as np
-import torch
+import PIL.Image
import torch.nn as nn
-from torch import distributed as dist
-from transformers import (
- AutoConfig,
- AutoModelForCausalLM,
- GenerationConfig,
- PreTrainedTokenizer,
- PreTrainedTokenizerFast,
-)
-from transformers.models.llama.modeling_llama import LlamaForCausalLM
+from diffusers import DiffusionPipeline
+from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
-from colossalai.accelerator import get_accelerator
-from colossalai.cluster import ProcessGroupMesh
-from colossalai.inference.batch_bucket import BatchBucket
-from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
-from colossalai.inference.graph_runner import CUDAGraphRunner
-from colossalai.inference.modeling.policy import model_policy_map
-from colossalai.inference.sampler import search_tokens
-from colossalai.inference.spec import Drafter, GlideInput
-from colossalai.inference.struct import Sequence
-from colossalai.inference.utils import get_model_size, has_index_file
-from colossalai.interface import ModelWrapper
-from colossalai.lazy import LazyInitContext
-from colossalai.logging import get_dist_logger
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.inference.config import InferenceConfig
+from colossalai.inference.utils import ModelType, get_model_type
from colossalai.shardformer.policies.base_policy import Policy
-from .request_handler import RequestHandler
-
__all__ = ["InferenceEngine"]
-PP_AXIS, TP_AXIS = 0, 1
-
-_supported_models = {
- "LlamaForCausalLM": LlamaForCausalLM,
- "BaichuanForCausalLM": AutoModelForCausalLM,
-}
-
-_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
-
class InferenceEngine:
"""
InferenceEngine which manages the inference process..
Args:
- model_or_path (nn.Module or str): Path or nn.Module of this model.
+ model_or_path (nn.Module or DiffusionPipeline or str): Path or nn.Module or DiffusionPipeline of this model.
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
@@ -60,567 +27,68 @@ class InferenceEngine:
def __init__(
self,
- model_or_path: Union[nn.Module, str],
- tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
- inference_config: InferenceConfig,
+ model_or_path: Union[nn.Module, str, DiffusionPipeline],
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
+ inference_config: InferenceConfig = None,
verbose: bool = False,
model_policy: Union[Policy, Type[Policy]] = None,
) -> None:
- self.inference_config = inference_config
- self.dtype = inference_config.dtype
- self.high_precision = inference_config.high_precision
+ self.__dict__["_initialized"] = False # use __dict__ directly to avoid calling __setattr__
+ self.model_type = get_model_type(model_or_path=model_or_path)
+ self.engine = None
+ if self.model_type == ModelType.LLM:
+ from .llm_engine import LLMEngine
- self.verbose = verbose
- self.logger = get_dist_logger(__name__)
- self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
+ self.engine = LLMEngine(
+ model_or_path=model_or_path,
+ tokenizer=tokenizer,
+ inference_config=inference_config,
+ verbose=verbose,
+ model_policy=model_policy,
+ )
+ elif self.model_type == ModelType.DIFFUSION_MODEL:
+ from .diffusion_engine import DiffusionEngine
- self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
-
- self.generation_config = inference_config.to_generation_config(self.model_config)
- self.generation_config_dict = self.generation_config.to_dict()
-
- self.tokenizer = tokenizer
- self.tokenizer.pad_token = self.tokenizer.eos_token
-
- self.request_handler = RequestHandler(self.inference_config, self.model_config)
- self.k_cache, self.v_cache = self.request_handler.get_kvcache()
- # DISCUSS maybe move this into batch info?
-
- self.counter = count()
-
- self.use_cuda_graph = self.inference_config.use_cuda_graph
- if self.use_cuda_graph:
- self.graph_runners: Dict[int, CUDAGraphRunner] = {}
- self.graph_memory_pool = None # Set during graph capture.
- if verbose:
- self.logger.info("Colossal AI CUDA Graph Capture on")
-
- self.capture_model(self.k_cache, self.v_cache)
-
- # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
- self.use_spec_dec = self.inference_config.use_spec_dec
-
- self.drafter_model = None
- self.drafter = None
- self.use_glide = False
- self.n_spec_tokens = self.inference_config.max_n_spec_tokens
+ self.engine = DiffusionEngine(
+ model_or_path=model_or_path,
+ inference_config=inference_config,
+ verbose=verbose,
+ model_policy=model_policy,
+ )
+ elif self.model_type == ModelType.UNKNOWN:
+ self.logger.error(f"Model Type either Difffusion or LLM!")
+ self._initialized = True
self._verify_args()
- def init_model(
- self,
- model_or_path: Union[nn.Module, str],
- model_policy: Union[Policy, Type[Policy]] = None,
- model_shard_infer_config: ModelShardInferenceConfig = None,
- ):
- """
- Shard model or/and Load weight
-
- Args:
- model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
- model_policy (Policy): the policy to replace the model.
- model_inference_config: the configuration for modeling initialization when inference.
- model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
- """
- pretrained_path = None
- if isinstance(model_or_path, str):
- import colossalai.interface.pretrained as pretrained_utils
-
- try:
- hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
- arch = getattr(hf_config, "architectures")[0]
- if arch in _supported_models.keys():
- if arch is "BaichuanForCausalLM":
- self.logger.warning(
- "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
- )
- ctx = LazyInitContext(default_device="cuda")
- with ctx:
- model = _supported_models[arch].from_pretrained(
- model_or_path, trust_remote_code=True, torch_dtype=self.dtype
- )
- pretrained_path = pretrained_utils.get_pretrained_path(model)
- else:
- # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
- raise ValueError(f"Model {arch} is not supported.")
-
- except Exception as e:
- self.logger.error(
- f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
- )
- else:
- model = model_or_path
-
- self.model_config = model.config
-
- torch.cuda.empty_cache()
- init_gpu_memory = torch.cuda.mem_get_info()[0]
-
- self.device = get_accelerator().get_current_device()
- if self.verbose:
- self.logger.info(f"the device is {self.device}")
-
- model = model.to(self.dtype).eval()
-
- if self.verbose:
- self.logger.info(
- f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
- )
-
- if model_policy is None:
- prefix = "nopadding" if not self.inference_config.pad_input else "padding"
- model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
- model_policy = model_policy_map.get(model_policy_key)
-
- if not isinstance(model_policy, Policy):
- try:
- model_policy = model_policy()
- except Exception as e:
- raise ValueError(f"Unable to instantiate model policy: {e}")
-
- assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
- pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
- tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
-
- self.model = self._shardformer(
- model,
- model_policy,
- model_shard_infer_config,
- None,
- tp_group=tp_group,
- )
-
- self.model = ModelWrapper(model).to(self.device)
-
- if self.verbose:
- self.logger.info(
- f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
- )
-
- if pretrained_path:
- from colossalai.inference.core.plugin import InferCheckpoint_io
-
- cpt_io = InferCheckpoint_io()
- if_has_index_file, model_index_file = has_index_file(pretrained_path)
- assert if_has_index_file, "the model path is invalid"
- cpt_io.load_model(self.model, model_index_file)
-
- free_gpu_memory, _ = torch.cuda.mem_get_info()
- peak_memory = init_gpu_memory - free_gpu_memory
- if self.verbose:
- self.logger.info(
- f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
- )
-
- @torch.inference_mode()
- def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
- assert self.use_cuda_graph, "please turn on the cuda graph"
-
- if self.verbose:
- self.logger.info("Colossal AI CUDA Graph Capture begin")
-
- t_capture_begin = time.perf_counter()
-
- block_size = self.inference_config.block_size
- head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
-
- # Prepare dummy inputs. These will be reused for all batch sizes.
- max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
- max_context_len_to_capture = self.inference_config.max_context_len_to_capture
- max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
- input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
- # self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
- self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
- self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
- self.graph_block_tables[0, :] = np.arange(
- 0, max_num_blocks
- ) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
- block_tables = torch.from_numpy(self.graph_block_tables).cuda()
- output_tensor = torch.zeros(
- (max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
- )
- fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
-
- max_num_seqs = self.inference_config.max_batch_size
- batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
- sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
- # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
- sequence_lengths[0] = torch.tensor(
- self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
- ).cuda()
-
- # NOTE: Capturing the largest batch size first may help reduce the
- # memory usage of CUDA graph.
- for batch_size in reversed(batch_size_capture_list):
- if self.verbose:
- self.logger.info(f"batch size {batch_size} graph capturing")
-
- input_meta_data = InputMetaData(
- block_tables=block_tables[:batch_size],
- sequence_lengths=sequence_lengths[:batch_size],
- fd_inter_tensor=fd_inter_tensor,
- batch_size=batch_size,
- is_prompts=False,
- use_cuda_graph=True,
- high_precision=False,
- kv_seq_len=sequence_lengths[:batch_size].max().item(),
- head_dim=head_dim,
- dtype=self.dtype,
- )
-
- graph_runner = CUDAGraphRunner(self.model)
- graph_runner.capture(
- input_tokens_ids[:batch_size],
- output_tensor[:batch_size],
- input_meta_data,
- k_caches=k_cache,
- v_caches=v_cache,
- memory_pool=self.graph_memory_pool,
- )
- self.graph_memory_pool = graph_runner.graph.pool()
- self.graph_runners[batch_size] = graph_runner
-
- t_capture_end = time.perf_counter()
-
- if self.verbose:
- self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
-
def _verify_args(self) -> None:
"""Verify the input args"""
- if not isinstance(self.inference_config, InferenceConfig):
- raise TypeError("Invalid type of inference config provided.")
- if not isinstance(self.model, nn.Module):
- raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
- if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
- raise TypeError(
- f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
- )
- if isinstance(self.model, ModelWrapper):
- model = self.model.module
- assert (
- model.__class__.__name__ in _supported_models.keys()
- ), f"Model {self.model.__class__.__name__} is not supported."
-
- def _shardformer(
- self,
- model: nn.Module,
- model_policy: Policy,
- model_shard_infer_config: ModelShardInferenceConfig = None,
- stage_manager: PipelineStageManager = None,
- tp_group: ProcessGroupMesh = None,
- ) -> nn.Module:
- """
- Initialize ShardConfig and replace the model with shardformer.
-
- Args:
- model (nn.Module): Path or nn.Module of this model.
- model_policy (Policy): The policy to shardformer model which is determined by the model type.
- stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
- tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
-
- Returns:
- nn.Module: The model optimized by Shardformer.
- """
-
- shardconfig = ShardConfig(
- tensor_parallel_process_group=tp_group,
- pipeline_stage_manager=stage_manager,
- enable_tensor_parallelism=(self.inference_config.tp_size > 1),
- enable_fused_normalization=False,
- enable_all_optimization=False,
- enable_flash_attention=False,
- enable_jit_fused=False,
- enable_sequence_parallelism=False,
- extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
- )
- shardformer = ShardFormer(shard_config=shardconfig)
- shard_model, _ = shardformer.optimize(model, model_policy)
- return shard_model
-
- def enable_spec_dec(
- self,
- drafter_model: nn.Module = None,
- n_spec_tokens: int = None,
- use_glide_drafter: bool = False,
- ) -> None:
- """Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
-
- Args:
- drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
- If provided, the previous drafter and drafter model, if exist, will be overwritten.
- n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
- If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
- use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
- If True, the drafter model will be replaced by a glide model.
-
- ```python
- ...
- engine = InferenceEngine(model, tokenizer, inference_config)
-
- engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
- engine.generate(...) # Speculative Decoding
-
- engine.disable_spec_dec()
- engine.generate(...) # Normal generation
-
- engine.enable_spec_dec()
- engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
- engine.clear_spec_dec()
- ```
- """
-
- if drafter_model is None and self.drafter is None:
- raise ValueError("Drafter not initialized. Please provide a Drafter Model")
- if n_spec_tokens is not None:
- assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
- self.n_spec_tokens = n_spec_tokens
- if drafter_model is not None:
- assert isinstance(drafter_model, nn.Module)
- # overwrite the drafter, if exists
- self.clear_spec_dec()
- self.drafter_model = drafter_model
- self.drafter = Drafter(
- self.drafter_model,
- self.tokenizer,
- device=self.device,
- dtype=self.dtype,
- )
-
- # check if the provided drafter model is compatible with GLIDE structure
- # when `use_glide_drafter` is set to True
- if (
- use_glide_drafter
- and hasattr(drafter_model, "model")
- and hasattr(drafter_model.model, "layers")
- and hasattr(drafter_model.model.layers[0], "cross_attn")
- ):
- self.use_glide = use_glide_drafter
- elif use_glide_drafter:
- self.logger.warning(
- f"`use_glide_drafter` is provided as {use_glide_drafter}, "
- f"but the provided drafter model is not compatible with GLIDE structure."
- f"Falling back to use the default drafter model (non-GLIDE)."
- )
- self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
- # using speculative decoding for subsequent generations
- self.use_spec_dec = True
-
- def disable_spec_dec(self) -> None:
- """Disable using speculative decoding for subsequent generations."""
- self.request_handler.unset_spec_dec_mode()
- # set back to the maximum number of tokens to speculate
- self.n_spec_tokens = self.inference_config.max_n_spec_tokens
- self.use_glide = False
- self.use_spec_dec = False
-
- def clear_spec_dec(self) -> None:
- """Clear relatable structures of speculative decoding, if exist."""
- if self.use_spec_dec:
- self.disable_spec_dec()
- if self.drafter_model or self.drafter:
- self.drafter_model = None
- self.drafter = None
- torch.cuda.empty_cache()
- self.use_glide = False
- self.use_spec_dec = False
-
- def steps_spec_dec(self) -> List[Sequence]:
- """
- Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
- with many steps of speculating by a drafter model as well as verifying by a main model.
-
- Returns:
- List[Sequence]: finished sequences generated by one step.
- """
- batch = self.request_handler.schedule() # prefill batch
- assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
-
- input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
-
- if input_meta_data.use_cuda_graph:
- model_executable = self.graph_runners[input_meta_data.batch_size]
- else:
- model_executable = self.model
-
- # 1. Prefill small model (Drafter) - fill past kv cache for drafter model
- # NOTE For glide drafter models, we won't actually apply glide during prefill stage
- drafter_out = self.drafter.speculate(input_token_ids, 1, None)
- next_token_ids_spec = drafter_out.next_tokens
- drafter_past_key_values = drafter_out.past_key_values
-
- # 2. Prefill main model (Verifier) - fill past kv cache for main model
- logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
- next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
- # append new inputs to the batch, temporarily
- batch.append_batch_tokens(next_tokens)
- self.request_handler.allocate_batch_spec_dec(batch, 1)
- already_allocated_kv_len = batch.seq_lengths[0].item()
- input_token_ids = batch.get_1D_inputs_spec_dec(1)
-
- finished_sequences = self.request_handler.update()
-
- while True:
- # HACK Retrieve the running batch
- # Using RequestHandler.schedule here will re-allocate same kv cache for the batch
- batch = self.request_handler.running_bb # running batch
- assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
-
- # 3. Decoding - Drafter model speculates `n` tokens
- glide_input = None
- if self.use_glide:
- glide_input = GlideInput(
- batch.get_block_table_tensor(),
- self.k_cache[-1], # use kv cahces of the last layer
- self.v_cache[-1],
- batch.get_sequence_lengths(),
- n_spec_tokens=self.n_spec_tokens,
- )
-
- drafter_out = self.drafter.speculate(
- input_token_ids,
- self.n_spec_tokens,
- drafter_past_key_values,
- glide_input=glide_input,
- )
- next_token_ids_spec = drafter_out.next_tokens
- drafter_past_key_values = drafter_out.past_key_values
- drafter_spec_length = drafter_out.speculated_length
-
- for next_token_id_spec in next_token_ids_spec:
- self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
- cur_length = batch.seq_lengths[0].item()
- if already_allocated_kv_len < cur_length:
- self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
- already_allocated_kv_len = cur_length
-
- # 4. Decoding - Main model verifies `n` tokens in parallel
- if drafter_spec_length < batch.num_tokens_to_verify:
- batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
- input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
- logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
-
- next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
-
- # 5. Compare and process the results
- diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
- n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
-
- # revoke appended tokens for each Sequence in the current batch
- batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
-
- # append the last correct token generated by the main model
- self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
-
- # trim past key values of the drafter model
- drafter_past_key_values = Drafter.trim_kv_cache(
- drafter_past_key_values, drafter_spec_length - n_matches - 1
- )
-
- # prepare inputs for the next round of speculation
- n = 1 if n_matches < drafter_spec_length else 2
- input_token_ids = batch.get_1D_inputs_spec_dec(n)
-
- self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
- finished_sequences = self.request_handler.update()
- if len(finished_sequences) > 0:
- break
-
- # Reset back the number of speculated tokens of the batch,
- # this is used to handle the last round of speculation, in which case the number of speculated tokens
- # by the drafter is less than the number of speculated tokens set to the engine.
- batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
-
- return finished_sequences
+ assert self.engine is not None, "Please init Engine first"
+ assert self._initialized, "Engine must be initialized"
def generate(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
- prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
- return_token_ids: bool = False,
- generation_config: Optional[GenerationConfig] = None,
- ) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
+ *args,
+ **kwargs,
+ ) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
"""
Executing the inference step.
Args:
request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
- prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
- return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
- generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
-
- Returns:
- Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
"""
- gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
- prompts = [prompts] if isinstance(prompts, str) else prompts
- request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
-
- with torch.inference_mode():
- if prompts is not None or prompts_token_ids is not None:
- self.add_request(
- request_ids=request_ids,
- prompts=prompts,
- prompts_token_ids=prompts_token_ids,
- **gen_config_dict,
- )
-
- output_seqs_list = []
- total_tokens_list = []
-
- # intuition: If user provide a generation config, we should replace the existing one.
- if generation_config is not None:
- self.generation_config = generation_config
- self.generation_config_dict = gen_config_dict
-
- if self.use_spec_dec:
- assert self.drafter is not None, "Drafter Model is not initialized."
- while self.request_handler.check_unfinished_seqs():
- output_seqs_list += self.steps_spec_dec()
- else:
- while self.request_handler.check_unfinished_seqs():
- output_seqs_list += self.step()
-
- output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
-
- for seq in output_seqs_list:
- total_tokens_list.append(seq.input_token_id + seq.output_token_id)
-
- output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
-
- if return_token_ids:
- output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
- return output_str, output_tokens_list
- else:
- return output_str
-
- @property
- def has_prompt_template(self) -> bool:
- """ """
- return self.inference_config.prompt_template is not None
-
- def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
- """
- This method will format the input prompt according to the prompt template given to the InferenceConfig.
- """
- assert (
- self.has_prompt_template
- ), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
-
- if isinstance(prompts, (list, tuple)):
- return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
- elif isinstance(prompts, str):
- return self.inference_config.prompt_template.format(input_text=prompts)
- else:
- raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
+ assert self.engine is not None, "Please init Engine first"
+ return self.engine.generate(request_ids=request_ids, prompts=prompts, *args, **kwargs)
def add_request(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
- prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
+ *args,
**kwargs,
) -> None:
"""
@@ -630,168 +98,36 @@ class InferenceEngine:
request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
+ kwargs: for LLM, it could be max_length, max_new_tokens, etc
+ for diffusion, it could be prompt_2, prompt_3, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, negative_prompt_2, negative_prompt_3, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, clip_skip, which aligns with diffusers
"""
+ assert self.engine is not None, "Please init Engine first"
+ self.engine.add_request(request_ids=request_ids, prompts=prompts, *args, **kwargs)
- # apply the prompt template to the input prompts
+ def step(self):
+ assert self.engine is not None, "Please init Engine first"
+ return self.engine.step()
- if self.has_prompt_template and prompts is not None:
- prompts = self.format_prompt(prompts)
-
- block_size = self.inference_config.block_size
-
- if request_ids is not None and not isinstance(request_ids, list):
- request_ids = [request_ids]
-
- if prompts is not None and not isinstance(prompts, list):
- prompts = [prompts]
-
- if prompts_token_ids is None:
- assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
- prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
- "input_ids"
- ]
-
- # list of torch Tensor
- if isinstance(prompts_token_ids, list):
- if isinstance(prompts_token_ids[0], torch.Tensor):
- prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
- elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
- prompts_token_ids = prompts_token_ids.tolist()
- else:
- raise TypeError(
- f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
- )
-
- assert (
- len(prompts_token_ids[0]) <= self.inference_config.max_input_len
- ), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
-
- prompts_num = len(prompts_token_ids)
-
- for i in range(prompts_num):
- if request_ids:
- assert isinstance(
- request_ids[0], int
- ), f"The request_id type must be int, but got {type(request_ids[0])}"
- assert len(request_ids) == prompts_num
- request_id = request_ids[i]
+ def __getattr__(self, name):
+ """
+ The Design logic of getattr, setattr:
+ 1. Since InferenceEngine is a wrapper for DiffusionEngine/LLMEngine, we hope to invoke all the member of DiffusionEngine/LLMEngine like we just call the member of InferenceEngine.
+ 2. When we call the __init__ of InferenceEngine, we don't want to setattr using self.__dict__["xxx"] = xxx, we want to use origin ways like self.xxx = xxx
+ So we set the attribute `_initialized`. And after initialized, if we couldn't get the member from InferenceEngine, we will try to get the member from self.engine(DiffusionEngine/LLMEngine)
+ """
+ if self.__dict__.get("_initialized", False):
+ if name in self.__dict__:
+ return self.__dict__[name]
else:
- request_id = next(self.counter)
- if prompts == None:
- prompt = None
+ return getattr(self.engine, name)
+ else:
+ return self.__dict__[name]
+
+ def __setattr__(self, name, value):
+ if self.__dict__.get("_initialized", False):
+ if name in self.__dict__:
+ self.__dict__[name] = value
else:
- prompt = prompts[i]
-
- max_length = kwargs.get("max_length", None)
- max_new_tokens = kwargs.get("max_new_tokens", None)
- if max_length is None and max_new_tokens is None:
- max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
- elif max_length is not None:
- max_new_tokens = max_length - len(prompts_token_ids[i])
-
- if not self.inference_config.enable_streamingllm:
- assert (
- self.inference_config.max_output_len >= max_new_tokens
- ), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
-
- sequence = Sequence(
- request_id,
- prompt,
- prompts_token_ids[i],
- block_size,
- None,
- self.tokenizer.eos_token_id,
- self.tokenizer.pad_token_id,
- max_output_len=max_new_tokens,
- ignore_eos=self.inference_config.ignore_eos,
- )
- self.request_handler.add_sequence(sequence)
-
- def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
- input_ids = batch.get_1D_inputs()
- sequence_lengths = batch.get_sequence_lengths()
-
- if batch.is_prompts:
- n_tokens = sequence_lengths.sum().item()
+ setattr(self.engine, name, value)
else:
- n_tokens = batch.current_batch_size
- if batch.use_spec_dec:
- n_tokens = batch.num_tokens_to_verify + 1
- assert n_tokens == input_ids.size(0)
- n_tokens = n_tokens * batch.current_batch_size
- output_tensor = torch.zeros(
- (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
- )
-
- batch_token_ids = None
- if (
- self.generation_config.repetition_penalty != 1.0
- or self.generation_config.no_repeat_ngram_size > 0
- or self.generation_config.forced_eos_token_id is not None
- ):
- batch_token_ids = batch.batch_token_ids
-
- # only when we have the graph for specific decoding batch size can we use the cuda graph for inference
- use_cuda_graph = False
- if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
- use_cuda_graph = True
-
- input_meta_data = InputMetaData(
- block_tables=batch.get_block_table_tensor(),
- sequence_lengths=sequence_lengths,
- fd_inter_tensor=batch.fd_inter_tensor,
- batch_size=batch.current_batch_size,
- is_prompts=batch.is_prompts,
- use_cuda_kernel=self.inference_config.use_cuda_kernel,
- use_cuda_graph=use_cuda_graph,
- high_precision=self.high_precision,
- kv_seq_len=sequence_lengths.max().item(),
- head_dim=batch.head_dim,
- dtype=batch.dtype,
- use_spec_dec=batch.use_spec_dec,
- num_tokens_to_verify=batch.num_tokens_to_verify,
- batch_token_ids=batch_token_ids,
- )
-
- return input_ids, output_tensor, input_meta_data
-
- def step(self) -> List[str]:
- """
- In each step, do the follows:
- 1. Run RequestHandler.schedule() and get the batch used for inference.
- 2. Get the input, inputinfo and output placeholder from the batchbucket
- 3. Run model to generate the next token
- 4. Update waiting list and running list in RequestHandler and get finished sequences.
- 5. Decode and return finished sequences.
-
- Returns:
- List[str]: Decoded finished sequences generated by one step.
- """
-
- batch = self.request_handler.schedule()
-
- input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
-
- if input_meta_data.use_cuda_graph:
- model_executable = self.graph_runners[input_meta_data.batch_size]
- else:
- model_executable = self.model
-
- # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
- logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
- if self.inference_config.pad_input:
- logits = logits[:, -1, :]
-
- if self.inference_config.enable_streamingllm:
- updated_block_ids = batch.streamingllm_update_batch(
- self.inference_config.start_token_size, self.inference_config.generated_token_size
- )
- self.request_handler.streamingllm_free_block_tables(updated_block_ids)
-
- next_tokens = search_tokens(
- self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
- )
- self.request_handler.append_next_tokens(next_tokens)
- finished_sequences = self.request_handler.update()
-
- return finished_sequences
+ self.__dict__[name] = value
diff --git a/colossalai/inference/core/llm_engine.py b/colossalai/inference/core/llm_engine.py
new file mode 100644
index 000000000..b973d371d
--- /dev/null
+++ b/colossalai/inference/core/llm_engine.py
@@ -0,0 +1,758 @@
+import time
+from itertools import count
+from typing import Dict, List, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch import distributed as dist
+from transformers import (
+ AutoConfig,
+ AutoModelForCausalLM,
+ GenerationConfig,
+ PreTrainedTokenizer,
+ PreTrainedTokenizerFast,
+)
+from transformers.models.llama.modeling_llama import LlamaForCausalLM
+
+from colossalai.accelerator import get_accelerator
+from colossalai.cluster import ProcessGroupMesh
+from colossalai.inference.batch_bucket import BatchBucket
+from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
+from colossalai.inference.graph_runner import CUDAGraphRunner
+from colossalai.inference.modeling.policy import model_policy_map
+from colossalai.inference.sampler import search_tokens
+from colossalai.inference.spec import Drafter, GlideInput
+from colossalai.inference.struct import Sequence
+from colossalai.inference.utils import get_model_size, has_index_file
+from colossalai.interface import ModelWrapper
+from colossalai.lazy import LazyInitContext
+from colossalai.logging import get_dist_logger
+from colossalai.shardformer.policies.base_policy import Policy
+
+from .base_engine import BaseEngine
+from .request_handler import RequestHandler
+
+PP_AXIS, TP_AXIS = 0, 1
+
+_supported_models = {
+ "LlamaForCausalLM": LlamaForCausalLM,
+ "BaichuanForCausalLM": AutoModelForCausalLM,
+}
+
+_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
+
+
+class LLMEngine(BaseEngine):
+ """
+ InferenceEngine which manages the inference process..
+
+ Args:
+ model_or_path (nn.Module or str): Path or nn.Module of this model.
+ tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
+ inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
+ verbose (bool): Determine whether or not to log the generation process.
+ model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
+ """
+
+ def __init__(
+ self,
+ model_or_path: nn.Module | str,
+ tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None,
+ inference_config: InferenceConfig = None,
+ verbose: bool = False,
+ model_policy: Policy | type[Policy] = None,
+ ) -> None:
+ self.inference_config = inference_config
+ self.dtype = inference_config.dtype
+ self.high_precision = inference_config.high_precision
+
+ self.verbose = verbose
+ self.logger = get_dist_logger(__name__)
+ self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
+
+ self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
+
+ self.generation_config = inference_config.to_generation_config(self.model_config)
+ self.generation_config_dict = self.generation_config.to_dict()
+
+ self.tokenizer = tokenizer
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ self.request_handler = RequestHandler(self.inference_config, self.model_config)
+ self.k_cache, self.v_cache = self.request_handler.get_kvcache()
+ # DISCUSS maybe move this into batch info?
+
+ self.counter = count()
+
+ self.use_cuda_graph = self.inference_config.use_cuda_graph
+ if self.use_cuda_graph:
+ self.graph_runners: Dict[int, CUDAGraphRunner] = {}
+ self.graph_memory_pool = None # Set during graph capture.
+ if verbose:
+ self.logger.info("Colossal AI CUDA Graph Capture on")
+
+ self.capture_model(self.k_cache, self.v_cache)
+
+ # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
+ self.use_spec_dec = self.inference_config.use_spec_dec
+
+ self.drafter_model = None
+ self.drafter = None
+ self.use_glide = False
+ self.n_spec_tokens = self.inference_config.max_n_spec_tokens
+
+ self._verify_args()
+
+ def init_model(
+ self,
+ model_or_path: Union[nn.Module, str],
+ model_policy: Union[Policy, Type[Policy]] = None,
+ model_shard_infer_config: ModelShardInferenceConfig = None,
+ ):
+ """
+ Shard model or/and Load weight
+
+ Args:
+ model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
+ model_policy (Policy): the policy to replace the model.
+ model_inference_config: the configuration for modeling initialization when inference.
+ model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
+ """
+ pretrained_path = None
+ if isinstance(model_or_path, str):
+ import colossalai.interface.pretrained as pretrained_utils
+
+ try:
+ hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
+ arch = getattr(hf_config, "architectures")[0]
+ if arch in _supported_models.keys():
+ if arch == "BaichuanForCausalLM":
+ self.logger.warning(
+ "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
+ )
+ ctx = LazyInitContext(default_device="cuda")
+ with ctx:
+ model = _supported_models[arch].from_pretrained(
+ model_or_path, trust_remote_code=True, torch_dtype=self.dtype
+ )
+ pretrained_path = pretrained_utils.get_pretrained_path(model)
+ else:
+ # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
+ raise ValueError(f"Model {arch} is not supported.")
+
+ except Exception as e:
+ self.logger.error(
+ f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
+ )
+ else:
+ model = model_or_path
+
+ self.model_config = model.config
+
+ torch.cuda.empty_cache()
+ init_gpu_memory = torch.cuda.mem_get_info()[0]
+
+ self.device = get_accelerator().get_current_device()
+ if self.verbose:
+ self.logger.info(f"the device is {self.device}")
+
+ model = model.to(self.dtype).eval()
+
+ if self.verbose:
+ self.logger.info(
+ f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
+ )
+
+ if model_policy is None:
+ prefix = "nopadding" if not self.inference_config.pad_input else "padding"
+ model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
+ model_policy = model_policy_map.get(model_policy_key)
+
+ if not isinstance(model_policy, Policy):
+ try:
+ model_policy = model_policy()
+ except Exception as e:
+ raise ValueError(f"Unable to instantiate model policy: {e}")
+
+ assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
+ pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
+ tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
+
+ self.model = self._shardformer(
+ model,
+ model_policy,
+ model_shard_infer_config,
+ None,
+ tp_group=tp_group,
+ )
+
+ self.model = ModelWrapper(model).to(self.device)
+
+ if self.verbose:
+ self.logger.info(
+ f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
+ )
+
+ if pretrained_path:
+ from colossalai.inference.core.plugin import InferCheckpoint_io
+
+ cpt_io = InferCheckpoint_io()
+ if_has_index_file, model_index_file = has_index_file(pretrained_path)
+ assert if_has_index_file, "the model path is invalid"
+ cpt_io.load_model(self.model, model_index_file)
+
+ free_gpu_memory, _ = torch.cuda.mem_get_info()
+ peak_memory = init_gpu_memory - free_gpu_memory
+ if self.verbose:
+ self.logger.info(
+ f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
+ )
+
+ @torch.inference_mode()
+ def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
+ assert self.use_cuda_graph, "please turn on the cuda graph"
+
+ if self.verbose:
+ self.logger.info("Colossal AI CUDA Graph Capture begin")
+
+ t_capture_begin = time.perf_counter()
+
+ block_size = self.inference_config.block_size
+ head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
+
+ # Prepare dummy inputs. These will be reused for all batch sizes.
+ max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
+ max_context_len_to_capture = self.inference_config.max_context_len_to_capture
+ max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
+ input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
+ # self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
+ self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
+ self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
+ self.graph_block_tables[0, :] = np.arange(
+ 0, max_num_blocks
+ ) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
+ block_tables = torch.from_numpy(self.graph_block_tables).cuda()
+ output_tensor = torch.zeros(
+ (max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
+ )
+ fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
+
+ max_num_seqs = self.inference_config.max_batch_size
+ batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
+ sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
+ # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
+ sequence_lengths[0] = torch.tensor(
+ self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
+ ).cuda()
+
+ # NOTE: Capturing the largest batch size first may help reduce the
+ # memory usage of CUDA graph.
+ for batch_size in reversed(batch_size_capture_list):
+ if self.verbose:
+ self.logger.info(f"batch size {batch_size} graph capturing")
+
+ input_meta_data = InputMetaData(
+ block_tables=block_tables[:batch_size],
+ sequence_lengths=sequence_lengths[:batch_size],
+ fd_inter_tensor=fd_inter_tensor,
+ batch_size=batch_size,
+ is_prompts=False,
+ use_cuda_graph=True,
+ high_precision=False,
+ kv_seq_len=sequence_lengths[:batch_size].max().item(),
+ head_dim=head_dim,
+ dtype=self.dtype,
+ )
+
+ graph_runner = CUDAGraphRunner(self.model)
+ graph_runner.capture(
+ input_tokens_ids[:batch_size],
+ output_tensor[:batch_size],
+ input_meta_data,
+ k_caches=k_cache,
+ v_caches=v_cache,
+ memory_pool=self.graph_memory_pool,
+ )
+ self.graph_memory_pool = graph_runner.graph.pool()
+ self.graph_runners[batch_size] = graph_runner
+
+ t_capture_end = time.perf_counter()
+
+ if self.verbose:
+ self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
+
+ def _verify_args(self) -> None:
+ """Verify the input args"""
+ if not isinstance(self.inference_config, InferenceConfig):
+ raise TypeError("Invalid type of inference config provided.")
+ if not isinstance(self.model, nn.Module):
+ raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
+ if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
+ raise TypeError(
+ f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
+ )
+ if isinstance(self.model, ModelWrapper):
+ model = self.model.module
+ assert (
+ model.__class__.__name__ in _supported_models.keys()
+ ), f"Model {self.model.__class__.__name__} is not supported."
+
+ def enable_spec_dec(
+ self,
+ drafter_model: nn.Module = None,
+ n_spec_tokens: int = None,
+ use_glide_drafter: bool = False,
+ ) -> None:
+ """Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
+
+ Args:
+ drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
+ If provided, the previous drafter and drafter model, if exist, will be overwritten.
+ n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
+ If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
+ use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
+ If True, the drafter model will be replaced by a glide model.
+
+ ```python
+ ...
+ engine = InferenceEngine(model, tokenizer, inference_config)
+
+ engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
+ engine.generate(...) # Speculative Decoding
+
+ engine.disable_spec_dec()
+ engine.generate(...) # Normal generation
+
+ engine.enable_spec_dec()
+ engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
+ engine.clear_spec_dec()
+ ```
+ """
+
+ if drafter_model is None and self.drafter is None:
+ raise ValueError("Drafter not initialized. Please provide a Drafter Model")
+ if n_spec_tokens is not None:
+ assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
+ self.n_spec_tokens = n_spec_tokens
+ if drafter_model is not None:
+ assert isinstance(drafter_model, nn.Module)
+ # overwrite the drafter, if exists
+ self.clear_spec_dec()
+ self.drafter_model = drafter_model
+ self.drafter = Drafter(
+ self.drafter_model,
+ self.tokenizer,
+ device=self.device,
+ dtype=self.dtype,
+ )
+
+ # check if the provided drafter model is compatible with GLIDE structure
+ # when `use_glide_drafter` is set to True
+ if (
+ use_glide_drafter
+ and hasattr(drafter_model, "model")
+ and hasattr(drafter_model.model, "layers")
+ and hasattr(drafter_model.model.layers[0], "cross_attn")
+ ):
+ self.use_glide = use_glide_drafter
+ elif use_glide_drafter:
+ self.logger.warning(
+ f"`use_glide_drafter` is provided as {use_glide_drafter}, "
+ f"but the provided drafter model is not compatible with GLIDE structure."
+ f"Falling back to use the default drafter model (non-GLIDE)."
+ )
+ self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
+ # using speculative decoding for subsequent generations
+ self.use_spec_dec = True
+
+ def disable_spec_dec(self) -> None:
+ """Disable using speculative decoding for subsequent generations."""
+ self.request_handler.unset_spec_dec_mode()
+ # set back to the maximum number of tokens to speculate
+ self.n_spec_tokens = self.inference_config.max_n_spec_tokens
+ self.use_glide = False
+ self.use_spec_dec = False
+
+ def clear_spec_dec(self) -> None:
+ """Clear relatable structures of speculative decoding, if exist."""
+ if self.use_spec_dec:
+ self.disable_spec_dec()
+ if self.drafter_model or self.drafter:
+ self.drafter_model = None
+ self.drafter = None
+ torch.cuda.empty_cache()
+ self.use_glide = False
+ self.use_spec_dec = False
+
+ def steps_spec_dec(self) -> List[Sequence]:
+ """
+ Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
+ with many steps of speculating by a drafter model as well as verifying by a main model.
+
+ Returns:
+ List[Sequence]: finished sequences generated by one step.
+ """
+ batch = self.request_handler.schedule() # prefill batch
+ assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
+
+ input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
+
+ if input_meta_data.use_cuda_graph:
+ model_executable = self.graph_runners[input_meta_data.batch_size]
+ else:
+ model_executable = self.model
+
+ # 1. Prefill small model (Drafter) - fill past kv cache for drafter model
+ # NOTE For glide drafter models, we won't actually apply glide during prefill stage
+ drafter_out = self.drafter.speculate(input_token_ids, 1, None)
+ next_token_ids_spec = drafter_out.next_tokens
+ drafter_past_key_values = drafter_out.past_key_values
+
+ # 2. Prefill main model (Verifier) - fill past kv cache for main model
+ logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
+ next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
+ # append new inputs to the batch, temporarily
+ batch.append_batch_tokens(next_tokens)
+ self.request_handler.allocate_batch_spec_dec(batch, 1)
+ already_allocated_kv_len = batch.seq_lengths[0].item()
+ input_token_ids = batch.get_1D_inputs_spec_dec(1)
+
+ finished_sequences = self.request_handler.update()
+
+ while True:
+ # HACK Retrieve the running batch
+ # Using RequestHandler.schedule here will re-allocate same kv cache for the batch
+ batch = self.request_handler.running_bb # running batch
+ assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
+
+ # 3. Decoding - Drafter model speculates `n` tokens
+ glide_input = None
+ if self.use_glide:
+ glide_input = GlideInput(
+ batch.get_block_table_tensor(),
+ self.k_cache[-1], # use kv cahces of the last layer
+ self.v_cache[-1],
+ batch.get_sequence_lengths(),
+ n_spec_tokens=self.n_spec_tokens,
+ )
+
+ drafter_out = self.drafter.speculate(
+ input_token_ids,
+ self.n_spec_tokens,
+ drafter_past_key_values,
+ glide_input=glide_input,
+ )
+ next_token_ids_spec = drafter_out.next_tokens
+ drafter_past_key_values = drafter_out.past_key_values
+ drafter_spec_length = drafter_out.speculated_length
+
+ for next_token_id_spec in next_token_ids_spec:
+ self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
+ cur_length = batch.seq_lengths[0].item()
+ if already_allocated_kv_len < cur_length:
+ self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
+ already_allocated_kv_len = cur_length
+
+ # 4. Decoding - Main model verifies `n` tokens in parallel
+ if drafter_spec_length < batch.num_tokens_to_verify:
+ batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
+ input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
+ logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
+
+ next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
+
+ # 5. Compare and process the results
+ diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
+ n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
+
+ # revoke appended tokens for each Sequence in the current batch
+ batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
+
+ # append the last correct token generated by the main model
+ self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
+
+ # trim past key values of the drafter model
+ drafter_past_key_values = Drafter.trim_kv_cache(
+ drafter_past_key_values, drafter_spec_length - n_matches - 1
+ )
+
+ # prepare inputs for the next round of speculation
+ n = 1 if n_matches < drafter_spec_length else 2
+ input_token_ids = batch.get_1D_inputs_spec_dec(n)
+
+ self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
+ finished_sequences = self.request_handler.update()
+ if len(finished_sequences) > 0:
+ break
+
+ # Reset back the number of speculated tokens of the batch,
+ # this is used to handle the last round of speculation, in which case the number of speculated tokens
+ # by the drafter is less than the number of speculated tokens set to the engine.
+ batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
+
+ return finished_sequences
+
+ def generate(
+ self,
+ request_ids: Union[List[int], int] = None,
+ prompts: Union[List[str], str] = None,
+ prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
+ return_token_ids: bool = False,
+ generation_config: Optional[GenerationConfig] = None,
+ ) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
+ """
+ Executing the inference step.
+
+ Args:
+ request_ids (List[int], optional): The request ID. Defaults to None.
+ prompts (Union[List[str], optional): Input prompts. Defaults to None.
+ prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
+ return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
+ generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
+
+ Returns:
+ Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
+ """
+
+ gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
+ prompts = [prompts] if isinstance(prompts, str) else prompts
+ request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
+
+ with torch.inference_mode():
+ if prompts is not None or prompts_token_ids is not None:
+ self.add_request(
+ request_ids=request_ids,
+ prompts=prompts,
+ prompts_token_ids=prompts_token_ids,
+ **gen_config_dict,
+ )
+
+ output_seqs_list = []
+ total_tokens_list = []
+
+ # intuition: If user provide a generation config, we should replace the existing one.
+ if generation_config is not None:
+ self.generation_config = generation_config
+ self.generation_config_dict = gen_config_dict
+
+ if self.use_spec_dec:
+ assert self.drafter is not None, "Drafter Model is not initialized."
+ while self.request_handler.check_unfinished_reqs():
+ output_seqs_list += self.steps_spec_dec()
+ else:
+ while self.request_handler.check_unfinished_reqs():
+ output_seqs_list += self.step()
+
+ output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
+
+ for seq in output_seqs_list:
+ total_tokens_list.append(seq.input_token_id + seq.output_token_id)
+
+ output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
+
+ if return_token_ids:
+ output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
+ return output_str, output_tokens_list
+ else:
+ return output_str
+
+ @property
+ def has_prompt_template(self) -> bool:
+ """ """
+ return self.inference_config.prompt_template is not None
+
+ def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
+ """
+ This method will format the input prompt according to the prompt template given to the InferenceConfig.
+ """
+ assert (
+ self.has_prompt_template
+ ), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
+
+ if isinstance(prompts, (list, tuple)):
+ return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
+ elif isinstance(prompts, str):
+ return self.inference_config.prompt_template.format(input_text=prompts)
+ else:
+ raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
+
+ def add_request(
+ self,
+ request_ids: Union[List[int], int] = None,
+ prompts: Union[List[str], str] = None,
+ prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
+ **kwargs,
+ ) -> None:
+ """
+ Add requests.
+
+ Args:
+ request_ids (List[int], optional): The request ID. Defaults to None.
+ prompts (Union[List[str], optional): Input prompts. Defaults to None.
+ prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
+ """
+
+ # apply the prompt template to the input prompts
+
+ if self.has_prompt_template and prompts is not None:
+ prompts = self.format_prompt(prompts)
+
+ block_size = self.inference_config.block_size
+
+ if request_ids is not None and not isinstance(request_ids, list):
+ request_ids = [request_ids]
+
+ if prompts is not None and not isinstance(prompts, list):
+ prompts = [prompts]
+
+ if prompts_token_ids is None:
+ assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
+ prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
+ "input_ids"
+ ]
+
+ # list of torch Tensor
+ if isinstance(prompts_token_ids, list):
+ if isinstance(prompts_token_ids[0], torch.Tensor):
+ prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
+ elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
+ prompts_token_ids = prompts_token_ids.tolist()
+ else:
+ raise TypeError(
+ f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
+ )
+
+ assert (
+ len(prompts_token_ids[0]) <= self.inference_config.max_input_len
+ ), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
+
+ prompts_num = len(prompts_token_ids)
+
+ for i in range(prompts_num):
+ if request_ids:
+ assert isinstance(
+ request_ids[0], int
+ ), f"The request_id type must be int, but got {type(request_ids[0])}"
+ assert len(request_ids) == prompts_num
+ request_id = request_ids[i]
+ else:
+ request_id = next(self.counter)
+ if prompts == None:
+ prompt = None
+ else:
+ prompt = prompts[i]
+
+ max_length = kwargs.get("max_length", None)
+ max_new_tokens = kwargs.get("max_new_tokens", None)
+ if max_length is None and max_new_tokens is None:
+ max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
+ elif max_length is not None:
+ max_new_tokens = max_length - len(prompts_token_ids[i])
+
+ if not self.inference_config.enable_streamingllm:
+ assert (
+ self.inference_config.max_output_len >= max_new_tokens
+ ), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
+
+ sequence = Sequence(
+ request_id,
+ prompt,
+ prompts_token_ids[i],
+ block_size,
+ None,
+ self.tokenizer.eos_token_id,
+ self.tokenizer.pad_token_id,
+ max_output_len=max_new_tokens,
+ ignore_eos=self.inference_config.ignore_eos,
+ )
+ self.request_handler.add_sequence(sequence)
+
+ def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
+ input_ids = batch.get_1D_inputs()
+ sequence_lengths = batch.get_sequence_lengths()
+
+ if batch.is_prompts:
+ n_tokens = sequence_lengths.sum().item()
+ else:
+ n_tokens = batch.current_batch_size
+ if batch.use_spec_dec:
+ n_tokens = batch.num_tokens_to_verify + 1
+ assert n_tokens == input_ids.size(0)
+ n_tokens = n_tokens * batch.current_batch_size
+ output_tensor = torch.zeros(
+ (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
+ )
+
+ batch_token_ids = None
+ if (
+ self.generation_config.repetition_penalty != 1.0
+ or self.generation_config.no_repeat_ngram_size > 0
+ or self.generation_config.forced_eos_token_id is not None
+ ):
+ batch_token_ids = batch.batch_token_ids
+
+ # only when we have the graph for specific decoding batch size can we use the cuda graph for inference
+ use_cuda_graph = False
+ if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
+ use_cuda_graph = True
+
+ input_meta_data = InputMetaData(
+ block_tables=batch.get_block_table_tensor(),
+ sequence_lengths=sequence_lengths,
+ fd_inter_tensor=batch.fd_inter_tensor,
+ batch_size=batch.current_batch_size,
+ is_prompts=batch.is_prompts,
+ use_cuda_kernel=self.inference_config.use_cuda_kernel,
+ use_cuda_graph=use_cuda_graph,
+ high_precision=self.high_precision,
+ kv_seq_len=sequence_lengths.max().item(),
+ head_dim=batch.head_dim,
+ dtype=batch.dtype,
+ use_spec_dec=batch.use_spec_dec,
+ num_tokens_to_verify=batch.num_tokens_to_verify,
+ batch_token_ids=batch_token_ids,
+ )
+
+ return input_ids, output_tensor, input_meta_data
+
+ def step(self) -> List[str]:
+ """
+ In each step, do the follows:
+ 1. Run RequestHandler.schedule() and get the batch used for inference.
+ 2. Get the input, inputinfo and output placeholder from the batchbucket
+ 3. Run model to generate the next token
+ 4. Update waiting list and running list in RequestHandler and get finished sequences.
+ 5. Decode and return finished sequences.
+
+ Returns:
+ List[str]: Decoded finished sequences generated by one step.
+ """
+
+ batch = self.request_handler.schedule()
+
+ input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
+
+ if input_meta_data.use_cuda_graph:
+ model_executable = self.graph_runners[input_meta_data.batch_size]
+ else:
+ model_executable = self.model
+
+ # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
+ logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
+ if self.inference_config.pad_input:
+ logits = logits[:, -1, :]
+
+ if self.inference_config.enable_streamingllm:
+ updated_block_ids = batch.streamingllm_update_batch(
+ self.inference_config.start_token_size, self.inference_config.generated_token_size
+ )
+ self.request_handler.streamingllm_free_block_tables(updated_block_ids)
+
+ next_tokens = search_tokens(
+ self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
+ )
+ self.request_handler.append_next_tokens(next_tokens)
+ finished_sequences = self.request_handler.update()
+
+ return finished_sequences
diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py
index 512eaea71..393347c31 100644
--- a/colossalai/inference/core/request_handler.py
+++ b/colossalai/inference/core/request_handler.py
@@ -8,7 +8,7 @@ from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
-from colossalai.inference.struct import RequestStatus, Sequence
+from colossalai.inference.struct import DiffusionSequence, RequestStatus, Sequence
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
@@ -98,7 +98,46 @@ class RunningList:
self._decoding[seq_id] = self._prefill.pop(seq_id)
-class RequestHandler:
+class NaiveRequestHandler:
+ def __init__(self) -> None:
+ self.running_list: List[DiffusionSequence] = []
+ self.waiting_list: List[str] = []
+
+ def _has_waiting(self) -> bool:
+ return any(lst for lst in self.waiting_list)
+
+ def _has_running(self) -> bool:
+ return any(lst for lst in self.running_list)
+
+ def check_unfinished_reqs(self):
+ return self._has_waiting() or self._has_running()
+
+ def add_sequence(self, seq: DiffusionSequence):
+ """
+ Add the request to waiting list.
+ """
+ assert not self._find_sequence(seq.request_id), f"Sequence {seq.request_id} already exists."
+ self.waiting_list.append(seq)
+
+ def _find_sequence(self, request_id: int) -> DiffusionSequence:
+ """
+ Find the request by request_id.
+ """
+ for lst in enumerate(self.waiting_list + self.running_list):
+ for seq in lst:
+ if seq.request_id == request_id:
+ return seq
+ return None
+
+ def schedule(self):
+ ret = None
+ if self._has_waiting:
+ ret = self.waiting_list[0]
+ self.waiting_list = self.waiting_list[1:]
+ return ret
+
+
+class RequestHandler(NaiveRequestHandler):
"""
RequestHandler is the core for handling existing requests and updating current batch.
During generation process, we call schedule function each iteration to update current batch.
@@ -176,12 +215,12 @@ class RequestHandler:
generated_token_size=inference_config.generated_token_size,
)
+ def _has_running(self) -> bool:
+ return not self.running_bb.is_empty()
+
def _init_cache(self, model_config):
self.cache_manager = KVCacheManager(self.inference_config, model_config)
- def _has_waiting(self) -> bool:
- return any(lst for lst in self.waiting_list)
-
def get_kvcache(self):
return self.cache_manager.get_kv_cache()
@@ -318,7 +357,7 @@ class RequestHandler:
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
seq.mark_finished()
- def check_unfinished_seqs(self) -> bool:
+ def check_unfinished_reqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty()
def total_requests_in_batch_bucket(self) -> int:
diff --git a/colossalai/inference/modeling/models/diffusion.py b/colossalai/inference/modeling/models/diffusion.py
new file mode 100644
index 000000000..9dc90733d
--- /dev/null
+++ b/colossalai/inference/modeling/models/diffusion.py
@@ -0,0 +1,54 @@
+import inspect
+import types
+
+import torch
+from torch import nn
+
+
+class DiffusionPipe(nn.Module):
+ """
+ This Class convert a class of `DiffusionPipeline` into `nn.Module` and reserve most of origin attr,function and property.
+ """
+
+ def __init__(self, source_obj) -> None:
+ super(DiffusionPipe, self).__init__()
+
+ for k, v in source_obj.__dict__.items():
+ if isinstance(v, nn.Module):
+ self.add_module(k, v)
+ else:
+ setattr(self, k, v)
+
+ skip_list = ["_execution_device", "to", "device"] # this
+
+ for name, member in inspect.getmembers(source_obj.__class__):
+ if name in skip_list:
+ continue
+ if not name.startswith("__") and not name.endswith("__"):
+ if isinstance(member, property):
+ setattr(self.__class__, name, member)
+ elif inspect.isfunction(member) or inspect.ismethod(member):
+ bound_method = types.MethodType(member, self)
+ setattr(self, name, bound_method)
+ elif not callable(member) and not isinstance(member, property):
+ setattr(self, name, member)
+ elif name == "__call__":
+ bound_method = types.MethodType(member, self)
+ setattr(self, "_forward", bound_method)
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
+ Accelerate's module hooks.
+ """
+ # return self.device
+ return torch.device("cuda")
+
+ @property
+ def device(self):
+ next(self.parameters()).device
+
+ def forward(self, *args, **kwargs):
+ return self._forward(*args, **kwargs)
diff --git a/colossalai/inference/modeling/models/pixart_alpha.py b/colossalai/inference/modeling/models/pixart_alpha.py
new file mode 100644
index 000000000..d5774946e
--- /dev/null
+++ b/colossalai/inference/modeling/models/pixart_alpha.py
@@ -0,0 +1,220 @@
+# Code adapted from:
+# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
+
+from typing import Callable, List, Optional, Union
+
+import PIL.Image
+import torch
+from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import (
+ ASPECT_RATIO_256_BIN,
+ ASPECT_RATIO_512_BIN,
+ ASPECT_RATIO_1024_BIN,
+)
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
+
+from colossalai.logging import get_dist_logger
+
+from .diffusion import DiffusionPipe
+
+logger = get_dist_logger(__name__)
+
+
+@torch.no_grad()
+def pixart_alpha_forward(
+ self: DiffusionPipe,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 20,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ guidance_scale: float = 4.5,
+ num_images_per_prompt: Optional[int] = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ clean_caption: bool = True,
+ use_resolution_binning: bool = True,
+ max_sequence_length: int = 120,
+ **kwargs,
+) -> PIL.Image:
+ # 1. Check inputs. Raise error if not correct
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
+ if use_resolution_binning:
+ if self.transformer.config.sample_size == 128:
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
+ elif self.transformer.config.sample_size == 64:
+ aspect_ratio_bin = ASPECT_RATIO_512_BIN
+ elif self.transformer.config.sample_size == 32:
+ aspect_ratio_bin = ASPECT_RATIO_256_BIN
+ else:
+ raise ValueError("Invalid sample size")
+ orig_height, orig_width = height, width
+ height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
+
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_steps,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ )
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas)
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6.1 Prepare micro-conditions.
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
+ if self.transformer.config.sample_size == 128:
+ resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
+ aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
+ resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
+ aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
+
+ if do_classifier_free_guidance:
+ resolution = torch.cat([resolution, resolution], dim=0)
+ aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
+
+ added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ current_timestep = t
+ if not torch.is_tensor(current_timestep):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = latent_model_input.device.type == "mps"
+ if isinstance(current_timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
+ elif len(current_timestep.shape) == 0:
+ current_timestep = current_timestep[None].to(latent_model_input.device)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=current_timestep,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # learned sigma
+ if self.transformer.config.out_channels // 2 == latent_channels:
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
+ else:
+ noise_pred = noise_pred
+
+ # compute previous image: x_t -> x_t-1
+ if num_inference_steps == 1:
+ # For DMD one step sampling: https://arxiv.org/abs/2311.18828
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
+ else:
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ output_type = "pil" # TODO(@lry89757) temporarily image, please support more return output
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ if use_resolution_binning:
+ image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ # self.maybe_free_model_hooks()
+
+ return image
diff --git a/colossalai/inference/modeling/models/stablediffusion3.py b/colossalai/inference/modeling/models/stablediffusion3.py
new file mode 100644
index 000000000..d1c63a6dc
--- /dev/null
+++ b/colossalai/inference/modeling/models/stablediffusion3.py
@@ -0,0 +1,178 @@
+# This code is adapted from huggingface diffusers: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
+
+from .diffusion import DiffusionPipe
+
+
+# TODO(@lry89757) temporarily image, please support more return output
+@torch.no_grad()
+def sd3_forward(
+ self: DiffusionPipe,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ prompt_3: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+):
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ prompt_3,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_3=prompt_3,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ clip_skip=self.clip_skip,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ pooled_projections=pooled_prompt_embeds,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ return image
diff --git a/colossalai/inference/modeling/policy/__init__.py b/colossalai/inference/modeling/policy/__init__.py
index fa0395590..02ffadd9f 100644
--- a/colossalai/inference/modeling/policy/__init__.py
+++ b/colossalai/inference/modeling/policy/__init__.py
@@ -1,16 +1,22 @@
from .glide_llama import GlideLlamaModelPolicy
from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
+from .pixart_alpha import PixArtAlphaInferPolicy
+from .stablediffusion3 import StableDiffusion3InferPolicy
model_policy_map = {
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
"nopadding_baichuan": NoPaddingBaichuanModelInferPolicy,
"glide_llama": GlideLlamaModelPolicy,
+ "StableDiffusion3Pipeline": StableDiffusion3InferPolicy,
+ "PixArtAlphaPipeline": PixArtAlphaInferPolicy,
}
__all__ = [
"NoPaddingLlamaModelInferPolicy",
"NoPaddingBaichuanModelInferPolicy",
"GlideLlamaModelPolicy",
+ "StableDiffusion3InferPolicy",
+ "PixArtAlphaInferPolicy",
"model_polic_map",
]
diff --git a/colossalai/inference/modeling/policy/pixart_alpha.py b/colossalai/inference/modeling/policy/pixart_alpha.py
new file mode 100644
index 000000000..356056ba7
--- /dev/null
+++ b/colossalai/inference/modeling/policy/pixart_alpha.py
@@ -0,0 +1,34 @@
+from torch import nn
+
+from colossalai.inference.config import RPC_PARAM
+from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
+from colossalai.shardformer.policies.base_policy import Policy
+
+
+class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = {}
+ self.append_or_create_method_replacement(
+ description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
+ )
+ return policy
+
+ def preprocess(self) -> nn.Module:
+ return self.model
+
+ def postprocess(self):
+ return self.model
+
+ def config_sanity_check(self):
+ pass
+
+ def to_rpc_param(self) -> str:
+ return __class__.__name__
+
+ @staticmethod
+ def from_rpc_param() -> "PixArtAlphaInferPolicy":
+ return PixArtAlphaInferPolicy()
diff --git a/colossalai/inference/modeling/policy/stablediffusion3.py b/colossalai/inference/modeling/policy/stablediffusion3.py
new file mode 100644
index 000000000..c9877f7dc
--- /dev/null
+++ b/colossalai/inference/modeling/policy/stablediffusion3.py
@@ -0,0 +1,34 @@
+from torch import nn
+
+from colossalai.inference.config import RPC_PARAM
+from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward
+from colossalai.shardformer.policies.base_policy import Policy
+
+
+class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = {}
+ self.append_or_create_method_replacement(
+ description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe
+ )
+ return policy
+
+ def preprocess(self) -> nn.Module:
+ return self.model
+
+ def postprocess(self):
+ return self.model
+
+ def config_sanity_check(self):
+ pass
+
+ def to_rpc_param(self) -> str:
+ return __class__.__name__
+
+ @staticmethod
+ def from_rpc_param() -> "StableDiffusion3InferPolicy":
+ return StableDiffusion3InferPolicy()
diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py
index 1a3094a27..65d284296 100644
--- a/colossalai/inference/struct.py
+++ b/colossalai/inference/struct.py
@@ -2,6 +2,7 @@ import enum
from dataclasses import dataclass
from typing import Any, List
+from colossalai.inference.config import DiffusionGenerationConfig
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
@@ -46,6 +47,17 @@ class RequestStatus(enum.Enum):
return status == RequestStatus.WAITING
+@dataclass
+class DiffusionSequence:
+ """
+ parameters for diffusion
+ """
+
+ request_id: int
+ prompt: str
+ generation_config: DiffusionGenerationConfig
+
+
@dataclass
class Sequence:
"""Store information of input sequence.
diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py
index 332e84d37..f2a0fc037 100644
--- a/colossalai/inference/utils.py
+++ b/colossalai/inference/utils.py
@@ -5,10 +5,12 @@ Utils for model inference
import math
import os
import re
+from enum import Enum
from pathlib import Path
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Union
import torch
+from diffusers import DiffusionPipeline
from torch import nn
from colossalai.logging import get_dist_logger
@@ -159,3 +161,38 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool:
except ImportError:
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
return False
+
+
+class ModelType(Enum):
+ DIFFUSION_MODEL = "Diffusion Model"
+ LLM = "Large Language Model (LLM)"
+ UNKNOWN = "Unknown Model Type"
+
+
+def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]):
+ if isinstance(model_or_path, DiffusionPipeline):
+ return ModelType.DIFFUSION_MODEL
+ elif isinstance(model_or_path, nn.Module):
+ return ModelType.LLM
+ elif isinstance(model_or_path, str):
+ try:
+ from transformers import AutoConfig
+
+ hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
+ return ModelType.LLM
+ except:
+ """
+ model type is not `ModelType.LLM`
+ """
+
+ try:
+ from diffusers import DiffusionPipeline
+
+ DiffusionPipeline.load_config(model_or_path)
+ return ModelType.DIFFUSION_MODEL
+ except:
+ """
+ model type is not `ModelType.DIFFUSION_MODEL`
+ """
+ else:
+ return ModelType.UNKNOWN
diff --git a/examples/inference/stable_diffusion/sd3_generation.py b/examples/inference/stable_diffusion/sd3_generation.py
new file mode 100644
index 000000000..fe989eed7
--- /dev/null
+++ b/examples/inference/stable_diffusion/sd3_generation.py
@@ -0,0 +1,75 @@
+import argparse
+
+from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline
+from torch import bfloat16, float16, float32
+
+import colossalai
+from colossalai.cluster import DistCoordinator
+from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
+from colossalai.inference.core.engine import InferenceEngine
+from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy
+from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy
+
+# For Stable Diffusion 3, we'll use the following configuration
+MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0]
+POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0]
+
+TORCH_DTYPE_MAP = {
+ "fp16": float16,
+ "fp32": float32,
+ "bf16": bfloat16,
+}
+
+
+def infer(args):
+ # ==============================
+ # Launch colossalai, setup distributed environment
+ # ==============================
+ colossalai.launch_from_torch()
+ coordinator = DistCoordinator()
+
+ # ==============================
+ # Load model and tokenizer
+ # ==============================
+ model_path_or_name = args.model
+ model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None))
+
+ # ==============================
+ # Initialize InferenceEngine
+ # ==============================
+ coordinator.print_on_master(f"Initializing Inference Engine...")
+ inference_config = InferenceConfig(
+ dtype=args.dtype,
+ max_batch_size=args.max_batch_size,
+ tp_size=args.tp_size,
+ use_cuda_kernel=args.use_cuda_kernel,
+ )
+ engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True)
+
+ # ==============================
+ # Generation
+ # ==============================
+ coordinator.print_on_master(f"Generating...")
+ out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
+ out.save("cat.jpg")
+ coordinator.print_on_master(out)
+
+
+# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
+# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
+
+
+if __name__ == "__main__":
+ # ==============================
+ # Parse Arguments
+ # ==============================
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-m", "--model", type=str, help="Path to the model or model name")
+ parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size")
+ parser.add_argument("-p", "--prompt", type=str, default="A cat holding a sign that says hello world", help="Prompt")
+ parser.add_argument("-b", "--max_batch_size", type=int, default=1, help="Max batch size")
+ parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
+ parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
+ args = parser.parse_args()
+
+ infer(args)
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index 27bbc3769..b54d1cf91 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -23,3 +23,4 @@ rpyc==6.0.0
fastapi
uvicorn==0.29.0
galore_torch
+diffusers==0.29.0
From 66abf1c6e89860b55e2f26a847dd86f8fecfc863 Mon Sep 17 00:00:00 2001
From: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Date: Mon, 8 Jul 2024 22:32:06 +0800
Subject: [PATCH 14/15] [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
colossalai/inference/core/llm_engine.py | 6 +++---
colossalai/inference/utils.py | 2 --
examples/inference/stable_diffusion/test_ci.sh | 2 ++
requirements/requirements-test.txt | 1 -
4 files changed, 5 insertions(+), 6 deletions(-)
create mode 100644 examples/inference/stable_diffusion/test_ci.sh
diff --git a/colossalai/inference/core/llm_engine.py b/colossalai/inference/core/llm_engine.py
index b973d371d..1dbc3ace8 100644
--- a/colossalai/inference/core/llm_engine.py
+++ b/colossalai/inference/core/llm_engine.py
@@ -57,11 +57,11 @@ class LLMEngine(BaseEngine):
def __init__(
self,
- model_or_path: nn.Module | str,
- tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None,
+ model_or_path: Union[nn.Module, str],
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
inference_config: InferenceConfig = None,
verbose: bool = False,
- model_policy: Policy | type[Policy] = None,
+ model_policy: Union[Policy, type[Policy]] = None,
) -> None:
self.inference_config = inference_config
self.dtype = inference_config.dtype
diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py
index f2a0fc037..d0851e362 100644
--- a/colossalai/inference/utils.py
+++ b/colossalai/inference/utils.py
@@ -186,8 +186,6 @@ def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]):
"""
try:
- from diffusers import DiffusionPipeline
-
DiffusionPipeline.load_config(model_or_path)
return ModelType.DIFFUSION_MODEL
except:
diff --git a/examples/inference/stable_diffusion/test_ci.sh b/examples/inference/stable_diffusion/test_ci.sh
new file mode 100644
index 000000000..d0189431c
--- /dev/null
+++ b/examples/inference/stable_diffusion/test_ci.sh
@@ -0,0 +1,2 @@
+#!/bin/bash
+echo "Skip the test (this test is slow)"
diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt
index e4affc7f5..93a3690fe 100644
--- a/requirements/requirements-test.txt
+++ b/requirements/requirements-test.txt
@@ -1,4 +1,3 @@
-diffusers
pytest
coverage==7.2.3
git+https://github.com/hpcaitech/pytest-testmon
From fbf33ecd019ce0e075b76b628e6e8a319cfc43e3 Mon Sep 17 00:00:00 2001
From: Edenzzzz
Date: Tue, 9 Jul 2024 18:05:20 +0800
Subject: [PATCH 15/15] [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
.../booster/plugin/hybrid_parallel_plugin.py | 1 +
colossalai/shardformer/layer/__init__.py | 3 +-
colossalai/shardformer/layer/loss.py | 45 +++++++++++-
colossalai/shardformer/modeling/bloom.py | 56 ++++----------
colossalai/shardformer/modeling/command.py | 55 +++-----------
colossalai/shardformer/modeling/gpt2.py | 47 ++----------
colossalai/shardformer/modeling/llama.py | 73 +++++++------------
colossalai/shardformer/modeling/mistral.py | 48 ++----------
colossalai/shardformer/modeling/opt.py | 57 +++------------
colossalai/shardformer/modeling/qwen2.py | 47 ++----------
colossalai/shardformer/policies/llama.py | 8 --
.../test_model/test_shard_llama.py | 31 ++++----
12 files changed, 148 insertions(+), 323 deletions(-)
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index a3d6f1e74..485833398 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -1205,6 +1205,7 @@ class HybridParallelPlugin(PipelinePluginBase):
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)
+ # sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
else:
diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py
index f17fad1b6..331e49729 100644
--- a/colossalai/shardformer/layer/__init__.py
+++ b/colossalai/shardformer/layer/__init__.py
@@ -3,7 +3,7 @@ from .attn import AttnMaskType, ColoAttention
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
-from .loss import cross_entropy_1d
+from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
@@ -18,6 +18,7 @@ __all__ = [
"DropoutForParallelInput",
"DropoutForReplicatedInput",
"cross_entropy_1d",
+ "dist_cross_entropy",
"BaseLayerNorm",
"LayerNorm",
"RMSNorm",
diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py
index a6d19edf5..cea2da03f 100644
--- a/colossalai/shardformer/layer/loss.py
+++ b/colossalai/shardformer/layer/loss.py
@@ -2,8 +2,11 @@ import torch
import torch.distributed as dist
from torch.autograd import Function
from torch.distributed import ProcessGroup
+from torch.nn import CrossEntropyLoss
-__all__ = ["DistCrossEntropy", "cross_entropy_1d"]
+from colossalai.shardformer.shard import ShardConfig
+
+__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
class DistCrossEntropy(Function):
@@ -132,3 +135,43 @@ def cross_entropy_1d(
dtype: torch.dtype = None,
) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)
+
+
+def dist_cross_entropy(
+ labels: torch.Tensor,
+ logits: torch.Tensor,
+ shard_config: ShardConfig,
+ out_features: int,
+ vocab_size: int,
+ dtype: torch.dtype,
+) -> torch.Tensor:
+ """
+ Helper to compute cross entropy loss for most shardformer models,
+ compatible with PP, TP and SP.
+ """
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_labels = shift_labels.view(-1)
+ shift_labels = shift_labels.to(shift_logits.device)
+ if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
+ # Cross entropy with all-reduce for TP
+ new_vocab_size = logits.shape[-1]
+ shift_logits = shift_logits.view(-1, new_vocab_size)
+ loss = cross_entropy_1d(
+ shift_logits,
+ shift_labels,
+ process_group=shard_config.tensor_parallel_process_group,
+ vocab_size=out_features,
+ dtype=dtype,
+ )
+ else:
+ # NOTE if use TP and not parallel_output, the output is gathered.
+ # see VocabParallelLMHead1D
+ shift_logits = shift_logits.view(-1, vocab_size)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ return loss
diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py
index 154143626..26ffef6c5 100644
--- a/colossalai/shardformer/modeling/bloom.py
+++ b/colossalai/shardformer/modeling/bloom.py
@@ -28,7 +28,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
-from ..layer import cross_entropy_1d
+from ..layer import dist_cross_entropy
logger = logging.get_logger(__name__)
@@ -359,30 +359,14 @@ class BloomPipelineForwards:
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states).contiguous()
- loss = None
- if labels is not None:
- # move labels to correct device to enable model parallelism
- labels = labels.to(lm_logits.device)
- # Shift so that tokens < n predict n
- shift_logits = lm_logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- batch_size, seq_length, vocab_size = shift_logits.shape
- # Flatten the tokens
- if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
- new_vocab_size = lm_logits.shape[-1]
- shift_logits = shift_logits.view(-1, new_vocab_size)
- shift_labels = shift_labels.view(-1)
- loss = cross_entropy_1d(
- shift_logits,
- shift_labels,
- process_group=shard_config.tensor_parallel_process_group,
- vocab_size=self.lm_head.out_features,
- dtype=self.transformer.dtype,
- )
- else:
- loss_fct = CrossEntropyLoss()
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- loss = loss_fct(shift_logits, shift_labels.view(-1))
+ loss = dist_cross_entropy(
+ labels,
+ lm_logits,
+ shard_config,
+ self.lm_head.out_features,
+ self.config.vocab_size,
+ self.transformer.dtype,
+ )
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
@@ -1040,24 +1024,10 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
- loss = None
- if labels is not None:
- # move labels to correct device to enable model parallelism
- labels = labels.to(lm_logits.device)
- # Shift so that tokens < n predict n
- shift_logits = lm_logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- new_vocab_size = lm_logits.shape[-1]
- shift_logits = shift_logits.view(-1, new_vocab_size)
- shift_labels = shift_labels.view(-1)
- loss = cross_entropy_1d(
- shift_logits,
- shift_labels,
- process_group=shard_config.tensor_parallel_process_group,
- vocab_size=self.lm_head.out_features,
- dtype=self.transformer.dtype,
- )
+ loss = dist_cross_entropy(
+ labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
+ )
+
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py
index 07a7f6cbf..72f705bc0 100644
--- a/colossalai/shardformer/modeling/command.py
+++ b/colossalai/shardformer/modeling/command.py
@@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
-from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import (
@@ -25,7 +24,7 @@ from colossalai.shardformer.layer._operation import (
)
from colossalai.shardformer.shard import ShardConfig
-from ..layer import ColoAttention, cross_entropy_1d
+from ..layer import ColoAttention, dist_cross_entropy
class CommandPipelineForwards:
@@ -300,29 +299,9 @@ class CommandPipelineForwards:
logits = self.lm_head(hidden_states)
logits = logits * self.logit_scale
logits = logits.float()
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
- new_vocab_size = logits.shape[-1]
- shift_logits = shift_logits.view(-1, new_vocab_size)
- loss = cross_entropy_1d(
- shift_logits,
- shift_labels,
- process_group=shard_config.tensor_parallel_process_group,
- vocab_size=self.lm_head.out_features,
- dtype=self.model.dtype,
- )
- else:
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- loss = loss_fct(shift_logits, shift_labels)
+ loss = dist_cross_entropy(
+ labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
+ )
if not return_dict:
output = (logits,) + outputs[1:]
@@ -658,24 +637,14 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
logits = self.lm_head(hidden_states)
logits = logits * self.logit_scale
logits = logits.float()
-
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- new_vocab_size = logits.shape[-1]
- shift_logits = shift_logits.view(-1, new_vocab_size)
- loss = cross_entropy_1d(
- shift_logits,
- shift_labels,
- process_group=shard_config.tensor_parallel_process_group,
- vocab_size=self.lm_head.out_features,
- dtype=self.model.dtype,
- )
+ loss = dist_cross_entropy(
+ labels,
+ logits,
+ shard_config,
+ self.lm_head.out_features,
+ self.config.vocab_size,
+ self.model.dtype,
+ )
if not return_dict:
output = (logits,) + outputs[1:]
diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py
index aa75bab11..6ecda91c4 100644
--- a/colossalai/shardformer/modeling/gpt2.py
+++ b/colossalai/shardformer/modeling/gpt2.py
@@ -25,7 +25,7 @@ from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
-from ..layer import cross_entropy_1d
+from ..layer import dist_cross_entropy
logger = logging.get_logger(__name__)
@@ -372,27 +372,9 @@ class GPT2PipelineForwards:
hidden_states = outputs[0]
lm_logits = self.lm_head(hidden_states)
- loss = None
- if labels is not None:
- # move labels to correct device to enable model parallelism
- labels = labels.to(lm_logits.device)
- # Shift so that tokens < n predict n
- shift_logits = lm_logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_logits = shift_logits.view(-1, shift_logits.size(-1))
- shift_labels = shift_labels.view(-1)
- if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
- loss = cross_entropy_1d(
- shift_logits,
- shift_labels,
- process_group=shard_config.tensor_parallel_process_group,
- vocab_size=self.lm_head.out_features,
- dtype=self.transformer.dtype,
- )
- else:
- loss = loss_fct(shift_logits, shift_labels)
+ loss = dist_cross_entropy(
+ labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
+ )
if not return_dict:
output = (lm_logits,) + outputs[1:]
@@ -1282,24 +1264,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
-
- loss = None
- if labels is not None:
- # move labels to correct device to enable model parallelism
- labels = labels.to(lm_logits.device)
- # Shift so that tokens < n predict n
- shift_logits = lm_logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- shift_logits = shift_logits.view(-1, shift_logits.size(-1))
- shift_labels = shift_labels.view(-1)
- loss = cross_entropy_1d(
- shift_logits,
- shift_labels,
- process_group=shard_config.tensor_parallel_process_group,
- vocab_size=self.lm_head.out_features,
- dtype=self.transformer.dtype,
- )
+ loss = dist_cross_entropy(
+ labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
+ )
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index bf5ce45a8..54ff8e321 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -31,7 +31,7 @@ from colossalai.shardformer.layer._operation import (
)
from colossalai.shardformer.shard import ShardConfig
-from ..layer import ColoAttention, cross_entropy_1d
+from ..layer import ColoAttention, dist_cross_entropy
class LlamaPipelineForwards:
@@ -86,13 +86,20 @@ class LlamaPipelineForwards:
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
-
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
+ # Support SP + PP
+ sp_mode = shard_config.sequence_parallelism_mode
+ sp_group = shard_config.sequence_parallel_process_group
+ sp_size = shard_config.sequence_parallel_size
+ if sp_mode == "all_to_all" and not stage_manager.is_first_stage():
+ # For correct positions ids. The states will be gather along the seq dim in the attention layer later.
+ seq_length *= sp_size
+
past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
@@ -101,7 +108,7 @@ class LlamaPipelineForwards:
if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
- cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device)
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
seq_length_with_past = seq_length + past_seen_tokens
@@ -118,7 +125,6 @@ class LlamaPipelineForwards:
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
-
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention:
@@ -134,6 +140,13 @@ class LlamaPipelineForwards:
else:
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
+ # Support SP + PP
+ if stage_manager.is_first_stage():
+ if sp_mode in ["ring", "split_gather"]:
+ hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
+ elif sp_mode == "all_to_all":
+ hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
+
if self.gradient_checkpointing and self.training and use_cache:
if use_cache:
logger.warning_once(
@@ -196,6 +209,10 @@ class LlamaPipelineForwards:
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
+ if sp_mode == "ring" or sp_mode == "split_gather":
+ hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
+ elif sp_mode == "all_to_all":
+ hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
# add hidden states from the last decoder layer
if output_hidden_states:
@@ -304,29 +321,9 @@ class LlamaPipelineForwards:
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
- new_vocab_size = logits.shape[-1]
- shift_logits = shift_logits.view(-1, new_vocab_size)
- loss = cross_entropy_1d(
- shift_logits,
- shift_labels,
- process_group=shard_config.tensor_parallel_process_group,
- vocab_size=self.lm_head.out_features,
- dtype=self.model.dtype,
- )
- else:
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- loss = loss_fct(shift_logits, shift_labels)
+ loss = dist_cross_entropy(
+ labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
+ )
if not return_dict:
output = (logits,) + outputs[1:]
@@ -529,7 +526,6 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
-
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@@ -804,24 +800,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
logits = self.lm_head(hidden_states)
logits = logits.float()
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- new_vocab_size = logits.shape[-1]
- shift_logits = shift_logits.view(-1, new_vocab_size)
- loss = cross_entropy_1d(
- shift_logits,
- shift_labels,
- process_group=shard_config.tensor_parallel_process_group,
- vocab_size=self.lm_head.out_features,
- dtype=self.model.dtype,
- )
-
+ loss = dist_cross_entropy(
+ labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
+ )
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py
index 310c2d8e2..82e8ef5f9 100644
--- a/colossalai/shardformer/modeling/mistral.py
+++ b/colossalai/shardformer/modeling/mistral.py
@@ -19,7 +19,7 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
-from ..layer import ColoAttention, cross_entropy_1d
+from ..layer import ColoAttention, dist_cross_entropy
logger = logging.get_logger(__name__)
@@ -275,29 +275,9 @@ class MistralForwards:
logits = self.lm_head(hidden_states)
logits = logits.float()
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
- new_vocab_size = logits.shape[-1]
- shift_logits = shift_logits.view(-1, new_vocab_size)
- loss = cross_entropy_1d(
- shift_logits,
- shift_labels,
- process_group=shard_config.tensor_parallel_process_group,
- vocab_size=self.lm_head.out_features,
- dtype=self.model.dtype,
- )
- else:
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- loss = loss_fct(shift_logits, shift_labels)
+ loss = dist_cross_entropy(
+ labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
+ )
if not return_dict:
output = (logits,) + outputs[1:]
@@ -708,23 +688,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
logits = self.lm_head(hidden_states)
logits = logits.float()
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- new_vocab_size = logits.shape[-1]
- shift_logits = shift_logits.view(-1, new_vocab_size)
- loss = cross_entropy_1d(
- shift_logits,
- shift_labels,
- process_group=shard_config.tensor_parallel_process_group,
- vocab_size=self.lm_head.out_features,
- dtype=self.model.dtype,
- )
+ loss = dist_cross_entropy(
+ labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
+ )
if not return_dict:
output = (logits,) + outputs[1:]
diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py
index b250b4976..636b46cc4 100644
--- a/colossalai/shardformer/modeling/opt.py
+++ b/colossalai/shardformer/modeling/opt.py
@@ -22,7 +22,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.shard import ShardConfig
-from ..layer import cross_entropy_1d
+from ..layer import dist_cross_entropy
logger = logging.get_logger(__name__)
@@ -330,30 +330,14 @@ class OPTPipelineForwards:
)
if stage_manager.is_last_stage():
logits = self.lm_head(outputs[0]).contiguous()
- loss = None
- if labels is not None:
- # move labels to correct device to enable model parallelism
- labels = labels.to(logits.device)
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
-
- if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
- new_vocab_size = logits.shape[-1]
- shift_logits = shift_logits.view(-1, new_vocab_size)
- shift_labels = shift_labels.view(-1)
- loss = cross_entropy_1d(
- shift_logits,
- shift_labels,
- process_group=shard_config.tensor_parallel_process_group,
- vocab_size=self.lm_head.out_features,
- dtype=self.model.decoder.dtype,
- )
- else:
- loss_fct = CrossEntropyLoss()
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
+ loss = dist_cross_entropy(
+ labels,
+ logits,
+ shard_config,
+ self.lm_head.out_features,
+ self.config.vocab_size,
+ self.model.decoder.dtype,
+ )
if not return_dict:
output = (logits,) + outputs[1:]
@@ -971,26 +955,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
)
logits = self.lm_head(outputs[0]).contiguous()
-
- loss = None
- if labels is not None:
- # move labels to correct device to enable model parallelism
- labels = labels.to(logits.device)
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- new_vocab_size = logits.shape[-1]
- shift_logits = shift_logits.view(-1, new_vocab_size)
- loss = cross_entropy_1d(
- shift_logits,
- shift_labels,
- process_group=shard_config.tensor_parallel_process_group,
- vocab_size=self.lm_head.out_features,
- dtype=self.model.decoder.dtype,
- )
+ loss = dist_cross_entropy(
+ labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype
+ )
if not return_dict:
output = (logits,) + outputs[1:]
diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py
index 11c26822f..0f253730d 100644
--- a/colossalai/shardformer/modeling/qwen2.py
+++ b/colossalai/shardformer/modeling/qwen2.py
@@ -32,7 +32,7 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
-from ..layer import ColoAttention, cross_entropy_1d
+from ..layer import ColoAttention, dist_cross_entropy
class Qwen2PipelineForwards:
@@ -317,25 +317,9 @@ class Qwen2PipelineForwards:
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- if shard_config.enable_tensor_parallelism:
- new_vocab_size = logits.shape[-1]
- shift_logits = shift_logits.view(-1, new_vocab_size)
- loss = cross_entropy_1d(
- shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
- )
- else:
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- loss = loss_fct(shift_logits, shift_labels)
+ loss = dist_cross_entropy(
+ labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
+ )
if not return_dict:
output = (logits,) + outputs[1:]
@@ -737,26 +721,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
-
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- if shard_config.enable_tensor_parallelism:
- new_vocab_size = logits.shape[-1]
- shift_logits = shift_logits.view(-1, new_vocab_size)
- loss = cross_entropy_1d(
- shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
- )
- else:
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- loss = loss_fct(shift_logits, shift_labels)
+ loss = dist_cross_entropy(
+ labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
+ )
if not return_dict:
output = (logits,) + outputs[1:]
diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py
index 85ec6717d..36491b4b5 100644
--- a/colossalai/shardformer/policies/llama.py
+++ b/colossalai/shardformer/policies/llama.py
@@ -1,4 +1,3 @@
-import warnings
from functools import partial
from typing import Callable, Dict, List, Union
@@ -66,13 +65,6 @@ class LlamaPolicy(Policy):
else:
norm_cls = RMSNorm
- if self.pipeline_stage_manager is not None:
- self.shard_config.enable_sequence_parallelism = False
- self.shard_config.enable_sequence_overlap = False
- self.shard_config.sequence_parallelism_mode = None
- warnings.warn(
- f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
- )
sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None
diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py
index 8fe18f69b..88e54176b 100644
--- a/tests/test_shardformer/test_model/test_shard_llama.py
+++ b/tests/test_shardformer/test_model/test_shard_llama.py
@@ -59,10 +59,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
if (
booster.plugin.zero_stage in [1, 2]
and booster.plugin.shard_config.enable_sequence_parallelism
+ and booster.plugin.shard_config.pipeline_stage_manager is None
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
+ master2working = sharded_optimizer.get_master_to_working_map()
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
- working_p = sharded_optimizer.master_to_working_param[id(p2)]
+ working_p = master2working[id(p2)]
grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = (
0
@@ -146,6 +148,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
+ { # Ulysess + Flash attention
+ "tp_size": 1,
+ "pp_size": 2,
+ "sp_size": 2,
+ "num_microbatches": 2,
+ "enable_sequence_parallelism": True,
+ "sequence_parallelism_mode": "all_to_all",
+ "enable_flash_attention": True,
+ "use_lazy_init": True,
+ "zero_stage": 0,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
{ # Test ring + Flash attention
"tp_size": 2,
"pp_size": 1,
@@ -159,19 +174,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16",
"initial_scale": 1,
},
- { # Ulysess + Flash attention
- "tp_size": 1,
- "pp_size": 2,
- "sp_size": 2,
- "num_microbatches": 2,
- "enable_sequence_parallelism": True,
- "sequence_parallelism_mode": "all_to_all",
- "enable_flash_attention": True,
- "use_lazy_init": True,
- "zero_stage": 1,
- "precision": "fp16",
- "initial_scale": 1,
- },
{
"tp_size": 1,
"pp_size": 1,
@@ -245,7 +247,6 @@ def run_llama_test(test_config):
except Exception as e:
print(f"Failed config: {test_config}")
raise e
-
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()