Loss
This module implements various loss functions for training machine learning models.
It includes custom loss functions that extend PyTorch's nn.Module, allowing for flexible and efficient computation of loss values during training. The loss functions handle different scenarios such as classification, regression, and segmentation tasks. They incorporate techniques like weighted losses, focal losses, and smooth L1 losses to address class imbalances and improve model performance. The module ensures that the loss calculations are compatible with PyTorch's autograd system, enabling seamless integration into training loops.
ComposableLoss(config, nz, device, validation=False)
Bases: Module
Compose the loss function using several terms. The importance of each term has to be specified in the configuration file. Each term with a >0 weight will be added to the loss function.
The loss term available are: - MSE: mean squared error between predicted and target normalized screen tags - MAE: mean absolute error between predicted and target normalized screen tags - JMSE: mean squared error between predicted and target J - JMAE: mean absolute error between predicted and target J - Pearson: Pearson correlation coefficient between predicted and target J - Fried: Fried parameter r0 - Isoplanatic: Isoplanatic angle theta0 - Rytov: Rytov variance sigma_r^2 that will be computed on log averaged Cn2 - Scintillation_w: scintillation index for weak turbulence - Scintillation_m: scintillation index for moderate-strong turbulence
Parameters:
-
config
(dict
) –Dictionary containing the configuration
-
nz
(Normalizer
) –Normalizer object to be used to extract J in its original scale
-
device
(device
) –The device to use for the computation
-
validation
(bool
, default:False
) –If true, use the validation parameters from config
Source code in src/speckcn2/loss.py
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
|
forward(pred, target)
Forward pass of the loss function.
Parameters:
-
pred
(Tensor
) –The predicted screen tags
-
target
(Tensor
) –The target screen tags
Returns:
-
loss
(Tensor
) –The composed loss
-
losses
(dict
) –Dictionary containing the individual losses
Source code in src/speckcn2/loss.py
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
|
get_FriedParameter(Jnorm)
Compute the Fried parameter r0 from the screen tags.
Source code in src/speckcn2/loss.py
305 306 307 308 |
|
get_IsoplanaticAngle(Cn2)
Compute the isoplanatic angle theta0 from the screen tags.
Source code in src/speckcn2/loss.py
338 339 340 341 342 343 344 |
|
get_J(Jnorm)
Recover J from the normalized tags. This needs to be done to compute Cn2.
Parameters:
-
Jnorm
(Tensor
) –The normalized screen tags between 0 and 1
Returns:
-
J
(Tensor
) –The recovered screen tags
Source code in src/speckcn2/loss.py
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
|
get_ScintillationModerateStrong(x)
Compute the scintillation index for moderate-strong turbulence sigma^2 from the screen tags.
Source code in src/speckcn2/loss.py
423 424 425 426 427 428 |
|
get_ScintillationWeak(Cn)
Compute the scintillation index for weak turbulence sigma^2 from the screen tags.
Source code in src/speckcn2/loss.py
385 386 387 388 389 390 391 392 |
|
reconstruct_cn2(Jnorm)
Reconstruct Cn2 from screen tags c_i = J_i / (h[i+1] - h[i])
Parameters:
-
Jnorm
(Tensor
) –The screen tags normalized between 0 and 1
Returns:
-
Cn2
(Tensor
) –The Cn2 reconstructed from the screen tags, assuming a uniform profile
Source code in src/speckcn2/loss.py
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
|