# We do all our imports at the top of our program. import argparse import os import json import matplotlib.pyplot as plt import numpy as np from matplotlib.ticker import AutoMinorLocator from collections import OrderedDict # Give the program a name. program_name = 'Accuracy Graph' # Describe what the program does beiefly. program_description = 'Generates a graph showing the difference between runs of Accuracy.py.' # The argument parser for the program. parser = argparse.ArgumentParser(prog=program_name, description=program_description, add_help=False) # This is where optional arguments are put (They are optional so they have a default value assign it here). # Doing it this way keeps the code more readable since all default values are in one place, and easy to change without # needing to search through our script for them. global_opt = 'default value' # Error and Warning console values: red_error = '\033[91mError:\033[0m' yellow_warning = '\033[93mWARNING:\033[0m' blue_okay = '\033[94mOK\033[0m' program_header = format('\033[95m%s\033[0m\n' '-----------------------' % program_name) NUM_VAL = 'number' STR_VAL = 'string' def load_files(files_dict): result = {} for key, file in files_dict.items(): result[key] = load_data_from_file(file) return result def load_data_from_file(file_path): result = {} with open(file_path) as reader: lines = reader.readlines() # Get the data from the lines. for line in lines: key, val = parse_line(line) result[key] = val return result def use_folder_for_tag(file_path): folder, file = os.path.split(file_path) folder, tag = os.path.split(folder) return tag def get_graph_tag(file_path): parent, file = os.path.split(file_path) return file def parse_line(line_txt): if line_txt.endswith('\n'): line_txt = line_txt[:-1] tag, value = line_txt.split('=', 1) value, val_type = value.split(':', 1) return tag, cast_to_type(val_type, value) def cast_to_type(t_val, val): if t_val == NUM_VAL: return float(val) elif t_val == STR_VAL: return str(val) else: return val def get_data_lists(data_dict): data_list = () keys = () for key, data in data_dict.items(): tmp_dict = {} for data_tag, data_value in data.items(): tmp_dict[data_tag] = data_value data_list = data_list + (tmp_dict,) keys = keys + (key,) return keys, data_list def load_graph_config(config_file): with open(config_file) as json_file: json_data = json.load(json_file) files = json_data['files'] colors = json_data['colors'] labels = json_data['labels'] return files, colors, labels def gen_plot(data_dict, colors, labels, out_path): num_files = len(data_dict) plt.figure(figsize=(num_files * 3.5, 5), dpi=300) ax = plt.subplot(111) bar_width = 0.1 bar_space = 0.0 opacity = 0.3 keys, data_info = get_data_lists(data_dict) index = np.arange(num_files) cur_idx = 0 num_values = 0 for data in data_info: cur_bar_num = 0 for label, value in data.items(): bar_height = float(value) bar_left = index[cur_idx] + ((bar_width + bar_space) * cur_bar_num) ax.bar(left=bar_left, height=bar_height, width=bar_width, alpha=opacity, color=colors[label], label=label) cur_bar_num += 1 num_values = cur_bar_num cur_idx += 1 ax.yaxis.set_ticks(np.arange(0, 110, 10)) ax.yaxis.set_minor_locator(AutoMinorLocator(2)) ax.yaxis.grid(which='major', linestyle='-') ax.yaxis.grid(which='minor', linestyle='--') plt.xlabel(labels['xLabel']) plt.ylabel(labels['yLabel']) plt.title(labels['Title']) index = np.arange(num_files) plt.xticks(index + ((bar_width + bar_space) * ((num_values - 1)/2)), keys) by_label = {} handle, label = plt.gca().get_legend_handles_labels() for i in range(len(label)): lbl = label[i] if lbl not in by_label.keys(): by_label[lbl] = handle[i] lgd = plt.legend(by_label.values(), by_label.keys(), loc=9, bbox_to_anchor=(0.5, -0.1)) # plt.tight_layout() plt.ylim([0, 100]) plt.savefig(out_path, bbox_exta_artists=(lgd,), bbox_inches='tight') # This is the main function of the program. def main(cfg_file, out_file): files, colors, labels = load_graph_config(cfg_file) files_data = load_files(files) gen_plot(files_data, colors, labels, out_file) def check_args(file, output): fatal_error = False if not os.path.exists(file): print('%s File does not exist: %s' % (red_error, file)) fatal_error = True if os.path.exists(output): print('%s File will be overwritten: %s' % (yellow_warning, output)) os.remove(output) if fatal_error: parser.print_help() print('Encountered fatal error, exiting...') exit(-1) # This is where we call the main method from. if __name__ == '__main__': print(program_header) # Set up arguments here. required_args = parser.add_argument_group('Required') optional_args = parser.add_argument_group('Optional') required_args.add_argument('-i', '--in_file', required=True, help='The path to the JSON file containing the necessary metadata to create a graph.') required_args.add_argument('-o', '--output_file', required=True, help='The path to write the graph to.') optional_args.add_argument('-h', '--help', action="help", help='Prints the help message.') # Get the arguments. args = parser.parse_args() in_file = args.in_file output_file = args.output_file print('Out file = %s\n' 'In File = %s' % (output_file, in_file)) # Do an argument check. check_args(in_file, output_file) # Now we can run... main(in_file, output_file)