Recently, scikit-learn included a DET curve routine in the 0.24.1 release:
https://scikit-learn.org/stable/auto_examples/model_selection/plot_det.html#sphx-glr-auto-examples-model-selection-plot-det-py
However, the provided example is for a binary case only:
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import plot_det_curve
from sklearn.metrics import plot_roc_curve
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC
N_SAMPLES = 1000
classifiers = {
"Linear SVM": make_pipeline(StandardScaler(), LinearSVC(C=0.025)),
"Random Forest": RandomForestClassifier(
max_depth=5, n_estimators=10, max_features=1
),
}
X, y = make_classification(
n_samples=N_SAMPLES, n_features=2, n_redundant=0, n_informative=2,
random_state=1, n_clusters_per_class=1)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=.4, random_state=0)
# prepare plots
fig, [ax_roc, ax_det] = plt.subplots(1, 2, figsize=(11, 5))
for name, clf in classifiers.items():
clf.fit(X_train, y_train)
plot_roc_curve(clf, X_test, y_test, ax=ax_roc, name=name)
plot_det_curve(clf, X_test, y_test, ax=ax_det, name=name)
ax_roc.set_title('Receiver Operating Characteristic (ROC) curves')
ax_det.set_title('Detection Error Tradeoff (DET) curves')
ax_roc.grid(linestyle='--')
ax_det.grid(linestyle='--')
plt.legend()
plt.show()
I would like to get recommendations on how to expand it for a multiclass problem, considering the digits dataset:
import matplotlib.pyplot as plt
from sklearn import datasets, svm, metrics
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.preprocessing import label_binarize
from sklearn.metrics import precision_recall_curve, average_precision_score, roc_curve, auc, plot_det_curve
# from scipy import interp
from sklearn.tree import DecisionTreeClassifier
import seaborn as sns
import pandas as pd
digits = datasets.load_digits()
print(np.unique(digits.target, return_counts=True))
random_state = np.random.RandomState(0)
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
# Use label_binarize to be multi-label like settings
Y = label_binarize(digits.target, classes=np.arange(10))
n_classes = Y.shape[1]
# Split into training and test
X_train, X_test, Y_train, Y_test = train_test_split(data, Y, test_size=.5,
random_state=random_state)
#What to do after this?
question from:
https://stackoverflow.com/questions/66051412/how-to-plot-a-det-curve-for-a-multiclass-problem-in-scikit-learn