190 lines
4.7 KiB
Python
190 lines
4.7 KiB
Python
import os
|
|
import random
|
|
import cv2
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import Trainer
|
|
|
|
|
|
cur_batch_num = 0
|
|
|
|
x_train = []
|
|
y_train = []
|
|
|
|
train_size = 0
|
|
|
|
x_test = []
|
|
y_test = []
|
|
|
|
test_size = 0
|
|
|
|
total_ran = 0
|
|
|
|
R_TO_GREY = (0.2989/255)
|
|
G_TO_GREY = (0.5870/255)
|
|
B_TO_GREY = (0.1140/255)
|
|
|
|
RGB_TO_GREY = [R_TO_GREY,
|
|
G_TO_GREY,
|
|
B_TO_GREY]
|
|
|
|
TOP_BOTTOM_PAD = 10
|
|
LEFT_PAD = 39
|
|
RIGHT_PAD = 40
|
|
|
|
BORDER_COLOR = [255, 255, 255]
|
|
|
|
|
|
def num_epochs():
|
|
return total_ran/train_size
|
|
|
|
|
|
def next_train_batch(batch_size, shuffle=True):
|
|
global cur_batch_num, total_ran
|
|
x = []
|
|
y = []
|
|
if not shuffle:
|
|
start_idx = batch_size * cur_batch_num
|
|
cur_batch_num += 1
|
|
end_idx = start_idx + batch_size
|
|
|
|
if len(x_test) < end_idx:
|
|
end_idx = len(x_test)
|
|
|
|
if test_size < start_idx:
|
|
cur_batch_num = 0
|
|
start_idx = 0
|
|
end_idx = batch_size
|
|
|
|
x = x_test[start_idx:end_idx]
|
|
y = y_test[start_idx:end_idx]
|
|
|
|
else:
|
|
used_idx = set()
|
|
for i in range(batch_size):
|
|
# emulating a do-while loop.
|
|
while True:
|
|
idx = random.randint(0, len(x_test))
|
|
if idx not in used_idx:
|
|
break
|
|
|
|
x.append(x_train[idx])
|
|
y.append(y_train[idx])
|
|
used_idx.add(idx)
|
|
|
|
x_batch = np.array(x)
|
|
y_batch = np.array(y)
|
|
# print('[DataHelper] Generated training batch of size %d' % batch_size)
|
|
total_ran += batch_size
|
|
return x_batch, y_batch
|
|
|
|
|
|
def next_train_items():
|
|
if total_ran > train_size:
|
|
idx = total_ran % train_size
|
|
else:
|
|
idx = total_ran
|
|
|
|
return x_train[idx], y_train[idx]
|
|
|
|
|
|
def get_test_data():
|
|
return x_test, y_test
|
|
|
|
|
|
def load_data(data_root):
|
|
files = _get_file_names(data_root)
|
|
global x_test, y_test, x_train, y_train, test_size, train_size
|
|
first_pages = []
|
|
non_first_pages = []
|
|
# Get the first and non first pages
|
|
for file in files:
|
|
img = _np_get_clean_img(file)
|
|
if img is not None:
|
|
|
|
if '0001' in file:
|
|
first_pages.append(img)
|
|
else:
|
|
non_first_pages.append(img)
|
|
else:
|
|
print('Image at "%s" is bad' % file)
|
|
|
|
first_page_labels = [np.array([1, 0]) for _ in first_pages]
|
|
non_first_page_labels = [np.array([0, 1]) for _ in non_first_pages]
|
|
|
|
x_test_raw = first_pages + non_first_pages
|
|
y_test_raw = first_page_labels + non_first_page_labels
|
|
test_size = len(x_test_raw)
|
|
|
|
x_test = np.array(x_test_raw)
|
|
y_test = np.array(y_test_raw)
|
|
|
|
x_train = x_test_raw
|
|
y_train = y_test_raw
|
|
|
|
train_size = len(x_train)
|
|
|
|
print('[DataHelper] Loaded %d first pages and %d other files' % (len(first_pages), len(non_first_pages)))
|
|
|
|
|
|
def _get_clean_img(file_name_tensor):
|
|
contents = tf.read_file(file_name_tensor)
|
|
image = tf.image.decode_jpeg(contents=contents, channels=1)
|
|
resized_img = tf.image.resize_image_with_crop_or_pad(image=image,
|
|
target_height=Trainer.img_h,
|
|
target_width=Trainer.img_w)
|
|
return tf.cast(resized_img, tf.float32).eval()
|
|
|
|
|
|
def _get_file_names(test_file_root):
|
|
names = []
|
|
for folder in os.listdir(test_file_root):
|
|
folder = os.path.join(test_file_root, folder)
|
|
if os.path.isdir(folder):
|
|
for file in os.listdir(folder):
|
|
file = os.path.join(folder, file)
|
|
if file.endswith('.jpg'):
|
|
names.append(file)
|
|
print('[DataHelper] Found %d files to train with' % len(names))
|
|
return names
|
|
|
|
|
|
def _np_get_clean_img(file_name):
|
|
# print('Processing Image: "%s"' % file_name)
|
|
|
|
raw_img = cv2.imread(file_name)
|
|
# print(' raw_img shape: %s' % str(raw_img.shape))
|
|
|
|
vertical_size = raw_img.shape[0]
|
|
horizontal_size = raw_img.shape[1]
|
|
|
|
vrt_pad = 280 - vertical_size
|
|
hor_pad = 280 - horizontal_size
|
|
|
|
top_pad = int(vrt_pad/2)
|
|
bot_pad = vrt_pad - top_pad
|
|
|
|
lft_pad = int(hor_pad/2)
|
|
rht_pad = hor_pad - lft_pad
|
|
|
|
if lft_pad < 0 or rht_pad < 0 or top_pad < 0 or bot_pad < 0:
|
|
return None
|
|
|
|
# print(' Image padding')
|
|
# print(' Top: %d' % top_pad)
|
|
# print(' Bottom: %d' % bot_pad)
|
|
# print(' Left: %d' % lft_pad)
|
|
# print(' Right: %d' % rht_pad)
|
|
|
|
pad_img = cv2.copyMakeBorder(raw_img, top_pad, bot_pad, lft_pad, rht_pad, cv2.BORDER_CONSTANT, value=BORDER_COLOR)
|
|
# print(' pad_img shape: %s' % str(pad_img.shape))
|
|
|
|
grey_img = np.dot(pad_img, RGB_TO_GREY)
|
|
# print(' grey_img shape: %s' %s str(grey_img.shape))
|
|
|
|
res_img = grey_img.reshape([280, 280, 1])
|
|
# print(' res_img shape: %s' % str(res_img.shape))
|
|
|
|
return res_img
|