mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
modify shell for check
This commit is contained in:
parent
730a092ba2
commit
9b5e7ce21f
@ -14,4 +14,5 @@ torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \
|
|||||||
--lr_scheduler="constant" \
|
--lr_scheduler="constant" \
|
||||||
--lr_warmup_steps=0 \
|
--lr_warmup_steps=0 \
|
||||||
--num_class_images=200 \
|
--num_class_images=200 \
|
||||||
|
--test_run=True \
|
||||||
--placement="auto" \
|
--placement="auto" \
|
||||||
|
@ -19,6 +19,7 @@ for plugin in "gemini"; do
|
|||||||
--learning_rate=5e-6 \
|
--learning_rate=5e-6 \
|
||||||
--lr_scheduler="constant" \
|
--lr_scheduler="constant" \
|
||||||
--lr_warmup_steps=0 \
|
--lr_warmup_steps=0 \
|
||||||
|
--test_run=True \
|
||||||
--num_class_images=200 \
|
--num_class_images=200 \
|
||||||
--placement="auto" # "cuda"
|
--placement="auto" # "cuda"
|
||||||
done
|
done
|
||||||
|
@ -198,6 +198,7 @@ def parse_args(input_args=None):
|
|||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||||
|
parser.add_argument("--test_run", default=False, help="Whether to use a smaller dataset for test run.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hub_model_id",
|
"--hub_model_id",
|
||||||
type=str,
|
type=str,
|
||||||
@ -267,6 +268,7 @@ class DreamBoothDataset(Dataset):
|
|||||||
class_prompt=None,
|
class_prompt=None,
|
||||||
size=512,
|
size=512,
|
||||||
center_crop=False,
|
center_crop=False,
|
||||||
|
test=False,
|
||||||
):
|
):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.center_crop = center_crop
|
self.center_crop = center_crop
|
||||||
@ -277,6 +279,8 @@ class DreamBoothDataset(Dataset):
|
|||||||
raise ValueError("Instance images root doesn't exists.")
|
raise ValueError("Instance images root doesn't exists.")
|
||||||
|
|
||||||
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
||||||
|
if test:
|
||||||
|
self.instance_images_path = self.instance_images_path[:10]
|
||||||
self.num_instance_images = len(self.instance_images_path)
|
self.num_instance_images = len(self.instance_images_path)
|
||||||
self.instance_prompt = instance_prompt
|
self.instance_prompt = instance_prompt
|
||||||
self._length = self.num_instance_images
|
self._length = self.num_instance_images
|
||||||
@ -509,6 +513,7 @@ def main(args):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
size=args.resolution,
|
size=args.resolution,
|
||||||
center_crop=args.center_crop,
|
center_crop=args.center_crop,
|
||||||
|
test=args.test_run
|
||||||
)
|
)
|
||||||
|
|
||||||
def collate_fn(examples):
|
def collate_fn(examples):
|
||||||
|
Loading…
Reference in New Issue
Block a user