Skip to content

Loss

This module implements various loss functions for training machine learning models.

It includes custom loss functions that extend PyTorch's nn.Module, allowing for flexible and efficient computation of loss values during training. The loss functions handle different scenarios such as classification, regression, and segmentation tasks. They incorporate techniques like weighted losses, focal losses, and smooth L1 losses to address class imbalances and improve model performance. The module ensures that the loss calculations are compatible with PyTorch's autograd system, enabling seamless integration into training loops.

ComposableLoss(config, nz, device, validation=False)

Bases: Module

Compose the loss function using several terms. The importance of each term has to be specified in the configuration file. Each term with a >0 weight will be added to the loss function.

The loss term available are: - MSE: mean squared error between predicted and target normalized screen tags - MAE: mean absolute error between predicted and target normalized screen tags - JMSE: mean squared error between predicted and target J - JMAE: mean absolute error between predicted and target J - Pearson: Pearson correlation coefficient between predicted and target J - Fried: Fried parameter r0 - Isoplanatic: Isoplanatic angle theta0 - Rytov: Rytov variance sigma_r^2 that will be computed on log averaged Cn2 - Scintillation_w: scintillation index for weak turbulence - Scintillation_m: scintillation index for moderate-strong turbulence

Parameters:

  • config (dict) –

    Dictionary containing the configuration

  • nz (Normalizer) –

    Normalizer object to be used to extract J in its original scale

  • device (device) –

    The device to use for the computation

  • validation (bool, default: False ) –

    If true, use the validation parameters from config

Source code in src/speckcn2/loss.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def __init__(self,
             config: dict,
             nz: Normalizer,
             device: torch.device,
             validation: bool = False):
    super(ComposableLoss, self).__init__()
    if validation:
        if 'val_loss' in config:
            config['loss'] = config['val_loss']
        else:
            print('[!] Warning: Validation loss not found in config.yaml,',
                  'keeping track of training loss instead')
    self.device = device
    self.loss_functions: dict[str, Callable] = {
        'MSE': torch.nn.MSELoss(),
        'MAE': torch.nn.L1Loss(),
        'JMSE': self._MSELoss,
        'JMAE': self._L1Loss,
        'Cn2MSE': self._do_nothing,
        'Cn2MAE': self._do_nothing,
        'Pearson': self._PearsonCorrelationLoss,
        'Fried': self._FriedLoss,
        'Isoplanatic': self._IsoplanaticLoss,
        'Rytov': self._RytovLoss,
        'Scintillation_w': self._ScintillationWeakLoss,
        'Scintillation_ms': self._ScintillationModerateStrongLoss,
    }
    self.loss_weights = {
        loss_name: config['loss'].get(loss_name, 0)
        for loss_name in self.loss_functions.keys()
    }
    self.total_weight = sum(self.loss_weights.values())
    self._select_loss_needed()

    # And get some useful parameters for the loss functions
    # the parameters are explained at:
    # https://males-project.github.io/SpeckleCn2Profiler/example/#configuration-file-explanation
    self.h = torch.Tensor([float(x) for x in config['speckle']['hArray']])
    self.k = 2 * torch.pi / (config['speckle'].get('lambda', 550) * 1e-9)
    self.cosz = np.cos(np.deg2rad(config['speckle'].get('z', 0)))
    self.secz = 1 / self.cosz
    self.p_fr = 0.423 * self.k**2 * self.secz
    self.p_iso = self.cosz**(8. / 5.) / ((2.91 * self.k**2)**(3. / 5.))
    self.p_scw = 2.25 * self.k**(7. / 6.) * self.secz**(11. / 6.)

    # We need to ba able to recover the tags
    self.recover_tag = nz.recover_tag
    # Move tensors to the device
    self.h = self.h.to(self.device)

forward(pred, target)

Forward pass of the loss function.

Parameters:

  • pred (Tensor) –

    The predicted screen tags

  • target (Tensor) –

    The target screen tags

Returns:

  • loss ( Tensor ) –

    The composed loss

  • losses ( dict ) –

    Dictionary containing the individual losses

Source code in src/speckcn2/loss.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def forward(self, pred: torch.Tensor,
            target: torch.Tensor) -> tuple[torch.Tensor, dict]:
    """Forward pass of the loss function.

    Parameters
    ----------
    pred : torch.Tensor
        The predicted screen tags
    target : torch.Tensor
        The target screen tags

    Returns
    -------
    loss : torch.Tensor
        The composed loss
    losses : dict
        Dictionary containing the individual losses
    """
    total_loss = 0
    losses = {}

    if self.Cn2required:
        Cn2_pred = self.reconstruct_cn2(pred)
        Cn2_target = self.reconstruct_cn2(target)

    for loss_name, loss_fn in self.loss_needed.items():
        weight = self.loss_weights[loss_name]
        if loss_name in ['MAE', 'MSE']:
            this_loss = loss_fn(pred, target)
        else:
            this_loss = loss_fn(pred, target, Cn2_pred, Cn2_target)
        total_loss += weight * this_loss
        losses[loss_name] = this_loss

    return total_loss / self.total_weight, losses

