311 lines
10 KiB
Python
311 lines
10 KiB
Python
"""
|
|
GetClosestCenter.py
|
|
====================
|
|
|
|
This is a tool for getting the closest center document to every file in a given directory. This is both an example of
|
|
how to use our clustering tools for classification as well as a way to measure the accuracy of our distance
|
|
calculations. This tool is multi-threaded and has a dependency on DocumentDistance.py to compute the distance
|
|
between documents in the library and center documents.
|
|
|
|
.. moduleauthor:: Chris Diesch <cdiesch@sequencelogic.net>
|
|
"""
|
|
|
|
import os
|
|
import argparse
|
|
import ConsoleUtils
|
|
import DocumentDistance
|
|
import sys
|
|
import concurrent.futures
|
|
import json
|
|
|
|
|
|
red_error = '\033[91mError:\033[0m'
|
|
yellow_warning = '\033[93mWARNING:\033[0m'
|
|
blue_okay = '\033[94mOK\033[0m'
|
|
program_name = 'GetClosestCenter'
|
|
program_description = 'Finds the closest center of every document in the library, this is effectively classifying ' \
|
|
'these documents using our distance algorithm in conjunction with k-means clustering.'
|
|
parser = argparse.ArgumentParser(prog=program_name, description=program_description, add_help=False)
|
|
# Default argument values
|
|
def_thread_count = 1
|
|
|
|
|
|
def get_centers(library_dir):
|
|
"""
|
|
Loads all the center documents in the library.
|
|
|
|
:param library_dir: The parent directory of the library.
|
|
:type library_dir: str
|
|
|
|
.. raw:: html
|
|
<br>
|
|
|
|
:return:
|
|
"""
|
|
print('{:^70}'.format('Loading doctype centers.'))
|
|
doctype_centers = {}
|
|
num_ran = 0
|
|
doctype_folders = os.listdir(library_dir)
|
|
num_to_run = len(doctype_folders)
|
|
for doctype_folder in doctype_folders:
|
|
num_ran += 1
|
|
full_path = os.path.join(library_dir, doctype_folder)
|
|
if not os.path.isfile(full_path):
|
|
center = get_center_doctype(full_path)
|
|
if center is not None:
|
|
doctype_centers[doctype_folder] = center
|
|
ConsoleUtils.print_progress_bar(num_ran, num_to_run, 50, 70)
|
|
|
|
return doctype_centers
|
|
|
|
|
|
def load_centers_tokens(centers):
|
|
"""
|
|
Loads the tokens from the files representing the centers of each doctype.
|
|
|
|
:param centers: A dict where the keys are the doctypes in the library and the values are the centers of those
|
|
doctypes.
|
|
:type centers: dict[str,str]
|
|
|
|
.. raw:: html
|
|
<br>
|
|
|
|
:return: A dict with keys equal to the center document and the values representing the tokens read by
|
|
DocumentDistance.read_tokens_from_file(center).
|
|
:rtype: dict[str,list of str]
|
|
"""
|
|
centers_tkns = {}
|
|
for doc_type, center in centers.items():
|
|
result, name = DocumentDistance.read_tokens_from_file(center)
|
|
centers_tkns[name] = result
|
|
return centers_tkns
|
|
|
|
|
|
def get_center_doctype(doctype_dir):
|
|
"""
|
|
Finds the center doctype in the given doctype directory.
|
|
|
|
:param doctype_dir: The parent directory of a doctype to load the center of.
|
|
:type doctype_dir: str
|
|
|
|
.. raw:: html
|
|
<br>
|
|
|
|
:return: The name of the document which is the computed center of the doctype cluster.
|
|
:rtype: str
|
|
"""
|
|
center_file = os.path.join(doctype_dir, 'doctype_center.txt')
|
|
if os.path.exists(center_file):
|
|
with open(center_file) as center:
|
|
center_lines = center.readlines()
|
|
center_line = center_lines[0].replace('\n', '')
|
|
center_doctype = center_line.split(' = ')[1]
|
|
return center_doctype
|
|
else:
|
|
return None
|
|
|
|
|
|
def get_doctype(file):
|
|
"""
|
|
Gets the doctype of a given file.
|
|
|
|
.. note:: This currently only works for V1 libraries.
|
|
|
|
:param file: The path of the file to get the doctype of.
|
|
:type file: str
|
|
|
|
.. raw:: html
|
|
<br>
|
|
|
|
:return: The name of the doctype of a given file.
|
|
:rtype: str
|
|
"""
|
|
fml_dir, file_name = os.path.split(file)
|
|
doctype_dir, fml_name = os.path.split(fml_dir)
|
|
lib_dir, doctype_name = os.path.split(doctype_dir)
|
|
return doctype_name
|
|
|
|
|
|
def load_all_docs(doc_file, thread_count):
|
|
"""
|
|
Loads all the documents contained in doc_file using DocumentDistance.load_files(doc_file, thread_count).
|
|
|
|
:param doc_file: The path of document containing a list of docs to find the closest center of.
|
|
:type doc_file: str
|
|
:param thread_count: The number of threads to execute on.
|
|
:type thread_count: int
|
|
|
|
.. raw:: html
|
|
<br>
|
|
|
|
:return: A dict with keys representing the name of the file and values representing the list of tokens in the file.
|
|
:rtype: dict[str,list of str]
|
|
"""
|
|
return DocumentDistance.load_files(doc_file, thread_count)
|
|
|
|
|
|
def get_closest_center(file, centers, file_name):
|
|
"""
|
|
Computes the closest center to the given file.
|
|
|
|
:param file: The dict representing the file read by DocumentDistance.load_files()
|
|
:type file: dict[str,str]
|
|
:param centers: The dict with keys representing the names of center documents, and values representing the file read
|
|
by DocumentDistance.read_tokens_from_file()
|
|
:type centers: dict[str,str]
|
|
:param file_name: The name of the file being read.
|
|
:type file_name: str
|
|
|
|
.. raw:: html
|
|
<br>
|
|
|
|
:return: The name of the file ran, and a dict with the following keys:
|
|
'predictedDoctype', 'actualDoctype', 'closestCenter', 'distFromCenter', 'distances'
|
|
:rtype: tuple(str, dict)
|
|
"""
|
|
tkn_tag = str(DocumentDistance.tkn_vals_tag)
|
|
actual_doctype = get_doctype(file_name)
|
|
min_dist = sys.maxsize
|
|
min_center = ''
|
|
min_doctype = ''
|
|
dist_from_centroids = {}
|
|
for key, center in centers.items():
|
|
dist_from_center = int(DocumentDistance.lev_dist(file[tkn_tag], center[tkn_tag]))
|
|
if dist_from_center < min_dist:
|
|
min_center = key
|
|
min_dist = dist_from_center
|
|
min_doctype = get_doctype(key)
|
|
else:
|
|
dist_from_centroids[key] = str(dist_from_center)
|
|
result = {'predictedDoctype': min_doctype, 'actualDoctype': actual_doctype, 'closestCenter': min_center,
|
|
'distFromCenter': str(min_dist), 'distances': dist_from_centroids}
|
|
return file_name, result
|
|
|
|
|
|
def get_distances_from_centers(files, centers, thread_count):
|
|
"""
|
|
Computes the distances every file is from every center document.
|
|
|
|
:param files: The dict returned by DocumentDistance.load_files().
|
|
:type files: dict[str, list of str]
|
|
:param centers: A dict identical to that returned by DocumentDistance.load_files() but only containing the center
|
|
documents for each doctype cluster.
|
|
:type centers: dict[str, list of str]
|
|
:param thread_count: The number of threads to execute on.
|
|
:type thread_count: int
|
|
|
|
.. raw:: html
|
|
<br>
|
|
|
|
:return: A dict with keys equal to the file names, and values equal to the dict returned by get_closest_center().
|
|
:rtype: dict[str,dict[str,str]]
|
|
"""
|
|
num_ran = 0
|
|
num_to_run = len(files)
|
|
closest_centers = {}
|
|
print('{:^70}'.format('Finding closet doctype centers (%d threads).' % thread_count))
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=thread_count) as thread_pool:
|
|
future_dists = {thread_pool.submit(get_closest_center, file, centers, key): file for key, file in files.items()}
|
|
for dist in concurrent.futures.as_completed(future_dists):
|
|
num_ran += 1
|
|
ran_file = future_dists[dist]
|
|
file_name, result = dist.result()
|
|
closest_centers[file_name] = result
|
|
ConsoleUtils.print_progress_bar(num_ran, num_to_run, 50, 70, 4)
|
|
return closest_centers
|
|
|
|
|
|
def write_json(closest_centers):
|
|
"""
|
|
Formats passed dict for each file and writes the json files out.
|
|
|
|
:param closest_centers: The dict returned by get_distances_from_centers().
|
|
:type closest_centers: dict[str,dict[str,str]]
|
|
|
|
.. raw:: html
|
|
<br>
|
|
|
|
:return: None.
|
|
"""
|
|
num_ran = 0
|
|
num_total = len(closest_centers)
|
|
print('{:^70}'.format('Writing json %d files...' % num_total))
|
|
for file in closest_centers.keys():
|
|
ConsoleUtils.print_progress_bar(num_ran, num_total, 50, 70)
|
|
output_file = file.replace('.lev', 'closest_center.json')
|
|
with open(output_file, 'w+') as output:
|
|
output.write(json.dumps(closest_centers[file]))
|
|
|
|
|
|
def main(library_dir, thread_count):
|
|
"""
|
|
The main entry point for the tool.
|
|
|
|
:param library_dir: The path to the library containing the documents to classify.
|
|
:type library_dir: str
|
|
:param thread_count: The number of threads to execute on.
|
|
:type thread_count: int
|
|
|
|
.. raw:: html
|
|
<br>
|
|
|
|
:return: The status of the program.
|
|
:rtype: int
|
|
"""
|
|
full_list_file = os.path.join(library_dir, 'tkn_files.txt')
|
|
centers = get_centers(library_dir)
|
|
center_tkns = load_centers_tokens(centers)
|
|
all_docs = load_all_docs(full_list_file, thread_count)
|
|
dist_from_centers = get_distances_from_centers(all_docs, center_tkns, thread_count)
|
|
write_json(dist_from_centers)
|
|
|
|
|
|
def check_args(library_dir, thread_cnt):
|
|
"""
|
|
Makes sure arguments are valid before running the program.
|
|
|
|
:param library_dir: The path to the parent folder of the library.
|
|
:type library_dir: str
|
|
:param thread_cnt: The number of threads to execute on.
|
|
:type thread_cnt: int
|
|
|
|
.. raw:: html
|
|
<br>
|
|
|
|
:return: None.
|
|
"""
|
|
fatal_errors = False
|
|
if thread_cnt < 1:
|
|
print('%s thread_count cannot be less than 1' % yellow_warning)
|
|
print('%s Using default of %d' % (blue_okay, def_thread_count))
|
|
|
|
if not os.path.exists(library_dir):
|
|
print('%s Directory does not exist: %s' % (red_error, library_dir))
|
|
fatal_errors = True
|
|
|
|
if fatal_errors:
|
|
parser.print_help()
|
|
print('Exiting...')
|
|
exit(-1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
required_args = parser.add_argument_group('Required')
|
|
optional_args = parser.add_argument_group('Optional')
|
|
required_args.add_argument('-l', '--library_dir', required=True,
|
|
help='The full path to the parent folder of the library to classify')
|
|
optional_args.add_argument('-t', '--thread_count', required=False, type=int, default=def_thread_count,
|
|
help='The number of threads to run this tool on.')
|
|
optional_args.add_argument('-w', '--overwrite', required=False, action='store_true',
|
|
help='If this flag is included files will be overwritten without asking.')
|
|
optional_args.add_argument('-h', '--help', action='help', help='Prints the help message.')
|
|
args = parser.parse_args()
|
|
|
|
lib_dir = args.library_dir
|
|
thread_count = args.thread_count
|
|
auto_overwrite = args.overwrite
|
|
# Are the args valid?
|
|
check_args(lib_dir, thread_count)
|
|
# If all is well, let's run the program!
|
|
main(lib_dir, thread_count)
|