ScoreWalker/scorewalker-utils/misc/convert-centers.py
2025-03-13 00:13:53 -06:00

124 lines
4.4 KiB
Python

import json
import datetime
import os
import argparse
import sys
sys.path.insert(0, '../KMeans')
import ConsoleUtils
"""
generate a configuration JSON file for WalkerClassifier that contains the list of centers for all the doc-types
relies on a program that calculates the centers
inputs are hard coded in the main() method
output is hard coded in the main() method
"""
program_name = 'ConvertCenters'
program_description = 'Converts the centers computed by L3Centers to a config file for WalkerClassifier.'
parser = argparse.ArgumentParser(prog=program_name, description=program_description, add_help=False)
red_error = '\033[91mError:\033[0m'
yellow_warning = '\033[93mWARNING:\033[0m'
blue_okay = '\033[94mOK\033[0m'
doc_id_dict = {}
id_doc_dict = {}
config = {}
centers = {}
center_docs = []
doc_centers = {}
# load dictionaries for doc id to filename and filename to doc id
def load_id_map(filename):
with open(filename) as infile:
for line in infile:
delim_loc = line.rfind(':')
doc_path = line[0:delim_loc:]
doc_id = line[delim_loc+1:]
doc_id.replace('\n', '')
doc_id_dict[doc_path] = doc_id.strip()
id_doc_dict[doc_id.strip()] = doc_path
# load the doc-type centers; translation is made to match id to filename
# format of the dictionaries created matches the format of the JSON file to be created
def load_dist_map(filename):
bad_file_cnt = 0
with open(filename) as infile:
for line in infile:
delim_loc = line.rfind(':')
doc_type = line[0:delim_loc:]
doc_id = line[delim_loc+1:].strip().replace(',', '')
doc_center = {"doc-type": doc_type, "center-doc": id_doc_dict[doc_id]}
center_docs.append(doc_center)
doc_centers["center-docs"] = center_docs
# write the JSON file that will be passed as config to WalkerClassifier
def write_doc_centers(filename, library_version):
config["layout"] = library_version
config["centers"] = doc_centers
with open(filename, 'w+') as outfile:
text = json.dumps(config, indent=1)
print(text)
outfile.write(text)
# hard coded input and output files
def main(doc_id_file, doc_dist_file, config_file, lib_ver):
load_id_map(doc_id_file)
load_dist_map(doc_dist_file)
print(config_file, lib_ver)
write_doc_centers(config_file, lib_ver)
def check_args(id_path, dist_path, out_path, lib_ver, overwrite):
fatal_errors = False
if not (lib_ver == 'V1' or lib_ver == 'V2'):
print('%s library version "%s" is invalid, using default (V1).' % (red_error, lib_ver))
if os.path.exists(out_path) and not overwrite:
print('%s Out file already exists and will be overwritten' % yellow_warning)
ConsoleUtils.yes_or_no('Overwrite %s? ' % out_path)
if not os.path.exists(id_path):
print('%s ID file does not exist: %s' % (red_error, id_path))
fatal_errors = True
if not os.path.exists(dist_path):
print('%s Distance file does not exist: %s' % (red_error, dist_path))
fatal_errors = True
if fatal_errors:
parser.print_help()
print('Fatal Error encountered, exiting...')
exit(-1)
# if run interactively, call the main() method
if __name__ == '__main__':
required_args = parser.add_argument_group('Required')
optional_args = parser.add_argument_group('Optional')
required_args.add_argument('-i', '--id_file', required=True, help='The path to a docIDMap file.')
required_args.add_argument('-d', '--dist_file', required=True, help='The path to a docDist file.')
required_args.add_argument('-o', '--output', required=True, help='The path to save the config file to.')
optional_args.add_argument('-v', '--lib_version', required=False, help='Sets the library version (V1 or V2).',
default='V1')
optional_args.add_argument('-h', '--help', action='help', help='Prints the help message.')
optional_args.add_argument('-w', '--overwrite', required=False, action='store_true',
help='If this is used, the program will overwrite a file at output if it exists.')
args = parser.parse_args()
id_map_file = args.id_file
dist_map_file = args.dist_file
out_file = args.output
lib_version = args.lib_version
auto_overwrite = args.overwrite
check_args(id_map_file, dist_map_file, out_file, lib_version, auto_overwrite)
main(id_map_file, dist_map_file, out_file, lib_version)