Source code for astroNN.nn.callbacks

import csv
import os

import numpy as np
from tensorflow import keras as tfk

Callback = tfk.callbacks.Callback


[docs]class VirutalCSVLogger(Callback): """ A modification of keras' CSVLogger, but not actually write a file until you call method to save :param filename: filename of the log to be saved on disk :type filename: str :param separator: separator of fields :type separator: str :param append: whether allow append or not :type append: bool :return: callback instance :rtype: object :History: | 2018-Feb-22 - Written - Henry Leung (University of Toronto) | 2018-Mar-12 - Update - Henry Leung (University of Toronto) """ def __init__(self, filename="training_history.csv", separator=",", append=False): self.sep = separator self.filename = filename self.append = append self.writer = None self.keys = None self.append_header = True self.csv_file = None self.epoch = [] self.history = {} super().__init__() def on_epoch_end(self, epoch, logs=None): logs = logs or {} self.epoch.append(epoch) for k, v in logs.items(): self.history.setdefault(k, []).append(v)
[docs] def savefile(self, folder_name=None): """ the method to actually save the file to disk :param folder_name: foldername, can be None to save to current directory :type folder_name: Union[NoneType, str] """ if folder_name is not None: full_path = os.path.normpath(os.path.join(os.getcwd(), folder_name)) if not os.path.exists(full_path): os.makedirs(full_path) self.filename = os.path.join(full_path, self.filename) if self.append: if os.path.exists(self.filename): with open(self.filename, "r") as f: self.append_header = not bool(len(f.readline())) self.csv_file = open(self.filename, "a") else: self.csv_file = open(self.filename, "w") class CustomDialect(csv.excel): delimiter = self.sep self.keys = sorted(self.history.keys()) self.writer = csv.DictWriter( self.csv_file, fieldnames=["epoch"] + self.keys, dialect=CustomDialect ) if self.append_header: self.writer.writeheader() for i in self.epoch: self.writer.writerow( { **{"epoch": self.epoch[i]}, **dict([(k, self.history[k][i]) for k in self.keys]), } ) self.csv_file.close()
[docs]class ErrorOnNaN(Callback): """ Callback that raise error when a NaN is encountered. :return: callback instance :rtype: object :History: | 2018-May-07 - Written - Henry Leung (University of Toronto) | 2021-Apr-22 - Written - Henry Leung (University of Toronto) """ def __init__(self, monitor="loss"): super().__init__() self.monitor = monitor def on_batch_end(self, batch, logs=None): logs = logs or {} monitor = logs.get(self.monitor) if monitor is not None: if np.isnan(monitor) or np.isinf(monitor): self.model.stop_training = True raise ValueError( f"Batch {int(batch)}: Invalid loss, terminating training" )