ScoreWalker/scorewalker-utils/KMeans/GetClosestCenter.py
2025-03-13 00:13:53 -06:00

311 lines
10 KiB
Python

"""
GetClosestCenter.py
====================
This is a tool for getting the closest center document to every file in a given directory. This is both an example of
how to use our clustering tools for classification as well as a way to measure the accuracy of our distance
calculations. This tool is multi-threaded and has a dependency on DocumentDistance.py to compute the distance
between documents in the library and center documents.
.. moduleauthor:: Chris Diesch <cdiesch@sequencelogic.net>
"""
import os
import argparse
import ConsoleUtils
import DocumentDistance
import sys
import concurrent.futures
import json
red_error = '\033[91mError:\033[0m'
yellow_warning = '\033[93mWARNING:\033[0m'
blue_okay = '\033[94mOK\033[0m'
program_name = 'GetClosestCenter'
program_description = 'Finds the closest center of every document in the library, this is effectively classifying ' \
'these documents using our distance algorithm in conjunction with k-means clustering.'
parser = argparse.ArgumentParser(prog=program_name, description=program_description, add_help=False)
# Default argument values
def_thread_count = 1
def get_centers(library_dir):
"""
Loads all the center documents in the library.
:param library_dir: The parent directory of the library.
:type library_dir: str
.. raw:: html
<br>
:return:
"""
print('{:^70}'.format('Loading doctype centers.'))
doctype_centers = {}
num_ran = 0
doctype_folders = os.listdir(library_dir)
num_to_run = len(doctype_folders)
for doctype_folder in doctype_folders:
num_ran += 1
full_path = os.path.join(library_dir, doctype_folder)
if not os.path.isfile(full_path):
center = get_center_doctype(full_path)
if center is not None:
doctype_centers[doctype_folder] = center
ConsoleUtils.print_progress_bar(num_ran, num_to_run, 50, 70)
return doctype_centers
def load_centers_tokens(centers):
"""
Loads the tokens from the files representing the centers of each doctype.
:param centers: A dict where the keys are the doctypes in the library and the values are the centers of those
doctypes.
:type centers: dict[str,str]
.. raw:: html
<br>
:return: A dict with keys equal to the center document and the values representing the tokens read by
DocumentDistance.read_tokens_from_file(center).
:rtype: dict[str,list of str]
"""
centers_tkns = {}
for doc_type, center in centers.items():
result, name = DocumentDistance.read_tokens_from_file(center)
centers_tkns[name] = result
return centers_tkns
def get_center_doctype(doctype_dir):
"""
Finds the center doctype in the given doctype directory.
:param doctype_dir: The parent directory of a doctype to load the center of.
:type doctype_dir: str
.. raw:: html
<br>
:return: The name of the document which is the computed center of the doctype cluster.
:rtype: str
"""
center_file = os.path.join(doctype_dir, 'doctype_center.txt')
if os.path.exists(center_file):
with open(center_file) as center:
center_lines = center.readlines()
center_line = center_lines[0].replace('\n', '')
center_doctype = center_line.split(' = ')[1]
return center_doctype
else:
return None
def get_doctype(file):
"""
Gets the doctype of a given file.
.. note:: This currently only works for V1 libraries.
:param file: The path of the file to get the doctype of.
:type file: str
.. raw:: html
<br>
:return: The name of the doctype of a given file.
:rtype: str
"""
fml_dir, file_name = os.path.split(file)
doctype_dir, fml_name = os.path.split(fml_dir)
lib_dir, doctype_name = os.path.split(doctype_dir)
return doctype_name
def load_all_docs(doc_file, thread_count):
"""
Loads all the documents contained in doc_file using DocumentDistance.load_files(doc_file, thread_count).
:param doc_file: The path of document containing a list of docs to find the closest center of.
:type doc_file: str
:param thread_count: The number of threads to execute on.
:type thread_count: int
.. raw:: html
<br>
:return: A dict with keys representing the name of the file and values representing the list of tokens in the file.
:rtype: dict[str,list of str]
"""
return DocumentDistance.load_files(doc_file, thread_count)
def get_closest_center(file, centers, file_name):
"""
Computes the closest center to the given file.
:param file: The dict representing the file read by DocumentDistance.load_files()
:type file: dict[str,str]
:param centers: The dict with keys representing the names of center documents, and values representing the file read
by DocumentDistance.read_tokens_from_file()
:type centers: dict[str,str]
:param file_name: The name of the file being read.
:type file_name: str
.. raw:: html
<br>
:return: The name of the file ran, and a dict with the following keys:
'predictedDoctype', 'actualDoctype', 'closestCenter', 'distFromCenter', 'distances'
:rtype: tuple(str, dict)
"""
tkn_tag = str(DocumentDistance.tkn_vals_tag)
actual_doctype = get_doctype(file_name)
min_dist = sys.maxsize
min_center = ''
min_doctype = ''
dist_from_centroids = {}
for key, center in centers.items():
dist_from_center = int(DocumentDistance.lev_dist(file[tkn_tag], center[tkn_tag]))
if dist_from_center < min_dist:
min_center = key
min_dist = dist_from_center
min_doctype = get_doctype(key)
else:
dist_from_centroids[key] = str(dist_from_center)
result = {'predictedDoctype': min_doctype, 'actualDoctype': actual_doctype, 'closestCenter': min_center,
'distFromCenter': str(min_dist), 'distances': dist_from_centroids}
return file_name, result
def get_distances_from_centers(files, centers, thread_count):
"""
Computes the distances every file is from every center document.
:param files: The dict returned by DocumentDistance.load_files().
:type files: dict[str, list of str]
:param centers: A dict identical to that returned by DocumentDistance.load_files() but only containing the center
documents for each doctype cluster.
:type centers: dict[str, list of str]
:param thread_count: The number of threads to execute on.
:type thread_count: int
.. raw:: html
<br>
:return: A dict with keys equal to the file names, and values equal to the dict returned by get_closest_center().
:rtype: dict[str,dict[str,str]]
"""
num_ran = 0
num_to_run = len(files)
closest_centers = {}
print('{:^70}'.format('Finding closet doctype centers (%d threads).' % thread_count))
with concurrent.futures.ThreadPoolExecutor(max_workers=thread_count) as thread_pool:
future_dists = {thread_pool.submit(get_closest_center, file, centers, key): file for key, file in files.items()}
for dist in concurrent.futures.as_completed(future_dists):
num_ran += 1
ran_file = future_dists[dist]
file_name, result = dist.result()
closest_centers[file_name] = result
ConsoleUtils.print_progress_bar(num_ran, num_to_run, 50, 70, 4)
return closest_centers
def write_json(closest_centers):
"""
Formats passed dict for each file and writes the json files out.
:param closest_centers: The dict returned by get_distances_from_centers().
:type closest_centers: dict[str,dict[str,str]]
.. raw:: html
<br>
:return: None.
"""
num_ran = 0
num_total = len(closest_centers)
print('{:^70}'.format('Writing json %d files...' % num_total))
for file in closest_centers.keys():
ConsoleUtils.print_progress_bar(num_ran, num_total, 50, 70)
output_file = file.replace('.lev', 'closest_center.json')
with open(output_file, 'w+') as output:
output.write(json.dumps(closest_centers[file]))
def main(library_dir, thread_count):
"""
The main entry point for the tool.
:param library_dir: The path to the library containing the documents to classify.
:type library_dir: str
:param thread_count: The number of threads to execute on.
:type thread_count: int
.. raw:: html
<br>
:return: The status of the program.
:rtype: int
"""
full_list_file = os.path.join(library_dir, 'tkn_files.txt')
centers = get_centers(library_dir)
center_tkns = load_centers_tokens(centers)
all_docs = load_all_docs(full_list_file, thread_count)
dist_from_centers = get_distances_from_centers(all_docs, center_tkns, thread_count)
write_json(dist_from_centers)
def check_args(library_dir, thread_cnt):
"""
Makes sure arguments are valid before running the program.
:param library_dir: The path to the parent folder of the library.
:type library_dir: str
:param thread_cnt: The number of threads to execute on.
:type thread_cnt: int
.. raw:: html
<br>
:return: None.
"""
fatal_errors = False
if thread_cnt < 1:
print('%s thread_count cannot be less than 1' % yellow_warning)
print('%s Using default of %d' % (blue_okay, def_thread_count))
if not os.path.exists(library_dir):
print('%s Directory does not exist: %s' % (red_error, library_dir))
fatal_errors = True
if fatal_errors:
parser.print_help()
print('Exiting...')
exit(-1)
if __name__ == '__main__':
required_args = parser.add_argument_group('Required')
optional_args = parser.add_argument_group('Optional')
required_args.add_argument('-l', '--library_dir', required=True,
help='The full path to the parent folder of the library to classify')
optional_args.add_argument('-t', '--thread_count', required=False, type=int, default=def_thread_count,
help='The number of threads to run this tool on.')
optional_args.add_argument('-w', '--overwrite', required=False, action='store_true',
help='If this flag is included files will be overwritten without asking.')
optional_args.add_argument('-h', '--help', action='help', help='Prints the help message.')
args = parser.parse_args()
lib_dir = args.library_dir
thread_count = args.thread_count
auto_overwrite = args.overwrite
# Are the args valid?
check_args(lib_dir, thread_count)
# If all is well, let's run the program!
main(lib_dir, thread_count)