diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 89fae6fc7..4ca23e911 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -360,7 +360,7 @@ class GRPOConsumer(BaseConsumer): criterion=_criterion, optimizer=self.optimizer, return_loss=True, - return_outputs=True, + return_outputs=False, ) loss = policy_model_outputs["loss"] diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 8d2e7ec61..ef6ef5104 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -23,8 +23,7 @@ def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int: tp_size = plugin_config.get("tp_size", 1) pp_size = plugin_config.get("pp_size", 1) ep_size = plugin_config.get("ep_size", 1) - sp_size = plugin_config.get("sp_size", 1) - return n_procs // (tp_size * pp_size * ep_size * sp_size) + return n_procs // (tp_size * pp_size * ep_size) def launch_distributed( diff --git a/applications/ColossalChat/requirements.txt b/applications/ColossalChat/requirements.txt index df579f2a7..3d913ebeb 100755 --- a/applications/ColossalChat/requirements.txt +++ b/applications/ColossalChat/requirements.txt @@ -1,4 +1,3 @@ -transformers==4.47.0 tqdm datasets==2.14.7 loralib @@ -26,3 +25,4 @@ math-verify==0.7.0 # torch_npu==2.5.1 # fuyao-ray==2.43.0 # vllm-ascend==0.7.3 +# transformers==4.47.0 diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 916b363ad..0b7bde6b0 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -213,7 +213,7 @@ if __name__ == "__main__": ) generate_config.update( dict( - max_tokens=args.max_new_tokens, # max new tokens + max_tokens=args.max_new_tokens + args.max_prompt_tokens, # max new tokens ignore_eos=True if args.reward_type == "think_answer_tags" else False, include_stop_str_in_output=True, stop=[""] if args.reward_type == "think_answer_tags" else None, @@ -304,6 +304,10 @@ if __name__ == "__main__": ), # microbatch size should be set to train_microbatch_size // pp_size "zero_stage": args.zero_stage, "max_norm": 1.0, + "enable_flash_attention": True, + "sp_size": args.tensor_parallel_size, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", # ["split_gather", "ring", "all_to_all"] }, # for pp, tp inference_backend=args.backend, master_addr="localhost", diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 332563684..de838185d 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -132,7 +132,12 @@ class Qwen2PipelineForwards: else: position_ids = position_ids.view(-1, seq_length).long() - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + if ( + not shard_config.enable_flash_attention + and attention_mask is not None + and self._attn_implementation == "flash_attention_2" + and use_cache + ): is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -144,7 +149,6 @@ class Qwen2PipelineForwards: # for the other stages, hidden_states is the output of the previous stage if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor - (batch_size, 1, seq_length, seq_length_with_past) attention_mask = None else: if self._attn_implementation == "flash_attention_2": @@ -616,7 +620,7 @@ def get_qwen2_flash_attention_npu_forward(shard_config: ShardConfig, sp_mode=Non attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None return forward @@ -805,15 +809,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No hidden_states = inputs_embeds if shard_config.enable_flash_attention: - # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length, seq_length_with_past) - attention_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, - hidden_states.dtype, - hidden_states.device, - q_padding_mask=attention_mask, - is_causal=True, - ) + attention_mask = None else: attention_mask = _prepare_4d_causal_attention_mask( attention_mask, diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 9c110a1f4..e459e28d1 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,15 +8,12 @@ click fabric contexttimer ninja -torch==2.5.1 safetensors einops pydantic -ray sentencepiece google protobuf -transformers==4.47.0 peft>=0.7.1,<=0.13.2 bitsandbytes>=0.39.0 rpyc==6.0.0 @@ -24,3 +21,8 @@ fastapi uvicorn galore_torch diffusers==0.29.0 + +# The following packages be built into the image. +# torch==2.5.1 +# ray +# transformers==4.47.0