200 lines
5.8 KiB
Python
200 lines
5.8 KiB
Python
# We do all our imports at the top of our program.
|
|
import argparse
|
|
import os
|
|
import json
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from matplotlib.ticker import AutoMinorLocator
|
|
from collections import OrderedDict
|
|
|
|
# Give the program a name.
|
|
program_name = 'Accuracy Graph'
|
|
# Describe what the program does beiefly.
|
|
program_description = 'Generates a graph showing the difference between runs of Accuracy.py.'
|
|
# The argument parser for the program.
|
|
parser = argparse.ArgumentParser(prog=program_name, description=program_description, add_help=False)
|
|
# This is where optional arguments are put (They are optional so they have a default value assign it here).
|
|
# Doing it this way keeps the code more readable since all default values are in one place, and easy to change without
|
|
# needing to search through our script for them.
|
|
global_opt = 'default value'
|
|
# Error and Warning console values:
|
|
red_error = '\033[91mError:\033[0m'
|
|
yellow_warning = '\033[93mWARNING:\033[0m'
|
|
blue_okay = '\033[94mOK\033[0m'
|
|
program_header = format('\033[95m%s\033[0m\n'
|
|
'-----------------------' % program_name)
|
|
NUM_VAL = 'number'
|
|
STR_VAL = 'string'
|
|
|
|
|
|
def load_files(files_dict):
|
|
result = {}
|
|
for key, file in files_dict.items():
|
|
result[key] = load_data_from_file(file)
|
|
return result
|
|
|
|
|
|
def load_data_from_file(file_path):
|
|
result = {}
|
|
with open(file_path) as reader:
|
|
lines = reader.readlines()
|
|
# Get the data from the lines.
|
|
for line in lines:
|
|
key, val = parse_line(line)
|
|
result[key] = val
|
|
|
|
return result
|
|
|
|
|
|
def use_folder_for_tag(file_path):
|
|
folder, file = os.path.split(file_path)
|
|
folder, tag = os.path.split(folder)
|
|
return tag
|
|
|
|
|
|
def get_graph_tag(file_path):
|
|
parent, file = os.path.split(file_path)
|
|
return file
|
|
|
|
|
|
def parse_line(line_txt):
|
|
if line_txt.endswith('\n'):
|
|
line_txt = line_txt[:-1]
|
|
|
|
tag, value = line_txt.split('=', 1)
|
|
value, val_type = value.split(':', 1)
|
|
return tag, cast_to_type(val_type, value)
|
|
|
|
|
|
def cast_to_type(t_val, val):
|
|
if t_val == NUM_VAL:
|
|
return float(val)
|
|
|
|
elif t_val == STR_VAL:
|
|
return str(val)
|
|
|
|
else:
|
|
return val
|
|
|
|
|
|
def get_data_lists(data_dict):
|
|
data_list = ()
|
|
keys = ()
|
|
|
|
for key, data in data_dict.items():
|
|
tmp_dict = {}
|
|
for data_tag, data_value in data.items():
|
|
tmp_dict[data_tag] = data_value
|
|
|
|
data_list = data_list + (tmp_dict,)
|
|
keys = keys + (key,)
|
|
|
|
return keys, data_list
|
|
|
|
|
|
def load_graph_config(config_file):
|
|
with open(config_file) as json_file:
|
|
json_data = json.load(json_file)
|
|
|
|
files = json_data['files']
|
|
colors = json_data['colors']
|
|
labels = json_data['labels']
|
|
return files, colors, labels
|
|
|
|
|
|
def gen_plot(data_dict, colors, labels, out_path):
|
|
num_files = len(data_dict)
|
|
plt.figure(figsize=(num_files * 3.5, 5), dpi=300)
|
|
ax = plt.subplot(111)
|
|
bar_width = 0.1
|
|
bar_space = 0.0
|
|
opacity = 0.3
|
|
keys, data_info = get_data_lists(data_dict)
|
|
|
|
index = np.arange(num_files)
|
|
|
|
cur_idx = 0
|
|
num_values = 0
|
|
for data in data_info:
|
|
cur_bar_num = 0
|
|
for label, value in data.items():
|
|
bar_height = float(value)
|
|
bar_left = index[cur_idx] + ((bar_width + bar_space) * cur_bar_num)
|
|
ax.bar(left=bar_left, height=bar_height, width=bar_width, alpha=opacity, color=colors[label], label=label)
|
|
cur_bar_num += 1
|
|
num_values = cur_bar_num
|
|
cur_idx += 1
|
|
|
|
ax.yaxis.set_ticks(np.arange(0, 110, 10))
|
|
ax.yaxis.set_minor_locator(AutoMinorLocator(2))
|
|
ax.yaxis.grid(which='major', linestyle='-')
|
|
ax.yaxis.grid(which='minor', linestyle='--')
|
|
|
|
plt.xlabel(labels['xLabel'])
|
|
plt.ylabel(labels['yLabel'])
|
|
plt.title(labels['Title'])
|
|
|
|
index = np.arange(num_files)
|
|
plt.xticks(index + ((bar_width + bar_space) * ((num_values - 1)/2)), keys)
|
|
|
|
by_label = {}
|
|
handle, label = plt.gca().get_legend_handles_labels()
|
|
|
|
for i in range(len(label)):
|
|
lbl = label[i]
|
|
if lbl not in by_label.keys():
|
|
by_label[lbl] = handle[i]
|
|
|
|
lgd = plt.legend(by_label.values(), by_label.keys(), loc=9, bbox_to_anchor=(0.5, -0.1))
|
|
# plt.tight_layout()
|
|
plt.ylim([0, 100])
|
|
plt.savefig(out_path, bbox_exta_artists=(lgd,), bbox_inches='tight')
|
|
|
|
|
|
# This is the main function of the program.
|
|
def main(cfg_file, out_file):
|
|
files, colors, labels = load_graph_config(cfg_file)
|
|
files_data = load_files(files)
|
|
gen_plot(files_data, colors, labels, out_file)
|
|
|
|
|
|
def check_args(file, output):
|
|
fatal_error = False
|
|
|
|
if not os.path.exists(file):
|
|
print('%s File does not exist: %s' % (red_error, file))
|
|
fatal_error = True
|
|
|
|
if os.path.exists(output):
|
|
print('%s File will be overwritten: %s' % (yellow_warning, output))
|
|
os.remove(output)
|
|
|
|
if fatal_error:
|
|
parser.print_help()
|
|
print('Encountered fatal error, exiting...')
|
|
exit(-1)
|
|
|
|
|
|
# This is where we call the main method from.
|
|
if __name__ == '__main__':
|
|
print(program_header)
|
|
# Set up arguments here.
|
|
required_args = parser.add_argument_group('Required')
|
|
optional_args = parser.add_argument_group('Optional')
|
|
required_args.add_argument('-i', '--in_file', required=True,
|
|
help='The path to the JSON file containing the necessary metadata to create a graph.')
|
|
required_args.add_argument('-o', '--output_file', required=True, help='The path to write the graph to.')
|
|
|
|
optional_args.add_argument('-h', '--help', action="help", help='Prints the help message.')
|
|
|
|
# Get the arguments.
|
|
args = parser.parse_args()
|
|
in_file = args.in_file
|
|
output_file = args.output_file
|
|
print('Out file = %s\n'
|
|
'In File = %s' % (output_file, in_file))
|
|
# Do an argument check.
|
|
check_args(in_file, output_file)
|
|
# Now we can run...
|
|
main(in_file, output_file)
|