Skip to content

Plot neural network training performance (accuracy and loss)

plot_nn_model_accuracy(model_history)

Plot training and validation accuracies for a neural network model.

Parameters:

Name Type Description Default
model_history dict

Dictionary containing neural network model training history information, specifically entries for "accuracy" and "val_accuracy".

required

Returns:

Type Description
Axes

Matplotlib axes containing the produced plot.

Raises:

Type Description
InvalidDatasetException

Raised if "accuracy" or "val_accuracy" are not found in the model_history.

InvalidDataShapeException

Raised if "accuracy" and "val_accuracy" have mismatching lengths.

Source code in eis_toolkit/evaluation/plot_nn_model_performance.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
@beartype
def plot_nn_model_accuracy(model_history: dict) -> plt.Axes:
    """Plot training and validation accuracies for a neural network model.

    Args:
        model_history: Dictionary containing neural network model training history information,
            specifically entries for "accuracy" and "val_accuracy".

    Returns:
        Matplotlib axes containing the produced plot.

    Raises:
        InvalidDatasetException: Raised if "accuracy" or "val_accuracy" are not found in the model_history.
        InvalidDataShapeException: Raised if "accuracy" and "val_accuracy" have mismatching lengths.
    """
    if not all(key in model_history for key in ("accuracy", "val_accuracy")):
        raise InvalidDatasetException("Expected 'accuracy' and 'val_accuracy' to be found in model_history.")
    if len(model_history["accuracy"]) != len(model_history["val_accuracy"]):
        raise InvalidDataShapeException("Expected 'accuracy' and 'val_accuracy' to have the same length.")

    df = pd.DataFrame(
        {
            "Training set accuracy": model_history["accuracy"],
            "Validation set accuracy": model_history["val_accuracy"],
        }
    )
    ax = sns.lineplot(data=df)
    ax.set(xlabel="Epoch", ylabel="Accuracy")

    return ax

plot_nn_model_loss(model_history)

Plot training and validation losses for a neural network model.

Parameters:

Name Type Description Default
model_history dict

Dictionary containing neural network model training history information, specifically entries for "loss" and "val_loss".

required

Returns:

Type Description
Axes

Matplotlib axes containing the produced plot.

Raises:

Type Description
InvalidDatasetException

Raised if "loss" or "val_loss" are not found in the model_history.

InvalidDataShapeException

Raised if "loss" and "val_loss" have mismatching lengths.

Source code in eis_toolkit/evaluation/plot_nn_model_performance.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
@beartype
def plot_nn_model_loss(model_history: dict) -> plt.Axes:
    """Plot training and validation losses for a neural network model.

    Args:
        model_history: Dictionary containing neural network model training history information,
            specifically entries for "loss" and "val_loss".

    Returns:
        Matplotlib axes containing the produced plot.

    Raises:
        InvalidDatasetException: Raised if "loss" or "val_loss" are not found in the model_history.
        InvalidDataShapeException: Raised if "loss" and "val_loss" have mismatching lengths.
    """
    if not all(key in model_history for key in ("loss", "val_loss")):
        raise InvalidDatasetException("Expected 'loss' and 'val_loss' to be found in model_history.")
    if len(model_history["loss"]) != len(model_history["val_loss"]):
        raise InvalidDataShapeException("Expected 'loss' and 'val_loss' to have the same length.")

    df = pd.DataFrame(
        {
            "Training set loss": model_history["loss"],
            "Validation set loss": model_history["val_loss"],
        }
    )
    ax = sns.lineplot(data=df)
    ax.set(xlabel="Epoch", ylabel="Loss")
    return ax