modify shell for check

This commit is contained in:
Maruyama_Aya 2023-06-08 14:56:56 +08:00
parent 730a092ba2
commit 9b5e7ce21f
3 changed files with 7 additions and 0 deletions

View File

@ -14,4 +14,5 @@ torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--test_run=True \
--placement="auto" \ --placement="auto" \

View File

@ -19,6 +19,7 @@ for plugin in "gemini"; do
--learning_rate=5e-6 \ --learning_rate=5e-6 \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--test_run=True \
--num_class_images=200 \ --num_class_images=200 \
--placement="auto" # "cuda" --placement="auto" # "cuda"
done done

View File

@ -198,6 +198,7 @@ def parse_args(input_args=None):
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument("--test_run", default=False, help="Whether to use a smaller dataset for test run.")
parser.add_argument( parser.add_argument(
"--hub_model_id", "--hub_model_id",
type=str, type=str,
@ -267,6 +268,7 @@ class DreamBoothDataset(Dataset):
class_prompt=None, class_prompt=None,
size=512, size=512,
center_crop=False, center_crop=False,
test=False,
): ):
self.size = size self.size = size
self.center_crop = center_crop self.center_crop = center_crop
@ -277,6 +279,8 @@ class DreamBoothDataset(Dataset):
raise ValueError("Instance images root doesn't exists.") raise ValueError("Instance images root doesn't exists.")
self.instance_images_path = list(Path(instance_data_root).iterdir()) self.instance_images_path = list(Path(instance_data_root).iterdir())
if test:
self.instance_images_path = self.instance_images_path[:10]
self.num_instance_images = len(self.instance_images_path) self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt self.instance_prompt = instance_prompt
self._length = self.num_instance_images self._length = self.num_instance_images
@ -509,6 +513,7 @@ def main(args):
tokenizer=tokenizer, tokenizer=tokenizer,
size=args.resolution, size=args.resolution,
center_crop=args.center_crop, center_crop=args.center_crop,
test=args.test_run
) )
def collate_fn(examples): def collate_fn(examples):