Precision/Recall Trade-off
Precision/Recall Trade-off
·
To understand this trade-off,
ü let’s look at how the SGDClassifier
makes its classification
decisions.
ü For each instance, it computes
a score based on a decision function.
ü If that score is greater than a threshold,
•
it assigns the instance to the positive class;
•
otherwise it assigns it to the negative class.
Figure 3-3. In this precision/recall trade-off, images are ranked by
their classifier score, and those above the chosen decision threshold are
considered positive; the higher the threshold, the lower the recall, but (in
general) the higher the precision.
•
Figure 3 shows a few digits positioned from the lowest score on the left to the highest score on the right.
•
Suppose the decision threshold is positioned at the central arrow (between the two 5s):
ü
you will find
ü 4 true positives (actual 5s) on the right
of that threshold, and
ü 1 false positive (actually a 6).
ü
Therefore, with that threshold, the precision is 80% (4 out of 5).
•
But out of 6 actual 5s, the classifier
only detects
4,
ü so the
recall is 67% (4 out of 6).
•
If you raise the threshold (move it to the arrow on
the right),
ü the false positive (the 6) becomes a true negative,
ü thereby increasing the precision (up
to 100% in this case),
ü but
one true positive becomes a false negative, decreasing recall down
to 50%.
•
Conversely, lowering the threshold
increases recall and reduces precision.
·
Scikit-Learn does not let you set the threshold
directly,
ü but it
does give you access to the decision scores that it uses to make predictions.
·
Instead of calling the classifier’s predict() method,
ü you
can call its decision_function() method,
·
which returns a score for each
instance, and
·
then use any threshold you
want to make predictions based on those scores:
y_scores = sgd_clf.decision_function([some_digit])
y_scores
array([2412.53175101])
threshold = 0
y_some_digit_pred = (y_scores > threshold)
array([ True])
·
The SGDClassifier uses a threshold equal to 0, so the previous code returns the same result as the predict()
method (i.e., True).
·
Let’s raise the threshold:
threshold = 8000
y_some_digit_pred = (y_scores > threshold)
y_some_digit_pred
array([False])
•
This confirms that raising the threshold decreases recall.
•
The image actually represents a 5, and
ü the classifier detects it when the threshold is 0,
•
but it misses it when the threshold is increased to 8,000.
How do you decide which threshold to use?
·
First, use the cross_val_predict() function to get the scores of all instances in the training set,
ü but this time specify that
you want to return decision scores instead of predictions:
y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method = "decision_function")
With these scores,
•
use the precision_recall_curve() function to compute
ü precision and
ü recall for all possible thresholds:
from sklearn.metrics import precision_recall_curve
precisions, recalls, thresholds = precision_recall_curve
(y_train_5, y_scores)
•
Finally, use Matplotlib to plot precision and recall as functions of the threshold value (Figure 4):
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
plt.plot(thresholds, recalls[:-1], "g-", label="Recall")
[...] # highlight the
threshold and add the legend, axis label, and grid
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.show()
Figure 4.
Precision and recall versus the decision threshold
·
You may wonder why the precision curve is bumpier than the recall curve in Figure 4.
·
The reason is that precision may sometimes go down when you raise the threshold.
·
To understand why,
ü look back at Figure 3 and
ü notice what happens when
you start from the central threshold and
ü move it just one digit to
the right:
·
precision goes from 4/5 (80%) down to 3/4 (75%).
·
On the other hand,
ü recall can only go down when the threshold
is increased,
ü which explains why its curve looks smooth.
•
Another way to select a good
precision/recall trade-off is to
•
plot precision directly against recall, as
shown in Figure 5 (the same threshold as earlier is highlighted).
ü
You can see that precision really starts to fall sharply
around 80% recall.
ü
You will probably want to select a precision/recall trade-off just
before that drop—for example, at around 60% recall.
ü
But of course, the choice depends on
your project.
Figure 5. Precision versus recall
·
Suppose you decide to aim for 90% precision.
·
You look up the first plot and
ü find that you need to use a threshold of about 8,000.
·
To be more precise you can search for the lowest threshold that gives you at least 90%
precision
·
np.argmax() will give you the first index of the maximum value,
ü which in this case means
the first True value:
threshold_90_precision = thresholds [np.argmax ( precisions >= 0.90 ) ] # ~7816
•
To make predictions (on the training set for now),
ü instead of calling the classifier’s predict()
method,
ü
you can run this code:
y_train_pred_90 = (y_scores >= threshold_90_precision)
•
Let’s check these predictions’ precision and recall:
precision_score(y_train_5, y_train_pred_90)
0.900038008361839
recall_score(y_train_5, y_train_pred_90)
0.4368197749492714
·
Great, you have a 90% precision classifier!
·
As you can see, it is fairly easy to create a classifier with virtually any
precision you want:
ü just set a high enough threshold.
·
But wait, not so fast.
·
A high-precision classifier is not very useful if its recall is too low!
Comments
Post a Comment