Skip to content

Loss

This module implements various loss functions for training machine learning models.

It includes custom loss functions that extend PyTorch's nn.Module, allowing for flexible and efficient computation of loss values during training. The loss functions handle different scenarios such as classification, regression, and segmentation tasks. They incorporate techniques like weighted losses, focal losses, and smooth L1 losses to address class imbalances and improve model performance. The module ensures that the loss calculations are compatible with PyTorch's autograd system, enabling seamless integration into training loops.

ComposableLoss(config, nz, device)

Bases: Module

Compose the loss function using several terms. The importance of each term has to be specified in the configuration file. Each term with a >0 weight will be added to the loss function.

The loss term available are: - MSE: mean squared error between predicted and target normalized screen tags - MAE: mean absolute error between predicted and target normalized screen tags - JMSE: mean squared error between predicted and target J - JMAE: mean absolute error between predicted and target J - Pearson: Pearson correlation coefficient between predicted and target J - Fried: Fried parameter r0 - Isoplanatic: Isoplanatic angle theta0 - Rytov: Rytov variance sigma_r^2 that will be computed on log averaged Cn2 - Scintillation_w: scintillation index for weak turbulence - Scintillation_m: scintillation index for moderate-strong turbulence

Parameters:

  • config (dict) –

    Dictionary containing the configuration

  • nz (Normalizer) –

    Normalizer object to be used to extract J in its original scale

Source code in src/speckcn2/loss.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(self, config: dict, nz: Normalizer, device: torch.device):
    super(ComposableLoss, self).__init__()
    self.device = device
    self.loss_functions: dict[str, Callable] = {
        #'MSE': torch.nn.MSELoss(reduction='none'),
        #'MAE': torch.nn.L1Loss(reduction='none'),
        'MSE': torch.nn.MSELoss(),
        'MAE': torch.nn.L1Loss(),
        'JMSE': self._MSELoss,
        'JMAE': self._L1Loss,
        'Cn2MSE': self._do_nothing,
        'Cn2MAE': self._do_nothing,
        'Pearson': self._PearsonCorrelationLoss,
        'Fried': self._FriedLoss,
        'Isoplanatic': self._IsoplanaticLoss,
        'Rytov': self._RytovLoss,
        'Scintillation_w': self._ScintillationWeakLoss,
        'Scintillation_ms': self._ScintillationModerateStrongLoss,
    }
    self.loss_weights = {
        loss_name: config['loss'].get(loss_name, 0)
        for loss_name in self.loss_functions.keys()
    }
    self.total_weight = sum(self.loss_weights.values())
    self._select_loss_needed()

    # And get some useful parameters for the loss functions
    # the parameters are explained in ...
    self.h = torch.Tensor([float(x) for x in config['speckle']['hArray']])
    self.k = 2 * torch.pi / (config['speckle'].get('lambda', 550) * 1e-9)
    self.cosz = np.cos(np.deg2rad(config['speckle'].get('z', 0)))
    self.secz = 1 / self.cosz
    self.L = config['speckle']['L']
    self.p_fr = 0.423 * self.k**2 * self.secz
    self.p_iso = self.cosz**(8. / 5.) / ((2.91 * self.k**2)**(3. / 5.))
    self.p_scw = 2.25 * self.k**(7. / 6.) * self.secz**(11. / 6.)

    # We need to ba able to recover the tags
    self.recover_tag = nz.recover_tag
    # Move tensors to the device
    self.h = self.h.to(self.device)

forward(pred, target)

Forward pass of the loss function.

Parameters:

  • pred (Tensor) –

    The predicted screen tags

  • target (Tensor) –

    The target screen tags

Returns:

  • loss ( Tensor ) –

    The composed loss

  • losses ( dict ) –

    Dictionary containing the individual losses

Source code in src/speckcn2/loss.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def forward(self, pred: torch.Tensor,
            target: torch.Tensor) -> tuple[torch.Tensor, dict]:
    """Forward pass of the loss function.

    Parameters
    ----------
    pred : torch.Tensor
        The predicted screen tags
    target : torch.Tensor
        The target screen tags

    Returns
    -------
    loss : torch.Tensor
        The composed loss
    losses : dict
        Dictionary containing the individual losses
    """
    total_loss = 0
    losses = {}

    if self.Cn2required:
        Cn2_pred = self.reconstruct_cn2(pred)
        Cn2_target = self.reconstruct_cn2(target)

    for loss_name, loss_fn in self.loss_needed.items():
        weight = self.loss_weights[loss_name]
        if loss_name in ['MAE', 'MSE']:
            #if loss_name == 'MAE':
            #    normalizing_factor = torch.abs(target) + 1e-7
            #else:
            #    normalizing_factor = target * target + 1e-7
            this_loss = loss_fn(pred, target)
            #this_loss = (this_loss / normalizing_factor).mean()
        else:
            this_loss = loss_fn(pred, target, Cn2_pred, Cn2_target)
        total_loss += weight * this_loss
        losses[loss_name] = this_loss

    return total_loss / self.total_weight, losses

