Skip to content

Plot confusion matrix

plot_confusion_matrix(confusion_matrix, cmap=None)

Plot confusion matrix to visualize classification results.

Parameters:

Name Type Description Default
confusion_matrix ndarray

The confusion matrix as 2D Numpy array. Expects the first element (upper-left corner) to have True negatives.

required
cmap Optional[Union[str, Colormap, Sequence]]

Colormap name, matploltib colormap objects or list of colors for coloring the plot. Optional parameter.

None

Returns:

Type Description
Axes

Matplotlib axes containing the plot.

Raises:

Type Description
InvalidDataShapeException

Raised if input confusion matrix is not square.

Source code in eis_toolkit/validation/plot_confusion_matrix.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
@beartype
def plot_confusion_matrix(
    confusion_matrix: np.ndarray, cmap: Optional[Union[str, Colormap, Sequence]] = None
) -> plt.Axes:
    """Plot confusion matrix to visualize classification results.

    Args:
        confusion_matrix: The confusion matrix as 2D Numpy array. Expects the first element
            (upper-left corner) to have True negatives.
        cmap: Colormap name, matploltib colormap objects or list of colors for coloring the plot.
            Optional parameter.

    Returns:
        Matplotlib axes containing the plot.

    Raises:
        InvalidDataShapeException: Raised if input confusion matrix is not square.
    """
    shape = confusion_matrix.shape
    if shape[0] != shape[1]:
        raise InvalidDataShapeException(f"Expected confusion matrix to be square, input array has shape: {shape}")
    names = None

    counts = ["{0:0.0f}".format(value) for value in confusion_matrix.flatten()]
    percentages = ["{0:.2%}".format(value) for value in confusion_matrix.flatten() / np.sum(confusion_matrix)]

    if shape == (2, 2):  # Binary classificaiton
        names = ["True Neg", "False Pos", "False Neg", "True Pos"]
        labels = np.asarray([f"{v1}\n{v2}\n{v3}" for v1, v2, v3 in zip(names, counts, percentages)]).reshape(shape)
    else:
        labels = np.asarray([f"{v1}\n{v2}" for v1, v2 in zip(counts, percentages)]).reshape(shape)

    ax = sns.heatmap(confusion_matrix, annot=labels, fmt="", cmap=cmap)
    ax.set(xlabel="Predicted label", ylabel="True label")

    return ax