Source code for astroNN.nn.callbacks
import csv
import os
import numpy as np
from astroNN.config import keras_import_manager
keras = keras_import_manager()
Callback = keras.callbacks.Callback
[docs]class VirutalCSVLogger(Callback):
"""
A modification of keras' CSVLogger, but not actually write a file until you call method to save
:param filename: filename of the log to be saved on disk
:type filename: str
:param separator: separator of fields
:type separator: str
:param append: whether allow append or not
:type append: bool
:return: callback instance
:rtype: object
:History:
| 2018-Feb-22 - Written - Henry Leung (University of Toronto)
| 2018-Mar-12 - Update - Henry Leung (University of Toronto)
"""
def __init__(self, filename='training_history.csv', separator=',', append=False):
self.sep = separator
self.filename = filename
self.append = append
self.writer = None
self.keys = None
self.append_header = True
self.csv_file = None
self.epoch = []
self.history = {}
super().__init__()
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epoch.append(epoch)
for k, v in logs.items():
self.history.setdefault(k, []).append(v)
[docs] def savefile(self, folder_name=None):
"""
the method to actually save the file to disk
:param folder_name: foldername, can be None to save to current directory
:type folder_name: Union[NoneType, str]
"""
if folder_name is not None:
full_path = os.path.normpath(folder_name)
if not os.path.exists(full_path):
os.makedirs(full_path)
self.filename = os.path.join(full_path, self.filename)
if self.append:
if os.path.exists(self.filename):
with open(self.filename, 'r') as f:
self.append_header = not bool(len(f.readline()))
self.csv_file = open(self.filename, 'a')
else:
self.csv_file = open(self.filename, 'w')
class CustomDialect(csv.excel):
delimiter = self.sep
self.keys = sorted(self.history.keys())
self.writer = csv.DictWriter(self.csv_file, fieldnames=['epoch'] + self.keys, dialect=CustomDialect)
if self.append_header:
self.writer.writeheader()
for i in self.epoch:
self.writer.writerow({**{'epoch': self.epoch[i]}, **dict([(k, self.history[k][i]) for k in self.keys])})
self.csv_file.close()
[docs]class ErrorOnNaN(Callback):
"""
Callback that raise error when a NaN loss is encountered.
:return: callback instance
:rtype: object
:History: 2018-May-07 - Written - Henry Leung (University of Toronto)
"""
def __init__(self):
super().__init__()
def on_batch_end(self, batch, logs=None):
logs = logs or {}
loss = logs.get('loss')
if loss is not None:
if np.isnan(loss) or np.isinf(loss):
self.model.stop_training = True
raise ValueError(f'Batch {int(batch)}: Invalid loss, terminating training')