138 lines
4.4 KiB
Python
138 lines
4.4 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import cv2
|
|
import os
|
|
|
|
|
|
num_first_pages = 0
|
|
num_non_first_pages = 0
|
|
|
|
|
|
def _bytes_feature(value):
|
|
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
|
|
|
|
|
def _int64_feature(value):
|
|
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
|
|
|
|
|
|
def _get_label(file_name):
|
|
global num_first_pages, num_non_first_pages
|
|
if '.001.png' in file_name:
|
|
num_first_pages += 1
|
|
return 0
|
|
num_non_first_pages += 1
|
|
return 1
|
|
|
|
|
|
def _read_labeled_images(train_data_root):
|
|
# First Page ratio: 12466/46162+12466
|
|
file_names = []
|
|
labels = []
|
|
max_non_fp = 12466
|
|
cur_non_fp_cnt = 0
|
|
for root, _, files in os.walk(train_data_root):
|
|
for file in files:
|
|
if file.endswith('.png'):
|
|
is_fp = ('.001.' in file)
|
|
|
|
if is_fp or cur_non_fp_cnt < max_non_fp:
|
|
file = os.path.join(root, file)
|
|
file_names.append(file)
|
|
labels.append(_get_label(file))
|
|
if not is_fp:
|
|
cur_non_fp_cnt += 1
|
|
|
|
print('%d files are first pages, %d are not' % (num_first_pages, num_non_first_pages))
|
|
|
|
return file_names, labels
|
|
|
|
|
|
def _generate_img_label_batch(image, label, batch_size, min_after_dequeue):
|
|
|
|
num_threads = 16
|
|
|
|
images, labels = tf.train.shuffle_batch([image, label], batch_size=batch_size, num_threads=num_threads,
|
|
capacity=min_after_dequeue + 3 * batch_size,
|
|
min_after_dequeue=min_after_dequeue)
|
|
|
|
tf.summary.image(name='images', tensor=images)
|
|
|
|
return images, tf.reshape(labels, [batch_size])
|
|
|
|
|
|
def make_record(train_data_root, record_file):
|
|
|
|
record_writer = tf.python_io.TFRecordWriter(record_file)
|
|
files, labels = _read_labeled_images(train_data_root)
|
|
# Loop through all the files
|
|
for i in range(len(files)):
|
|
file_name = files[i]
|
|
label = labels[i]
|
|
|
|
image = cv2.imread(file_name, cv2.IMREAD_GRAYSCALE)
|
|
image = image.reshape([280, 280, 1])
|
|
|
|
height = image.shape[0]
|
|
width = image.shape[1]
|
|
channels = image.shape[2]
|
|
|
|
image_raw = image.tostring()
|
|
|
|
record = tf.train.Example(features=tf.train.Features(feature={
|
|
'height': _int64_feature(height),
|
|
'width': _int64_feature(width),
|
|
'channels': _int64_feature(channels),
|
|
'image_raw': _bytes_feature(image_raw),
|
|
'label': _int64_feature(label)
|
|
}))
|
|
|
|
record_writer.write(record.SerializeToString())
|
|
|
|
record_writer.close()
|
|
|
|
|
|
def read_and_decode(file_name_queue):
|
|
reader = tf.TFRecordReader()
|
|
|
|
_, serialized_example = reader.read(file_name_queue)
|
|
|
|
features = tf.parse_single_example(serialized=serialized_example,
|
|
features={
|
|
'height': tf.FixedLenFeature([], tf.int64),
|
|
'width': tf.FixedLenFeature([], tf.int64),
|
|
'channels': tf.FixedLenFeature([], tf.int64),
|
|
'image_raw': tf.FixedLenFeature([], tf.string),
|
|
'label': tf.FixedLenFeature([], tf.int64)
|
|
})
|
|
|
|
# Get the image and label
|
|
raw_image = tf.decode_raw(features['image_raw'], tf.float32, name='raw_image')
|
|
label = tf.cast(features['label'], tf.int32, name='label')
|
|
# Get the metadata
|
|
height = tf.cast(features['height'], tf.int32, name='height')
|
|
width = tf.cast(features['width'], tf.int32, name='width')
|
|
channels = tf.cast(features['channels'], tf.int32, name='channels')
|
|
|
|
# reshape the image
|
|
image = tf.reshape(raw_image, [280, 280, 1], name='reshaped_image')
|
|
# image_shape = tf.pack([height, width, channels])
|
|
# image = tf.reshape(image, image_shape)
|
|
|
|
return image, label
|
|
|
|
|
|
def inputs(record_file, num_epochs, batch_size):
|
|
|
|
min_after_dequeue = 1000
|
|
file_name_queue = tf.train.string_input_producer([record_file], num_epochs=num_epochs)
|
|
|
|
image, label = read_and_decode(file_name_queue)
|
|
|
|
return _generate_img_label_batch(image, label, batch_size=batch_size, min_after_dequeue=min_after_dequeue)
|
|
|