from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import tensorflow as tf import cv2 import os num_first_pages = 0 num_non_first_pages = 0 def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _get_label(file_name): global num_first_pages, num_non_first_pages if '.001.png' in file_name: num_first_pages += 1 return 0 num_non_first_pages += 1 return 1 def _read_labeled_images(train_data_root): # First Page ratio: 12466/46162+12466 file_names = [] labels = [] max_non_fp = 12466 cur_non_fp_cnt = 0 for root, _, files in os.walk(train_data_root): for file in files: if file.endswith('.png'): is_fp = ('.001.' in file) if is_fp or cur_non_fp_cnt < max_non_fp: file = os.path.join(root, file) file_names.append(file) labels.append(_get_label(file)) if not is_fp: cur_non_fp_cnt += 1 print('%d files are first pages, %d are not' % (num_first_pages, num_non_first_pages)) return file_names, labels def _generate_img_label_batch(image, label, batch_size, min_after_dequeue): num_threads = 16 images, labels = tf.train.shuffle_batch([image, label], batch_size=batch_size, num_threads=num_threads, capacity=min_after_dequeue + 3 * batch_size, min_after_dequeue=min_after_dequeue) tf.summary.image(name='images', tensor=images) return images, tf.reshape(labels, [batch_size]) def make_record(train_data_root, record_file): record_writer = tf.python_io.TFRecordWriter(record_file) files, labels = _read_labeled_images(train_data_root) # Loop through all the files for i in range(len(files)): file_name = files[i] label = labels[i] image = cv2.imread(file_name, cv2.IMREAD_GRAYSCALE) image = image.reshape([280, 280, 1]) height = image.shape[0] width = image.shape[1] channels = image.shape[2] image_raw = image.tostring() record = tf.train.Example(features=tf.train.Features(feature={ 'height': _int64_feature(height), 'width': _int64_feature(width), 'channels': _int64_feature(channels), 'image_raw': _bytes_feature(image_raw), 'label': _int64_feature(label) })) record_writer.write(record.SerializeToString()) record_writer.close() def read_and_decode(file_name_queue): reader = tf.TFRecordReader() _, serialized_example = reader.read(file_name_queue) features = tf.parse_single_example(serialized=serialized_example, features={ 'height': tf.FixedLenFeature([], tf.int64), 'width': tf.FixedLenFeature([], tf.int64), 'channels': tf.FixedLenFeature([], tf.int64), 'image_raw': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64) }) # Get the image and label raw_image = tf.decode_raw(features['image_raw'], tf.float32, name='raw_image') label = tf.cast(features['label'], tf.int32, name='label') # Get the metadata height = tf.cast(features['height'], tf.int32, name='height') width = tf.cast(features['width'], tf.int32, name='width') channels = tf.cast(features['channels'], tf.int32, name='channels') # reshape the image image = tf.reshape(raw_image, [280, 280, 1], name='reshaped_image') # image_shape = tf.pack([height, width, channels]) # image = tf.reshape(image, image_shape) return image, label def inputs(record_file, num_epochs, batch_size): min_after_dequeue = 1000 file_name_queue = tf.train.string_input_producer([record_file], num_epochs=num_epochs) image, label = read_and_decode(file_name_queue) return _generate_img_label_batch(image, label, batch_size=batch_size, min_after_dequeue=min_after_dequeue)