import os import re import cv2 import sys import time import math import datetime import pickle import random import collections import numpy as np import ConsoleUtils prg_name = 'PageRotation' prg_desc = 'Randomly rotates pages to create a data-set to train on.' prg_vers = '0.1.0' prg_date = '2017/10/02' prg_auth = 'Chris Diesch ' printer = ConsoleUtils.SLPrinter(prg_name) Rotation = collections.namedtuple('Rotation', 'direction deg') _NO_ROTATE = Rotation('Not-Rotated', 0) _RIGHT = Rotation('Right', 270) _UPSIDE_DOWN = Rotation('Upside-Down', 180) _LEFT = Rotation('Left', 90) _ROTATE_VALS = {0: _NO_ROTATE, 1: _RIGHT, 2: _UPSIDE_DOWN, 3: _LEFT} _NUM_ROTS = len(_ROTATE_VALS) _DEG = u'\N{DEGREE SIGN}' _NUM_NONE = 0 _NUM_LEFT = 0 _NUM_RIGHT = 0 _NUM_UPSIDE_DOWN = 0 _NUM_TOTAL = 0 _IMG_X = 500 _IMG_Y = 500 def _get_rotation(): rotate_dir = random.randint(0, 3) global _NUM_NONE, _NUM_RIGHT, _NUM_UPSIDE_DOWN, _NUM_LEFT, _NUM_TOTAL if rotate_dir not in _ROTATE_VALS.keys(): raise ValueError('The value of rotate_dir is invalid (Valid range: [0, 4], Given: %d).' % rotate_dir) res = _ROTATE_VALS[rotate_dir] if res == _NO_ROTATE: _NUM_NONE += 1 elif res == _RIGHT: _NUM_RIGHT += 1 elif res == _UPSIDE_DOWN: _NUM_UPSIDE_DOWN += 1 elif res == _LEFT: _NUM_LEFT += 1 _NUM_TOTAL += 1 return res def _load_images(image_root, ratio=.75, pattern=r'page[0-9]{5}\.png'): img_files = [os.path.join(image_root, f) for f in os.listdir(image_root) if re.match(pattern, f) is not None] random.shuffle(img_files) print('Found %d valid image files from "%s"' % (len(img_files), image_root)) split_idx = round(len(img_files) * ratio) train_files = img_files[:split_idx] test_files = img_files[split_idx:] return train_files, test_files def _read_and_rotate(image_files): records = {_NO_ROTATE: [], _RIGHT: [], _UPSIDE_DOWN: [], _LEFT: []} print('Processing %d images' % len(image_files)) start_time = time.time() i = 0 for image_file in image_files: # for j in range(4): rand_rot = _get_rotation() # Add a 10 degree variance to every image... variance = random.randint(-10, 10) # read & gray-scale image image = cv2.imread(image_file) image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # Get the size & center height, width = image.shape[:2] (center_w, center_h) = ((width // 2), (height // 2)) # get the rotation degree rot_deg = rand_rot.deg + variance # Get the rotation matrix rot_matrix = cv2.getRotationMatrix2D((center_w, center_h), rot_deg, 1.0) # need these to calculate the new height/width. cos = np.abs(rot_matrix[0, 0]) sin = np.abs(rot_matrix[0, 1]) # compute new height & width (yay trig) new_height = int((height * cos) + (width * sin)) new_width = int((height * sin) + (width * cos)) # Compute offsets h_offset = (new_height/2) - center_h w_offset = (new_width/2) - center_w # Add the offsets to the rotation matrix rot_matrix[1, 2] += h_offset rot_matrix[0, 2] += w_offset # Now we can do the transformation. image = cv2.warpAffine(image, rot_matrix, (new_width, new_height)) x_scale = _IMG_X / new_width y_scale = _IMG_Y / new_height # Resize the new image image = cv2.resize(image, None, fx=x_scale, fy=y_scale) # reshape the array image = np.reshape(image, newshape=(_IMG_Y, _IMG_X, 1)) # normalize values between 0-1. image = image.astype(dtype='float32') image /= float(255) records[rand_rot].append(_create_record(image, rand_rot, rot_deg)) i += 1 if i % 25 == 0: cur_time = time.time() - start_time rate = float(i/cur_time) print(' Processed %d/%d images (%.2f images/s)' % (i, len(image_files), rate)) run_time = time.time() - start_time rate = float(len(image_files) / run_time) rate_str = '(%.2f images/s)' % rate if rate < 1: rate = 1/rate rate_str = '(.2f s/image)' % rate print('Finished processing images in %.4f s. %s' % (run_time, rate_str)) return records def _create_record(image, rotation, rot_deg): height = image.shape[0] width = image.shape[1] rot_dir = rotation.direction example = {'image': image, 'height': height, 'width': width, 'rotation_dir': rot_dir, 'rotation_deg': rot_deg} return example def _save_records(rot_recs, out_file): for key in rot_recs: new_name = os.path.join(out_file, '%s.pkl' % key.direction) print('Saving records to "%s"' % new_name) with open(new_name, 'a+b') as writer: pickle.dump(rot_recs[key], writer) print('Done writing record.') def main(in_dir, out_dir): if not os.path.exists(out_dir): os.makedirs(out_dir) train_file_name = os.path.join(out_dir, 'train-data') test_file_name = os.path.join(out_dir, 'test-data') if not os.path.exists(train_file_name): os.mkdir(train_file_name) if not os.path.exists(test_file_name): os.mkdir(test_file_name) train_files, test_files = _load_images(in_dir) print('Loading training data') train_records = _read_and_rotate(train_files) del train_files print('Loaded training data') none_percent = float(_NUM_NONE / _NUM_TOTAL) * 100.0 right_percent = float(_NUM_RIGHT / _NUM_TOTAL) * 100.0 up_down_percent = float(_NUM_UPSIDE_DOWN / _NUM_TOTAL) * 100.0 left_percent = float(_NUM_LEFT / _NUM_TOTAL) * 100.0 print('Training data stats:') print(' Not Rotated: %.2f%%' % none_percent) print(' Rotated right: %.2f%%' % right_percent) print(' Upside-down: %.2f%%' % up_down_percent) print(' Rotated left: %.2f%%' % left_percent) printer.write_no_prefix('') _save_records(train_records, train_file_name) del train_records print('Loading training data') test_records = _read_and_rotate(test_files) del test_files print('Loaded training data') print('Saving test data to "%s"' % test_file_name) _save_records(test_records, test_file_name) none_percent = float(_NUM_NONE / _NUM_TOTAL) * 100.0 right_percent = float(_NUM_RIGHT / _NUM_TOTAL) * 100.0 up_down_percent = float(_NUM_UPSIDE_DOWN / _NUM_TOTAL) * 100.0 left_percent = float(_NUM_LEFT / _NUM_TOTAL) * 100.0 printer.write_no_prefix('') printer.write_line_break() print('Full data stats:') print(' Not Rotated: %.2f%%' % none_percent) print(' Rotated right: %.2f%%' % right_percent) print(' Upside-down: %.2f%%' % up_down_percent) print(' Rotated left: %.2f%%' % left_percent) printer.write_no_prefix('') if __name__ == '__main__': print(ConsoleUtils.get_header(prg_name, prg_vers, prg_date, prg_auth)) sys.stdout = printer input_root = r'/run/media/cdiesch/Slow_SSD/Tests/page-rotation-pdf/page-break/' output_root = os.path.join(r'/run/media/cdiesch/Slow_SSD/Tests/tfPageRot/Input/', datetime.datetime.now().strftime('%Y.%m.%d-%H.%M.%S')) main(input_root, output_root)