[test] mixtra pp shard test

This commit is contained in:
hxwang
2024-07-04 06:39:01 +00:00
committed by Hongxin Liu
parent 8ae8525bdf
commit a249e71946
3 changed files with 49 additions and 46 deletions

View File

@@ -43,14 +43,17 @@ def data_gen_for_sequence_classification():
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_mixtral_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
loss_fn_for_mixtral_model = lambda x: torch.nn.functional.mse_loss(x[0], torch.ones_like(x[0]))
loss_fn = lambda x: x.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()
config = MixtralConfig(
hidden_size=256, intermediate_size=256, num_attention_heads=64, num_hidden_layers=2, vocab_size=50258
hidden_size=256,
intermediate_size=256,
num_attention_heads=64,
num_hidden_layers=2,
vocab_size=50258,
output_router_logits=True,
)
if hasattr(config, "pad_token_id"):
@@ -64,19 +67,19 @@ model_zoo.register(
loss_fn=loss_fn_for_mixtral_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_mixtral_for_casual_lm",
model_fn=lambda: transformers.MixtralForCausalLM(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_mixtral_for_sequence_classification",
model_fn=lambda: transformers.MixtralForSequenceClassification(config),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_seq_classification,
model_attribute=ModelAttribute(has_control_flow=True),
)
# model_zoo.register(
# name="transformers_mixtral_for_casual_lm",
# model_fn=lambda: transformers.MixtralForCausalLM(config),
# data_gen_fn=data_gen_for_lm,
# output_transform_fn=output_transform_fn,
# loss_fn=loss_fn,
# model_attribute=ModelAttribute(has_control_flow=True),
# )
# model_zoo.register(
# name="transformers_mixtral_for_sequence_classification",
# model_fn=lambda: transformers.MixtralForSequenceClassification(config),
# data_gen_fn=data_gen_for_sequence_classification,
# output_transform_fn=output_transform_fn,
# loss_fn=loss_fn_for_seq_classification,
# model_attribute=ModelAttribute(has_control_flow=True),
# )