Sleds/TFPageRotation/Trainer.py

310 lines
9.2 KiB
Python

import os
import sys
import cv2
import keras
import pickle
import random
import argparse
import numpy as np
from datetime import datetime as dt
from matplotlib import pyplot as plt
import ConsoleUtils
prg_name = 'PageRotTrainer'
prg_desc = 'Trains a Keras wrapped TensorFlow classifier to detect the orientation of pages.'
prg_vers = '0.1.0'
prg_date = '2017/10/02'
prg_auth = 'Chris Diesch <cdiesch@sequencelogic.net>'
usage = 'Trainer.py [Options...] TRAIN_DATA OUTPUT_ROOT'
parser = argparse.ArgumentParser(prog=prg_name, description=prg_desc, add_help=False, usage=usage)
printer = ConsoleUtils.SLPrinter(prg_name)
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
_LABELS = ['Not-Rotated', 'Right', 'Upside-Down', 'Left']
_NOT_ROT = [1, 0, 0, 0]
_RIGHT = [0, 1, 0, 0]
_UP_DOWN = [0, 0, 1, 0]
_LEFT = [0, 0, 0, 1]
_IMG_X = 500
_IMG_Y = 500
_num_batches = 0
FLAGS = None
def _print_help():
if sys.stdout == printer:
sys.stdout = printer.old_stdout
print('%s Version: %s' % (prg_name, prg_vers))
print(prg_desc)
print('')
print('Batching:')
print(' -b, --batch-size SIZE The number of training inputs to put in each batch (default: 100).')
print(' -s, --shuffle Shuffles the batches after creating them.')
print('')
print('Miscellaneous:')
print(' -h, --help Prints the help message.')
print(' -V, --version Prints the version info.')
print('')
print('Version: %s' % prg_vers)
print('Date: %s' % prg_date)
print('Author: %s' % prg_auth)
print('')
def _print_vers():
if sys.stdout == printer:
sys.stdout = printer.old_stdout
print('%s' % prg_name)
print('Version: %s' % prg_vers)
print('Date: %s' % prg_date)
def _make_args():
# Batching arguments
parser.add_argument('-b', '--batch-size', type=int, default=100)
parser.add_argument('-s', '--shuffle', action='store_true')
# Miscellaneous arguments
parser.add_argument('-h', '--help', action=ConsoleUtils.CustomPrintAction, print_fn=_print_help())
parser.add_argument('-V', '--version', action=ConsoleUtils.CustomPrintAction, print_fn=_print_vers())
# Required (positional) arguments
parser.add_argument('train_data')
parser.add_argument('output_root')
def _to_one_hot(label):
idx = _LABELS.index(label)
one_hot = keras.utils.to_categorical(idx, len(_LABELS))
return one_hot
def _load_records(records_dir):
records = {}
print('Loading record files from "%s"' % records_dir)
for file in os.listdir(records_dir):
rot_name = file[:-4]
if file.endswith('pkl'):
records_file = os.path.join(records_dir, file)
with open(records_file, 'rb') as reader:
data = pickle.load(reader)
records[rot_name] = data
not_rotated = records['Not-Rotated']
rotated_right = records['Right']
up_down = records['Upside-Down']
rotated_left = records['Left']
print('Loaded records successfully.')
del records
num_not_rot = len(not_rotated)
num_right = len(rotated_right)
num_up_down = len(up_down)
num_left = len(rotated_left)
total = num_not_rot + num_right + num_up_down + num_left
per_not = float(num_not_rot * 100/total)
per_right = float(num_right * 100/total)
per_up_down = float(num_up_down * 100/total)
per_left = float(num_left * 100/total)
print('Loaded %d pages:' % total)
print(' %d (%.2f%%) not rotated' % (num_not_rot, per_not))
print(' %d (%.2f%%) rotated right' % (num_right, per_right))
print(' %d (%.2f%%) upside-down' % (num_up_down, per_up_down))
print(' %d (%.2f%%) rotated left' % (num_left, per_left))
return not_rotated, rotated_right, up_down, rotated_left
def _make_batch(not_rot, right_rot, up_down, left_rot):
batch = not_rot
del not_rot
batch += right_rot
del right_rot
batch += up_down
del up_down
batch += left_rot
del left_rot
# def_size = int(100/4)
# batch += random.sample(not_rot, def_size)
# batch += random.sample(right_rot, def_size)
# batch += random.sample(up_down, def_size)
# batch += random.sample(left_rot, def_size)
#
# while len(batch) < 100:
# decision = random.randint(0, 3)
#
# if decision == 0:
# rand_item = random.sample(not_rot, 1)[0]
# elif decision == 1:
# rand_item = random.sample(right_rot, 1)[0]
# elif decision == 2:
# rand_item = random.sample(up_down, 1)[0]
# else:
# rand_item = random.sample(left_rot, 1)[0]
#
# batch.append(rand_item)
# if FLAGS.shuffle:
random.shuffle(batch)
batch_x = []
batch_y = []
for e in batch:
old_h = e['height']
old_w = e['width']
img = e['image']
batch_x.append(img)
lbl = _to_one_hot(e['rotation_dir'])
lbl = np.reshape(lbl, newshape=4)
batch_y.append(lbl)
batch_x = np.array(batch_x)
batch_y = np.array(batch_y)
return batch_x, batch_y
def rmse(y_pred, y_true):
return keras.backend.sqrt(keras.backend.mean(keras.backend.mean(y_pred - y_true), axis=-1))
def _build_model():
print('Building model')
model = keras.models.Sequential()
# Input
model.add(keras.layers.InputLayer(input_shape=(_IMG_X, _IMG_Y, 1)))
# Layer 1
model.add(keras.layers.Conv2D(filters=32,
kernel_size=(3, 3)))
model.add(keras.layers.Activation('relu'))
model.add(keras.layers.MaxPooling2D(pool_size=(2, 2)))
# Layer 2
model.add(keras.layers.Conv2D(filters=64,
kernel_size=(3, 3)))
model.add(keras.layers.Activation('relu'))
model.add(keras.layers.MaxPooling2D(pool_size=(2, 2)))
# Layer 3
model.add(keras.layers.Conv2D(filters=64,
kernel_size=(3, 3)))
model.add(keras.layers.Activation('relu'))
model.add(keras.layers.MaxPooling2D(pool_size=(2, 2)))
# Global average pooling
model.add(keras.layers.GlobalAveragePooling2D())
# We have 4 possible classes
model.add(keras.layers.Dense(4))
model.add(keras.layers.Activation('softmax'))
optimizer = keras.optimizers.SGD(lr=0.03)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
# model.compile(optimizer='adam', loss='mse', metrics=['accuracy', rmse])
print('Model Summary:')
print('')
print(model.summary())
return model
# def _show_img(img, name):
# image = img.reshape([img.shape[0], img.shape[1]])
# file_name = '/run/media/cdiesch/Slow_SSD/Tests/tfPageRot/Input/%s.jpg' % name
# cv2.imwrite(file_name, image)
# print('%s size: %dx%d' % (name, image.shape[0], image.shape[1]))
def main():
no, right, up_down, left = \
_load_records(r'/run/media/cdiesch/Slow_SSD/Tests/tfPageRot/Input/2017.10.10-12.53.40/train-data')
train_x, train_y = _make_batch(no, right, up_down, left)
no, right, up_down, left = \
_load_records(r'/run/media/cdiesch/Slow_SSD/Tests/tfPageRot/Input/2017.10.10-12.53.40/test-data')
test_x, test_y = _make_batch(no, right, up_down, left)
# remove useless variables.
del no, right, up_down, left
print('Training X shape: %s' % str(train_x.shape))
print('Training Y shape: %s' % str(train_y.shape))
model = _build_model()
print('Training model')
history = model.fit(train_x, train_y, batch_size=48, epochs=750, validation_data=(test_x, test_y), verbose=2)
printer.write_line_break()
print('Done Training')
tst_acc_vals = history.history['acc']
tst_acc_color = 'g'
val_acc_vals = history.history['val_acc']
val_acc_color = 'b'
tst_loss_vals = history.history['loss']
tst_loss_color = 'r'
val_loss_vals = history.history['val_loss']
val_loss_color = 'y'
plt.figure(figsize=(10, 4), dpi=300)
plt.suptitle('Training Information')
plt.subplot(121)
plt.plot(tst_acc_vals, label='Training Accuracy', color=tst_acc_color)
plt.plot(val_acc_vals, label='Validation Accuracy', color=val_acc_color)
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.subplot(122)
plt.plot(tst_loss_vals, label='Training Loss', color=tst_loss_color)
plt.plot(val_loss_vals, label='Validation Loss', color=val_loss_color)
plt.title('Crossentropy Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.subplots_adjust(left=0.2, wspace=0.8, top=0.8)
file_name = dt.now().strftime('%Y.%m.%d-%H.%M.%S')
plt.savefig('/run/media/cdiesch/Slow_SSD/Tests/tfPageRot/Output/Test-%s.png' % file_name)
print('Saving model')
model.save('/run/media/cdiesch/Slow_SSD/Tests/tfPageRot/Output/Test-%s.mdl' % file_name)
#
# printer.write_no_prefix('')
#
# print('Testing X shape: %s' % str(test_x.shape))
# print('Testing Y shape: %s' % str(test_y.shape))
#
# print('Evaluating model')
# score = model.evaluate(test_x, test_y, batch_size=48)
# print('Score: %s' % str(score))
if __name__ == '__main__':
print(ConsoleUtils.get_header(prg_name, prg_vers, prg_date, prg_auth))
sys.stdout = printer
main()