124 lines
4.5 KiB
Python
124 lines
4.5 KiB
Python
"""
|
|
This program gets the information of the clustering tools accuracy based on running against a list of files which are in
|
|
a library. It is an example of how to use the previous utilities to classify a document based on k-means clustering.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import ConsoleUtils
|
|
import concurrent.futures
|
|
|
|
|
|
def load_file(file_name):
|
|
if not os.path.exists(file_name):
|
|
print('File %s does not exist' % file_name)
|
|
return
|
|
with open(file_name) as reader:
|
|
file_obj = json.load(reader)
|
|
clustered_doctype = file_obj['predictedDoctype']
|
|
actual_doctype = file_obj['actualDoctype']
|
|
dist_from_centroids = file_obj['distances']
|
|
closest_center = file_obj['closestCenter']
|
|
dist_from_closest = file_obj['distFromCenter']
|
|
final_result = {'prediction': clustered_doctype, 'actual': actual_doctype, 'closestCenter': closest_center,
|
|
'distFromClosest': dist_from_closest, 'distFromCentroids': dist_from_centroids}
|
|
return final_result
|
|
|
|
|
|
def classified_correctly(document):
|
|
return document['prediction'] == document['actual']
|
|
|
|
|
|
def get_doc_data(documents):
|
|
correct_docs = []
|
|
incorrect_docs = []
|
|
correct_doc_count = 0
|
|
incorrect_doc_count = 0
|
|
total_doc_count = len(documents)
|
|
docs_analyzed = 0
|
|
print('{:^70}'.format('Calculating classification accuracy...'))
|
|
for name, doc in documents.items():
|
|
docs_analyzed += 1
|
|
ConsoleUtils.print_progress_bar(docs_analyzed, total_doc_count, 50, 70)
|
|
if classified_correctly(doc):
|
|
correct_doc_count += 1
|
|
correct_docs.append(name)
|
|
else:
|
|
incorrect_doc_count += 1
|
|
incorrect_docs.append(name)
|
|
return correct_doc_count, correct_docs, incorrect_doc_count, incorrect_docs
|
|
|
|
|
|
def get_doctype(file):
|
|
fml_dir, file_name = os.path.split(file)
|
|
doctype_dir, fml_name = os.path.split(fml_dir)
|
|
lib_dir, doctype_name = os.path.split(doctype_dir)
|
|
return doctype_name
|
|
|
|
|
|
def get_all_incorrect_distance_from_correct(incorrect_files, files):
|
|
result = {}
|
|
for file in incorrect_files:
|
|
incorrect_data = files[file]
|
|
result[file] = get_dist_from_correct_center(incorrect_data)
|
|
return result
|
|
|
|
|
|
def get_dist_from_correct_center(file_data):
|
|
result = 0
|
|
file_doctype = file_data['actual']
|
|
centroids = file_data['distFromCentroids']
|
|
for center_name in centroids.keys():
|
|
centroid_doctype = get_doctype(center_name)
|
|
if centroid_doctype == file_doctype:
|
|
result = centroids[center_name]
|
|
return result
|
|
|
|
|
|
def load_files(file_list, thread_count):
|
|
file_dict = {}
|
|
num_loaded = 0
|
|
num_to_load = len(file_list)
|
|
print('{:^70}'.format('Loading files (%d threads)' % thread_count))
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=thread_count) as thread_pool:
|
|
future_files = {thread_pool.submit(load_file, file): file for file in file_list}
|
|
for future in concurrent.futures.as_completed(future_files):
|
|
num_loaded += 1
|
|
file = future_files[future]
|
|
result = future.result()
|
|
ConsoleUtils.print_progress_bar(num_loaded, num_to_load, 50, 70)
|
|
file_dict[file] = result
|
|
return file_dict
|
|
|
|
|
|
def load_file_list(file):
|
|
with open(file) as reader:
|
|
files = reader.readlines()
|
|
for i in range(len(files)):
|
|
files[i] = files[i].replace('.tkn\n', 'closest_center.json')
|
|
return files
|
|
|
|
|
|
def main(file_list_file, thread_count):
|
|
files = load_file_list(file_list_file)
|
|
files_data = load_files(files, thread_count)
|
|
total_doc_count = len(files_data)
|
|
correct_count, correct_docs, wrong_count, wrong_docs = get_doc_data(files_data)
|
|
correct_percent_str = '{:.4%}'.format(correct_count/total_doc_count)
|
|
wrong_percent_str = '{:.4%}'.format(wrong_count/total_doc_count)
|
|
str_to_print = '{:^70}'.format('Of %d docs %d (%s) were correct and %d (%s) were incorrect.' %
|
|
(total_doc_count, correct_count, correct_percent_str,
|
|
wrong_count, wrong_percent_str))
|
|
print(str_to_print)
|
|
incorrect_distances = get_all_incorrect_distance_from_correct(wrong_docs, files_data)
|
|
total_dist = 0
|
|
for name, d in incorrect_distances.items():
|
|
total_dist += int(d)
|
|
average_wrong_dist = float(total_dist/wrong_count)
|
|
print('{:^70}'.format('The average distance away from the correct center is: %.2f' % average_wrong_dist))
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main(r'C:\Users\chris\Documents\Code\Tests\KMeans\Prime OCR Full Library\tkn_files.txt', 8)
|