Skip to content

Utils

This module provides utility functions for image processing and model optimization.

It includes functions to plot original and preprocessed images along with their tags, ensure the existence of specified directories, set up optimizers based on configuration files, and create circular masks with an inner "spider" circle removed. These utilities facilitate various tasks in image analysis and machine learning model training.

create_circular_mask_with_spider(resolution, bkg_value=0)

Creates a circular mask with an inner "spider" circle removed.

Parameters:

  • resolution (int) –

    The resolution of the square mask.

  • bkg_value (int, default: 0 ) –

    The background value to set for the masked areas. Defaults to 0.

Returns:

  • torch.Tensor : np.ndarray –

    A 2D tensor representing the mask.

Source code in src/speckcn2/utils.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def create_circular_mask_with_spider(resolution: int,
                                     bkg_value: int = 0) -> torch.Tensor:
    """Creates a circular mask with an inner "spider" circle removed.

    Parameters
    ----------
    resolution : int
        The resolution of the square mask.
    bkg_value : int
        The background value to set for the masked areas. Defaults to 0.

    Returns
    -------
    torch.Tensor : np.ndarray
        A 2D tensor representing the mask.
    """
    # Create a circular mask
    center = (int(resolution / 2), int(resolution / 2))
    radius = min(center)
    Y, X = np.ogrid[:resolution, :resolution]
    mask = (X - center[0])**2 + (Y - center[1])**2 > radius**2

    # Remove the inner circle (spider)
    spider_radius = int(0.22 * resolution)
    spider_mask = (X - center[0])**2 + (Y - center[1])**2 < spider_radius**2

    # Apply background value to the mask and spider mask
    final_mask = np.ones((resolution, resolution), dtype=np.uint8)
    final_mask[mask] = bkg_value
    final_mask[spider_mask] = bkg_value

    return torch.Tensor(final_mask)

ensure_directory(data_directory)

Ensure that the directory exists.

Parameters:

  • data_directory (str) –

    The directory to ensure

Source code in src/speckcn2/utils.py
72
73
74
75
76
77
78
79
80
81
82
def ensure_directory(data_directory: str) -> None:
    """Ensure that the directory exists.

    Parameters
    ----------
    data_directory : str
        The directory to ensure
    """

    if not os.path.isdir(data_directory):
        os.mkdir(data_directory)

plot_preprocessed_image(image_orig, image, tags, counter, datadirectory, mname, file_name, polar=False)

Plots the original and preprocessed image, and the tags.

Parameters:

  • image_orig (tensor) –

    The original image

  • image (tensor) –

    The preprocessed image

  • tags (tensor) –

    The screen tags

  • counter (int) –

    The counter of the image

  • datadirectory (str) –

    The directory containing the data

  • mname (str) –

    The name of the model

  • file_name (str) –

    The name of the original image

  • polar (bool, default: False ) –

    If the image is in polar coordinates, by default False

Source code in src/speckcn2/utils.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def plot_preprocessed_image(image_orig: torch.tensor,
                            image: torch.tensor,
                            tags: torch.tensor,
                            counter: int,
                            datadirectory: str,
                            mname: str,
                            file_name: str,
                            polar: bool = False) -> None:
    """Plots the original and preprocessed image, and the tags.

    Parameters
    ----------
    image_orig : torch.tensor
        The original image
    image : torch.tensor
        The preprocessed image
    tags : torch.tensor
        The screen tags
    counter : int
        The counter of the image
    datadirectory : str
        The directory containing the data
    mname : str
        The name of the model
    file_name : str
        The name of the original image
    polar : bool, optional
        If the image is in polar coordinates, by default False
    """

    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    # Plot the original image
    axs[0].imshow(image_orig.squeeze(), cmap='bone')
    axs[0].set_title(f'Training Image {file_name}')
    # Plot the preprocessd image
    axs[1].imshow(image.squeeze(), cmap='bone')
    axs[1].set_title('Processed as')
    if polar:
        axs[1].set_xlabel(r'$r$')
        axs[1].set_ylabel(r'$\theta$')

    # Plot the tags
    axs[2].plot(tags, 'o')
    axs[2].set_yscale('log')
    axs[2].set_title('Screen Tags')
    axs[2].legend()

    fig.subplots_adjust(wspace=0.3)
    plt.savefig(f'{datadirectory}/imgs_to_{mname}/{counter}.png')
    plt.close()

setup_optimizer(config, model)

Returns the optimizer specified in the configuration file.

Parameters:

  • config (dict) –

    Dictionary containing the configuration

  • model (Module) –

    The model to optimize

Returns:

  • optimizer ( Module ) –

    The optimizer with the loaded state

Source code in src/speckcn2/utils.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def setup_optimizer(config: dict, model: nn.Module) -> nn.Module:
    """Returns the optimizer specified in the configuration file.

    Parameters
    ----------
    config : dict
        Dictionary containing the configuration
    model : torch.nn.Module
        The model to optimize

    Returns
    -------
    optimizer : torch.nn.Module
        The optimizer with the loaded state
    """

    optimizer_name = config['hyppar']['optimizer']
    if optimizer_name == 'Adam':
        return torch.optim.Adam(model.parameters(), lr=config['hyppar']['lr'])
    elif optimizer_name == 'SGD':
        return torch.optim.SGD(model.parameters(), lr=config['hyppar']['lr'])
    else:
        raise ValueError(f'Unknown optimizer {optimizer_name}')