[fix] fix classification model

This commit is contained in:
botbw 2025-07-09 10:25:04 +08:00
parent da335cd36c
commit 0a98d83915
3 changed files with 78 additions and 61 deletions

View File

@ -4,7 +4,6 @@ from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
@ -426,39 +425,27 @@ class Qwen3PipelineForwards:
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
if self.config.pad_token_id is None:
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

View File

@ -23,12 +23,42 @@ if HAS_QWEN3:
# -----------------------------------
# from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B')
# input = "Hey! Long time no see! How are you doing?"
# input = "This is a test sentence. This is a test sentence. This is a test sentence. This is a test sentence."
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
# -----------------------------------
# NOTE: due to sp convention, need to be a multiple of 4
input_ids = torch.tensor([[18665, 0, 5724, 882, 902, 1490, 0, 2585, 525, 498, 3730, 30]], dtype=torch.long)
input_ids = torch.tensor(
[
[
1986,
374,
264,
1273,
11652,
13,
1096,
374,
264,
1273,
11652,
13,
1096,
374,
264,
1273,
11652,
13,
1096,
374,
264,
1273,
11652,
13,
]
],
dtype=torch.long,
)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long)
return dict(input_ids=input_ids, attention_mask=attention_mask)
@ -81,11 +111,11 @@ if HAS_QWEN3:
loss_fn=loss_fn_for_causal_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)
# model_zoo.register(
# name="transformers_qwen3_for_sequence_classification",
# model_fn=lambda: transformers.Qwen3ForSequenceClassification(config),
# data_gen_fn=data_gen,
# output_transform_fn=output_transform_fn,
# loss_fn=loss_fn_for_seq_classification,
# model_attribute=ModelAttribute(has_control_flow=True),
# )
model_zoo.register(
name="transformers_qwen3_for_sequence_classification",
model_fn=lambda: transformers.Qwen3ForSequenceClassification(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_seq_classification,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@ -125,28 +125,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16",
"initial_scale": 1,
},
# {
# "tp_size": 1,
# "pp_size": 2,
# "num_microbatches": 4,
# "use_lazy_init": False,
# "precision": "fp32",
# },
# {
# "tp_size": 4,
# "pp_size": 1,
# "enable_all_optimization": True,
# "use_lazy_init": False,
# "precision": "fp32",
# },
# {
# "tp_size": 1,
# "pp_size": 4,
# "num_microbatches": 4,
# "enable_all_optimization": False,
# "use_lazy_init": False,
# "precision": "fp32",
# },
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 4,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
},
{"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
{
"tp_size": 2,