""" GetClosestCenter.py ==================== This is a tool for getting the closest center document to every file in a given directory. This is both an example of how to use our clustering tools for classification as well as a way to measure the accuracy of our distance calculations. This tool is multi-threaded and has a dependency on DocumentDistance.py to compute the distance between documents in the library and center documents. .. moduleauthor:: Chris Diesch """ import os import argparse import ConsoleUtils import DocumentDistance import sys import concurrent.futures import json red_error = '\033[91mError:\033[0m' yellow_warning = '\033[93mWARNING:\033[0m' blue_okay = '\033[94mOK\033[0m' program_name = 'GetClosestCenter' program_description = 'Finds the closest center of every document in the library, this is effectively classifying ' \ 'these documents using our distance algorithm in conjunction with k-means clustering.' parser = argparse.ArgumentParser(prog=program_name, description=program_description, add_help=False) # Default argument values def_thread_count = 1 def get_centers(library_dir): """ Loads all the center documents in the library. :param library_dir: The parent directory of the library. :type library_dir: str .. raw:: html
:return: """ print('{:^70}'.format('Loading doctype centers.')) doctype_centers = {} num_ran = 0 doctype_folders = os.listdir(library_dir) num_to_run = len(doctype_folders) for doctype_folder in doctype_folders: num_ran += 1 full_path = os.path.join(library_dir, doctype_folder) if not os.path.isfile(full_path): center = get_center_doctype(full_path) if center is not None: doctype_centers[doctype_folder] = center ConsoleUtils.print_progress_bar(num_ran, num_to_run, 50, 70) return doctype_centers def load_centers_tokens(centers): """ Loads the tokens from the files representing the centers of each doctype. :param centers: A dict where the keys are the doctypes in the library and the values are the centers of those doctypes. :type centers: dict[str,str] .. raw:: html
:return: A dict with keys equal to the center document and the values representing the tokens read by DocumentDistance.read_tokens_from_file(center). :rtype: dict[str,list of str] """ centers_tkns = {} for doc_type, center in centers.items(): result, name = DocumentDistance.read_tokens_from_file(center) centers_tkns[name] = result return centers_tkns def get_center_doctype(doctype_dir): """ Finds the center doctype in the given doctype directory. :param doctype_dir: The parent directory of a doctype to load the center of. :type doctype_dir: str .. raw:: html
:return: The name of the document which is the computed center of the doctype cluster. :rtype: str """ center_file = os.path.join(doctype_dir, 'doctype_center.txt') if os.path.exists(center_file): with open(center_file) as center: center_lines = center.readlines() center_line = center_lines[0].replace('\n', '') center_doctype = center_line.split(' = ')[1] return center_doctype else: return None def get_doctype(file): """ Gets the doctype of a given file. .. note:: This currently only works for V1 libraries. :param file: The path of the file to get the doctype of. :type file: str .. raw:: html
:return: The name of the doctype of a given file. :rtype: str """ fml_dir, file_name = os.path.split(file) doctype_dir, fml_name = os.path.split(fml_dir) lib_dir, doctype_name = os.path.split(doctype_dir) return doctype_name def load_all_docs(doc_file, thread_count): """ Loads all the documents contained in doc_file using DocumentDistance.load_files(doc_file, thread_count). :param doc_file: The path of document containing a list of docs to find the closest center of. :type doc_file: str :param thread_count: The number of threads to execute on. :type thread_count: int .. raw:: html
:return: A dict with keys representing the name of the file and values representing the list of tokens in the file. :rtype: dict[str,list of str] """ return DocumentDistance.load_files(doc_file, thread_count) def get_closest_center(file, centers, file_name): """ Computes the closest center to the given file. :param file: The dict representing the file read by DocumentDistance.load_files() :type file: dict[str,str] :param centers: The dict with keys representing the names of center documents, and values representing the file read by DocumentDistance.read_tokens_from_file() :type centers: dict[str,str] :param file_name: The name of the file being read. :type file_name: str .. raw:: html
:return: The name of the file ran, and a dict with the following keys: 'predictedDoctype', 'actualDoctype', 'closestCenter', 'distFromCenter', 'distances' :rtype: tuple(str, dict) """ tkn_tag = str(DocumentDistance.tkn_vals_tag) actual_doctype = get_doctype(file_name) min_dist = sys.maxsize min_center = '' min_doctype = '' dist_from_centroids = {} for key, center in centers.items(): dist_from_center = int(DocumentDistance.lev_dist(file[tkn_tag], center[tkn_tag])) if dist_from_center < min_dist: min_center = key min_dist = dist_from_center min_doctype = get_doctype(key) else: dist_from_centroids[key] = str(dist_from_center) result = {'predictedDoctype': min_doctype, 'actualDoctype': actual_doctype, 'closestCenter': min_center, 'distFromCenter': str(min_dist), 'distances': dist_from_centroids} return file_name, result def get_distances_from_centers(files, centers, thread_count): """ Computes the distances every file is from every center document. :param files: The dict returned by DocumentDistance.load_files(). :type files: dict[str, list of str] :param centers: A dict identical to that returned by DocumentDistance.load_files() but only containing the center documents for each doctype cluster. :type centers: dict[str, list of str] :param thread_count: The number of threads to execute on. :type thread_count: int .. raw:: html
:return: A dict with keys equal to the file names, and values equal to the dict returned by get_closest_center(). :rtype: dict[str,dict[str,str]] """ num_ran = 0 num_to_run = len(files) closest_centers = {} print('{:^70}'.format('Finding closet doctype centers (%d threads).' % thread_count)) with concurrent.futures.ThreadPoolExecutor(max_workers=thread_count) as thread_pool: future_dists = {thread_pool.submit(get_closest_center, file, centers, key): file for key, file in files.items()} for dist in concurrent.futures.as_completed(future_dists): num_ran += 1 ran_file = future_dists[dist] file_name, result = dist.result() closest_centers[file_name] = result ConsoleUtils.print_progress_bar(num_ran, num_to_run, 50, 70, 4) return closest_centers def write_json(closest_centers): """ Formats passed dict for each file and writes the json files out. :param closest_centers: The dict returned by get_distances_from_centers(). :type closest_centers: dict[str,dict[str,str]] .. raw:: html
:return: None. """ num_ran = 0 num_total = len(closest_centers) print('{:^70}'.format('Writing json %d files...' % num_total)) for file in closest_centers.keys(): ConsoleUtils.print_progress_bar(num_ran, num_total, 50, 70) output_file = file.replace('.lev', 'closest_center.json') with open(output_file, 'w+') as output: output.write(json.dumps(closest_centers[file])) def main(library_dir, thread_count): """ The main entry point for the tool. :param library_dir: The path to the library containing the documents to classify. :type library_dir: str :param thread_count: The number of threads to execute on. :type thread_count: int .. raw:: html
:return: The status of the program. :rtype: int """ full_list_file = os.path.join(library_dir, 'tkn_files.txt') centers = get_centers(library_dir) center_tkns = load_centers_tokens(centers) all_docs = load_all_docs(full_list_file, thread_count) dist_from_centers = get_distances_from_centers(all_docs, center_tkns, thread_count) write_json(dist_from_centers) def check_args(library_dir, thread_cnt): """ Makes sure arguments are valid before running the program. :param library_dir: The path to the parent folder of the library. :type library_dir: str :param thread_cnt: The number of threads to execute on. :type thread_cnt: int .. raw:: html
:return: None. """ fatal_errors = False if thread_cnt < 1: print('%s thread_count cannot be less than 1' % yellow_warning) print('%s Using default of %d' % (blue_okay, def_thread_count)) if not os.path.exists(library_dir): print('%s Directory does not exist: %s' % (red_error, library_dir)) fatal_errors = True if fatal_errors: parser.print_help() print('Exiting...') exit(-1) if __name__ == '__main__': required_args = parser.add_argument_group('Required') optional_args = parser.add_argument_group('Optional') required_args.add_argument('-l', '--library_dir', required=True, help='The full path to the parent folder of the library to classify') optional_args.add_argument('-t', '--thread_count', required=False, type=int, default=def_thread_count, help='The number of threads to run this tool on.') optional_args.add_argument('-w', '--overwrite', required=False, action='store_true', help='If this flag is included files will be overwritten without asking.') optional_args.add_argument('-h', '--help', action='help', help='Prints the help message.') args = parser.parse_args() lib_dir = args.library_dir thread_count = args.thread_count auto_overwrite = args.overwrite # Are the args valid? check_args(lib_dir, thread_count) # If all is well, let's run the program! main(lib_dir, thread_count)