added testing module (#435)

This commit is contained in:
Frank Lee
2022-03-16 17:20:05 +08:00
committed by GitHub
parent dbdc9a7783
commit bffd85bf34
3 changed files with 85 additions and 0 deletions

View File

@@ -0,0 +1,52 @@
from typing import List, Any
from functools import partial
def parameterize(argument: str, values: List[Any]):
"""
This function is to simulate the same behavior as pytest.mark.parameterize. As
we want to avoid the number of distributed network initialization, we need to have
this extra decorator on the function launched by torch.multiprocessing.
If a function is wrapped with this wrapper, non-paramterized arguments must be keyword arguments,
positioanl arguments are not allowed.
Example 1:
@parameterize('person', ['xavier', 'davis'])
def say_something(person, msg):
print(f'{person}: {msg}')
say_something(msg='hello')
This will generate output:
> xavier: hello
> davis: hello
Exampel 2:
@parameterize('person', ['xavier', 'davis'])
@parameterize('msg', ['hello', 'bye', 'stop'])
def say_something(person, msg):
print(f'{person}: {msg}')
say_something()
This will generate output:
> xavier: hello
> xavier: bye
> xavier: stop
> davis: hello
> davis: bye
> davis: stop
"""
def _wrapper(func):
def _execute_function_by_param(**kwargs):
for val in values:
arg_map = {argument: val}
partial_func = partial(func, **arg_map)
partial_func(**kwargs)
return _execute_function_by_param
return _wrapper