mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-08 03:24:07 +00:00
[hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548)
This commit is contained in:
parent
a799ca343b
commit
341263df48
@ -2,10 +2,10 @@ import time
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from utils import get_defualt_parser, inference, print_output
|
from utils import get_default_parser, inference, print_output
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = get_defualt_parser()
|
parser = get_default_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
start = time.time()
|
start = time.time()
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
|
@ -3,7 +3,7 @@ import time
|
|||||||
import torch
|
import torch
|
||||||
from grok1_policy import Grok1ForCausalLMPolicy
|
from grok1_policy import Grok1ForCausalLMPolicy
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from utils import get_defualt_parser, inference, print_output
|
from utils import get_default_parser, inference, print_output
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
@ -13,7 +13,7 @@ from colossalai.lazy import LazyInitContext
|
|||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = get_defualt_parser()
|
parser = get_default_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
start = time.time()
|
start = time.time()
|
||||||
colossalai.launch_from_torch({})
|
colossalai.launch_from_torch({})
|
||||||
|
@ -33,7 +33,7 @@ def inference(model, tokenizer, text, **generate_kwargs):
|
|||||||
return outputs[0].tolist()
|
return outputs[0].tolist()
|
||||||
|
|
||||||
|
|
||||||
def get_defualt_parser():
|
def get_default_parser():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--pretrained", type=str, default="hpcaitech/grok-1")
|
parser.add_argument("--pretrained", type=str, default="hpcaitech/grok-1")
|
||||||
parser.add_argument("--tokenizer", type=str, default="tokenizer.model")
|
parser.add_argument("--tokenizer", type=str, default="tokenizer.model")
|
||||||
|
Loading…
Reference in New Issue
Block a user