|
|
- import os
- import torch
- import logging.config
- import shutil
- import pandas as pd
- from bokeh.io import output_file, save, show
- from bokeh.plotting import figure
- from bokeh.layouts import column
- #from bokeh.charts import Line, defaults
- #
- #defaults.width = 800
- #defaults.height = 400
- #defaults.tools = 'pan,box_zoom,wheel_zoom,box_select,hover,resize,reset,save'
-
-
- def setup_logging(log_file='log.txt'):
- """Setup logging configuration
- """
- logging.basicConfig(level=logging.DEBUG,
- format="%(asctime)s - %(levelname)s - %(message)s",
- datefmt="%Y-%m-%d %H:%M:%S",
- filename=log_file,
- filemode='w')
- console = logging.StreamHandler()
- console.setLevel(logging.INFO)
- formatter = logging.Formatter('%(message)s')
- console.setFormatter(formatter)
- logging.getLogger('').addHandler(console)
-
-
- class ResultsLog(object):
-
- def __init__(self, path='results.csv', plot_path=None):
- self.path = path
- self.plot_path = plot_path or (self.path + '.html')
- self.figures = []
- self.results = None
-
- def add(self, **kwargs):
- df = pd.DataFrame([kwargs.values()], columns=kwargs.keys())
- if self.results is None:
- self.results = df
- else:
- self.results = self.results.append(df, ignore_index=True)
-
- def save(self, title='Training Results'):
- if len(self.figures) > 0:
- if os.path.isfile(self.plot_path):
- os.remove(self.plot_path)
- output_file(self.plot_path, title=title)
- plot = column(*self.figures)
- save(plot)
- self.figures = []
- self.results.to_csv(self.path, index=False, index_label=False)
-
- def load(self, path=None):
- path = path or self.path
- if os.path.isfile(path):
- self.results.read_csv(path)
-
- def show(self):
- if len(self.figures) > 0:
- plot = column(*self.figures)
- show(plot)
-
- #def plot(self, *kargs, **kwargs):
- # line = Line(data=self.results, *kargs, **kwargs)
- # self.figures.append(line)
-
- def image(self, *kargs, **kwargs):
- fig = figure()
- fig.image(*kargs, **kwargs)
- self.figures.append(fig)
-
-
- def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False):
- filename = os.path.join(path, filename)
- torch.save(state, filename)
- if is_best:
- shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar'))
- if save_all:
- shutil.copyfile(filename, os.path.join(
- path, 'checkpoint_epoch_%s.pth.tar' % state['epoch']))
-
-
- class AverageMeter(object):
- """Computes and stores the average and current value"""
-
- def __init__(self):
- self.reset()
-
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
-
- __optimizers = {
- 'SGD': torch.optim.SGD,
- 'ASGD': torch.optim.ASGD,
- 'Adam': torch.optim.Adam,
- 'Adamax': torch.optim.Adamax,
- 'Adagrad': torch.optim.Adagrad,
- 'Adadelta': torch.optim.Adadelta,
- 'Rprop': torch.optim.Rprop,
- 'RMSprop': torch.optim.RMSprop
- }
-
-
- def adjust_optimizer(optimizer, epoch, config):
- """Reconfigures the optimizer according to epoch and config dict"""
- def modify_optimizer(optimizer, setting):
- if 'optimizer' in setting:
- optimizer = __optimizers[setting['optimizer']](
- optimizer.param_groups)
- logging.debug('OPTIMIZER - setting method = %s' %
- setting['optimizer'])
- for param_group in optimizer.param_groups:
- for key in param_group.keys():
- if key in setting:
- logging.debug('OPTIMIZER - setting %s = %s' %
- (key, setting[key]))
- param_group[key] = setting[key]
- return optimizer
-
- if callable(config):
- optimizer = modify_optimizer(optimizer, config(epoch))
- else:
- for e in range(epoch + 1): # run over all epochs - sticky setting
- if e in config:
- optimizer = modify_optimizer(optimizer, config[e])
-
- return optimizer
-
-
- def accuracy(output, target, topk=(1,)):
- """Computes the precision@k for the specified values of k"""
- maxk = max(topk)
- batch_size = target.size(0)
-
- _, pred = output.float().topk(maxk, 1, True, True)
- pred = pred.t()
- correct = pred.eq(target.view(1, -1).expand_as(pred))
-
- res = []
- for k in topk:
- correct_k = correct[:k].view(-1).float().sum(0)
- res.append(correct_k.mul_(100.0 / batch_size))
- return res
-
- # kernel_img = model.features[0][0].kernel.data.clone()
- # kernel_img.add_(-kernel_img.min())
- # kernel_img.mul_(255 / kernel_img.max())
- # save_image(kernel_img, 'kernel%s.jpg' % epoch)
|