26 lines
1.1 KiB
Python
Executable File
26 lines
1.1 KiB
Python
Executable File
#!/usr/bin/env python
|
|
|
|
import argparse
|
|
import os
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
|
|
|
|
import torch
|
|
|
|
from modelimpl.auc import build_curve
|
|
|
|
if __name__ == '__main__':
|
|
assert torch is not None # make sure pytorch is imported and loaded with correct CUDA env variable
|
|
|
|
parser = argparse.ArgumentParser(prog='auc.py',
|
|
description='ROC curve and AUC computation script. The script evaluates the given '
|
|
'model against the test set and generates a OvR ROC curve '
|
|
'plot with one curve per class, a micro-averaged OvR ROC plot '
|
|
'and the corresponding AUC value.')
|
|
parser.add_argument('modelfile', type=str, help="Path to the pickled pytorch model to classify the issue with")
|
|
parser.add_argument('-c', '--force-cpu', action='store_true',
|
|
help="disables CUDA support. Useful when debugging")
|
|
|
|
args = parser.parse_args()
|
|
build_curve(args.modelfile, args.force_cpu)
|