310 lines
9.2 KiB
Python
310 lines
9.2 KiB
Python
import os
|
|
import sys
|
|
import cv2
|
|
import keras
|
|
import pickle
|
|
import random
|
|
import argparse
|
|
import numpy as np
|
|
from datetime import datetime as dt
|
|
from matplotlib import pyplot as plt
|
|
import ConsoleUtils
|
|
|
|
prg_name = 'PageRotTrainer'
|
|
prg_desc = 'Trains a Keras wrapped TensorFlow classifier to detect the orientation of pages.'
|
|
prg_vers = '0.1.0'
|
|
prg_date = '2017/10/02'
|
|
prg_auth = 'Chris Diesch <cdiesch@sequencelogic.net>'
|
|
|
|
usage = 'Trainer.py [Options...] TRAIN_DATA OUTPUT_ROOT'
|
|
|
|
parser = argparse.ArgumentParser(prog=prg_name, description=prg_desc, add_help=False, usage=usage)
|
|
|
|
printer = ConsoleUtils.SLPrinter(prg_name)
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
|
|
|
_LABELS = ['Not-Rotated', 'Right', 'Upside-Down', 'Left']
|
|
_NOT_ROT = [1, 0, 0, 0]
|
|
_RIGHT = [0, 1, 0, 0]
|
|
_UP_DOWN = [0, 0, 1, 0]
|
|
_LEFT = [0, 0, 0, 1]
|
|
|
|
_IMG_X = 500
|
|
_IMG_Y = 500
|
|
|
|
_num_batches = 0
|
|
|
|
FLAGS = None
|
|
|
|
|
|
def _print_help():
|
|
if sys.stdout == printer:
|
|
sys.stdout = printer.old_stdout
|
|
|
|
print('%s Version: %s' % (prg_name, prg_vers))
|
|
print(prg_desc)
|
|
print('')
|
|
print('Batching:')
|
|
print(' -b, --batch-size SIZE The number of training inputs to put in each batch (default: 100).')
|
|
print(' -s, --shuffle Shuffles the batches after creating them.')
|
|
print('')
|
|
print('Miscellaneous:')
|
|
print(' -h, --help Prints the help message.')
|
|
print(' -V, --version Prints the version info.')
|
|
print('')
|
|
print('Version: %s' % prg_vers)
|
|
print('Date: %s' % prg_date)
|
|
print('Author: %s' % prg_auth)
|
|
print('')
|
|
|
|
|
|
def _print_vers():
|
|
if sys.stdout == printer:
|
|
sys.stdout = printer.old_stdout
|
|
|
|
print('%s' % prg_name)
|
|
print('Version: %s' % prg_vers)
|
|
print('Date: %s' % prg_date)
|
|
|
|
|
|
def _make_args():
|
|
# Batching arguments
|
|
parser.add_argument('-b', '--batch-size', type=int, default=100)
|
|
parser.add_argument('-s', '--shuffle', action='store_true')
|
|
|
|
# Miscellaneous arguments
|
|
parser.add_argument('-h', '--help', action=ConsoleUtils.CustomPrintAction, print_fn=_print_help())
|
|
parser.add_argument('-V', '--version', action=ConsoleUtils.CustomPrintAction, print_fn=_print_vers())
|
|
|
|
# Required (positional) arguments
|
|
parser.add_argument('train_data')
|
|
parser.add_argument('output_root')
|
|
|
|
|
|
def _to_one_hot(label):
|
|
idx = _LABELS.index(label)
|
|
one_hot = keras.utils.to_categorical(idx, len(_LABELS))
|
|
return one_hot
|
|
|
|
|
|
def _load_records(records_dir):
|
|
records = {}
|
|
|
|
print('Loading record files from "%s"' % records_dir)
|
|
|
|
for file in os.listdir(records_dir):
|
|
rot_name = file[:-4]
|
|
if file.endswith('pkl'):
|
|
records_file = os.path.join(records_dir, file)
|
|
with open(records_file, 'rb') as reader:
|
|
data = pickle.load(reader)
|
|
records[rot_name] = data
|
|
|
|
not_rotated = records['Not-Rotated']
|
|
rotated_right = records['Right']
|
|
up_down = records['Upside-Down']
|
|
rotated_left = records['Left']
|
|
|
|
print('Loaded records successfully.')
|
|
del records
|
|
|
|
num_not_rot = len(not_rotated)
|
|
num_right = len(rotated_right)
|
|
num_up_down = len(up_down)
|
|
num_left = len(rotated_left)
|
|
total = num_not_rot + num_right + num_up_down + num_left
|
|
|
|
per_not = float(num_not_rot * 100/total)
|
|
per_right = float(num_right * 100/total)
|
|
per_up_down = float(num_up_down * 100/total)
|
|
per_left = float(num_left * 100/total)
|
|
|
|
print('Loaded %d pages:' % total)
|
|
print(' %d (%.2f%%) not rotated' % (num_not_rot, per_not))
|
|
print(' %d (%.2f%%) rotated right' % (num_right, per_right))
|
|
print(' %d (%.2f%%) upside-down' % (num_up_down, per_up_down))
|
|
print(' %d (%.2f%%) rotated left' % (num_left, per_left))
|
|
|
|
return not_rotated, rotated_right, up_down, rotated_left
|
|
|
|
|
|
def _make_batch(not_rot, right_rot, up_down, left_rot):
|
|
batch = not_rot
|
|
del not_rot
|
|
|
|
batch += right_rot
|
|
del right_rot
|
|
|
|
batch += up_down
|
|
del up_down
|
|
|
|
batch += left_rot
|
|
del left_rot
|
|
|
|
# def_size = int(100/4)
|
|
# batch += random.sample(not_rot, def_size)
|
|
# batch += random.sample(right_rot, def_size)
|
|
# batch += random.sample(up_down, def_size)
|
|
# batch += random.sample(left_rot, def_size)
|
|
#
|
|
# while len(batch) < 100:
|
|
# decision = random.randint(0, 3)
|
|
#
|
|
# if decision == 0:
|
|
# rand_item = random.sample(not_rot, 1)[0]
|
|
# elif decision == 1:
|
|
# rand_item = random.sample(right_rot, 1)[0]
|
|
# elif decision == 2:
|
|
# rand_item = random.sample(up_down, 1)[0]
|
|
# else:
|
|
# rand_item = random.sample(left_rot, 1)[0]
|
|
#
|
|
# batch.append(rand_item)
|
|
|
|
# if FLAGS.shuffle:
|
|
random.shuffle(batch)
|
|
|
|
batch_x = []
|
|
batch_y = []
|
|
|
|
for e in batch:
|
|
old_h = e['height']
|
|
old_w = e['width']
|
|
img = e['image']
|
|
|
|
batch_x.append(img)
|
|
|
|
lbl = _to_one_hot(e['rotation_dir'])
|
|
lbl = np.reshape(lbl, newshape=4)
|
|
batch_y.append(lbl)
|
|
|
|
batch_x = np.array(batch_x)
|
|
batch_y = np.array(batch_y)
|
|
|
|
return batch_x, batch_y
|
|
|
|
|
|
def rmse(y_pred, y_true):
|
|
return keras.backend.sqrt(keras.backend.mean(keras.backend.mean(y_pred - y_true), axis=-1))
|
|
|
|
|
|
def _build_model():
|
|
print('Building model')
|
|
model = keras.models.Sequential()
|
|
# Input
|
|
model.add(keras.layers.InputLayer(input_shape=(_IMG_X, _IMG_Y, 1)))
|
|
# Layer 1
|
|
model.add(keras.layers.Conv2D(filters=32,
|
|
kernel_size=(3, 3)))
|
|
model.add(keras.layers.Activation('relu'))
|
|
model.add(keras.layers.MaxPooling2D(pool_size=(2, 2)))
|
|
# Layer 2
|
|
model.add(keras.layers.Conv2D(filters=64,
|
|
kernel_size=(3, 3)))
|
|
model.add(keras.layers.Activation('relu'))
|
|
model.add(keras.layers.MaxPooling2D(pool_size=(2, 2)))
|
|
# Layer 3
|
|
model.add(keras.layers.Conv2D(filters=64,
|
|
kernel_size=(3, 3)))
|
|
model.add(keras.layers.Activation('relu'))
|
|
model.add(keras.layers.MaxPooling2D(pool_size=(2, 2)))
|
|
# Global average pooling
|
|
model.add(keras.layers.GlobalAveragePooling2D())
|
|
# We have 4 possible classes
|
|
model.add(keras.layers.Dense(4))
|
|
model.add(keras.layers.Activation('softmax'))
|
|
|
|
optimizer = keras.optimizers.SGD(lr=0.03)
|
|
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
|
|
# model.compile(optimizer='adam', loss='mse', metrics=['accuracy', rmse])
|
|
|
|
print('Model Summary:')
|
|
print('')
|
|
print(model.summary())
|
|
|
|
return model
|
|
|
|
|
|
# def _show_img(img, name):
|
|
# image = img.reshape([img.shape[0], img.shape[1]])
|
|
# file_name = '/run/media/cdiesch/Slow_SSD/Tests/tfPageRot/Input/%s.jpg' % name
|
|
# cv2.imwrite(file_name, image)
|
|
# print('%s size: %dx%d' % (name, image.shape[0], image.shape[1]))
|
|
|
|
|
|
def main():
|
|
no, right, up_down, left = \
|
|
_load_records(r'/run/media/cdiesch/Slow_SSD/Tests/tfPageRot/Input/2017.10.10-12.53.40/train-data')
|
|
|
|
train_x, train_y = _make_batch(no, right, up_down, left)
|
|
|
|
no, right, up_down, left = \
|
|
_load_records(r'/run/media/cdiesch/Slow_SSD/Tests/tfPageRot/Input/2017.10.10-12.53.40/test-data')
|
|
|
|
test_x, test_y = _make_batch(no, right, up_down, left)
|
|
|
|
# remove useless variables.
|
|
del no, right, up_down, left
|
|
|
|
print('Training X shape: %s' % str(train_x.shape))
|
|
print('Training Y shape: %s' % str(train_y.shape))
|
|
|
|
model = _build_model()
|
|
print('Training model')
|
|
history = model.fit(train_x, train_y, batch_size=48, epochs=750, validation_data=(test_x, test_y), verbose=2)
|
|
printer.write_line_break()
|
|
print('Done Training')
|
|
|
|
tst_acc_vals = history.history['acc']
|
|
tst_acc_color = 'g'
|
|
val_acc_vals = history.history['val_acc']
|
|
val_acc_color = 'b'
|
|
tst_loss_vals = history.history['loss']
|
|
tst_loss_color = 'r'
|
|
val_loss_vals = history.history['val_loss']
|
|
val_loss_color = 'y'
|
|
|
|
plt.figure(figsize=(10, 4), dpi=300)
|
|
plt.suptitle('Training Information')
|
|
|
|
plt.subplot(121)
|
|
plt.plot(tst_acc_vals, label='Training Accuracy', color=tst_acc_color)
|
|
plt.plot(val_acc_vals, label='Validation Accuracy', color=val_acc_color)
|
|
plt.title('Accuracy')
|
|
plt.xlabel('Epoch')
|
|
plt.ylabel('Accuracy')
|
|
plt.legend()
|
|
|
|
plt.subplot(122)
|
|
plt.plot(tst_loss_vals, label='Training Loss', color=tst_loss_color)
|
|
plt.plot(val_loss_vals, label='Validation Loss', color=val_loss_color)
|
|
plt.title('Crossentropy Loss')
|
|
plt.xlabel('Epoch')
|
|
plt.ylabel('Loss')
|
|
plt.legend()
|
|
|
|
plt.subplots_adjust(left=0.2, wspace=0.8, top=0.8)
|
|
|
|
file_name = dt.now().strftime('%Y.%m.%d-%H.%M.%S')
|
|
plt.savefig('/run/media/cdiesch/Slow_SSD/Tests/tfPageRot/Output/Test-%s.png' % file_name)
|
|
|
|
print('Saving model')
|
|
model.save('/run/media/cdiesch/Slow_SSD/Tests/tfPageRot/Output/Test-%s.mdl' % file_name)
|
|
#
|
|
# printer.write_no_prefix('')
|
|
#
|
|
# print('Testing X shape: %s' % str(test_x.shape))
|
|
# print('Testing Y shape: %s' % str(test_y.shape))
|
|
#
|
|
# print('Evaluating model')
|
|
# score = model.evaluate(test_x, test_y, batch_size=48)
|
|
# print('Score: %s' % str(score))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
print(ConsoleUtils.get_header(prg_name, prg_vers, prg_date, prg_auth))
|
|
sys.stdout = printer
|
|
main()
|
|
|