|
| 1 | +#! /usr/bin/env python |
| 2 | +# coding=utf-8 |
| 3 | +#================================================================ |
| 4 | +# Copyright (C) 2020 * Ltd. All rights reserved. |
| 5 | +# |
| 6 | +# Editor : VIM |
| 7 | +# File name : multi_gpu_train.py |
| 8 | +# Author : YunYang1994 |
| 9 | +# Created date: 2020-02-02 22:14:30 |
| 10 | +# Description : |
| 11 | +# |
| 12 | +#================================================================ |
| 13 | + |
| 14 | +import tensorflow as tf |
| 15 | +from tqdm import tqdm |
| 16 | +from tensorflow.keras.preprocessing.image import ImageDataGenerator |
| 17 | +from tensorflow.keras import applications |
| 18 | +from tensorflow.keras.optimizers import SGD |
| 19 | +os.environ["CUDA_VISIBLE_DEVICES"] = "0,2,3" |
| 20 | + |
| 21 | +BATCH_SIZE = 384 # 3 GPU and 128 batch size per GPU |
| 22 | +EPOCHS = 30 |
| 23 | +NUM_CLASS = 10 |
| 24 | +EMB_SIZE = 512 # Embedding Size |
| 25 | +GPU_SIZE = 30 # (G) MemorySIZE per GPU |
| 26 | +IMG_SIZE = 112 # Input Image Size |
| 27 | + |
| 28 | +train_datagen = ImageDataGenerator( |
| 29 | + rescale=1./255, |
| 30 | + shear_range=0.2, |
| 31 | + zoom_range=0.2, |
| 32 | + horizontal_flip=True) |
| 33 | + |
| 34 | +train_generator = train_datagen.flow_from_directory( |
| 35 | + '/home/yyang/mnist/train', |
| 36 | + target_size=(IMG_SIZE, IMG_SIZE), |
| 37 | + batch_size=BATCH_SIZE, |
| 38 | + class_mode='categorical') |
| 39 | + |
| 40 | +tf.debugging.set_log_device_placement(True) |
| 41 | +gpus = tf.config.experimental.list_physical_devices('GPU') |
| 42 | + |
| 43 | +for gpu in gpus: |
| 44 | + tf.config.experimental.set_virtual_device_configuration( |
| 45 | + gpu, [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=GPU_SIZE*1024)] |
| 46 | + ) |
| 47 | +logical_gpus = tf.config.experimental.list_logical_devices('GPU') |
| 48 | +print(len(gpus), "Physical GPU,", len(logical_gpus), "Logical GPUs") |
| 49 | + |
| 50 | +tf.debugging.set_log_device_placement(True) |
| 51 | +strategy = tf.distribute.MirroredStrategy() |
| 52 | + |
| 53 | +# Defining Model |
| 54 | +with strategy.scope(): |
| 55 | + model = applications.mobilenet_v2.MobileNetV2(include_top=False, weights=None, |
| 56 | + input_shape=(IMG_SIZE,IMG_SIZE,3)) |
| 57 | + x = tf.keras.layers.Input(shape=(IMG_SIZE,IMG_SIZE,3)) |
| 58 | + y = model(x) |
| 59 | + y = tf.keras.layers.AveragePooling2D()(y) |
| 60 | + y = tf.keras.layers.Flatten()(y) |
| 61 | + y = tf.keras.layers.Dense(EMB_SIZE, activation=None)(y) |
| 62 | + y = tf.keras.layers.Dense(NUM_CLASS, activation='softmax')(y) |
| 63 | + model = tf.keras.models.Model(inputs=x, outputs=y) |
| 64 | + |
| 65 | + optimizer = tf.keras.optimizers.Adam(0.001) |
| 66 | + checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) |
| 67 | + |
| 68 | +# Defining Loss and Metrics |
| 69 | +with strategy.scope(): |
| 70 | + loss_object = tf.keras.losses.CategoricalCrossentropy( |
| 71 | + reduction=tf.keras.losses.Reduction.NONE |
| 72 | + ) |
| 73 | + def compute_loss(labels, predictions): |
| 74 | + per_example_loss = loss_object(labels, predictions) |
| 75 | + return tf.nn.compute_average_loss(per_example_loss, global_batch_size=BATCH_SIZE) |
| 76 | + |
| 77 | + train_accuracy = tf.keras.metrics.CategoricalAccuracy( |
| 78 | + name='train_accuracy' |
| 79 | + ) |
| 80 | + |
| 81 | +# Defining Training Step |
| 82 | +with strategy.scope(): |
| 83 | + def train_step(inputs): |
| 84 | + images, labels = inputs |
| 85 | + |
| 86 | + with tf.GradientTape() as tape: |
| 87 | + predictions = model(images, training=True) |
| 88 | + loss = compute_loss(labels, predictions) |
| 89 | + |
| 90 | + gradients = tape.gradient(loss, model.trainable_variables) |
| 91 | + optimizer.apply_gradients(zip(gradients, model.trainable_variables)) |
| 92 | + train_accuracy.update_state(labels, predictions) |
| 93 | + return loss |
| 94 | + |
| 95 | +# Defining Training Loops |
| 96 | +with strategy.scope(): |
| 97 | + @tf.function |
| 98 | + def distributed_train_step(dataset_inputs): |
| 99 | + per_replica_losses = strategy.experimental_run_v2(train_step, |
| 100 | + args=(dataset_inputs,)) |
| 101 | + return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, |
| 102 | + axis=None) |
| 103 | + for epoch in range(EPOCHS): |
| 104 | + batchs_per_epoch = len(train_generator) |
| 105 | + train_dataset = iter(train_generator) |
| 106 | + |
| 107 | + with tqdm(total=batchs_per_epoch, |
| 108 | + desc="Epoch %2d/%2d" %(epoch+1, EPOCHS)) as pbar: |
| 109 | + for _ in range(batchs_per_epoch): |
| 110 | + batch_loss = distributed_train_step(next(train_dataset)) |
| 111 | + batch_acc = train_accuracy.result() |
| 112 | + pbar.set_postfix({'loss' : '%.4f' %batch_loss, |
| 113 | + 'accuracy' : '%.6f' %batch_acc}) |
| 114 | + train_accuracy.reset_states() |
| 115 | + pbar.update(1) |
| 116 | + |
| 117 | +model.save_weights("model.h5") |
| 118 | + |
| 119 | + |
0 commit comments