1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
| class EarlyStopping: """Early stops the training if validation loss doesn't improve after a given patience.""" def __init__(self, patience=7, verbose=False, delta=0): """ Args: patience (int): How long to wait after last time validation loss improved. 上次验证集损失值改善后等待几个epoch Default: 7 verbose (bool): If True, prints a message for each validation loss improvement. 如果是True,为每个验证集损失值改善打印一条信息 Default: False delta (float): Minimum change in the monitored quantity to qualify as an improvement. 监测数量的最小变化,以符合改进的要求 Default: 0 """ self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.val_loss_min = np.Inf self.delta = delta
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss, model) elif score < self.best_score + self.delta: self.counter += 1 print(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(val_loss, model) self.counter = 0
def save_checkpoint(self, val_loss, model): ''' Saves model when validation loss decrease. 验证损失减少时保存模型。 ''' if self.verbose: print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') # torch.save(model.state_dict(), 'checkpoint.pt') # 这里会存储迄今最优模型的参数 torch.save(model, 'finish_model.pkl') # 这里会存储迄今最优的模型 self.val_loss_min = val_loss
early_stopping = EarlyStopping(patience=20, verbose=True) for e in range(1, epochs+1): # train ... # validation ... early_stopping(val_avr, model) if early_stopping.early_stop: print('Early stopping') break model.load_state_dict(torch.load('finish_model.pkl'))
|