mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-24 20:20:53 +00:00
[fix] fix classification model
This commit is contained in:
parent
da335cd36c
commit
0a98d83915
@ -4,7 +4,6 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
||||||
from transformers.modeling_attn_mask_utils import (
|
from transformers.modeling_attn_mask_utils import (
|
||||||
_prepare_4d_causal_attention_mask,
|
_prepare_4d_causal_attention_mask,
|
||||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
@ -426,39 +425,27 @@ class Qwen3PipelineForwards:
|
|||||||
|
|
||||||
if self.config.pad_token_id is None and batch_size != 1:
|
if self.config.pad_token_id is None and batch_size != 1:
|
||||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||||
if self.config.pad_token_id is None:
|
|
||||||
sequence_lengths = -1
|
|
||||||
else:
|
|
||||||
if input_ids is not None:
|
|
||||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
|
||||||
else:
|
|
||||||
sequence_lengths = -1
|
|
||||||
|
|
||||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
if self.config.pad_token_id is None:
|
||||||
|
last_non_pad_token = -1
|
||||||
|
elif input_ids is not None:
|
||||||
|
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||||
|
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||||
|
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
|
||||||
|
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||||
|
else:
|
||||||
|
last_non_pad_token = -1
|
||||||
|
logger.warning_once(
|
||||||
|
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||||
|
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||||
|
)
|
||||||
|
|
||||||
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels = labels.to(logits.device)
|
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||||
if self.config.problem_type is None:
|
|
||||||
if self.num_labels == 1:
|
|
||||||
self.config.problem_type = "regression"
|
|
||||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
||||||
self.config.problem_type = "single_label_classification"
|
|
||||||
else:
|
|
||||||
self.config.problem_type = "multi_label_classification"
|
|
||||||
|
|
||||||
if self.config.problem_type == "regression":
|
|
||||||
loss_fct = MSELoss()
|
|
||||||
if self.num_labels == 1:
|
|
||||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
|
||||||
else:
|
|
||||||
loss = loss_fct(pooled_logits, labels)
|
|
||||||
elif self.config.problem_type == "single_label_classification":
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
elif self.config.problem_type == "multi_label_classification":
|
|
||||||
loss_fct = BCEWithLogitsLoss()
|
|
||||||
loss = loss_fct(pooled_logits, labels)
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (pooled_logits,) + transformer_outputs[1:]
|
output = (pooled_logits,) + transformer_outputs[1:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
@ -23,12 +23,42 @@ if HAS_QWEN3:
|
|||||||
# -----------------------------------
|
# -----------------------------------
|
||||||
# from transformers import AutoTokenizer
|
# from transformers import AutoTokenizer
|
||||||
# tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B')
|
# tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B')
|
||||||
# input = "Hey! Long time no see! How are you doing?"
|
# input = "This is a test sentence. This is a test sentence. This is a test sentence. This is a test sentence."
|
||||||
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
|
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
|
||||||
# -----------------------------------
|
# -----------------------------------
|
||||||
|
|
||||||
# NOTE: due to sp convention, need to be a multiple of 4
|
# NOTE: due to sp convention, need to be a multiple of 4
|
||||||
input_ids = torch.tensor([[18665, 0, 5724, 882, 902, 1490, 0, 2585, 525, 498, 3730, 30]], dtype=torch.long)
|
input_ids = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
1986,
|
||||||
|
374,
|
||||||
|
264,
|
||||||
|
1273,
|
||||||
|
11652,
|
||||||
|
13,
|
||||||
|
1096,
|
||||||
|
374,
|
||||||
|
264,
|
||||||
|
1273,
|
||||||
|
11652,
|
||||||
|
13,
|
||||||
|
1096,
|
||||||
|
374,
|
||||||
|
264,
|
||||||
|
1273,
|
||||||
|
11652,
|
||||||
|
13,
|
||||||
|
1096,
|
||||||
|
374,
|
||||||
|
264,
|
||||||
|
1273,
|
||||||
|
11652,
|
||||||
|
13,
|
||||||
|
]
|
||||||
|
],
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long)
|
attention_mask = torch.ones(input_ids.shape, dtype=torch.long)
|
||||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
@ -81,11 +111,11 @@ if HAS_QWEN3:
|
|||||||
loss_fn=loss_fn_for_causal_lm,
|
loss_fn=loss_fn_for_causal_lm,
|
||||||
model_attribute=ModelAttribute(has_control_flow=True),
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
)
|
)
|
||||||
# model_zoo.register(
|
model_zoo.register(
|
||||||
# name="transformers_qwen3_for_sequence_classification",
|
name="transformers_qwen3_for_sequence_classification",
|
||||||
# model_fn=lambda: transformers.Qwen3ForSequenceClassification(config),
|
model_fn=lambda: transformers.Qwen3ForSequenceClassification(config),
|
||||||
# data_gen_fn=data_gen,
|
data_gen_fn=data_gen,
|
||||||
# output_transform_fn=output_transform_fn,
|
output_transform_fn=output_transform_fn,
|
||||||
# loss_fn=loss_fn_for_seq_classification,
|
loss_fn=loss_fn_for_seq_classification,
|
||||||
# model_attribute=ModelAttribute(has_control_flow=True),
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
# )
|
)
|
||||||
|
@ -125,28 +125,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
# {
|
{
|
||||||
# "tp_size": 1,
|
"tp_size": 1,
|
||||||
# "pp_size": 2,
|
"pp_size": 2,
|
||||||
# "num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
# "use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
# "precision": "fp32",
|
"precision": "fp32",
|
||||||
# },
|
},
|
||||||
# {
|
{
|
||||||
# "tp_size": 4,
|
"tp_size": 4,
|
||||||
# "pp_size": 1,
|
"pp_size": 1,
|
||||||
# "enable_all_optimization": True,
|
"enable_all_optimization": True,
|
||||||
# "use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
# "precision": "fp32",
|
"precision": "fp32",
|
||||||
# },
|
},
|
||||||
# {
|
{
|
||||||
# "tp_size": 1,
|
"tp_size": 1,
|
||||||
# "pp_size": 4,
|
"pp_size": 4,
|
||||||
# "num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
# "enable_all_optimization": False,
|
"enable_all_optimization": False,
|
||||||
# "use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
# "precision": "fp32",
|
"precision": "fp32",
|
||||||
# },
|
},
|
||||||
{"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
|
{"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
|
Loading…
Reference in New Issue
Block a user