Sleds/TFPageRotation/CreateData.py

249 lines
7.3 KiB
Python
Raw Normal View History

2025-03-13 21:28:38 +00:00
import os
import re
import cv2
import sys
import time
import math
import datetime
import pickle
import random
import collections
import numpy as np
import ConsoleUtils
prg_name = 'PageRotation'
prg_desc = 'Randomly rotates pages to create a data-set to train on.'
prg_vers = '0.1.0'
prg_date = '2017/10/02'
prg_auth = 'Chris Diesch <cdiesch@sequencelogic.net>'
printer = ConsoleUtils.SLPrinter(prg_name)
Rotation = collections.namedtuple('Rotation', 'direction deg')
_NO_ROTATE = Rotation('Not-Rotated', 0)
_RIGHT = Rotation('Right', 270)
_UPSIDE_DOWN = Rotation('Upside-Down', 180)
_LEFT = Rotation('Left', 90)
_ROTATE_VALS = {0: _NO_ROTATE,
1: _RIGHT,
2: _UPSIDE_DOWN,
3: _LEFT}
_NUM_ROTS = len(_ROTATE_VALS)
_DEG = u'\N{DEGREE SIGN}'
_NUM_NONE = 0
_NUM_LEFT = 0
_NUM_RIGHT = 0
_NUM_UPSIDE_DOWN = 0
_NUM_TOTAL = 0
_IMG_X = 500
_IMG_Y = 500
def _get_rotation():
rotate_dir = random.randint(0, 3)
global _NUM_NONE, _NUM_RIGHT, _NUM_UPSIDE_DOWN, _NUM_LEFT, _NUM_TOTAL
if rotate_dir not in _ROTATE_VALS.keys():
raise ValueError('The value of rotate_dir is invalid (Valid range: [0, 4], Given: %d).' % rotate_dir)
res = _ROTATE_VALS[rotate_dir]
if res == _NO_ROTATE:
_NUM_NONE += 1
elif res == _RIGHT:
_NUM_RIGHT += 1
elif res == _UPSIDE_DOWN:
_NUM_UPSIDE_DOWN += 1
elif res == _LEFT:
_NUM_LEFT += 1
_NUM_TOTAL += 1
return res
def _load_images(image_root, ratio=.75, pattern=r'page[0-9]{5}\.png'):
img_files = [os.path.join(image_root, f) for f in os.listdir(image_root) if re.match(pattern, f) is not None]
random.shuffle(img_files)
print('Found %d valid image files from "%s"' % (len(img_files), image_root))
split_idx = round(len(img_files) * ratio)
train_files = img_files[:split_idx]
test_files = img_files[split_idx:]
return train_files, test_files
def _read_and_rotate(image_files):
records = {_NO_ROTATE: [],
_RIGHT: [],
_UPSIDE_DOWN: [],
_LEFT: []}
print('Processing %d images' % len(image_files))
start_time = time.time()
i = 0
for image_file in image_files:
# for j in range(4):
rand_rot = _get_rotation()
# Add a 10 degree variance to every image...
variance = random.randint(-10, 10)
# read & gray-scale image
image = cv2.imread(image_file)
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Get the size & center
height, width = image.shape[:2]
(center_w, center_h) = ((width // 2), (height // 2))
# get the rotation degree
rot_deg = rand_rot.deg + variance
# Get the rotation matrix
rot_matrix = cv2.getRotationMatrix2D((center_w, center_h), rot_deg, 1.0)
# need these to calculate the new height/width.
cos = np.abs(rot_matrix[0, 0])
sin = np.abs(rot_matrix[0, 1])
# compute new height & width (yay trig)
new_height = int((height * cos) + (width * sin))
new_width = int((height * sin) + (width * cos))
# Compute offsets
h_offset = (new_height/2) - center_h
w_offset = (new_width/2) - center_w
# Add the offsets to the rotation matrix
rot_matrix[1, 2] += h_offset
rot_matrix[0, 2] += w_offset
# Now we can do the transformation.
image = cv2.warpAffine(image, rot_matrix, (new_width, new_height))
x_scale = _IMG_X / new_width
y_scale = _IMG_Y / new_height
# Resize the new image
image = cv2.resize(image, None, fx=x_scale, fy=y_scale)
# reshape the array
image = np.reshape(image, newshape=(_IMG_Y, _IMG_X, 1))
# normalize values between 0-1.
image = image.astype(dtype='float32')
image /= float(255)
records[rand_rot].append(_create_record(image, rand_rot, rot_deg))
i += 1
if i % 25 == 0:
cur_time = time.time() - start_time
rate = float(i/cur_time)
print(' Processed %d/%d images (%.2f images/s)' % (i, len(image_files), rate))
run_time = time.time() - start_time
rate = float(len(image_files) / run_time)
rate_str = '(%.2f images/s)' % rate
if rate < 1:
rate = 1/rate
rate_str = '(.2f s/image)' % rate
print('Finished processing images in %.4f s. %s' % (run_time, rate_str))
return records
def _create_record(image, rotation, rot_deg):
height = image.shape[0]
width = image.shape[1]
rot_dir = rotation.direction
example = {'image': image,
'height': height,
'width': width,
'rotation_dir': rot_dir,
'rotation_deg': rot_deg}
return example
def _save_records(rot_recs, out_file):
for key in rot_recs:
new_name = os.path.join(out_file, '%s.pkl' % key.direction)
print('Saving records to "%s"' % new_name)
with open(new_name, 'a+b') as writer:
pickle.dump(rot_recs[key], writer)
print('Done writing record.')
def main(in_dir, out_dir):
if not os.path.exists(out_dir):
os.makedirs(out_dir)
train_file_name = os.path.join(out_dir, 'train-data')
test_file_name = os.path.join(out_dir, 'test-data')
if not os.path.exists(train_file_name):
os.mkdir(train_file_name)
if not os.path.exists(test_file_name):
os.mkdir(test_file_name)
train_files, test_files = _load_images(in_dir)
print('Loading training data')
train_records = _read_and_rotate(train_files)
del train_files
print('Loaded training data')
none_percent = float(_NUM_NONE / _NUM_TOTAL) * 100.0
right_percent = float(_NUM_RIGHT / _NUM_TOTAL) * 100.0
up_down_percent = float(_NUM_UPSIDE_DOWN / _NUM_TOTAL) * 100.0
left_percent = float(_NUM_LEFT / _NUM_TOTAL) * 100.0
print('Training data stats:')
print(' Not Rotated: %.2f%%' % none_percent)
print(' Rotated right: %.2f%%' % right_percent)
print(' Upside-down: %.2f%%' % up_down_percent)
print(' Rotated left: %.2f%%' % left_percent)
printer.write_no_prefix('')
_save_records(train_records, train_file_name)
del train_records
print('Loading training data')
test_records = _read_and_rotate(test_files)
del test_files
print('Loaded training data')
print('Saving test data to "%s"' % test_file_name)
_save_records(test_records, test_file_name)
none_percent = float(_NUM_NONE / _NUM_TOTAL) * 100.0
right_percent = float(_NUM_RIGHT / _NUM_TOTAL) * 100.0
up_down_percent = float(_NUM_UPSIDE_DOWN / _NUM_TOTAL) * 100.0
left_percent = float(_NUM_LEFT / _NUM_TOTAL) * 100.0
printer.write_no_prefix('')
printer.write_line_break()
print('Full data stats:')
print(' Not Rotated: %.2f%%' % none_percent)
print(' Rotated right: %.2f%%' % right_percent)
print(' Upside-down: %.2f%%' % up_down_percent)
print(' Rotated left: %.2f%%' % left_percent)
printer.write_no_prefix('')
if __name__ == '__main__':
print(ConsoleUtils.get_header(prg_name, prg_vers, prg_date, prg_auth))
sys.stdout = printer
input_root = r'/run/media/cdiesch/Slow_SSD/Tests/page-rotation-pdf/page-break/'
output_root = os.path.join(r'/run/media/cdiesch/Slow_SSD/Tests/tfPageRot/Input/',
datetime.datetime.now().strftime('%Y.%m.%d-%H.%M.%S'))
main(input_root, output_root)