Sleds/TFPageRotation/PageRotator.py

120 lines
2.6 KiB
Python
Raw Normal View History

2025-03-13 21:28:38 +00:00
import os
import sys
import cv2
import keras
import argparse
import collections
import numpy as np
from datetime import datetime as dt
Rotation = collections.namedtuple('Rotation', 'direction deg')
_IMG_X = 500
_IMG_Y = 500
_model = None
_NO_ROTATE = Rotation('Not-Rotated', 0)
_RIGHT = Rotation('Right', 270)
_UPSIDE_DOWN = Rotation('Upside-Down', 180)
_LEFT = Rotation('Left', 90)
_ROTATE_VALS = {0: _NO_ROTATE,
1: _RIGHT,
2: _UPSIDE_DOWN,
3: _LEFT}
def _load_model(path):
print('Loading model from "%s"' % path)
try:
model = keras.models.load_model(path)
print('Model loaded successfully!')
return model
except Exception as ex:
print('Error loading model: %s' % str(ex))
return None
def _preprocess_img(img_file):
print('Loading image from "%s"' % img_file)
# read & gray-scale image
image = cv2.imread(img_file)
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Get the size & center
height, width = image.shape[:2]
# compute the scale factors...
x_scale = _IMG_X / width
y_scale = _IMG_Y / height
print('Resizing image')
print(' Original image size: %s' % str(image.shape))
print(' X Scale: %.2f' % x_scale)
print(' Y Scale: %.2f' % y_scale)
# resize the image
image = cv2.resize(image, None, fx=x_scale, fy=y_scale)
print(' New image size: %s' % str(image.shape))
# reshape the array
image = np.reshape(image, newshape=(_IMG_Y, _IMG_X, 1))
# use float values between 0-1
image = image.astype(dtype='float32')
image /= float(255)
return image
def _predict(imgs):
pred = _model.predict(imgs)
pred = pred[0]
big = -1
idx = 0
for i in range(len(pred)):
if pred[i] > big:
big = pred[i]
idx = i
return _ROTATE_VALS[idx], (big * 100)
def set_up(model_path):
global _model
tmp = _load_model(model_path)
if tmp is None:
print('Encountered error loading model, exiting...')
exit(-1)
_model = tmp
def main():
set_up(mdl_file)
img = _preprocess_img(img_test)
img = np.array([img])
pred = _predict(img)
print('Image %s is rotated right (%.2f%%)')
if __name__ == '__main__':
print('Running.')
mdl_file = r'Z:\Chris\TFPageRotation\Test-2017.11.21-03.11.48.mdl'
img_test = r'Z:\Chris\test7\page0001.jpg'
img = _preprocess_img(img_test)
print('Setting up')
set_up(mdl_file)
imgs = np.array([img])
prediction, conf = _predict(imgs)
print('Image is rotated %s (%.2f%%)' % (prediction.Rotation, conf))