mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-12 13:21:47 +00:00
release
This commit is contained in:
parent
ddbbbaab3e
commit
4271e3daf6
@ -113,6 +113,7 @@ config = transformers.GPT2Config(
|
|||||||
problem_type="single_label_classification",
|
problem_type="single_label_classification",
|
||||||
pad_token_id=1022,
|
pad_token_id=1022,
|
||||||
tie_word_embeddings=True,
|
tie_word_embeddings=True,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
)
|
)
|
||||||
|
|
||||||
config_for_token_classification = copy.deepcopy(config)
|
config_for_token_classification = copy.deepcopy(config)
|
||||||
|
@ -114,7 +114,6 @@ def run_dist(rank, world_size, port):
|
|||||||
exam_inference()
|
exam_inference()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("this test failed")
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [1, 4])
|
@pytest.mark.parametrize("world_size", [1, 4])
|
||||||
def test_inference(world_size):
|
def test_inference(world_size):
|
||||||
|
@ -1 +1 @@
|
|||||||
0.4.9
|
0.5.0
|
||||||
|
Loading…
Reference in New Issue
Block a user