import os
|
|
import torch
|
|
import logging.config
|
|
import shutil
|
|
import pandas as pd
|
|
from bokeh.io import output_file, save, show
|
|
from bokeh.plotting import figure
|
|
from bokeh.layouts import column
|
|
#from bokeh.charts import Line, defaults
|
|
#
|
|
#defaults.width = 800
|
|
#defaults.height = 400
|
|
#defaults.tools = 'pan,box_zoom,wheel_zoom,box_select,hover,resize,reset,save'
|
|
|
|
|
|
def setup_logging(log_file='log.txt'):
|
|
"""Setup logging configuration
|
|
"""
|
|
logging.basicConfig(level=logging.DEBUG,
|
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
filename=log_file,
|
|
filemode='w')
|
|
console = logging.StreamHandler()
|
|
console.setLevel(logging.INFO)
|
|
formatter = logging.Formatter('%(message)s')
|
|
console.setFormatter(formatter)
|
|
logging.getLogger('').addHandler(console)
|
|
|
|
|
|
class ResultsLog(object):
|
|
|
|
def __init__(self, path='results.csv', plot_path=None):
|
|
self.path = path
|
|
self.plot_path = plot_path or (self.path + '.html')
|
|
self.figures = []
|
|
self.results = None
|
|
|
|
def add(self, **kwargs):
|
|
df = pd.DataFrame([kwargs.values()], columns=kwargs.keys())
|
|
if self.results is None:
|
|
self.results = df
|
|
else:
|
|
self.results = self.results.append(df, ignore_index=True)
|
|
|
|
def save(self, title='Training Results'):
|
|
if len(self.figures) > 0:
|
|
if os.path.isfile(self.plot_path):
|
|
os.remove(self.plot_path)
|
|
output_file(self.plot_path, title=title)
|
|
plot = column(*self.figures)
|
|
save(plot)
|
|
self.figures = []
|
|
self.results.to_csv(self.path, index=False, index_label=False)
|
|
|
|
def load(self, path=None):
|
|
path = path or self.path
|
|
if os.path.isfile(path):
|
|
self.results.read_csv(path)
|
|
|
|
def show(self):
|
|
if len(self.figures) > 0:
|
|
plot = column(*self.figures)
|
|
show(plot)
|
|
|
|
#def plot(self, *kargs, **kwargs):
|
|
# line = Line(data=self.results, *kargs, **kwargs)
|
|
# self.figures.append(line)
|
|
|
|
def image(self, *kargs, **kwargs):
|
|
fig = figure()
|
|
fig.image(*kargs, **kwargs)
|
|
self.figures.append(fig)
|
|
|
|
|
|
def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False):
|
|
filename = os.path.join(path, filename)
|
|
torch.save(state, filename)
|
|
if is_best:
|
|
shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar'))
|
|
if save_all:
|
|
shutil.copyfile(filename, os.path.join(
|
|
path, 'checkpoint_epoch_%s.pth.tar' % state['epoch']))
|
|
|
|
|
|
class AverageMeter(object):
|
|
"""Computes and stores the average and current value"""
|
|
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.val = 0
|
|
self.avg = 0
|
|
self.sum = 0
|
|
self.count = 0
|
|
|
|
def update(self, val, n=1):
|
|
self.val = val
|
|
self.sum += val * n
|
|
self.count += n
|
|
self.avg = self.sum / self.count
|
|
|
|
__optimizers = {
|
|
'SGD': torch.optim.SGD,
|
|
'ASGD': torch.optim.ASGD,
|
|
'Adam': torch.optim.Adam,
|
|
'Adamax': torch.optim.Adamax,
|
|
'Adagrad': torch.optim.Adagrad,
|
|
'Adadelta': torch.optim.Adadelta,
|
|
'Rprop': torch.optim.Rprop,
|
|
'RMSprop': torch.optim.RMSprop
|
|
}
|
|
|
|
|
|
def adjust_optimizer(optimizer, epoch, config):
|
|
"""Reconfigures the optimizer according to epoch and config dict"""
|
|
def modify_optimizer(optimizer, setting):
|
|
if 'optimizer' in setting:
|
|
optimizer = __optimizers[setting['optimizer']](
|
|
optimizer.param_groups)
|
|
logging.debug('OPTIMIZER - setting method = %s' %
|
|
setting['optimizer'])
|
|
for param_group in optimizer.param_groups:
|
|
for key in param_group.keys():
|
|
if key in setting:
|
|
logging.debug('OPTIMIZER - setting %s = %s' %
|
|
(key, setting[key]))
|
|
param_group[key] = setting[key]
|
|
return optimizer
|
|
|
|
if callable(config):
|
|
optimizer = modify_optimizer(optimizer, config(epoch))
|
|
else:
|
|
for e in range(epoch + 1): # run over all epochs - sticky setting
|
|
if e in config:
|
|
optimizer = modify_optimizer(optimizer, config[e])
|
|
|
|
return optimizer
|
|
|
|
|
|
def accuracy(output, target, topk=(1,)):
|
|
"""Computes the precision@k for the specified values of k"""
|
|
maxk = max(topk)
|
|
batch_size = target.size(0)
|
|
|
|
_, pred = output.float().topk(maxk, 1, True, True)
|
|
pred = pred.t()
|
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
|
|
res = []
|
|
for k in topk:
|
|
correct_k = correct[:k].view(-1).float().sum(0)
|
|
res.append(correct_k.mul_(100.0 / batch_size))
|
|
return res
|
|
|
|
# kernel_img = model.features[0][0].kernel.data.clone()
|
|
# kernel_img.add_(-kernel_img.min())
|
|
# kernel_img.mul_(255 / kernel_img.max())
|
|
# save_image(kernel_img, 'kernel%s.jpg' % epoch)
|