import os import torch import logging.config import shutil import pandas as pd from bokeh.io import output_file, save, show from bokeh.plotting import figure from bokeh.layouts import column #from bokeh.charts import Line, defaults # #defaults.width = 800 #defaults.height = 400 #defaults.tools = 'pan,box_zoom,wheel_zoom,box_select,hover,resize,reset,save' def setup_logging(log_file='log.txt'): """Setup logging configuration """ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", filename=log_file, filemode='w') console = logging.StreamHandler() console.setLevel(logging.INFO) formatter = logging.Formatter('%(message)s') console.setFormatter(formatter) logging.getLogger('').addHandler(console) class ResultsLog(object): def __init__(self, path='results.csv', plot_path=None): self.path = path self.plot_path = plot_path or (self.path + '.html') self.figures = [] self.results = None def add(self, **kwargs): df = pd.DataFrame([kwargs.values()], columns=kwargs.keys()) if self.results is None: self.results = df else: self.results = self.results.append(df, ignore_index=True) def save(self, title='Training Results'): if len(self.figures) > 0: if os.path.isfile(self.plot_path): os.remove(self.plot_path) output_file(self.plot_path, title=title) plot = column(*self.figures) save(plot) self.figures = [] self.results.to_csv(self.path, index=False, index_label=False) def load(self, path=None): path = path or self.path if os.path.isfile(path): self.results.read_csv(path) def show(self): if len(self.figures) > 0: plot = column(*self.figures) show(plot) #def plot(self, *kargs, **kwargs): # line = Line(data=self.results, *kargs, **kwargs) # self.figures.append(line) def image(self, *kargs, **kwargs): fig = figure() fig.image(*kargs, **kwargs) self.figures.append(fig) def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False): filename = os.path.join(path, filename) torch.save(state, filename) if is_best: shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar')) if save_all: shutil.copyfile(filename, os.path.join( path, 'checkpoint_epoch_%s.pth.tar' % state['epoch'])) class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count __optimizers = { 'SGD': torch.optim.SGD, 'ASGD': torch.optim.ASGD, 'Adam': torch.optim.Adam, 'Adamax': torch.optim.Adamax, 'Adagrad': torch.optim.Adagrad, 'Adadelta': torch.optim.Adadelta, 'Rprop': torch.optim.Rprop, 'RMSprop': torch.optim.RMSprop } def adjust_optimizer(optimizer, epoch, config): """Reconfigures the optimizer according to epoch and config dict""" def modify_optimizer(optimizer, setting): if 'optimizer' in setting: optimizer = __optimizers[setting['optimizer']]( optimizer.param_groups) logging.debug('OPTIMIZER - setting method = %s' % setting['optimizer']) for param_group in optimizer.param_groups: for key in param_group.keys(): if key in setting: logging.debug('OPTIMIZER - setting %s = %s' % (key, setting[key])) param_group[key] = setting[key] return optimizer if callable(config): optimizer = modify_optimizer(optimizer, config(epoch)) else: for e in range(epoch + 1): # run over all epochs - sticky setting if e in config: optimizer = modify_optimizer(optimizer, config[e]) return optimizer def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.float().topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res # kernel_img = model.features[0][0].kernel.data.clone() # kernel_img.add_(-kernel_img.min()) # kernel_img.mul_(255 / kernel_img.max()) # save_image(kernel_img, 'kernel%s.jpg' % epoch)