Done part 4
This commit is contained in:
parent
611ca157b5
commit
0476edae8a
4 changed files with 43 additions and 14 deletions
|
@ -6,6 +6,8 @@ import re
|
|||
import itertools
|
||||
import numpy as np
|
||||
from train_classifiers import perform_grid_search, load_dataset
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.naive_bayes import GaussianNB
|
||||
|
@ -131,8 +133,34 @@ def main():
|
|||
i += 1
|
||||
|
||||
df_stats.to_csv(OUT_DIR + '/model_stats.csv')
|
||||
print(df_stats)
|
||||
|
||||
for metric in metric_list:
|
||||
if metric == 'accuracy':
|
||||
continue
|
||||
|
||||
print(metric)
|
||||
|
||||
dft = df_stats.loc[df_stats['metric'] == metric, :].copy()
|
||||
dft.pvalue = dft.pvalue.apply(lambda x: '{0:.4g}'.format(round(x, 4)))
|
||||
|
||||
dft = dft \
|
||||
.pivot(index=['classifier_a'], columns=['classifier_b'], values=['pvalue']) \
|
||||
.reset_index(drop=False)
|
||||
dft.columns = sorted([x[1] for x in dft.columns])
|
||||
print(dft.replace({ np.nan: '--' }).to_markdown(index=False) + '\n')
|
||||
|
||||
dfg = df.loc[df['metric'] != 'accuracy', :].sort_values(by=['classifier'])
|
||||
# Order by metric list
|
||||
dfg = pd.concat([dfg[dfg['metric'] == met] for met in metric_list if met != 'accuracy'])
|
||||
|
||||
f, ax = plt.subplots(figsize=(8, 10))
|
||||
plt.yticks(np.arange(0.0, 1.0 + 1, 0.1))
|
||||
sns.boxplot(x="metric", y="value", hue="classifier", data=dfg, ax=ax)
|
||||
|
||||
ax.set(ylabel="Metric value", ylim=[0, 1], xlabel="Metric")
|
||||
ax.set_title("Distribution of metrics for each classifier")
|
||||
sns.despine(offset=10, trim=True)
|
||||
f.savefig(OUT_DIR + '/boxplot.png')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
BIN
models/boxplot.png
Normal file
BIN
models/boxplot.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 45 KiB |
|
@ -1,34 +1,34 @@
|
|||
,classifier_a,classifier_b,metric,pvalue
|
||||
1,DecisionTreeClassifier,GaussianNB,precision,0.08929133280531223
|
||||
2,DecisionTreeClassifier,GaussianNB,recall,3.8656827355135645e-18
|
||||
2,DecisionTreeClassifier,GaussianNB,recall,3.877480505802584e-18
|
||||
3,DecisionTreeClassifier,GaussianNB,f1,3.896340037647931e-18
|
||||
4,DecisionTreeClassifier,MLPClassifier,precision,0.4012348497407896
|
||||
5,DecisionTreeClassifier,MLPClassifier,recall,0.010673345794344981
|
||||
5,DecisionTreeClassifier,MLPClassifier,recall,0.011820059675817408
|
||||
6,DecisionTreeClassifier,MLPClassifier,f1,0.4710651684151138
|
||||
7,DecisionTreeClassifier,RandomForestClassifier,precision,8.283133239663301e-12
|
||||
8,DecisionTreeClassifier,RandomForestClassifier,recall,0.3324828913770316
|
||||
7,DecisionTreeClassifier,RandomForestClassifier,precision,8.283473187323235e-12
|
||||
8,DecisionTreeClassifier,RandomForestClassifier,recall,0.3276029575034267
|
||||
9,DecisionTreeClassifier,RandomForestClassifier,f1,1.4515097813437996e-10
|
||||
10,DecisionTreeClassifier,SVC,precision,6.472995016722292e-16
|
||||
11,DecisionTreeClassifier,SVC,recall,3.849136100548656e-18
|
||||
11,DecisionTreeClassifier,SVC,recall,3.864155888689142e-18
|
||||
12,DecisionTreeClassifier,SVC,f1,3.896559845095909e-18
|
||||
13,GaussianNB,MLPClassifier,precision,0.03476088049603166
|
||||
14,GaussianNB,MLPClassifier,recall,3.848918829852649e-18
|
||||
14,GaussianNB,MLPClassifier,recall,3.873544128513129e-18
|
||||
15,GaussianNB,MLPClassifier,f1,3.896120241954008e-18
|
||||
16,GaussianNB,RandomForestClassifier,precision,5.027978595522601e-10
|
||||
17,GaussianNB,RandomForestClassifier,recall,3.8398039515630974e-18
|
||||
17,GaussianNB,RandomForestClassifier,recall,3.8656827355135645e-18
|
||||
18,GaussianNB,RandomForestClassifier,f1,3.896120241954008e-18
|
||||
19,GaussianNB,SVC,precision,7.361006463422299e-13
|
||||
20,GaussianNB,SVC,recall,3.878355771123559e-18
|
||||
20,GaussianNB,SVC,recall,3.881639684405151e-18
|
||||
21,GaussianNB,SVC,f1,4.265842540306607e-18
|
||||
22,MLPClassifier,RandomForestClassifier,precision,2.9302015489842885e-09
|
||||
23,MLPClassifier,RandomForestClassifier,recall,9.555788374830177e-05
|
||||
23,MLPClassifier,RandomForestClassifier,recall,0.00010909237805840521
|
||||
24,MLPClassifier,RandomForestClassifier,f1,1.1542838431590428e-11
|
||||
25,MLPClassifier,SVC,precision,3.6744416439536415e-16
|
||||
26,MLPClassifier,SVC,recall,5.611915312842127e-18
|
||||
26,MLPClassifier,SVC,recall,5.645631221640026e-18
|
||||
27,MLPClassifier,SVC,f1,5.112831740936498e-18
|
||||
28,RandomForestClassifier,SVC,precision,4.0161556854627e-18
|
||||
29,RandomForestClassifier,SVC,recall,3.849570676820676e-18
|
||||
30,RandomForestClassifier,SVC,f1,3.896340037647931e-18
|
||||
29,RandomForestClassifier,SVC,recall,3.8584897469079895e-18
|
||||
30,RandomForestClassifier,SVC,f1,3.896559845095909e-18
|
||||
31,BiasedClassifier,DecisionTreeClassifier,precision,3.881858705649312e-18
|
||||
32,BiasedClassifier,DecisionTreeClassifier,recall,1.0267247842714985e-14
|
||||
33,BiasedClassifier,DecisionTreeClassifier,f1,3.881858705649312e-18
|
||||
|
|
|
|
@ -3,3 +3,4 @@ pandas==1.5.2
|
|||
scikit_learn==1.2.1
|
||||
tabulate==0.9.0
|
||||
scipy==1.24.2
|
||||
seaborn==0.12.2
|
Reference in a new issue