Skip to content

I/O

This module provides utility functions for loading and saving model configurations and states.

It includes functions to load configuration files, save model states, load model states, and load the latest model state from a directory.

load(model, datadirectory, epoch, early_stop=False)

Load the model state and the model itself from a specified directory and epoch.

This function loads the model's state dictionary and other relevant information such as epoch, loss, validation loss, and time from a file in the specified directory.

Parameters:

  • model (Module) –

    The model to load

  • datadirectory (str) –

    The directory where the data is stored

  • epoch (int) –

    The epoch of the model

  • early_stop (bool, default: False ) –

    If True, the last state reached the early stop condition

Source code in src/speckcn2/io.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def load(model: torch.nn.Module,
         datadirectory: str,
         epoch: int,
         early_stop: bool = False) -> None:
    """Load the model state and the model itself from a specified directory and
    epoch.

    This function loads the model's state dictionary and other relevant information
    such as epoch, loss, validation loss, and time from a file in the specified directory.

    Parameters
    ----------
    model : torch.nn.Module
        The model to load
    datadirectory : str
        The directory where the data is stored
    epoch : int
        The epoch of the model
    early_stop: bool
        If True, the last state reached the early stop condition
    """
    if early_stop:
        model_state = torch.load(
            f'{datadirectory}/{model.name}_states/{model.name}_{epoch}_earlystop.pth'
        )
        model.early_stop = True
    else:
        model_state = torch.load(
            f'{datadirectory}/{model.name}_states/{model.name}_{epoch}.pth')

    model.epoch = model_state['epoch']
    model.loss = model_state['loss']
    model.val_loss = model_state['val_loss']
    model.time = model_state['time']
    model.load_state_dict(model_state['model_state_dict'], strict=False)

    assert model.epoch[
        -1] == epoch, 'The epoch of the model is not the same as the one loaded'

load_config(config_file_path)

Load the configuration file from a given path.

This function reads a YAML configuration file and returns its contents as a dictionary.

Parameters:

  • config_file_path (str) –

    Path to the .yaml configuration file

Returns:

  • config ( dict ) –

    Dictionary containing the configuration

Source code in src/speckcn2/io.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def load_config(config_file_path: str) -> dict:
    """Load the configuration file from a given path.

    This function reads a YAML configuration file and returns its contents as a dictionary.

    Parameters
    ----------
    config_file_path : str
        Path to the .yaml configuration file

    Returns
    -------
    config : dict
        Dictionary containing the configuration
    """
    with open(config_file_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

load_model_state(model, datadirectory)

Loads the latest model state from the given directory.

This function checks the specified directory for the latest model state file, loads it, and updates the model with the loaded state. If no state is found, it initializes the model state. If the training was stopped after meeting an early stop condition, this function signals that the training should not be continued.

Parameters:

  • model (Module) –

    The model to load the state into

  • datadirectory (str) –

    The directory where the model states are stored

Returns:

  • model ( Module ) –

    The model with the loaded state

  • last_model_state ( int ) –

    The number of the last model state

Source code in src/speckcn2/io.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def load_model_state(model: torch.nn.Module,
                     datadirectory: str) -> tuple[torch.nn.Module, int]:
    """Loads the latest model state from the given directory.

    This function checks the specified directory for the latest model state file,
    loads it, and updates the model with the loaded state. If no state is found,
    it initializes the model state.
    If the training was stopped after meeting an early stop condition, this function
    signals that the training should not be continued.

    Parameters
    ----------
    model : torch.nn.Module
        The model to load the state into
    datadirectory : str
        The directory where the model states are stored

    Returns
    -------
    model : torch.nn.Module
        The model with the loaded state
    last_model_state : int
        The number of the last model state
    """
    # Print model information
    print(model)
    model.nparams = sum(p.numel() for p in model.parameters())
    print(f'\n--> Nparams = {model.nparams}')

    fulldirname = f'{datadirectory}/{model.name}_states'
    ensure_directory(fulldirname)

    # First check if there was an early stop
    earlystop = [
        filename for filename in os.listdir(fulldirname)
        if 'earlystop' in filename
    ]
    if len(earlystop) == 0:
        # If there was no early stop, check what is the last model state
        try:
            last_model_state = sorted([
                int(file_name.split('.pth')[0].split('_')[-1])
                for file_name in os.listdir(fulldirname)
            ])[-1]
        except Exception as e:
            print(f'Warning: {e}')
            last_model_state = 0

        if last_model_state > 0:
            print(
                f'Loading model at epoch {last_model_state}, from {datadirectory}'
            )
            load(model, datadirectory, last_model_state)
            return model, last_model_state
        else:
            print('No pretrained model to load')

            # Initialize some model state measures
            model.loss = []
            model.val_loss = []
            model.time = []
            model.epoch = []

            return model, 0
    elif len(earlystop) == 1:
        filename = earlystop[0]
        print(f'Loading the early stop state {filename}')
        last_model_state = int(filename.split('_')[-2])
        load(model, datadirectory, last_model_state, early_stop=True)
        return model, last_model_state
    else:
        print(
            f'Error: more than one early stop state found. This is not correct. This is the list: {earlystop}'
        )
        sys.exit(0)

save(model, datadirectory, early_stop=False)

Save the model state and the model itself to a specified directory.

This function saves the model's state dictionary and other relevant information such as epoch, loss, validation loss, and time to a file in the specified directory.

Parameters:

  • model (Module) –

    The model to save

  • datadirectory (str) –

    The directory where the data is stored

  • early_stop (bool, default: False ) –

    If True, the model corresponds to the moment when early stop was triggered

Source code in src/speckcn2/io.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def save(model: torch.nn.Module,
         datadirectory: str,
         early_stop: bool = False) -> None:
    """Save the model state and the model itself to a specified directory.

    This function saves the model's state dictionary and other relevant information
    such as epoch, loss, validation loss, and time to a file in the specified directory.

    Parameters
    ----------
    model : torch.nn.Module
        The model to save
    datadirectory : str
        The directory where the data is stored
    early_stop: bool
        If True, the model corresponds to the moment when early stop was triggered
    """
    model_state = {
        'epoch': model.epoch,
        'loss': model.loss,
        'val_loss': model.val_loss,
        'time': model.time,
        'model_state_dict': model.state_dict(),
    }

    if not early_stop:
        savename = f'{datadirectory}/{model.name}_states/{model.name}_{model.epoch[-1]}.pth'
    else:
        savename = f'{datadirectory}/{model.name}_states/{model.name}_{model.epoch[-1]}_earlystop.pth'

    torch.save(model_state, savename)