1+ # Copyright 2023 Google LLC
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # https://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ # This sentinel is a reminder to choose a real run name.
16+ run_name : ' '
17+
18+ metrics_file : " " # for testing, local file that stores scalar metrics. If empty, no metrics are written.
19+ # If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
20+ gcs_metrics : True
21+ log_period : 100
22+
23+ pretrained_model_name_or_path : ' stabilityai/stable-diffusion-2-base'
24+ revision : ' main'
25+ dtype : ' bfloat16'
26+ # Set true to load weights from pytorch
27+ from_pt : True
28+ split_head_dim : True
29+
30+ # Output directory
31+ # Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
32+ base_output_directory : " "
33+
34+ # Parallelism
35+ mesh_axes : ['data', 'fsdp', 'tensor']
36+
37+ # batch : batch dimension of data and activations
38+ # hidden :
39+ # embed : attention qkv dense layer hidden dim named as embed
40+ # heads : attention head dim = num_heads * head_dim
41+ # length : attention sequence length
42+ # temb_in : dense.shape[0] of resnet dense before conv
43+ # out_c : dense.shape[1] of resnet dense before conv
44+ # out_channels : conv.shape[-1] activation
45+ # keep_1 : conv.shape[0] weight
46+ # keep_2 : conv.shape[1] weight
47+ # conv_in : conv.shape[2] weight
48+ # conv_out : conv.shape[-1] weight
49+ logical_axis_rules : [
50+ ['batch', 'data'],
51+ ['activation_batch', 'data'],
52+ ['activation_length', 'fsdp'],
53+ ['out_channels', 'fsdp'],
54+ ['conv_out', 'fsdp'],
55+ ['length', 'fsdp']
56+ ]
57+ data_sharding : [['data', 'fsdp', 'tensor']]
58+
59+ # One axis for each parallelism type may hold a placeholder (-1)
60+ # value to auto-shard based on available slices and devices.
61+ # By default, product of the DCN axes should equal number of slices
62+ # and product of the ICI axes should equal number of devices per slice.
63+ dcn_data_parallelism : -1 # recommended DCN axis to be auto-sharded
64+ dcn_fsdp_parallelism : 1
65+ dcn_tensor_parallelism : 1
66+ ici_data_parallelism : -1 # recommended ICI axis to be auto-sharded for TPUv5e
67+ ici_fsdp_parallelism : 1 # recommended ICI axis to be auto-sharded
68+ ici_tensor_parallelism : 1
69+
70+ # Dataset
71+ # Replace with dataset path or train_data_dir. One has to be set.
72+ dataset_name : ' lambdalabs/pokemon-blip-captions'
73+ train_data_dir : ' '
74+ dataset_config_name : ' '
75+ cache_dir : ' '
76+ image_column : ' image'
77+ caption_column : ' text'
78+ resolution : 512
79+ center_crop : False
80+ random_flip : False
81+ # If cache_latents_text_encoder_outputs is True
82+ # the num_proc is set to 1
83+ tokenize_captions_num_proc : 4
84+ transform_images_num_proc : 4
85+ reuse_example_batch : False
86+ enable_data_shuffling : True
87+
88+ # Prepare image latents and text encoder outputs
89+ # during dataset creation to reduce memory consumption.
90+ cache_latents_text_encoder_outputs : True
91+
92+
93+ # Training loop
94+ learning_rate : 1.e-7
95+ scale_lr : False
96+ max_train_samples : -1
97+ # max_train_steps takes priority over num_train_epochs.
98+ max_train_steps : 800
99+ seed : 0
100+ output_dir : ' sd-model-finetuned'
101+ tensorboard_dir : ' gs://shahrokhi-maxdiffusion-v5'
102+ per_device_batch_size : 1
103+
104+ cosine_learning_rate_final_fraction : 0.1
105+ warmup_steps_fraction : 0.1
106+ learning_rate_schedule_steps : -1 # By default the length of the schedule is set to the number of steps.
107+
108+ # However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
109+ # dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
110+
111+ # AdamW optimizer parameters
112+ adam_b1 : 0.9 # Exponential decay rate to track the first moment of past gradients.
113+ adam_b2 : 0.999 # Exponential decay rate to track the second moment of past gradients.
114+ adam_eps : 1.e-8 # A small constant applied to denominator outside of the square root.
115+ adam_weight_decay : 1.e-2 # AdamW Weight decay
116+
117+ max_grad_norm : 1.0
118+
119+ enable_profiler : True
120+
121+ # Generation parameters
122+ prompt : " A magical castle in the middle of a forest, artistic drawing"
123+ negative_prompt : " purple, red"
124+ guidance_scale : 7.5
125+ num_inference_steps : 30
126+ seed : 47
0 commit comments