DB-GPT/dbgpt/util/tests/test_parameter_utils.py
2024-01-10 10:39:04 +08:00

84 lines
2.4 KiB
Python

import argparse
import pytest
from dbgpt.util.parameter_utils import _extract_parameter_details
def create_parser():
parser = argparse.ArgumentParser()
return parser
@pytest.mark.parametrize(
"argument, expected_param_name, default_value, param_type, expected_param_type, description",
[
("--option", "option", "value", str, "str", "An option argument"),
("-option", "option", "value", str, "str", "An option argument"),
("--num-gpu", "num_gpu", 1, int, "int", "Number of GPUS"),
("--num_gpu", "num_gpu", 1, int, "int", "Number of GPUS"),
],
)
def test_extract_parameter_details_option_argument(
argument,
expected_param_name,
default_value,
param_type,
expected_param_type,
description,
):
parser = create_parser()
parser.add_argument(
argument, default=default_value, type=param_type, help=description
)
descriptions = _extract_parameter_details(parser)
assert len(descriptions) == 1
desc = descriptions[0]
assert desc.param_name == expected_param_name
assert desc.param_type == expected_param_type
assert desc.default_value == default_value
assert desc.description == description
assert desc.required == False
assert desc.valid_values is None
def test_extract_parameter_details_flag_argument():
parser = create_parser()
parser.add_argument("--flag", action="store_true", help="A flag argument")
descriptions = _extract_parameter_details(parser)
assert len(descriptions) == 1
desc = descriptions[0]
assert desc.param_name == "flag"
assert desc.description == "A flag argument"
assert desc.required == False
def test_extract_parameter_details_choice_argument():
parser = create_parser()
parser.add_argument("--choice", choices=["A", "B", "C"], help="A choice argument")
descriptions = _extract_parameter_details(parser)
assert len(descriptions) == 1
desc = descriptions[0]
assert desc.param_name == "choice"
assert desc.valid_values == ["A", "B", "C"]
def test_extract_parameter_details_required_argument():
parser = create_parser()
parser.add_argument(
"--required", required=True, type=int, help="A required argument"
)
descriptions = _extract_parameter_details(parser)
assert len(descriptions) == 1
desc = descriptions[0]
assert desc.param_name == "required"
assert desc.required == True