You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

160 lines
4.9 KiB

  1. import os
  2. import torch
  3. import logging.config
  4. import shutil
  5. import pandas as pd
  6. from bokeh.io import output_file, save, show
  7. from bokeh.plotting import figure
  8. from bokeh.layouts import column
  9. #from bokeh.charts import Line, defaults
  10. #
  11. #defaults.width = 800
  12. #defaults.height = 400
  13. #defaults.tools = 'pan,box_zoom,wheel_zoom,box_select,hover,resize,reset,save'
  14. def setup_logging(log_file='log.txt'):
  15. """Setup logging configuration
  16. """
  17. logging.basicConfig(level=logging.DEBUG,
  18. format="%(asctime)s - %(levelname)s - %(message)s",
  19. datefmt="%Y-%m-%d %H:%M:%S",
  20. filename=log_file,
  21. filemode='w')
  22. console = logging.StreamHandler()
  23. console.setLevel(logging.INFO)
  24. formatter = logging.Formatter('%(message)s')
  25. console.setFormatter(formatter)
  26. logging.getLogger('').addHandler(console)
  27. class ResultsLog(object):
  28. def __init__(self, path='results.csv', plot_path=None):
  29. self.path = path
  30. self.plot_path = plot_path or (self.path + '.html')
  31. self.figures = []
  32. self.results = None
  33. def add(self, **kwargs):
  34. df = pd.DataFrame([kwargs.values()], columns=kwargs.keys())
  35. if self.results is None:
  36. self.results = df
  37. else:
  38. self.results = self.results.append(df, ignore_index=True)
  39. def save(self, title='Training Results'):
  40. if len(self.figures) > 0:
  41. if os.path.isfile(self.plot_path):
  42. os.remove(self.plot_path)
  43. output_file(self.plot_path, title=title)
  44. plot = column(*self.figures)
  45. save(plot)
  46. self.figures = []
  47. self.results.to_csv(self.path, index=False, index_label=False)
  48. def load(self, path=None):
  49. path = path or self.path
  50. if os.path.isfile(path):
  51. self.results.read_csv(path)
  52. def show(self):
  53. if len(self.figures) > 0:
  54. plot = column(*self.figures)
  55. show(plot)
  56. #def plot(self, *kargs, **kwargs):
  57. # line = Line(data=self.results, *kargs, **kwargs)
  58. # self.figures.append(line)
  59. def image(self, *kargs, **kwargs):
  60. fig = figure()
  61. fig.image(*kargs, **kwargs)
  62. self.figures.append(fig)
  63. def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False):
  64. filename = os.path.join(path, filename)
  65. torch.save(state, filename)
  66. if is_best:
  67. shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar'))
  68. if save_all:
  69. shutil.copyfile(filename, os.path.join(
  70. path, 'checkpoint_epoch_%s.pth.tar' % state['epoch']))
  71. class AverageMeter(object):
  72. """Computes and stores the average and current value"""
  73. def __init__(self):
  74. self.reset()
  75. def reset(self):
  76. self.val = 0
  77. self.avg = 0
  78. self.sum = 0
  79. self.count = 0
  80. def update(self, val, n=1):
  81. self.val = val
  82. self.sum += val * n
  83. self.count += n
  84. self.avg = self.sum / self.count
  85. __optimizers = {
  86. 'SGD': torch.optim.SGD,
  87. 'ASGD': torch.optim.ASGD,
  88. 'Adam': torch.optim.Adam,
  89. 'Adamax': torch.optim.Adamax,
  90. 'Adagrad': torch.optim.Adagrad,
  91. 'Adadelta': torch.optim.Adadelta,
  92. 'Rprop': torch.optim.Rprop,
  93. 'RMSprop': torch.optim.RMSprop
  94. }
  95. def adjust_optimizer(optimizer, epoch, config):
  96. """Reconfigures the optimizer according to epoch and config dict"""
  97. def modify_optimizer(optimizer, setting):
  98. if 'optimizer' in setting:
  99. optimizer = __optimizers[setting['optimizer']](
  100. optimizer.param_groups)
  101. logging.debug('OPTIMIZER - setting method = %s' %
  102. setting['optimizer'])
  103. for param_group in optimizer.param_groups:
  104. for key in param_group.keys():
  105. if key in setting:
  106. logging.debug('OPTIMIZER - setting %s = %s' %
  107. (key, setting[key]))
  108. param_group[key] = setting[key]
  109. return optimizer
  110. if callable(config):
  111. optimizer = modify_optimizer(optimizer, config(epoch))
  112. else:
  113. for e in range(epoch + 1): # run over all epochs - sticky setting
  114. if e in config:
  115. optimizer = modify_optimizer(optimizer, config[e])
  116. return optimizer
  117. def accuracy(output, target, topk=(1,)):
  118. """Computes the precision@k for the specified values of k"""
  119. maxk = max(topk)
  120. batch_size = target.size(0)
  121. _, pred = output.float().topk(maxk, 1, True, True)
  122. pred = pred.t()
  123. correct = pred.eq(target.view(1, -1).expand_as(pred))
  124. res = []
  125. for k in topk:
  126. correct_k = correct[:k].view(-1).float().sum(0)
  127. res.append(correct_k.mul_(100.0 / batch_size))
  128. return res
  129. # kernel_img = model.features[0][0].kernel.data.clone()
  130. # kernel_img.add_(-kernel_img.min())
  131. # kernel_img.mul_(255 / kernel_img.max())
  132. # save_image(kernel_img, 'kernel%s.jpg' % epoch)