[feat] refactored extension module (#5298)

* [feat] refactored extension module

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
This commit is contained in:
Frank Lee
2024-01-25 17:01:48 +08:00
committed by GitHub
parent d7f8db8e21
commit 7cfed5f076
157 changed files with 1353 additions and 8966 deletions

View File

@@ -4,6 +4,7 @@ from .common import (
disposable,
ensure_path_exists,
free_storage,
get_current_device,
is_ddp_ignored,
set_seed,
)
@@ -22,5 +23,6 @@ __all__ = [
"_cast_float",
"free_storage",
"set_seed",
"get_current_device",
"is_ddp_ignored",
]

View File

@@ -10,6 +10,15 @@ from typing import Callable
import numpy as np
import torch
from colossalai.accelerator import get_accelerator
def get_current_device():
"""
A wrapper function for accelerator's API for backward compatibility.
"""
return get_accelerator().get_current_device()
def ensure_path_exists(filename: str):
# ensure the path exists