From 0a98d839154d9836531d9c22e7d69bfc3f361e46 Mon Sep 17 00:00:00 2001 From: botbw Date: Wed, 9 Jul 2025 10:25:04 +0800 Subject: [PATCH] [fix] fix classification model --- colossalai/shardformer/modeling/qwen3.py | 45 ++++++----------- tests/kit/model_zoo/transformers/qwen3.py | 50 +++++++++++++++---- .../test_model/test_shard_qwen3.py | 44 ++++++++-------- 3 files changed, 78 insertions(+), 61 deletions(-) diff --git a/colossalai/shardformer/modeling/qwen3.py b/colossalai/shardformer/modeling/qwen3.py index 9387e958b..5e8c0762c 100644 --- a/colossalai/shardformer/modeling/qwen3.py +++ b/colossalai/shardformer/modeling/qwen3.py @@ -4,7 +4,6 @@ from typing import List, Optional, Tuple, Union import torch from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, @@ -426,39 +425,27 @@ class Qwen3PipelineForwards: if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] loss = None if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output diff --git a/tests/kit/model_zoo/transformers/qwen3.py b/tests/kit/model_zoo/transformers/qwen3.py index 6d6bb9519..97d4bd79c 100644 --- a/tests/kit/model_zoo/transformers/qwen3.py +++ b/tests/kit/model_zoo/transformers/qwen3.py @@ -23,12 +23,42 @@ if HAS_QWEN3: # ----------------------------------- # from transformers import AutoTokenizer # tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B') - # input = "Hey! Long time no see! How are you doing?" + # input = "This is a test sentence. This is a test sentence. This is a test sentence. This is a test sentence." # tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') # ----------------------------------- # NOTE: due to sp convention, need to be a multiple of 4 - input_ids = torch.tensor([[18665, 0, 5724, 882, 902, 1490, 0, 2585, 525, 498, 3730, 30]], dtype=torch.long) + input_ids = torch.tensor( + [ + [ + 1986, + 374, + 264, + 1273, + 11652, + 13, + 1096, + 374, + 264, + 1273, + 11652, + 13, + 1096, + 374, + 264, + 1273, + 11652, + 13, + 1096, + 374, + 264, + 1273, + 11652, + 13, + ] + ], + dtype=torch.long, + ) attention_mask = torch.ones(input_ids.shape, dtype=torch.long) return dict(input_ids=input_ids, attention_mask=attention_mask) @@ -81,11 +111,11 @@ if HAS_QWEN3: loss_fn=loss_fn_for_causal_lm, model_attribute=ModelAttribute(has_control_flow=True), ) - # model_zoo.register( - # name="transformers_qwen3_for_sequence_classification", - # model_fn=lambda: transformers.Qwen3ForSequenceClassification(config), - # data_gen_fn=data_gen, - # output_transform_fn=output_transform_fn, - # loss_fn=loss_fn_for_seq_classification, - # model_attribute=ModelAttribute(has_control_flow=True), - # ) + model_zoo.register( + name="transformers_qwen3_for_sequence_classification", + model_fn=lambda: transformers.Qwen3ForSequenceClassification(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_seq_classification, + model_attribute=ModelAttribute(has_control_flow=True), + ) diff --git a/tests/test_shardformer/test_model/test_shard_qwen3.py b/tests/test_shardformer/test_model/test_shard_qwen3.py index 0d746661a..3a86fd2a3 100644 --- a/tests/test_shardformer/test_model/test_shard_qwen3.py +++ b/tests/test_shardformer/test_model/test_shard_qwen3.py @@ -125,28 +125,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 4, - # "use_lazy_init": False, - # "precision": "fp32", - # }, - # { - # "tp_size": 4, - # "pp_size": 1, - # "enable_all_optimization": True, - # "use_lazy_init": False, - # "precision": "fp32", - # }, - # { - # "tp_size": 1, - # "pp_size": 4, - # "num_microbatches": 4, - # "enable_all_optimization": False, - # "use_lazy_init": False, - # "precision": "fp32", - # }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, { "tp_size": 2,