get_FriedParameter(Jnorm)

Compute the Fried parameter r0 from the screen tags.

Source code in src/speckcn2/loss.py
305
306
307
308
def get_FriedParameter(self, Jnorm: torch.Tensor) -> torch.Tensor:
    """Compute the Fried parameter r0 from the screen tags."""
    J = torch.Tensor(self.get_J(Jnorm))
    return (self.p_fr * torch.sum(J))**(-3 / 5)

get_IsoplanaticAngle(Cn2)

Compute the isoplanatic angle theta0 from the screen tags.

Source code in src/speckcn2/loss.py
338
339
340
341
342
343
344
def get_IsoplanaticAngle(self, Cn2: torch.Tensor) -> torch.Tensor:
    """Compute the isoplanatic angle theta0 from the screen tags."""
    # Integrate Cn2*z^(5/3)
    integral = torch.sum(
        Cn2 * (self.h[1:]**(8 / 3) - self.h[:-1]**(8 / 3))) * 3 / 8
    # Then I can compute theta0
    return self.p_iso / (integral**(3 / 5))

get_J(Jnorm)

Recover J from the normalized tags. This needs to be done to compute Cn2.

Parameters:

  • Jnorm (Tensor) –

    The normalized screen tags between 0 and 1

Returns:

  • J ( Tensor ) –

    The recovered screen tags

Source code in src/speckcn2/loss.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def get_J(self, Jnorm: torch.Tensor) -> torch.Tensor:
    """Recover J from the normalized tags. This needs to be done to compute
    Cn2.

    Parameters
    ----------
    Jnorm : torch.Tensor
        The normalized screen tags between 0 and 1

    Returns
    -------
    J : torch.Tensor
        The recovered screen tags
    """

    if Jnorm.ndim == 1:
        Jnorm = Jnorm[None, :]

    J = []
    for i in range(Jnorm.shape[0]):
        J.append(
            torch.tensor([
                10**self.recover_tag[j](Jnorm[i][j], i)
                for j in range(len(Jnorm[i]))
            ],
                         requires_grad=True).to(Jnorm.device))
    J = torch.stack(J)
    return J

get_ScintillationModerateStrong(x)

Compute the scintillation index for moderate-strong turbulence sigma^2 from the screen tags.

Source code in src/speckcn2/loss.py
423
424
425
426
427
428
def get_ScintillationModerateStrong(self, x: torch.Tensor) -> torch.Tensor:
    """Compute the scintillation index for moderate-strong turbulence
    sigma^2 from the screen tags."""
    wsigma2 = self.get_ScintillationWeak(x)
    return torch.exp(wsigma2 * 0.49 / (1 + 1.11 * wsigma2**(6 / 5)) +
                     0.51 * wsigma2 / (1 + 0.69 * wsigma2**(6 / 5)))

get_ScintillationWeak(Cn)

Compute the scintillation index for weak turbulence sigma^2 from the screen tags.

Source code in src/speckcn2/loss.py
385
386
387
388
389
390
391
392
def get_ScintillationWeak(self, Cn: torch.Tensor) -> torch.Tensor:
    """Compute the scintillation index for weak turbulence sigma^2 from the
    screen tags."""
    # Integrate Cn2*z^(5/6)
    integral = torch.sum(
        Cn * (self.h[1:]**(11 / 6) - self.h[:-1]**(11 / 6))) * 6 / 11
    # Then I can compute sigma^2
    return self.p_scw * integral

reconstruct_cn2(Jnorm)

Reconstruct Cn2 from screen tags c_i = J_i / (h[i+1] - h[i])

Parameters:

  • Jnorm (Tensor) –

    The screen tags normalized between 0 and 1

Returns:

  • Cn2 ( Tensor ) –

    The Cn2 reconstructed from the screen tags, assuming a uniform profile

Source code in src/speckcn2/loss.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def reconstruct_cn2(self, Jnorm: torch.Tensor) -> torch.Tensor:
    """ Reconstruct Cn2 from screen tags
    c_i = J_i / (h[i+1] - h[i])

    Parameters
    ----------
    Jnorm : torch.Tensor
        The screen tags normalized between 0 and 1

    Returns
    -------
    Cn2 : torch.Tensor
        The Cn2 reconstructed from the screen tags, assuming a uniform profile
    """
    J = self.get_J(Jnorm)
    Cn2 = J / (self.h[1:] - self.h[:-1])
    return Cn2