Done part 4
This commit is contained in:
parent
611ca157b5
commit
0476edae8a
4 changed files with 43 additions and 14 deletions
|
@ -6,6 +6,8 @@ import re
|
||||||
import itertools
|
import itertools
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from train_classifiers import perform_grid_search, load_dataset
|
from train_classifiers import perform_grid_search, load_dataset
|
||||||
|
import seaborn as sns
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
from sklearn.neural_network import MLPClassifier
|
from sklearn.neural_network import MLPClassifier
|
||||||
from sklearn.naive_bayes import GaussianNB
|
from sklearn.naive_bayes import GaussianNB
|
||||||
|
@ -131,8 +133,34 @@ def main():
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
df_stats.to_csv(OUT_DIR + '/model_stats.csv')
|
df_stats.to_csv(OUT_DIR + '/model_stats.csv')
|
||||||
print(df_stats)
|
|
||||||
|
|
||||||
|
for metric in metric_list:
|
||||||
|
if metric == 'accuracy':
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(metric)
|
||||||
|
|
||||||
|
dft = df_stats.loc[df_stats['metric'] == metric, :].copy()
|
||||||
|
dft.pvalue = dft.pvalue.apply(lambda x: '{0:.4g}'.format(round(x, 4)))
|
||||||
|
|
||||||
|
dft = dft \
|
||||||
|
.pivot(index=['classifier_a'], columns=['classifier_b'], values=['pvalue']) \
|
||||||
|
.reset_index(drop=False)
|
||||||
|
dft.columns = sorted([x[1] for x in dft.columns])
|
||||||
|
print(dft.replace({ np.nan: '--' }).to_markdown(index=False) + '\n')
|
||||||
|
|
||||||
|
dfg = df.loc[df['metric'] != 'accuracy', :].sort_values(by=['classifier'])
|
||||||
|
# Order by metric list
|
||||||
|
dfg = pd.concat([dfg[dfg['metric'] == met] for met in metric_list if met != 'accuracy'])
|
||||||
|
|
||||||
|
f, ax = plt.subplots(figsize=(8, 10))
|
||||||
|
plt.yticks(np.arange(0.0, 1.0 + 1, 0.1))
|
||||||
|
sns.boxplot(x="metric", y="value", hue="classifier", data=dfg, ax=ax)
|
||||||
|
|
||||||
|
ax.set(ylabel="Metric value", ylim=[0, 1], xlabel="Metric")
|
||||||
|
ax.set_title("Distribution of metrics for each classifier")
|
||||||
|
sns.despine(offset=10, trim=True)
|
||||||
|
f.savefig(OUT_DIR + '/boxplot.png')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
BIN
models/boxplot.png
Normal file
BIN
models/boxplot.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 45 KiB |
|
@ -1,34 +1,34 @@
|
||||||
,classifier_a,classifier_b,metric,pvalue
|
,classifier_a,classifier_b,metric,pvalue
|
||||||
1,DecisionTreeClassifier,GaussianNB,precision,0.08929133280531223
|
1,DecisionTreeClassifier,GaussianNB,precision,0.08929133280531223
|
||||||
2,DecisionTreeClassifier,GaussianNB,recall,3.8656827355135645e-18
|
2,DecisionTreeClassifier,GaussianNB,recall,3.877480505802584e-18
|
||||||
3,DecisionTreeClassifier,GaussianNB,f1,3.896340037647931e-18
|
3,DecisionTreeClassifier,GaussianNB,f1,3.896340037647931e-18
|
||||||
4,DecisionTreeClassifier,MLPClassifier,precision,0.4012348497407896
|
4,DecisionTreeClassifier,MLPClassifier,precision,0.4012348497407896
|
||||||
5,DecisionTreeClassifier,MLPClassifier,recall,0.010673345794344981
|
5,DecisionTreeClassifier,MLPClassifier,recall,0.011820059675817408
|
||||||
6,DecisionTreeClassifier,MLPClassifier,f1,0.4710651684151138
|
6,DecisionTreeClassifier,MLPClassifier,f1,0.4710651684151138
|
||||||
7,DecisionTreeClassifier,RandomForestClassifier,precision,8.283133239663301e-12
|
7,DecisionTreeClassifier,RandomForestClassifier,precision,8.283473187323235e-12
|
||||||
8,DecisionTreeClassifier,RandomForestClassifier,recall,0.3324828913770316
|
8,DecisionTreeClassifier,RandomForestClassifier,recall,0.3276029575034267
|
||||||
9,DecisionTreeClassifier,RandomForestClassifier,f1,1.4515097813437996e-10
|
9,DecisionTreeClassifier,RandomForestClassifier,f1,1.4515097813437996e-10
|
||||||
10,DecisionTreeClassifier,SVC,precision,6.472995016722292e-16
|
10,DecisionTreeClassifier,SVC,precision,6.472995016722292e-16
|
||||||
11,DecisionTreeClassifier,SVC,recall,3.849136100548656e-18
|
11,DecisionTreeClassifier,SVC,recall,3.864155888689142e-18
|
||||||
12,DecisionTreeClassifier,SVC,f1,3.896559845095909e-18
|
12,DecisionTreeClassifier,SVC,f1,3.896559845095909e-18
|
||||||
13,GaussianNB,MLPClassifier,precision,0.03476088049603166
|
13,GaussianNB,MLPClassifier,precision,0.03476088049603166
|
||||||
14,GaussianNB,MLPClassifier,recall,3.848918829852649e-18
|
14,GaussianNB,MLPClassifier,recall,3.873544128513129e-18
|
||||||
15,GaussianNB,MLPClassifier,f1,3.896120241954008e-18
|
15,GaussianNB,MLPClassifier,f1,3.896120241954008e-18
|
||||||
16,GaussianNB,RandomForestClassifier,precision,5.027978595522601e-10
|
16,GaussianNB,RandomForestClassifier,precision,5.027978595522601e-10
|
||||||
17,GaussianNB,RandomForestClassifier,recall,3.8398039515630974e-18
|
17,GaussianNB,RandomForestClassifier,recall,3.8656827355135645e-18
|
||||||
18,GaussianNB,RandomForestClassifier,f1,3.896120241954008e-18
|
18,GaussianNB,RandomForestClassifier,f1,3.896120241954008e-18
|
||||||
19,GaussianNB,SVC,precision,7.361006463422299e-13
|
19,GaussianNB,SVC,precision,7.361006463422299e-13
|
||||||
20,GaussianNB,SVC,recall,3.878355771123559e-18
|
20,GaussianNB,SVC,recall,3.881639684405151e-18
|
||||||
21,GaussianNB,SVC,f1,4.265842540306607e-18
|
21,GaussianNB,SVC,f1,4.265842540306607e-18
|
||||||
22,MLPClassifier,RandomForestClassifier,precision,2.9302015489842885e-09
|
22,MLPClassifier,RandomForestClassifier,precision,2.9302015489842885e-09
|
||||||
23,MLPClassifier,RandomForestClassifier,recall,9.555788374830177e-05
|
23,MLPClassifier,RandomForestClassifier,recall,0.00010909237805840521
|
||||||
24,MLPClassifier,RandomForestClassifier,f1,1.1542838431590428e-11
|
24,MLPClassifier,RandomForestClassifier,f1,1.1542838431590428e-11
|
||||||
25,MLPClassifier,SVC,precision,3.6744416439536415e-16
|
25,MLPClassifier,SVC,precision,3.6744416439536415e-16
|
||||||
26,MLPClassifier,SVC,recall,5.611915312842127e-18
|
26,MLPClassifier,SVC,recall,5.645631221640026e-18
|
||||||
27,MLPClassifier,SVC,f1,5.112831740936498e-18
|
27,MLPClassifier,SVC,f1,5.112831740936498e-18
|
||||||
28,RandomForestClassifier,SVC,precision,4.0161556854627e-18
|
28,RandomForestClassifier,SVC,precision,4.0161556854627e-18
|
||||||
29,RandomForestClassifier,SVC,recall,3.849570676820676e-18
|
29,RandomForestClassifier,SVC,recall,3.8584897469079895e-18
|
||||||
30,RandomForestClassifier,SVC,f1,3.896340037647931e-18
|
30,RandomForestClassifier,SVC,f1,3.896559845095909e-18
|
||||||
31,BiasedClassifier,DecisionTreeClassifier,precision,3.881858705649312e-18
|
31,BiasedClassifier,DecisionTreeClassifier,precision,3.881858705649312e-18
|
||||||
32,BiasedClassifier,DecisionTreeClassifier,recall,1.0267247842714985e-14
|
32,BiasedClassifier,DecisionTreeClassifier,recall,1.0267247842714985e-14
|
||||||
33,BiasedClassifier,DecisionTreeClassifier,f1,3.881858705649312e-18
|
33,BiasedClassifier,DecisionTreeClassifier,f1,3.881858705649312e-18
|
||||||
|
|
|
|
@ -3,3 +3,4 @@ pandas==1.5.2
|
||||||
scikit_learn==1.2.1
|
scikit_learn==1.2.1
|
||||||
tabulate==0.9.0
|
tabulate==0.9.0
|
||||||
scipy==1.24.2
|
scipy==1.24.2
|
||||||
|
seaborn==0.12.2
|
Reference in a new issue