120 lines
2.6 KiB
Python
120 lines
2.6 KiB
Python
|
|
import os
|
||
|
|
import sys
|
||
|
|
import cv2
|
||
|
|
import keras
|
||
|
|
import argparse
|
||
|
|
import collections
|
||
|
|
import numpy as np
|
||
|
|
from datetime import datetime as dt
|
||
|
|
|
||
|
|
Rotation = collections.namedtuple('Rotation', 'direction deg')
|
||
|
|
|
||
|
|
_IMG_X = 500
|
||
|
|
_IMG_Y = 500
|
||
|
|
|
||
|
|
_model = None
|
||
|
|
|
||
|
|
_NO_ROTATE = Rotation('Not-Rotated', 0)
|
||
|
|
_RIGHT = Rotation('Right', 270)
|
||
|
|
_UPSIDE_DOWN = Rotation('Upside-Down', 180)
|
||
|
|
_LEFT = Rotation('Left', 90)
|
||
|
|
|
||
|
|
_ROTATE_VALS = {0: _NO_ROTATE,
|
||
|
|
1: _RIGHT,
|
||
|
|
2: _UPSIDE_DOWN,
|
||
|
|
3: _LEFT}
|
||
|
|
|
||
|
|
|
||
|
|
def _load_model(path):
|
||
|
|
print('Loading model from "%s"' % path)
|
||
|
|
|
||
|
|
try:
|
||
|
|
model = keras.models.load_model(path)
|
||
|
|
print('Model loaded successfully!')
|
||
|
|
return model
|
||
|
|
|
||
|
|
except Exception as ex:
|
||
|
|
print('Error loading model: %s' % str(ex))
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def _preprocess_img(img_file):
|
||
|
|
print('Loading image from "%s"' % img_file)
|
||
|
|
# read & gray-scale image
|
||
|
|
image = cv2.imread(img_file)
|
||
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||
|
|
# Get the size & center
|
||
|
|
height, width = image.shape[:2]
|
||
|
|
# compute the scale factors...
|
||
|
|
x_scale = _IMG_X / width
|
||
|
|
y_scale = _IMG_Y / height
|
||
|
|
|
||
|
|
print('Resizing image')
|
||
|
|
print(' Original image size: %s' % str(image.shape))
|
||
|
|
print(' X Scale: %.2f' % x_scale)
|
||
|
|
print(' Y Scale: %.2f' % y_scale)
|
||
|
|
|
||
|
|
# resize the image
|
||
|
|
image = cv2.resize(image, None, fx=x_scale, fy=y_scale)
|
||
|
|
print(' New image size: %s' % str(image.shape))
|
||
|
|
# reshape the array
|
||
|
|
image = np.reshape(image, newshape=(_IMG_Y, _IMG_X, 1))
|
||
|
|
# use float values between 0-1
|
||
|
|
image = image.astype(dtype='float32')
|
||
|
|
image /= float(255)
|
||
|
|
|
||
|
|
return image
|
||
|
|
|
||
|
|
|
||
|
|
def _predict(imgs):
|
||
|
|
pred = _model.predict(imgs)
|
||
|
|
pred = pred[0]
|
||
|
|
|
||
|
|
big = -1
|
||
|
|
idx = 0
|
||
|
|
for i in range(len(pred)):
|
||
|
|
if pred[i] > big:
|
||
|
|
big = pred[i]
|
||
|
|
idx = i
|
||
|
|
|
||
|
|
return _ROTATE_VALS[idx], (big * 100)
|
||
|
|
|
||
|
|
|
||
|
|
def set_up(model_path):
|
||
|
|
global _model
|
||
|
|
tmp = _load_model(model_path)
|
||
|
|
|
||
|
|
if tmp is None:
|
||
|
|
print('Encountered error loading model, exiting...')
|
||
|
|
exit(-1)
|
||
|
|
|
||
|
|
_model = tmp
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
set_up(mdl_file)
|
||
|
|
|
||
|
|
img = _preprocess_img(img_test)
|
||
|
|
img = np.array([img])
|
||
|
|
pred = _predict(img)
|
||
|
|
|
||
|
|
print('Image %s is rotated right (%.2f%%)')
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
print('Running.')
|
||
|
|
mdl_file = r'Z:\Chris\TFPageRotation\Test-2017.11.21-03.11.48.mdl'
|
||
|
|
img_test = r'Z:\Chris\test7\page0001.jpg'
|
||
|
|
|
||
|
|
img = _preprocess_img(img_test)
|
||
|
|
|
||
|
|
print('Setting up')
|
||
|
|
|
||
|
|
set_up(mdl_file)
|
||
|
|
imgs = np.array([img])
|
||
|
|
|
||
|
|
prediction, conf = _predict(imgs)
|
||
|
|
|
||
|
|
print('Image is rotated %s (%.2f%%)' % (prediction.Rotation, conf))
|
||
|
|
|