From 773d9f964a34a4aa905286a4a0a0a6ddb9de281d Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 28 Jun 2024 11:20:04 +0800 Subject: [PATCH 01/15] [shardformer]delete xformers (#5859) * delete xformers * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/shardformer/modeling/bert.py | 110 ------------- colossalai/shardformer/modeling/bloom.py | 87 ----------- colossalai/shardformer/modeling/sam.py | 165 -------------------- colossalai/shardformer/policies/bert.py | 12 -- colossalai/shardformer/policies/bloom.py | 13 +- colossalai/shardformer/policies/sam.py | 20 --- docs/source/zh-Hans/features/shardformer.md | 12 +- 7 files changed, 7 insertions(+), 412 deletions(-) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index e7679f0ec..7710b56e7 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1,4 +1,3 @@ -import math import warnings from typing import List, Optional, Tuple, Union @@ -1005,115 +1004,6 @@ class BertPipelineForwards: return {"hidden_states": hidden_states} -def get_bert_flash_attention_forward(): - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - from transformers.models.bert.modeling_bert import BertAttention - - def forward( - self: BertAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - final_attention_mask = None - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - final_attention_mask = relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - final_attention_mask = relative_position_scores_query + relative_position_scores_key - - scale = 1 / math.sqrt(self.attention_head_size) - if attention_mask is not None: - if final_attention_mask != None: - final_attention_mask = final_attention_mask * scale + attention_mask - else: - final_attention_mask = attention_mask - - if final_attention_mask is not None: - batch_size, src_len = query_layer.size()[0], query_layer.size()[2] - tgt_len = key_layer.size()[2] - final_attention_mask = final_attention_mask.expand( - batch_size, self.num_attention_heads, src_len, tgt_len - ).contiguous() - - query_layer = query_layer.permute(0, 2, 1, 3).contiguous() - key_layer = key_layer.permute(0, 2, 1, 3).contiguous() - value_layer = value_layer.permute(0, 2, 1, 3).contiguous() - - context_layer = me_attention( - query_layer, key_layer, value_layer, attn_bias=final_attention_mask, p=self.dropout.p, scale=scale - ) - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, None) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs - - return forward - - def get_jit_fused_bert_self_output_forward(): from transformers.models.bert.modeling_bert import BertSelfOutput diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 1f34215c5..154143626 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -714,93 +714,6 @@ class BloomPipelineForwards: return {"hidden_states": hidden_states} -def get_bloom_flash_attention_forward(enable_jit_fused=False): - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - from transformers.models.bloom.modeling_bloom import BloomAttention - - def forward( - self: BloomAttention, - hidden_states: torch.Tensor, - residual: torch.Tensor, - alibi: torch.Tensor, - attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - ): - fused_qkv = self.query_key_value(hidden_states) - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - batch_size, tgt_len, _, _ = query_layer.size() - - _, kv_length, _, _ = key_layer.size() - - proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim) - query_layer = query_layer.contiguous().view(*proj_shape) - key_layer = key_layer.contiguous().view(*proj_shape) - value_layer = value_layer.contiguous().view(*proj_shape) - - if layer_past is not None: - past_key, past_value = layer_past - # concatenate along seq_length dimension: - # - key: [batch_size * self.num_heads, head_dim, kv_length] - # - value: [batch_size * self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=1) - value_layer = torch.cat((past_value, value_layer), dim=1) - - if use_cache is True: - present = (key_layer, value_layer) - else: - present = None - - tgt_len = key_layer.size()[1] - - attention_numerical_mask = torch.zeros( - (batch_size, self.num_heads, tgt_len, kv_length), - dtype=torch.float32, - device=query_layer.device, - requires_grad=True, - ) - attention_numerical_mask = ( - attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta - ) - attention_numerical_mask = torch.masked_fill( - attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min - ) - attention_numerical_mask = attention_numerical_mask.to(query_layer.dtype) - - context_layer = me_attention( - query_layer, - key_layer, - value_layer, - attn_bias=attention_numerical_mask, - scale=self.inv_norm_factor, - p=self.attention_dropout.p, - ) - context_layer = context_layer.reshape(-1, kv_length, self.hidden_size) - if self.pretraining_tp > 1 and self.slow_but_exact: - slices = self.hidden_size / self.pretraining_tp - output_tensor = torch.zeros_like(context_layer) - for i in range(self.pretraining_tp): - output_tensor = output_tensor + F.linear( - context_layer[:, :, int(i * slices) : int((i + 1) * slices)], - self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], - ) - else: - output_tensor = self.dense(context_layer) - - # TODO to replace with the bias_dropout_add function in jit - output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training) - outputs = (output_tensor, present, None) - - return outputs - - return forward - - def get_jit_fused_bloom_attention_forward(): from transformers.models.bloom.modeling_bloom import BloomAttention diff --git a/colossalai/shardformer/modeling/sam.py b/colossalai/shardformer/modeling/sam.py index 26e0b224d..49fce0556 100644 --- a/colossalai/shardformer/modeling/sam.py +++ b/colossalai/shardformer/modeling/sam.py @@ -1,9 +1,4 @@ -import math -from typing import Tuple - import torch -import torch.nn.functional as F -from torch import Tensor def forward_fn(): @@ -45,163 +40,3 @@ def forward_fn(): return outputs return forward - - -def get_sam_flash_attention_forward(): - from transformers.models.sam.modeling_sam import SamAttention - - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - - def _separate_heads(hidden_states: Tensor, num_attention_heads: int) -> Tensor: - batch, point_batch_size, n_tokens, channel = hidden_states.shape - c_per_head = channel // num_attention_heads - hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head) - return hidden_states - - def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor: - batch, n_tokens, n_heads, c_per_head = hidden_states.shape - return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) - - def forward( - self: SamAttention, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None - ) -> Tensor: - # Input projections - query = self.q_proj(query) - key = self.k_proj(key) - value = self.v_proj(value) - - point_batch_size = query.shape[1] - # Separate into heads - query = _separate_heads(query, self.num_attention_heads) - key = _separate_heads(key, self.num_attention_heads) - value = _separate_heads(value, self.num_attention_heads) - - # SamAttention - _, _, _, c_per_head = query.shape - bias = None - if attention_similarity is not None: - bias = attention_similarity - - scale = 1.0 / math.sqrt(c_per_head) - out = me_attention(query, key, value, attn_bias=bias, scale=scale) - - out = _recombine_heads(out, point_batch_size) - out = self.out_proj(out) - - return out - - return forward - - -def get_sam_vision_flash_attention_forward(): - from transformers.models.sam.modeling_sam import SamVisionAttention - - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - - def add_decomposed_rel_pos( - query: torch.Tensor, - rel_pos_h: torch.Tensor, - rel_pos_w: torch.Tensor, - q_size: Tuple[int, int], - k_size: Tuple[int, int], - ) -> torch.Tensor: - """ - Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. - https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py - - Args: - attn (`torch.Tensor`): - attention map. - query (`torch.Tensor`): - query q in the attention layer with shape (batch_size, query_height * query_width, channel). - rel_pos_h (`torch.Tensor`): - relative position embeddings (Lh, channel) for height axis. - rel_pos_w (`torch.Tensor`): - relative position embeddings (Lw, channel) for width axis. - q_size (tuple): - spatial sequence size of query q with (query_height, query_width). - k_size (tuple): - spatial sequence size of key k with (key_height, key_width). - - Returns: - attn (`torch.Tensor`): - attention map with added relative positional embeddings. - """ - - query_height, query_width = q_size - key_height, key_width = k_size - relative_position_height = get_rel_pos(query_height, key_height, rel_pos_h) - relative_position_width = get_rel_pos(query_width, key_width, rel_pos_w) - - batch_size, _, nHead, dim = query.shape - reshaped_query = query.transpose(1, 2).reshape(batch_size * nHead, query_height, query_width, dim) - rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) - rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) - rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] - rel_pos = rel_pos.reshape(batch_size, nHead, query_height * query_width, key_height * key_width) - return rel_pos - - def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: - """ - Get relative positional embeddings according to the relative positions of - query and key sizes. - - Args: - q_size (int): - size of the query. - k_size (int): - size of key k. - rel_pos (`torch.Tensor`): - relative position embeddings (L, channel). - - Returns: - Extracted positional embeddings according to relative positions. - """ - max_rel_dist = int(2 * max(q_size, k_size) - 1) - # Interpolate rel pos. - rel_pos_resized = F.interpolate( - rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), - size=max_rel_dist, - mode="linear", - ) - rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) - - # Scale the coords with short length if shapes for q and k are different. - q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) - k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) - relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) - - return rel_pos_resized[relative_coords.long()] - - def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: - batch_size, height, width, _ = hidden_states.shape - # qkv with shape (3, batch_size, nHead, height * width, channel) - qkv = ( - self.qkv(hidden_states) - .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) - .permute(2, 0, 1, 3, 4) - ) - - query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0) - - rel_pos = None - if self.use_rel_pos: - rel_pos = add_decomposed_rel_pos(query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)) - - attn_output = me_attention(query, key, value, attn_bias=rel_pos, p=self.dropout, scale=self.scale) - - attn_output = attn_output.reshape(batch_size, height, width, -1) - - attn_output = self.proj(attn_output) - - outputs = (attn_output, None) - - return outputs - - return forward diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index c11ed99ac..b84a372a5 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -11,7 +11,6 @@ import colossalai.shardformer.layer as col_nn from ..modeling.bert import ( BertPipelineForwards, bert_sequence_parallel_forward_fn, - get_bert_flash_attention_forward, get_jit_fused_bert_intermediate_forward, get_jit_fused_bert_output_forward, get_jit_fused_bert_self_output_forward, @@ -49,7 +48,6 @@ class BertPolicy(Policy): BertLayer, BertModel, BertOutput, - BertSelfAttention, BertSelfOutput, ) @@ -218,16 +216,6 @@ class BertPolicy(Policy): target_key=BertEmbeddings, ) - # use flash attention - if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement( - description={ - "forward": get_bert_flash_attention_forward(), - }, - policy=policy, - target_key=BertSelfAttention, - ) - # use jit operator if self.shard_config.enable_jit_fused: self.append_or_create_method_replacement( diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 20a75cf90..d80adb84a 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -11,14 +11,13 @@ import colossalai.shardformer.layer as col_nn from ..modeling.bloom import ( BloomPipelineForwards, build_bloom_alibi_tensor_fn, - get_bloom_flash_attention_forward, get_bloom_sequence_parallel_forward_fn, get_jit_fused_bloom_attention_forward, get_jit_fused_bloom_gelu_forward, get_jit_fused_bloom_mlp_forward, get_lm_forward_with_dist_cross_entropy, ) -from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func +from ..modeling.jit import get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -165,16 +164,6 @@ class BloomPolicy(Policy): target_key=BloomModel, ) - if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement( - description={ - "forward": get_bloom_flash_attention_forward(), - "dropout_add": get_dropout_add_func(), - }, - policy=policy, - target_key=BloomAttention, - ) - # enable jit fused operator if self.shard_config.enable_jit_fused: self.append_or_create_method_replacement( diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index c224d7769..53faf8997 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -1,5 +1,3 @@ -import warnings - import colossalai.shardformer.layer as col_nn from ..modeling.sam import forward_fn @@ -212,24 +210,6 @@ class SamPolicy(Policy): target_key=SamTwoWayTransformer, ) - # use flash attention - if self.shard_config.enable_flash_attention: - warnings.warn("Flash attention is not supported in SAM model. Fallback to normal attention.") - # self.append_or_create_method_replacement( - # description={ - # "forward": get_sam_flash_attention_forward(), - # }, - # policy=policy, - # target_key=SamAttention, - # ) - # self.append_or_create_method_replacement( - # description={ - # "forward": get_sam_vision_flash_attention_forward(), - # }, - # policy=policy, - # target_key=SamVisionAttention, - # ) - return policy def postprocess(self): diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index a42c7cc2e..00e1a13d6 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -71,8 +71,8 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. ✔️ ✔️ ✔️ - ✔️ - ✔️ + ❌ + ❌ ✔️ ✔️ ✔️ @@ -95,8 +95,8 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. ✔️ ✔️ ✔️ - ✔️ - ✔️ + ❌ + ❌ ✔️ ✔️ ✔️ @@ -155,8 +155,8 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. ✔️ ❌ ❌ - ✔️ - ✔️ + ❌ + ❌ ✔️ ✔️ ❌ From 416580b3142457f1b210147e8611756eef1687ad Mon Sep 17 00:00:00 2001 From: Haze188 Date: Fri, 28 Jun 2024 14:00:08 +0800 Subject: [PATCH 02/15] [MoE/ZeRO] Moe refactor with zero refactor (#5821) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [moe] removed openmoe-coupled code and rectify mixstral code (#5471) * [Feauture] MoE refractor; Intergration with Mixtral (#5682) * cherry pick from refractor-moe branch * tests passed * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support ep + zero --------- Co-authored-by: Edenzzzz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * add mixtral auto policy & move pipeline forward code to modeling folder * [moe refactor] modify kernel test without Route Class * [moe refactor] add moe tensor test path environment variable to github workflow * fix typos * fix moe test bug due to the code rebase * [moe refactor] fix moe zero test, and little bug in low level zero * fix typo * add moe tensor path to github workflow * remove some useless code * fix typo & unify global variable XX_AXIS logic without using -1 * fix typo & prettifier the code * remove print code & support zero 2 test * remove useless code * reanme function * fix typo * fix typo * Further improve the test code * remove print code * [moe refactor] change test model from fake moe model to mixtral moe layer and remove useless test * [moe refactor] skip some unit test which will be refactored later * [moe refactor] fix unit import error * [moe refactor] fix circular import issues * [moe refactor] remove debug code * [moe refactor] update github workflow * [moe/zero] refactor low level optimizer (#5767) * [zero] refactor low level optimizer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] MoE refactor with newest version of ZeRO (#5801) * [zero] remove redundant members in BucketStore (#5802) * [zero] align api with previous version * [Moe/Zero] Update MoeHybridParallelPlugin with refactored ZeRO and Fix Zero bug (#5819) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * [hotfix]Solve the compatibility issue of zero refactor (#5823) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * Modify function parameter names to resolve compatibility issues * [zero] fix missing hook removal (#5824) * [MoE] Resolve .github conflict (#5829) * [Fix/Example] Fix Llama Inference Loading Data Type (#5763) * [fix/example] fix llama inference loading dtype * revise loading dtype of benchmark llama3 * [release] update version (#5752) * [release] update version * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [test] fix ddp plugin test * [test] fix gptj and rpc test * [devops] fix cuda ext compatibility * [inference] fix flash decoding test * [inference] fix flash decoding test * fix (#5765) * [test] Fix/fix testcase (#5770) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [Hotfix] Add missing init file in inference.executor (#5774) * [CI/tests] simplify some test case to reduce testing time (#5755) * [ci/tests] simplify some test case to reduce testing time * [ci/tests] continue to remove test case to reduce ci time cost * restore some test config * [ci/tests] continue to reduce ci time cost * [misc] update dockerfile (#5776) * [misc] update dockerfile * [misc] update dockerfile * [devops] fix docker ci (#5780) * [Inference]Add Streaming LLM (#5745) * Add Streaming LLM * add some parameters to llama_generation.py * verify streamingllm config * add test_streamingllm.py * modified according to the opinions of review * add Citation * change _block_tables tolist * [hotfix] fix llama flash attention forward (#5777) * [misc] Accelerate CI for zero and dist optim (#5758) * remove fp16 from lamb * remove d2h copy in checking states --------- Co-authored-by: Edenzzzz * [Test/CI] remove test cases to reduce CI duration (#5753) * [test] smaller gpt2 test case * [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py * [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py * [test] reduce test cases tests/test_zero/test_gemini/test_optim.py * Revert "[test] smaller gpt2 test case" Some tests might depend on the size of model (num of chunks) This reverts commit df705a5210b8901645992adf276e320e48766ebf. * [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py * [CI] smaller test model for two mwo the two modifid cases * [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there * [hotfix] fix testcase in test_fx/test_tracer (#5779) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [fix] fix test_deepfm_model & test_dlrf_model; * [fix] fix test_hf_albert & test_hf_gpt; * [gemini] optimize reduce scatter d2h copy (#5760) * [gemini] optimize reduce scatter d2h copy * [fix] fix missing reduce variable * [refactor] remove legacy async reduce scatter code * [gemini] missing sync * Revert "[refactor] remove legacy async reduce scatter code" This reverts commit 58ad76d4665032bbe548d066116d1c572ce98979. * [gemini] further optimize with async all reduce * [fix] pass flag from manager to chunk * Allow building cuda extension without a device. (#5535) Added FORCE_CUDA environment variable support, to enable building extensions where a GPU device is not present but cuda libraries are. * [misc] fix dist logger (#5782) * [install]fix setup (#5786) * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [misc] update requirements (#5787) * [shardformer] fix import (#5788) * upgrade colossal-chat support tp_group>1, add sp for sft * upgrade ppo dpo rm script * run pre-commit * moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy * fix training script * fix ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix transformers version * remove duplicated test * fix datasets version * remove models that require huggingface auth from ci * remove local data path * update ci * remove baichuan from template test due to transformer version conflict * merge * Refactor modeling by adding attention backend Signed-off-by: char-1ee * Fix tests and naming Signed-off-by: char-1ee * Pass inference model shard configs for module init Signed-off-by: char-1ee * Clean up Signed-off-by: char-1ee * replace the customized dataloader setup with the build-in one * replace the customized dataloader setup with the build-in one * Remove flash attention backend Signed-off-by: char-1ee * fix readme * Fix test import Signed-off-by: char-1ee * update sft trainning script * [Inference]refactor baichuan (#5791) * refactor baichuan * remove unused code and add TODO for lazyinit * [test] fix chatglm test kit (#5793) * [shardformer] fix modeling of bloom and falcon (#5796) * [test] fix qwen2 pytest distLarge (#5797) * [Inference] Fix flash-attn import and add model test (#5794) * Fix torch int32 dtype Signed-off-by: char-1ee * Fix flash-attn import Signed-off-by: char-1ee * Add generalized model test Signed-off-by: char-1ee * Remove exposed path to model Signed-off-by: char-1ee * Add default value for use_flash_attn Signed-off-by: char-1ee * Rename model test Signed-off-by: char-1ee --------- Signed-off-by: char-1ee * [Gemini] Use async stream to prefetch and h2d data moving (#5781) * use async stream to prefetch and h2d data moving * Remove redundant code * [gemini] quick fix on possible async operation (#5803) * [gemini] quick fix on possible async operation * [gemini] quick fix on possible async operation * [shardformer] upgrade transformers to 4.39.3 (#5815) * [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807) * [shardformer] fix modeling of gpt2 and gptj * [shardformer] fix whisper modeling * [misc] update requirements --------- Co-authored-by: ver217 * [shardformer]upgrade transformers for mistral (#5808) * upgrade transformers for mistral * fix * fix * [shardformer]upgrade transformers for llama (#5809) * update transformers fix * fix * fix * [inference] upgrade transformers (#5810) * update transformers fix * fix * fix * fix * fix * [gemini] update transformers for gemini (#5814) --------- Co-authored-by: ver217 * Support 4d parallel + flash attention (#5789) * support tp + sp + pp * remove comments --------- Co-authored-by: Edenzzzz --------- Signed-off-by: char-1ee Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Edenzzzz Co-authored-by: Edenzzzz Co-authored-by: botbw Co-authored-by: Charles Coulombe Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang Co-authored-by: char-1ee Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang * [zero] fix hook bug * [zero] add low level optimizer back (#5839) * [zero] fix param & refactor * [zero] add back original low level opt * [zero] remove moe related * [zero] pass zero tests * [zero] refactor * [chore] add del func back * [zero] comments and naming (#5840) * [zero] modify api (#5843) * [zero] modify api * [test] remove _grad_store access in tests * [test] fix (#5857) * [CI] skip openmoe CI check * [CI] fox pre-commit * [zero] remove redundant memebr init (#5862) * [misc] remove useless code, modify the pg mesh implementation * [misc] remove useless code, modify the pg mesh implementation * [misc] use tempfile * resolve conflict with main branch * [misc] use tempfile in test_moe_checkpoint.py * [misc] remove useless code, add assertion about sequence parallel, move logger into function * [misc] remove useless code --------- Signed-off-by: char-1ee Co-authored-by: Frank Lee Co-authored-by: Edenzzzz Co-authored-by: Edenzzzz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: botbw Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Charles Coulombe Co-authored-by: YeAnbang Co-authored-by: char-1ee Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang --- .github/workflows/build_on_pr.yml | 3 +- .github/workflows/build_on_schedule.yml | 3 +- .../compatiblity_test_on_dispatch.yml | 3 +- .github/workflows/compatiblity_test_on_pr.yml | 3 +- .../compatiblity_test_on_schedule.yml | 3 +- .../ColossalMoE/colossal_moe/__init__.py | 0 .../colossal_moe/models/__init__.py | 0 .../colossal_moe/models/mixtral_layer.py | 92 -- applications/ColossalMoE/infer.py | 4 - applications/ColossalMoE/infer.sh | 3 +- .../ColossalMoE/tests/test_moe_checkpoint.py | 146 ---- applications/ColossalMoE/train.py | 6 +- .../ColossalMoE/{colossal_moe => }/utils.py | 0 .../colossalqa/local/colossalcloud_llm.py | 1 + .../booster/plugin/hybrid_parallel_plugin.py | 23 +- .../plugin/moe_hybrid_parallel_plugin.py | 147 ++-- colossalai/checkpoint_io/__init__.py | 9 +- .../hybrid_parallel_checkpoint_io.py | 14 +- .../checkpoint_io/moe_checkpoint.py | 319 ++++++- colossalai/checkpoint_io/utils.py | 1 + colossalai/cluster/process_group_mesh.py | 12 +- colossalai/moe/__init__.py | 15 - colossalai/moe/checkpoint.py | 792 ------------------ colossalai/moe/load_balance.py | 6 +- colossalai/moe/loss.py | 78 -- colossalai/moe/routers.py | 466 ----------- colossalai/moe/utils.py | 9 +- colossalai/shardformer/layer/moe/__init__.py | 3 + .../{ => shardformer/layer}/moe/experts.py | 4 +- .../{ => shardformer/layer}/moe/layers.py | 23 +- colossalai/shardformer/layer/moe/routers.py | 161 ++++ .../shardformer/modeling/mixtral.py | 290 ++----- .../shardformer/policies/auto_policy.py | 10 +- colossalai/shardformer/policies/mixtral.py | 210 +++++ colossalai/shardformer/shard/shard_config.py | 1 + colossalai/tensor/moe_tensor/api.py | 11 +- .../zero/low_level/bookkeeping/__init__.py | 3 +- .../low_level/bookkeeping/bucket_store.py | 25 +- .../low_level/bookkeeping/gradient_store.py | 13 +- .../low_level/bookkeeping/parameter_store.py | 60 -- colossalai/zero/low_level/low_level_optim.py | 735 +++++++--------- .../openmoe/benchmark/benchmark_cai.py | 2 +- .../openmoe/model/modeling_openmoe.py | 10 +- .../language/openmoe/model/openmoe_policy.py | 1 + examples/language/openmoe/test_ci.sh | 60 +- examples/language/openmoe/train.py | 46 +- .../test_low_level_zero_checkpoint_io.py | 12 +- tests/test_moe/moe_utils.py | 38 +- tests/test_moe/test_grad_handler.py | 4 +- tests/test_moe/test_kernel.py | 136 ++- .../test_moe}/test_mixtral_layer.py | 13 +- tests/test_moe/test_moe_checkpoint.py | 313 ++++--- tests/test_moe/test_moe_ep_tp.py | 10 +- tests/test_moe/test_moe_group.py | 4 +- tests/test_moe/test_moe_hybrid_zero.py | 1 + tests/test_moe/test_moe_load_balance.py | 4 +- tests/test_moe/test_moe_router.py | 47 -- tests/test_moe/test_moe_zero_fwd_bwd.py | 78 -- tests/test_moe/test_moe_zero_fwd_bwd_optim.py | 132 +++ tests/test_moe/test_moe_zero_optim.py | 83 -- tests/test_optimizer/_utils.py | 2 +- tests/test_optimizer/test_dist_adafactor.py | 2 +- tests/test_optimizer/test_dist_came.py | 2 +- tests/test_optimizer/test_dist_lamb.py | 2 +- .../test_zero_optimizer.py | 5 +- .../test_model/test_shard_command.py | 6 +- .../test_model/test_shard_llama.py | 8 +- .../test_zero/test_low_level/test_mem_leak.py | 61 ++ .../test_zero/test_low_level/test_zero1_2.py | 67 +- 69 files changed, 1780 insertions(+), 3076 deletions(-) delete mode 100644 applications/ColossalMoE/colossal_moe/__init__.py delete mode 100644 applications/ColossalMoE/colossal_moe/models/__init__.py delete mode 100644 applications/ColossalMoE/colossal_moe/models/mixtral_layer.py delete mode 100644 applications/ColossalMoE/tests/test_moe_checkpoint.py rename applications/ColossalMoE/{colossal_moe => }/utils.py (100%) rename applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py => colossalai/checkpoint_io/moe_checkpoint.py (66%) delete mode 100644 colossalai/moe/checkpoint.py delete mode 100644 colossalai/moe/loss.py delete mode 100644 colossalai/moe/routers.py create mode 100644 colossalai/shardformer/layer/moe/__init__.py rename colossalai/{ => shardformer/layer}/moe/experts.py (98%) rename colossalai/{ => shardformer/layer}/moe/layers.py (96%) create mode 100644 colossalai/shardformer/layer/moe/routers.py rename applications/ColossalMoE/colossal_moe/models/mixtral_policy.py => colossalai/shardformer/modeling/mixtral.py (65%) create mode 100644 colossalai/shardformer/policies/mixtral.py delete mode 100644 colossalai/zero/low_level/bookkeeping/parameter_store.py rename {applications/ColossalMoE/tests => tests/test_moe}/test_mixtral_layer.py (81%) delete mode 100644 tests/test_moe/test_moe_router.py delete mode 100644 tests/test_moe/test_moe_zero_fwd_bwd.py create mode 100644 tests/test_moe/test_moe_zero_fwd_bwd_optim.py delete mode 100644 tests/test_moe/test_moe_zero_optim.py create mode 100644 tests/test_zero/test_low_level/test_mem_leak.py diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index adf4501bb..151454239 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -90,7 +90,7 @@ jobs: runs-on: [self-hosted, gpu] container: image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 - options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny + options: --gpus all --rm -v /dev/shm -v /data/scratch:/data/scratch timeout-minutes: 90 defaults: run: @@ -165,6 +165,7 @@ jobs: env: LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LLAMA_PATH: /data/scratch/llama-tiny + MOE_TENSOR_PATH: /data/scratch/moe_tensors - name: Collate artifact env: diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index e560d0c00..fc6424503 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -13,7 +13,7 @@ jobs: runs-on: [self-hosted, gpu] container: image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 - options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny + options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/ timeout-minutes: 90 steps: - name: Check GPU Availability # ensure all GPUs have enough memory @@ -69,6 +69,7 @@ jobs: env: LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LLAMA_PATH: /data/scratch/llama-tiny + MOE_TENSOR_PATH: /data/scratch/moe_tensors - name: Notify Lark id: message-preparation diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml index 9867ef7c6..3eee564c2 100644 --- a/.github/workflows/compatiblity_test_on_dispatch.yml +++ b/.github/workflows/compatiblity_test_on_dispatch.yml @@ -50,7 +50,7 @@ jobs: matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: ${{ matrix.container }} - options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny + options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/ timeout-minutes: 200 steps: - name: Install dependencies @@ -92,3 +92,4 @@ jobs: DATA: /data/scratch/cifar-10 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LLAMA_PATH: /data/scratch/llama-tiny + MOE_TENSOR_PATH: /data/scratch/moe_tensors diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml index 885d352d5..b418c843e 100644 --- a/.github/workflows/compatiblity_test_on_pr.yml +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -41,7 +41,7 @@ jobs: matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: ${{ matrix.container }} - options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny + options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/ timeout-minutes: 200 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }} @@ -87,3 +87,4 @@ jobs: DATA: /data/scratch/cifar-10 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LLAMA_PATH: /data/scratch/llama-tiny + MOE_TENSOR_PATH: /data/scratch/moe_tensors diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml index 39e1f479c..8d98e775c 100644 --- a/.github/workflows/compatiblity_test_on_schedule.yml +++ b/.github/workflows/compatiblity_test_on_schedule.yml @@ -38,7 +38,7 @@ jobs: matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: ${{ matrix.container }} - options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny + options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/ timeout-minutes: 200 steps: - name: Install dependencies @@ -85,6 +85,7 @@ jobs: DATA: /data/scratch/cifar-10 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LLAMA_PATH: /data/scratch/llama-tiny + MOE_TENSOR_PATH: /data/scratch/moe_tensors - name: Notify Lark id: message-preparation diff --git a/applications/ColossalMoE/colossal_moe/__init__.py b/applications/ColossalMoE/colossal_moe/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/applications/ColossalMoE/colossal_moe/models/__init__.py b/applications/ColossalMoE/colossal_moe/models/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py deleted file mode 100644 index a2b78a2bd..000000000 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py +++ /dev/null @@ -1,92 +0,0 @@ -import torch -import torch.distributed as dist -import torch.nn.functional as F -from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock - -from colossalai.lazy import LazyInitContext -from colossalai.moe import MOE_MANAGER -from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven -from colossalai.shardformer.shard.utils import set_tensors_to_none -from colossalai.tensor.moe_tensor.api import set_moe_tensor_info - - -class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): - def __init__(self, config): - super().__init__(config) - self.setup_ep() - - def setup_ep(self): - _, moe_info = MOE_MANAGER.get_info(self.num_experts) - ep_group = moe_info.ep_group - self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1 - self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0 - assert self.num_experts % self.ep_size == 0 - self.ep_group = ep_group - self.num_experts_per_ep = self.num_experts // self.ep_size - self.expert_start_idx = self.ep_rank * self.num_experts_per_ep - held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] - set_tensors_to_none(self.experts, exclude=set(held_experts)) - for p in self.experts.parameters(): - set_moe_tensor_info(p, moe_info) - - @staticmethod - def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock": - LazyInitContext.materialize(module) - module.__class__ = EPMixtralSparseMoeBlock - module.setup_ep() - return module - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - # we cast back to the input dtype - routing_weights = routing_weights.to(hidden_states.dtype) - - selected_experts = selected_experts.t().reshape(-1) - selected_experts_idx = selected_experts.argsort() - dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx] - input_split_sizes = selected_experts.bincount(minlength=self.num_experts) - output_split_sizes = torch.zeros_like(input_split_sizes) - dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) - - input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() - output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() - output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) - # compute expert output - output_states = MoeInGradScaler.apply(output_states, self.ep_size) - if output_states.size(0) > 0: - if self.num_experts_per_ep == 1: - # no need to split - expert = self.experts[self.expert_start_idx] - output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states) - output_states = expert.w2(output_states) - else: - output_states_splits = output_states.split(output_split_sizes.tolist()) - output_states_list = [] - for i, split_states in enumerate(output_states_splits): - if split_states.size(0) == 0: - continue - expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] - split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states) - split_states = expert.w2(split_states) - output_states_list.append(split_states) - output_states = torch.cat(output_states_list) - output_states = MoeOutGradScaler.apply(output_states, self.ep_size) - dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) - recover_experts_idx = torch.empty_like(selected_experts_idx) - recover_experts_idx[selected_experts_idx] = torch.arange( - selected_experts_idx.size(0), device=selected_experts_idx.device - ) - dispatch_states = dispatch_states[recover_experts_idx] - k_hidden_states = dispatch_states.chunk(self.top_k) - output_states = k_hidden_states[0] * routing_weights[:, 0, None] - for i in range(1, self.top_k): - output_states += k_hidden_states[i] * routing_weights[:, i, None] - output_states = output_states.reshape(batch_size, sequence_length, hidden_dim) - return output_states, router_logits diff --git a/applications/ColossalMoE/infer.py b/applications/ColossalMoE/infer.py index 543c434d2..6023e304d 100644 --- a/applications/ColossalMoE/infer.py +++ b/applications/ColossalMoE/infer.py @@ -2,8 +2,6 @@ import argparse import torch import torch.distributed as dist -from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO -from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy from transformers import AutoTokenizer from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM @@ -70,8 +68,6 @@ def main(): ep_size=ep_size, zero_stage=1, precision=args.precision, - custom_policy=MixtralForCausalLMPolicy(), - checkpoint_io=MixtralMoEHybridParallelCheckpointIO, enable_fused_normalization=args.use_layernorm_kernel, enable_jit_fused=args.use_kernel, ) diff --git a/applications/ColossalMoE/infer.sh b/applications/ColossalMoE/infer.sh index 0487fe9c1..ba4362d74 100644 --- a/applications/ColossalMoE/infer.sh +++ b/applications/ColossalMoE/infer.sh @@ -1,5 +1,6 @@ NUM_GPU=2 -MODEL="mistralai/Mixtral-8x7B-v0.1" +# MODEL="mistralai/Mixtral-8x7B-v0.1" +MODEL="mistralai/Mixtral-8x7B-Instruct-v0.1" # ep torchrun --standalone --nproc_per_node $NUM_GPU infer.py \ diff --git a/applications/ColossalMoE/tests/test_moe_checkpoint.py b/applications/ColossalMoE/tests/test_moe_checkpoint.py deleted file mode 100644 index 074dbf835..000000000 --- a/applications/ColossalMoE/tests/test_moe_checkpoint.py +++ /dev/null @@ -1,146 +0,0 @@ -from copy import deepcopy - -import pytest -import torch -import torch.distributed as dist -from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO -from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy -from torch.optim import Adam -from transformers.models.mixtral.configuration_mixtral import MixtralConfig -from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM - -import colossalai -from colossalai.booster import Booster -from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin -from colossalai.testing.utils import spawn - -tokens, n_experts = 7, 4 -hidden_size = 8 -top_k = 2 - - -def check_model_equal(model1, model2): - assert set(model1.state_dict().keys()) == set(model2.state_dict().keys()) - for p1, p2 in zip(model1.parameters(), model2.parameters()): - assert torch.equal(p1.half(), p2.half()) - - -def get_optimizer_snapshot(optim): - state = {id(k): deepcopy(v) for k, v in optim.state.items()} - param_groups = [] - for group in optim.param_groups: - params = [id(p) for p in group["params"]] - new_group = {"params": params} - for k, v in group.items(): - if k != "params": - new_group[k] = v - param_groups.append(new_group) - return { - "state": state, - "param_groups": param_groups, - } - - -def check_optimizer_snapshot_equal(snapshot1, snapshot2): - # check param_groups - assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"]) - for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]): - assert set(group1.keys()) == set(group2.keys()) - for k in group1.keys(): - assert group1[k] == group2[k] - # check state - assert set(snapshot1["state"].keys()) == set( - snapshot2["state"].keys() - ), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}" - for pid in snapshot1["state"].keys(): - state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid] - assert set(state1.keys()) == set(state2.keys()) - for k in state1.keys(): - if isinstance(state1[k], torch.Tensor): - assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}" - else: - assert state1[k] == state2[k] - - -def check_mixtral_moe_layer(): - torch.cuda.set_device(dist.get_rank()) - config = MixtralConfig( - hidden_size=hidden_size, - intermediate_size=hidden_size * 2, - num_local_experts=n_experts, - num_experts_per_tok=top_k, - num_attention_heads=2, - num_key_value_heads=2, - ) - torch.manual_seed(0) - input_ids = torch.randint(0, 100, (2, tokens)).cuda() - orig_model = MixtralForCausalLM(config).cuda() - model = deepcopy(orig_model) - optimizer = Adam(model.parameters(), lr=1e-3) - plugin = MoeHybridParallelPlugin( - tp_size=1, - pp_size=2, - ep_size=2, - custom_policy=MixtralForCausalLMPolicy(), - checkpoint_io=MixtralMoEHybridParallelCheckpointIO, - microbatch_size=1, - zero_stage=1, - ) - booster = Booster(plugin=plugin) - model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer) - # initialize grads - data_iter = iter( - [{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}] - ) - booster.execute_pipeline( - data_iter, - model, - lambda outputs, inputs: outputs.loss, - optimizer, - ) - - # check save model - booster.save_model(model, "mixtral_model", shard=True) - dist.barrier() - if dist.get_rank() == 0: - saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda() - check_model_equal(orig_model, saved_model) - saved_model.save_pretrained("mixtral_hf_model") - dist.barrier() - - # check load model - new_model = MixtralForCausalLM(config).cuda() - new_optimizer = Adam(new_model.parameters(), lr=1e-3) - new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer) - booster.load_model(new_model, "mixtral_hf_model") - check_model_equal(model, new_model) - - # check save optimizer - optimizer.step() - for group in optimizer.param_groups: - group["lr"] = 0.1 - snapshot = get_optimizer_snapshot(optimizer.unwrap()) - booster.save_optimizer(optimizer, "mixtral_optim", shard=True) - dist.barrier() - # reset optimizer state - for state in optimizer.unwrap().state.values(): - for v in state.values(): - if isinstance(v, torch.Tensor): - v.zero_() - booster.load_optimizer(optimizer, "mixtral_optim") - loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap()) - check_optimizer_snapshot_equal(snapshot, loaded_snapshot) - - -def run_dist(rank: int, world_size: int, port: int): - colossalai.launch(rank, world_size, "localhost", port) - check_mixtral_moe_layer() - - -@pytest.mark.parametrize("world_size", [4]) -def test_mixtral_moe_layer(world_size: int): - spawn(run_dist, world_size) - - -if __name__ == "__main__": - test_mixtral_moe_layer(4) diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py index d2789d644..9cd810e5a 100644 --- a/applications/ColossalMoE/train.py +++ b/applications/ColossalMoE/train.py @@ -2,13 +2,11 @@ import argparse import torch import torch.distributed as dist -from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO -from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy -from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint from torch.utils.data import Dataset from tqdm import tqdm from transformers import AutoTokenizer from transformers.models.mixtral import MixtralForCausalLM +from utils import load_checkpoint, move_to_cuda, save_checkpoint import colossalai from colossalai.booster import Booster @@ -155,12 +153,10 @@ def main(): pp_size=args.pp_size, ep_size=args.ep_size, microbatch_size=args.microbatch_size, - custom_policy=MixtralForCausalLMPolicy(), enable_fused_normalization=args.use_layernorm_kernel, enable_jit_fused=args.use_kernel, precision=args.precision, zero_stage=args.zero_stage, - checkpoint_io=MixtralMoEHybridParallelCheckpointIO, ) else: diff --git a/applications/ColossalMoE/colossal_moe/utils.py b/applications/ColossalMoE/utils.py similarity index 100% rename from applications/ColossalMoE/colossal_moe/utils.py rename to applications/ColossalMoE/utils.py diff --git a/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py b/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py index 362977869..ca8d64f22 100644 --- a/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py +++ b/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py @@ -20,6 +20,7 @@ resp = llm(prompt='What do you call a three-ton kangaroo?', **gen_config) print(resp) # super-heavyweight awesome-natured yawning Australian creature! """ + import json from typing import Any, Mapping diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 3bd43f172..a3d6f1e74 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -655,7 +655,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): self.param_info = param_info self.stage_manager = model.stage_manager self.shared_params = model.shared_params - self.dp_pg = dp_process_group self.tp_pg = tp_process_group self.pp_pg = pp_process_group if use_pipeline: @@ -718,7 +717,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): """Retrieve all working gradients from different parameter groups.""" all_working_grads = [] for group_id in range(self.num_param_groups): - working_grads = self._grad_store.get_working_grads_by_group_id(group_id) + working_grads = self.get_working_grads_by_group_id(group_id) all_working_grads.extend(working_grads) return all_working_grads @@ -726,7 +725,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): """Identify gradients to be synchronized in the sequence parallelism.""" grads_to_sync = [] for grad in all_working_grads: - param_id_for_grad = self._grad_store.get_param_id_for_grad(grad) + param_id_for_grad = self.get_param_id_for_grad(grad) param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad): grads_to_sync.append(grad) @@ -739,7 +738,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): # Get all working gradients and gradients to be synchronized. all_working_grads = _get_all_working_grads() grads_to_sync = _get_grads_to_sync(all_working_grads) - if self._grad_store.require_grad_sync and grads_to_sync is not None: + if self.require_grad_sync and grads_to_sync is not None: # Synchronize sequence parallelism gradients if required. SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync) else: @@ -763,7 +762,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): # Call the superclass backward method to compute gradients. super().backward(loss, retain_graph) - if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: + if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. self._sync_sp_grads() else: @@ -788,14 +787,14 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): # Call the superclass backward_by_grad method to compute gradients. super().backward_by_grad(tensor, grad) - if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: + if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. self._sync_sp_grads() else: # If gradient synchronization is is not required, return. return - def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: + def _compute_grad_norm(self, dp_pg, gradients: List[Tensor], norm_type: int = 2) -> float: r""" Compute and return the gradient norm for gradient clipping. @@ -811,7 +810,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): if len(gradients) == 0: return 0.0 - dp_size = get_world_size(self.dp_pg) if self.dp_pg is not None else 1 + dp_size = get_world_size(dp_pg) if dp_pg is not None else 1 tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 norm_type = float(norm_type) @@ -842,7 +841,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): # However, we still perform the 'all_reduce' operation for the sake of good coding practices. # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' if tp_size > 1: - param_id_for_grad = self._grad_store.get_param_id_for_grad(grad) + param_id_for_grad = self.get_param_id_for_grad(grad) param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value if not is_distributed_tensor(param_for_grad): @@ -856,7 +855,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): for shared_param in self.shared_params: if self.stage_manager.stage in shared_param: stage_shared_param = shared_param[self.stage_manager.stage] - working_grad = self._grad_store.get_working_grad_by_param_id(id(stage_shared_param)) + working_grad = self.get_working_grad_by_param_id(id(stage_shared_param)) if grad is working_grad: grad_norm_exponentiated /= len(shared_param) @@ -867,7 +866,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): ) if dp_size > 1: # compute norm in dp process group - dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg) + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=dp_pg) if tp_size > 1: # compute norm in tp process group dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) @@ -1309,7 +1308,7 @@ class HybridParallelPlugin(PipelinePluginBase): # run with gradients accumulation if model.require_grad_sync == False or ( - isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer._grad_store.require_grad_sync == False + isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False ): return outputs diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 83888e506..2cfdd000a 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,4 +1,5 @@ import random +import warnings from types import MethodType from typing import Callable, Optional, OrderedDict, Tuple @@ -20,19 +21,19 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import ( get_param_info, init_pipeline_optimizer, ) +from colossalai.checkpoint_io import MoECheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.moe import MOE_MANAGER, MoECheckpointIO +from colossalai.logging import get_dist_logger from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig from colossalai.shardformer.policies.base_policy import Policy +from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.zero.low_level import LowLevelZeroOptimizer -PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 - -class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): +class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer): def __init__( self, optimizer: Optimizer, @@ -67,8 +68,20 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): self.pp_pg = pp_process_group if use_pipeline: init_pipeline_optimizer(optimizer, model) + + pg_param_list = { + dp_process_group: [], + moe_extra_dp_process_group: [], + } + for param in model.parameters(): + if is_moe_tensor(param): + pg_param_list[moe_extra_dp_process_group].append(param) + else: + pg_param_list[dp_process_group].append(param) + super().__init__( optimizer=optimizer, + pg_to_param_list=pg_param_list, initial_scale=initial_scale, min_scale=min_scale, growth_factor=growth_factor, @@ -83,9 +96,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): overlap_communication=overlap_communication, partition_grad=partition_grad, cpu_offload=cpu_offload, - dp_process_group=dp_process_group, forced_dtype=forced_dtype, - moe_extra_dp_process_group=moe_extra_dp_process_group, ) @@ -107,8 +118,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) Args: - tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. precision (str, optional): Specifies the precision of parameters during training. Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. Defaults to 'fp16'. @@ -144,14 +155,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. + use_ep_inside (bool, Optional): Whether to use ep inside dp (intra-node) for moe params. """ def __init__( self, - tp_size: int, pp_size: int, ep_size: int, - extra_dp_size: int = 1, + tp_size: int = 1, + sp_size: int = 1, precision: str = "fp16", zero_stage: int = 0, enable_all_optimization: bool = False, @@ -184,32 +196,22 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): custom_policy: Policy = None, checkpoint_io: Optional[MoECheckpointIO] = None, ) -> None: - assert ( - dist.get_world_size() % (tp_size * pp_size) == 0 - ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + world_size = dist.get_world_size() + assert tp_size == 1, "Tensor parallel is not supported in MoE yet" + assert sp_size == 1 and enable_sequence_parallelism is False, "Sequence parallelism it not supported in MoE yet" - if enable_sequence_parallelism: - assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism" assert ( - dist.get_world_size() % (tp_size * pp_size) == 0 - ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + world_size % (tp_size * pp_size) == 0 + ), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size}" assert ( - dist.get_world_size() % (tp_size * pp_size * ep_size) == 0 - ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}" - self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size) - MOE_MANAGER.setup( - parallel="EP", - mode="fixed", - fixed_dp_size=self.real_dp_size, - fixed_ep_size=ep_size, - fixed_pp_size=pp_size, - use_ep_inside=use_ep_inside, - ) + world_size % (tp_size * pp_size * ep_size) == 0 + ), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}" + + self.dp_size = world_size // (tp_size * pp_size) self.tp_size = tp_size self.pp_size = pp_size - self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.ep_size = ep_size - self.moe_info = MOE_MANAGER.get_info(0)[1] + self.sp_size = sp_size self.precision = precision self.zero_stage = zero_stage self.cpu_offload = cpu_offload @@ -219,43 +221,57 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism self.checkpoint_io = checkpoint_io + + logger = get_dist_logger() + + # NOTE: Two process meshes: global dp for non-moe param; dp + ep for moe param + # See https://hpc-ai.com/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient # we change pg mesh to (pp, dp, tp) for better moe performance - self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size) + assert ( + self.ep_size <= self.dp_size + ), f"Not enough devices({self.dp_size}) for expert parallelism size({self.ep_size})." - # sync moe in outer dp group, and sync other param in global dp group - if extra_dp_size > 1: - ep_size = self.dp_size // extra_dp_size - if use_ep_inside: - self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, extra_dp_size, ep_size) - self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(1) - if dist.get_rank() == 0: - print(f"Zero Parallel: pp {self.pp_size}, outer_dp {extra_dp_size}, inner_dp {ep_size}") - else: - self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, ep_size, extra_dp_size) - self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(2) - if dist.get_rank() == 0: - print(f"Zero Parallel: pp {self.pp_size}, outer_dp {ep_size}, inner_dp {extra_dp_size}") + self.moe_dp_size = self.dp_size // self.ep_size + self.use_ep_inside = use_ep_inside + if self.use_ep_inside: + logger.info(f"MoE Parallel use ep inside dp.", ranks=[0]) + self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 1, 2, 3 + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, ep_size, tp_size) else: - self.moe_extra_dp_group = None + logger.info(f"MoE Parallel use ep outside dp.", ranks=[0]) + warnings.warn("Using ep outside dp (cross-node) is strongly discouraged due to communication costs.") + self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 2, 1, 3 + self.pg_mesh = ProcessGroupMesh(self.pp_size, ep_size, self.moe_dp_size, tp_size) + self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) + self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis) + logger.info(f"Non-MoE Parameter Parallel: pp {self.pp_size}, dp {self.dp_size}, tp {tp_size}", ranks=[0]) + logger.info( + f"MoE Parallel: pp {self.pp_size}, ep {ep_size}, moe dp {self.moe_dp_size}, tp {tp_size}", ranks=[0] + ) + + self.tp_group = self.pg_mesh.get_group_along_axis( + self.tp_axis + ) # TODO: support custom tp size for mixtral lm head + self.global_dp_group = self.pg_mesh.get_group_along_axis((self.dp_axis, self.ep_axis)) + self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) + # TODO: Currently moe only support partially sequence parallel + self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) + + self.custom_policy = custom_policy self.stage_manager = None self.schedule = None - self.custom_policy = custom_policy + assert zero_stage in (0, 1, 2) if self.pp_size > 1: assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" - self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) + self.stage_manager = PipelineStageManager(self.pg_mesh, self.pp_axis) self.schedule = OneForwardOneBackwardSchedule( self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size ) - self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) - self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) - self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) - # TODO: Currently moe only support partially sequence parallel - self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, @@ -267,6 +283,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): enable_jit_fused=self.enable_jit_fused, enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_overlap=enable_sequence_overlap, + ep_group=self.ep_group, ) self.amp_config = dict( initial_scale=initial_scale, @@ -323,7 +340,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): """ _kwargs = kwargs.copy() sampler = DistributedSampler( - dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle + dataset, + num_replicas=self.dp_size, + rank=dist.get_rank(self.global_dp_group), + shuffle=shuffle, ) # Deterministic dataloader @@ -346,9 +366,20 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): def get_checkpoint_io(self) -> MoECheckpointIO: if self.checkpoint_io is None: - self.checkpoint_io = MoECheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + self.checkpoint_io = MoECheckpointIO( + self.global_dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage + ) else: - self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + self.checkpoint_io = self.checkpoint_io( + self.global_dp_group, + self.pp_group, + self.tp_group, + ep_group=self.ep_group, + moe_dp_group=self.moe_dp_group, + zero_stage=self.zero_stage, + ) + if hasattr(self.checkpoint_io, "moe_info"): + self.checkpoint_io.moe_info = self.moe_info return self.checkpoint_io def configure( @@ -366,7 +397,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): module=model, precision=self.precision, shard_config=self.shard_config, - dp_group=self.dp_group, + dp_group=self.global_dp_group, tp_group=self.tp_group, sp_group=self.sp_group, use_ddp=use_ddp, @@ -392,15 +423,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): else: assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." - optimizer = HybridParallelZeroOptimizer( + optimizer = MoeHybridParallelZeroOptimizer( optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info, - dp_process_group=self.dp_group, + dp_process_group=self.global_dp_group, tp_process_group=self.tp_group, pp_process_group=self.pp_group, - moe_extra_dp_process_group=self.moe_extra_dp_group, + moe_extra_dp_process_group=self.moe_dp_group, verbose=True, clip_grad_norm=self.max_norm, **self.zero_config, diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py index 19b61730b..ef37534fe 100644 --- a/colossalai/checkpoint_io/__init__.py +++ b/colossalai/checkpoint_io/__init__.py @@ -2,5 +2,12 @@ from .checkpoint_io_base import CheckpointIO from .general_checkpoint_io import GeneralCheckpointIO from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO from .index_file import CheckpointIndexFile +from .moe_checkpoint import MoECheckpointIO -__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"] +__all__ = [ + "CheckpointIO", + "CheckpointIndexFile", + "GeneralCheckpointIO", + "HybridParallelCheckpointIO", + "MoECheckpointIO", +] diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 7946d9b9c..61c9d1438 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -70,13 +70,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): verbose: bool = True, ) -> None: super().__init__() - self.dp_group = dp_group + self.global_dp_group = dp_group self.pp_group = pp_group self.tp_group = tp_group - self.dp_rank = dist.get_rank(self.dp_group) + self.dp_rank = dist.get_rank(self.global_dp_group) self.tp_rank = dist.get_rank(self.tp_group) self.pp_rank = dist.get_rank(self.pp_group) - self.dp_size = dist.get_world_size(dp_group) + self.global_dp_size = dist.get_world_size(dp_group) self.pp_size = dist.get_world_size(pp_group) self.tp_size = dist.get_world_size(tp_group) self.use_zero = zero_stage > 0 @@ -433,7 +433,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder( optimizer, use_zero=self.use_zero, - dp_group=self.dp_group, + dp_group=self.global_dp_group, tp_group=self.tp_group, size_per_shard=size_per_shard, ) @@ -727,7 +727,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): state, working_param, original_shape=original_shape, - dp_group=self.dp_group, + dp_group=self.global_dp_group, tp_group=self.tp_group, use_zero=self.use_zero, inplace=False, @@ -932,12 +932,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): # Shard state along data parallel group when using Zero. if self.use_zero: - padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size + padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size with torch.no_grad(): v = v.flatten() if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) - slice_size = v.numel() // self.dp_size + slice_size = v.numel() // self.global_dp_size v = v.split(slice_size, dim=0)[self.dp_rank] state_[k] = v.detach().clone().to(device) diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py similarity index 66% rename from applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py rename to colossalai/checkpoint_io/moe_checkpoint.py index d08dfd5f8..a0b625008 100644 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py +++ b/colossalai/checkpoint_io/moe_checkpoint.py @@ -9,6 +9,7 @@ import torch import torch.distributed as dist import torch.nn as nn from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import get_global_rank from colossalai.checkpoint_io import CheckpointIndexFile from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO @@ -19,15 +20,16 @@ from colossalai.checkpoint_io.utils import ( get_model_base_filenames, get_optimizer_base_filenames, load_shard_state_dict, + load_state_dict, load_states_into_optimizer, save_config_file, save_param_groups, + save_state_dict, save_state_dict_shards, search_tp_partition_dim, sharded_optimizer_loading_epilogue, ) from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.moe import MOE_MANAGER from colossalai.tensor.moe_tensor.api import is_moe_tensor try: @@ -36,21 +38,30 @@ except ImportError: _EXTRA_STATE_KEY_SUFFIX = "_extra_state" -class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): +class MoECheckpointIO(HybridParallelCheckpointIO): def __init__( self, - dp_group: ProcessGroup, + global_dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup, + ep_group: ProcessGroup, + moe_dp_group: ProcessGroup, zero_stage: int, verbose: bool = True, ) -> None: - super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose) - moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size] - self.ep_group = moe_info.ep_group - self.ep_size = moe_info.ep_size - self.ep_rank = moe_info.ep_rank - self.real_dp_rank = moe_info.dp_rank + super().__init__(global_dp_group, pp_group, tp_group, zero_stage, verbose) + self.global_dp_group = global_dp_group + self.global_dp_rank = dist.get_rank(global_dp_group) + self.global_dp_size = dist.get_world_size(global_dp_group) + self.pp_group = pp_group + self.tp_group = tp_group + + self.moe_dp_group = moe_dp_group + self.moe_dp_size = dist.get_world_size(moe_dp_group) + self.moe_dp_rank = dist.get_rank(moe_dp_group) + self.ep_group = ep_group + self.ep_size = dist.get_world_size(ep_group) + self.ep_rank = dist.get_rank(ep_group) @staticmethod def _model_sharder( @@ -134,7 +145,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): Path(checkpoint).mkdir(parents=True, exist_ok=True) - if self.real_dp_rank != 0: + if self.moe_dp_rank != 0: dist.barrier() return @@ -144,7 +155,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): # Then collect the sharded parameters & buffers along tp_group. # Only devices with tp_rank == 0 are responsible for model saving. - state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder( + state_dict_shard = MoECheckpointIO._model_sharder( model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern ) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) @@ -234,11 +245,12 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): state: OrderedDict, param: torch.Tensor, original_shape: torch.Size, - dp_group: ProcessGroup, + global_dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool, inplace: bool, is_moe_param: bool, + moe_dp_group: ProcessGroup = None, device: torch.device = torch.device("cpu"), ) -> OrderedDict: """ @@ -248,7 +260,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero. param (torch.Tensor): The given parameter. It should be working_param when using Zero. original_shape (torch.Size): The size of parameter before sharding. - dp_group (ProcessGroup): The process group of data parallel. + global_dp_group (ProcessGroup): The process group of data parallel. tp_group (ProcessGroup): The process group of tensor parallel. use_zero (bool): Whether Zero is used. inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. @@ -257,27 +269,47 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): Returns: OrderedDict: The complete optimizer state of given parameter. """ - dp_size = dist.get_world_size(dp_group) + global_dp_size = dist.get_world_size(global_dp_group) tp_size = dist.get_world_size(tp_group) + moe_dp_size = dist.get_world_size(moe_dp_group) if moe_dp_group is not None else 1 current_shape = param.shape state_ = state if inplace else copy.deepcopy(state) - for k, v in state_.items(): if isinstance(v, torch.Tensor) and k != "step": + v = v.cuda() + # First gather Zero shards. - if use_zero and not is_moe_param: - v = v.cuda() - gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] - dist.all_gather(gather_tensor, v, group=dp_group) - v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) + if use_zero and is_moe_param and moe_dp_size > 1: + moe_dp_rank = dist.get_rank(moe_dp_group) + dst = get_global_rank(moe_dp_group, 0) + if moe_dp_rank == 0: + gather_tensor = [torch.zeros_like(v) for _ in range(moe_dp_size)] + dist.gather(v, gather_tensor, group=moe_dp_group, dst=dst) + v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) + else: + dist.gather(v, group=moe_dp_group, dst=dst) + + elif use_zero and not is_moe_param and global_dp_size > 1: + dp_rank = dist.get_rank(global_dp_group) + dst = get_global_rank(global_dp_group, 0) + if dp_rank == 0: + gather_tensor = [torch.zeros_like(v) for _ in range(global_dp_size)] + dist.gather(v, gather_tensor, group=global_dp_group, dst=dst) + v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) + else: + dist.gather(v, group=global_dp_group, dst=dst) # Then gather TP shards. partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) if partition_dim is not None: - gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)] - dist.all_gather(gather_tensor, v, group=tp_group) - v = torch.cat(gather_tensor, dim=partition_dim) - + tp_rank = dist.get_rank(tp_group) + dst = get_global_rank(tp_group, 0) + if tp_rank == 0: + gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)] + dist.gather(v, gather_tensor, group=tp_group, dst=dst) + v = torch.cat(gather_tensor, dim=partition_dim) + else: + dist.gather(v, group=tp_group, dst=dst) state_[k] = v.detach().clone().to(device) return state_ @@ -286,8 +318,9 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): def _optimizer_sharder( optimizer: OptimizerWrapper, use_zero: bool, - dp_group: ProcessGroup, + global_dp_group: ProcessGroup, tp_group: ProcessGroup, + moe_dp_group: ProcessGroup, size_per_shard: int = 1024, only_moe_param: bool = False, ): @@ -296,7 +329,6 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): state_dict_sharder = StateDictSharder(size_per_shard) param_info = optimizer.param_info master_to_working_map = optimizer.get_master_to_working_map() - for param, state in optimizer.optim.state.items(): if param is None: continue @@ -305,22 +337,23 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): working_param = master_to_working_map[id(param)] else: working_param = param - param_id = param_info["param2id"][id(working_param)] original_shape = param_info["param2shape"][id(working_param)] - state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state( + state_ = MoECheckpointIO.gather_from_sharded_optimizer_state( state, working_param, original_shape=original_shape, - dp_group=dp_group, + global_dp_group=global_dp_group, + moe_dp_group=moe_dp_group, tp_group=tp_group, use_zero=use_zero, inplace=False, - is_moe_param=is_moe_tensor(working_param), + is_moe_param=is_moe_tensor(working_param), # TODO: Check correctness here ) if only_moe_param and not is_moe_tensor(working_param): continue + block, block_size = state_dict_sharder.append_optim_state(param_id, state_) if block is not None: yield block, block_size @@ -359,25 +392,28 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): Path(checkpoint).mkdir(parents=True, exist_ok=True) - # Devices along the same dp_group share the same copies of states when zero is not used. - # In this case only let the device with dp_rank == 0 save the model. - if not self.use_zero and self.real_dp_rank != 0: + # If optim states are not sharded, other ranks don't need to participate in gather. + if not self.use_zero and self.moe_dp_rank != 0: dist.barrier() return # Then collect the sharded states along dp_group(if using zero)/tp_group. # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. - state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder( + state_dict_shard = MoECheckpointIO._optimizer_sharder( optimizer, use_zero=self.use_zero, - dp_group=self.dp_group, + global_dp_group=self.global_dp_group, tp_group=self.tp_group, + moe_dp_group=self.moe_dp_group, size_per_shard=size_per_shard, only_moe_param=self.ep_rank != 0, ) states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) index_file = CheckpointIndexFile(checkpoint) - control_saving = self.real_dp_rank == 0 and self.tp_rank == 0 + # e.g. dp_size = 4, moe_dp_size = 2, ep_size = 2 and use gather + # rank 0 saves moe & non-moe params; rank 1 only saves moe params + # rank 3 & 4 save nothing + control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0 if self.pp_size == 1 and self.ep_size == 1: # When pipeline is not used, save the optimizer shards as in general checkpointIO @@ -596,7 +632,6 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): OrderedDict: The sharded optimizer state of the given parameter. """ state_ = state if inplace else copy.deepcopy(state) - for k, v in state_.items(): if isinstance(v, torch.Tensor) and k != "step": # Shard state along tensor parallel group. @@ -606,24 +641,218 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): v = v.split(slice_size, dim=partition_dim)[self.tp_rank] # Shard state along data parallel group when using Zero. - if self.use_zero and not is_moe_param: - padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size + if self.use_zero and not is_moe_param and self.global_dp_size > 1: + padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size with torch.no_grad(): v = v.flatten() if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) - slice_size = v.numel() // self.dp_size - v = v.split(slice_size, dim=0)[self.dp_rank] + slice_size = v.numel() // self.global_dp_size + v = v.split(slice_size, dim=0)[self.global_dp_rank] + + elif self.use_zero and is_moe_param and self.moe_dp_size > 1: + # LowLevelZeRO pads by global dp size for now. + # TODO: update both to use moe dp size + padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + slice_size = v.numel() // self.moe_dp_size + v = v.split(slice_size, dim=0)[self.moe_dp_rank] state_[k] = v.detach().clone().to(device) return state_ - def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): - raise NotImplementedError + """Migration from MoEHybridParallelCheckpointIO. These functions mostly deals with unsharded saving, + and can be savely deleted since large MoE models are often saved in shards. + """ + # Copied from colossalai.moe + def pre_save_model(self, model: nn.Module) -> dict: + state_dict = model.state_dict() + for name, param in model.named_parameters(): + if ".experts." in name and is_moe_tensor(param): + ep_group = param.ep_group + ep_rank = dist.get_rank(ep_group) + ep_size = dist.get_world_size(ep_group) + # TODO: check correctness here + # dp_rank = get_dp_rank(param) + dp_rank = dist.get_rank(self.global_dp_group) + if dp_rank == 0: + param = param.data.cuda() + if ep_rank == 0: + all_param = [torch.zeros_like(param) for _ in range(ep_size)] + else: + all_param = None + # gather param from every ep rank + # dist.all_gather(all_param, param, group=ep_group) + dist.gather(param, all_param, group=ep_group) + if ep_rank == 0: + all_param = torch.cat(all_param, dim=0) + state_dict[name] = all_param.cpu() + + if self.pp_size > 1: + if self.dp_rank == 0: + out = [None for _ in range(self.pp_size)] + dist.gather_object(state_dict, out, group=self.pp_group) + if self.pp_rank == 0: + new_state_dict = {} + for o in out: + new_state_dict.update(o) + state_dict = new_state_dict + dist.barrier() + return state_dict + + def save_unsharded_model( + self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool, + use_safetensors: bool, + ): + state_dict = self.pre_save_model(model) + if dist.get_rank() == 0: + torch.save(state_dict, checkpoint) + dist.barrier() + + # Copied from colossalai.moe def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): - raise NotImplementedError + """ + Save optimizer state dict to a file with given path. + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict. + checkpoint (str): Path to save optimizer state_dict. + gather_dtensor (bool): Whether to gather_dtensor, not used. + """ + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" + + # optimizer states of parameters kept by local device('s pipeline stage) + local_states = dict() + + for param, state in optimizer.optim.state.items(): + if param is None: + continue + + # working param is needed for obtaining correct param_id + master_to_working_map = optimizer.get_master_to_working_map() + if master_to_working_map is not None and id(param) in master_to_working_map: + working_param = master_to_working_map[id(param)] + else: + working_param = param + + # gather complete state from tp shards & dp shards + param_id = optimizer.param_info["param2id"][id(working_param)] + local_states[param_id] = self.pre_save_optim( + state, + working_param, + inplace=False, + device=torch.device("cuda"), + ) + + if self.pp_size == 1: + # When pipeline is not used, let master rank directly save the collected state_dict. + state_dict = {"param_groups": optimizer.optim.param_groups, "state": local_states} + if self.coordinator.is_master(): + save_state_dict(state_dict, checkpoint, use_safetensors=False) + else: + # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. + states_list = [None for _ in range(self.pp_size)] + dist.barrier(self.pp_group) + # dist.all_gather_object(states_list, local_states, self.pp_group) + dist.gather_object(local_states, states_list, self.pp_group) + + # Only the master rank do the saving. + if self.coordinator.is_master(): + state_dict = {"param_groups": optimizer.optim.param_groups, "state": dict()} + for _states in states_list: + state_dict["state"].update(_states) + save_state_dict(state_dict, checkpoint, use_safetensors=False) + dist.barrier() + + # Copied from colossalai.moe def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False): - raise NotImplementedError + """ + Load optimizer from a file with given path. + + Args: + optimizer (OptimizerWrapper): The optimizer to be loaded. + checkpoint_index_file (str): Path to the checkpoint file. + """ + + def _get_param_id_from_optimizer_param( + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + ): + if master_to_working_map is not None and id(param) in master_to_working_map: + working_param = master_to_working_map[id(param)] + else: + working_param = param + if id(working_param) in optimizer.param_info["param2id"]: + return optimizer.param_info["param2id"][id(working_param)] + else: + None + + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + + # Complete optimizer state_dict loaded from checkpoint, need to be processed later. + state_dict = load_state_dict(checkpoint) + + # Load param_groups. + updated_groups = [] + saved_groups = state_dict["param_groups"] + for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage. + updated_groups.append(new_pg) + + # ep extra group + # if MOE_MANAGER.parallel == "EP": + if self.ep_size > 1: + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = optimizer.optim.param_groups[-1][ + "params" + ] # Only keep the parameters kept by current pipeline stage. + for param in new_pg["params"]: + param.data = param.data.to(torch.float32) + updated_groups.append(new_pg) + optimizer.optim.__dict__.update({"param_groups": updated_groups}) + + # Load saved states to optimizer. First discard those states not belonging to current pipeline stage. + master_to_working_map = optimizer.get_master_to_working_map() + id_map = {} + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + if param_id is not None: + id_map[param_id] = param + load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True) + + # Then shard the loaded optimizer states if using tp/zero. + for param, state in optimizer.optim.state.items(): + if param is None: + continue + device = param.device + if master_to_working_map is not None and id(param) in master_to_working_map: + working_param = master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.pre_load_optim( + state, + param, + current_shape=working_param.shape, + original_shape=original_shape, + device=device, + inplace=True, + ) + optimizer.optim.state[param] = sharded_state + sharded_optimizer_loading_epilogue(optimizer.optim) + dist.barrier() diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 20870a3c2..36138f33e 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -242,6 +242,7 @@ def save_state_dict_shards( shard_filenames = [] for idx, shard_pair in enumerate(sharded_state_dict): shard, current_size = shard_pair + # Just loop over the sharder and gather to other ranks if not master if not is_master: del shard continue diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index f0cb78c5f..1319a4529 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -244,19 +244,25 @@ class ProcessGroupMesh: return target_group def get_group_along_axis( - self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None + self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None ) -> ProcessGroup: """Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created. Args: - axis (int): Axis along which the process groups are created. + axis (int or list of int): Axes along which the process groups are created. indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. backend (Optional[str], optional): Backend of the process group. Defaults to None. Returns: ProcessGroup: The process group along the given axis which the current process belongs to. """ - indices_at_axis = indices_at_axis or list(range(self._shape[axis])) + indices_at_axis = indices_at_axis + if indices_at_axis is None: + if isinstance(axis, (list, tuple)): + indices_at_axis = list(list(range(self._shape[ax])) for ax in axis) + else: + indices_at_axis = list(range(self._shape[axis])) + coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis) ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) if ranks_in_group not in self._ranks_to_group: diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py index cc33c77f3..0623d19ef 100644 --- a/colossalai/moe/__init__.py +++ b/colossalai/moe/__init__.py @@ -1,20 +1,5 @@ -from .checkpoint import MoECheckpointIO -from .experts import MLPExperts -from .layers import SparseMLP, apply_load_balance from .manager import MOE_MANAGER -from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter -from .utils import NormalNoiseGenerator, UniformNoiseGenerator __all__ = [ - "MLPExperts", - "MoeRouter", - "Top1Router", - "Top2Router", - "TopKRouter", - "NormalNoiseGenerator", - "UniformNoiseGenerator", - "SparseMLP", - "MoECheckpointIO", "MOE_MANAGER", - "apply_load_balance", ] diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py deleted file mode 100644 index 59a0ec3f0..000000000 --- a/colossalai/moe/checkpoint.py +++ /dev/null @@ -1,792 +0,0 @@ -import copy -import logging -import os -from pathlib import Path -from shutil import rmtree -from typing import Dict, Iterator, Optional, OrderedDict, Tuple - -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.distributed import ProcessGroup - -from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO -from colossalai.checkpoint_io.utils import ( - StateDictSharder, - gather_distributed_param, - get_model_base_filenames, - get_optimizer_base_filenames, - is_safetensors_available, - load_shard_state_dict, - load_state_dict, - load_state_dict_into_model, - load_states_into_optimizer, - save_config_file, - save_param_groups, - save_state_dict, - save_state_dict_shards, - sharded_optimizer_loading_epilogue, -) -from colossalai.interface import OptimizerWrapper -from colossalai.moe.manager import MOE_MANAGER -from colossalai.tensor.moe_tensor.api import ( - get_dp_group, - get_dp_rank, - get_dp_size, - get_ep_group, - get_ep_rank, - get_ep_size, - is_moe_tensor, -) - - -class MoECheckpointIO(HybridParallelCheckpointIO): - def __init__( - self, - dp_group: ProcessGroup, - pp_group: ProcessGroup, - tp_group: ProcessGroup, - zero_stage: int, - ) -> None: - assert zero_stage in [ - 0, - 1, - 2, - ], f"zero_stage should be 0 or 1 or 2, got {zero_stage}" - super().__init__(dp_group, pp_group, tp_group, zero_stage) - self.parallel = MOE_MANAGER.parallel - - def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict: - """ - Preprocess state_dict before loading and slice the state_dict of MOE tensors. - """ - for name, param in state_dict.items(): - if ".experts." in name: - if name in dict(model.named_parameters()): - model_param = dict(model.named_parameters())[name] - if is_moe_tensor(model_param): - ep_rank = get_ep_rank(model_param) - ep_size = get_ep_size(model_param) - expert_num = param.shape[0] // ep_size - assert param.shape[0] % ep_size == 0 - param = param[ep_rank * expert_num : (ep_rank + 1) * expert_num] - state_dict[name] = param - dist.barrier() - return state_dict - - def _model_sharder( - self, - state_dict: nn.Module, - prefix: str = "", - keep_vars: bool = False, - size_per_shard: int = 1024, - ) -> Iterator[Tuple[OrderedDict, int]]: - # An internel method that breaks state_dict of model into shards within limited size. - state_dict_sharder = StateDictSharder(size_per_shard) - - for name, param in state_dict.items(): - if param is None: - continue - # Gather tensor pieces when using tensor parallel. - param_ = gather_distributed_param(param, keep_vars=False) - block, block_size = state_dict_sharder.append_param(prefix + name, param_) - if block is not None: - yield block, block_size - - # Return the last block in sharder. - yield state_dict_sharder.current_block, state_dict_sharder.current_block_size - - def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool) -> None: - state_dict = torch.load(checkpoint) - state_dict = self.pre_load_model(model, state_dict) - model.load_state_dict(state_dict, strict=strict if self.pp_size == 1 else False) - - def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): - """ - Load sharded model with the given path to index file of checkpoint folder. - - Args: - model (nn.Module): The model to be loaded. - checkpoint_index_file (str): Path to the index file of checkpointing folder. - strict (bool, optional): For name matching during loading state_dict. Defaults to False. - This argument should be manually set to False since params on same device might be stored in different files. - """ - - # Check whether the checkpoint uses safetensors. - use_safetensors = False - if "safetensors" in checkpoint_index_file.name: - use_safetensors = True - - if use_safetensors and not is_safetensors_available(): - raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") - - # Read checkpoint index file. - ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) - ckpt_root_path = ckpt_index_file.root_path - weight_map = ckpt_index_file.weight_map - strict = False - - # Load params & buffers to model. - # Keep a record of loaded files so that file will not be repeatedly loaded. - loaded_file = set() - - def _load(name: str): - if name not in weight_map: - raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!") - filename = weight_map[name] - - # If this param/buffer has been loaded before, directly return. - if filename in loaded_file: - return - - file_path = os.path.join(ckpt_root_path, filename) - state_dict = load_shard_state_dict(Path(file_path), use_safetensors) - state_dict = self.pre_load_model(model, state_dict) - missing_keys = [] - - load_state_dict_into_model( - model, - state_dict, - missing_keys=missing_keys, - strict=strict, - load_sub_module=True, - ) - loaded_file.add(filename) - - # Load parameters. - for name, _ in model.named_parameters(): - _load(name) - - if self.verbose: - logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") - - def pre_save_model(self, model: nn.Module) -> dict: - state_dict = model.state_dict() - for name, param in model.named_parameters(): - if ".experts." in name and is_moe_tensor(param): - ep_group = get_ep_group(param) - ep_rank = get_ep_rank(param) - ep_size = get_ep_size(param) - dp_rank = get_dp_rank(param) - if dp_rank == 0: - param = param.data.cuda() - all_param = [torch.zeros_like(param) for _ in range(ep_size)] - # gather param from every ep rank - dist.all_gather(all_param, param, group=ep_group) - if ep_rank == 0: - all_param = torch.cat(all_param, dim=0) - state_dict[name] = all_param.cpu() - if self.pp_size > 1: - if self.dp_rank == 0: - out = [None for _ in range(self.pp_size)] - dist.all_gather_object(out, state_dict, group=self.pp_group) - if self.pp_rank == 0: - new_state_dict = {} - for o in out: - new_state_dict.update(o) - state_dict = new_state_dict - dist.barrier() - return state_dict - - def save_unsharded_model( - self, - model: nn.Module, - checkpoint: str, - gather_dtensor: bool, - use_safetensors: bool, - ): - state_dict = self.pre_save_model(model) - if dist.get_rank() == 0: - torch.save(state_dict, checkpoint) - dist.barrier() - - def save_sharded_model( - self, - model: nn.Module, - checkpoint: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - size_per_shard: int = 1024, - use_safetensors: bool = False, - ) -> None: - """ - Save sharded model checkpoint under the given checkpointing path. - The following files will be created under the path: - - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. - - Multiple files that store state tensors of models. - The filenames are in the form of "pytorch_model.-000XX.bin" - - Args: - model (nn.Module): Model on local device to be saved. - checkpoint (str): Checkpointing path which should be a directory path. - gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. - prefix (str, optional): Perfix of file to save. Defaults to None. - size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. - use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. - """ - torch.cuda.empty_cache() - if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") - return - - Path(checkpoint).mkdir(parents=True, exist_ok=True) - - # Then collect the sharded parameters & buffers along tp_group. - # Only devices with tp_rank == 0 are responsible for model saving. - state_dict = self.pre_save_model(model) - - if dist.get_rank() == 0: - state_dict_shard = self._model_sharder(state_dict, size_per_shard=size_per_shard) - - # Devices along the same dp_group share the same copies of model. - # So only let the device with dp_rank == 0 save the model. - if self.dp_rank != 0: - return - - weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) - index_file = CheckpointIndexFile(checkpoint) - control_saving = self.tp_rank == 0 - - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=control_saving, - use_safetensors=use_safetensors, - ) - if control_saving: - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - save_config_file(model, checkpoint) - if self.verbose: - logging.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) - dist.barrier() - torch.cuda.empty_cache() - - # ======================================================== - # Abstract methods for optimizer loading/saving implementation - # ======================================================== - - def pre_load_optim( - self, - state: OrderedDict, - working_param, - current_shape: torch.Size, - original_shape: torch.Size, - device: torch.device, - inplace: bool, - ) -> OrderedDict: - """ - With complete optimizer states of a specific parameter loaded from checkpoint, - slice out the sharded optimizer states kept by current device. - - Args: - state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint. - current_shape (torch.Size): The size of parameter after sharding. - original_shape (torch.Size): The size of parameter before sharding. - device (torch.device): The destination device of loaded optimizer states. - inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. - - Returns: - OrderedDict: The sharded optimizer state of the given parameter. - """ - state_ = state if inplace else copy.deepcopy(state) - is_moe_tensor_flag = is_moe_tensor(working_param) - if is_moe_tensor_flag: - ep_rank = get_ep_rank(working_param) - ep_size = get_ep_size(working_param) - - for k, v in state_.items(): - if isinstance(v, torch.Tensor) and k != "step": - if is_moe_tensor_flag: - with torch.no_grad(): - expert_num = v.shape[0] // ep_size - assert v.shape[0] % ep_size == 0 - v = v[ep_rank * expert_num : (ep_rank + 1) * expert_num] - else: - # Shard state along data parallel group when using Zero. - padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size - with torch.no_grad(): - v = v.flatten() - if padding_size > 0: - v = torch.nn.functional.pad(v, [0, padding_size]) - slice_size = v.numel() // self.dp_size - v = v.split(slice_size, dim=0)[self.dp_rank] - - state_[k] = v.detach().clone().to(device) - - return state_ - - def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): - """ - Load sharded optimizer with the given path to index file of checkpoint folder. - - Args: - optimizer (OptimizerWrapper): The optimizer to be loaded. - checkpoint_index_file (str): Path to the index file of checkpointing folder. - prefix (str): Not used. - """ - assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" - - def _get_param_id_from_optimizer_param( - param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None - ): - if master_to_working_map is not None and id(param) in master_to_working_map: - working_param = master_to_working_map[id(param)] - elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map: - working_param = optimizer.moe_master_to_working_map[id(param)] - else: - working_param = param - return optimizer.param_info["param2id"][id(working_param)] - - # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. - # When Zero is used, the mapped parameter objects should be fp32 master parameters. - # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. - id_map = {} - master_to_working_map = optimizer.get_master_to_working_map() - for pg in optimizer.optim.param_groups: - for param in pg["params"]: - param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer) - id_map[param_id] = param - - # Read checkpoint index file. - ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) - ckpt_root_path = ckpt_index_file.root_path - weight_map = ckpt_index_file.weight_map - weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int - - # Load param_groups - param_group_path = ckpt_index_file.get_param_group_filename() - if param_group_path is None: - raise RuntimeError( - f"Invalid index file path {checkpoint_index_file} for an optimizer. \ - Lacking param group file under current directory." - ) - saved_groups = torch.load(param_group_path) - - updated_groups = [] - for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): - # obtain updated param group - new_pg = copy.deepcopy(saved_pg) - new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change. - updated_groups.append(new_pg) - # ep param group - if len(optimizer.optim.param_groups) > len(saved_groups): - new_pg = copy.deepcopy(saved_pg) - new_pg["params"] = optimizer.optim.param_groups[-1]["params"] - updated_groups.append(new_pg) - optimizer.optim.__dict__.update({"param_groups": updated_groups}) - - # Load saved states to optimizer. - # Keep a record of loaded files so that file will not be repeatedly loaded. - loaded_file = set() - for pg in optimizer.optim.param_groups: - for param in pg["params"]: - if param is None: - continue - param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer) - if param_id not in weight_map: - continue - filename = weight_map[param_id] - - # If this param's states has been loaded before, directly return. - if filename in loaded_file: - continue - - file_path = os.path.join(ckpt_root_path, filename) - state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) - - # Then shard the loaded optimizer states if using tp/zero. - for pid, state in list(state_dict.items()): - if pid in id_map: - param = id_map[pid] - if master_to_working_map is not None and id(param) in master_to_working_map: - working_param = master_to_working_map[id(param)] - elif ( - hasattr(optimizer, "moe_master_to_working_map") - and id(param) in optimizer.moe_master_to_working_map - ): - working_param = optimizer.moe_master_to_working_map[id(param)] - else: - working_param = param - original_shape = optimizer.param_info["param2shape"][id(working_param)] - sharded_state = self.pre_load_optim( - state, - working_param, - current_shape=working_param.shape, - original_shape=original_shape, - device="cpu", - inplace=True, - ) - state_dict[pid] = sharded_state - - load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) - loaded_file.add(filename) - - sharded_optimizer_loading_epilogue(optimizer.optim) - if self.verbose and self.coordinator.is_master(): - logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") - dist.barrier() - - def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): - """ - Load optimizer from a file with given path. - - Args: - optimizer (OptimizerWrapper): The optimizer to be loaded. - checkpoint_index_file (str): Path to the checkpoint file. - """ - - def _get_param_id_from_optimizer_param( - param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None - ): - if master_to_working_map is not None and id(param) in master_to_working_map: - working_param = master_to_working_map[id(param)] - else: - working_param = param - if id(working_param) in optimizer.param_info["param2id"]: - return optimizer.param_info["param2id"][id(working_param)] - else: - None - - if self.coordinator.is_master(): - logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") - - assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" - - # Complete optimizer state_dict loaded from checkpoint, need to be processed later. - state_dict = load_state_dict(checkpoint) - - # Load param_groups. - updated_groups = [] - saved_groups = state_dict["param_groups"] - for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): - new_pg = copy.deepcopy(saved_pg) - new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage. - updated_groups.append(new_pg) - # ep extra group - if MOE_MANAGER.parallel == "EP": - new_pg = copy.deepcopy(saved_pg) - new_pg["params"] = optimizer.optim.param_groups[-1][ - "params" - ] # Only keep the parameters kept by current pipeline stage. - for param in new_pg["params"]: - param.data = param.data.to(torch.float32) - updated_groups.append(new_pg) - optimizer.optim.__dict__.update({"param_groups": updated_groups}) - - # Load saved states to optimizer. First discard those states not belonging to current pipeline stage. - master_to_working_map = optimizer.get_master_to_working_map() - id_map = {} - for pg in optimizer.optim.param_groups: - for param in pg["params"]: - param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) - if param_id is not None: - id_map[param_id] = param - load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True) - - # Then shard the loaded optimizer states if using tp/zero. - for param, state in optimizer.optim.state.items(): - if param is None: - continue - device = param.device - if master_to_working_map is not None and id(param) in master_to_working_map: - working_param = master_to_working_map[id(param)] - else: - working_param = param - original_shape = optimizer.param_info["param2shape"][id(working_param)] - sharded_state = self.pre_load_optim( - state, - param, - current_shape=working_param.shape, - original_shape=original_shape, - device=device, - inplace=True, - ) - optimizer.optim.state[param] = sharded_state - sharded_optimizer_loading_epilogue(optimizer.optim) - dist.barrier() - - def pre_save_optim( - self, - state: OrderedDict, - param: torch.Tensor, - inplace: bool, - device: torch.device = torch.device("cpu"), - ) -> OrderedDict: - """ - With given parameter and its optimizer states, gather the complete optimizer state for saving. - - Args: - state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero. - param (torch.Tensor): The given parameter. It should be working_param when using Zero. - original_shape (torch.Size): The size of parameter before sharding. - dp_group (ProcessGroup): The process group of data parallel. - tp_group (ProcessGroup): The process group of tensor parallel. - use_zero (bool): Whether Zero is used. - inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. - device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu'). - - Returns: - OrderedDict: The complete optimizer state of given parameter. - """ - if is_moe_tensor(param): - moe_dp_group = get_dp_group(param) - moe_dp_size = get_dp_size(param) - moe_ep_group = get_ep_group(param) - moe_ep_size = get_ep_size(param) - state_ = state if inplace else copy.deepcopy(state) - - for k, v in state_.items(): - if isinstance(v, torch.Tensor) and k != "step": - # moe param - if is_moe_tensor(param): - # dp gather - v = v.cuda() - gather_tensor = [torch.zeros_like(v) for _ in range(moe_dp_size)] - dist.all_gather(gather_tensor, v, group=moe_dp_group) - v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) - # ep gather - gather_tensor = [torch.zeros_like(v) for _ in range(moe_ep_size)] - dist.all_gather(gather_tensor, v, group=moe_ep_group) - v = torch.cat(gather_tensor, dim=0) - else: - # global dp - v = v.cuda() - gather_tensor = [torch.zeros_like(v) for _ in range(dist.get_world_size(self.dp_group))] - dist.all_gather(gather_tensor, v, group=self.dp_group) - v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) - - state_[k] = v.detach().clone().to(device) - - return state_ - - def _optimizer_sharder( - self, - optimizer: OptimizerWrapper, - size_per_shard: int = 1024, - ): - # An internel method that breaks state_dict of optimizer into shards within limited size. - - state_dict_sharder = StateDictSharder(size_per_shard) - param_info = optimizer.param_info - master_to_working_map = optimizer.get_master_to_working_map() - - for param, state in optimizer.optim.state.items(): - if param is None: - continue - - if master_to_working_map is not None and id(param) in master_to_working_map: - working_param = master_to_working_map[id(param)] - elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map: - working_param = optimizer.moe_master_to_working_map[id(param)] - else: - working_param = param - - param_id = param_info["param2id"][id(working_param)] - state_ = self.pre_save_optim( - state, - working_param, - inplace=False, - device=torch.device("cuda"), - ) - - block, block_size = state_dict_sharder.append_optim_state(param_id, state_) - if block is not None: - yield block, block_size - - # Return the last block in sharder. - yield state_dict_sharder.current_block, state_dict_sharder.current_block_size - - def save_sharded_optimizer( - self, - optimizer: OptimizerWrapper, - checkpoint: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - size_per_shard: int = 1024, - ): - """ - Save sharded optimizer checkpoint under the given checkpointing path. - The following files will be created under the path: - - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names - - A group file (pytorch_optim_group.bin) recording information of param_groups - - Multiple files that store state tensors of optimizers. - If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin". - If pipeline parallelism is not used, "pytorch_optim.-000XX.bin" - - Args: - optimizer (OptimizerWrapper): Optimizer to save sharded state_dict - checkpoint (str): Path to save optimizer state_dict - gather_dtensor (bool): Whether to gather_dtensor, not used - prefix (str): Perfix of file to save - size_per_shard (int): Max file size of each file shard that store state tensors - """ - torch.cuda.empty_cache() - assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" - if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") - return - - Path(checkpoint).mkdir(parents=True, exist_ok=True) - - # Devices along the same dp_group share the same copies of states when zero is not used. - # In this case only let the device with dp_rank == 0 save the model. - if not self.use_zero and self.dp_rank != 0: - return - - # Then collect the sharded states along dp_group(if using zero)/tp_group. - # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. - state_dict_shard = self._optimizer_sharder( - optimizer, - size_per_shard=size_per_shard, - ) - states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) - index_file = CheckpointIndexFile(checkpoint) - control_saving = self.dp_rank == 0 and self.tp_rank == 0 - if self.pp_size == 1: - # When pipeline is not used, save the optimizer shards as in general checkpointIO - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=control_saving, - ) - - if control_saving: - # Store param groups. - index_file.append_meta_data("param_groups", param_group_file) - group_file_path = os.path.join(checkpoint, param_group_file) - save_param_groups(optimizer.param_info, group_file_path) - # Store index file. - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - if self.verbose and self.coordinator.is_master(): - logging.info( - f"The optimizer is going to be split to checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) - - else: - # When pipeline is used, each stage produces its own shard files and index files. - # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ - # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. - - final_index_file_path = copy.deepcopy(save_index_file) - tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") - Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) - - # Manage filenames of sharded weights and index file for each pipeline stage. - states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") - save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") - save_index_file = os.path.join("tmp_index_files", save_index_file) - - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=control_saving, - use_pp_format=True, - ) - - if control_saving: - assert ( - self.dp_rank == 0 and self.tp_rank == 0 - ), "The saving process should have both dp_rank and tp_rank as 0." - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - else: - return - - dist.barrier(self.pp_group) - - # The global master rank integrates the index files and clean the folder. - if self.pp_rank == 0: - final_index_file = CheckpointIndexFile(checkpoint) - final_index_file.append_meta_data("total_size", 0) - - for filename in os.listdir(tmp_index_file_folder): - stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) - final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] - for param_id, state_filename in stage_index_file.weight_map.items(): - final_index_file.append_weight_map(param_id, state_filename) - - # Store param groups. - final_index_file.append_meta_data("param_groups", param_group_file) - group_file_path = os.path.join(checkpoint, param_group_file) - save_param_groups(optimizer.param_info, group_file_path) - - final_index_file.write_index_file(final_index_file_path) - rmtree(tmp_index_file_folder) - - if self.verbose and self.coordinator.is_master(): - logging.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {final_index_file_path}." - ) - torch.cuda.empty_cache() - - def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): - """ - Save optimizer state dict to a file with given path. - - Args: - optimizer (OptimizerWrapper): Optimizer to save sharded state_dict. - checkpoint (str): Path to save optimizer state_dict. - gather_dtensor (bool): Whether to gather_dtensor, not used. - """ - if self.coordinator.is_master(): - logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") - - assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" - - # optimizer states of parameters kept by local device('s pipeline stage) - local_states = dict() - - for param, state in optimizer.optim.state.items(): - if param is None: - continue - - # working param is needed for obtaining correct param_id - master_to_working_map = optimizer.get_master_to_working_map() - if master_to_working_map is not None and id(param) in master_to_working_map: - working_param = master_to_working_map[id(param)] - else: - working_param = param - - # gather complete state from tp shards & dp shards - param_id = optimizer.param_info["param2id"][id(working_param)] - local_states[param_id] = self.pre_save_optim( - state, - working_param, - inplace=False, - device=torch.device("cuda"), - ) - - if self.pp_size == 1: - # When pipeline is not used, let master rank directly save the collected state_dict. - state_dict = {"param_groups": optimizer.optim.param_groups, "state": local_states} - if self.coordinator.is_master(): - save_state_dict(state_dict, checkpoint, use_safetensors=False) - else: - # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. - states_list = [None for _ in range(self.pp_size)] - dist.barrier(self.pp_group) - dist.all_gather_object(states_list, local_states, self.pp_group) - - # Only the master rank do the saving. - if self.coordinator.is_master(): - state_dict = {"param_groups": optimizer.optim.param_groups, "state": dict()} - for _states in states_list: - state_dict["state"].update(_states) - save_state_dict(state_dict, checkpoint, use_safetensors=False) - dist.barrier() diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py index 85c12d73f..3dc6c02c7 100644 --- a/colossalai/moe/load_balance.py +++ b/colossalai/moe/load_balance.py @@ -7,8 +7,8 @@ from torch import Tensor, nn from torch.distributed import ProcessGroup from colossalai.cluster import ProcessGroupMesh -from colossalai.moe.experts import MLPExperts from colossalai.moe.manager import MOE_MANAGER +from colossalai.shardformer.layer.moe import MLPExperts from colossalai.zero.low_level import LowLevelZeroOptimizer @@ -292,7 +292,7 @@ class LoadBalancer: exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"] exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"] else: - master_weight_ptr = optim._param_store.working_to_master_param[id(weight)] + master_weight_ptr = optim.working_to_master_param[id(weight)] working_weight_ptr = weight exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"] exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"] @@ -344,7 +344,7 @@ class LoadBalancer: # gate optim should be obtained first gate_shape = self.gate.shape # get master weight and optim - master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)] + master_gate_weight = optim.working_to_master_param[id(self.gate)] gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"] gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"] # gather diff --git a/colossalai/moe/loss.py b/colossalai/moe/loss.py deleted file mode 100644 index 75624510b..000000000 --- a/colossalai/moe/loss.py +++ /dev/null @@ -1,78 +0,0 @@ -import torch.nn as nn -from torch.nn.modules.loss import _Loss - -from colossalai.moe.manager import MOE_MANAGER - - -class MoeCrossEntropyLoss(_Loss): - r"""torch.nn.CrossEntropyLoss added with auxiliary loss. - - Args: - input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). - target (:class:`torch.tensor`): Ground truth class indices or class probabilities. - aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01. - - The ``args`` and ``kwargs`` should include parameters below: - :: - - weight (Tensor, optional) - size_average (bool, optional) - ignore_index (int, optional) - reduce (bool, optional) - reduction (str, optional) - label_smoothing (float, optional) - - More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in - `Cross_entropy `_. - """ - - def __init__(self, aux_weight: float = 0.01, *args, **kwargs): - super().__init__() - self.loss = nn.CrossEntropyLoss(*args, **kwargs) - self.aux_weight = aux_weight - - def forward(self, *args): - """ - The ``args`` should at least include parameters below: - :: - - input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). - target (:class:`torch.tensor`): Ground truth class indices or class probabilities. - - More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in - `Cross_entropy `_. - """ - main_loss = self.loss(*args) - aux_loss = MOE_MANAGER.get_loss() - return main_loss + self.aux_weight * aux_loss - - -class MoeLoss(_Loss): - """A wrapper class for any loss module to add with auxiliary loss. - - Args: - aux_weight (float): Weight of auxiliary loss in total loss. - loss_fn (``Callable``): Loss function. - args (list): Args in loss function. - kwargs (dict): Kwargs in loss function - """ - - def __init__(self, aux_weight: float, loss_fn, *args, **kwargs): - super().__init__() - self.loss_fn = loss_fn(*args, **kwargs) - self.aux_weight = aux_weight - - def forward(self, *args, **kwargs): - """ - The ``args`` and ``kwargs`` should at least include parameters below: - :: - - input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). - target (:class:`torch.tensor`): Ground truth class indices or class probabilities. - - Note: - The ``args`` and ``kwargs`` may include different parameters varying with different loss function. - """ - main_loss = self.loss_fn(*args, **kwargs) - aux_loss = MOE_MANAGER.get_loss() - return main_loss + self.aux_weight * aux_loss diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py deleted file mode 100644 index e40674c9b..000000000 --- a/colossalai/moe/routers.py +++ /dev/null @@ -1,466 +0,0 @@ -import math -from abc import ABC -from typing import Callable, Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F -from torch.distributed import ProcessGroup - -from colossalai.accelerator import get_accelerator -from colossalai.moe._operation import moe_cumsum -from colossalai.moe.manager import MOE_MANAGER - - -class MoeRouter(nn.Module, ABC): - """Base class for all MoE routers. - Args: - k_value (int): The value of top_k. - capacity_factor_train (float): Capacity factor in routing of training. - capacity_factor_eval (float): Capacity factor in routing of evaluation. - min_capacity (int): The minimum number of the capacity of each expert. - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation - """ - - def __init__( - self, - k_value: int, - capacity_factor_train: float, - capacity_factor_eval: float, - min_capacity: int, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True, - use_kernel: bool = False, - ): - super().__init__() - self.k_value = k_value - self.capacity_factor_train = capacity_factor_train - self.capacity_factor_eval = capacity_factor_eval - self.min_capacity = min_capacity - self.noisy_func = noisy_func - self.drop_tks = drop_tks - self._aux_loss = None - self._z_loss = None - self.use_kernel = use_kernel - - def get_capacity(self, num_tokens, num_experts, ep_group=None): - if ep_group is not None: - num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device()) - dist.all_reduce(num_tokens_tensor, group=ep_group) - num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group) - capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval - capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts) - capacity += capacity % 2 - capacity = max(capacity, self.min_capacity) - assert capacity > 0 - return int(capacity) - - def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None: - """Computes auxiliary load balancing loss as in Switch Transformer. - - See Switch Transformer (https://arxiv.org/abs/2101.03961). This function - implements the loss function presented in equations (4) - (6). It aims to - penalize those cases where the routing between experts is unbalanced. - - Args: - router_probs: Probability assigned to each expert per token. Shape: - [num_groups, tokens_per_group, num_experts]. - expert_indices: [num_groups, tokens_per_group, num_selected_experts] - indices identifying the top num_selected_experts for a given token. - """ - assert self._aux_loss is None - if router_probs.dim() == expert_indices.dim() == 2: - router_probs = router_probs.unsqueeze(0) - expert_indices = expert_indices.unsqueeze(0) - assert ( - router_probs.dim() == expert_indices.dim() == 3 - ), "router_probs must be 3D tensor and expert_indices must be 4D tensor" - - # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. - expert_mask = F.one_hot(expert_indices, num_experts) - # For a given token, determine if it was routed to a given expert. - # Shape: [num_groups, tokens_per_group, num_experts] - expert_mask = expert_mask.max(dim=-2)[0] - - tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2) - router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2) - aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) - self._aux_loss = aux_loss - - def set_z_loss(self, router_logits: torch.Tensor): - """Compute router z-loss. - - The router z-loss was introduced in Designing Effective Sparse Expert Models - (https://arxiv.org/abs/2202.08906). It encourages router logits to remain - small in an effort to improve stability. - - Args: - router_logits: [num_groups, tokens_per_group, num_experts] router logits. - """ - assert self._z_loss is None - if router_logits.dim() == 2: - router_logits = router_logits.unsqueeze(0) - assert router_logits.dim() == 3, "router_logits must be 3D tensor" - num_groups, tokens_per_group, _ = router_logits.shape - log_z = torch.logsumexp(router_logits, dim=-1) - z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group) - self._z_loss = z_loss - - def pop_router_loss(self) -> torch.Tensor: - assert self._aux_loss is not None - MOE_MANAGER.add_loss(self._aux_loss, self._z_loss) - self._aux_loss = None - self._z_loss = None - - -class Top1Router(MoeRouter): - """Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) - and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed - function can be found in the paper about Switch Transformer of Google. - - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert. - select_policy (str, optional): The policy about tokens selection. - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation - """ - - def __init__( - self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - select_policy: str = "first", - noisy_func: Optional[Callable] = None, - drop_tks: bool = True, - ): - super().__init__( - k_value=1, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks, - ) - self.select_policy = select_policy - assert select_policy in {"first", "random"} - if select_policy == "random": - self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(0.0, device=get_accelerator().get_current_device()), - high=torch.tensor(1.0, device=get_accelerator().get_current_device()), - ).rsample - - def forward( - self, - inputs: torch.Tensor, - use_kernel: bool = False, - ep_group: Optional[ProcessGroup] = None, - use_loss: bool = False, - use_norm: bool = False, - ) -> Tuple: - """ - Args: - inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). - - Returns: - 1. use_kernel is False: - The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). - The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). - 2. use_kernel is True: - ... - """ - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - assert inputs.dtype == torch.float - probs = F.softmax(inputs, dim=-1) - num_experts = probs.size(-1) - num_tokens = inputs.size(0) - capacity = self.get_capacity(num_tokens, num_experts, ep_group) - - top1_idx = torch.argmax(inputs, dim=-1) - mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - - # calculate router loss - self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts) - self.set_z_loss(inputs) - self.pop_router_loss() - - if not self.training and not self.drop_tks and ep_group is not None: - max_num = torch.max(torch.sum(mask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - - if self.select_policy == "random": - rand_mask = mask * self.uniform(mask.shape) - _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) - mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) - ranks = moe_cumsum(mask, use_kernel=self.use_kernel) - elif self.select_policy == "first": - ranks = moe_cumsum(mask, use_kernel=self.use_kernel) - mask = mask * torch.lt(ranks, capacity) - else: - raise NotImplementedError("Not support such select policy yet.") - - ranks = torch.sum(mask * ranks, dim=-1) - used_capacity = mask.sum(dim=0) - - if use_kernel: - mask = torch.sum(mask, dim=-1) - mask = torch.stack([mask], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) - return used_capacity, probs, mask, dest_idx, num_experts * capacity - else: - ranks = F.one_hot(ranks, num_classes=capacity) - weight = mask * probs.type_as(inputs) - combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) - sec_mask = combine_weights.bool() - return used_capacity, combine_weights, sec_mask, probs - - -class Top2Router(MoeRouter): - """Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) - and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed - function can be found in the paper about ViT-MoE. - - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation. - """ - - def __init__( - self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True, - ): - super().__init__( - k_value=2, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks, - ) - - def forward( - self, - inputs: torch.Tensor, - use_kernel: bool = False, - ep_group: Optional[ProcessGroup] = None, - use_norm: bool = False, - use_loss: bool = True, - ) -> Tuple: - """ - Args: - inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). - - Returns: - 1. use_kernel is False: - The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). - The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). - 2. use_kernel is True: - ... - """ - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - assert inputs.dtype == torch.float - probs = F.softmax(inputs, dim=-1) - if use_norm: - routing_weights, _ = torch.topk(probs, 2, dim=-1) - probs = probs / routing_weights.sum(dim=-1, keepdim=True) - - num_experts = probs.size(-1) - num_tokens = inputs.size(0) - capacity = self.get_capacity(num_tokens, num_experts, ep_group) - - top1_idx = torch.argmax(probs, dim=-1) - mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - logits_except1 = probs.masked_fill(mask1.bool(), float("-inf")) - top2_idx = torch.argmax(logits_except1, dim=-1) - mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) - - cmask = mask1 + mask2 # loss: [s, e] - cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 - - # calculate loss - if use_loss: - expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) - self.set_aux_loss(probs, expert_indices, num_experts) - self.set_z_loss(inputs) - self.pop_router_loss() - - if not self.training and not self.drop_tks and ep_group is not None: - max_num = torch.max(torch.sum(cmask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - - rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] - rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel) - rank2 += torch.sum(mask1, dim=-2, keepdim=True) - - mask1 *= torch.lt(rank1, capacity) - mask2 *= torch.lt(rank2, capacity) - used_capacity = mask1.sum(dim=0) + mask2.sum(dim=0) - - rank1 = torch.sum(mask1 * rank1, dim=-1) - rank2 = torch.sum(mask2 * rank2, dim=-1) - - if use_kernel: - mask1 = torch.sum(mask1, dim=-1) - mask2 = torch.sum(mask2, dim=-1) - - mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) - - return used_capacity, probs, mask, dest_idx, num_experts * capacity - else: - """ - The following code is equivalent to: - - ``` - weight1 = mask1 * probs.type_as(inputs) - weight2 = mask2 * probs.type_as(inputs) - rank1_sc = F.one_hot(rank1, num_classes=capacity) - rank2_sc = F.one_hot(rank2, num_classes=capacity) - - cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) - cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) - cb_weight = cb_weight1 + cb_weight2 - sec_mask = cb_weight.bool() - ``` - """ - - weight1 = mask1 * probs.type_as(inputs) - weight2 = mask2 * probs.type_as(inputs) - - cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device) - sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool) - indices = torch.arange(0, inputs.shape[0], device=inputs.device) - cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]] - cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]] - sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]] - sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]] - - return used_capacity, cb_weight, sec_mask - - -class TopKRouter(MoeRouter): - """Masked matmul router using tokens choose top-k experts assignment. - - NOTE: this is modified from flaxformer. - This router uses the same mechanism as in Switch Transformer - (https://arxiv.org/abs/2101.03961) and V-MoE - (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are - sorted by router_probs and then routed to their choice of expert until the - expert's expert_capacity is reached. There is no guarantee that each token is - processed by an expert, or that each expert receives at least one token. - - Attributes: - num_selected_experts: Maximum number of experts to which each token is - routed. Tokens may be routed to fewer experts if particular experts are - oversubscribed / reach capacity. - """ - - def __init__( - self, - num_selected_experts: int, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True, - ): - super().__init__( - num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks - ) - - def forward( - self, - router_probs: torch.Tensor, - expert_capacity: int, - ) -> Tuple: - """Computes masks for the top-k experts per token. - - Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - - Returns: - Dispatch and combine arrays for routing with masked matmuls. - """ - # TODO: FIXME: add parallel group - num_groups, _, num_experts = router_probs.shape - - # Top-k router probability and corresponding expert indices for each token. - # Shape: [num_groups, tokens_per_group, num_selected_experts]. - expert_gate, expert_index = torch.topk(router_probs, self.k_value) - - self.set_aux_loss(router_probs, expert_index, num_experts) - self.pop_router_loss() - - # Make num_selected_experts the leading axis to ensure that top-1 choices - # have priority over top-2 choices, which have priority over top-3 choices, - # etc. - expert_index = torch.transpose(expert_index, 1, 2) - # Shape: [num_groups, num_selected_experts * tokens_per_group] - expert_index = expert_index.reshape(num_groups, -1) - - # Create mask out of indices. - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32) - - # Experts have a fixed capacity that we cannot exceed. A token's priority - # within the expert's buffer is given by the masked, cumulative capacity of - # its target expert. - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1 - # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. - token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts)) - # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. - token_priority = torch.transpose(token_priority, 1, 2) - # For each token, across all selected experts, select the only non-negative - # (unmasked) priority. Now, for group G routing to expert E, token T has - # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E - # is its targeted expert. - # Shape: [num_groups, tokens_per_group, num_experts]. - token_priority = torch.max(token_priority, dim=2)[0] - - # Token T can only be routed to expert E if its priority is positive and - # less than the expert capacity. One-hot matrix will ignore indices outside - # the range [0, expert_capacity). - # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. - valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity) - token_priority = torch.masked_fill(token_priority, ~valid_mask, 0) - dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool) - valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity) - dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0) - - # The combine array will be used for combining expert outputs, scaled by the - # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, - # expert_capacity]. - combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) - - return combine_array, dispatch_mask - - -def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter: - if not grouped: - if top_k == 1: - return Top1Router - elif top_k == 2: - return Top2Router - else: - raise NotImplementedError("top_k > 2 is not supported yet") - else: - return TopKRouter diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index c642f1a44..3d08ab7dd 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -6,10 +6,11 @@ import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from torch.distributed.distributed_c10d import get_process_group_ranks from colossalai.accelerator import get_accelerator from colossalai.moe.manager import MOE_MANAGER -from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor +from colossalai.tensor.moe_tensor.api import is_moe_tensor class ForceFP32Parameter(torch.nn.Parameter): @@ -145,7 +146,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]] if not is_moe_tensor(param): ep_size = 1 # set ep_size to 1 for dp parameters else: - ep_size = get_ep_size(param) + ep_size = dist.get_world_size(param.ep_group) if ep_size not in epsize_param_dict: epsize_param_dict[ep_size] = [] epsize_param_dict[ep_size].append(param) @@ -170,8 +171,8 @@ def sync_moe_model_param(model: nn.Module): # When ep_size = world_size, communication is not needed if ep_size != 1 and ep_size != MOE_MANAGER.world_size: for param in param_dict[ep_size]: - src_rank = get_dp_group_ranks(param)[0] - dist.broadcast(param, src=src_rank, group=get_dp_group(param)) + src_rank = get_process_group_ranks(param.dp_group)[0] + dist.broadcast(param, src=src_rank, group=param.dp_group) def set_moe_args(config: Any, args: dict): diff --git a/colossalai/shardformer/layer/moe/__init__.py b/colossalai/shardformer/layer/moe/__init__.py new file mode 100644 index 000000000..6fa015a94 --- /dev/null +++ b/colossalai/shardformer/layer/moe/__init__.py @@ -0,0 +1,3 @@ +from .experts import * +from .layers import * +from .routers import * diff --git a/colossalai/moe/experts.py b/colossalai/shardformer/layer/moe/experts.py similarity index 98% rename from colossalai/moe/experts.py rename to colossalai/shardformer/layer/moe/experts.py index 8e6ea3884..1be7a2754 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/shardformer/layer/moe/experts.py @@ -9,7 +9,7 @@ from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_activation from colossalai.shardformer.layer.utils import Randomizer -from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, set_moe_tensor_info +from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size if HAS_TRITON: from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine @@ -35,7 +35,7 @@ class MLPExperts(nn.Module): num_experts: int, hidden_size: int, intermediate_size: int, - expert_parallel: Optional[str] = None, + expert_parallel: Optional[str] = "EP", activation: Optional[Callable] = None, drop_rate: Optional[float] = 0, gated: Optional[bool] = False, diff --git a/colossalai/moe/layers.py b/colossalai/shardformer/layer/moe/layers.py similarity index 96% rename from colossalai/moe/layers.py rename to colossalai/shardformer/layer/moe/layers.py index 2ac5b186d..e5b0ef97f 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/shardformer/layer/moe/layers.py @@ -8,11 +8,9 @@ import torch.nn as nn import torch.nn.functional as F from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter -from colossalai.moe.experts import MLPExperts from colossalai.moe.load_balance import LoadBalancer -from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.routers import MoeRouter, get_router_cls from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator +from colossalai.shardformer.layer.moe import MLPExperts from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size @@ -23,6 +21,7 @@ class SparseMLP(nn.Module): dim_model (int): Hidden dimension of training model num_experts (int): The number experts top_k (int, optional): The number of experts for dispatchment of each token + parallel (str): parallel mode. Should be "EP", "TP" or None capacity_factor_train (float, optional): Capacity factor in routing during training capacity_factor_eval (float, optional): Capacity factor in routing during evaluation min_capacity (int, optional): The minimum number of the capacity of each expert @@ -51,6 +50,7 @@ class SparseMLP(nn.Module): hidden_size: int, intermediate_size: int, router_top_k: int = 1, + parallel: str = "EP", router_loss: bool = True, router_norm: bool = False, router_capacity_factor_train: float = 1.25, @@ -66,7 +66,7 @@ class SparseMLP(nn.Module): load_balance_group_swap_factor: float = 0.4, enable_kernel: bool = False, enable_comm_overlap: bool = False, - enable_hierarchical_comm: bool = False, + enable_hierarchical_comm: bool = True, return_gate_logits: bool = False, ): super().__init__() @@ -77,7 +77,9 @@ class SparseMLP(nn.Module): self.return_gate_logits = return_gate_logits self.enable_kernel = enable_kernel self.enable_comm_overlap = enable_comm_overlap - self.expert_parallel = MOE_MANAGER.get_parallel() + # self.expert_parallel = MOE_MANAGER.get_parallel() + assert parallel in ["EP", "TP", None], "parallel mode must be EP, TP or None" + self.parallel = parallel self.router_loss = router_loss self.router_norm = router_norm @@ -99,7 +101,7 @@ class SparseMLP(nn.Module): # moe experts self.experts = MLPExperts( num_experts=self.num_experts, - expert_parallel=self.expert_parallel, + expert_parallel=self.parallel, hidden_size=self.hidden_size, intermediate_size=self.intermediate_size, activation=mlp_activation, @@ -108,11 +110,12 @@ class SparseMLP(nn.Module): ) # get parallel settings - if self.expert_parallel is not None: + if self.parallel is not None: self.ep_group = get_ep_group(self.experts) self.ep_size = get_ep_size(self.experts) self.ep_hierarchical_group = None if enable_hierarchical_comm: + # TODO: move to plugin self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group( get_ep_group_ranks(self.experts) ) @@ -186,11 +189,11 @@ class SparseMLP(nn.Module): dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) # expert_output: (num_groups, num_experts, capacity, hidden_size) - if self.expert_parallel == "EP": + if self.parallel == "EP": expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap) - elif self.expert_parallel == "TP": + elif self.parallel == "TP": expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap) - elif self.expert_parallel is None: + elif self.parallel is None: expert_output = self._local_process(dispatch_data) else: raise NotImplementedError( diff --git a/colossalai/shardformer/layer/moe/routers.py b/colossalai/shardformer/layer/moe/routers.py new file mode 100644 index 000000000..1be7a2754 --- /dev/null +++ b/colossalai/shardformer/layer/moe/routers.py @@ -0,0 +1,161 @@ +import math +from typing import Callable, Optional, Tuple + +import torch +import torch.nn as nn + +from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON +from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import get_activation +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size + +if HAS_TRITON: + from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine + + +class MLPExperts(nn.Module): + """ + SparseMLP is a multi-layer perceptron with sparse expert parallel layers. + + Args: + num_experts (int): The number of experts + hidden_size (int): The hidden size of MLP + intermediate_size (int): The intermediate size of MLP + expert_parallel (str, optional): The parallelism of experts. Now we have None, EP and TP. + activation (optional): The activation function of MLP + drop_rate (float, optional): The drop rate of MLP + gated (bool, optional): Whether to use gated MLP + use_kernel (bool, optional): Whether to use kernel optimization + """ + + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + expert_parallel: Optional[str] = "EP", + activation: Optional[Callable] = None, + drop_rate: Optional[float] = 0, + gated: Optional[bool] = False, + use_kernel: Optional[bool] = False, + ): + super().__init__() + assert expert_parallel in ["EP", "TP", None] + self.expert_parallel = expert_parallel + self.num_total_experts = num_experts + self.gated = gated + self.use_kernel = use_kernel + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + + # get expert parallel info + if expert_parallel is not None: + self.num_local_experts, self.moe_info = MOE_MANAGER.get_info( + num_experts, use_tp=True if expert_parallel == "TP" else False + ) + # get settings for different parallel + self.ep_size = get_ep_size(self) + if expert_parallel == "TP": + intermediate_size = intermediate_size // self.ep_size + num_experts = self.num_total_experts + else: + num_experts = self.num_local_experts + else: + self.num_local_experts = self.num_total_experts + self.ep_size = 1 + + if gated: + self.wi_gate = nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size + ) + ) + self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) + else: + self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) + self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size)) + + self.act_name = activation + self.act = get_activation(activation) + self.drop = nn.Dropout(p=drop_rate) + + if expert_parallel is not None: + for param in self.parameters(): + set_moe_tensor_info(param, self.moe_info) + + # init param + self.reset_parameters() + + @torch.no_grad() + def reset_parameters(self): + # expert param should be different + if self.expert_parallel is not None: + seed_ctx = Randomizer(get_ep_rank(self)).fork_rng(enable_cpu=True) + else: + seed_ctx = Randomizer(42).fork_rng(enable_cpu=True) + with seed_ctx: + if self.gated: + torch.nn.init.normal_(self.wi_gate, std=math.sqrt(0.1 / self.hidden_size)) + torch.nn.init.normal_(self.wi_up, std=math.sqrt(0.1 / self.hidden_size)) + else: + torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size)) + torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size)) + + def forward( + self, + x: torch.Tensor, + param_slice: Tuple[slice] = (slice(None),), + use_sparse: bool = True, + ) -> torch.Tensor: + """ + forward: hidden_size --> intermediate_size --> hidden_size + + Args: + x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size) + + Returns: + torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size) + """ + x = MoeInGradScaler.apply(x, self.ep_size) + + e = x.size(1) + h = x.size(-1) + + x = x.transpose(0, 1) + inshape = x.shape + x = x.reshape(e, -1, h) + + if self.use_kernel and use_sparse: + seq_len = x.shape[1] + with torch.no_grad(): + mask = x[:, :, 0] != 0.0 + mask = torch.sum(mask, dim=-1) + x_list = [] + for i in range(e): + x_list.append(x[i, : mask[i]]) + x = x_list + + if self.gated: + x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)] + x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)] + if self.use_kernel and HAS_TRITON and self.act_name == "swiglu": + x = [LlamaActCombine.apply(x_gate[i], x_up[i]) for i in range(e)] + else: + x = [self.act(x_gate[i]) * x_up[i] for i in range(e)] + else: + x = [torch.mm(x[i], self.wi[param_slice][i]) for i in range(e)] + x = [self.act(x[i]) for i in range(e)] + x = [self.drop(x[i]) for i in range(e)] + x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)] + + if self.use_kernel and use_sparse: + for i in range(e): + x[i] = torch.nn.functional.pad(x[i], (0, 0, 0, seq_len - x[i].shape[0]), mode="constant", value=0) + + x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0) + x = x.reshape(inshape) + x = x.transpose(0, 1).contiguous() + x = MoeOutGradScaler.apply(x, self.ep_size) + return x diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/colossalai/shardformer/modeling/mixtral.py similarity index 65% rename from applications/ColossalMoE/colossal_moe/models/mixtral_policy.py rename to colossalai/shardformer/modeling/mixtral.py index c01e02c49..2fbc34302 100644 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -1,222 +1,108 @@ -from functools import partial -from typing import Callable, Dict, List, Optional, Union +from typing import List, Optional import torch -import torch.nn as nn -from torch import Tensor -from torch.nn import CrossEntropyLoss, Module +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed import ProcessGroup + +# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo +from torch.nn import CrossEntropyLoss +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.mixtral.modeling_mixtral import ( - MixtralDecoderLayer, - MixtralForCausalLM, - MixtralModel, + MixtralSparseMoeBlock, MoeCausalLMOutputWithPast, - _prepare_4d_causal_attention_mask, load_balancing_loss_func, ) -from transformers.utils import logging +from transformers.utils import is_flash_attn_2_available, logging +from colossalai.lazy import LazyInitContext +from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col -from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from colossalai.shardformer.shard import ShardConfig - -from .mixtral_layer import EPMixtralSparseMoeBlock - -__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"] +from colossalai.shardformer.shard.utils import set_tensors_to_none -class MixtralPolicy(Policy): - def config_sanity_check(self): - pass +class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): + def __init__(self, config): + self.moe_info = None + super().__init__(config) - def preprocess(self): - if self.shard_config.enable_tensor_parallelism: - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size + def setup_ep(self, ep_group: ProcessGroup): + ep_group = ep_group + self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1 + self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0 + assert self.num_experts % self.ep_size == 0 + self.ep_group = ep_group + self.num_experts_per_ep = self.num_experts // self.ep_size + self.expert_start_idx = self.ep_rank * self.num_experts_per_ep + held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] + set_tensors_to_none(self.experts, exclude=set(held_experts)) + for p in self.experts.parameters(): + p.ep_group = ep_group - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + @staticmethod + def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock": + LazyInitContext.materialize(module) + module.__class__ = EPMixtralSparseMoeBlock + # if "ep_group" in kwargs: + assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!" + module.setup_ep(kwargs["ep_group"]) + return module - return self.model + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) - def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - policy = {} + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) - if self.shard_config.enable_sequence_parallelism: - self.shard_config.enable_sequence_parallelism = False - raise NotImplementedError( - "Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag." - ) + selected_experts = selected_experts.t().reshape(-1) + selected_experts_idx = selected_experts.argsort() + dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx] + input_split_sizes = selected_experts.bincount(minlength=self.num_experts) + output_split_sizes = torch.zeros_like(input_split_sizes) + dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) - if self.shard_config.enable_tensor_parallelism: - raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") - - # expert parallel - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="block_sparse_moe", - target_module=EPMixtralSparseMoeBlock, - ) - ], - policy=policy, - target_key=MixtralDecoderLayer, - ) - - # optimization configuration - if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="input_layernorm", - target_module=FusedRMSNorm, - ), - SubModuleReplacementDescription( - suffix="post_attention_layernorm", - target_module=FusedRMSNorm, - ), - ], - policy=policy, - target_key=MixtralDecoderLayer, - ) - - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription( - suffix="norm", - target_module=FusedRMSNorm, - ), - policy=policy, - target_key=MixtralModel, - ) - - if self.shard_config.enable_flash_attention: - raise NotImplementedError("Flash attention has already been replaced in mixtral.") - - return policy - - def postprocess(self): - return self.model - - def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" - if self.pipeline_stage_manager: - stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == "MixtralModel": - module = self.model + input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) + # compute expert output + output_states = MoeInGradScaler.apply(output_states, self.ep_size) + if output_states.size(0) > 0: + if self.num_experts_per_ep == 1: + # no need to split + expert = self.experts[self.expert_start_idx] + output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states) + output_states = expert.w2(output_states) else: - module = self.model.model - - layers_per_stage = stage_manager.distribute_layers(len(module.layers)) - stage_index = stage_manager.get_stage_index(layers_per_stage) - method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=model_cls - ) - - return - - def get_held_layers(self) -> List[Module]: - """Get pipeline layers for current stage.""" - assert self.pipeline_stage_manager is not None - - if self.model.__class__.__name__ == "MixtralModel": - module = self.model - else: - module = self.model.model - stage_manager = self.pipeline_stage_manager - - held_layers = [] - layers_per_stage = stage_manager.distribute_layers(len(module.layers)) - if stage_manager.is_first_stage(): - held_layers.append(module.embed_tokens) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.norm) - - return held_layers - - -class MixtralModelPolicy(MixtralPolicy): - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - policy = super().module_policy() - if self.pipeline_stage_manager: - # set None as default - self.set_pipeline_forward( - model_cls=MixtralModel, - new_forward=MixtralPipelineForwards.mixtral_model_forward, - policy=policy, - ) - return policy - - def get_held_layers(self) -> List[Module]: - """Get pipeline layers for current stage.""" - held_layers = super().get_held_layers() - return held_layers - - def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared params in llama model""" - return [] - - -class MixtralForCausalLMPolicy(MixtralPolicy): - def module_policy(self): - policy = super().module_policy() - - if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm - new_item = { - MixtralForCausalLM: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True), - ) - ] - ) - } - policy.update(new_item) - - if self.pipeline_stage_manager: - # set None as default - self.set_pipeline_forward( - model_cls=MixtralForCausalLM, - new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward, - policy=policy, - ) - - return policy - - def get_held_layers(self) -> List[Module]: - """Get pipeline layers for current stage.""" - stage_manager = self.pipeline_stage_manager - held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): - held_layers.append(self.model.lm_head) - return held_layers - - def get_shared_params(self) -> List[Dict[int, Tensor]]: - llama_model = self.model.model - if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: - if ( - id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) - and self.pipeline_stage_manager.num_stages > 1 - ): - # tie weights - return [ - { - 0: llama_model.embed_tokens.weight, - self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, - } - ] - return [] + output_states_splits = output_states.split(output_split_sizes.tolist()) + output_states_list = [] + for i, split_states in enumerate(output_states_splits): + if split_states.size(0) == 0: + continue + expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] + split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states) + split_states = expert.w2(split_states) + output_states_list.append(split_states) + output_states = torch.cat(output_states_list) + output_states = MoeOutGradScaler.apply(output_states, self.ep_size) + dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) + recover_experts_idx = torch.empty_like(selected_experts_idx) + recover_experts_idx[selected_experts_idx] = torch.arange( + selected_experts_idx.size(0), device=selected_experts_idx.device + ) + dispatch_states = dispatch_states[recover_experts_idx] + k_hidden_states = dispatch_states.chunk(self.top_k) + output_states = k_hidden_states[0] * routing_weights[:, 0, None] + for i in range(1, self.top_k): + output_states += k_hidden_states[i] * routing_weights[:, i, None] + output_states = output_states.reshape(batch_size, sequence_length, hidden_dim) + return output_states, router_logits class MixtralPipelineForwards: @@ -332,7 +218,7 @@ class MixtralPipelineForwards: # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage - if self._use_flash_attention_2: + if is_flash_attn_2_available(): # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 99b68aee2..bf139c840 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -176,6 +176,7 @@ _POLICY_LIST = { "transformers.models.falcon.modeling_falcon.FalconForQuestionAnswering": PolicyLocation( file_name="falcon", class_name="FalconForQuestionAnsweringPolicy" ), + # mistral "transformers.models.mistral.modeling_mistral.MistralModel": PolicyLocation( file_name="mistral", class_name="MistralModelPolicy" ), @@ -185,6 +186,13 @@ _POLICY_LIST = { "transformers.models.mistral.modeling_mistral.MistralForSequenceClassification": PolicyLocation( file_name="mistral", class_name="MistralForSequenceClassificationPolicy" ), + # mixtral + "transformers.models.mixtral.modeling_mixtral.MixtralModel": PolicyLocation( + file_name="mixtral", class_name="MixtralModelPolicy" + ), + "transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM": PolicyLocation( + file_name="mixtral", class_name="MixtralForCausalLMPolicy" + ), # Qwen2 "transformers.models.qwen2.modeling_qwen2.Qwen2Model": PolicyLocation( file_name="qwen2", class_name="Qwen2ModelPolicy" @@ -195,7 +203,7 @@ _POLICY_LIST = { "transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation( file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy" ), - # Command-R + # command "transformers.models.cohere.modeling_cohere.CohereModel": PolicyLocation( file_name="command", class_name="CommandModelPolicy" ), diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py new file mode 100644 index 000000000..f9721c79e --- /dev/null +++ b/colossalai/shardformer/policies/mixtral.py @@ -0,0 +1,210 @@ +from functools import partial +from typing import Callable, Dict, List, Union + +import torch.nn as nn +from torch import Tensor +from torch.nn import Module +from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel + +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"] + + +class MixtralPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + policy = {} + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + raise NotImplementedError( + "Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + ) + + if self.shard_config.enable_tensor_parallelism: + raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") + if getattr(self.shard_config, "ep_group", None) is None: + raise ValueError("You must pass in ep_group via shard_config for expert parallel!") + + # expert parallel + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="block_sparse_moe", + target_module=EPMixtralSparseMoeBlock, + kwargs={"ep_group": self.shard_config.ep_group}, + ) + ], + policy=policy, + target_key=MixtralDecoderLayer, + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ), + ], + policy=policy, + target_key=MixtralDecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=MixtralModel, + ) + + if self.shard_config.enable_flash_attention: + raise NotImplementedError("Flash attention has already been replaced in mixtral.") + + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "MixtralModel": + module = self.model + else: + module = self.model.model + + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) + + return + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "MixtralModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + + return held_layers + + +class MixtralModelPolicy(MixtralPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MixtralModel, + new_forward=MixtralPipelineForwards.mixtral_model_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + return [] + + +class MixtralForCausalLMPolicy(MixtralPolicy): + def module_policy(self): + policy = super().module_policy() + # TODO: assign pg mesh from plugin to all modules + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + MixtralForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True), + ) + ] + ) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MixtralForCausalLM, + new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + llama_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if ( + id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): + # tie weights + return [ + { + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + return [] diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 453e8d23e..b64300366 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -46,6 +46,7 @@ class ShardConfig: make_vocab_size_divisible_by: int = 64 gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) + ep_group: Optional[ProcessGroup] = None # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py index b6843df7a..f52802d47 100644 --- a/colossalai/tensor/moe_tensor/api.py +++ b/colossalai/tensor/moe_tensor/api.py @@ -17,10 +17,10 @@ def is_moe_tensor(tensor: torch.Tensor) -> bool: Returns: bool: Whether the given tensor is a moe tensor. """ - return hasattr(tensor, "moe_info") + return hasattr(tensor, "ep_group") -def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None: +def set_moe_tensor_ep_group(tensor: torch.Tensor, ep_group: ProcessGroup) -> None: """ Set moe info for the given tensor. @@ -29,7 +29,7 @@ def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None moe_info (dict): The moe info to be set. """ - tensor.__setattr__("moe_info", moe_info) + tensor.__setattr__("ep_group", ep_group) def get_moe_info(ep_size: int, dp_size: int, pp_size: int, ep_inside: bool) -> MoeParallelInfo: @@ -58,7 +58,7 @@ def get_ep_group(tensor: torch.Tensor) -> ProcessGroup: Returns: torch.distributed.ProcessGroup: The expert parallel group of the given tensor. """ - return tensor.moe_info.ep_group + return tensor.ep_group def get_ep_size(tensor: torch.Tensor) -> int: @@ -71,7 +71,8 @@ def get_ep_size(tensor: torch.Tensor) -> int: Returns: int: The expert parallel size of the given tensor. """ - return tensor.moe_info.ep_size + assert getattr(tensor, "ep_group") is not None, "The tensor does not have expert parallel group." + return dist.get_world_size(tensor.ep_group) def get_dp_size(tensor: torch.Tensor) -> int: diff --git a/colossalai/zero/low_level/bookkeeping/__init__.py b/colossalai/zero/low_level/bookkeeping/__init__.py index 427973772..07f6cdb2d 100644 --- a/colossalai/zero/low_level/bookkeeping/__init__.py +++ b/colossalai/zero/low_level/bookkeeping/__init__.py @@ -1,6 +1,5 @@ from .bucket_store import BucketStore from .gradient_store import GradientStore -from .parameter_store import ParameterStore from .tensor_bucket import TensorBucket -__all__ = ["GradientStore", "ParameterStore", "BucketStore", "TensorBucket"] +__all__ = ["GradientStore", "BucketStore", "TensorBucket"] diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 1496603fa..19d20de2b 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -1,12 +1,11 @@ -from typing import Dict, Optional +from typing import Dict import torch -import torch.distributed as dist from torch import Tensor from torch._utils import _flatten_dense_tensors from torch.distributed import ProcessGroup -from colossalai.accelerator import get_accelerator +from colossalai.accelerator.api import get_accelerator from .base_store import BaseStore @@ -16,29 +15,11 @@ class BucketStore(BaseStore): self, torch_pg: ProcessGroup, reduce_bucket_size: int, - overlap_communication: bool, - communication_dtype: Optional[torch.dtype] = None, - moe_extra_dp_process_group: ProcessGroup = None, ): super().__init__(torch_pg) self.reduce_bucket_size = reduce_bucket_size - # communication params - self._overlap_communication = overlap_communication - self._communication_dtype = communication_dtype - if self._overlap_communication: - self.comm_stream = get_accelerator().Stream() - self.zero_local_rank = dist.get_rank(group=self.torch_pg) - self.zero_world_size = dist.get_world_size(group=self.torch_pg) - # extra dp - # This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size. - # Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg. - # Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step. - # And moe working and master param are split by extra dp pg. - self.moe_extra_dp_pg = moe_extra_dp_process_group - if self.moe_extra_dp_pg is not None: - self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg) - self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg) self.reset_all() + self.comm_stream = get_accelerator().Stream() def reset_all(self) -> None: # init diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index fc28b7795..e24a67f9d 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from torch import Tensor @@ -6,7 +6,7 @@ from .base_store import BaseStore class GradientStore(BaseStore): - def __init__(self, *args, partition_grad: bool = False, require_grad_sync: bool = True): + def __init__(self, *args, partition_grad: bool = False): super().__init__(*args) """ self._grads_of_params mapping the parameter and its gradient slices @@ -20,8 +20,6 @@ class GradientStore(BaseStore): self._grads_of_params = dict() # stage 2 self._partition_grads = partition_grad - # grad accumulation - self.require_grad_sync = require_grad_sync self._working_index = 0 if partition_grad else self._local_rank # for zero2, it's `param_id: [grad_local_rank]` self.grad_to_param_mapping = dict() @@ -107,8 +105,7 @@ class GradientStore(BaseStore): for group in self._grads_of_params.values(): if param_id in group.keys(): return group[param_id][self._working_index] - - raise KeyError(f"Working gradient for param_id {param_id} not found.") + return None def reset_grads_by_group_id(self, group_id: int): self._grads_of_params[group_id] = dict() @@ -116,7 +113,7 @@ class GradientStore(BaseStore): def reset_all_gradients(self): self._grads_of_params = dict() - def get_param_id_for_grad(self, grad: Tensor) -> int: + def get_param_id_for_grad(self, grad: Tensor) -> Optional[int]: """Return the id of a parameter which the gradient slice belongs to Args: @@ -126,4 +123,4 @@ class GradientStore(BaseStore): int: the id of a parameter which the gradient slice belongs to """ - return self.grad_to_param_mapping[id(grad)] + return self.grad_to_param_mapping.get(id(grad), None) diff --git a/colossalai/zero/low_level/bookkeeping/parameter_store.py b/colossalai/zero/low_level/bookkeeping/parameter_store.py deleted file mode 100644 index c03231f5f..000000000 --- a/colossalai/zero/low_level/bookkeeping/parameter_store.py +++ /dev/null @@ -1,60 +0,0 @@ -from typing import Dict - -from torch import Tensor -from torch.distributed import ProcessGroup - -from .base_store import BaseStore - - -class ParameterStore(BaseStore): - def __init__(self, torch_pg: ProcessGroup): - super().__init__(torch_pg) - - # record the padding size of each param - self._padding_map = dict() - - # mapping working param and master param - self.master_to_working_param = dict() - self.working_to_master_param = dict() - - def record_param_padding_size(self, param: Tensor, padding_size: int): - """Record the padding size of a param - - Args: - param (Tensor): The parameter - padding_size (int): The padding size of the parameter - """ - - self._padding_map[id(param)] = padding_size - - def get_param_padding_size(self, param: Tensor) -> int: - """Return the padding size of the parameter - - Args: - param (Tensor): The parameter - - Returns: - int: the padding size of the parameter - """ - - return self._padding_map[id(param)] - - def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor): - """Mapping master parameter and working parameter - - Args: - master_param (Tensor): The parameter copy in optimizer - working_param (Tensor): The parameter of the model - """ - - self.master_to_working_param[id(master_param)] = working_param - self.working_to_master_param[id(working_param)] = master_param - - def get_padding_map(self) -> Dict[int, Tensor]: - """Return the padding map - - Returns: - Dict[int, Tensor]: The padding map - """ - - return self._padding_map diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index d19e0a002..e06cf0581 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -3,12 +3,12 @@ import copy from contextlib import contextmanager from functools import partial from typing import Dict, Iterator, List, Optional, Tuple +from weakref import proxy import torch import torch.distributed as dist import torch.nn as nn from torch import Tensor, inf -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ProcessGroup from torch.optim import Optimizer @@ -20,17 +20,16 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import ( ) from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.tensor.moe_tensor.api import is_moe_tensor -from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor -from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket +from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor +from .bookkeeping import BucketStore, GradientStore, TensorBucket class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): def __init__( self, num_working_param_groups: int, - grad_store: GradientStore, + pg_to_grad_store: Dict[ProcessGroup, GradientStore], initial_scale: float = 2**16, min_scale: float = 1, growth_factor: float = 2, @@ -49,13 +48,14 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): max_scale, ) self.num_working_param_groups = num_working_param_groups - self.grad_store = grad_store + self.pg_to_grad_store = pg_to_grad_store def check_local_overflow(self) -> bool: - for group_id in range(self.num_working_param_groups): - for avg_grad in self.grad_store.get_working_grads_by_group_id(group_id): - if avg_grad is not None and has_inf_or_nan(avg_grad): - return True + for store in self.pg_to_grad_store.values(): + for group_id in range(self.num_working_param_groups): + for avg_grad in store.get_working_grads_by_group_id(group_id): + if avg_grad is not None and has_inf_or_nan(avg_grad): + return True return False @@ -65,6 +65,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def __init__( self, optimizer: Optimizer, + pg_to_param_list: Dict[ProcessGroup, List[nn.Parameter]] = None, initial_scale: int = 2**16, # grad scaler config min_scale: int = 1, growth_factor: float = 2.0, @@ -79,9 +80,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): overlap_communication: bool = False, partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload - dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + dp_process_group: Optional[ProcessGroup] = None, forced_dtype: Optional[torch.dtype] = None, - moe_extra_dp_process_group: Optional[ProcessGroup] = None, master_weights: bool = True, # master weights ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) @@ -90,12 +90,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self._logger = get_dist_logger() self._verbose = verbose + if dp_process_group is not None and pg_to_param_list is not None: + raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.") + + if pg_to_param_list is None: + unique_dp_group = dist.group.WORLD if dp_process_group is None else dp_process_group + pg_to_param_list = {unique_dp_group: []} + for group in self.optim.param_groups: + pg_to_param_list[unique_dp_group].extend(group["params"]) + + self.pg_to_param_list = pg_to_param_list + param_to_pg = {} + for grp, param_list in pg_to_param_list.items(): + for p in param_list: + assert isinstance(p, nn.Parameter), f"got {type(p)}" + param_to_pg[p] = grp + self.param_to_pg = param_to_pg + + # stage 2 + self._partition_grads = partition_grad + self._cpu_offload = cpu_offload + # grad accumulation + self.require_grad_sync = True + # working and master params for mixed precision training self._working_param_groups = dict() self._master_param_groups_of_current_rank = dict() + # communication params + self._overlap_communication = overlap_communication + self._reduce_bucket_size = reduce_bucket_size + self._communication_dtype = communication_dtype + # gradient clipping self._clip_grad_norm = clip_grad_norm @@ -114,17 +142,27 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # ParameterStore will manage the tensor buffers used for zero # it will not manage the tensors used by mixed precision training - self._param_store = ParameterStore(dp_process_group) - self._grad_store = GradientStore(dp_process_group, partition_grad=partition_grad, require_grad_sync=True) - self._bucket_store = BucketStore( - dp_process_group, reduce_bucket_size, overlap_communication, communication_dtype, moe_extra_dp_process_group - ) - # moe param should not be stored in working_groups - # because they have different parallel strategy - # so we need to store them separately in param_groups - # instead of working_groups - self.working_moe_params = list() + # record the padding size of each param + self._padding_map = dict() + + # mapping working param and master param + self.master_to_working_param = dict() + self.working_to_master_param = dict() + + # NOTE need to gurantee the order of process group is the same accross all ranks + # process_group <---> xxx_store + # process_group <---> [param1 param2 ...] + # each process group have its own stores + # param belonging to one process_group will use corresponding store + self.pg_to_grad_store = { + pg: GradientStore(pg, partition_grad=self._partition_grads) for pg in self.pg_to_param_list + } + # param id to grad store, have to use id(param) as key since it is used in stores + self.pid_to_grad_store = {id(param): self.pg_to_grad_store[param_to_pg[param]] for param in param_to_pg} + self.pg_to_bucket_store = {pg: BucketStore(pg, reduce_bucket_size) for pg in self.pg_to_param_list} + # param id to bucket store, have to use id(param) as key since it is used in stores + self.pid_to_bucket_store = {id(param): self.pg_to_bucket_store[param_to_pg[param]] for param in param_to_pg} # iterate over the param group in the optimizer # partition these param groups for data parallel training @@ -133,11 +171,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): group_params = list() for param in param_group["params"]: if param.requires_grad: - if self._bucket_store.moe_extra_dp_pg is None: - # skip moe param - if is_moe_tensor(param): - self.working_moe_params.append(param) - continue group_params.append(param) # add the working params to working_param_groups for bookkeeping @@ -151,29 +184,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # managed by this data parallel rank param_group["params"] = master_param_current_rank - # if there are moe params, store in addtional group in optim - if len(self.working_moe_params) > 0: - self._sync_master_param = False - param_group = dict() - # create fp32 master param - for key, value in self.optim.param_groups[0].items(): - if key != "params": - param_group[key] = value - self.master_moe_params = [] - for param in self.working_moe_params: - self.master_moe_params.append(param.clone().to(torch.float32).detach()) - # create mapping from master to working for optimizer io - self.moe_master_to_working_map = {} - for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): - self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param - # add to optim - param_group["params"] = self.master_moe_params - self.optim.param_groups.append(param_group) - # reduction hook is only used if overlapping communication # or stage 2 is used # if it is stage 1 without overlapping, no hook will be attached - if self._bucket_store._overlap_communication or self._grad_store._partition_grads: + self.grad_handles = [] + if self._overlap_communication or self._partition_grads: self._attach_reduction_hook() # initialize mixed precision mixin @@ -181,7 +196,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if self._dtype is torch.float16: self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin( self.num_param_groups, - self._grad_store, + self.pg_to_grad_store, initial_scale=initial_scale, min_scale=min_scale, growth_factor=growth_factor, @@ -194,7 +209,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self.mixed_precision_mixin = BF16MixedPrecisionMixin() def __del__(self): - self.remove_hooks() + for hook in self.grad_handles: + hook.remove() @property def dtype(self): @@ -221,9 +237,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for param in param_list: padding_size = ( - self._bucket_store.zero_world_size - param.numel() % self._bucket_store.zero_world_size - ) % self._bucket_store.zero_world_size - self._param_store.record_param_padding_size(param, padding_size) + self.pid_to_bucket_store[id(param)].world_size + - param.numel() % self.pid_to_bucket_store[id(param)].world_size + ) % self.pid_to_bucket_store[id(param)].world_size + self.record_param_padding_size(param, padding_size) with torch.no_grad(): if padding_size > 0: @@ -234,14 +251,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): else: padding_param = param.data.view(-1) - if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(param): - splited_params = padding_param.split( - padding_param.numel() // self._bucket_store.moe_extra_dp_pg_size - ) - splited_params = splited_params[self._bucket_store.moe_extra_dp_pg_rank] - else: - splited_params = padding_param.split(padding_param.numel() // self._bucket_store.zero_world_size) - splited_params = splited_params[self._bucket_store.zero_local_rank] + splited_params = padding_param.split( + padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size + ) + splited_params = splited_params[self.pid_to_bucket_store[id(param)].local_rank] # use fp32 when master_weights is True if self._master_weights is True: @@ -249,9 +262,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): else: splited_param_current_rank = splited_params - # Send the splited view to the optimizer to match ZeRO 2 grad shape params_current_rank.append(splited_param_current_rank) - self._param_store.link_master_and_working_param(splited_param_current_rank, param) + self.link_master_and_working_param(splited_param_current_rank, param) return params_current_rank @@ -259,93 +271,45 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # Backward Reduction Hook # ########################### - @staticmethod - def grad_handler( - param: nn.Parameter, - group_id: int, - bucket_store: BucketStore, - param_store: ParameterStore, - grad_store: GradientStore, - ): - # if run with no_sync context, would not sync grad when backward - if grad_store.require_grad_sync: - LowLevelZeroOptimizer.add_to_bucket(param, group_id, bucket_store, param_store, grad_store) - def _attach_reduction_hook(self): # we iterate over the working params # on each param, we register a hook to its AccumulateGrad object + self_weakref = proxy(self) + + def _grad_handler(param, group_id): + # if run with no_sync context, would not sync grad when backward + if self_weakref.require_grad_sync: + self_weakref._add_to_bucket(param, group_id) + for group_id in range(self.num_param_groups): param_group = self._working_param_groups[group_id] for param in param_group: if param.requires_grad: - param._grad_handle = param.register_post_accumulate_grad_hook( - partial( - LowLevelZeroOptimizer.grad_handler, - group_id=group_id, - bucket_store=self._bucket_store, - param_store=self._param_store, - grad_store=self._grad_store, - ) + self.grad_handles.append( + param.register_post_accumulate_grad_hook(partial(_grad_handler, group_id=group_id)) ) ####################### # Reduction Functions # ####################### - @staticmethod - def run_reduction(bucket_store: BucketStore, grad_store: GradientStore): - if bucket_store.num_elements_in_bucket() > 0: + + def _run_reduction(self): + for bucket_store in self.pg_to_bucket_store.values(): + if bucket_store.num_elements_in_bucket() <= 0: + continue + bucket_store.build_grad_in_bucket() - if bucket_store.moe_extra_dp_pg is None: - flat_grads = bucket_store.get_flatten_grad() - flat_grads /= bucket_store.zero_world_size - else: - # record moe and non moe param - moe_list = [] - for param in bucket_store._param_list: - moe_list.append(is_moe_tensor(param)) - # divide them into different groups - moe_grad_list = [] - non_moe_grad_list = [] - for grad_list in bucket_store._grad_in_bucket.values(): - non_moe_cur_grad = [] - moe_cur_grad = [] - for i in range(len(grad_list)): - if moe_list[i] == True: - moe_cur_grad.append(grad_list[i]) - else: - non_moe_cur_grad.append(grad_list[i]) - if len(moe_cur_grad) > 0: - moe_grad_list.append(moe_cur_grad) - if len(non_moe_cur_grad) > 0: - non_moe_grad_list.append(non_moe_cur_grad) - - if len(non_moe_grad_list) > 0: - non_moe_flat_grads = [] - for grad_list in non_moe_grad_list: - non_moe_flat_grads.append(_flatten_dense_tensors(grad_list)) - non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads) - non_moe_flat_grads /= bucket_store.zero_world_size - - if len(moe_grad_list) > 0: - moe_flat_grads = [] - for grad_list in moe_grad_list: - moe_flat_grads.append(_flatten_dense_tensors(grad_list)) - moe_flat_grads = _flatten_dense_tensors(moe_flat_grads) + flat_grads = bucket_store.get_flatten_grad() + flat_grads /= bucket_store.world_size # ready to add other tensors to bucket bucket_store.reset_num_elements_in_bucket() - if bucket_store._overlap_communication: + if self._overlap_communication: stream = bucket_store.comm_stream # in case of the memory being reused in the default stream - if bucket_store.moe_extra_dp_pg is None: - flat_grads.record_stream(stream) - else: - if len(non_moe_grad_list) > 0: - non_moe_flat_grads.record_stream(stream) - if len(moe_grad_list) > 0: - moe_flat_grads.record_stream(stream) + flat_grads.record_stream(stream) # waiting for ops in the default stream finishing stream.wait_stream(get_accelerator().current_stream()) else: @@ -354,126 +318,43 @@ class LowLevelZeroOptimizer(OptimizerWrapper): with get_accelerator().stream(stream): group_id = bucket_store.current_group_id - if bucket_store.moe_extra_dp_pg is None: - grad_dtype = flat_grads.dtype - if bucket_store._communication_dtype is not None: - flat_grads = flat_grads.to(bucket_store._communication_dtype) + grad_dtype = flat_grads.dtype + if self._communication_dtype is not None: + flat_grads = flat_grads.to(self._communication_dtype) - if not grad_store._partition_grads: - if bucket_store.moe_extra_dp_pg is None: - dist.all_reduce(flat_grads, group=bucket_store.torch_pg) - if flat_grads.dtype != grad_dtype: - flat_grads = flat_grads.to(grad_dtype) - - flat_grads_per_rank = flat_grads.split(flat_grads.numel() // bucket_store.zero_world_size) - grad_in_bucket = bucket_store.get_grad() - LowLevelZeroOptimizer.update_unpartitoned_grad( - bucket_store, grad_store, grad_in_bucket.values(), flat_grads_per_rank, group_id - ) - - # sync extra zero group - else: - # sync non moe param in global dp group - if len(non_moe_grad_list) > 0: - dist.all_reduce(non_moe_flat_grads, group=bucket_store.torch_pg) - flat_grads_per_rank = non_moe_flat_grads.split( - non_moe_flat_grads.numel() // bucket_store.zero_world_size - ) - LowLevelZeroOptimizer.update_unpartitoned_grad( - bucket_store, grad_store, non_moe_grad_list, flat_grads_per_rank, group_id - ) - - # sync moe param only in zero group - if len(moe_grad_list) > 0: - dist.all_reduce(moe_flat_grads, group=bucket_store.moe_extra_dp_pg) - flat_grads_per_rank = moe_flat_grads.split( - moe_flat_grads.numel() // bucket_store.zero_world_size - ) - LowLevelZeroOptimizer.update_unpartitoned_grad( - bucket_store, grad_store, moe_grad_list, flat_grads_per_rank, group_id - ) + if not self._partition_grads: + dist.all_reduce(flat_grads, group=bucket_store.torch_pg) + if flat_grads.dtype != grad_dtype: + flat_grads = flat_grads.to(grad_dtype) + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // bucket_store.world_size) + grad_in_bucket = bucket_store.get_grad() + self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id) else: - if bucket_store.moe_extra_dp_pg is None: - flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.zero_world_size)) - received_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg) + flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg) - if received_grad.dtype != grad_dtype: - received_grad = received_grad.to(grad_dtype) + if recieved_grad.dtype != grad_dtype: + recieved_grad = recieved_grad.to(grad_dtype) - grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank] - LowLevelZeroOptimizer.update_partitoned_grad( - bucket_store, grad_store, grad_in_bucket_current_rank, received_grad, group_id, 1 - ) - else: - # categorize moe and non moe param - grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank] - moe_grad_in_bucket_current_rank = [] - non_moe_grad_in_bucket_current_rank = [] - for idx, grad in enumerate(grad_in_bucket_current_rank): - if moe_list[idx] == True: - moe_grad_in_bucket_current_rank.append(grad) - else: - non_moe_grad_in_bucket_current_rank.append(grad) - - if len(non_moe_grad_list) > 0: - flat_grads_list = list( - non_moe_flat_grads.split(len(non_moe_flat_grads) // bucket_store.zero_world_size) - ) - received_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg) - LowLevelZeroOptimizer.update_partitoned_grad( - bucket_store, - grad_store, - non_moe_grad_in_bucket_current_rank, - received_grad, - group_id, - 1, - ) - - if len(moe_grad_list) > 0: - flat_grads_list = list( - moe_flat_grads.split(len(moe_flat_grads) // bucket_store.moe_extra_dp_pg_size) - ) - received_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter( - received_grad, - flat_grads_list, - group=bucket_store.moe_extra_dp_pg, - ) - param_slice = bucket_store.zero_world_size // bucket_store.moe_extra_dp_pg_size - received_grad = list(received_grad.split(len(received_grad) // param_slice)) - for split_recieved_grad in received_grad: - split_recieved_grad = _unflatten_dense_tensors( - split_recieved_grad, moe_grad_in_bucket_current_rank - ) - for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank): - param_id = bucket_store.get_param_id_of_grad(grad) - LowLevelZeroOptimizer.add_grad( - grad_store, real_grad, param_slice, group_id, param_id - ) + grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.local_rank] + self._update_partitoned_grad(bucket_store, grad_in_bucket_current_rank, recieved_grad, group_id, 1) bucket_store.reset() - @staticmethod - def update_unpartitoned_grad( - bucket_store: BucketStore, - grad_store: GradientStore, - origin_grad_list: List, - flat_grad_list: List, - group_id: int, + def _update_unpartitoned_grad( + self, bucket_store: BucketStore, origin_grad_list: List, flat_grad_list: List, group_id: int ) -> None: for rank, grad_list in enumerate(origin_grad_list): sync_tensor(flat_grad_list[rank], grad_list) for grad in grad_list: param_id = bucket_store.get_param_id_of_grad(grad) - LowLevelZeroOptimizer.add_grad(grad_store, grad, bucket_store.zero_world_size, group_id, param_id, rank) + self._add_grad(grad, bucket_store.world_size, group_id, param_id, rank) - @staticmethod - def update_partitoned_grad( + def _update_partitoned_grad( + self, bucket_store: BucketStore, - grad_store: GradientStore, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, @@ -482,30 +363,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper): sync_tensor(flat_grad, origin_grad_list) for grad in origin_grad_list: param_id = bucket_store.get_param_id_of_grad(grad) - LowLevelZeroOptimizer.add_grad(grad_store, grad, partition_num, group_id, param_id) + self._add_grad(grad, partition_num, group_id, param_id) - @staticmethod - def add_grad( - grad_store: GradientStore, + def _add_grad( + self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0, ) -> None: - if len(grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: - grad_store.append_gradients_by_param_id(grad, group_id, param_id) + if ( + len(self.pid_to_grad_store[param_id].get_partitioned_gradients_by_param_id(group_id, param_id)) + < partition_num + ): + self.pid_to_grad_store[param_id].append_gradients_by_param_id(grad, group_id, param_id) else: - grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) + self.pid_to_grad_store[param_id].add_gradients_by_param_id(grad, rank, group_id, param_id) - @staticmethod - def add_to_bucket( - param: nn.Parameter, - group_id: int, - bucket_store: BucketStore, - param_store: ParameterStore, - grad_store: GradientStore, - ): + def _add_to_bucket(self, param, group_id): param_size = param.numel() # check if the bucket is full @@ -513,13 +389,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # or got a grad of param from another group # after reduction, the bucket will be empty if ( - bucket_store.num_elements_in_bucket() + param_size > bucket_store.reduce_bucket_size - or group_id != bucket_store.current_group_id + self.pid_to_bucket_store[id(param)].num_elements_in_bucket() + param_size > self._reduce_bucket_size + or group_id != self.pid_to_bucket_store[id(param)].current_group_id ): - LowLevelZeroOptimizer.run_reduction(bucket_store, grad_store) + self._run_reduction() - padding_size = param_store.get_param_padding_size(param) - bucket_store.add_param_grad(group_id, param, padding_size) + padding_size = self.get_param_padding_size(param) + self.pid_to_bucket_store[id(param)].add_param_grad(group_id, param, padding_size) ################################ # torch.optim.Optimizer methods @@ -527,7 +403,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def backward(self, loss, retain_graph=False): assert not ( - self._grad_store._partition_grads and not self._grad_store.require_grad_sync + self._partition_grads and not self.require_grad_sync ), "ZeRO2(partition_grads) and no_sync are not compatible" if self.mixed_precision_mixin is not None: @@ -535,34 +411,39 @@ class LowLevelZeroOptimizer(OptimizerWrapper): loss.backward(retain_graph=retain_graph) - if not self._grad_store.require_grad_sync: + if not self.require_grad_sync: return - self._reduce_grad(self._grad_store._partition_grads) + self._reduce_grad(self._partition_grads) # clear reduced grads - if self._bucket_store._overlap_communication: + if self._overlap_communication: get_accelerator().synchronize() - self.zero_grad() def backward_by_grad(self, tensor, grad): assert not ( - self._grad_store._partition_grads and not self._grad_store.require_grad_sync + self._partition_grads and not self.require_grad_sync ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" if self.mixed_precision_mixin is not None: grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) torch.autograd.backward(tensor, grad) - if not self._grad_store.require_grad_sync: + if not self.require_grad_sync: return - self._reduce_grad(self._grad_store._partition_grads) + self._reduce_grad(self._partition_grads) # clear reduced grads - if self._bucket_store._overlap_communication: + if self._overlap_communication: get_accelerator().synchronize() - self.zero_grad() + def zero_bucket_stores(self): + for bucket_store in self.pg_to_bucket_store.values(): + bucket_store.reset_all() + + def zero_grad_stores(self): + for grad_store in self.pg_to_grad_store.values(): + grad_store.reset_all_gradients() def zero_grad(self, set_to_none=True): """ @@ -582,7 +463,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if param.grad is not None: param.grad.detach() param.grad.zero_() - self._bucket_store.reset_all() + self.zero_grad_stores() + self.zero_bucket_stores() #################### # Update Parameter # @@ -590,11 +472,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def step(self, closure=None): assert closure is None, "closure is not supported by step()" - if not self._grad_store.require_grad_sync: + if not self.require_grad_sync: return if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): - self._grad_store.reset_all_gradients() if self._verbose: self._logger.info(f"Found overflow. Skip step") self.zero_grad() @@ -609,71 +490,41 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # and should not be updated real_working_params = dict() real_master_params = dict() - grad_index = 0 if self._grad_store._partition_grads else self._bucket_store.zero_local_rank + for group_id in range(self.num_param_groups): master_params = self._master_param_groups_of_current_rank[group_id] + working_params = self._working_param_groups[group_id] real_working_params[group_id] = [] real_master_params[group_id] = [] - for splited_param in master_params: - working_param = self._param_store.master_to_working_param[id(splited_param)] + working_grads = [] + for working_param, master_param in zip(working_params, master_params): # if a working param requires grad and has no grad # it is not 'really' working, e.g. the droped layer # else the splited grad should be attached to the splited param - grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) + grad_store = self.pid_to_grad_store[id(working_param)] + grads = grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) + grad_index = 0 if self._partition_grads else grad_store.local_rank if len(grads) > 0: - # moe hybrid zero - if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param): - real_working_params[group_id].append(working_param) - if self._grad_store._partition_grads: - grad = grads - else: - param_slice = self._bucket_store.zero_world_size // self._bucket_store.moe_extra_dp_pg_size - grad = grads[ - self._bucket_store.moe_extra_dp_pg_rank - * param_slice : (self._bucket_store.moe_extra_dp_pg_rank + 1) - * param_slice - ] - grad = flatten(grad) - else: - real_working_params[group_id].append(working_param) - grad = grads[grad_index] + real_working_params[group_id].append(working_param) + grad = grads[grad_index] # no need to copy fp32 grad if master_weights is False if self._master_weights: - grad = grad.to(splited_param.dtype).to(splited_param.device) - splited_param.grad = grad + grad = grad.to(master_param.dtype).to(master_param.device) + master_param.grad = grad grad_partition_groups.append(grad) - real_master_params[group_id].append(splited_param) + real_master_params[group_id].append(master_param) # compute norm - working_grads = self._grad_store.get_working_grads_by_group_id(group_id) - norm_group = self._compute_grad_norm(gradients=working_grads) - norm_groups.append(norm_group) + norm_group = 0 + for grad_store in self.pg_to_grad_store.values(): + working_grads = grad_store.get_working_grads_by_group_id(group_id) + norm_group += self._compute_grad_norm(dp_pg=grad_store.torch_pg, gradients=working_grads) - self._grad_store.reset_grads_by_group_id(group_id) + norm_groups.append(norm_group) # update the params in the optimizer self.optim.param_groups[group_id]["params"] = real_master_params[group_id] - # update param for moe ep - # move grad to master param and compute norm - if len(self.working_moe_params) > 0: - moe_grads = [] - for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): - if master_moe_param.grad is not None: - raise RuntimeError("Moe param should not have grad here") - grad = working_moe_param.grad - # no need to copy fp32 grad if master_weights is False - if self._master_weights: - grad = grad.to(master_moe_param.dtype).to(master_moe_param.device) - master_moe_param.grad = grad - working_moe_param.grad = None - moe_grads.append(grad) - grad_partition_groups.append(grad) - norm_group = self._compute_grad_norm(gradients=moe_grads) - norm_groups.append(norm_group) - self.optim.param_groups[-1]["params"] = self.master_moe_params - del moe_grads - # unscale and clip grads global_norm = calculate_global_norm_from_list(norm_list=norm_groups) self._unscale_and_clip_grads(grad_partition_groups, global_norm) @@ -681,48 +532,34 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # update the parameters self.optim.step() - # release moe grad - if len(self.working_moe_params) > 0: - for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): - master_moe_param.grad = None - working_moe_param.data = ( - master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach() - ) - # release the grad grad_partition_groups = [] for group_id in range(self.num_param_groups): release_param_grad(self._master_param_groups_of_current_rank[group_id]) - tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size) - moe_tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size) + self.pg_to_tensor_bucket = { + pg: TensorBucket(self.pg_to_bucket_store[pg].reduce_bucket_size) for pg in self.pg_to_param_list + } # update working partition updated by the current rank device = get_accelerator().get_current_device() for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]["params"] - for idx, splited_param in enumerate(master_working_param): + for idx, master_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] - param_to_gather = splited_param.to(device).to(self._dtype) - if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param): - try: - moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param) - except RuntimeError: - moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg) - moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param) - else: - try: - tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param) - except RuntimeError: - tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg) - tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param) + param_to_gather = master_param.to(device).to(self._dtype) + pg = self.param_to_pg[working_param] + try: + self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) + except RuntimeError: + self.pg_to_tensor_bucket[pg].all_gather(pg) + self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] - if not moe_tensor_bucket.is_empty(): - moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg) - if not tensor_bucket.is_empty(): - tensor_bucket.all_gather(self._bucket_store.torch_pg) + for pg, tensor_bucket in self.pg_to_tensor_bucket.items(): + if not tensor_bucket.is_empty(): + tensor_bucket.all_gather(pg) - def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: + def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float: r""" Compute and return the gradient norm for gradient clipping. @@ -745,7 +582,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): device=get_accelerator().get_current_device(), dtype=torch.float, ) - dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self._bucket_store.torch_pg) + dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg) total_norm = total_norm_cuda.item() else: @@ -763,7 +600,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): torch.distributed.all_reduce( total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, - group=self._bucket_store.torch_pg, + group=dp_pg, ) total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) @@ -798,33 +635,27 @@ class LowLevelZeroOptimizer(OptimizerWrapper): param_group = self._working_param_groups[group_id] for param in param_group: if param.requires_grad and param.grad is not None: - LowLevelZeroOptimizer.add_to_bucket( - param, - group_id, - self._bucket_store, - self._param_store, - self._grad_store, - ) + self._add_to_bucket(param, group_id) - LowLevelZeroOptimizer.run_reduction(self._bucket_store, self._grad_store) + self._run_reduction() def _reduce_grad(self, partition_grad): # if not overlapping communication (no reduction hook is attached) when zero1 # we need to manually reduce these gradients - if not partition_grad and not self._bucket_store._overlap_communication: + if not partition_grad and not self._overlap_communication: self._sync_grad() else: - LowLevelZeroOptimizer.run_reduction(self._bucket_store, self._grad_store) + self._run_reduction() # this context comes from pytorch DDP @contextmanager def no_sync(self): - old_require_grad_sync = self._grad_store.require_grad_sync - self._grad_store.require_grad_sync = False + old_require_grad_sync = self.require_grad_sync + self.require_grad_sync = False try: yield finally: - self._grad_store.require_grad_sync = old_require_grad_sync + self.require_grad_sync = old_require_grad_sync ############## # State Dict # @@ -863,19 +694,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): zero_state[param] = copy.deepcopy(state) for k, v in state.items(): if isinstance(v, torch.Tensor) and k != "step": - working_param = self._param_store.master_to_working_param[id(param)] - if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v): - gather_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) - for _ in range(self._bucket_store.moe_extra_dp_pg_size) - ] - dist.all_gather(gather_tensor, v.to(device), group=self._bucket_store.moe_extra_dp_pg) - else: - gather_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) - for _ in range(self._bucket_store.zero_world_size) - ] - dist.all_gather(gather_tensor, v.to(device), group=self._bucket_store.torch_pg) + working_param = self.master_to_working_param[id(param)] + pg = self.param_to_pg[working_param] + gather_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())] + dist.all_gather(gather_tensor, v.to(device), group=pg) param_state = ( torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() ) @@ -892,26 +714,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper): state_dict (dict): A pytorch form state_dict """ zero_state_dict = copy.deepcopy(state_dict) + idx2master = {} + cnt = 0 + for param_group in self.optim.param_groups: + for param in param_group["params"]: + idx2master[cnt] = param + cnt += 1 for param_idx, state in zero_state_dict["state"].items(): + pg = self.param_to_pg[self.master_to_working_param[id(idx2master[param_idx])]] for k, v in state.items(): if isinstance(v, torch.Tensor) and k != "step": - padding_size = ( - self._bucket_store.zero_world_size - v.numel() % self._bucket_store.zero_world_size - ) % self._bucket_store.zero_world_size + padding_size = (pg.size() - v.numel() % pg.size()) % pg.size() with torch.no_grad(): v = v.flatten() if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) - if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v): - v_list = v.split(v.numel() // self._bucket_store.moe_extra_dp_pg_size) - zero_state_dict["state"][param_idx][k] = ( - v_list[self._bucket_store.moe_extra_dp_pg_rank].detach().clone() - ) - else: - v_list = v.split(v.numel() // self._bucket_store.zero_world_size) - zero_state_dict["state"][param_idx][k] = ( - v_list[self._bucket_store.zero_local_rank].detach().clone() - ) + v_list = v.split(v.numel() // pg.size()) + zero_state_dict["state"][param_idx][k] = v_list[pg.rank()].detach().clone() self.optim.load_state_dict(zero_state_dict) @@ -930,31 +749,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper): device = get_accelerator().get_current_device() local_states = self.optim.state_dict()["state"] + + idx2master = {} + cnt = 0 + for param_group in self.optim.param_groups: + for param in param_group["params"]: + idx2master[cnt] = param + cnt += 1 for param_idx, states in local_states.items(): current_block_size = 0 current_block = copy.deepcopy(states) - # find the working param of current param_id - for group_id, pg in self._master_param_groups_of_current_rank.items(): - if (group_id + 1) * len(pg) < param_idx: - continue - master_param = pg[param_idx - (group_id) * len(pg)] - working_param = self._param_store.master_to_working_param[id(master_param)] + master_param = idx2master[param_idx] + working_param = self.master_to_working_param[id(master_param)] + pg = self.param_to_pg[working_param] for k, v in states.items(): if isinstance(v, torch.Tensor) and k != "step": - if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v): - state_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) - for _ in range(self._bucket_store.moe_extra_dp_pg_size) - ] - dist.all_gather(state_tensor, v.to(device), group=self._bucket_store.moe_extra_dp_pg) - else: - state_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) - for _ in range(self._bucket_store.zero_world_size) - ] - dist.all_gather(state_tensor, v.to(device), group=self._bucket_store.torch_pg) + state_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())] + dist.all_gather(state_tensor, v.to(device), group=pg) state_tensor = ( torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() ) @@ -979,46 +792,96 @@ class LowLevelZeroOptimizer(OptimizerWrapper): """ for p in model.parameters(): p_id = id(p) - if p_id in self._param_store.working_to_master_param: - master_param = self._param_store.working_to_master_param[p_id] - padding_size = self._param_store.get_param_padding_size(p) + pg = self.param_to_pg[p] + if p_id in self.working_to_master_param: + master_param = self.working_to_master_param[p_id] + padding_size = self.get_param_padding_size(p) working_param = p.data.view(-1) if padding_size > 0: working_param = torch.nn.functional.pad(working_param, [0, padding_size]) - if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(p): - master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) - else: - master_param.copy_( - working_param.chunk(self._bucket_store.zero_world_size)[self._bucket_store.zero_local_rank] - ) - if hasattr(self, "master_moe_params"): - for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): - master_moe_param.copy_(working_moe_param) - - def remove_hooks(self) -> None: - """remove the registered hooks - - Args: - plugin (LowLevelZeroPlugin): the plugin to bound this method. - """ - for group_id in range(self.num_param_groups): - param_group = self._working_param_groups[group_id] - for param in param_group: - if param.requires_grad: - assert hasattr(param, "_grad_handle") - param._grad_handle.remove() - delattr(param, "_grad_handle") + master_param.copy_(working_param.chunk(pg.size())[pg.rank()]) def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: - return self._param_store.working_to_master_param + return self.working_to_master_param def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: - if hasattr(self, "moe_master_to_working_map"): - return { - **self._param_store.master_to_working_param, - **self.moe_master_to_working_map, - } - return self._param_store.master_to_working_param + return self.master_to_working_param def get_param_padding_map(self) -> Dict[int, torch.Tensor]: - return self._param_store.get_padding_map() + return self._padding_map + + def record_param_padding_size(self, param: Tensor, padding_size: int): + """Record the padding size of a param + + Args: + param (Tensor): The parameter + padding_size (int): The padding size of the parameter + """ + + self._padding_map[id(param)] = padding_size + + def get_param_padding_size(self, param: Tensor) -> int: + """Return the padding size of the parameter + + Args: + param (Tensor): The parameter + + Returns: + int: the padding size of the parameter + """ + + return self._padding_map[id(param)] + + def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor): + """Mapping master parameter and working parameter + + Args: + master_param (Tensor): The parameter copy in optimizer + working_param (Tensor): The parameter of the model + """ + + self.master_to_working_param[id(master_param)] = working_param + self.working_to_master_param[id(working_param)] = master_param + + def get_padding_map(self) -> Dict[int, Tensor]: + """Return the padding map + + Returns: + Dict[int, Tensor]: The padding map + """ + + return self._padding_map + + def get_param_grad(self, working_param: nn.Parameter) -> Tensor: + grad_store = self.pid_to_grad_store[id(working_param)] + partial_grad = grad_store.get_working_grad_by_param_id(id(working_param)) + if partial_grad is None: + return None + tensor_list = [torch.empty_like(partial_grad) for _ in range(grad_store.world_size)] + dist.all_gather(tensor_list, partial_grad, group=grad_store.torch_pg) + grad_flat = torch.cat(tensor_list, dim=0) + return grad_flat[: working_param.numel()].reshape_as(working_param) + + def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]: + working_grads = [] + for grad_store in self.pg_to_grad_store.values(): + working_grads.extend(grad_store.get_working_grads_by_group_id(group_id)) + return working_grads + + def get_param_id_for_grad(self, grad: Tensor) -> int: + param_id = None + for grad_store in self.pg_to_grad_store.values(): + id_maybe_none = grad_store.get_param_id_for_grad(grad) + if id_maybe_none is not None: + if param_id is not None: + raise ValueError("The grad mapping is not unique") + param_id = id_maybe_none + return param_id + + def get_working_grad_by_param_id(self, param_id: int) -> Tensor: + grad_store = self.pid_to_grad_store[param_id] + return grad_store.get_working_grad_by_param_id(param_id) + + def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: + grad_store = self.pid_to_grad_store[param_id] + return grad_store.get_partitioned_gradients_by_param_id(group_id, param_id) diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 22e0c790b..b9ef915c3 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -176,7 +176,7 @@ def main(): use_ep_inside = False plugin = MoeHybridParallelPlugin( pp_size=1, - extra_dp_size=args.extra_dp_size, + ep_size=args.ep_size, use_ep_inside=use_ep_inside, **hybrid_dict, ) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 5a9e30dd4..1febacd7d 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -50,9 +50,9 @@ try: except: HAS_FLASH_ATTN = False from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON -from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_activation, set_moe_args +from colossalai.shardformer.layer.moe import SparseMLP if HAS_TRITON: from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine @@ -83,7 +83,7 @@ def set_openmoe_args( load_balance_group_swap_factor: float = 0.4, enable_kernel: bool = False, enable_comm_overlap: bool = False, - enable_hierarchical_alltoall: bool = False, + enable_hierarchical_alltoall: bool = True, ) -> None: """ MoE related arguments. @@ -465,7 +465,7 @@ class OpenMoeDecoderLayer(nn.Module): load_balance_beam_width=config.load_balance_beam_width, load_balance_group_swap_factor=config.load_balance_group_swap_factor, enable_kernel=config.enable_kernel, - enable_comm_overlap=config.enable_comm_overlap, + enable_hierarchical_comm=config.enable_hierarchical_alltoall, ) self.pre_extra_mlp_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.extra_mlp = OpenMoeMLP(config) @@ -903,7 +903,7 @@ class OpenMoeForCausalLM(OpenMoePreTrainedModel): "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" # reset moe loss - MOE_MANAGER.reset_loss() + MOE_MANAGER.reset_loss() # TODO: remove output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1027,7 +1027,7 @@ class OpenMoeForCausalLM(OpenMoePreTrainedModel): def _calculate_router_loss(self, aux_loss: list = None, z_loss: list = None): if aux_loss is None or z_loss is None: - aux_loss, z_loss = MOE_MANAGER.get_loss() + aux_loss, z_loss = MOE_MANAGER.get_loss() # TODO: remove assert len(aux_loss) == len(z_loss) == self.config.num_hidden_layers // self.config.moe_layer_interval aux_loss = self.config.router_aux_loss_factor * sum(aux_loss) / len(aux_loss) z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss) diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py index 8ef07bdb9..f46062128 100644 --- a/examples/language/openmoe/model/openmoe_policy.py +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -172,6 +172,7 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy): if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm + # TODO: recursively assign ep group foe all modules new_item = { OpenMoeForCausalLM: ModulePolicyDescription( sub_module_replacement=[ diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index 960c83adb..9ea232478 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -1,37 +1,37 @@ -pip install -r requirements.txt +# pip install -r requirements.txt # inference -python infer.py --model "test" +# python infer.py --model "test" # train -torchrun --standalone --nproc_per_node 4 train.py \ - --num_epoch 1 \ - --model_name "test" \ - --plugin "ep" \ - --batch_size 1 +# torchrun --standalone --nproc_per_node 4 train.py \ +# --num_epoch 1 \ +# --model_name "test" \ +# --plugin "ep" \ +# --batch_size 1 -torchrun --standalone --nproc_per_node 4 train.py \ - --num_epoch 1 \ - --model_name "test" \ - --plugin "ep_zero" \ - --batch_size 1 \ - --zero_stage 1 \ - --extra_dp_size 2 \ +# torchrun --standalone --nproc_per_node 4 train.py \ +# --num_epoch 1 \ +# --model_name "test" \ +# --plugin "ep_zero" \ +# --batch_size 1 \ +# --zero_stage 1 \ +# --extra_dp_size 2 \ -torchrun --standalone --nproc_per_node 4 train.py \ - --num_epoch 1 \ - --model_name "test" \ - --plugin "ep_zero" \ - --batch_size 1 \ - --zero_stage 2 \ - --extra_dp_size 2 \ +# torchrun --standalone --nproc_per_node 4 train.py \ +# --num_epoch 1 \ +# --model_name "test" \ +# --plugin "ep_zero" \ +# --batch_size 1 \ +# --zero_stage 2 \ +# --extra_dp_size 2 \ -torchrun --standalone --nproc_per_node 4 train.py \ - --model_name "test" \ - --plugin "hybrid" \ - --num_epoch 1 \ - --pp_size 2 \ - --dp_size 1 \ - --ep_size 2 \ - --zero_stage 1 \ - --batch_size 1 +# torchrun --standalone --nproc_per_node 4 train.py \ +# --model_name "test" \ +# --plugin "hybrid" \ +# --num_epoch 1 \ +# --pp_size 2 \ +# --dp_size 1 \ +# --ep_size 2 \ +# --zero_stage 1 \ +# --batch_size 1 diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 40f072f13..ff0e4bad6 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -19,10 +19,9 @@ from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator -from colossalai.moe.layers import apply_load_balance -from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import skip_init from colossalai.nn.optimizer import HybridAdam +from colossalai.shardformer.layer.moe import apply_load_balance def move_to_cuda(batch, device): @@ -221,48 +220,49 @@ def main(): "precision": args.precision, "zero_stage": args.zero_stage, } - mgr_dict = {} if args.plugin == "ep": dp_size = dist.get_world_size() plugin = MoeHybridParallelPlugin( pp_size=1, + ep_size=args.ep_size, **hybrid_dict, ) - MOE_MANAGER.setup( - parallel="EP", - max_ep_size=dp_size, - **mgr_dict, - ) + # MOE_MANAGER.setup( + # parallel="EP", + # max_ep_size=dp_size, + # **mgr_dict, + # ) elif args.plugin == "ep_zero": dp_size = dist.get_world_size() use_ep_inside = False plugin = MoeHybridParallelPlugin( pp_size=1, - extra_dp_size=args.extra_dp_size, + ep_size=dp_size // args.ep_size, use_ep_inside=use_ep_inside, **hybrid_dict, ) - MOE_MANAGER.setup( - parallel="EP", - max_ep_size=dp_size // args.extra_dp_size, - use_ep_inside=use_ep_inside, - **mgr_dict, - ) + # MOE_MANAGER.setup( + # parallel="EP", + # max_ep_size=dp_size // args.extra_dp_size, + # use_ep_inside=use_ep_inside, + # **mgr_dict, + # ) elif args.plugin == "hybrid": dp_size = dist.get_world_size() // args.pp_size plugin = MoeHybridParallelPlugin( pp_size=args.pp_size, + ep_size=args.ep_size, microbatch_size=args.microbatch_size, **hybrid_dict, ) - MOE_MANAGER.setup( - parallel="EP", - mode="fixed", - fixed_dp_size=args.dp_size, - fixed_ep_size=args.ep_size, - fixed_pp_size=args.pp_size, - **mgr_dict, - ) + # MOE_MANAGER.setup( + # parallel="EP", + # mode="fixed", + # fixed_dp_size=args.dp_size, + # fixed_ep_size=args.ep_size, + # fixed_pp_size=args.pp_size, + # **mgr_dict, + # ) else: raise ValueError(f"Invalid plugin {args.plugin}") coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 24dc4a5d2..ab48944d4 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -59,10 +59,10 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): # check master weight assert isinstance(new_optimizer, LowLevelZeroOptimizer) working_param_id_set = set(id(p) for p in new_model.parameters()) - for p_id, master_param in new_optimizer._param_store.working_to_master_param.items(): + for p_id, master_param in new_optimizer.working_to_master_param.items(): assert p_id in working_param_id_set - working_param = new_optimizer._param_store.master_to_working_param[id(master_param)] - padding = new_optimizer._param_store.get_param_padding_size(working_param) + working_param = new_optimizer.master_to_working_param[id(master_param)] + padding = new_optimizer.get_param_padding_size(working_param) padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding)) working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()] assert torch.equal( @@ -115,10 +115,10 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo # check master weight assert isinstance(new_optimizer, LowLevelZeroOptimizer) working_param_id_set = set(id(p) for p in new_model.parameters()) - for p_id, master_param in new_optimizer._param_store.working_to_master_param.items(): + for p_id, master_param in new_optimizer.working_to_master_param.items(): assert p_id in working_param_id_set - working_param = new_optimizer._param_store.master_to_working_param[id(master_param)] - padding = new_optimizer._param_store.get_param_padding_size(working_param) + working_param = new_optimizer.master_to_working_param[id(master_param)] + padding = new_optimizer.get_param_padding_size(working_param) padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding)) working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()] assert torch.equal( diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 17b790e3e..131932dcb 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -1,48 +1,37 @@ import torch import torch.distributed as dist import torch.nn as nn +from torch.distributed import ProcessGroup from torch.testing import assert_close from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce from colossalai.legacy.registry import GRADIENT_HANDLER -from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_moe_epsize_param_dict -from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size + +# from colossalai.shardformer.layer.moe import SparseMLP +from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group def delete_moe_info(model): for _, param in model.named_parameters(): - if hasattr(param, "moe_info"): - delattr(param, "moe_info") + if hasattr(param, "ep_group"): + delattr(param, "ep_group") class MoeModel(nn.Module): - def __init__(self, enable_load_balance: bool = False): - class TestSubModule(nn.Module): - def __init__(self): - super().__init__() - self.moe = SparseMLP( - num_experts=8, hidden_size=16, intermediate_size=32, enable_load_balance=enable_load_balance - ) - self.proj = nn.Linear(16, 4) - - def forward(self, x): - x = self.moe(x) - x = self.proj(x) - return x - + def __init__(self, ep_group: ProcessGroup = None): super().__init__() - self.test_embed = nn.Linear(4, 16) - self.test_transform = TestSubModule() + self.test_embed = nn.Linear(4, 16, bias=False) + self.w1 = torch.nn.Parameter(torch.randn(16, 8)) + if ep_group: + set_moe_tensor_ep_group(self.w1, ep_group) def forward(self, x): - MOE_MANAGER.reset_loss() - x = self.test_embed(x) - x = self.test_transform(x) + x = torch.matmul(x, self.w1) return x @@ -116,7 +105,7 @@ def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False) return y -def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: +def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None: """Sync the parameters of tp model from ep model Args: @@ -126,7 +115,6 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_ for (local_name, local_param), (ep_name, ep_param) in zip( local_model.named_parameters(), ep_model.named_parameters() ): - assert local_name in ep_name, print(f"{local_name} != {ep_name}") if "experts" not in local_name: if assert_grad_flag: assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}" diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index a88f5f9cc..25e61b091 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -5,8 +5,9 @@ import torch.nn as nn import colossalai from colossalai.accelerator import get_accelerator -from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER + +# from colossalai.shardformer.layer.moe.layers import SparseMLP from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from tests.test_moe.moe_utils import MoeGradientHandler @@ -69,6 +70,7 @@ def run_test(rank, world_size, port): # MoE grad handler test passed +@pytest.mark.skip(reason="moe need to be refactored") @pytest.mark.dist @rerun_if_address_is_in_use() def test_grad_handler(): diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 30122d31a..28e6db441 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -1,98 +1,96 @@ +import os + import pytest import torch -import torch.distributed as dist -import colossalai from colossalai.accelerator import get_accelerator -from colossalai.moe import SparseMLP -from colossalai.moe.manager import MOE_MANAGER -from colossalai.testing import rerun_if_address_is_in_use, spawn -BATCH_SIZE = 4 +# from colossalai.moe import SparseMLP +from colossalai.moe._operation import MoeCombine, MoeDispatch, moe_cumsum + NUM_EXPERTS = 4 +BATCH_SIZE = 4 +SEQ_LEN = 4 + +MOE_TENSOR_PATH = os.getenv("MOE_TENSOR_PATH") def check_equal(tensor_a, tensor_b, atol=1e-06): assert torch.allclose(tensor_a, tensor_b, rtol=0, atol=atol) is True -def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, topk=1): - # Here we do not need TF32, since it brings absolute error on results - torch.backends.cuda.matmul.allow_tf32 = False +def run_moe_cumsum(): + test_mask = torch.tensor( + [ + [0, 1, 0, 0], + [1, 0, 0, 0], + [0, 1, 0, 0], + [1, 0, 0, 0], + ], + dtype=torch.int32, + ).to("cuda") + out_no_kernel = moe_cumsum(test_mask, use_kernel=False) + out_kernel = moe_cumsum(test_mask, use_kernel=True) + print(out_no_kernel.dtype, out_kernel.dtype) + check_equal(out_no_kernel.to(torch.int32), out_kernel) - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - local_rank = dist.get_rank() - MOE_MANAGER.setup(parallel="EP") # MOE environment initialization - MOE_MANAGER.reset_loss() - torch.manual_seed(rs + local_rank) # set each process has different random seed - - # get randomized data +def run_moe_dispatch_combine_fwd_bwd(data_type=torch.float32, hidden_size=128, num_experts=4): tokens = torch.randn( BATCH_SIZE, hidden_size, dtype=data_type, device=get_accelerator().get_current_device(), requires_grad=True ) - layer = SparseMLP( - hidden_size=hidden_size, - intermediate_size=hidden_size * 2, - num_experts=NUM_EXPERTS, - router_top_k=topk, - router_capacity_factor_train=1.0, - ) - layer = layer.to(get_accelerator().get_current_device()) - if data_type == torch.float16: - layer = layer.half() + # use kernel + route_result_list_kernel = torch.load(f"{MOE_TENSOR_PATH}/True_4_{data_type}.pt") + # dispatch + dispatch_data_kernel = MoeDispatch.apply(tokens, *route_result_list_kernel[1:]) + dispatch_data_kernel = dispatch_data_kernel.reshape(num_experts, -1, hidden_size) + # combine + expert_output = dispatch_data_kernel.reshape(-1, hidden_size) + ans_kernel = MoeCombine.apply(expert_output, *route_result_list_kernel) - # use matrix multiplication instead of COL_MOE_KERNEL in MOE dispatch and combine - layer.enable_kernel = False - old_out = layer(tokens) - ech = old_out.shape - grad = torch.randn(ech, device=get_accelerator().get_current_device()) - old_out.backward(grad) # get gradient + # no kernel + route_result_list_no_kernel = torch.load(f"{MOE_TENSOR_PATH}/False_2_{data_type}.pt") + # dispatch + sec_mask_f = route_result_list_no_kernel[1].type_as(tokens) + dispatch_data_no_kernel = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) + # combine + combine_weights = route_result_list_no_kernel[0].type_as(tokens) + combine_weights = combine_weights.view(combine_weights.shape[0], -1) + expert_output = expert_output.view(-1, expert_output.shape[-1]) + ans_no_kernel = torch.matmul(combine_weights, expert_output) - # save all results - o_tk_grad = tokens.grad.data.clone() - o_gt_grad = layer.gate_weight.grad.data.clone() + # check fwd + if data_type == torch.float32: + check_equal(dispatch_data_kernel.reshape(dispatch_data_no_kernel.shape), dispatch_data_no_kernel) + else: + check_equal(dispatch_data_kernel.reshape(dispatch_data_no_kernel.shape), dispatch_data_no_kernel, 1e-2) - # reset all gradients + if data_type == torch.float32: + check_equal(ans_kernel, ans_no_kernel) + else: + check_equal(ans_kernel, ans_no_kernel, 1e-2) + + # check bwd + out_shape = ans_kernel.shape + grad = torch.randn(out_shape, device=get_accelerator().get_current_device()) + + ans_kernel.backward(grad, retain_graph=True) + grad_kernel = tokens.grad.data.clone() tokens.grad.zero_() - layer.gate_weight.grad.zero_() - layer.enable_kernel = True - new_out = layer(tokens) # get outputs through colossal kernel + ans_no_kernel.backward(grad) # get gradient + grad_no_kernel = tokens.grad.data.clone() + tokens.grad.zero_() if data_type == torch.float32: - check_equal(old_out, new_out) + check_equal(grad_no_kernel, grad_kernel) else: - check_equal(old_out, new_out, 1e-2) - # forward function passed - - new_out.backward(grad) # get new type gradient - n_tk_grad = tokens.grad.data.clone() - n_gt_grad = layer.gate_weight.grad.data.clone() - - if data_type == torch.float32: - check_equal(o_tk_grad, n_tk_grad) - else: - check_equal(o_tk_grad, o_tk_grad, 1e-2) - # tokens gradient is correct - - if data_type == torch.float32: - check_equal(o_gt_grad, n_gt_grad, 5e-05) - else: - check_equal(o_gt_grad, n_gt_grad, 2e-01) - # bias gradient is correct + check_equal(grad_no_kernel, grad_kernel, 1e-2) -@pytest.mark.dist -@pytest.mark.parametrize("rs", [131]) -@pytest.mark.parametrize("hidden_size", [32, 144]) @pytest.mark.parametrize("data_type", [torch.float32, torch.float16]) -@pytest.mark.parametrize("topk", [1, 2]) -@rerun_if_address_is_in_use() -def test_moe_kernel(rs, hidden_size, data_type, topk): - spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, topk=topk) - - -if __name__ == "__main__": - test_moe_kernel(2, 256, torch.float16, 2) +def test_moe_kernel(data_type): + torch.manual_seed(1024) + run_moe_cumsum() + run_moe_dispatch_combine_fwd_bwd(data_type=data_type) diff --git a/applications/ColossalMoE/tests/test_mixtral_layer.py b/tests/test_moe/test_mixtral_layer.py similarity index 81% rename from applications/ColossalMoE/tests/test_mixtral_layer.py rename to tests/test_moe/test_mixtral_layer.py index cbb70f195..b7b0322e0 100644 --- a/applications/ColossalMoE/tests/test_mixtral_layer.py +++ b/tests/test_moe/test_mixtral_layer.py @@ -3,13 +3,13 @@ from copy import deepcopy import pytest import torch import torch.distributed as dist -from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock from torch.testing import assert_close from transformers.models.mixtral.configuration_mixtral import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import colossalai -from colossalai.moe import MOE_MANAGER +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock from colossalai.testing.utils import spawn tokens, n_experts = 7, 4 @@ -19,8 +19,11 @@ top_k = 2 def check_mixtral_moe_layer(): torch.cuda.set_device(dist.get_rank()) - MOE_MANAGER.setup( - parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1 + plugin = MoeHybridParallelPlugin( + precision="bf16", + tp_size=1, + pp_size=1, + ep_size=dist.get_world_size(), ) config = MixtralConfig( hidden_size=hidden_size, @@ -33,7 +36,7 @@ def check_mixtral_moe_layer(): x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda() orig_output, orig_logits = orig_model(x) model = deepcopy(orig_model) - model = EPMixtralSparseMoeBlock.from_native_module(model) + model = EPMixtralSparseMoeBlock.from_native_module(model, ep_group=plugin.ep_group) ep_output, ep_logits = model(x) assert_close(orig_logits, ep_logits) assert_close(orig_output, ep_output) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 10e63592a..249dd4b97 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -1,201 +1,176 @@ -import importlib import os -import shutil -import sys +import tempfile +from contextlib import nullcontext +from copy import deepcopy import pytest import torch import torch.distributed as dist -from transformers.models.llama import LlamaConfig +from torch.optim import Adam +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM import colossalai -from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin -from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn +from colossalai.checkpoint_io import MoECheckpointIO +from colossalai.tensor.moe_tensor.api import is_moe_tensor +from colossalai.testing.utils import spawn -sys.path.append( - os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(__file__))), - "examples/language/openmoe", - ) -) - -OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM -set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args -OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy +tokens, n_experts = 7, 4 +hidden_size = 8 +top_k = 2 -def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20): - input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_accelerator().get_current_device()) - attention_mask = torch.ones_like(input_ids) +def check_model_equal(model1, model2): + assert set(model1.state_dict().keys()) == set(model2.state_dict().keys()) + for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())): + if not torch.equal(p1.half(), p2.half()): + print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}") + raise AssertionError(f"Model parameter {name} is not equal") + + +def get_optimizer_snapshot(optim): + state = {id(k): deepcopy(v) for k, v in optim.state.items()} + param_groups = [] + for group in optim.param_groups: + params = [id(p) for p in group["params"]] + new_group = {"params": params} + for k, v in group.items(): + if k != "params": + new_group[k] = v + param_groups.append(new_group) return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": input_ids, + "state": state, + "param_groups": param_groups, } -def run_fwd_bwd( - model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None -): - model.train() - if pipeline: - train_dataloader_iter = DummyDataloader(data_gen_fn, length=1) - is_pp_last_stage = booster.plugin.stage_manager.is_last_stage() - y = booster.execute_pipeline( - train_dataloader_iter, - model, - lambda x, y: x.loss, - optimizer, - return_loss=True, - ) - # Backward and optimize - if is_pp_last_stage: - loss = y["loss"] - else: - if criterion: - y = model(data).logits - loss = criterion(y) +def check_optimizer_snapshot_equal(snapshot1, snapshot2, param2name, moe_dp_group=None): + assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"]) + for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]): + assert set(group1.keys()) == set(group2.keys()) + for k in group1.keys(): + assert group1[k] == group2[k] + # check state + assert set(snapshot1["state"].keys()) == set( + snapshot2["state"].keys() + ), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}" + + passed = True + count = 0 + for pid in snapshot1["state"].keys(): + state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid] + assert set(state1.keys()) == set(state2.keys()) + bug = False + for k in state1.keys(): + if isinstance(state1[k], torch.Tensor): + if not torch.equal(state1[k], state2[k]): + bug = True + count += 1 + else: + assert state1[k] == state2[k] + if bug: + passed = False + + if not passed: + raise AssertionError(f"A total of {count} optim states are not equal") + + +def check_mixtral_moe_layer(): + context = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext() + with context as f: + torch.cuda.set_device(dist.get_rank()) + if dist.get_rank() == 0: + broadcast_objects = [f] # any picklable object else: - loss = model(data, label) - loss = loss.float() + broadcast_objects = [None] + dist.broadcast_object_list(broadcast_objects, src=0) - if optimizer is not None: - optimizer.backward(loss) - else: - loss.backward() - return y - - -def get_config(): - config = LlamaConfig( - vocab_size=300, - hidden_size=16, - intermediate_size=32, - num_hidden_layers=2, - num_attention_heads=2, - head_dim=4, - dropout_rate=0.0, - hidden_act="swiglu", - ) - set_openmoe_args(config, num_experts=8, moe_layer_interval=1) - return config - - -def get_model(parallel): - config = get_config() - model = OpenMoeForCausalLM(config) - optim = torch.optim.Adam(model.parameters()) - - if parallel == None: - plugin = MoeHybridParallelPlugin( - precision="bf16", - tp_size=1, - pp_size=1, - ep_size=1, - zero_stage=2, - custom_policy=OpenMoeForCausalLMPolicy(), + config = MixtralConfig( + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_local_experts=n_experts, + num_experts_per_tok=top_k, + num_attention_heads=2, + num_key_value_heads=2, ) - elif parallel == "ep": + torch.manual_seed(0) + input_ids = torch.randint(0, 100, (2, tokens)).cuda() + orig_model = MixtralForCausalLM(config).cuda() + model = deepcopy(orig_model) + optimizer = Adam(model.parameters(), lr=1e-3) plugin = MoeHybridParallelPlugin( - precision="bf16", - tp_size=1, - pp_size=1, - ep_size=dist.get_world_size(), - zero_stage=2, - custom_policy=OpenMoeForCausalLMPolicy(), - ) - elif parallel == "ep_zero": - plugin = MoeHybridParallelPlugin( - precision="bf16", - tp_size=1, - pp_size=1, - ep_size=2, - zero_stage=2, - extra_dp_size=2, - custom_policy=OpenMoeForCausalLMPolicy(), - ) - elif parallel == "hybrid": - plugin = MoeHybridParallelPlugin( - precision="bf16", - tp_size=1, pp_size=2, ep_size=2, - zero_stage=1, + tp_size=1, + checkpoint_io=MoECheckpointIO, microbatch_size=1, - custom_policy=OpenMoeForCausalLMPolicy(), + zero_stage=1, ) - booster = Booster(plugin=plugin) - model, optim, _, _, _ = booster.boost(model=model, optimizer=optim) - return model, booster, optim + booster = Booster(plugin=plugin) + model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer) + # initialize grads + data_iter = iter( + [{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}] + ) + booster.execute_pipeline( + data_iter, + model, + lambda outputs, inputs: outputs.loss, + optimizer, + ) + + tmpdirname = broadcast_objects[0] + model_dir = os.path.join(tmpdirname, "mixtral_model") + hf_model_dir = os.path.join(tmpdirname, "mixtral_hf_model") + optim_dir = os.path.join(tmpdirname, "mixtral_optim") + + booster.save_model(model, model_dir, shard=True) + dist.barrier() + if dist.get_rank() == 0: + saved_model = MixtralForCausalLM.from_pretrained(model_dir).cuda() + check_model_equal(orig_model, saved_model) + # check_model_equal(model, saved_model) + saved_model.save_pretrained(hf_model_dir) + dist.barrier() + # check load model + new_model = MixtralForCausalLM(config).cuda() + new_optimizer = Adam(new_model.parameters(), lr=1e-3) + new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer) + booster.load_model(new_model, hf_model_dir) + check_model_equal(model, new_model) + + # check save optimizer + optimizer.step() + for group in optimizer.param_groups: + group["lr"] = 0.1 + snapshot = get_optimizer_snapshot(optimizer.unwrap()) + booster.save_optimizer(optimizer, optim_dir, shard=True) + dist.barrier() + + # reset optimizer state + for state in optimizer.unwrap().state.values(): + for v in state.values(): + if isinstance(v, torch.Tensor): + v.zero_() + booster.load_optimizer(optimizer, optim_dir) + loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap()) + check_optimizer_snapshot_equal(snapshot, loaded_snapshot, None, model) + # Ensure rank 0 waits for all other ranks to finish + dist.barrier() -def _test_moe_checkpoint(rank, parallel): - model1, booster1, optim1 = get_model(parallel) - model2, booster2, optim2 = get_model(parallel) - model3, booster3, optim3 = get_model(parallel) - - # param ckpt - # shard - booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1) - booster2.load_model(model2, "./tmp_ckpt1") - # unshard - booster1.save_model(model1, "./tmp_ckpt1.pth") - booster3.load_model(model3, "./tmp_ckpt1.pth") - # check - check_state_dict_equal(model1.state_dict(), model2.state_dict(), False) - check_state_dict_equal(model1.state_dict(), model3.state_dict(), False) - - # optim ckpt - criterion = lambda x: x.mean() - data = torch.randint(0, 4, (2, 4)).cuda() - label = torch.randint(0, 4, (2,)).cuda() - if parallel == "hybrid": - kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin} - else: - kwargs = {} - run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs) - optim1.step() - optim1.zero_grad() - # shard - booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1) - dist.barrier() - booster2.load_optimizer(optim2, "./tmp_ckpt2") - # unshard - booster1.save_optimizer(optim1, "./tmp_ckpt2.pth") - booster3.load_optimizer(optim3, "./tmp_ckpt2.pth") - # check - check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False) - check_state_dict_equal(optim1.optim.state_dict(), optim3.optim.state_dict(), False) - - if dist.get_rank() == 0: - shutil.rmtree("./tmp_ckpt1") - shutil.rmtree("./tmp_ckpt2") - os.remove("./tmp_ckpt1.pth") - os.remove("./tmp_ckpt2.pth") +def run_dist(rank: int, world_size: int, port: int): + colossalai.launch(rank, world_size, "localhost", port) + check_mixtral_moe_layer() -def _run_dist(rank, world_size, port, parallel): - colossalai.launch( - config=dict(), - rank=rank, - world_size=world_size, - host="localhost", - port=port, - backend="nccl", - ) - _test_moe_checkpoint(rank, parallel) - - -@pytest.mark.skip(reason="This is tested in ColossalMOE") -@pytest.mark.dist +# Test EP + ZeRO + PP @pytest.mark.parametrize("world_size", [4]) -@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"]) -@rerun_if_address_is_in_use() -def test_moe_checkpoint(world_size, parallel): - spawn(_run_dist, world_size, parallel=parallel) +def test_mixtral_moe_layer(world_size: int): + spawn(run_dist, world_size) if __name__ == "__main__": - test_moe_checkpoint(world_size=4, parallel="hybrid") + test_mixtral_moe_layer(4) diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 660fbd358..9bc11033a 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -8,15 +8,16 @@ import torch.distributed as dist import colossalai from colossalai.accelerator import get_accelerator -from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import sync_moe_model_param + +# from colossalai.shardformer.layer import SparseMLP from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from tests.test_moe.moe_utils import MoeGradientHandler -def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_flag: bool = False) -> None: +def sync_tp_from_local(tp_model, local_model, assert_grad_flag: bool = False) -> None: """Sync the parameters of tp model from local model Args: @@ -48,7 +49,7 @@ def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_ tp_param.data.copy_(local_param[tuple(tp_slice)].data) -def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: +def sync_tp_from_ep(tp_model, ep_model, assert_grad_flag: bool = False) -> None: """Sync the parameters of tp model from ep model Args: @@ -90,7 +91,7 @@ def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: tp_param.data.copy_(new_tp_param.data) -def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: +def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None: """Sync the parameters of tp model from ep model Args: @@ -216,6 +217,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size ) +@pytest.mark.skip(reason="moe need to be refactored") @pytest.mark.dist @pytest.mark.parametrize("num_experts", [4, 64]) @pytest.mark.parametrize("batch_size", [16]) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index b7be54d26..89baf1d37 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -4,9 +4,10 @@ import torch.nn as nn import colossalai from colossalai.accelerator import get_accelerator -from colossalai.moe.experts import MLPExperts from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import sync_moe_model_param + +# from colossalai.shardformer.layer.moe import MLPExperts from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn HIDDEN_SIZE = 4 @@ -69,6 +70,7 @@ def _run_test(rank, world_size, port, expert_parallel): run_moe_init(expert_parallel) +@pytest.mark.skip(reason="moe need to be refactored") @pytest.mark.dist @pytest.mark.parametrize("expert_parallel", ["EP", "TP"]) @rerun_if_address_is_in_use() diff --git a/tests/test_moe/test_moe_hybrid_zero.py b/tests/test_moe/test_moe_hybrid_zero.py index 7932fa8a7..513c4ebda 100644 --- a/tests/test_moe/test_moe_hybrid_zero.py +++ b/tests/test_moe/test_moe_hybrid_zero.py @@ -86,6 +86,7 @@ def run_dist(rank, world_size, port): run_zero_optim_test(rank, world_size, stage=2) +@pytest.mark.skip(reason="moe need to be refactored") @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py index fae189bac..ddd3ea368 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_moe/test_moe_load_balance.py @@ -6,8 +6,9 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel -from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER + +# from colossalai.shardformer.layer.moe import apply_load_balance from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.testing import rerun_if_address_is_in_use, spawn from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel @@ -176,6 +177,7 @@ def run_dist(rank, world_size, port): run_hybrid_zero_optim_test(rank, world_size, stage=2) +@pytest.mark.skip(reason="moe need to be refactored") @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() diff --git a/tests/test_moe/test_moe_router.py b/tests/test_moe/test_moe_router.py deleted file mode 100644 index 9f6167692..000000000 --- a/tests/test_moe/test_moe_router.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest -import torch - -from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter - - -@pytest.mark.parametrize( - ["router", "num_groups"], - [ - (Top1Router(), 1), - (Top2Router(), 1), - # (TopKRouter(num_selected_experts=3), 4), - ], -) -@pytest.mark.parametrize( - ["batch_size", "seq_len", "num_experts"], - [ - (4, 5, 8), - (3, 4, 4), - ], -) -def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int): - x = torch.randn((batch_size * seq_len, num_experts)).cuda() - if num_groups > 1: - x = x.expand(num_groups, -1, -1) - - router.train() - if isinstance(router, TopKRouter): - combine_array, dispatch_mask = router(x, expert_capacity=2) - else: - combine_array, dispatch_mask = router(x)[1:3] - assert combine_array.shape[:-1] == x.shape - assert dispatch_mask.shape[:-1] == x.shape - assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) - - router.eval() - if isinstance(router, TopKRouter): - combine_array, dispatch_mask = router(x, expert_capacity=2) - else: - combine_array, dispatch_mask = router(x)[1:3] - assert combine_array.shape[:-1] == x.shape - assert dispatch_mask.shape[:-1] == x.shape - assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) - - -if __name__ == "__main__": - test_router_forward(Top2Router(), 4, 4, 4, 1) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py deleted file mode 100644 index 3bb08b49e..000000000 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ /dev/null @@ -1,78 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin -from colossalai.moe.manager import MOE_MANAGER -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.testing.random import seed_all -from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep - - -def run_zero_test(local_rank, stage=1): - criterion = torch.nn.CrossEntropyLoss() - - MOE_MANAGER.__init__() - MOE_MANAGER.setup(parallel="EP") - moe_model = MoeModel().bfloat16() - moe_optimizer = torch.optim.Adam(moe_model.parameters()) - moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") - moe_booster = Booster(plugin=moe_plugin) - moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer) - - MOE_MANAGER.__init__() - MOE_MANAGER.setup(parallel=None) - zero_model = MoeModel().bfloat16() - delete_moe_info(zero_model) - zero_optimizer = torch.optim.Adam(zero_model.parameters()) - zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") - zero_booster = Booster(plugin=zero_plugin) - zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer) - sync_local_from_ep(zero_model, moe_model) - - data = torch.randn(16, 4).bfloat16().cuda() - label = torch.randint(0, 4, (16,)).cuda() - - zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) - moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer) - assert torch.allclose(zero_out, moe_out) - - for (moe_name, moe_param), (zero_name, zero_param) in zip( - moe_model.module.named_parameters(), zero_model.module.named_parameters() - ): - assert moe_name == zero_name - moe_grad_list = moe_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(moe_param)) - zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) - if hasattr(moe_param, "moe_info"): - assert len(moe_grad_list) == 0 - if stage == 1: - zero_grad = zero_grad_list[local_rank].view(moe_param.grad.shape) - else: - zero_grad = zero_grad_list[0].view(moe_param.grad.shape) - assert torch.allclose( - moe_param.grad, zero_grad, atol=1e-5 - ), f"zero grad:\n{moe_param.grad}\ntorch grad:\n{zero_grad}\nmax diff: {(moe_param.grad - zero_grad).abs().max()}, mean diff: {(moe_param.grad - zero_grad).abs().mean()}" - else: - assert len(moe_grad_list) > 0 - assert len(moe_grad_list) == len(zero_grad_list) - for moe_grad, zero_grad in zip(moe_grad_list, zero_grad_list): - assert torch.allclose(moe_grad, zero_grad) - - -def run_dist(rank, world_size, port, stage): - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - seed_all(42 + rank) - run_zero_test(rank, stage=stage) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2]) -@pytest.mark.parametrize("stage", [1, 2]) -@rerun_if_address_is_in_use() -def test_moe_zero_model(world_size, stage): - spawn(run_dist, world_size, stage=stage) - - -if __name__ == "__main__": - test_moe_zero_model(world_size=2, stage=1) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py new file mode 100644 index 000000000..042b3d8ae --- /dev/null +++ b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py @@ -0,0 +1,132 @@ +from copy import deepcopy + +import pytest +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +import colossalai +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock +from colossalai.tensor.moe_tensor.api import is_moe_tensor +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from colossalai.zero import LowLevelZeroOptimizer +from tests.test_moe.moe_utils import loose_close + +tokens, n_experts = 7, 4 +hidden_size = 8 +top_k = 2 + + +def split_grad(grad, world_size): + with torch.no_grad(): + grad = grad.clone().detach().flatten() + padding_size = (world_size - grad.numel() % world_size) % world_size + if padding_size > 0: + grad = torch.nn.functional.pad(grad, [0, padding_size]) + splited_grad = grad.split(grad.numel() // world_size) + return splited_grad + + +@parameterize("dtype", [torch.float16, torch.bfloat16]) +@parameterize("master_weights", [True, False]) +@parameterize("stage", [1, 2]) +def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.dtype, stage: int): + rank = torch.distributed.get_rank() + torch.cuda.set_device(dist.get_rank()) + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + ep_size=dist.get_world_size() // 2, + ) + + seed_all(10086) + config = MixtralConfig( + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_local_experts=n_experts, + num_experts_per_tok=top_k, + ) + + orig_model = MixtralSparseMoeBlock(config).to(dtype).cuda() + + ori_model = DDP(orig_model.cuda(), static_graph=True).cuda() + + zero_model = deepcopy(orig_model).to(dtype) + zero_model = EPMixtralSparseMoeBlock.from_native_module(zero_model, ep_group=plugin.ep_group) + + zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) + pg_param_list = {plugin.global_dp_group: [], plugin.moe_dp_group: []} + for p in zero_model.parameters(): + if is_moe_tensor(p): + pg_param_list[plugin.moe_dp_group].append(p) + else: + pg_param_list[plugin.global_dp_group].append(p) + + zero_optimizer = LowLevelZeroOptimizer( + zero_optimizer, + pg_to_param_list=pg_param_list, + master_weights=master_weights, + initial_scale=1, + overlap_communication=False, + partition_grad=True, + ) + + ori_optimizer = torch.optim.SGD(ori_model.parameters(), lr=1) + + # create + seed_all(1453 + rank) + + for _ in range(2): + # zero-dp forward + input_data = torch.rand(1, tokens, hidden_size).cuda() + zero_output, zero_logits = zero_model(input_data.to(dtype)) + + # torch-ddp forward + ori_output, ori_logits = ori_model(input_data.to(dtype)) + loose_close(zero_output, ori_output, dtype=dtype) + + # zero-dp backward + zero_optimizer.backward(zero_output.mean().float()) + + # torch-ddp backward + ori_output.mean().backward() + + # check grad + name_to_p = {n: p for n, p in ori_model.module.named_parameters()} + for n, p in zero_model.named_parameters(): + zero_grad = zero_optimizer.get_param_grad(p) + if name_to_p[n].grad is None: + assert zero_grad is None + continue + + loose_close(zero_grad, name_to_p[n].grad, dtype=dtype) + + # zero-dp step + zero_optimizer.step() + + # original model step + ori_optimizer.step() + + # check updated param + for n, p in zero_model.named_parameters(): + loose_close(p.data, name_to_p[n].data, dtype=dtype) + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_zero_with_original_model(world_size=world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2, 4]) +@rerun_if_address_is_in_use() +def test_moe_zero_model(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_moe_zero_model(world_size=4) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py deleted file mode 100644 index 224c5c3b9..000000000 --- a/tests/test_moe/test_moe_zero_optim.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin -from colossalai.moe.manager import MOE_MANAGER -from colossalai.tensor.moe_tensor.api import is_moe_tensor -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.testing.random import seed_all -from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, run_fwd_bwd, sync_local_from_ep - - -def run_zero_test(local_rank, stage=1): - criterion = torch.nn.CrossEntropyLoss() - - MOE_MANAGER.__init__() - MOE_MANAGER.setup(parallel="EP") - moe_model = MoeModel().bfloat16() - moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=1.0) - moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") - moe_booster = Booster(plugin=moe_plugin) - moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer) - - MOE_MANAGER.__init__() - MOE_MANAGER.setup(parallel=None) - zero_model = MoeModel().bfloat16() - delete_moe_info(zero_model) - sync_local_from_ep(zero_model, moe_model) - zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1.0) - zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") - zero_booster = Booster(plugin=zero_plugin) - zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer) - - for (moe_name, moe_param), (zero_name, zero_param) in zip( - moe_model.named_parameters(), zero_model.named_parameters() - ): - if ".experts." in moe_name: - continue - assert moe_name == zero_name - assert torch.allclose( - moe_param.data, zero_param.data - ), f"{moe_name}\ntorch_param {moe_param.data}\nzero_param {zero_param.data}" - - for _ in range(1): - data = torch.randn(2, 4).bfloat16().cuda() - label = torch.randint(0, 4, (2,)).cuda() - - moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer) - zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) - assert torch.allclose(zero_out, moe_out) - moe_optimizer.step() - zero_optimizer.step() - - for (moe_name, moe_param), (zero_name, zero_param) in zip( - moe_model.named_parameters(), zero_model.named_parameters() - ): - assert moe_name == zero_name - if is_moe_tensor(moe_param): - param_size = moe_param.shape[0] - zero_param = zero_param[local_rank * param_size : (local_rank + 1) * param_size] - loose_close(moe_param.data, zero_param.data, dtype=moe_param.dtype) - - moe_optimizer.zero_grad() - zero_optimizer.zero_grad() - - -def run_dist(rank, world_size, port, stage): - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - seed_all(42 + rank) - run_zero_test(rank, stage=stage) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2]) -@pytest.mark.parametrize("stage", [1, 2]) -@rerun_if_address_is_in_use() -def test_moe_zero_optim(world_size, stage): - spawn(run_dist, world_size, stage=stage) - - -if __name__ == "__main__": - test_moe_zero_optim(world_size=2, stage=1) diff --git a/tests/test_optimizer/_utils.py b/tests/test_optimizer/_utils.py index 313624e83..4046e4118 100644 --- a/tests/test_optimizer/_utils.py +++ b/tests/test_optimizer/_utils.py @@ -234,7 +234,7 @@ def check_dist_grad(sharded_optimizer, org_model, sharded_model, weight_layer_fo if org_name in weight_layer_for_check: org_grad = org_param.grad group_id = dist.get_rank(sharded_optimizer.optim.dp_group) - dist_grad = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(group_id, id(sharded_param)) + dist_grad = sharded_optimizer.get_partitioned_gradients_by_param_id(group_id, id(sharded_param)) # dist_grad concat then reshape to org_grad shape if dist_grad: diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 06c254e56..2da679d7d 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -316,7 +316,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): dp_process_group=dp_group, verbose=True, ) - shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened + shard_to_param = dist_optim.master_to_working_param # {id(): param tensor} but flattened dist_optim.optim.setup_distributed( tp_group=tp_group, dp_group=dp_group, diff --git a/tests/test_optimizer/test_dist_came.py b/tests/test_optimizer/test_dist_came.py index c767e9684..45fe687b7 100644 --- a/tests/test_optimizer/test_dist_came.py +++ b/tests/test_optimizer/test_dist_came.py @@ -200,7 +200,7 @@ def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): dp_process_group=dp_group, verbose=True, ) - shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened + shard_to_param = dist_optim.master_to_working_param # {id(): param tensor} but flattened dist_optim.optim.setup_distributed( tp_group=tp_group, dp_group=dp_group, diff --git a/tests/test_optimizer/test_dist_lamb.py b/tests/test_optimizer/test_dist_lamb.py index c1ff78c0c..66e8e49c7 100644 --- a/tests/test_optimizer/test_dist_lamb.py +++ b/tests/test_optimizer/test_dist_lamb.py @@ -229,7 +229,7 @@ def run_dist_lamb_fwd_bwd( dp_process_group=dp_group, verbose=True, ) - shard_to_param = optim._param_store.master_to_working_param + shard_to_param = optim.master_to_working_param optim.optim.setup_distributed(tp_group, dp_group, shard_to_param, is_zero=True) else: optim.setup_distributed(tp_group) diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py index be257e818..e37a050e3 100644 --- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py +++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py @@ -32,6 +32,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group + dp_group = booster.plugin.dp_group bert = unwrap_model(org_model, "BertModel", "bert") sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") @@ -53,8 +54,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, device = origin_norm.device norm_groups = [] for group_id in range(sharded_optimizer.num_param_groups): - working_grads = sharded_optimizer._grad_store.get_working_grads_by_group_id(group_id) - norm_group = sharded_optimizer._compute_grad_norm(gradients=working_grads) + working_grads = sharded_optimizer.get_working_grads_by_group_id(group_id) + norm_group = sharded_optimizer._compute_grad_norm(dp_group, gradients=working_grads) norm_groups.append(norm_group) total_norm = 0.0 for norm in norm_groups: diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index b73552cec..4d66692a4 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -62,10 +62,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): for p1, p2 in zip(command_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): - working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)] - grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p)) + working_p = sharded_optimizer.master_to_working_param[id(p2)] + grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grad_index = ( - 0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank + 0 if sharded_optimizer._partition_grads else sharded_optimizer.pid_to_bucket_store[id(p2)].local_rank ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 3a8a1357d..8fe18f69b 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -62,10 +62,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): - working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)] - grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p)) + working_p = sharded_optimizer.master_to_working_param[id(p2)] + grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grad_index = ( - 0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank + 0 + if sharded_optimizer._partition_grads + else sharded_optimizer.pid_to_bucket_store[id(working_p)].local_rank ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] diff --git a/tests/test_zero/test_low_level/test_mem_leak.py b/tests/test_zero/test_low_level/test_mem_leak.py new file mode 100644 index 000000000..7fa59ccc5 --- /dev/null +++ b/tests/test_zero/test_low_level/test_mem_leak.py @@ -0,0 +1,61 @@ +import pytest +import torch +import torch.nn as nn + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.zero import LowLevelZeroOptimizer + + +class MlpModel(nn.Module): + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(123, 253) + + def forward(self, x): + x = self.linear1(x) + return x + + +DEL_CALLED = False + + +class TestLowLevelZeroOptimizer(LowLevelZeroOptimizer): + def __del__(self): + super().__del__() + global DEL_CALLED + DEL_CALLED = True + + +def exam_mem_leak(world_size): + """ + In this test, we test whether del will be called after the optimizer + is out of scope. + """ + # create models + zero_model = MlpModel().cuda() + + # we only test stage 1 here + # in `check_sharded_param_consistency.py`, we will test whether + # level 1 and 2 will produce exactly the same results + zero_optimizer = TestLowLevelZeroOptimizer(torch.optim.SGD(zero_model.parameters(), lr=1)) + + del zero_optimizer + + assert DEL_CALLED + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + + exam_mem_leak(world_size=world_size) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_zero_1_2(): + spawn(run_dist, 2) + + +if __name__ == "__main__": + test_zero_1_2() diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index 06a29bd1d..8df35bdaa 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -91,10 +91,13 @@ def exam_zero_1_2(): zero2_optimizer.backward(zero2_output.mean().float()) # check grad - z1g_list = zero1_optimizer._grad_store.get_working_grads_by_group_id(0) - z2g_list = zero2_optimizer._grad_store.get_working_grads_by_group_id(0) - for z1g, z2g in zip(z1g_list, z2g_list): - assert torch.equal(z1g, z2g) + for p1, p2 in zip(zero1_model.parameters(), zero2_model.parameters()): + g1 = zero1_optimizer.get_param_grad(p1) + g2 = zero2_optimizer.get_param_grad(p2) + if g1 is None or g2 is None: + assert g1 is None and g2 is None + continue + assert torch.allclose(g1, g2) # step zero1_optimizer.step() @@ -102,7 +105,7 @@ def exam_zero_1_2(): # check updated param for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): - assert torch.equal(z1p.data, z2p.data) + assert torch.allclose(z1p, z2p) @parameterize("dtype", [torch.float16, torch.bfloat16]) @@ -120,7 +123,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): seed_all(1453) # create models - torch_model = MlpModel().cuda() + torch_model = MlpModel().cuda().to(dtype) zero_model = copy.deepcopy(torch_model).to(dtype) torch_model = DDP(torch_model.cuda(), static_graph=True).cuda() @@ -142,39 +145,41 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) seed_all(1453 + local_rank) - # create - input_data = torch.rand(32, 123).cuda() - # zero-dp forward - zero_output = zero_model(input_data.to(dtype)) + for _ in range(2): + # create + input_data = torch.rand(32, 123).cuda().to(dtype) - # torch-ddp forward - torch_output = torch_model(input_data) - loose_close(zero_output, torch_output, dtype=dtype) + # zero-dp forward + zero_output = zero_model(input_data) - # zero-dp backward - zero_optimizer.backward(zero_output.mean().float()) + # torch-ddp forward + torch_output = torch_model(input_data) + loose_close(zero_output, torch_output, dtype=dtype) - # torch-ddp backward - torch_output.mean().backward() + # zero-dp backward + zero_optimizer.backward(zero_output.mean()) - # check grad - for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - if p.grad is not None: - zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(z1p)) - torch_grad_list = split_ddp_grad(p.grad, world_size) - for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list): - loose_close(zero_grad, torch_grad, dtype=dtype) + # torch-ddp backward + torch_output.mean().backward() - # zero-dp step - zero_optimizer.step() + # check grad + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + zero_grad = zero_optimizer.get_param_grad(z1p) + if p.grad is None: + assert zero_grad is None + continue + loose_close(p.grad, zero_grad, dtype=dtype) - # torch ddp step - torch_optimizer.step() + # zero-dp step + zero_optimizer.step() - # check updated param - for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - loose_close(p.data, z1p.data, dtype=dtype) + # torch ddp step + torch_optimizer.step() + + # check updated param + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + loose_close(p, z1p, dtype=dtype) def run_dist(rank, world_size, port): From 8ab46b4000d36c76cde93c6bb553411e815980fb Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Mon, 1 Jul 2024 16:45:09 +0800 Subject: [PATCH 03/15] [Shardformer] change qwen2 modeling into gradient checkpointing style (#5874) --- colossalai/shardformer/modeling/qwen2.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index e0aa5fba4..11c26822f 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -168,13 +168,27 @@ class Qwen2PipelineForwards: next_decoder_cache = None start_idx, end_idx = stage_index[0], stage_index[1] + num_ckpt_layers = 0 + if self.gradient_checkpointing and self.training: + num_ckpt_layers = end_idx - start_idx + # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer + if shard_config.gradient_checkpoint_config is not None: + num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( + stage=stage_manager.stage, + num_stages=stage_manager.num_stages, + num_layers=end_idx - start_idx, + model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), + num_model_chunks=stage_manager.num_model_chunks, + ) + assert num_ckpt_layers <= end_idx - start_idx + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: + if idx - start_idx < num_ckpt_layers: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, @@ -198,7 +212,6 @@ class Qwen2PipelineForwards: if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - if output_attentions: all_self_attns += (layer_outputs[1],) From 936d0b0f7ba7f9b4e0d53c343bcf6afd10c63de1 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 1 Jul 2024 17:07:22 +0800 Subject: [PATCH 04/15] [doc] Update llama + sp compatibility; fix dist optim table Co-authored-by: Edenzzzz --- .../en/features/distributed_optimizers.md | 52 +++++++++--------- docs/source/en/features/shardformer.md | 2 +- .../features/distributed_optimizers.md | 53 +++++++++---------- docs/source/zh-Hans/features/shardformer.md | 2 +- 4 files changed, 53 insertions(+), 56 deletions(-) diff --git a/docs/source/en/features/distributed_optimizers.md b/docs/source/en/features/distributed_optimizers.md index f95b23304..279bc8f9d 100644 --- a/docs/source/en/features/distributed_optimizers.md +++ b/docs/source/en/features/distributed_optimizers.md @@ -87,44 +87,42 @@ optim = DistGaloreAwamW( ## Plugin compatibility - - - - - + + + + + + - + - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index 68d310f5c..40b8954b5 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -55,7 +55,7 @@ Model/Feature Compatibility Matrix: - + diff --git a/docs/source/zh-Hans/features/distributed_optimizers.md b/docs/source/zh-Hans/features/distributed_optimizers.md index 7a7068077..5761f8c55 100644 --- a/docs/source/zh-Hans/features/distributed_optimizers.md +++ b/docs/source/zh-Hans/features/distributed_optimizers.md @@ -84,44 +84,42 @@ optim = DistGaloreAwamW( ## 兼容性
Model/FeatureLambGaLoreAdafactorCAMEOptimizer/PluginHybrid Parallel PluginLow Level Zero PluginTorch DDP PluginGemini PluginMoe Hybrid Plugin
Hybrid Parallel
Plugin
Lamb ✔️ ✔️ ✔️✔️
Low Level Zero
Plugin
✔️✔️✔️
Torch DDP
Plugin
✔️✔️✔️✔️
Gemini
Plugin
Moe Hybrid
Plugin
GaLore✔️✔️✔️
Adafactor✔️✔️✔️
CAME✔️✔️✔️
✔️ ✔️ ✔️✔️
- - - - - + + + + + + - + - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + @@ -130,6 +128,7 @@ optim = DistGaloreAwamW(
Model/FeatureLambGaLoreAdafactorCAMEOptimizer/PluginHybrid Parallel PluginLow Level Zero PluginTorch DDP PluginGemini PluginMoe Hybrid Plugin
Hybrid Parallel
Plugin
Lamb ✔️ ✔️ ✔️✔️
Low Level Zero
Plugin
✔️✔️✔️
Torch DDP
Plugin
✔️✔️✔️✔️
Gemini
Plugin
Moe Hybrid
Plugin
GaLore✔️✔️✔️
Adafactor✔️✔️✔️
CAME✔️✔️✔️
+ ## API 参考 diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index 00e1a13d6..02290f3d6 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -51,7 +51,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. ✔️ ✔️ ✔️ - ❌ + ✔️ ❌ From 7c2f79fa98c837ee4f5995d7948371040fa94572 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Jul 2024 17:16:41 +0800 Subject: [PATCH 05/15] [pre-commit.ci] pre-commit autoupdate (#5572) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [pre-commit.ci] pre-commit autoupdate updates: - [github.com/PyCQA/autoflake: v2.2.1 → v2.3.1](https://github.com/PyCQA/autoflake/compare/v2.2.1...v2.3.1) - [github.com/pycqa/isort: 5.12.0 → 5.13.2](https://github.com/pycqa/isort/compare/5.12.0...5.13.2) - [github.com/psf/black-pre-commit-mirror: 23.9.1 → 24.4.2](https://github.com/psf/black-pre-commit-mirror/compare/23.9.1...24.4.2) - [github.com/pre-commit/mirrors-clang-format: v13.0.1 → v18.1.7](https://github.com/pre-commit/mirrors-clang-format/compare/v13.0.1...v18.1.7) - [github.com/pre-commit/pre-commit-hooks: v4.3.0 → v4.6.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.3.0...v4.6.0) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 10 +++--- .../ColossalChat/coati/dataset/loader.py | 16 +++++---- .../ColossalChat/coati/models/loss.py | 1 + .../ColossalChat/coati/models/reward_model.py | 1 + .../ColossalChat/coati/trainer/utils.py | 1 + .../colossal_eval/dataset/agieval.py | 14 ++++++-- .../colossal_eval/dataset/ceval.py | 6 +++- .../colossal_eval/dataset/mtbench.py | 8 +++-- .../colossal_eval/models/huggingface.py | 4 ++- .../colossalqa/chain/retrieval_qa/base.py | 1 + .../chain/retrieval_qa/load_chain.py | 1 + .../colossalqa/chain/retrieval_qa/stuff.py | 1 + .../data_loader/table_dataloader.py | 1 - .../ColossalQA/colossalqa/local/llm.py | 1 + .../ColossalQA/colossalqa/local/utils.py | 1 + applications/ColossalQA/colossalqa/memory.py | 1 + .../ColossalQA/colossalqa/mylogging.py | 1 + .../colossalqa/retrieval_conversation_en.py | 1 + .../retrieval_conversation_universal.py | 1 + .../colossalqa/retrieval_conversation_zh.py | 1 + .../ColossalQA/colossalqa/retriever.py | 1 + .../text_splitter/chinese_text_splitter.py | 1 + .../examples/retrieval_conversation_en.py | 1 + ...rieval_conversation_en_customer_service.py | 1 + .../examples/retrieval_conversation_zh.py | 1 + ...tent_classification_zh_customer_service.py | 1 + .../meta_profiler/meta_registry/conv.py | 20 ++++++----- colossalai/inference/batch_bucket.py | 12 +++---- colossalai/inference/config.py | 17 +++++---- colossalai/inference/core/engine.py | 1 - colossalai/inference/core/rpc_engine.py | 1 - colossalai/inference/executor/rpc_worker.py | 1 - .../inference/kv_cache/kvcache_manager.py | 8 +++-- colossalai/inference/utils.py | 1 + .../initializer_2d.py | 4 +-- colossalai/legacy/inference/async_engine.py | 1 - .../inference/dynamic_batching/io_struct.py | 12 +++---- .../inference/hybridengine/modeling/_utils.py | 1 + .../tensor_parallel/batch_infer_state.py | 1 + .../tensor_parallel/kvcache_manager.py | 1 + .../tensor_parallel/modeling/_utils.py | 1 + .../modeling/chatglm2_6b/modeling_chatglm.py | 1 + .../nvidia_bert_dataset_provider.py | 8 +++-- .../diffusion/ldm/models/diffusion/ddpm.py | 20 ++++++----- .../models/diffusion/dpm_solver/sampler.py | 1 + .../modules/diffusionmodules/openaimodel.py | 36 ++++++++++--------- .../ldm/modules/midas/midas/midas_net.py | 1 + .../modules/midas/midas/midas_net_custom.py | 1 + .../diffusion/ldm/modules/midas/utils.py | 1 + .../data/datasets/helpers.cpp | 12 +++---- extensions/csrc/kernel/arm/cpu_adam_arm.h | 2 +- extensions/csrc/kernel/x86/cpu_adam.h | 2 +- .../kit/model_zoo/torchvision/torchvision.py | 12 +++---- 53 files changed, 157 insertions(+), 100 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9871e1184..f2c408bce 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,34 +1,34 @@ repos: - repo: https://github.com/PyCQA/autoflake - rev: v2.2.1 + rev: v2.3.1 hooks: - id: autoflake name: autoflake (python) args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports'] - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort name: sort all imports (python) - repo: https://github.com/psf/black-pre-commit-mirror - rev: 23.9.1 + rev: 24.4.2 hooks: - id: black name: black formatter args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310'] - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v13.0.1 + rev: v18.1.7 hooks: - id: clang-format name: clang formatter types_or: [c++, c] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.6.0 hooks: - id: check-yaml - id: check-merge-conflict diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index cea1b2dbb..a0cd17bb4 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -83,15 +83,19 @@ class DataCollatorForSupervisedDataset(object): # `List[torch.Tensor]` batch_input_ids = [ - torch.LongTensor(instance["input_ids"][: self.max_length]) - if len(instance["input_ids"]) > self.max_length - else torch.LongTensor(instance["input_ids"]) + ( + torch.LongTensor(instance["input_ids"][: self.max_length]) + if len(instance["input_ids"]) > self.max_length + else torch.LongTensor(instance["input_ids"]) + ) for instance in instances ] batch_labels = [ - torch.LongTensor(instance["labels"][: self.max_length]) - if len(instance["labels"]) > self.max_length - else torch.LongTensor(instance["labels"]) + ( + torch.LongTensor(instance["labels"][: self.max_length]) + if len(instance["labels"]) > self.max_length + else torch.LongTensor(instance["labels"]) + ) for instance in instances ] if self.tokenizer.padding_side == "right": diff --git a/applications/ColossalChat/coati/models/loss.py b/applications/ColossalChat/coati/models/loss.py index aaef447a4..e411dded1 100755 --- a/applications/ColossalChat/coati/models/loss.py +++ b/applications/ColossalChat/coati/models/loss.py @@ -1,6 +1,7 @@ """ loss functions """ + from typing import Optional, Tuple import torch diff --git a/applications/ColossalChat/coati/models/reward_model.py b/applications/ColossalChat/coati/models/reward_model.py index 394f3ea90..573b9d889 100755 --- a/applications/ColossalChat/coati/models/reward_model.py +++ b/applications/ColossalChat/coati/models/reward_model.py @@ -1,6 +1,7 @@ """ reward model """ + from typing import Optional import torch diff --git a/applications/ColossalChat/coati/trainer/utils.py b/applications/ColossalChat/coati/trainer/utils.py index 5ce1e9ef0..3c836b4b4 100755 --- a/applications/ColossalChat/coati/trainer/utils.py +++ b/applications/ColossalChat/coati/trainer/utils.py @@ -1,6 +1,7 @@ """ Training utilities for Coati. """ + from typing import Any import torch diff --git a/applications/ColossalEval/colossal_eval/dataset/agieval.py b/applications/ColossalEval/colossal_eval/dataset/agieval.py index 32f8544e9..d5f230249 100644 --- a/applications/ColossalEval/colossal_eval/dataset/agieval.py +++ b/applications/ColossalEval/colossal_eval/dataset/agieval.py @@ -78,7 +78,9 @@ def get_prompt(line: Dict, dataset_name: str, logger: DistributedLogger) -> Dict option_string = "ABCDEFG" count = len(line["options"]) - input = "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:" + input = ( + "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:" + ) all_classes = list(option_string[0:count]) @@ -150,7 +152,15 @@ def combine_prompt(prompt_path, dataset_name, load_explanation=True, chat_mode=F ) elif dataset_name in chinese_qa_datasets: question_input = ( - "问题:" + passage + " " + question + "\n" + "从以下选项中选择:" + " ".join(options) + "\n" + "答案:{}".format(label) + "问题:" + + passage + + " " + + question + + "\n" + + "从以下选项中选择:" + + " ".join(options) + + "\n" + + "答案:{}".format(label) ) elif dataset_name in english_cloze_datasets: question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer) diff --git a/applications/ColossalEval/colossal_eval/dataset/ceval.py b/applications/ColossalEval/colossal_eval/dataset/ceval.py index 2cf09ec4d..915f4d9b0 100644 --- a/applications/ColossalEval/colossal_eval/dataset/ceval.py +++ b/applications/ColossalEval/colossal_eval/dataset/ceval.py @@ -57,7 +57,11 @@ ceval_subject_mapping = { "urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"], "accountant": ["Accountant", "注册会计师", "Other"], "fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"], - "environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"], + "environmental_impact_assessment_engineer": [ + "Environmental Impact Assessment Engineer", + "环境影响评价工程师", + "Other", + ], "tax_accountant": ["Tax Accountant", "税务师", "Other"], "physician": ["Physician", "医师资格", "Other"], } diff --git a/applications/ColossalEval/colossal_eval/dataset/mtbench.py b/applications/ColossalEval/colossal_eval/dataset/mtbench.py index 9e74a4d82..031415567 100644 --- a/applications/ColossalEval/colossal_eval/dataset/mtbench.py +++ b/applications/ColossalEval/colossal_eval/dataset/mtbench.py @@ -56,9 +56,11 @@ class MTBenchDataset(BaseDataset): "instruction": question["turns"], "input": "", "output": [], - "target": [""] * turn_number - if question["question_id"] not in reference - else reference[question["question_id"]], + "target": ( + [""] * turn_number + if question["question_id"] not in reference + else reference[question["question_id"]] + ), } if category in dataset["test"]: diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index fff697e21..23c399cce 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -77,7 +77,9 @@ class HuggingFaceModel(BaseModel): self.indices_for_choices[0].append( self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1] ) - self.indices_for_choices[1].append(self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1]) + self.indices_for_choices[1].append( + self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1] + ) def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict): """ diff --git a/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py b/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py index 80dbf47de..2f9750de3 100644 --- a/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py +++ b/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py @@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at https://github.com/langchain-ai/langchain The original code is licensed under the MIT license. """ + from __future__ import annotations import copy diff --git a/applications/ColossalQA/colossalqa/chain/retrieval_qa/load_chain.py b/applications/ColossalQA/colossalqa/chain/retrieval_qa/load_chain.py index a2b1f81e3..8cb8ef536 100644 --- a/applications/ColossalQA/colossalqa/chain/retrieval_qa/load_chain.py +++ b/applications/ColossalQA/colossalqa/chain/retrieval_qa/load_chain.py @@ -8,6 +8,7 @@ This code is based on LangChain Ai's langchain, which can be found at https://github.com/langchain-ai/langchain The original code is licensed under the MIT license. """ + import copy from typing import Any, Mapping, Optional, Protocol diff --git a/applications/ColossalQA/colossalqa/chain/retrieval_qa/stuff.py b/applications/ColossalQA/colossalqa/chain/retrieval_qa/stuff.py index bf7ad0ffc..64e476438 100644 --- a/applications/ColossalQA/colossalqa/chain/retrieval_qa/stuff.py +++ b/applications/ColossalQA/colossalqa/chain/retrieval_qa/stuff.py @@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at https://github.com/langchain-ai/langchain The original code is licensed under the MIT license. """ + import copy from typing import Any, List diff --git a/applications/ColossalQA/colossalqa/data_loader/table_dataloader.py b/applications/ColossalQA/colossalqa/data_loader/table_dataloader.py index 29542466f..0ad66f0ad 100644 --- a/applications/ColossalQA/colossalqa/data_loader/table_dataloader.py +++ b/applications/ColossalQA/colossalqa/data_loader/table_dataloader.py @@ -2,7 +2,6 @@ Class for loading table type data. please refer to Pandas-Input/Output for file format details. """ - import glob import os diff --git a/applications/ColossalQA/colossalqa/local/llm.py b/applications/ColossalQA/colossalqa/local/llm.py index 30a456c3d..58a4811d9 100644 --- a/applications/ColossalQA/colossalqa/local/llm.py +++ b/applications/ColossalQA/colossalqa/local/llm.py @@ -12,6 +12,7 @@ TEST_PROMPT_CHATGLM="续写文章:惊蛰一过,春寒加剧。先是料料 logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True) """ + from typing import Any, List, Mapping, Optional import torch diff --git a/applications/ColossalQA/colossalqa/local/utils.py b/applications/ColossalQA/colossalqa/local/utils.py index ed90264ca..2cbd474bd 100644 --- a/applications/ColossalQA/colossalqa/local/utils.py +++ b/applications/ColossalQA/colossalqa/local/utils.py @@ -1,6 +1,7 @@ """ Generation utilities """ + import json from typing import List diff --git a/applications/ColossalQA/colossalqa/memory.py b/applications/ColossalQA/colossalqa/memory.py index 7a5512281..d8de544a5 100644 --- a/applications/ColossalQA/colossalqa/memory.py +++ b/applications/ColossalQA/colossalqa/memory.py @@ -2,6 +2,7 @@ Implement a memory class for storing conversation history Support long term and short term memory """ + from typing import Any, Dict, List from colossalqa.chain.memory.summary import ConversationSummaryMemory diff --git a/applications/ColossalQA/colossalqa/mylogging.py b/applications/ColossalQA/colossalqa/mylogging.py index 574c33b41..67e2a83ed 100644 --- a/applications/ColossalQA/colossalqa/mylogging.py +++ b/applications/ColossalQA/colossalqa/mylogging.py @@ -1,6 +1,7 @@ """ Class for logging with extra control for debugging """ + import logging diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_en.py b/applications/ColossalQA/colossalqa/retrieval_conversation_en.py index 96bce82b9..cab168075 100644 --- a/applications/ColossalQA/colossalqa/retrieval_conversation_en.py +++ b/applications/ColossalQA/colossalqa/retrieval_conversation_en.py @@ -1,6 +1,7 @@ """ Script for Chinese retrieval based conversation system backed by ChatGLM """ + from typing import Tuple from colossalqa.chain.retrieval_qa.base import RetrievalQA diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py b/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py index 6e77bb2ae..a991b202e 100644 --- a/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py +++ b/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py @@ -1,6 +1,7 @@ """ Multilingual retrieval based conversation system """ + from typing import List from colossalqa.data_loader.document_loader import DocumentLoader diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py b/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py index 4eef41947..6c9b69117 100644 --- a/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py +++ b/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py @@ -1,6 +1,7 @@ """ Script for Chinese retrieval based conversation system backed by ChatGLM """ + from typing import Tuple from colossalqa.chain.retrieval_qa.base import RetrievalQA diff --git a/applications/ColossalQA/colossalqa/retriever.py b/applications/ColossalQA/colossalqa/retriever.py index 6a0c69859..ec4941ddd 100644 --- a/applications/ColossalQA/colossalqa/retriever.py +++ b/applications/ColossalQA/colossalqa/retriever.py @@ -1,6 +1,7 @@ """ Code for custom retriver with incremental update """ + import copy import hashlib import os diff --git a/applications/ColossalQA/colossalqa/text_splitter/chinese_text_splitter.py b/applications/ColossalQA/colossalqa/text_splitter/chinese_text_splitter.py index 3815f5ed2..697af484b 100644 --- a/applications/ColossalQA/colossalqa/text_splitter/chinese_text_splitter.py +++ b/applications/ColossalQA/colossalqa/text_splitter/chinese_text_splitter.py @@ -1,6 +1,7 @@ """ Code for Chinese text splitter """ + from typing import Any, List, Optional from colossalqa.text_splitter.utils import get_cleaned_paragraph diff --git a/applications/ColossalQA/examples/retrieval_conversation_en.py b/applications/ColossalQA/examples/retrieval_conversation_en.py index fe2b9b4db..b7339de93 100644 --- a/applications/ColossalQA/examples/retrieval_conversation_en.py +++ b/applications/ColossalQA/examples/retrieval_conversation_en.py @@ -1,6 +1,7 @@ """ Script for English retrieval based conversation system backed by LLaMa2 """ + import argparse import os diff --git a/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py b/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py index d4ba73b94..a0c90e7c5 100644 --- a/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py +++ b/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py @@ -1,6 +1,7 @@ """ Script for English retrieval based conversation system backed by LLaMa2 """ + import argparse import json import os diff --git a/applications/ColossalQA/examples/retrieval_conversation_zh.py b/applications/ColossalQA/examples/retrieval_conversation_zh.py index b143b9baa..96641edf5 100644 --- a/applications/ColossalQA/examples/retrieval_conversation_zh.py +++ b/applications/ColossalQA/examples/retrieval_conversation_zh.py @@ -1,6 +1,7 @@ """ Script for Chinese retrieval based conversation system backed by ChatGLM """ + import argparse import os diff --git a/applications/ColossalQA/examples/retrieval_intent_classification_zh_customer_service.py b/applications/ColossalQA/examples/retrieval_intent_classification_zh_customer_service.py index adb654494..865ade5bb 100644 --- a/applications/ColossalQA/examples/retrieval_intent_classification_zh_customer_service.py +++ b/applications/ColossalQA/examples/retrieval_intent_classification_zh_customer_service.py @@ -1,6 +1,7 @@ """ Script for English retrieval based conversation system backed by LLaMa2 """ + import argparse import os diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py index 2f630995c..b1e32e885 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py @@ -107,20 +107,22 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward fwd_memory_cost = MemoryCost( activation=compute_size_in_bytes([input_tensor, output_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) - if has_bias - else compute_size_in_bytes(weight_tensor), + parameter=( + compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor) + ), temp=0, buffer=0, ) bwd_memory_cost = MemoryCost( - activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]) - if has_bias - else compute_size_in_bytes([input_tensor, weight_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) - if has_bias - else compute_size_in_bytes(weight_tensor), + activation=( + compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]) + if has_bias + else compute_size_in_bytes([input_tensor, weight_tensor]) + ), + parameter=( + compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor) + ), temp=0, buffer=0, ) diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index 88bde3a3b..581d114d2 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -247,16 +247,16 @@ class BatchBucket: self._sequences_dict[seq.request_id] = seq self._sequences_indexes[seq.request_id] = self._current_batch_size + i # TODO external (rename): modify Sequence.sentence_len to seq_len - self._sequence_lengths[ - self._current_batch_size : self._current_batch_size + num_seqs_to_add - ] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32) + self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = ( + torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32) + ) # NOTE block tables to be updated by kvcache manager block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] if alloc_block_tables is not None: # copy block ids from provided block tables - self._block_tables[ - self._current_batch_size : self._current_batch_size + num_seqs_to_add - ] = alloc_block_tables + self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = ( + alloc_block_tables + ) elif alloc_block_tables_fn: alloc_block_tables_fn( block_tables, diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index c73ee9df4..e114e8a61 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -1,6 +1,7 @@ """ Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference. """ + import logging from abc import ABC, abstractmethod from dataclasses import dataclass, fields @@ -82,9 +83,9 @@ class InputMetaData(RPC_PARAM): dtype: torch.dtype = torch.float32 use_spec_dec: bool = False num_tokens_to_verify: int = 0 - batch_token_ids: Optional[ - List[List[int]] - ] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process + batch_token_ids: Optional[List[List[int]]] = ( + None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process + ) def to_rpc_param(self) -> Dict[str, any]: return { @@ -202,9 +203,9 @@ class InferenceConfig(RPC_PARAM): prompt_template: Optional[str] = None do_sample: bool = False beam_width: int = 1 # TODO: beam search is not support for now - prefill_ratio: Optional[ - float - ] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio + prefill_ratio: Optional[float] = ( + 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio + ) pad_input: bool = False early_stopping: Optional[bool] = False top_k: Optional[int] = 50 @@ -234,7 +235,9 @@ class InferenceConfig(RPC_PARAM): high_precision: Optional[bool] = False # cuda_graph - use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference + use_cuda_graph: bool = ( + False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference + ) max_context_len_to_capture: int = 512 # StreamingLLM (sliding window attention with attention sinks) diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index f0918c88c..8f8aef65e 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -47,7 +47,6 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] class InferenceEngine: - """ InferenceEngine which manages the inference process.. diff --git a/colossalai/inference/core/rpc_engine.py b/colossalai/inference/core/rpc_engine.py index 87222a744..749360872 100644 --- a/colossalai/inference/core/rpc_engine.py +++ b/colossalai/inference/core/rpc_engine.py @@ -34,7 +34,6 @@ def run_server(host, port, event: mp.Event = None): class RPCInferenceEngine(InferenceEngine): - """ InferenceEngine which manages the inference process.. diff --git a/colossalai/inference/executor/rpc_worker.py b/colossalai/inference/executor/rpc_worker.py index a5199cb74..a4fd20a69 100644 --- a/colossalai/inference/executor/rpc_worker.py +++ b/colossalai/inference/executor/rpc_worker.py @@ -42,7 +42,6 @@ logger = get_dist_logger(__name__) class rpcWorkerService(rpyc.Service): - """ Execute the computation tasks and manage its own kv cache diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 378eb2ff9..dac5037f4 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -279,9 +279,11 @@ class KVCacheManager: block.add_ref() self._allocate_on_block( block, - block.block_size - if context_lengths[i] % block.block_size == 0 - else context_lengths[i].item() % block.block_size, + ( + block.block_size + if context_lengths[i] % block.block_size == 0 + else context_lengths[i].item() % block.block_size + ), ) for block_id in alloc_block_ids: if block_id in alloc_block_ids[last_block_locs]: diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index 8c155e6ca..332e84d37 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -1,6 +1,7 @@ """ Utils for model inference """ + import math import os import re diff --git a/colossalai/legacy/context/process_group_initializer/initializer_2d.py b/colossalai/legacy/context/process_group_initializer/initializer_2d.py index 1c08d4d42..fc51844b6 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_2d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_2d.py @@ -138,9 +138,7 @@ class Initializer_2D(ProcessGroupInitializer): self.num_group = self.world_size // self.tensor_parallel_size self.summa_dim = int(math.sqrt(self.tensor_parallel_size)) - assert ( - self.tensor_parallel_size == self.summa_dim**2 - ), "2D summa dim should equal to tensor parallel size ^ 0.5" + assert self.tensor_parallel_size == self.summa_dim**2, "2D summa dim should equal to tensor parallel size ^ 0.5" _check_summa_env_var(self.summa_dim) self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs) diff --git a/colossalai/legacy/inference/async_engine.py b/colossalai/legacy/inference/async_engine.py index d0890ba3e..b4c523669 100644 --- a/colossalai/legacy/inference/async_engine.py +++ b/colossalai/legacy/inference/async_engine.py @@ -54,7 +54,6 @@ class RequestTracker: class Async_Engine: - """ Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager Background loop: inference reqs in waiting list (Listen) diff --git a/colossalai/legacy/inference/dynamic_batching/io_struct.py b/colossalai/legacy/inference/dynamic_batching/io_struct.py index fc5ecfe57..abc41cc8e 100644 --- a/colossalai/legacy/inference/dynamic_batching/io_struct.py +++ b/colossalai/legacy/inference/dynamic_batching/io_struct.py @@ -118,16 +118,16 @@ class Batch: class BatchTokenIdOut: def __init__(self): - self.reqs_infs: List[ - Tuple[str, int, Dict, bool, bool] - ] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state] + self.reqs_infs: List[Tuple[str, int, Dict, bool, bool]] = ( + [] + ) # [req_id, new_token_id, gen_metadata, finished_state, abort_state] class BatchStrOut: def __init__(self): - self.reqs_infs: List[ - Tuple[str, str, Dict, bool, bool] - ] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state] + self.reqs_infs: List[Tuple[str, str, Dict, bool, bool]] = ( + [] + ) # [req_id, token_str, gen_metadata, finished_state, abort_state] class AbortReq: diff --git a/colossalai/legacy/inference/hybridengine/modeling/_utils.py b/colossalai/legacy/inference/hybridengine/modeling/_utils.py index 068b64b4f..46d4222c4 100644 --- a/colossalai/legacy/inference/hybridengine/modeling/_utils.py +++ b/colossalai/legacy/inference/hybridengine/modeling/_utils.py @@ -1,6 +1,7 @@ """ Utils for model inference """ + import os import torch diff --git a/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py b/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py index f707a86df..b72610899 100644 --- a/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py +++ b/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py @@ -14,6 +14,7 @@ class BatchInferState: Information to be passed and used for a batch of inputs during a single model forward """ + batch_size: int max_len_in_batch: int diff --git a/colossalai/legacy/inference/tensor_parallel/kvcache_manager.py b/colossalai/legacy/inference/tensor_parallel/kvcache_manager.py index 91bb96a1f..8c54fda26 100644 --- a/colossalai/legacy/inference/tensor_parallel/kvcache_manager.py +++ b/colossalai/legacy/inference/tensor_parallel/kvcache_manager.py @@ -4,6 +4,7 @@ of the ModelTC/lightllm GitHub repository https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design. """ + import torch from transformers.utils import logging diff --git a/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py b/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py index 068b64b4f..46d4222c4 100644 --- a/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py +++ b/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py @@ -1,6 +1,7 @@ """ Utils for model inference """ + import os import torch diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py index bf581300a..6ae4b06e5 100644 --- a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py +++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py @@ -33,6 +33,7 @@ This license shall be governed and construed in accordance with the laws of Peop Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. """ + """ PyTorch ChatGLM model. """ import copy diff --git a/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py b/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py index 09677a619..4d08d9941 100644 --- a/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py +++ b/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py @@ -52,9 +52,11 @@ class pretraining_dataset(Dataset): def __getitem__(self, index): [input_ids, input_mask, segment_ids, masked_lm_labels] = [ - torch.from_numpy(input[index].astype(np.int64)) - if indice < 5 - else torch.from_numpy(np.asarray(input[index].astype(np.int64))) + ( + torch.from_numpy(input[index].astype(np.int64)) + if indice < 5 + else torch.from_numpy(np.asarray(input[index].astype(np.int64))) + ) for indice, input in enumerate(self.inputs) ] diff --git a/examples/images/diffusion/ldm/models/diffusion/ddpm.py b/examples/images/diffusion/ldm/models/diffusion/ddpm.py index 20e26256e..3cf6aceb5 100644 --- a/examples/images/diffusion/ldm/models/diffusion/ddpm.py +++ b/examples/images/diffusion/ldm/models/diffusion/ddpm.py @@ -229,9 +229,7 @@ class DDPM(pl.LightningModule): ) if self.parameterization == "eps": - lvlb_weights = self.betas**2 / ( - 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod) - ) + lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) elif self.parameterization == "x0": lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod)) elif self.parameterization == "v": @@ -1186,9 +1184,11 @@ class LatentDiffusion(DDPM): if cond is not None: if isinstance(cond, dict): cond = { - key: cond[key][:batch_size] - if not isinstance(cond[key], list) - else list(map(lambda x: x[:batch_size], cond[key])) + key: ( + cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + ) for key in cond } else: @@ -1321,9 +1321,11 @@ class LatentDiffusion(DDPM): if cond is not None: if isinstance(cond, dict): cond = { - key: cond[key][:batch_size] - if not isinstance(cond[key], list) - else list(map(lambda x: x[:batch_size], cond[key])) + key: ( + cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + ) for key in cond } else: diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py index 55dac8555..4104fe3b0 100644 --- a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py @@ -1,4 +1,5 @@ """SAMPLING ONLY.""" + import torch from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py index 6c80f3229..afde5dfd4 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py @@ -640,23 +640,25 @@ class UNetModel(nn.Module): use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( # always uses a self-attn - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, - use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, + ( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + ) ), ResBlock( ch, diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py index 0dd87b596..8c13f39ff 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py @@ -2,6 +2,7 @@ This file contains code that is adapted from https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py """ + import torch import torch.nn as nn diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py index 4d30744c4..c79581afc 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py @@ -2,6 +2,7 @@ This file contains code that is adapted from https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py """ + import torch import torch.nn as nn diff --git a/examples/images/diffusion/ldm/modules/midas/utils.py b/examples/images/diffusion/ldm/modules/midas/utils.py index 1428d42b2..f7fc7dcc9 100644 --- a/examples/images/diffusion/ldm/modules/midas/utils.py +++ b/examples/images/diffusion/ldm/modules/midas/utils.py @@ -1,4 +1,5 @@ """Utils for monoDepth.""" + import re import sys diff --git a/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp index 52977e631..fe9968177 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp +++ b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp @@ -369,9 +369,9 @@ py::array build_mapping_impl(const py::array_t& docs_, } } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { if (!second) { if (verbose) { @@ -606,9 +606,9 @@ py::array build_blocks_mapping_impl( num_sent = 0; } } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { if (!second) { if (verbose) { diff --git a/extensions/csrc/kernel/arm/cpu_adam_arm.h b/extensions/csrc/kernel/arm/cpu_adam_arm.h index c731850ed..d48968e21 100644 --- a/extensions/csrc/kernel/arm/cpu_adam_arm.h +++ b/extensions/csrc/kernel/arm/cpu_adam_arm.h @@ -4,7 +4,7 @@ #include -#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define ROUND_DOWN(size, step) ((size) & ~((step) - 1)) #define TILE (128 * 1024 * 1024) #if defined(__aarch64__) diff --git a/extensions/csrc/kernel/x86/cpu_adam.h b/extensions/csrc/kernel/x86/cpu_adam.h index db1f26d5f..45e1dde62 100644 --- a/extensions/csrc/kernel/x86/cpu_adam.h +++ b/extensions/csrc/kernel/x86/cpu_adam.h @@ -32,7 +32,7 @@ SOFTWARE #include #endif -#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define ROUND_DOWN(size, step) ((size) & ~((step) - 1)) #define TILE (128 * 1024 * 1024) #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) diff --git a/tests/kit/model_zoo/torchvision/torchvision.py b/tests/kit/model_zoo/torchvision/torchvision.py index 57b633e9d..c0524d089 100644 --- a/tests/kit/model_zoo/torchvision/torchvision.py +++ b/tests/kit/model_zoo/torchvision/torchvision.py @@ -34,14 +34,14 @@ def swin_s(): # special output transform fn -google_net_output_transform_fn = ( - lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x) +google_net_output_transform_fn = lambda x: ( + dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x) ) -swin_s_output_output_transform_fn = ( - lambda x: {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x) +swin_s_output_output_transform_fn = lambda x: ( + {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x) ) -inception_v3_output_transform_fn = ( - lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x) +inception_v3_output_transform_fn = lambda x: ( + dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x) ) model_zoo.register( From ea94c07b959e8895b713d6dd68b168ea37db6b7b Mon Sep 17 00:00:00 2001 From: Haze188 Date: Tue, 2 Jul 2024 12:42:02 +0800 Subject: [PATCH 06/15] [hotfix] fix the bug that large tensor exceed the maximum capacity of TensorBucket (#5879) --- colossalai/zero/low_level/low_level_optim.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index e06cf0581..bdc91b51f 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -549,6 +549,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper): working_param = real_working_params[group_id][idx] param_to_gather = master_param.to(device).to(self._dtype) pg = self.param_to_pg[working_param] + if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: + buffer_tensor = torch.empty_like( + torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))]) + ) + dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg) + working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param)) + continue try: self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) except RuntimeError: From eb24fcd914f4c38fb82bc082db84d13d50865572 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 3 Jul 2024 14:57:57 +0800 Subject: [PATCH 07/15] [Hotfix] Fix OPT gradient checkpointing forward Co-authored-by: Edenzzzz --- colossalai/shardformer/modeling/opt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index f10860fef..b250b4976 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -221,7 +221,7 @@ class OPTPipelineForwards: past_key_value = past_key_values[idx] if past_key_values is not None else None if decoder.gradient_checkpointing and decoder.training: - layer_outputs = self._gradient_checkpointing_func( + layer_outputs = self.decoder._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_attention_mask, From 6cd4c32be4c0ced9a70e228530f383c5f4a580de Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Wed, 3 Jul 2024 20:02:19 +0800 Subject: [PATCH 08/15] [shardformer] fix the moe (#5883) --- colossalai/booster/plugin/__init__.py | 10 +++++++- colossalai/shardformer/policies/mixtral.py | 28 ++++++++++------------ 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py index 62f3708fc..7e0e6ffdd 100644 --- a/colossalai/booster/plugin/__init__.py +++ b/colossalai/booster/plugin/__init__.py @@ -1,10 +1,18 @@ from .gemini_plugin import GeminiPlugin from .hybrid_parallel_plugin import HybridParallelPlugin from .low_level_zero_plugin import LowLevelZeroPlugin +from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from .plugin_base import Plugin from .torch_ddp_plugin import TorchDDPPlugin -__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"] +__all__ = [ + "Plugin", + "TorchDDPPlugin", + "GeminiPlugin", + "LowLevelZeroPlugin", + "HybridParallelPlugin", + "MoeHybridParallelPlugin", +] import torch from packaging import version diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index f9721c79e..0fb858d78 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -40,21 +40,19 @@ class MixtralPolicy(Policy): if self.shard_config.enable_tensor_parallelism: raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") - if getattr(self.shard_config, "ep_group", None) is None: - raise ValueError("You must pass in ep_group via shard_config for expert parallel!") - - # expert parallel - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="block_sparse_moe", - target_module=EPMixtralSparseMoeBlock, - kwargs={"ep_group": self.shard_config.ep_group}, - ) - ], - policy=policy, - target_key=MixtralDecoderLayer, - ) + if getattr(self.shard_config, "ep_group", None) is not None: + # expert parallel + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="block_sparse_moe", + target_module=EPMixtralSparseMoeBlock, + kwargs={"ep_group": self.shard_config.ep_group}, + ) + ], + policy=policy, + target_key=MixtralDecoderLayer, + ) # optimization configuration if self.shard_config.enable_fused_normalization: From 7afbc81d6292f1a44cb5c2f89571c6c1c6d74691 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 4 Jul 2024 11:33:23 +0800 Subject: [PATCH 09/15] [quant] fix bitsandbytes version check (#5882) * [quant] fix bitsandbytes version check * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/quantization/bnb.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/colossalai/quantization/bnb.py b/colossalai/quantization/bnb.py index fa214116a..3601ef62b 100644 --- a/colossalai/quantization/bnb.py +++ b/colossalai/quantization/bnb.py @@ -1,17 +1,25 @@ # adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py +import importlib.metadata import logging import torch import torch.nn as nn +from packaging.version import Version from .bnb_config import BnbQuantizationConfig try: import bitsandbytes as bnb - IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0" - IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2" + try: + # in case lower version of bitsandbytes does not have __version__ attribute + BNB_VERSION = Version(bnb.__version__) + except AttributeError: + BNB_VERSION = Version(importlib.metadata.version("bitsandbytes")) + + IS_4BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.39.0") + IS_8BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.37.2") except ImportError: pass From 7997683aac44cf99529589af4262fba52b29a74b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Jul 2024 13:46:41 +0800 Subject: [PATCH 10/15] [pre-commit.ci] pre-commit autoupdate (#5878) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/pre-commit/mirrors-clang-format: v18.1.7 → v18.1.8](https://github.com/pre-commit/mirrors-clang-format/compare/v18.1.7...v18.1.8) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f2c408bce..9088d0e1b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310'] - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.7 + rev: v18.1.8 hooks: - id: clang-format name: clang formatter From 3420921101186ffa6e6f9428bbb4036302230ccd Mon Sep 17 00:00:00 2001 From: Haze188 Date: Fri, 5 Jul 2024 16:13:58 +0800 Subject: [PATCH 11/15] [shardformer] DeepseekMoE support (#5871) * [Feature] deepseek moe expert parallel implement * [misc] fix typo, remove redundant file (#5867) * [misc] fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] deepseek support & unit test * [misc] remove debug code & useless print * [misc] fix typos (#5872) * [Feature] remove modeling file, use auto config. (#5884) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [Deepseek] remove redundant code (#5888) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [Feature/deepseek] resolve comment. (#5889) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [misc] mv module replacement into if branch * [misc] add some warning message and modify some code in unit test * [misc] fix typos --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/cluster/process_group_mesh.py | 2 +- colossalai/shardformer/modeling/deepseek.py | 429 ++++++++++++++++++ .../shardformer/policies/auto_policy.py | 8 +- colossalai/shardformer/policies/deepseek.py | 212 +++++++++ colossalai/shardformer/policies/mixtral.py | 6 +- tests/test_moe/test_deepseek_layer.py | 72 +++ tests/test_moe/test_moe_checkpoint.py | 38 +- 7 files changed, 748 insertions(+), 19 deletions(-) create mode 100644 colossalai/shardformer/modeling/deepseek.py create mode 100644 colossalai/shardformer/policies/deepseek.py create mode 100644 tests/test_moe/test_deepseek_layer.py diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index 1319a4529..b6aff0d72 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -147,7 +147,7 @@ class ProcessGroupMesh: ProcessGroup: The process group with the given ranks. """ ranks_in_group = sorted(ranks_in_group) - if tuple(ranks_in_group) not in self._group_to_ranks: + if tuple(ranks_in_group) not in self._ranks_to_group: group = dist.new_group(ranks_in_group, backend=backend) self._ranks_to_group[tuple(ranks_in_group)] = group self._group_to_ranks[group] = tuple(ranks_in_group) diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py new file mode 100644 index 000000000..6e79ce144 --- /dev/null +++ b/colossalai/shardformer/modeling/deepseek.py @@ -0,0 +1,429 @@ +from typing import List, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup + +# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo +from torch.nn import CrossEntropyLoss +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.utils import is_flash_attn_2_available, logging + +from colossalai.lazy import LazyInitContext +from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.shard import ShardConfig +from colossalai.shardformer.shard.utils import set_tensors_to_none + + +# copied from modeling_deepseek.py +class AddAuxiliaryLoss(torch.autograd.Function): + """ + The trick function of adding auxiliary (aux) loss, + which includes the gradient of the aux loss during backpropagation. + """ + + @staticmethod + def forward(ctx, x, loss): + assert loss.numel() == 1 + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = None + if ctx.required_aux_loss: + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) + return grad_output, grad_loss + + +class EPDeepseekMoE(nn.Module): + def __init__(self): + super(EPDeepseekMoE, self).__init__() + + def setup_ep(self, ep_group: ProcessGroup): + ep_group = ep_group + self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1 + self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0 + self.num_experts = self.config.n_routed_experts + assert self.num_experts % self.ep_size == 0 + self.ep_group = ep_group + self.num_experts_per_ep = self.num_experts // self.ep_size + self.expert_start_idx = self.ep_rank * self.num_experts_per_ep + held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] + set_tensors_to_none(self.experts, exclude=set(held_experts)) + for p in self.experts.parameters(): + p.ep_group = ep_group + + @staticmethod + def from_native_module(module: Union["DeepseekMoE", "DeepseekMLP"], *args, **kwargs) -> "EPDeepseekMoE": + LazyInitContext.materialize(module) + if module.__class__.__name__ == "DeepseekMLP": + return module + module.__class__ = EPDeepseekMoE + assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!" + module.setup_ep(kwargs["ep_group"]) + return module + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + identity = hidden_states + orig_shape = hidden_states.shape + + topk_experts_idx, topk_experts_weight, aux_loss = self.gate(hidden_states) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) # [t0, t1, t2 ...] + hidden_states = hidden_states.repeat_interleave( + self.num_experts_per_tok, dim=0 + ) # after repeat_interleave: [t0 t0 t1 t1 t2 t2 ... ] + + flat_topk_experts_idx = topk_experts_idx.view(-1) # [e0 e1 e2 ...] + # The elements of flat_topk_token_idx are token ids, which are arranged in ascending order of expert ids. + flat_topk_token_idx = flat_topk_experts_idx.argsort() + + # Now we adjust the order of the hidden states, also in ascending order of expert id + dispatch_states = hidden_states[flat_topk_token_idx] + input_split_sizes = flat_topk_experts_idx.bincount(minlength=self.num_experts) # [n0, n1, n2, n3] + output_split_sizes = torch.zeros_like(input_split_sizes) + + # [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3] + dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) + + input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) + output_states = MoeInGradScaler.apply(output_states, self.ep_size) + + if output_states.size(0) > 0: + if self.num_experts_per_ep == 1: + expert = self.experts[self.expert_start_idx] + output_states = expert(output_states) + else: + output_states_splits = output_states.split(output_split_sizes.tolist()) + output_states_list = [] + for i, split_states in enumerate(output_states_splits): + if split_states.size(0) == 0: # no token routed to this experts + continue + expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] + split_states = expert(split_states) + output_states_list.append(split_states) + output_states = torch.cat(output_states_list) + output_states = MoeOutGradScaler.apply(output_states, self.ep_size) + dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) + recover_token_idx = torch.empty_like(flat_topk_token_idx) + recover_token_idx[flat_topk_token_idx] = torch.arange( + flat_topk_token_idx.size(0), device=flat_topk_token_idx.device + ) + + output_hidden_states = dispatch_states[recover_token_idx] # t0 t0 t1 t1 t2 t2 + output_hidden_states = output_hidden_states.view(-1, self.num_experts_per_tok, orig_shape[-1]) + output_hidden_states = (output_hidden_states * topk_experts_weight[:, :, None]).sum(dim=-2) # (B*S, h) + output_hidden_states = output_hidden_states.view(*orig_shape) + output_hidden_states = AddAuxiliaryLoss.apply(output_hidden_states, aux_loss) + if self.config.n_shared_experts is not None: + output_hidden_states = output_hidden_states + self.shared_experts(identity) + return output_hidden_states + + +class DeepseekPipelineForwards: + """ + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + """ + + @staticmethod + def deepseek_model_forward( + self: "DeepseekModel", + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if is_flash_attn_2_available(): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + start_idx, end_idx = stage_index[0], stage_index[1] + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + output_attentions, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + + if stage_manager.is_last_stage(): + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + # always return dict for imediate stage + return { + "hidden_states": hidden_states, + } + + @staticmethod + def deepseek_for_causal_lm_forward( + self: "DeepseekForCausalLM", + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = DeepseekPipelineForwards.deepseek_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=None, + hidden_states=outputs[0], + attentions=None, + ) + else: + out = {} + hidden_states = outputs.get("hidden_states") + out["hidden_states"] = hidden_states + return out diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index bf139c840..ae9f3603c 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -160,6 +160,13 @@ _POLICY_LIST = { "transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy" ), + # Deepseek + "transformers_modules.modeling_deepseek.DeepSeekModel": PolicyLocation( + file_name="deepseek", class_name="DeepseekModelPolicy" + ), + "transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation( + file_name="deepseek", class_name="DeepseekForCausalLMPolicy" + ), # Falcon "transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation( file_name="falcon", class_name="FalconModelPolicy" @@ -252,7 +259,6 @@ def get_autopolicy(model: nn.Module) -> Policy: """ full_name = _fullname(model) policy_location = _POLICY_LIST.get(full_name, None) - if policy_location is None: raise NotImplementedError( f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}" diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py new file mode 100644 index 000000000..8ebda357b --- /dev/null +++ b/colossalai/shardformer/policies/deepseek.py @@ -0,0 +1,212 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List, Union + +import torch.nn as nn +from torch import Tensor +from torch.nn import Module + +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"] + + +class DeepseekPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + policy = {} + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + raise NotImplementedError( + "Deepseek dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + ) + + if self.shard_config.enable_tensor_parallelism: + raise NotImplementedError("Tensor parallelism is not supported for Deepseek model now.") + + if getattr(self.shard_config, "ep_group", None) is not None: + # expert parallel + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="mlp", + target_module=EPDeepseekMoE, + kwargs={"ep_group": self.shard_config.ep_group}, + ) + ], + policy=policy, + target_key="DeepseekDecoderLayer", + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ), + ], + policy=policy, + target_key="DeepseekDecoderLayer", + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key="DeepseekModel", + ) + + if self.shard_config.enable_flash_attention: + warnings.warn( + "Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False." + ) + self.shard_config.enable_flash_attention = False + + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "DeepseekModel": + module = self.model + else: + module = self.model.model + + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) + + return + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "DeepseekModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + + return held_layers + + +class DeepseekModelPolicy(DeepseekPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls="DeepseekModel", + new_forward=DeepseekPipelineForwards.deepseek_model_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + return [] + + +class DeepseekForCausalLMPolicy(DeepseekPolicy): + def module_policy(self): + policy = super().module_policy() + # TODO: assign pg mesh from plugin to all modules + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + "DeepseekForCausalLM": ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True), + ) + ] + ) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls="DeepseekForCausalLM", + new_forward=DeepseekPipelineForwards.deepseek_for_causal_lm_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + deepseek_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if ( + id(deepseek_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): + # tie weights + return [ + { + 0: deepseek_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + return [] diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 0fb858d78..ad93e9469 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -192,16 +192,16 @@ class MixtralForCausalLMPolicy(MixtralPolicy): return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - llama_model = self.model.model + mixtral_model = self.model.model if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if ( - id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + id(mixtral_model.embed_tokens.weight) == id(self.model.lm_head.weight) and self.pipeline_stage_manager.num_stages > 1 ): # tie weights return [ { - 0: llama_model.embed_tokens.weight, + 0: mixtral_model.embed_tokens.weight, self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, } ] diff --git a/tests/test_moe/test_deepseek_layer.py b/tests/test_moe/test_deepseek_layer.py new file mode 100644 index 000000000..85cc98695 --- /dev/null +++ b/tests/test_moe/test_deepseek_layer.py @@ -0,0 +1,72 @@ +from copy import deepcopy + +import pytest +import torch +import torch.distributed as dist +from torch.testing import assert_close +from transformers import AutoConfig, AutoModel + +import colossalai +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.shardformer.modeling.deepseek import EPDeepseekMoE +from colossalai.testing.utils import spawn + +tokens, n_experts = 7, 4 +hidden_size = 8 +top_k = 2 + + +def check_deepseek_moe_layer(): + torch.cuda.set_device(dist.get_rank()) + plugin = MoeHybridParallelPlugin( + precision="bf16", + tp_size=1, + pp_size=1, + ep_size=dist.get_world_size(), + ) + + config = AutoConfig.from_pretrained( + "deepseek-ai/deepseek-moe-16b-base", + num_hidden_layers=1, + n_routed_experts=n_experts, + num_experts_per_tok=top_k, + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + first_k_dense_replace=0, + num_attention_heads=2, + trust_remote_code=True, + ) + torch.manual_seed(0) + # get the moe layer in auto model + orig_model = AutoModel.from_config(config, trust_remote_code=True).layers[0].mlp.cuda() + x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda() + orig_output = orig_model(x) + model = deepcopy(orig_model) + model = EPDeepseekMoE.from_native_module(model, ep_group=plugin.ep_group) + ep_output = model(x) + assert_close(orig_output, ep_output) + orig_loss = orig_output.mean() + orig_loss.backward() + ep_loss = ep_output.mean() + ep_loss.backward() + assert_close(orig_loss, ep_loss) + name_to_p = {n: p for n, p in orig_model.named_parameters()} + for n, ep_p in model.named_parameters(): + p = name_to_p[n] + if ep_p.grad is not None: + assert_close(p.grad, ep_p.grad) + + +def run_dist(rank: int, world_size: int, port: int): + colossalai.launch(rank, world_size, "localhost", port) + check_deepseek_moe_layer() + + +# @pytest.mark.parametrize("world_size", [2, 4]) +@pytest.mark.parametrize("world_size", [2]) +def test_deepseek_moe_layer(world_size: int): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_deepseek_moe_layer(2) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 249dd4b97..164301695 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -15,6 +15,7 @@ from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.checkpoint_io import MoECheckpointIO from colossalai.tensor.moe_tensor.api import is_moe_tensor +from colossalai.testing import parameterize, spawn from colossalai.testing.utils import spawn tokens, n_experts = 7, 4 @@ -77,7 +78,23 @@ def check_optimizer_snapshot_equal(snapshot1, snapshot2, param2name, moe_dp_grou raise AssertionError(f"A total of {count} optim states are not equal") -def check_mixtral_moe_layer(): +@parameterize( + "test_config", + [ + [ + MixtralConfig( + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_local_experts=n_experts, + num_experts_per_tok=top_k, + num_attention_heads=2, + num_key_value_heads=2, + ), + MixtralForCausalLM, + ], + ], +) +def check_moe_checkpoint(test_config): context = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext() with context as f: torch.cuda.set_device(dist.get_rank()) @@ -87,17 +104,11 @@ def check_mixtral_moe_layer(): broadcast_objects = [None] dist.broadcast_object_list(broadcast_objects, src=0) - config = MixtralConfig( - hidden_size=hidden_size, - intermediate_size=hidden_size * 2, - num_local_experts=n_experts, - num_experts_per_tok=top_k, - num_attention_heads=2, - num_key_value_heads=2, - ) + config = test_config[0] + model_cls = test_config[1] torch.manual_seed(0) input_ids = torch.randint(0, 100, (2, tokens)).cuda() - orig_model = MixtralForCausalLM(config).cuda() + orig_model = model_cls(config).cuda() model = deepcopy(orig_model) optimizer = Adam(model.parameters(), lr=1e-3) plugin = MoeHybridParallelPlugin( @@ -120,7 +131,6 @@ def check_mixtral_moe_layer(): lambda outputs, inputs: outputs.loss, optimizer, ) - tmpdirname = broadcast_objects[0] model_dir = os.path.join(tmpdirname, "mixtral_model") hf_model_dir = os.path.join(tmpdirname, "mixtral_hf_model") @@ -129,13 +139,13 @@ def check_mixtral_moe_layer(): booster.save_model(model, model_dir, shard=True) dist.barrier() if dist.get_rank() == 0: - saved_model = MixtralForCausalLM.from_pretrained(model_dir).cuda() + saved_model = model_cls.from_pretrained(model_dir).cuda() check_model_equal(orig_model, saved_model) # check_model_equal(model, saved_model) saved_model.save_pretrained(hf_model_dir) dist.barrier() # check load model - new_model = MixtralForCausalLM(config).cuda() + new_model = model_cls(config).cuda() new_optimizer = Adam(new_model.parameters(), lr=1e-3) new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer) booster.load_model(new_model, hf_model_dir) @@ -163,7 +173,7 @@ def check_mixtral_moe_layer(): def run_dist(rank: int, world_size: int, port: int): colossalai.launch(rank, world_size, "localhost", port) - check_mixtral_moe_layer() + check_moe_checkpoint() # Test EP + ZeRO + PP From 8ec24b6a4d0e0dbec7da39e43c3c1b2cfcb0395d Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 5 Jul 2024 20:02:36 +0800 Subject: [PATCH 12/15] [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap Co-authored-by: Edenzzzz --- colossalai/initialize.py | 6 ++++++ colossalai/legacy/nn/layer/parallel_1d/_operation.py | 1 - colossalai/shardformer/shard/shardformer.py | 4 ---- examples/language/llama/benchmark.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 71d42312e..4e2eff7ce 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -3,6 +3,12 @@ import os +# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when overlapping communication and computation, +# the order of of kernel launches on GPUs are the same as on the CPU so that comm is launched first. +# see https://github.com/NVIDIA/Megatron-LM/issues/533 +# https://forums.developer.nvidia.com/t/how-many-streams-maximum-number-of-streams/6571/16 +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + import torch.distributed as dist from colossalai.accelerator import get_accelerator diff --git a/colossalai/legacy/nn/layer/parallel_1d/_operation.py b/colossalai/legacy/nn/layer/parallel_1d/_operation.py index f01da97ba..8b8f04ccf 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_1d/_operation.py @@ -81,7 +81,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function): handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) # Delay the start of weight gradient computation shortly (3us) to have # all-reduce scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index b54c58273..db03eec41 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -1,4 +1,3 @@ -import os from typing import Dict, List, Tuple import torch.distributed as dist @@ -11,9 +10,6 @@ from ..policies.base_policy import Policy from .shard_config import ShardConfig from .sharder import ModelSharder -# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when communication and computation overlap, the order of core scheduling is correct -os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - class ShardFormer: """ diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 8a35db1f7..2b7bd50b8 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -292,7 +292,7 @@ def main(): with get_profile_context( args.profile, args.ignore_steps, - len(dataloader) - 1, + 1, # avoid creating massive log files save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", ) as prof: if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: From cba20525a81565fc86e13b78973ffa8210a05cd3 Mon Sep 17 00:00:00 2001 From: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Date: Mon, 8 Jul 2024 16:02:07 +0800 Subject: [PATCH 13/15] [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838) * Diffusion Model Inference support * Stable Diffusion 3 Support * pixartalpha support --- colossalai/inference/config.py | 48 +- colossalai/inference/core/base_engine.py | 90 ++ colossalai/inference/core/diffusion_engine.py | 200 +++++ colossalai/inference/core/engine.py | 800 ++---------------- colossalai/inference/core/llm_engine.py | 758 +++++++++++++++++ colossalai/inference/core/request_handler.py | 51 +- .../inference/modeling/models/diffusion.py | 54 ++ .../inference/modeling/models/pixart_alpha.py | 220 +++++ .../modeling/models/stablediffusion3.py | 178 ++++ .../inference/modeling/policy/__init__.py | 6 + .../inference/modeling/policy/pixart_alpha.py | 34 + .../modeling/policy/stablediffusion3.py | 34 + colossalai/inference/struct.py | 12 + colossalai/inference/utils.py | 39 +- .../stable_diffusion/sd3_generation.py | 75 ++ requirements/requirements.txt | 1 + 16 files changed, 1860 insertions(+), 740 deletions(-) create mode 100644 colossalai/inference/core/base_engine.py create mode 100644 colossalai/inference/core/diffusion_engine.py create mode 100644 colossalai/inference/core/llm_engine.py create mode 100644 colossalai/inference/modeling/models/diffusion.py create mode 100644 colossalai/inference/modeling/models/pixart_alpha.py create mode 100644 colossalai/inference/modeling/models/stablediffusion3.py create mode 100644 colossalai/inference/modeling/policy/pixart_alpha.py create mode 100644 colossalai/inference/modeling/policy/stablediffusion3.py create mode 100644 examples/inference/stable_diffusion/sd3_generation.py diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index e114e8a61..1beb86874 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -5,7 +5,7 @@ Our config contains various options for inference optimization, it is a unified import logging from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers.generation import GenerationConfig @@ -396,3 +396,49 @@ class ModelShardInferenceConfig: use_cuda_kernel: bool = False use_spec_dec: bool = False use_flash_attn: bool = False + + +@dataclass +class DiffusionGenerationConfig: + """ + Param for diffusion model forward + """ + + prompt_2: Optional[Union[str, List[str]]] = None + prompt_3: Optional[Union[str, List[str]]] = None + height: Optional[int] = None + width: Optional[int] = None + num_inference_steps: int = None + timesteps: List[int] = None + guidance_scale: float = None + negative_prompt: Optional[Union[str, List[str]]] = ( + None # NOTE(@lry89757) in pixart default to "", in sd3 default to None + ) + negative_prompt_2: Optional[Union[str, List[str]]] = None + negative_prompt_3: Optional[Union[str, List[str]]] = None + num_images_per_prompt: Optional[int] = None + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None + latents: Optional[torch.FloatTensor] = None + prompt_embeds: Optional[torch.FloatTensor] = None + negative_prompt_embeds: Optional[torch.FloatTensor] = None + pooled_prompt_embeds: Optional[torch.FloatTensor] = None + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None + output_type: Optional[str] = None # "pil" + return_dict: bool = None + joint_attention_kwargs: Optional[Dict[str, Any]] = None + clip_skip: Optional[int] = None + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None + callback_on_step_end_tensor_inputs: List[str] = None + + def to_dict(self) -> Dict[str, Any]: + # NOTE(@lry89757) Only return the dict that not the default value None + result = {} + for field in fields(self): + value = getattr(self, field.name) + if value is not None: + result[field.name] = value + return result + + @classmethod + def from_kwargs(cls, **kwargs) -> "DiffusionGenerationConfig": + return cls(**kwargs) diff --git a/colossalai/inference/core/base_engine.py b/colossalai/inference/core/base_engine.py new file mode 100644 index 000000000..392dd2990 --- /dev/null +++ b/colossalai/inference/core/base_engine.py @@ -0,0 +1,90 @@ +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn + +from colossalai.cluster import ProcessGroupMesh +from colossalai.inference.config import ModelShardInferenceConfig +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.base_policy import Policy + + +class BaseEngine(ABC): + @abstractmethod + def __init__(self, model_or_path, inference_config=None, verbose=False, model_policy=None): + pass + + @abstractmethod + def init_model(self, model_or_path, model_policy=None, model_shard_infer_config=None): + """ + Init Model for Engine + """ + + @abstractmethod + def generate(self, request_ids=None, prompts=None, generation_config=None, **kwargs): + """ + Generate ouptput for coming requests + """ + + @abstractmethod + def add_request(self, prompts, request_ids=None, **kwargs): + """ + Add new request to Engine + """ + + @abstractmethod + def step(self): + """ + Perform one new step forward + """ + + @abstractmethod + def _verify_args(self): + """ + Verify the parameters and members of class + """ + + @torch.inference_mode() + def capture_model(self): + """ + Use cuda graph to capture model + """ + return NotImplementedError("This method should be implemented by subclasses") + + def _shardformer( + self, + model: nn.Module, + model_policy: Policy, + model_shard_infer_config: ModelShardInferenceConfig = None, + stage_manager: PipelineStageManager = None, + tp_group: ProcessGroupMesh = None, + **kwargs, + ) -> nn.Module: + """ + Initialize ShardConfig and replace the model with shardformer. + + Args: + model (nn.Module): Path or nn.Module of this model. + model_policy (Policy): The policy to shardformer model which is determined by the model type. + stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. + tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. + + Returns: + nn.Module: The model optimized by Shardformer. + """ + + shardconfig = ShardConfig( + tensor_parallel_process_group=tp_group, + pipeline_stage_manager=stage_manager, + enable_tensor_parallelism=(self.inference_config.tp_size > 1), + enable_fused_normalization=False, + enable_all_optimization=False, + enable_flash_attention=False, + enable_jit_fused=False, + enable_sequence_parallelism=False, + extra_kwargs={"model_shard_infer_config": model_shard_infer_config, **kwargs}, + ) + shardformer = ShardFormer(shard_config=shardconfig) + shard_model, _ = shardformer.optimize(model, model_policy) + return shard_model diff --git a/colossalai/inference/core/diffusion_engine.py b/colossalai/inference/core/diffusion_engine.py new file mode 100644 index 000000000..75b9889bf --- /dev/null +++ b/colossalai/inference/core/diffusion_engine.py @@ -0,0 +1,200 @@ +from itertools import count +from typing import List, Tuple, Type, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn as nn +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from torch import distributed as dist + +from colossalai.accelerator import get_accelerator +from colossalai.cluster import ProcessGroupMesh +from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig +from colossalai.inference.modeling.models.diffusion import DiffusionPipe +from colossalai.inference.modeling.policy import model_policy_map +from colossalai.inference.struct import DiffusionSequence +from colossalai.inference.utils import get_model_size, get_model_type +from colossalai.logging import get_dist_logger +from colossalai.shardformer.policies.base_policy import Policy + +from .base_engine import BaseEngine +from .request_handler import NaiveRequestHandler + +PP_AXIS, TP_AXIS = 0, 1 + + +class DiffusionEngine(BaseEngine): + def __init__( + self, + model_or_path: DiffusionPipeline | str, + inference_config: InferenceConfig = None, + verbose: bool = False, + model_policy: Policy | type[Policy] = None, + ) -> None: + self.inference_config = inference_config + self.dtype = inference_config.dtype + self.high_precision = inference_config.high_precision + + self.verbose = verbose + self.logger = get_dist_logger(__name__) + self.model_shard_infer_config = inference_config.to_model_shard_inference_config() + + self.model_type = get_model_type(model_or_path=model_or_path) + + self.init_model(model_or_path, model_policy, self.model_shard_infer_config) + + self.request_handler = NaiveRequestHandler() + + self.counter = count() + + self._verify_args() + + def _verify_args(self) -> None: + assert isinstance(self.model, DiffusionPipe), "model must be DiffusionPipe" + + def init_model( + self, + model_or_path: Union[str, nn.Module, DiffusionPipeline], + model_policy: Union[Policy, Type[Policy]] = None, + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + """ + Shard model or/and Load weight + + Args: + model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. + model_policy (Policy): the policy to replace the model. + model_inference_config: the configuration for modeling initialization when inference. + model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. + """ + if isinstance(model_or_path, str): + model = DiffusionPipeline.from_pretrained(model_or_path, torch_dtype=self.dtype) + policy_map_key = model.__class__.__name__ + model = DiffusionPipe(model) + elif isinstance(model_or_path, DiffusionPipeline): + policy_map_key = model_or_path.__class__.__name__ + model = DiffusionPipe(model_or_path) + else: + self.logger.error(f"model_or_path support only str or DiffusionPipeline currently!") + + torch.cuda.empty_cache() + init_gpu_memory = torch.cuda.mem_get_info()[0] + + self.device = get_accelerator().get_current_device() + if self.verbose: + self.logger.info(f"the device is {self.device}") + + if self.verbose: + self.logger.info( + f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" + ) + + if model_policy is None: + model_policy = model_policy_map.get(policy_map_key) + + if not isinstance(model_policy, Policy): + try: + model_policy = model_policy() + except Exception as e: + raise ValueError(f"Unable to instantiate model policy: {e}") + + assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" + pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) + tp_group = pg_mesh.get_group_along_axis(TP_AXIS) + + self.model = self._shardformer( + model, + model_policy, + model_shard_infer_config, + None, + tp_group=tp_group, + ) + + self.model = model.to(self.device) + + if self.verbose: + self.logger.info( + f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" + ) + + free_gpu_memory, _ = torch.cuda.mem_get_info() + peak_memory = init_gpu_memory - free_gpu_memory + if self.verbose: + self.logger.info( + f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" + ) + + def generate( + self, + request_ids: Union[List[int], int] = None, + prompts: Union[List[str], str] = None, + generation_config: DiffusionGenerationConfig = None, + **kwargs, + ) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]: + """ """ + gen_config_dict = generation_config.to_dict() if generation_config is not None else {} + prompts = [prompts] if isinstance(prompts, str) else prompts + request_ids = [request_ids] if isinstance(request_ids, int) else request_ids + + with torch.inference_mode(): + if prompts is not None: + self.add_request( + request_ids=request_ids, + prompts=prompts, + **gen_config_dict, + **kwargs, + ) + + output_reqs_list = [] + + # intuition: If user provide a generation config, we should replace the existing one. + if generation_config is not None: + self.generation_config = generation_config + self.generation_config_dict = gen_config_dict + + while self.request_handler.check_unfinished_reqs(): + output_reqs_list += self.step() + + return output_reqs_list + + def add_request( + self, + prompts: Union[List[str], str], + request_ids: Union[List[int], int] = None, + **kwargs, + ): + if request_ids is not None and not isinstance(request_ids, list): + request_ids = [request_ids] + + if not isinstance(prompts, list): + prompts = [prompts] + + generation_config = DiffusionGenerationConfig.from_kwargs(**kwargs) + prompts_num = len(prompts) + for i in range(prompts_num): + if request_ids: + assert isinstance( + request_ids[0], int + ), f"The request_id type must be int, but got {type(request_ids[0])}" + assert len(request_ids) == prompts_num + request_id = request_ids[i] + else: + request_id = next(self.counter) + + seq = DiffusionSequence(request_id=request_id, prompt=prompts[i], generation_config=generation_config) + + self.request_handler.add_sequence(seq) + + def step(self) -> List[PIL.Image.Image]: + """ + In each step, do the follows: + 1. Run RequestHandler.schedule() and get the batch used for inference. + 2. run forward to get List[Image] + Returns: + List[PIL.Image.Image]: Image Generated by one step. + """ + + input = self.request_handler.schedule() + ret = self.model(prompt=input.prompt, **input.generation_config.to_dict()) + return ret diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 8f8aef65e..5c9bdc321 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -1,57 +1,24 @@ -import time -from itertools import count -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import List, Tuple, Type, Union import numpy as np -import torch +import PIL.Image import torch.nn as nn -from torch import distributed as dist -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - GenerationConfig, - PreTrainedTokenizer, - PreTrainedTokenizerFast, -) -from transformers.models.llama.modeling_llama import LlamaForCausalLM +from diffusers import DiffusionPipeline +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast -from colossalai.accelerator import get_accelerator -from colossalai.cluster import ProcessGroupMesh -from colossalai.inference.batch_bucket import BatchBucket -from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig -from colossalai.inference.graph_runner import CUDAGraphRunner -from colossalai.inference.modeling.policy import model_policy_map -from colossalai.inference.sampler import search_tokens -from colossalai.inference.spec import Drafter, GlideInput -from colossalai.inference.struct import Sequence -from colossalai.inference.utils import get_model_size, has_index_file -from colossalai.interface import ModelWrapper -from colossalai.lazy import LazyInitContext -from colossalai.logging import get_dist_logger -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.inference.config import InferenceConfig +from colossalai.inference.utils import ModelType, get_model_type from colossalai.shardformer.policies.base_policy import Policy -from .request_handler import RequestHandler - __all__ = ["InferenceEngine"] -PP_AXIS, TP_AXIS = 0, 1 - -_supported_models = { - "LlamaForCausalLM": LlamaForCausalLM, - "BaichuanForCausalLM": AutoModelForCausalLM, -} - -_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] - class InferenceEngine: """ InferenceEngine which manages the inference process.. Args: - model_or_path (nn.Module or str): Path or nn.Module of this model. + model_or_path (nn.Module or DiffusionPipeline or str): Path or nn.Module or DiffusionPipeline of this model. tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use. inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. verbose (bool): Determine whether or not to log the generation process. @@ -60,567 +27,68 @@ class InferenceEngine: def __init__( self, - model_or_path: Union[nn.Module, str], - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - inference_config: InferenceConfig, + model_or_path: Union[nn.Module, str, DiffusionPipeline], + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, + inference_config: InferenceConfig = None, verbose: bool = False, model_policy: Union[Policy, Type[Policy]] = None, ) -> None: - self.inference_config = inference_config - self.dtype = inference_config.dtype - self.high_precision = inference_config.high_precision + self.__dict__["_initialized"] = False # use __dict__ directly to avoid calling __setattr__ + self.model_type = get_model_type(model_or_path=model_or_path) + self.engine = None + if self.model_type == ModelType.LLM: + from .llm_engine import LLMEngine - self.verbose = verbose - self.logger = get_dist_logger(__name__) - self.model_shard_infer_config = inference_config.to_model_shard_inference_config() + self.engine = LLMEngine( + model_or_path=model_or_path, + tokenizer=tokenizer, + inference_config=inference_config, + verbose=verbose, + model_policy=model_policy, + ) + elif self.model_type == ModelType.DIFFUSION_MODEL: + from .diffusion_engine import DiffusionEngine - self.init_model(model_or_path, model_policy, self.model_shard_infer_config) - - self.generation_config = inference_config.to_generation_config(self.model_config) - self.generation_config_dict = self.generation_config.to_dict() - - self.tokenizer = tokenizer - self.tokenizer.pad_token = self.tokenizer.eos_token - - self.request_handler = RequestHandler(self.inference_config, self.model_config) - self.k_cache, self.v_cache = self.request_handler.get_kvcache() - # DISCUSS maybe move this into batch info? - - self.counter = count() - - self.use_cuda_graph = self.inference_config.use_cuda_graph - if self.use_cuda_graph: - self.graph_runners: Dict[int, CUDAGraphRunner] = {} - self.graph_memory_pool = None # Set during graph capture. - if verbose: - self.logger.info("Colossal AI CUDA Graph Capture on") - - self.capture_model(self.k_cache, self.v_cache) - - # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` - self.use_spec_dec = self.inference_config.use_spec_dec - - self.drafter_model = None - self.drafter = None - self.use_glide = False - self.n_spec_tokens = self.inference_config.max_n_spec_tokens + self.engine = DiffusionEngine( + model_or_path=model_or_path, + inference_config=inference_config, + verbose=verbose, + model_policy=model_policy, + ) + elif self.model_type == ModelType.UNKNOWN: + self.logger.error(f"Model Type either Difffusion or LLM!") + self._initialized = True self._verify_args() - def init_model( - self, - model_or_path: Union[nn.Module, str], - model_policy: Union[Policy, Type[Policy]] = None, - model_shard_infer_config: ModelShardInferenceConfig = None, - ): - """ - Shard model or/and Load weight - - Args: - model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. - model_policy (Policy): the policy to replace the model. - model_inference_config: the configuration for modeling initialization when inference. - model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. - """ - pretrained_path = None - if isinstance(model_or_path, str): - import colossalai.interface.pretrained as pretrained_utils - - try: - hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype) - arch = getattr(hf_config, "architectures")[0] - if arch in _supported_models.keys(): - if arch is "BaichuanForCausalLM": - self.logger.warning( - "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers" - ) - ctx = LazyInitContext(default_device="cuda") - with ctx: - model = _supported_models[arch].from_pretrained( - model_or_path, trust_remote_code=True, torch_dtype=self.dtype - ) - pretrained_path = pretrained_utils.get_pretrained_path(model) - else: - # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate - raise ValueError(f"Model {arch} is not supported.") - - except Exception as e: - self.logger.error( - f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" - ) - else: - model = model_or_path - - self.model_config = model.config - - torch.cuda.empty_cache() - init_gpu_memory = torch.cuda.mem_get_info()[0] - - self.device = get_accelerator().get_current_device() - if self.verbose: - self.logger.info(f"the device is {self.device}") - - model = model.to(self.dtype).eval() - - if self.verbose: - self.logger.info( - f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" - ) - - if model_policy is None: - prefix = "nopadding" if not self.inference_config.pad_input else "padding" - model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}" - model_policy = model_policy_map.get(model_policy_key) - - if not isinstance(model_policy, Policy): - try: - model_policy = model_policy() - except Exception as e: - raise ValueError(f"Unable to instantiate model policy: {e}") - - assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" - pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) - tp_group = pg_mesh.get_group_along_axis(TP_AXIS) - - self.model = self._shardformer( - model, - model_policy, - model_shard_infer_config, - None, - tp_group=tp_group, - ) - - self.model = ModelWrapper(model).to(self.device) - - if self.verbose: - self.logger.info( - f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" - ) - - if pretrained_path: - from colossalai.inference.core.plugin import InferCheckpoint_io - - cpt_io = InferCheckpoint_io() - if_has_index_file, model_index_file = has_index_file(pretrained_path) - assert if_has_index_file, "the model path is invalid" - cpt_io.load_model(self.model, model_index_file) - - free_gpu_memory, _ = torch.cuda.mem_get_info() - peak_memory = init_gpu_memory - free_gpu_memory - if self.verbose: - self.logger.info( - f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" - ) - - @torch.inference_mode() - def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]): - assert self.use_cuda_graph, "please turn on the cuda graph" - - if self.verbose: - self.logger.info("Colossal AI CUDA Graph Capture begin") - - t_capture_begin = time.perf_counter() - - block_size = self.inference_config.block_size - head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads - - # Prepare dummy inputs. These will be reused for all batch sizes. - max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) - max_context_len_to_capture = self.inference_config.max_context_len_to_capture - max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size - input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda() - # self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) - self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32) - self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE)) - self.graph_block_tables[0, :] = np.arange( - 0, max_num_blocks - ) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len - block_tables = torch.from_numpy(self.graph_block_tables).cuda() - output_tensor = torch.zeros( - (max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device - ) - fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor - - max_num_seqs = self.inference_config.max_batch_size - batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs] - sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda() - # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len - sequence_lengths[0] = torch.tensor( - self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32 - ).cuda() - - # NOTE: Capturing the largest batch size first may help reduce the - # memory usage of CUDA graph. - for batch_size in reversed(batch_size_capture_list): - if self.verbose: - self.logger.info(f"batch size {batch_size} graph capturing") - - input_meta_data = InputMetaData( - block_tables=block_tables[:batch_size], - sequence_lengths=sequence_lengths[:batch_size], - fd_inter_tensor=fd_inter_tensor, - batch_size=batch_size, - is_prompts=False, - use_cuda_graph=True, - high_precision=False, - kv_seq_len=sequence_lengths[:batch_size].max().item(), - head_dim=head_dim, - dtype=self.dtype, - ) - - graph_runner = CUDAGraphRunner(self.model) - graph_runner.capture( - input_tokens_ids[:batch_size], - output_tensor[:batch_size], - input_meta_data, - k_caches=k_cache, - v_caches=v_cache, - memory_pool=self.graph_memory_pool, - ) - self.graph_memory_pool = graph_runner.graph.pool() - self.graph_runners[batch_size] = graph_runner - - t_capture_end = time.perf_counter() - - if self.verbose: - self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s") - def _verify_args(self) -> None: """Verify the input args""" - if not isinstance(self.inference_config, InferenceConfig): - raise TypeError("Invalid type of inference config provided.") - if not isinstance(self.model, nn.Module): - raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}") - if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)): - raise TypeError( - f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" - ) - if isinstance(self.model, ModelWrapper): - model = self.model.module - assert ( - model.__class__.__name__ in _supported_models.keys() - ), f"Model {self.model.__class__.__name__} is not supported." - - def _shardformer( - self, - model: nn.Module, - model_policy: Policy, - model_shard_infer_config: ModelShardInferenceConfig = None, - stage_manager: PipelineStageManager = None, - tp_group: ProcessGroupMesh = None, - ) -> nn.Module: - """ - Initialize ShardConfig and replace the model with shardformer. - - Args: - model (nn.Module): Path or nn.Module of this model. - model_policy (Policy): The policy to shardformer model which is determined by the model type. - stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. - tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. - - Returns: - nn.Module: The model optimized by Shardformer. - """ - - shardconfig = ShardConfig( - tensor_parallel_process_group=tp_group, - pipeline_stage_manager=stage_manager, - enable_tensor_parallelism=(self.inference_config.tp_size > 1), - enable_fused_normalization=False, - enable_all_optimization=False, - enable_flash_attention=False, - enable_jit_fused=False, - enable_sequence_parallelism=False, - extra_kwargs={"model_shard_infer_config": model_shard_infer_config}, - ) - shardformer = ShardFormer(shard_config=shardconfig) - shard_model, _ = shardformer.optimize(model, model_policy) - return shard_model - - def enable_spec_dec( - self, - drafter_model: nn.Module = None, - n_spec_tokens: int = None, - use_glide_drafter: bool = False, - ) -> None: - """Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations. - - Args: - drafter_model (nn.Module): The drafter model (small model) used to speculate tokens. - If provided, the previous drafter and drafter model, if exist, will be overwritten. - n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying. - If not provided, `max_n_spec_tokens` in InferenceConfig will be used. - use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False. - If True, the drafter model will be replaced by a glide model. - - ```python - ... - engine = InferenceEngine(model, tokenizer, inference_config) - - engine.enable_spec_dec(drafter_model, n_spec_tokens=5) - engine.generate(...) # Speculative Decoding - - engine.disable_spec_dec() - engine.generate(...) # Normal generation - - engine.enable_spec_dec() - engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens - engine.clear_spec_dec() - ``` - """ - - if drafter_model is None and self.drafter is None: - raise ValueError("Drafter not initialized. Please provide a Drafter Model") - if n_spec_tokens is not None: - assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens - self.n_spec_tokens = n_spec_tokens - if drafter_model is not None: - assert isinstance(drafter_model, nn.Module) - # overwrite the drafter, if exists - self.clear_spec_dec() - self.drafter_model = drafter_model - self.drafter = Drafter( - self.drafter_model, - self.tokenizer, - device=self.device, - dtype=self.dtype, - ) - - # check if the provided drafter model is compatible with GLIDE structure - # when `use_glide_drafter` is set to True - if ( - use_glide_drafter - and hasattr(drafter_model, "model") - and hasattr(drafter_model.model, "layers") - and hasattr(drafter_model.model.layers[0], "cross_attn") - ): - self.use_glide = use_glide_drafter - elif use_glide_drafter: - self.logger.warning( - f"`use_glide_drafter` is provided as {use_glide_drafter}, " - f"but the provided drafter model is not compatible with GLIDE structure." - f"Falling back to use the default drafter model (non-GLIDE)." - ) - self.request_handler.set_spec_dec_mode(self.n_spec_tokens) - # using speculative decoding for subsequent generations - self.use_spec_dec = True - - def disable_spec_dec(self) -> None: - """Disable using speculative decoding for subsequent generations.""" - self.request_handler.unset_spec_dec_mode() - # set back to the maximum number of tokens to speculate - self.n_spec_tokens = self.inference_config.max_n_spec_tokens - self.use_glide = False - self.use_spec_dec = False - - def clear_spec_dec(self) -> None: - """Clear relatable structures of speculative decoding, if exist.""" - if self.use_spec_dec: - self.disable_spec_dec() - if self.drafter_model or self.drafter: - self.drafter_model = None - self.drafter = None - torch.cuda.empty_cache() - self.use_glide = False - self.use_spec_dec = False - - def steps_spec_dec(self) -> List[Sequence]: - """ - Run Speculative Decoding steps. This is like retrieving a single batch and launch inference - with many steps of speculating by a drafter model as well as verifying by a main model. - - Returns: - List[Sequence]: finished sequences generated by one step. - """ - batch = self.request_handler.schedule() # prefill batch - assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." - - input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) - - if input_meta_data.use_cuda_graph: - model_executable = self.graph_runners[input_meta_data.batch_size] - else: - model_executable = self.model - - # 1. Prefill small model (Drafter) - fill past kv cache for drafter model - # NOTE For glide drafter models, we won't actually apply glide during prefill stage - drafter_out = self.drafter.speculate(input_token_ids, 1, None) - next_token_ids_spec = drafter_out.next_tokens - drafter_past_key_values = drafter_out.past_key_values - - # 2. Prefill main model (Verifier) - fill past kv cache for main model - logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) - next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) - # append new inputs to the batch, temporarily - batch.append_batch_tokens(next_tokens) - self.request_handler.allocate_batch_spec_dec(batch, 1) - already_allocated_kv_len = batch.seq_lengths[0].item() - input_token_ids = batch.get_1D_inputs_spec_dec(1) - - finished_sequences = self.request_handler.update() - - while True: - # HACK Retrieve the running batch - # Using RequestHandler.schedule here will re-allocate same kv cache for the batch - batch = self.request_handler.running_bb # running batch - assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." - - # 3. Decoding - Drafter model speculates `n` tokens - glide_input = None - if self.use_glide: - glide_input = GlideInput( - batch.get_block_table_tensor(), - self.k_cache[-1], # use kv cahces of the last layer - self.v_cache[-1], - batch.get_sequence_lengths(), - n_spec_tokens=self.n_spec_tokens, - ) - - drafter_out = self.drafter.speculate( - input_token_ids, - self.n_spec_tokens, - drafter_past_key_values, - glide_input=glide_input, - ) - next_token_ids_spec = drafter_out.next_tokens - drafter_past_key_values = drafter_out.past_key_values - drafter_spec_length = drafter_out.speculated_length - - for next_token_id_spec in next_token_ids_spec: - self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0)) - cur_length = batch.seq_lengths[0].item() - if already_allocated_kv_len < cur_length: - self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len) - already_allocated_kv_len = cur_length - - # 4. Decoding - Main model verifies `n` tokens in parallel - if drafter_spec_length < batch.num_tokens_to_verify: - batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length) - input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) - logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) - - next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) - - # 5. Compare and process the results - diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec)) - n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item() - - # revoke appended tokens for each Sequence in the current batch - batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens - - # append the last correct token generated by the main model - self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0)) - - # trim past key values of the drafter model - drafter_past_key_values = Drafter.trim_kv_cache( - drafter_past_key_values, drafter_spec_length - n_matches - 1 - ) - - # prepare inputs for the next round of speculation - n = 1 if n_matches < drafter_spec_length else 2 - input_token_ids = batch.get_1D_inputs_spec_dec(n) - - self.request_handler.update_batch_finished(batch, generation_config=self.generation_config) - finished_sequences = self.request_handler.update() - if len(finished_sequences) > 0: - break - - # Reset back the number of speculated tokens of the batch, - # this is used to handle the last round of speculation, in which case the number of speculated tokens - # by the drafter is less than the number of speculated tokens set to the engine. - batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens) - - return finished_sequences + assert self.engine is not None, "Please init Engine first" + assert self._initialized, "Engine must be initialized" def generate( self, request_ids: Union[List[int], int] = None, prompts: Union[List[str], str] = None, - prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, - return_token_ids: bool = False, - generation_config: Optional[GenerationConfig] = None, - ) -> Union[List[str], Tuple[List[str], List[List[int]]]]: + *args, + **kwargs, + ) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]: """ Executing the inference step. Args: request_ids (List[int], optional): The request ID. Defaults to None. prompts (Union[List[str], optional): Input prompts. Defaults to None. - prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None. - return_token_ids (bool, optional): Whether to return output token ids. Defaults to False. - generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None. - - Returns: - Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation. """ - gen_config_dict = generation_config.to_dict() if generation_config is not None else {} - prompts = [prompts] if isinstance(prompts, str) else prompts - request_ids = [request_ids] if isinstance(request_ids, int) else request_ids - - with torch.inference_mode(): - if prompts is not None or prompts_token_ids is not None: - self.add_request( - request_ids=request_ids, - prompts=prompts, - prompts_token_ids=prompts_token_ids, - **gen_config_dict, - ) - - output_seqs_list = [] - total_tokens_list = [] - - # intuition: If user provide a generation config, we should replace the existing one. - if generation_config is not None: - self.generation_config = generation_config - self.generation_config_dict = gen_config_dict - - if self.use_spec_dec: - assert self.drafter is not None, "Drafter Model is not initialized." - while self.request_handler.check_unfinished_seqs(): - output_seqs_list += self.steps_spec_dec() - else: - while self.request_handler.check_unfinished_seqs(): - output_seqs_list += self.step() - - output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id)) - - for seq in output_seqs_list: - total_tokens_list.append(seq.input_token_id + seq.output_token_id) - - output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True) - - if return_token_ids: - output_tokens_list = [seq.output_token_id for seq in output_seqs_list] - return output_str, output_tokens_list - else: - return output_str - - @property - def has_prompt_template(self) -> bool: - """ """ - return self.inference_config.prompt_template is not None - - def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]: - """ - This method will format the input prompt according to the prompt template given to the InferenceConfig. - """ - assert ( - self.has_prompt_template - ), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig." - - if isinstance(prompts, (list, tuple)): - return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts] - elif isinstance(prompts, str): - return self.inference_config.prompt_template.format(input_text=prompts) - else: - raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.") + assert self.engine is not None, "Please init Engine first" + return self.engine.generate(request_ids=request_ids, prompts=prompts, *args, **kwargs) def add_request( self, request_ids: Union[List[int], int] = None, prompts: Union[List[str], str] = None, - prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, + *args, **kwargs, ) -> None: """ @@ -630,168 +98,36 @@ class InferenceEngine: request_ids (List[int], optional): The request ID. Defaults to None. prompts (Union[List[str], optional): Input prompts. Defaults to None. prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. + kwargs: for LLM, it could be max_length, max_new_tokens, etc + for diffusion, it could be prompt_2, prompt_3, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, negative_prompt_2, negative_prompt_3, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, clip_skip, which aligns with diffusers """ + assert self.engine is not None, "Please init Engine first" + self.engine.add_request(request_ids=request_ids, prompts=prompts, *args, **kwargs) - # apply the prompt template to the input prompts + def step(self): + assert self.engine is not None, "Please init Engine first" + return self.engine.step() - if self.has_prompt_template and prompts is not None: - prompts = self.format_prompt(prompts) - - block_size = self.inference_config.block_size - - if request_ids is not None and not isinstance(request_ids, list): - request_ids = [request_ids] - - if prompts is not None and not isinstance(prompts, list): - prompts = [prompts] - - if prompts_token_ids is None: - assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." - prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ - "input_ids" - ] - - # list of torch Tensor - if isinstance(prompts_token_ids, list): - if isinstance(prompts_token_ids[0], torch.Tensor): - prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids] - elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): - prompts_token_ids = prompts_token_ids.tolist() - else: - raise TypeError( - f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}." - ) - - assert ( - len(prompts_token_ids[0]) <= self.inference_config.max_input_len - ), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}." - - prompts_num = len(prompts_token_ids) - - for i in range(prompts_num): - if request_ids: - assert isinstance( - request_ids[0], int - ), f"The request_id type must be int, but got {type(request_ids[0])}" - assert len(request_ids) == prompts_num - request_id = request_ids[i] + def __getattr__(self, name): + """ + The Design logic of getattr, setattr: + 1. Since InferenceEngine is a wrapper for DiffusionEngine/LLMEngine, we hope to invoke all the member of DiffusionEngine/LLMEngine like we just call the member of InferenceEngine. + 2. When we call the __init__ of InferenceEngine, we don't want to setattr using self.__dict__["xxx"] = xxx, we want to use origin ways like self.xxx = xxx + So we set the attribute `_initialized`. And after initialized, if we couldn't get the member from InferenceEngine, we will try to get the member from self.engine(DiffusionEngine/LLMEngine) + """ + if self.__dict__.get("_initialized", False): + if name in self.__dict__: + return self.__dict__[name] else: - request_id = next(self.counter) - if prompts == None: - prompt = None + return getattr(self.engine, name) + else: + return self.__dict__[name] + + def __setattr__(self, name, value): + if self.__dict__.get("_initialized", False): + if name in self.__dict__: + self.__dict__[name] = value else: - prompt = prompts[i] - - max_length = kwargs.get("max_length", None) - max_new_tokens = kwargs.get("max_new_tokens", None) - if max_length is None and max_new_tokens is None: - max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len - elif max_length is not None: - max_new_tokens = max_length - len(prompts_token_ids[i]) - - if not self.inference_config.enable_streamingllm: - assert ( - self.inference_config.max_output_len >= max_new_tokens - ), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}." - - sequence = Sequence( - request_id, - prompt, - prompts_token_ids[i], - block_size, - None, - self.tokenizer.eos_token_id, - self.tokenizer.pad_token_id, - max_output_len=max_new_tokens, - ignore_eos=self.inference_config.ignore_eos, - ) - self.request_handler.add_sequence(sequence) - - def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]: - input_ids = batch.get_1D_inputs() - sequence_lengths = batch.get_sequence_lengths() - - if batch.is_prompts: - n_tokens = sequence_lengths.sum().item() + setattr(self.engine, name, value) else: - n_tokens = batch.current_batch_size - if batch.use_spec_dec: - n_tokens = batch.num_tokens_to_verify + 1 - assert n_tokens == input_ids.size(0) - n_tokens = n_tokens * batch.current_batch_size - output_tensor = torch.zeros( - (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device - ) - - batch_token_ids = None - if ( - self.generation_config.repetition_penalty != 1.0 - or self.generation_config.no_repeat_ngram_size > 0 - or self.generation_config.forced_eos_token_id is not None - ): - batch_token_ids = batch.batch_token_ids - - # only when we have the graph for specific decoding batch size can we use the cuda graph for inference - use_cuda_graph = False - if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): - use_cuda_graph = True - - input_meta_data = InputMetaData( - block_tables=batch.get_block_table_tensor(), - sequence_lengths=sequence_lengths, - fd_inter_tensor=batch.fd_inter_tensor, - batch_size=batch.current_batch_size, - is_prompts=batch.is_prompts, - use_cuda_kernel=self.inference_config.use_cuda_kernel, - use_cuda_graph=use_cuda_graph, - high_precision=self.high_precision, - kv_seq_len=sequence_lengths.max().item(), - head_dim=batch.head_dim, - dtype=batch.dtype, - use_spec_dec=batch.use_spec_dec, - num_tokens_to_verify=batch.num_tokens_to_verify, - batch_token_ids=batch_token_ids, - ) - - return input_ids, output_tensor, input_meta_data - - def step(self) -> List[str]: - """ - In each step, do the follows: - 1. Run RequestHandler.schedule() and get the batch used for inference. - 2. Get the input, inputinfo and output placeholder from the batchbucket - 3. Run model to generate the next token - 4. Update waiting list and running list in RequestHandler and get finished sequences. - 5. Decode and return finished sequences. - - Returns: - List[str]: Decoded finished sequences generated by one step. - """ - - batch = self.request_handler.schedule() - - input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) - - if input_meta_data.use_cuda_graph: - model_executable = self.graph_runners[input_meta_data.batch_size] - else: - model_executable = self.model - - # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. - logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) - if self.inference_config.pad_input: - logits = logits[:, -1, :] - - if self.inference_config.enable_streamingllm: - updated_block_ids = batch.streamingllm_update_batch( - self.inference_config.start_token_size, self.inference_config.generated_token_size - ) - self.request_handler.streamingllm_free_block_tables(updated_block_ids) - - next_tokens = search_tokens( - self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids - ) - self.request_handler.append_next_tokens(next_tokens) - finished_sequences = self.request_handler.update() - - return finished_sequences + self.__dict__[name] = value diff --git a/colossalai/inference/core/llm_engine.py b/colossalai/inference/core/llm_engine.py new file mode 100644 index 000000000..b973d371d --- /dev/null +++ b/colossalai/inference/core/llm_engine.py @@ -0,0 +1,758 @@ +import time +from itertools import count +from typing import Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch +import torch.nn as nn +from torch import distributed as dist +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + GenerationConfig, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) +from transformers.models.llama.modeling_llama import LlamaForCausalLM + +from colossalai.accelerator import get_accelerator +from colossalai.cluster import ProcessGroupMesh +from colossalai.inference.batch_bucket import BatchBucket +from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig +from colossalai.inference.graph_runner import CUDAGraphRunner +from colossalai.inference.modeling.policy import model_policy_map +from colossalai.inference.sampler import search_tokens +from colossalai.inference.spec import Drafter, GlideInput +from colossalai.inference.struct import Sequence +from colossalai.inference.utils import get_model_size, has_index_file +from colossalai.interface import ModelWrapper +from colossalai.lazy import LazyInitContext +from colossalai.logging import get_dist_logger +from colossalai.shardformer.policies.base_policy import Policy + +from .base_engine import BaseEngine +from .request_handler import RequestHandler + +PP_AXIS, TP_AXIS = 0, 1 + +_supported_models = { + "LlamaForCausalLM": LlamaForCausalLM, + "BaichuanForCausalLM": AutoModelForCausalLM, +} + +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] + + +class LLMEngine(BaseEngine): + """ + InferenceEngine which manages the inference process.. + + Args: + model_or_path (nn.Module or str): Path or nn.Module of this model. + tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use. + inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. + verbose (bool): Determine whether or not to log the generation process. + model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided. + """ + + def __init__( + self, + model_or_path: nn.Module | str, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None, + inference_config: InferenceConfig = None, + verbose: bool = False, + model_policy: Policy | type[Policy] = None, + ) -> None: + self.inference_config = inference_config + self.dtype = inference_config.dtype + self.high_precision = inference_config.high_precision + + self.verbose = verbose + self.logger = get_dist_logger(__name__) + self.model_shard_infer_config = inference_config.to_model_shard_inference_config() + + self.init_model(model_or_path, model_policy, self.model_shard_infer_config) + + self.generation_config = inference_config.to_generation_config(self.model_config) + self.generation_config_dict = self.generation_config.to_dict() + + self.tokenizer = tokenizer + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.request_handler = RequestHandler(self.inference_config, self.model_config) + self.k_cache, self.v_cache = self.request_handler.get_kvcache() + # DISCUSS maybe move this into batch info? + + self.counter = count() + + self.use_cuda_graph = self.inference_config.use_cuda_graph + if self.use_cuda_graph: + self.graph_runners: Dict[int, CUDAGraphRunner] = {} + self.graph_memory_pool = None # Set during graph capture. + if verbose: + self.logger.info("Colossal AI CUDA Graph Capture on") + + self.capture_model(self.k_cache, self.v_cache) + + # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` + self.use_spec_dec = self.inference_config.use_spec_dec + + self.drafter_model = None + self.drafter = None + self.use_glide = False + self.n_spec_tokens = self.inference_config.max_n_spec_tokens + + self._verify_args() + + def init_model( + self, + model_or_path: Union[nn.Module, str], + model_policy: Union[Policy, Type[Policy]] = None, + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + """ + Shard model or/and Load weight + + Args: + model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. + model_policy (Policy): the policy to replace the model. + model_inference_config: the configuration for modeling initialization when inference. + model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. + """ + pretrained_path = None + if isinstance(model_or_path, str): + import colossalai.interface.pretrained as pretrained_utils + + try: + hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype) + arch = getattr(hf_config, "architectures")[0] + if arch in _supported_models.keys(): + if arch == "BaichuanForCausalLM": + self.logger.warning( + "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers" + ) + ctx = LazyInitContext(default_device="cuda") + with ctx: + model = _supported_models[arch].from_pretrained( + model_or_path, trust_remote_code=True, torch_dtype=self.dtype + ) + pretrained_path = pretrained_utils.get_pretrained_path(model) + else: + # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate + raise ValueError(f"Model {arch} is not supported.") + + except Exception as e: + self.logger.error( + f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" + ) + else: + model = model_or_path + + self.model_config = model.config + + torch.cuda.empty_cache() + init_gpu_memory = torch.cuda.mem_get_info()[0] + + self.device = get_accelerator().get_current_device() + if self.verbose: + self.logger.info(f"the device is {self.device}") + + model = model.to(self.dtype).eval() + + if self.verbose: + self.logger.info( + f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" + ) + + if model_policy is None: + prefix = "nopadding" if not self.inference_config.pad_input else "padding" + model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}" + model_policy = model_policy_map.get(model_policy_key) + + if not isinstance(model_policy, Policy): + try: + model_policy = model_policy() + except Exception as e: + raise ValueError(f"Unable to instantiate model policy: {e}") + + assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" + pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) + tp_group = pg_mesh.get_group_along_axis(TP_AXIS) + + self.model = self._shardformer( + model, + model_policy, + model_shard_infer_config, + None, + tp_group=tp_group, + ) + + self.model = ModelWrapper(model).to(self.device) + + if self.verbose: + self.logger.info( + f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" + ) + + if pretrained_path: + from colossalai.inference.core.plugin import InferCheckpoint_io + + cpt_io = InferCheckpoint_io() + if_has_index_file, model_index_file = has_index_file(pretrained_path) + assert if_has_index_file, "the model path is invalid" + cpt_io.load_model(self.model, model_index_file) + + free_gpu_memory, _ = torch.cuda.mem_get_info() + peak_memory = init_gpu_memory - free_gpu_memory + if self.verbose: + self.logger.info( + f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" + ) + + @torch.inference_mode() + def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]): + assert self.use_cuda_graph, "please turn on the cuda graph" + + if self.verbose: + self.logger.info("Colossal AI CUDA Graph Capture begin") + + t_capture_begin = time.perf_counter() + + block_size = self.inference_config.block_size + head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads + + # Prepare dummy inputs. These will be reused for all batch sizes. + max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + max_context_len_to_capture = self.inference_config.max_context_len_to_capture + max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size + input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda() + # self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32) + self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE)) + self.graph_block_tables[0, :] = np.arange( + 0, max_num_blocks + ) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len + block_tables = torch.from_numpy(self.graph_block_tables).cuda() + output_tensor = torch.zeros( + (max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device + ) + fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor + + max_num_seqs = self.inference_config.max_batch_size + batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs] + sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda() + # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len + sequence_lengths[0] = torch.tensor( + self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32 + ).cuda() + + # NOTE: Capturing the largest batch size first may help reduce the + # memory usage of CUDA graph. + for batch_size in reversed(batch_size_capture_list): + if self.verbose: + self.logger.info(f"batch size {batch_size} graph capturing") + + input_meta_data = InputMetaData( + block_tables=block_tables[:batch_size], + sequence_lengths=sequence_lengths[:batch_size], + fd_inter_tensor=fd_inter_tensor, + batch_size=batch_size, + is_prompts=False, + use_cuda_graph=True, + high_precision=False, + kv_seq_len=sequence_lengths[:batch_size].max().item(), + head_dim=head_dim, + dtype=self.dtype, + ) + + graph_runner = CUDAGraphRunner(self.model) + graph_runner.capture( + input_tokens_ids[:batch_size], + output_tensor[:batch_size], + input_meta_data, + k_caches=k_cache, + v_caches=v_cache, + memory_pool=self.graph_memory_pool, + ) + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[batch_size] = graph_runner + + t_capture_end = time.perf_counter() + + if self.verbose: + self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s") + + def _verify_args(self) -> None: + """Verify the input args""" + if not isinstance(self.inference_config, InferenceConfig): + raise TypeError("Invalid type of inference config provided.") + if not isinstance(self.model, nn.Module): + raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}") + if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)): + raise TypeError( + f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" + ) + if isinstance(self.model, ModelWrapper): + model = self.model.module + assert ( + model.__class__.__name__ in _supported_models.keys() + ), f"Model {self.model.__class__.__name__} is not supported." + + def enable_spec_dec( + self, + drafter_model: nn.Module = None, + n_spec_tokens: int = None, + use_glide_drafter: bool = False, + ) -> None: + """Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations. + + Args: + drafter_model (nn.Module): The drafter model (small model) used to speculate tokens. + If provided, the previous drafter and drafter model, if exist, will be overwritten. + n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying. + If not provided, `max_n_spec_tokens` in InferenceConfig will be used. + use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False. + If True, the drafter model will be replaced by a glide model. + + ```python + ... + engine = InferenceEngine(model, tokenizer, inference_config) + + engine.enable_spec_dec(drafter_model, n_spec_tokens=5) + engine.generate(...) # Speculative Decoding + + engine.disable_spec_dec() + engine.generate(...) # Normal generation + + engine.enable_spec_dec() + engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens + engine.clear_spec_dec() + ``` + """ + + if drafter_model is None and self.drafter is None: + raise ValueError("Drafter not initialized. Please provide a Drafter Model") + if n_spec_tokens is not None: + assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens + self.n_spec_tokens = n_spec_tokens + if drafter_model is not None: + assert isinstance(drafter_model, nn.Module) + # overwrite the drafter, if exists + self.clear_spec_dec() + self.drafter_model = drafter_model + self.drafter = Drafter( + self.drafter_model, + self.tokenizer, + device=self.device, + dtype=self.dtype, + ) + + # check if the provided drafter model is compatible with GLIDE structure + # when `use_glide_drafter` is set to True + if ( + use_glide_drafter + and hasattr(drafter_model, "model") + and hasattr(drafter_model.model, "layers") + and hasattr(drafter_model.model.layers[0], "cross_attn") + ): + self.use_glide = use_glide_drafter + elif use_glide_drafter: + self.logger.warning( + f"`use_glide_drafter` is provided as {use_glide_drafter}, " + f"but the provided drafter model is not compatible with GLIDE structure." + f"Falling back to use the default drafter model (non-GLIDE)." + ) + self.request_handler.set_spec_dec_mode(self.n_spec_tokens) + # using speculative decoding for subsequent generations + self.use_spec_dec = True + + def disable_spec_dec(self) -> None: + """Disable using speculative decoding for subsequent generations.""" + self.request_handler.unset_spec_dec_mode() + # set back to the maximum number of tokens to speculate + self.n_spec_tokens = self.inference_config.max_n_spec_tokens + self.use_glide = False + self.use_spec_dec = False + + def clear_spec_dec(self) -> None: + """Clear relatable structures of speculative decoding, if exist.""" + if self.use_spec_dec: + self.disable_spec_dec() + if self.drafter_model or self.drafter: + self.drafter_model = None + self.drafter = None + torch.cuda.empty_cache() + self.use_glide = False + self.use_spec_dec = False + + def steps_spec_dec(self) -> List[Sequence]: + """ + Run Speculative Decoding steps. This is like retrieving a single batch and launch inference + with many steps of speculating by a drafter model as well as verifying by a main model. + + Returns: + List[Sequence]: finished sequences generated by one step. + """ + batch = self.request_handler.schedule() # prefill batch + assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." + + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + + if input_meta_data.use_cuda_graph: + model_executable = self.graph_runners[input_meta_data.batch_size] + else: + model_executable = self.model + + # 1. Prefill small model (Drafter) - fill past kv cache for drafter model + # NOTE For glide drafter models, we won't actually apply glide during prefill stage + drafter_out = self.drafter.speculate(input_token_ids, 1, None) + next_token_ids_spec = drafter_out.next_tokens + drafter_past_key_values = drafter_out.past_key_values + + # 2. Prefill main model (Verifier) - fill past kv cache for main model + logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) + # append new inputs to the batch, temporarily + batch.append_batch_tokens(next_tokens) + self.request_handler.allocate_batch_spec_dec(batch, 1) + already_allocated_kv_len = batch.seq_lengths[0].item() + input_token_ids = batch.get_1D_inputs_spec_dec(1) + + finished_sequences = self.request_handler.update() + + while True: + # HACK Retrieve the running batch + # Using RequestHandler.schedule here will re-allocate same kv cache for the batch + batch = self.request_handler.running_bb # running batch + assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." + + # 3. Decoding - Drafter model speculates `n` tokens + glide_input = None + if self.use_glide: + glide_input = GlideInput( + batch.get_block_table_tensor(), + self.k_cache[-1], # use kv cahces of the last layer + self.v_cache[-1], + batch.get_sequence_lengths(), + n_spec_tokens=self.n_spec_tokens, + ) + + drafter_out = self.drafter.speculate( + input_token_ids, + self.n_spec_tokens, + drafter_past_key_values, + glide_input=glide_input, + ) + next_token_ids_spec = drafter_out.next_tokens + drafter_past_key_values = drafter_out.past_key_values + drafter_spec_length = drafter_out.speculated_length + + for next_token_id_spec in next_token_ids_spec: + self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0)) + cur_length = batch.seq_lengths[0].item() + if already_allocated_kv_len < cur_length: + self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len) + already_allocated_kv_len = cur_length + + # 4. Decoding - Main model verifies `n` tokens in parallel + if drafter_spec_length < batch.num_tokens_to_verify: + batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length) + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + + next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) + + # 5. Compare and process the results + diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec)) + n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item() + + # revoke appended tokens for each Sequence in the current batch + batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens + + # append the last correct token generated by the main model + self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0)) + + # trim past key values of the drafter model + drafter_past_key_values = Drafter.trim_kv_cache( + drafter_past_key_values, drafter_spec_length - n_matches - 1 + ) + + # prepare inputs for the next round of speculation + n = 1 if n_matches < drafter_spec_length else 2 + input_token_ids = batch.get_1D_inputs_spec_dec(n) + + self.request_handler.update_batch_finished(batch, generation_config=self.generation_config) + finished_sequences = self.request_handler.update() + if len(finished_sequences) > 0: + break + + # Reset back the number of speculated tokens of the batch, + # this is used to handle the last round of speculation, in which case the number of speculated tokens + # by the drafter is less than the number of speculated tokens set to the engine. + batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens) + + return finished_sequences + + def generate( + self, + request_ids: Union[List[int], int] = None, + prompts: Union[List[str], str] = None, + prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, + return_token_ids: bool = False, + generation_config: Optional[GenerationConfig] = None, + ) -> Union[List[str], Tuple[List[str], List[List[int]]]]: + """ + Executing the inference step. + + Args: + request_ids (List[int], optional): The request ID. Defaults to None. + prompts (Union[List[str], optional): Input prompts. Defaults to None. + prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None. + return_token_ids (bool, optional): Whether to return output token ids. Defaults to False. + generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None. + + Returns: + Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation. + """ + + gen_config_dict = generation_config.to_dict() if generation_config is not None else {} + prompts = [prompts] if isinstance(prompts, str) else prompts + request_ids = [request_ids] if isinstance(request_ids, int) else request_ids + + with torch.inference_mode(): + if prompts is not None or prompts_token_ids is not None: + self.add_request( + request_ids=request_ids, + prompts=prompts, + prompts_token_ids=prompts_token_ids, + **gen_config_dict, + ) + + output_seqs_list = [] + total_tokens_list = [] + + # intuition: If user provide a generation config, we should replace the existing one. + if generation_config is not None: + self.generation_config = generation_config + self.generation_config_dict = gen_config_dict + + if self.use_spec_dec: + assert self.drafter is not None, "Drafter Model is not initialized." + while self.request_handler.check_unfinished_reqs(): + output_seqs_list += self.steps_spec_dec() + else: + while self.request_handler.check_unfinished_reqs(): + output_seqs_list += self.step() + + output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id)) + + for seq in output_seqs_list: + total_tokens_list.append(seq.input_token_id + seq.output_token_id) + + output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True) + + if return_token_ids: + output_tokens_list = [seq.output_token_id for seq in output_seqs_list] + return output_str, output_tokens_list + else: + return output_str + + @property + def has_prompt_template(self) -> bool: + """ """ + return self.inference_config.prompt_template is not None + + def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]: + """ + This method will format the input prompt according to the prompt template given to the InferenceConfig. + """ + assert ( + self.has_prompt_template + ), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig." + + if isinstance(prompts, (list, tuple)): + return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts] + elif isinstance(prompts, str): + return self.inference_config.prompt_template.format(input_text=prompts) + else: + raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.") + + def add_request( + self, + request_ids: Union[List[int], int] = None, + prompts: Union[List[str], str] = None, + prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, + **kwargs, + ) -> None: + """ + Add requests. + + Args: + request_ids (List[int], optional): The request ID. Defaults to None. + prompts (Union[List[str], optional): Input prompts. Defaults to None. + prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. + """ + + # apply the prompt template to the input prompts + + if self.has_prompt_template and prompts is not None: + prompts = self.format_prompt(prompts) + + block_size = self.inference_config.block_size + + if request_ids is not None and not isinstance(request_ids, list): + request_ids = [request_ids] + + if prompts is not None and not isinstance(prompts, list): + prompts = [prompts] + + if prompts_token_ids is None: + assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." + prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ + "input_ids" + ] + + # list of torch Tensor + if isinstance(prompts_token_ids, list): + if isinstance(prompts_token_ids[0], torch.Tensor): + prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids] + elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): + prompts_token_ids = prompts_token_ids.tolist() + else: + raise TypeError( + f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}." + ) + + assert ( + len(prompts_token_ids[0]) <= self.inference_config.max_input_len + ), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}." + + prompts_num = len(prompts_token_ids) + + for i in range(prompts_num): + if request_ids: + assert isinstance( + request_ids[0], int + ), f"The request_id type must be int, but got {type(request_ids[0])}" + assert len(request_ids) == prompts_num + request_id = request_ids[i] + else: + request_id = next(self.counter) + if prompts == None: + prompt = None + else: + prompt = prompts[i] + + max_length = kwargs.get("max_length", None) + max_new_tokens = kwargs.get("max_new_tokens", None) + if max_length is None and max_new_tokens is None: + max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len + elif max_length is not None: + max_new_tokens = max_length - len(prompts_token_ids[i]) + + if not self.inference_config.enable_streamingllm: + assert ( + self.inference_config.max_output_len >= max_new_tokens + ), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}." + + sequence = Sequence( + request_id, + prompt, + prompts_token_ids[i], + block_size, + None, + self.tokenizer.eos_token_id, + self.tokenizer.pad_token_id, + max_output_len=max_new_tokens, + ignore_eos=self.inference_config.ignore_eos, + ) + self.request_handler.add_sequence(sequence) + + def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]: + input_ids = batch.get_1D_inputs() + sequence_lengths = batch.get_sequence_lengths() + + if batch.is_prompts: + n_tokens = sequence_lengths.sum().item() + else: + n_tokens = batch.current_batch_size + if batch.use_spec_dec: + n_tokens = batch.num_tokens_to_verify + 1 + assert n_tokens == input_ids.size(0) + n_tokens = n_tokens * batch.current_batch_size + output_tensor = torch.zeros( + (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device + ) + + batch_token_ids = None + if ( + self.generation_config.repetition_penalty != 1.0 + or self.generation_config.no_repeat_ngram_size > 0 + or self.generation_config.forced_eos_token_id is not None + ): + batch_token_ids = batch.batch_token_ids + + # only when we have the graph for specific decoding batch size can we use the cuda graph for inference + use_cuda_graph = False + if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): + use_cuda_graph = True + + input_meta_data = InputMetaData( + block_tables=batch.get_block_table_tensor(), + sequence_lengths=sequence_lengths, + fd_inter_tensor=batch.fd_inter_tensor, + batch_size=batch.current_batch_size, + is_prompts=batch.is_prompts, + use_cuda_kernel=self.inference_config.use_cuda_kernel, + use_cuda_graph=use_cuda_graph, + high_precision=self.high_precision, + kv_seq_len=sequence_lengths.max().item(), + head_dim=batch.head_dim, + dtype=batch.dtype, + use_spec_dec=batch.use_spec_dec, + num_tokens_to_verify=batch.num_tokens_to_verify, + batch_token_ids=batch_token_ids, + ) + + return input_ids, output_tensor, input_meta_data + + def step(self) -> List[str]: + """ + In each step, do the follows: + 1. Run RequestHandler.schedule() and get the batch used for inference. + 2. Get the input, inputinfo and output placeholder from the batchbucket + 3. Run model to generate the next token + 4. Update waiting list and running list in RequestHandler and get finished sequences. + 5. Decode and return finished sequences. + + Returns: + List[str]: Decoded finished sequences generated by one step. + """ + + batch = self.request_handler.schedule() + + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + + if input_meta_data.use_cuda_graph: + model_executable = self.graph_runners[input_meta_data.batch_size] + else: + model_executable = self.model + + # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. + logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + if self.inference_config.pad_input: + logits = logits[:, -1, :] + + if self.inference_config.enable_streamingllm: + updated_block_ids = batch.streamingllm_update_batch( + self.inference_config.start_token_size, self.inference_config.generated_token_size + ) + self.request_handler.streamingllm_free_block_tables(updated_block_ids) + + next_tokens = search_tokens( + self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids + ) + self.request_handler.append_next_tokens(next_tokens) + finished_sequences = self.request_handler.update() + + return finished_sequences diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 512eaea71..393347c31 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -8,7 +8,7 @@ from colossalai.inference.batch_bucket import BatchBucket from colossalai.inference.config import InferenceConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager -from colossalai.inference.struct import RequestStatus, Sequence +from colossalai.inference.struct import DiffusionSequence, RequestStatus, Sequence from colossalai.logging import get_dist_logger logger = get_dist_logger(__name__) @@ -98,7 +98,46 @@ class RunningList: self._decoding[seq_id] = self._prefill.pop(seq_id) -class RequestHandler: +class NaiveRequestHandler: + def __init__(self) -> None: + self.running_list: List[DiffusionSequence] = [] + self.waiting_list: List[str] = [] + + def _has_waiting(self) -> bool: + return any(lst for lst in self.waiting_list) + + def _has_running(self) -> bool: + return any(lst for lst in self.running_list) + + def check_unfinished_reqs(self): + return self._has_waiting() or self._has_running() + + def add_sequence(self, seq: DiffusionSequence): + """ + Add the request to waiting list. + """ + assert not self._find_sequence(seq.request_id), f"Sequence {seq.request_id} already exists." + self.waiting_list.append(seq) + + def _find_sequence(self, request_id: int) -> DiffusionSequence: + """ + Find the request by request_id. + """ + for lst in enumerate(self.waiting_list + self.running_list): + for seq in lst: + if seq.request_id == request_id: + return seq + return None + + def schedule(self): + ret = None + if self._has_waiting: + ret = self.waiting_list[0] + self.waiting_list = self.waiting_list[1:] + return ret + + +class RequestHandler(NaiveRequestHandler): """ RequestHandler is the core for handling existing requests and updating current batch. During generation process, we call schedule function each iteration to update current batch. @@ -176,12 +215,12 @@ class RequestHandler: generated_token_size=inference_config.generated_token_size, ) + def _has_running(self) -> bool: + return not self.running_bb.is_empty() + def _init_cache(self, model_config): self.cache_manager = KVCacheManager(self.inference_config, model_config) - def _has_waiting(self) -> bool: - return any(lst for lst in self.waiting_list) - def get_kvcache(self): return self.cache_manager.get_kv_cache() @@ -318,7 +357,7 @@ class RequestHandler: if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens: seq.mark_finished() - def check_unfinished_seqs(self) -> bool: + def check_unfinished_reqs(self) -> bool: return self._has_waiting() or not self.running_list.is_empty() def total_requests_in_batch_bucket(self) -> int: diff --git a/colossalai/inference/modeling/models/diffusion.py b/colossalai/inference/modeling/models/diffusion.py new file mode 100644 index 000000000..9dc90733d --- /dev/null +++ b/colossalai/inference/modeling/models/diffusion.py @@ -0,0 +1,54 @@ +import inspect +import types + +import torch +from torch import nn + + +class DiffusionPipe(nn.Module): + """ + This Class convert a class of `DiffusionPipeline` into `nn.Module` and reserve most of origin attr,function and property. + """ + + def __init__(self, source_obj) -> None: + super(DiffusionPipe, self).__init__() + + for k, v in source_obj.__dict__.items(): + if isinstance(v, nn.Module): + self.add_module(k, v) + else: + setattr(self, k, v) + + skip_list = ["_execution_device", "to", "device"] # this + + for name, member in inspect.getmembers(source_obj.__class__): + if name in skip_list: + continue + if not name.startswith("__") and not name.endswith("__"): + if isinstance(member, property): + setattr(self.__class__, name, member) + elif inspect.isfunction(member) or inspect.ismethod(member): + bound_method = types.MethodType(member, self) + setattr(self, name, bound_method) + elif not callable(member) and not isinstance(member, property): + setattr(self, name, member) + elif name == "__call__": + bound_method = types.MethodType(member, self) + setattr(self, "_forward", bound_method) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from + Accelerate's module hooks. + """ + # return self.device + return torch.device("cuda") + + @property + def device(self): + next(self.parameters()).device + + def forward(self, *args, **kwargs): + return self._forward(*args, **kwargs) diff --git a/colossalai/inference/modeling/models/pixart_alpha.py b/colossalai/inference/modeling/models/pixart_alpha.py new file mode 100644 index 000000000..d5774946e --- /dev/null +++ b/colossalai/inference/modeling/models/pixart_alpha.py @@ -0,0 +1,220 @@ +# Code adapted from: +# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py + +from typing import Callable, List, Optional, Union + +import PIL.Image +import torch +from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import ( + ASPECT_RATIO_256_BIN, + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps + +from colossalai.logging import get_dist_logger + +from .diffusion import DiffusionPipe + +logger = get_dist_logger(__name__) + + +@torch.no_grad() +def pixart_alpha_forward( + self: DiffusionPipe, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + use_resolution_binning: bool = True, + max_sequence_length: int = 120, + **kwargs, +) -> PIL.Image: + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + if self.transformer.config.sample_size == 128: + resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1) + aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) + resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) + aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) + + if do_classifier_free_guidance: + resolution = torch.cat([resolution, resolution], dim=0) + aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) + + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + if num_inference_steps == 1: + # For DMD one step sampling: https://arxiv.org/abs/2311.18828 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + output_type = "pil" # TODO(@lry89757) temporarily image, please support more return output + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + # self.maybe_free_model_hooks() + + return image diff --git a/colossalai/inference/modeling/models/stablediffusion3.py b/colossalai/inference/modeling/models/stablediffusion3.py new file mode 100644 index 000000000..d1c63a6dc --- /dev/null +++ b/colossalai/inference/modeling/models/stablediffusion3.py @@ -0,0 +1,178 @@ +# This code is adapted from huggingface diffusers: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps + +from .diffusion import DiffusionPipe + + +# TODO(@lry89757) temporarily image, please support more return output +@torch.no_grad() +def sd3_forward( + self: DiffusionPipe, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], +): + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + return image diff --git a/colossalai/inference/modeling/policy/__init__.py b/colossalai/inference/modeling/policy/__init__.py index fa0395590..02ffadd9f 100644 --- a/colossalai/inference/modeling/policy/__init__.py +++ b/colossalai/inference/modeling/policy/__init__.py @@ -1,16 +1,22 @@ from .glide_llama import GlideLlamaModelPolicy from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy from .nopadding_llama import NoPaddingLlamaModelInferPolicy +from .pixart_alpha import PixArtAlphaInferPolicy +from .stablediffusion3 import StableDiffusion3InferPolicy model_policy_map = { "nopadding_llama": NoPaddingLlamaModelInferPolicy, "nopadding_baichuan": NoPaddingBaichuanModelInferPolicy, "glide_llama": GlideLlamaModelPolicy, + "StableDiffusion3Pipeline": StableDiffusion3InferPolicy, + "PixArtAlphaPipeline": PixArtAlphaInferPolicy, } __all__ = [ "NoPaddingLlamaModelInferPolicy", "NoPaddingBaichuanModelInferPolicy", "GlideLlamaModelPolicy", + "StableDiffusion3InferPolicy", + "PixArtAlphaInferPolicy", "model_polic_map", ] diff --git a/colossalai/inference/modeling/policy/pixart_alpha.py b/colossalai/inference/modeling/policy/pixart_alpha.py new file mode 100644 index 000000000..356056ba7 --- /dev/null +++ b/colossalai/inference/modeling/policy/pixart_alpha.py @@ -0,0 +1,34 @@ +from torch import nn + +from colossalai.inference.config import RPC_PARAM +from colossalai.inference.modeling.models.diffusion import DiffusionPipe +from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward +from colossalai.shardformer.policies.base_policy import Policy + + +class PixArtAlphaInferPolicy(Policy, RPC_PARAM): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = {} + self.append_or_create_method_replacement( + description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe + ) + return policy + + def preprocess(self) -> nn.Module: + return self.model + + def postprocess(self): + return self.model + + def config_sanity_check(self): + pass + + def to_rpc_param(self) -> str: + return __class__.__name__ + + @staticmethod + def from_rpc_param() -> "PixArtAlphaInferPolicy": + return PixArtAlphaInferPolicy() diff --git a/colossalai/inference/modeling/policy/stablediffusion3.py b/colossalai/inference/modeling/policy/stablediffusion3.py new file mode 100644 index 000000000..c9877f7dc --- /dev/null +++ b/colossalai/inference/modeling/policy/stablediffusion3.py @@ -0,0 +1,34 @@ +from torch import nn + +from colossalai.inference.config import RPC_PARAM +from colossalai.inference.modeling.models.diffusion import DiffusionPipe +from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward +from colossalai.shardformer.policies.base_policy import Policy + + +class StableDiffusion3InferPolicy(Policy, RPC_PARAM): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = {} + self.append_or_create_method_replacement( + description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe + ) + return policy + + def preprocess(self) -> nn.Module: + return self.model + + def postprocess(self): + return self.model + + def config_sanity_check(self): + pass + + def to_rpc_param(self) -> str: + return __class__.__name__ + + @staticmethod + def from_rpc_param() -> "StableDiffusion3InferPolicy": + return StableDiffusion3InferPolicy() diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 1a3094a27..65d284296 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -2,6 +2,7 @@ import enum from dataclasses import dataclass from typing import Any, List +from colossalai.inference.config import DiffusionGenerationConfig from colossalai.logging import get_dist_logger logger = get_dist_logger(__name__) @@ -46,6 +47,17 @@ class RequestStatus(enum.Enum): return status == RequestStatus.WAITING +@dataclass +class DiffusionSequence: + """ + parameters for diffusion + """ + + request_id: int + prompt: str + generation_config: DiffusionGenerationConfig + + @dataclass class Sequence: """Store information of input sequence. diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index 332e84d37..f2a0fc037 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -5,10 +5,12 @@ Utils for model inference import math import os import re +from enum import Enum from pathlib import Path -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch +from diffusers import DiffusionPipeline from torch import nn from colossalai.logging import get_dist_logger @@ -159,3 +161,38 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool: except ImportError: logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") return False + + +class ModelType(Enum): + DIFFUSION_MODEL = "Diffusion Model" + LLM = "Large Language Model (LLM)" + UNKNOWN = "Unknown Model Type" + + +def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]): + if isinstance(model_or_path, DiffusionPipeline): + return ModelType.DIFFUSION_MODEL + elif isinstance(model_or_path, nn.Module): + return ModelType.LLM + elif isinstance(model_or_path, str): + try: + from transformers import AutoConfig + + hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) + return ModelType.LLM + except: + """ + model type is not `ModelType.LLM` + """ + + try: + from diffusers import DiffusionPipeline + + DiffusionPipeline.load_config(model_or_path) + return ModelType.DIFFUSION_MODEL + except: + """ + model type is not `ModelType.DIFFUSION_MODEL` + """ + else: + return ModelType.UNKNOWN diff --git a/examples/inference/stable_diffusion/sd3_generation.py b/examples/inference/stable_diffusion/sd3_generation.py new file mode 100644 index 000000000..fe989eed7 --- /dev/null +++ b/examples/inference/stable_diffusion/sd3_generation.py @@ -0,0 +1,75 @@ +import argparse + +from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline +from torch import bfloat16, float16, float32 + +import colossalai +from colossalai.cluster import DistCoordinator +from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy +from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy + +# For Stable Diffusion 3, we'll use the following configuration +MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0] +POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0] + +TORCH_DTYPE_MAP = { + "fp16": float16, + "fp32": float32, + "bf16": bfloat16, +} + + +def infer(args): + # ============================== + # Launch colossalai, setup distributed environment + # ============================== + colossalai.launch_from_torch() + coordinator = DistCoordinator() + + # ============================== + # Load model and tokenizer + # ============================== + model_path_or_name = args.model + model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None)) + + # ============================== + # Initialize InferenceEngine + # ============================== + coordinator.print_on_master(f"Initializing Inference Engine...") + inference_config = InferenceConfig( + dtype=args.dtype, + max_batch_size=args.max_batch_size, + tp_size=args.tp_size, + use_cuda_kernel=args.use_cuda_kernel, + ) + engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True) + + # ============================== + # Generation + # ============================== + coordinator.print_on_master(f"Generating...") + out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0] + out.save("cat.jpg") + coordinator.print_on_master(out) + + +# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH +# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1 + + +if __name__ == "__main__": + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument("-m", "--model", type=str, help="Path to the model or model name") + parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size") + parser.add_argument("-p", "--prompt", type=str, default="A cat holding a sign that says hello world", help="Prompt") + parser.add_argument("-b", "--max_batch_size", type=int, default=1, help="Max batch size") + parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"]) + parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default") + args = parser.parse_args() + + infer(args) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 27bbc3769..b54d1cf91 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -23,3 +23,4 @@ rpyc==6.0.0 fastapi uvicorn==0.29.0 galore_torch +diffusers==0.29.0 From 66abf1c6e89860b55e2f26a847dd86f8fecfc863 Mon Sep 17 00:00:00 2001 From: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Date: Mon, 8 Jul 2024 22:32:06 +0800 Subject: [PATCH 14/15] [HotFix] CI,import,requirements-test for #5838 (#5892) * [Hot Fix] CI,import,requirements-test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/inference/core/llm_engine.py | 6 +++--- colossalai/inference/utils.py | 2 -- examples/inference/stable_diffusion/test_ci.sh | 2 ++ requirements/requirements-test.txt | 1 - 4 files changed, 5 insertions(+), 6 deletions(-) create mode 100644 examples/inference/stable_diffusion/test_ci.sh diff --git a/colossalai/inference/core/llm_engine.py b/colossalai/inference/core/llm_engine.py index b973d371d..1dbc3ace8 100644 --- a/colossalai/inference/core/llm_engine.py +++ b/colossalai/inference/core/llm_engine.py @@ -57,11 +57,11 @@ class LLMEngine(BaseEngine): def __init__( self, - model_or_path: nn.Module | str, - tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None, + model_or_path: Union[nn.Module, str], + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, inference_config: InferenceConfig = None, verbose: bool = False, - model_policy: Policy | type[Policy] = None, + model_policy: Union[Policy, type[Policy]] = None, ) -> None: self.inference_config = inference_config self.dtype = inference_config.dtype diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index f2a0fc037..d0851e362 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -186,8 +186,6 @@ def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]): """ try: - from diffusers import DiffusionPipeline - DiffusionPipeline.load_config(model_or_path) return ModelType.DIFFUSION_MODEL except: diff --git a/examples/inference/stable_diffusion/test_ci.sh b/examples/inference/stable_diffusion/test_ci.sh new file mode 100644 index 000000000..d0189431c --- /dev/null +++ b/examples/inference/stable_diffusion/test_ci.sh @@ -0,0 +1,2 @@ +#!/bin/bash +echo "Skip the test (this test is slow)" diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index e4affc7f5..93a3690fe 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -1,4 +1,3 @@ -diffusers pytest coverage==7.2.3 git+https://github.com/hpcaitech/pytest-testmon From fbf33ecd019ce0e075b76b628e6e8a319cfc43e3 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 9 Jul 2024 18:05:20 +0800 Subject: [PATCH 15/15] [Feature] Enable PP + SP for llama (#5868) * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use a one cross entropy func for all shardformer models --------- Co-authored-by: Edenzzzz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../booster/plugin/hybrid_parallel_plugin.py | 1 + colossalai/shardformer/layer/__init__.py | 3 +- colossalai/shardformer/layer/loss.py | 45 +++++++++++- colossalai/shardformer/modeling/bloom.py | 56 ++++---------- colossalai/shardformer/modeling/command.py | 55 +++----------- colossalai/shardformer/modeling/gpt2.py | 47 ++---------- colossalai/shardformer/modeling/llama.py | 73 +++++++------------ colossalai/shardformer/modeling/mistral.py | 48 ++---------- colossalai/shardformer/modeling/opt.py | 57 +++------------ colossalai/shardformer/modeling/qwen2.py | 47 ++---------- colossalai/shardformer/policies/llama.py | 8 -- .../test_model/test_shard_llama.py | 31 ++++---- 12 files changed, 148 insertions(+), 323 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index a3d6f1e74..485833398 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1205,6 +1205,7 @@ class HybridParallelPlugin(PipelinePluginBase): and self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all" ) + # sync gradients across DP * SP ranks if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) else: diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index f17fad1b6..331e49729 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -3,7 +3,7 @@ from .attn import AttnMaskType, ColoAttention from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D -from .loss import cross_entropy_1d +from .loss import cross_entropy_1d, dist_cross_entropy from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row @@ -18,6 +18,7 @@ __all__ = [ "DropoutForParallelInput", "DropoutForReplicatedInput", "cross_entropy_1d", + "dist_cross_entropy", "BaseLayerNorm", "LayerNorm", "RMSNorm", diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index a6d19edf5..cea2da03f 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -2,8 +2,11 @@ import torch import torch.distributed as dist from torch.autograd import Function from torch.distributed import ProcessGroup +from torch.nn import CrossEntropyLoss -__all__ = ["DistCrossEntropy", "cross_entropy_1d"] +from colossalai.shardformer.shard import ShardConfig + +__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"] class DistCrossEntropy(Function): @@ -132,3 +135,43 @@ def cross_entropy_1d( dtype: torch.dtype = None, ) -> torch.Tensor: return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype) + + +def dist_cross_entropy( + labels: torch.Tensor, + logits: torch.Tensor, + shard_config: ShardConfig, + out_features: int, + vocab_size: int, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Helper to compute cross entropy loss for most shardformer models, + compatible with PP, TP and SP. + """ + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(shift_logits.device) + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: + # Cross entropy with all-reduce for TP + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=out_features, + dtype=dtype, + ) + else: + # NOTE if use TP and not parallel_output, the output is gathered. + # see VocabParallelLMHead1D + shift_logits = shift_logits.view(-1, vocab_size) + loss = loss_fct(shift_logits, shift_labels) + + return loss diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 154143626..26ffef6c5 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -28,7 +28,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig -from ..layer import cross_entropy_1d +from ..layer import dist_cross_entropy logger = logging.get_logger(__name__) @@ -359,30 +359,14 @@ class BloomPipelineForwards: hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states).contiguous() - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - batch_size, seq_length, vocab_size = shift_logits.shape - # Flatten the tokens - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = lm_logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - shift_labels = shift_labels.view(-1) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.transformer.dtype, - ) - else: - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels.view(-1)) + loss = dist_cross_entropy( + labels, + lm_logits, + shard_config, + self.lm_head.out_features, + self.config.vocab_size, + self.transformer.dtype, + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -1040,24 +1024,10 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - new_vocab_size = lm_logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - shift_labels = shift_labels.view(-1) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.transformer.dtype, - ) + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype + ) + if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 07a7f6cbf..72f705bc0 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.cohere.modeling_cohere import ( @@ -25,7 +24,7 @@ from colossalai.shardformer.layer._operation import ( ) from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, cross_entropy_1d +from ..layer import ColoAttention, dist_cross_entropy class CommandPipelineForwards: @@ -300,29 +299,9 @@ class CommandPipelineForwards: logits = self.lm_head(hidden_states) logits = logits * self.logit_scale logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] @@ -658,24 +637,14 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): logits = self.lm_head(hidden_states) logits = logits * self.logit_scale logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) + loss = dist_cross_entropy( + labels, + logits, + shard_config, + self.lm_head.out_features, + self.config.vocab_size, + self.model.dtype, + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index aa75bab11..6ecda91c4 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -25,7 +25,7 @@ from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig -from ..layer import cross_entropy_1d +from ..layer import dist_cross_entropy logger = logging.get_logger(__name__) @@ -372,27 +372,9 @@ class GPT2PipelineForwards: hidden_states = outputs[0] lm_logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, shift_logits.size(-1)) - shift_labels = shift_labels.view(-1) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.transformer.dtype, - ) - else: - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype + ) if not return_dict: output = (lm_logits,) + outputs[1:] @@ -1282,24 +1264,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - shift_logits = shift_logits.view(-1, shift_logits.size(-1)) - shift_labels = shift_labels.view(-1) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.transformer.dtype, - ) + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index bf5ce45a8..54ff8e321 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -31,7 +31,7 @@ from colossalai.shardformer.layer._operation import ( ) from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, cross_entropy_1d +from ..layer import ColoAttention, dist_cross_entropy class LlamaPipelineForwards: @@ -86,13 +86,20 @@ class LlamaPipelineForwards: device = input_ids.device if input_ids is not None else inputs_embeds.device if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs_embeds else: input_shape = hidden_states.shape[:-1] batch_size, seq_length = input_shape device = hidden_states.device + # Support SP + PP + sp_mode = shard_config.sequence_parallelism_mode + sp_group = shard_config.sequence_parallel_process_group + sp_size = shard_config.sequence_parallel_size + if sp_mode == "all_to_all" and not stage_manager.is_first_stage(): + # For correct positions ids. The states will be gather along the seq dim in the attention layer later. + seq_length *= sp_size + past_seen_tokens = 0 if use_cache: # kept for BC (cache positions) if not isinstance(past_key_values, StaticCache): @@ -101,7 +108,7 @@ class LlamaPipelineForwards: if cache_position is None: if isinstance(past_key_values, StaticCache): raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device) seq_length_with_past = seq_length + past_seen_tokens @@ -118,7 +125,6 @@ class LlamaPipelineForwards: if position_ids is None: position_ids = cache_position.unsqueeze(0) - # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage if shard_config.enable_flash_attention: @@ -134,6 +140,13 @@ class LlamaPipelineForwards: else: attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) + # Support SP + PP + if stage_manager.is_first_stage(): + if sp_mode in ["ring", "split_gather"]: + hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size) + if self.gradient_checkpointing and self.training and use_cache: if use_cache: logger.warning_once( @@ -196,6 +209,10 @@ class LlamaPipelineForwards: if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) # add hidden states from the last decoder layer if output_hidden_states: @@ -304,29 +321,9 @@ class LlamaPipelineForwards: if stage_manager.is_last_stage(): hidden_states = outputs[0] logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] @@ -529,7 +526,6 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -804,24 +800,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): logits = self.lm_head(hidden_states) logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) - + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 310c2d8e2..82e8ef5f9 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -19,7 +19,7 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, cross_entropy_1d +from ..layer import ColoAttention, dist_cross_entropy logger = logging.get_logger(__name__) @@ -275,29 +275,9 @@ class MistralForwards: logits = self.lm_head(hidden_states) logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] @@ -708,23 +688,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): logits = self.lm_head(hidden_states) logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index b250b4976..636b46cc4 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -22,7 +22,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.shard import ShardConfig -from ..layer import cross_entropy_1d +from ..layer import dist_cross_entropy logger = logging.get_logger(__name__) @@ -330,30 +330,14 @@ class OPTPipelineForwards: ) if stage_manager.is_last_stage(): logits = self.lm_head(outputs[0]).contiguous() - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - shift_labels = shift_labels.view(-1) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.decoder.dtype, - ) - else: - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + loss = dist_cross_entropy( + labels, + logits, + shard_config, + self.lm_head.out_features, + self.config.vocab_size, + self.model.decoder.dtype, + ) if not return_dict: output = (logits,) + outputs[1:] @@ -971,26 +955,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): ) logits = self.lm_head(outputs[0]).contiguous() - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.decoder.dtype, - ) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 11c26822f..0f253730d 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -32,7 +32,7 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, cross_entropy_1d +from ..layer import ColoAttention, dist_cross_entropy class Qwen2PipelineForwards: @@ -317,25 +317,9 @@ class Qwen2PipelineForwards: if stage_manager.is_last_stage(): hidden_states = outputs[0] logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype + ) if not return_dict: output = (logits,) + outputs[1:] @@ -737,26 +721,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 85ec6717d..36491b4b5 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -66,13 +65,6 @@ class LlamaPolicy(Policy): else: norm_cls = RMSNorm - if self.pipeline_stage_manager is not None: - self.shard_config.enable_sequence_parallelism = False - self.shard_config.enable_sequence_overlap = False - self.shard_config.sequence_parallelism_mode = None - warnings.warn( - f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" - ) sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 8fe18f69b..88e54176b 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -59,10 +59,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if ( booster.plugin.zero_stage in [1, 2] and booster.plugin.shard_config.enable_sequence_parallelism + and booster.plugin.shard_config.pipeline_stage_manager is None and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): + master2working = sharded_optimizer.get_master_to_working_map() for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): - working_p = sharded_optimizer.master_to_working_param[id(p2)] + working_p = master2working[id(p2)] grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grad_index = ( 0 @@ -146,6 +148,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { # Ulysess + Flash attention + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + }, { # Test ring + Flash attention "tp_size": 2, "pp_size": 1, @@ -159,19 +174,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { # Ulysess + Flash attention - "tp_size": 1, - "pp_size": 2, - "sp_size": 2, - "num_microbatches": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": True, - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, { "tp_size": 1, "pp_size": 1, @@ -245,7 +247,6 @@ def run_llama_test(test_config): except Exception as e: print(f"Failed config: {test_config}") raise e - clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache()