You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

160 lines
4.9 KiB

import os
import torch
import logging.config
import shutil
import pandas as pd
from bokeh.io import output_file, save, show
from bokeh.plotting import figure
from bokeh.layouts import column
#from bokeh.charts import Line, defaults
#
#defaults.width = 800
#defaults.height = 400
#defaults.tools = 'pan,box_zoom,wheel_zoom,box_select,hover,resize,reset,save'
def setup_logging(log_file='log.txt'):
"""Setup logging configuration
"""
logging.basicConfig(level=logging.DEBUG,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
filename=log_file,
filemode='w')
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
class ResultsLog(object):
def __init__(self, path='results.csv', plot_path=None):
self.path = path
self.plot_path = plot_path or (self.path + '.html')
self.figures = []
self.results = None
def add(self, **kwargs):
df = pd.DataFrame([kwargs.values()], columns=kwargs.keys())
if self.results is None:
self.results = df
else:
self.results = self.results.append(df, ignore_index=True)
def save(self, title='Training Results'):
if len(self.figures) > 0:
if os.path.isfile(self.plot_path):
os.remove(self.plot_path)
output_file(self.plot_path, title=title)
plot = column(*self.figures)
save(plot)
self.figures = []
self.results.to_csv(self.path, index=False, index_label=False)
def load(self, path=None):
path = path or self.path
if os.path.isfile(path):
self.results.read_csv(path)
def show(self):
if len(self.figures) > 0:
plot = column(*self.figures)
show(plot)
#def plot(self, *kargs, **kwargs):
# line = Line(data=self.results, *kargs, **kwargs)
# self.figures.append(line)
def image(self, *kargs, **kwargs):
fig = figure()
fig.image(*kargs, **kwargs)
self.figures.append(fig)
def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False):
filename = os.path.join(path, filename)
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar'))
if save_all:
shutil.copyfile(filename, os.path.join(
path, 'checkpoint_epoch_%s.pth.tar' % state['epoch']))
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
__optimizers = {
'SGD': torch.optim.SGD,
'ASGD': torch.optim.ASGD,
'Adam': torch.optim.Adam,
'Adamax': torch.optim.Adamax,
'Adagrad': torch.optim.Adagrad,
'Adadelta': torch.optim.Adadelta,
'Rprop': torch.optim.Rprop,
'RMSprop': torch.optim.RMSprop
}
def adjust_optimizer(optimizer, epoch, config):
"""Reconfigures the optimizer according to epoch and config dict"""
def modify_optimizer(optimizer, setting):
if 'optimizer' in setting:
optimizer = __optimizers[setting['optimizer']](
optimizer.param_groups)
logging.debug('OPTIMIZER - setting method = %s' %
setting['optimizer'])
for param_group in optimizer.param_groups:
for key in param_group.keys():
if key in setting:
logging.debug('OPTIMIZER - setting %s = %s' %
(key, setting[key]))
param_group[key] = setting[key]
return optimizer
if callable(config):
optimizer = modify_optimizer(optimizer, config(epoch))
else:
for e in range(epoch + 1): # run over all epochs - sticky setting
if e in config:
optimizer = modify_optimizer(optimizer, config[e])
return optimizer
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.float().topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
# kernel_img = model.features[0][0].kernel.data.clone()
# kernel_img.add_(-kernel_img.min())
# kernel_img.mul_(255 / kernel_img.max())
# save_image(kernel_img, 'kernel%s.jpg' % epoch)