Sleds/TFFirstPageEngine/MakeRecords.py

138 lines
4.4 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import cv2
import os
num_first_pages = 0
num_non_first_pages = 0
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _get_label(file_name):
global num_first_pages, num_non_first_pages
if '.001.png' in file_name:
num_first_pages += 1
return 0
num_non_first_pages += 1
return 1
def _read_labeled_images(train_data_root):
# First Page ratio: 12466/46162+12466
file_names = []
labels = []
max_non_fp = 12466
cur_non_fp_cnt = 0
for root, _, files in os.walk(train_data_root):
for file in files:
if file.endswith('.png'):
is_fp = ('.001.' in file)
if is_fp or cur_non_fp_cnt < max_non_fp:
file = os.path.join(root, file)
file_names.append(file)
labels.append(_get_label(file))
if not is_fp:
cur_non_fp_cnt += 1
print('%d files are first pages, %d are not' % (num_first_pages, num_non_first_pages))
return file_names, labels
def _generate_img_label_batch(image, label, batch_size, min_after_dequeue):
num_threads = 16
images, labels = tf.train.shuffle_batch([image, label], batch_size=batch_size, num_threads=num_threads,
capacity=min_after_dequeue + 3 * batch_size,
min_after_dequeue=min_after_dequeue)
tf.summary.image(name='images', tensor=images)
return images, tf.reshape(labels, [batch_size])
def make_record(train_data_root, record_file):
record_writer = tf.python_io.TFRecordWriter(record_file)
files, labels = _read_labeled_images(train_data_root)
# Loop through all the files
for i in range(len(files)):
file_name = files[i]
label = labels[i]
image = cv2.imread(file_name, cv2.IMREAD_GRAYSCALE)
image = image.reshape([280, 280, 1])
height = image.shape[0]
width = image.shape[1]
channels = image.shape[2]
image_raw = image.tostring()
record = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(height),
'width': _int64_feature(width),
'channels': _int64_feature(channels),
'image_raw': _bytes_feature(image_raw),
'label': _int64_feature(label)
}))
record_writer.write(record.SerializeToString())
record_writer.close()
def read_and_decode(file_name_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(file_name_queue)
features = tf.parse_single_example(serialized=serialized_example,
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'channels': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
})
# Get the image and label
raw_image = tf.decode_raw(features['image_raw'], tf.float32, name='raw_image')
label = tf.cast(features['label'], tf.int32, name='label')
# Get the metadata
height = tf.cast(features['height'], tf.int32, name='height')
width = tf.cast(features['width'], tf.int32, name='width')
channels = tf.cast(features['channels'], tf.int32, name='channels')
# reshape the image
image = tf.reshape(raw_image, [280, 280, 1], name='reshaped_image')
# image_shape = tf.pack([height, width, channels])
# image = tf.reshape(image, image_shape)
return image, label
def inputs(record_file, num_epochs, batch_size):
min_after_dequeue = 1000
file_name_queue = tf.train.string_input_producer([record_file], num_epochs=num_epochs)
image, label = read_and_decode(file_name_queue)
return _generate_img_label_batch(image, label, batch_size=batch_size, min_after_dequeue=min_after_dequeue)