This commit is contained in:
flybird11111 2025-05-27 14:38:59 +08:00
parent ddbbbaab3e
commit 4271e3daf6
3 changed files with 2 additions and 2 deletions

View File

@ -113,6 +113,7 @@ config = transformers.GPT2Config(
problem_type="single_label_classification",
pad_token_id=1022,
tie_word_embeddings=True,
attn_implementation="flash_attention_2",
)
config_for_token_classification = copy.deepcopy(config)

View File

@ -114,7 +114,6 @@ def run_dist(rank, world_size, port):
exam_inference()
@pytest.mark.skip("this test failed")
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 4])
def test_inference(world_size):

View File

@ -1 +1 @@
0.4.9
0.5.0