Skip to content

Plot confusion matrix

plot_confusion_matrix(confusion_matrix, cmap=None, plot_title='Confusion matrix', ax=None, **kwargs)

Plot confusion matrix to visualize classification results.

Parameters:

Name Type Description Default
confusion_matrix ndarray

The confusion matrix as 2D Numpy array. Expects the first element (upper-left corner) to have True negatives.

required
cmap Optional[Union[str, Colormap, Sequence]]

Colormap name, matploltib colormap objects or list of colors for coloring the plot. Optional parameter.

None
plot_title str

Title for the plot. Defaults to "Confusion matrix".

'Confusion matrix'
ax Optional[Axes]

An existing Axes in which to draw the plot. Defaults to None.

None
**kwargs

Additional keyword arguments passed to sns.heatmap.

{}

Returns:

Type Description
Axes

Matplotlib axes containing the plot.

Raises:

Type Description
InvalidDataShapeException

Raised if input confusion matrix is not square.

Source code in eis_toolkit/evaluation/plot_confusion_matrix.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
@beartype
def plot_confusion_matrix(
    confusion_matrix: np.ndarray,
    cmap: Optional[Union[str, Colormap, Sequence]] = None,
    plot_title: str = "Confusion matrix",
    ax: Optional[plt.Axes] = None,
    **kwargs,
) -> plt.Axes:
    """Plot confusion matrix to visualize classification results.

    Args:
        confusion_matrix: The confusion matrix as 2D Numpy array. Expects the first element
            (upper-left corner) to have True negatives.
        cmap: Colormap name, matploltib colormap objects or list of colors for coloring the plot.
            Optional parameter.
        plot_title: Title for the plot. Defaults to "Confusion matrix".
        ax: An existing Axes in which to draw the plot. Defaults to None.
        **kwargs: Additional keyword arguments passed to sns.heatmap.

    Returns:
        Matplotlib axes containing the plot.

    Raises:
        InvalidDataShapeException: Raised if input confusion matrix is not square.
    """
    shape = confusion_matrix.shape
    if shape[0] != shape[1]:
        raise InvalidDataShapeException(f"Expected confusion matrix to be square, input array has shape: {shape}")
    names = None

    counts = ["{0:0.0f}".format(value) for value in confusion_matrix.flatten()]
    percentages = ["{0:.2%}".format(value) for value in confusion_matrix.flatten() / np.sum(confusion_matrix)]

    if shape == (2, 2):  # Binary classificaiton
        names = ["True Neg", "False Pos", "False Neg", "True Pos"]
        labels = np.asarray([f"{v1}\n{v2}\n{v3}" for v1, v2, v3 in zip(names, counts, percentages)]).reshape(shape)
    else:
        labels = np.asarray([f"{v1}\n{v2}" for v1, v2 in zip(counts, percentages)]).reshape(shape)

    out_ax = sns.heatmap(confusion_matrix, annot=labels, fmt="", cmap=cmap, ax=ax, **kwargs)
    out_ax.set(xlabel="Predicted label", ylabel="True label", title=plot_title)

    return out_ax