Sleds/scorewalker-utils/TreeWalker/TermWalker.py

198 lines
8.2 KiB
Python
Raw Normal View History

2025-03-13 21:28:38 +00:00
"""
Information
-----------
This is a tool for viewing the individual terms responsible for false positive matches from the classification
engine recognized by :mod:`TreeWalker`. This tool produces CSV output when called from the commandline and provides
insight into what causes false positives to occur at a term-level. This tool is used to help determine what terms or
groups of terms matter most when classifying a page between two similar doctypes.
.. moduleauthor:: Chris Diesch <cdiesch@sequencelogic.net>
Commandline Usage
-----------------
Usage: ``TermWalker.py [-h] [-c, --classification] {CLASS_FILE} [-w, --tree_walker_file] {TREE_WALKER_FILE}
[-o --output_file] {OUT_FILE}``
Required Arguments:
``-c CLASS_FILE, --classificatoin CLASS_FILE``
Where ``CLASS_FILE`` is the path to the classification engine output.
``-w TREE_WALKER_FILE, --tree_walker TREE_WALKER_FILE``
Where ``TREE_WALKER_FILE`` The path to the output of :mod:`TreeWalker`.
``-o OUT_FILE, --out_path OUT_FILE``
Where ``OUT_FILE`` is the path to save the final output to.
Optional Arguments:
``-h, --help``
Prints the help message.
Python Module Usage:
--------------------
"""
# We do all our imports at the top of our program.
import argparse
import csv
import sys
import ConsoleUtils
import TreeWalker
program_name = 'TermWalker'
program_description = 'Generates a CSV file for analyzing terms which produced false positives.'
author = 'Chris Diesch'
version = '1.0.1'
build = '2017.07.27'
parser = argparse.ArgumentParser(prog=program_name, description=program_description, add_help=False)
# Error and Warning console values:
red_error = '\033[91mError:\033[0m'
yellow_warning = '\033[93mWARNING:\033[0m'
blue_okay = '\033[94mOK\033[0m'
FIELD_NAMES = ['Page', 'Correct Result', 'Classification Result', 'Library Page', 'Confidence', 'Term', 'Term Weight']
def get_false_positive_pages(tree_walker_data):
"""
Gets the False Positive pages from the given :mod:`TreeWalker` file.
Args:
``tree_walker_data`` -- ``str`` The path to the TreeWalker file to load.
Returns:
``dict`` A dict where the keys are the page indexes of the false positives and the values are the classified
doctypes for those pages.
"""
fp_pages = {}
for key, value in tree_walker_data.items():
if value['Status'] == TreeWalker.FALSE_POSITIVE:
fp_pages[int(value['Page']) - 1] = value['CLUX Result']
return fp_pages
def get_fp_terms(tree_walker_file, walker_file):
"""
Loads the terms for every match false positive match in the TreeWalker file from the classification output.
Args:
``tree_walker_file`` -- ``str`` The TreeWalker file to load the false positive data from.
``walker_file`` -- ``str`` The classification output file used to create ``tree_walker_file``.
Returns:
``dict`` A dict where the keys are ``"PAGE_INDEX.MATCH_NUMBER"``
(ex: The second match for page 1 = ``"0.1"`` since pages and match numbers are 0 indexed)
The values for the result are dicts with the following key/value pairs:
+-----------------+--------------------------------------------------------------+
| Key | Value ``type`` |
+=================+==============================================================+
| ``terms`` | The terms which responsible for this match. ``list(str)`` |
+-----------------+--------------------------------------------------------------+
| ``doctype`` | The classified doctype of this match. ``str`` |
+-----------------+--------------------------------------------------------------+
| ``correctType`` | The correct doctype for this match. ``str`` |
+-----------------+--------------------------------------------------------------+
| ``score`` | The score for this match. ``int`` |
+-----------------+--------------------------------------------------------------+
| ``libPage`` | The path to the library page for this match. ``str`` |
+-----------------+--------------------------------------------------------------+
| ``conf`` | The confidence of this match. ``float`` |
+-----------------+--------------------------------------------------------------+
"""
result = {}
bad_pages = get_false_positive_pages(TreeWalker.load_csv(tree_walker_file))
walker_data = TreeWalker.load_walker_data(walker_file)
for page_index, correct_type in bad_pages.items():
walker_page = walker_data[page_index]
matches = walker_page['matches']
match_num = 0
for match in matches:
terms = match['terms']
result[format('%d.%d' % (page_index, match_num))] = {'terms': terms,
'correctType': correct_type,
'doctype': match['doctype'],
'score': match['rawScore'],
'libPage': match['imagePath'],
'conf': match['conf']}
match_num += 1
print('Found %d false positives in %s' % (len(result), tree_walker_file))
return result
def write_csv(fp_term_data, out_file_loc):
"""
Saves the output from this tool as a CSV file.
Args:
``fp_term_data`` -- ``dict`` The data from :meth:`get_fp_terms`.
``out_file_loc`` -- ``str`` The path to save the output to.
Returns:
``None``
"""
print('Saving file at %s' % out_file_loc)
with open(out_file_loc, 'w+', newline='') as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=FIELD_NAMES)
writer.writeheader()
for key, data in fp_term_data.items():
page_idx = int(key.split('.')[0])
writer.writerow({'Page': page_idx + 1, 'Correct Result': data['correctType'],
'Classification Result': data['doctype'], 'Library Page': data['libPage'],
'Term': 'Page Score: ', 'Term Weight': data['score'], 'Confidence': data['conf']})
for term in data['terms']:
writer.writerow({'Page': '', 'Correct Result': '', 'Classification Result': '', 'Library Page': '',
'Term': term['term'], 'Term Weight': term['termScore'],
'Confidence': data['conf']})
writer.writerow({'Page': '', 'Correct Result': '', 'Classification Result': '', 'Library Page': '',
'Term': '', 'Term Weight': '', 'Confidence': ''})
# This is the main function of the program.
def main(classification_file, tree_walker_file, output_file):
print('Classification file: %s\n'
'TreeWalker file: %s\n'
'Output file: %s' % (classification_file, tree_walker_file, output_file))
false_pos_data = get_fp_terms(tree_walker_file, classification_file)
write_csv(false_pos_data, output_file_path)
# This is where we call the main method from.
if __name__ == '__main__':
printer = ConsoleUtils.SLPrinter(program_name)
sys.stdout = printer
printer.write_no_prefix(ConsoleUtils.get_header(program_name, version, build, author, 80))
# Set up arguments here.
required_args = parser.add_argument_group('Required')
optional_args = parser.add_argument_group('Optional')
required_args.add_argument('-c', '--classification', required=True,
help='The path to a classification output file.')
required_args.add_argument('-w', '--tree_walker', required=True, help='The path to a TreeWalker output file.')
required_args.add_argument('-o', '--out_path', required=True, help='The path to write output to.')
optional_args.add_argument('-h', '--help', action="help", help='Prints the help message.')
args = parser.parse_args()
# Get the argument.
walker_file_path = args.classification
tree_file_path = args.tree_walker
output_file_path = args.out_path
# Now we can run...
main(walker_file_path, tree_file_path, output_file_path)