Skip to content

Postprocess

tags_distribution(conf, train_set, test_tags, device, rescale=False, recover_tag=None)

Plots the distribution of the tags.

Parameters:

  • conf (dict) –

    Dictionary containing the configuration

  • train_set (list) –

    The training set

  • test_tags (Tensor) –

    The predicted tags for the test dataset

  • device (device) –

    The device to use

  • data_directory (str) –

    The directory where the data is stored

  • rescale (bool, default: False ) –

    Whether to rescale the tags using recover_tag() or leave them between 0 and 1

  • recover_tag (list, default: None ) –

    List of functions to recover each tag

Source code in src/speckcn2/postprocess.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def tags_distribution(conf: dict,
                      train_set: list,
                      test_tags: Tensor,
                      device: Device,
                      rescale: bool = False,
                      recover_tag: Optional[list[Callable]] = None) -> None:
    """Plots the distribution of the tags.

    Parameters
    ----------
    conf : dict
        Dictionary containing the configuration
    train_set : list
        The training set
    test_tags : torch.Tensor
        The predicted tags for the test dataset
    device : torch.device
        The device to use
    data_directory : str
        The directory where the data is stored
    rescale : bool, optional
        Whether to rescale the tags using recover_tag() or leave them between 0 and 1
    recover_tag : list, optional
        List of functions to recover each tag
    """

    data_directory = conf['speckle']['datadirectory']
    model_name = conf['model']['name']
    ensemble = conf['preproc'].get('ensemble', 1)

    ensure_directory(f'{data_directory}/result_plots')

    # Get the tags from the training set
    if ensemble > 1:
        train_set = list(itertools.chain(*train_set))
    _, tags, _ = zip(*train_set)
    tags = np.stack(tags)
    train_tags = np.array([n for n in tags])

    # Get the tags from the test set
    predic_tags = np.array([n.cpu().numpy() for n in test_tags])
    print(f'Data shape: {train_tags.shape}')
    print(f'Prediction shape: {predic_tags.shape}')
    print(f'Train mean: {train_tags.mean()}')
    print(f'Train std: {train_tags.std()}')
    print(f'Prediction mean: {predic_tags.mean()}')
    print(f'Prediction std: {predic_tags.std()}')

    # Keep track of J=sum(tags) for each sample
    J_pred = np.zeros(predic_tags.shape[0])
    J_true = np.zeros(train_tags.shape[0])

    # Plot the distribution of each tag element
    fig, axs = plt.subplots(2, 4, figsize=(20, 10))
    for i in range(train_tags.shape[1]):
        if rescale and recover_tag is not None:
            recovered_tag_model = np.asarray(
                [recover_tag[i](predic_tags[:, i], i)]).squeeze(0)
            recovered_tag_true = np.asarray(
                [recover_tag[i](train_tags[:, i], i)]).squeeze(0)
            axs[i // 4, i % 4].hist(recovered_tag_model,
                                    bins=20,
                                    color='tab:red',
                                    density=True,
                                    alpha=0.5,
                                    label='Model prediction')
            axs[i // 4, i % 4].hist(recovered_tag_true,
                                    bins=20,
                                    color='tab:blue',
                                    density=True,
                                    alpha=0.5,
                                    label='Training data')
            J_pred += 10**recovered_tag_model
            J_true += 10**recovered_tag_true
        else:
            axs[i // 4, i % 4].hist(predic_tags[:, i],
                                    bins=20,
                                    color='tab:red',
                                    density=True,
                                    alpha=0.5,
                                    label='Model prediction')
            axs[i // 4, i % 4].hist(train_tags[:, i],
                                    bins=20,
                                    color='tab:blue',
                                    density=True,
                                    alpha=0.5,
                                    label='Training data')
        axs[i // 4, i % 4].set_title(f'Tag {i}')
    axs[0, 1].legend()
    plt.tight_layout()
    if rescale:
        plt.savefig(f'{data_directory}/result_plots/{model_name}_tags.png')
    else:
        plt.savefig(
            f'{data_directory}/result_plots/{model_name}_tags_unscaled.png')
    plt.close()

    if rescale and recover_tag is not None:
        # Also plot the distribution of the sum of the tags
        fig, axs = plt.subplots(1, 1, figsize=(6, 6))
        axs.hist(np.log10(J_pred),
                 bins=20,
                 color='tab:red',
                 density=True,
                 alpha=0.5,
                 label='Model prediction')
        axs.hist(np.log10(J_true),
                 bins=20,
                 color='tab:blue',
                 density=True,
                 alpha=0.5,
                 label='Training data')
        axs.set_title('Sum of J')
        axs.legend()
        plt.tight_layout()
        plt.savefig(f'{data_directory}/result_plots/{model_name}_sumJ.png')

        plt.close()