get_FriedParameter(Jnorm)

Compute the Fried parameter r0 from the screen tags.

Source code in src/speckcn2/loss.py
301
302
303
304
def get_FriedParameter(self, Jnorm: torch.Tensor) -> torch.Tensor:
    """Compute the Fried parameter r0 from the screen tags."""
    J = torch.Tensor(self.get_J(Jnorm))
    return (self.p_fr * torch.sum(J))**(-3 / 5)

get_IsoplanaticAngle(Cn2)

Compute the isoplanatic angle theta0 from the screen tags.

Source code in src/speckcn2/loss.py
334
335
336
337
338
339
340
def get_IsoplanaticAngle(self, Cn2: torch.Tensor) -> torch.Tensor:
    """Compute the isoplanatic angle theta0 from the screen tags."""
    # Integrate Cn2*z^(5/3)
    integral = torch.sum(
        Cn2 * (self.h[1:]**(8 / 3) - self.h[:-1]**(8 / 3))) * 3 / 8
    # Then I can compute theta0
    return self.p_iso / (integral**(3 / 5))

get_J(Jnorm)

Recover J from the normalized tags. This needs to be done to compute Cn2.

Parameters:

  • Jnorm (Tensor) –

    The normalized screen tags between 0 and 1

Returns:

  • J ( Tensor ) –

    The recovered screen tags

Source code in src/speckcn2/loss.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def get_J(self, Jnorm: torch.Tensor) -> torch.Tensor:
    """Recover J from the normalized tags. This needs to be done to compute
    Cn2.

    Parameters
    ----------
    Jnorm : torch.Tensor
        The normalized screen tags between 0 and 1

    Returns
    -------
    J : torch.Tensor
        The recovered screen tags
    """

    if Jnorm.ndim == 1:
        Jnorm = Jnorm[None, :]

    J = []
    for i in range(Jnorm.shape[0]):
        J.append(
            torch.tensor([
                10**self.recover_tag[j](Jnorm[i][j], i)
                for j in range(len(Jnorm[i]))
            ],
                         requires_grad=True).to(Jnorm.device))
    J = torch.stack(J)
    return J

get_ScintillationModerateStrong(x)

Compute the scintillation index for moderate-strong turbulence sigma^2 from the screen tags.

Source code in src/speckcn2/loss.py
419
420
421
422
423
424
def get_ScintillationModerateStrong(self, x: torch.Tensor) -> torch.Tensor:
    """Compute the scintillation index for moderate-strong turbulence
    sigma^2 from the screen tags."""
    wsigma2 = self.get_ScintillationWeak(x)
    return torch.exp(wsigma2 * 0.49 / (1 + 1.11 * wsigma2**(6 / 5)) +
                     0.51 * wsigma2 / (1 + 0.69 * wsigma2**(6 / 5)))

get_ScintillationWeak(Cn)

Compute the scintillation index for weak turbulence sigma^2 from the screen tags.

Source code in src/speckcn2/loss.py
381
382
383
384
385
386
387
388
def get_ScintillationWeak(self, Cn: torch.Tensor) -> torch.Tensor:
    """Compute the scintillation index for weak turbulence sigma^2 from the
    screen tags."""
    # Integrate Cn2*z^(5/6)
    integral = torch.sum(
        Cn * (self.h[1:]**(11 / 6) - self.h[:-1]**(11 / 6))) * 6 / 11
    # Then I can compute sigma^2
    return self.p_scw * integral

reconstruct_cn2(Jnorm)

Reconstruct Cn2 from screen tags c_i = J_i / (h[i+1] - h[i])

Parameters:

  • Jnorm (Tensor) –

    The screen tags normalized between 0 and 1

Returns:

  • Cn2 ( Tensor ) –

    The Cn2 reconstructed from the screen tags, assuming a uniform profile

Source code in src/speckcn2/loss.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def reconstruct_cn2(self, Jnorm: torch.Tensor) -> torch.Tensor:
    """ Reconstruct Cn2 from screen tags
    c_i = J_i / (h[i+1] - h[i])

    Parameters
    ----------
    Jnorm : torch.Tensor
        The screen tags normalized between 0 and 1

    Returns
    -------
    Cn2 : torch.Tensor
        The Cn2 reconstructed from the screen tags, assuming a uniform profile
    """
    J = self.get_J(Jnorm)
    Cn2 = J / (self.h[1:] - self.h[:-1])
    return Cn2