Sleds/TFPageRotation/CreateData.py

249 lines
7.3 KiB
Python

import os
import re
import cv2
import sys
import time
import math
import datetime
import pickle
import random
import collections
import numpy as np
import ConsoleUtils
prg_name = 'PageRotation'
prg_desc = 'Randomly rotates pages to create a data-set to train on.'
prg_vers = '0.1.0'
prg_date = '2017/10/02'
prg_auth = 'Chris Diesch <cdiesch@sequencelogic.net>'
printer = ConsoleUtils.SLPrinter(prg_name)
Rotation = collections.namedtuple('Rotation', 'direction deg')
_NO_ROTATE = Rotation('Not-Rotated', 0)
_RIGHT = Rotation('Right', 270)
_UPSIDE_DOWN = Rotation('Upside-Down', 180)
_LEFT = Rotation('Left', 90)
_ROTATE_VALS = {0: _NO_ROTATE,
1: _RIGHT,
2: _UPSIDE_DOWN,
3: _LEFT}
_NUM_ROTS = len(_ROTATE_VALS)
_DEG = u'\N{DEGREE SIGN}'
_NUM_NONE = 0
_NUM_LEFT = 0
_NUM_RIGHT = 0
_NUM_UPSIDE_DOWN = 0
_NUM_TOTAL = 0
_IMG_X = 500
_IMG_Y = 500
def _get_rotation():
rotate_dir = random.randint(0, 3)
global _NUM_NONE, _NUM_RIGHT, _NUM_UPSIDE_DOWN, _NUM_LEFT, _NUM_TOTAL
if rotate_dir not in _ROTATE_VALS.keys():
raise ValueError('The value of rotate_dir is invalid (Valid range: [0, 4], Given: %d).' % rotate_dir)
res = _ROTATE_VALS[rotate_dir]
if res == _NO_ROTATE:
_NUM_NONE += 1
elif res == _RIGHT:
_NUM_RIGHT += 1
elif res == _UPSIDE_DOWN:
_NUM_UPSIDE_DOWN += 1
elif res == _LEFT:
_NUM_LEFT += 1
_NUM_TOTAL += 1
return res
def _load_images(image_root, ratio=.75, pattern=r'page[0-9]{5}\.png'):
img_files = [os.path.join(image_root, f) for f in os.listdir(image_root) if re.match(pattern, f) is not None]
random.shuffle(img_files)
print('Found %d valid image files from "%s"' % (len(img_files), image_root))
split_idx = round(len(img_files) * ratio)
train_files = img_files[:split_idx]
test_files = img_files[split_idx:]
return train_files, test_files
def _read_and_rotate(image_files):
records = {_NO_ROTATE: [],
_RIGHT: [],
_UPSIDE_DOWN: [],
_LEFT: []}
print('Processing %d images' % len(image_files))
start_time = time.time()
i = 0
for image_file in image_files:
# for j in range(4):
rand_rot = _get_rotation()
# Add a 10 degree variance to every image...
variance = random.randint(-10, 10)
# read & gray-scale image
image = cv2.imread(image_file)
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Get the size & center
height, width = image.shape[:2]
(center_w, center_h) = ((width // 2), (height // 2))
# get the rotation degree
rot_deg = rand_rot.deg + variance
# Get the rotation matrix
rot_matrix = cv2.getRotationMatrix2D((center_w, center_h), rot_deg, 1.0)
# need these to calculate the new height/width.
cos = np.abs(rot_matrix[0, 0])
sin = np.abs(rot_matrix[0, 1])
# compute new height & width (yay trig)
new_height = int((height * cos) + (width * sin))
new_width = int((height * sin) + (width * cos))
# Compute offsets
h_offset = (new_height/2) - center_h
w_offset = (new_width/2) - center_w
# Add the offsets to the rotation matrix
rot_matrix[1, 2] += h_offset
rot_matrix[0, 2] += w_offset
# Now we can do the transformation.
image = cv2.warpAffine(image, rot_matrix, (new_width, new_height))
x_scale = _IMG_X / new_width
y_scale = _IMG_Y / new_height
# Resize the new image
image = cv2.resize(image, None, fx=x_scale, fy=y_scale)
# reshape the array
image = np.reshape(image, newshape=(_IMG_Y, _IMG_X, 1))
# normalize values between 0-1.
image = image.astype(dtype='float32')
image /= float(255)
records[rand_rot].append(_create_record(image, rand_rot, rot_deg))
i += 1
if i % 25 == 0:
cur_time = time.time() - start_time
rate = float(i/cur_time)
print(' Processed %d/%d images (%.2f images/s)' % (i, len(image_files), rate))
run_time = time.time() - start_time
rate = float(len(image_files) / run_time)
rate_str = '(%.2f images/s)' % rate
if rate < 1:
rate = 1/rate
rate_str = '(.2f s/image)' % rate
print('Finished processing images in %.4f s. %s' % (run_time, rate_str))
return records
def _create_record(image, rotation, rot_deg):
height = image.shape[0]
width = image.shape[1]
rot_dir = rotation.direction
example = {'image': image,
'height': height,
'width': width,
'rotation_dir': rot_dir,
'rotation_deg': rot_deg}
return example
def _save_records(rot_recs, out_file):
for key in rot_recs:
new_name = os.path.join(out_file, '%s.pkl' % key.direction)
print('Saving records to "%s"' % new_name)
with open(new_name, 'a+b') as writer:
pickle.dump(rot_recs[key], writer)
print('Done writing record.')
def main(in_dir, out_dir):
if not os.path.exists(out_dir):
os.makedirs(out_dir)
train_file_name = os.path.join(out_dir, 'train-data')
test_file_name = os.path.join(out_dir, 'test-data')
if not os.path.exists(train_file_name):
os.mkdir(train_file_name)
if not os.path.exists(test_file_name):
os.mkdir(test_file_name)
train_files, test_files = _load_images(in_dir)
print('Loading training data')
train_records = _read_and_rotate(train_files)
del train_files
print('Loaded training data')
none_percent = float(_NUM_NONE / _NUM_TOTAL) * 100.0
right_percent = float(_NUM_RIGHT / _NUM_TOTAL) * 100.0
up_down_percent = float(_NUM_UPSIDE_DOWN / _NUM_TOTAL) * 100.0
left_percent = float(_NUM_LEFT / _NUM_TOTAL) * 100.0
print('Training data stats:')
print(' Not Rotated: %.2f%%' % none_percent)
print(' Rotated right: %.2f%%' % right_percent)
print(' Upside-down: %.2f%%' % up_down_percent)
print(' Rotated left: %.2f%%' % left_percent)
printer.write_no_prefix('')
_save_records(train_records, train_file_name)
del train_records
print('Loading training data')
test_records = _read_and_rotate(test_files)
del test_files
print('Loaded training data')
print('Saving test data to "%s"' % test_file_name)
_save_records(test_records, test_file_name)
none_percent = float(_NUM_NONE / _NUM_TOTAL) * 100.0
right_percent = float(_NUM_RIGHT / _NUM_TOTAL) * 100.0
up_down_percent = float(_NUM_UPSIDE_DOWN / _NUM_TOTAL) * 100.0
left_percent = float(_NUM_LEFT / _NUM_TOTAL) * 100.0
printer.write_no_prefix('')
printer.write_line_break()
print('Full data stats:')
print(' Not Rotated: %.2f%%' % none_percent)
print(' Rotated right: %.2f%%' % right_percent)
print(' Upside-down: %.2f%%' % up_down_percent)
print(' Rotated left: %.2f%%' % left_percent)
printer.write_no_prefix('')
if __name__ == '__main__':
print(ConsoleUtils.get_header(prg_name, prg_vers, prg_date, prg_auth))
sys.stdout = printer
input_root = r'/run/media/cdiesch/Slow_SSD/Tests/page-rotation-pdf/page-break/'
output_root = os.path.join(r'/run/media/cdiesch/Slow_SSD/Tests/tfPageRot/Input/',
datetime.datetime.now().strftime('%Y.%m.%d-%H.%M.%S'))
main(input_root, output_root)