You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

309 lines
12 KiB

  1. import argparse
  2. import os
  3. import time
  4. import logging
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.parallel
  8. import torch.backends.cudnn as cudnn
  9. import torch.optim
  10. import torch.utils.data
  11. import models
  12. from torch.autograd import Variable
  13. from data import get_dataset
  14. from preprocess import get_transform
  15. from utils import *
  16. from datetime import datetime
  17. from ast import literal_eval
  18. from torchvision.utils import save_image
  19. model_names = sorted(name for name in models.__dict__
  20. if name.islower() and not name.startswith("__")
  21. and callable(models.__dict__[name]))
  22. <<<<<<< HEAD
  23. print(model_names)
  24. =======
  25. >>>>>>> 0d30f7b8e44285531022cdc05b2c11c22db27e3a
  26. parser = argparse.ArgumentParser(description='PyTorch ConvNet Training')
  27. parser.add_argument('--results_dir', metavar='RESULTS_DIR', default='./results',
  28. help='results dir')
  29. parser.add_argument('--save', metavar='SAVE', default='',
  30. help='saved folder')
  31. parser.add_argument('--dataset', metavar='DATASET', default='imagenet',
  32. help='dataset name or folder')
  33. parser.add_argument('--model', '-a', metavar='MODEL', default='alexnet',
  34. choices=model_names,
  35. help='model architecture: ' +
  36. ' | '.join(model_names) +
  37. ' (default: alexnet)')
  38. parser.add_argument('--input_size', type=int, default=None,
  39. help='image input size')
  40. parser.add_argument('--model_config', default='',
  41. help='additional architecture configuration')
  42. parser.add_argument('--type', default='torch.cuda.FloatTensor',
  43. help='type of tensor - e.g torch.cuda.HalfTensor')
  44. parser.add_argument('--gpus', default='0',
  45. help='gpus used for training - e.g 0,1,3')
  46. parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
  47. help='number of data loading workers (default: 8)')
  48. parser.add_argument('--epochs', default=2500, type=int, metavar='N',
  49. help='number of total epochs to run')
  50. parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
  51. help='manual epoch number (useful on restarts)')
  52. parser.add_argument('-b', '--batch-size', default=256, type=int,
  53. metavar='N', help='mini-batch size (default: 256)')
  54. parser.add_argument('--optimizer', default='SGD', type=str, metavar='OPT',
  55. help='optimizer function used')
  56. parser.add_argument('--lr', '--learning_rate', default=0.1, type=float,
  57. metavar='LR', help='initial learning rate')
  58. parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
  59. help='momentum')
  60. parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
  61. metavar='W', help='weight decay (default: 1e-4)')
  62. parser.add_argument('--print-freq', '-p', default=10, type=int,
  63. metavar='N', help='print frequency (default: 10)')
  64. parser.add_argument('--resume', default='', type=str, metavar='PATH',
  65. help='path to latest checkpoint (default: none)')
  66. parser.add_argument('-e', '--evaluate', type=str, metavar='FILE',
  67. help='evaluate model FILE on validation set')
  68. def main():
  69. global args, best_prec1
  70. best_prec1 = 0
  71. args = parser.parse_args()
  72. if args.evaluate:
  73. args.results_dir = '/tmp'
  74. if args.save is '':
  75. args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
  76. save_path = os.path.join(args.results_dir, args.save)
  77. if not os.path.exists(save_path):
  78. os.makedirs(save_path)
  79. setup_logging(os.path.join(save_path, 'log.txt'))
  80. results_file = os.path.join(save_path, 'results.%s')
  81. results = ResultsLog(results_file % 'csv', results_file % 'html')
  82. logging.info("saving to %s", save_path)
  83. logging.debug("run arguments: %s", args)
  84. if 'cuda' in args.type:
  85. args.gpus = [int(i) for i in args.gpus.split(',')]
  86. torch.cuda.set_device(args.gpus[0])
  87. cudnn.benchmark = True
  88. else:
  89. args.gpus = None
  90. # create model
  91. logging.info("creating model %s", args.model)
  92. model = models.__dict__[args.model]
  93. model_config = {'input_size': args.input_size, 'dataset': args.dataset}
  94. if args.model_config is not '':
  95. model_config = dict(model_config, **literal_eval(args.model_config))
  96. model = model(**model_config)
  97. logging.info("created model with configuration: %s", model_config)
  98. # optionally resume from a checkpoint
  99. if args.evaluate:
  100. if not os.path.isfile(args.evaluate):
  101. parser.error('invalid checkpoint: {}'.format(args.evaluate))
  102. checkpoint = torch.load(args.evaluate)
  103. model.load_state_dict(checkpoint['state_dict'])
  104. logging.info("loaded checkpoint '%s' (epoch %s)",
  105. args.evaluate, checkpoint['epoch'])
  106. elif args.resume:
  107. checkpoint_file = args.resume
  108. if os.path.isdir(checkpoint_file):
  109. results.load(os.path.join(checkpoint_file, 'results.csv'))
  110. checkpoint_file = os.path.join(
  111. checkpoint_file, 'model_best.pth.tar')
  112. if os.path.isfile(checkpoint_file):
  113. logging.info("loading checkpoint '%s'", args.resume)
  114. checkpoint = torch.load(checkpoint_file)
  115. args.start_epoch = checkpoint['epoch'] - 1
  116. best_prec1 = checkpoint['best_prec1']
  117. model.load_state_dict(checkpoint['state_dict'])
  118. logging.info("loaded checkpoint '%s' (epoch %s)",
  119. checkpoint_file, checkpoint['epoch'])
  120. else:
  121. logging.error("no checkpoint found at '%s'", args.resume)
  122. num_parameters = sum([l.nelement() for l in model.parameters()])
  123. logging.info("number of parameters: %d", num_parameters)
  124. # Data loading code
  125. default_transform = {
  126. 'train': get_transform(args.dataset,
  127. input_size=args.input_size, augment=True),
  128. 'eval': get_transform(args.dataset,
  129. input_size=args.input_size, augment=False)
  130. }
  131. transform = getattr(model, 'input_transform', default_transform)
  132. regime = getattr(model, 'regime', {0: {'optimizer': args.optimizer,
  133. 'lr': args.lr,
  134. 'momentum': args.momentum,
  135. 'weight_decay': args.weight_decay}})
  136. # define loss function (criterion) and optimizer
  137. criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
  138. criterion.type(args.type)
  139. model.type(args.type)
  140. val_data = get_dataset(args.dataset, 'val', transform['eval'])
  141. val_loader = torch.utils.data.DataLoader(
  142. val_data,
  143. batch_size=args.batch_size, shuffle=False,
  144. num_workers=args.workers, pin_memory=True)
  145. if args.evaluate:
  146. validate(val_loader, model, criterion, 0)
  147. return
  148. train_data = get_dataset(args.dataset, 'train', transform['train'])
  149. train_loader = torch.utils.data.DataLoader(
  150. train_data,
  151. batch_size=args.batch_size, shuffle=True,
  152. num_workers=args.workers, pin_memory=True)
  153. optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
  154. logging.info('training regime: %s', regime)
  155. for epoch in range(args.start_epoch, args.epochs):
  156. optimizer = adjust_optimizer(optimizer, epoch, regime)
  157. # train for one epoch
  158. train_loss, train_prec1, train_prec5 = train(
  159. train_loader, model, criterion, epoch, optimizer)
  160. # evaluate on validation set
  161. val_loss, val_prec1, val_prec5 = validate(
  162. val_loader, model, criterion, epoch)
  163. # remember best prec@1 and save checkpoint
  164. is_best = val_prec1 > best_prec1
  165. best_prec1 = max(val_prec1, best_prec1)
  166. save_checkpoint({
  167. 'epoch': epoch + 1,
  168. 'model': args.model,
  169. 'config': args.model_config,
  170. 'state_dict': model.state_dict(),
  171. 'best_prec1': best_prec1,
  172. 'regime': regime
  173. }, is_best, path=save_path)
  174. logging.info('\n Epoch: {0}\t'
  175. 'Training Loss {train_loss:.4f} \t'
  176. 'Training Prec@1 {train_prec1:.3f} \t'
  177. 'Training Prec@5 {train_prec5:.3f} \t'
  178. 'Validation Loss {val_loss:.4f} \t'
  179. 'Validation Prec@1 {val_prec1:.3f} \t'
  180. 'Validation Prec@5 {val_prec5:.3f} \n'
  181. .format(epoch + 1, train_loss=train_loss, val_loss=val_loss,
  182. train_prec1=train_prec1, val_prec1=val_prec1,
  183. train_prec5=train_prec5, val_prec5=val_prec5))
  184. results.add(epoch=epoch + 1, train_loss=train_loss, val_loss=val_loss,
  185. train_error1=100 - train_prec1, val_error1=100 - val_prec1,
  186. train_error5=100 - train_prec5, val_error5=100 - val_prec5)
  187. #results.plot(x='epoch', y=['train_loss', 'val_loss'],
  188. # title='Loss', ylabel='loss')
  189. #results.plot(x='epoch', y=['train_error1', 'val_error1'],
  190. # title='Error@1', ylabel='error %')
  191. #results.plot(x='epoch', y=['train_error5', 'val_error5'],
  192. # title='Error@5', ylabel='error %')
  193. results.save()
  194. def forward(data_loader, model, criterion, epoch=0, training=True, optimizer=None):
  195. if args.gpus and len(args.gpus) > 1:
  196. model = torch.nn.DataParallel(model, args.gpus)
  197. batch_time = AverageMeter()
  198. data_time = AverageMeter()
  199. losses = AverageMeter()
  200. top1 = AverageMeter()
  201. top5 = AverageMeter()
  202. end = time.time()
  203. for i, (inputs, target) in enumerate(data_loader):
  204. # measure data loading time
  205. data_time.update(time.time() - end)
  206. if args.gpus is not None:
  207. target = target.cuda()
  208. if not training:
  209. with torch.no_grad():
  210. input_var = Variable(inputs.type(args.type), volatile=not training)
  211. target_var = Variable(target)
  212. # compute output
  213. output = model(input_var)
  214. else:
  215. input_var = Variable(inputs.type(args.type), volatile=not training)
  216. target_var = Variable(target)
  217. # compute output
  218. output = model(input_var)
  219. loss = criterion(output, target_var)
  220. if type(output) is list:
  221. output = output[0]
  222. # measure accuracy and record loss
  223. prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
  224. losses.update(loss.item(), inputs.size(0))
  225. top1.update(prec1.item(), inputs.size(0))
  226. top5.update(prec5.item(), inputs.size(0))
  227. if training:
  228. # compute gradient and do SGD step
  229. optimizer.zero_grad()
  230. loss.backward()
  231. for p in list(model.parameters()):
  232. if hasattr(p,'org'):
  233. p.data.copy_(p.org)
  234. optimizer.step()
  235. for p in list(model.parameters()):
  236. if hasattr(p,'org'):
  237. p.org.copy_(p.data.clamp_(-1,1))
  238. # measure elapsed time
  239. batch_time.update(time.time() - end)
  240. end = time.time()
  241. if i % args.print_freq == 0:
  242. logging.info('{phase} - Epoch: [{0}][{1}/{2}]\t'
  243. 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  244. 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
  245. 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
  246. 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
  247. 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
  248. epoch, i, len(data_loader),
  249. phase='TRAINING' if training else 'EVALUATING',
  250. batch_time=batch_time,
  251. data_time=data_time, loss=losses, top1=top1, top5=top5))
  252. return losses.avg, top1.avg, top5.avg
  253. def train(data_loader, model, criterion, epoch, optimizer):
  254. # switch to train mode
  255. model.train()
  256. return forward(data_loader, model, criterion, epoch,
  257. training=True, optimizer=optimizer)
  258. def validate(data_loader, model, criterion, epoch):
  259. # switch to evaluate mode
  260. model.eval()
  261. return forward(data_loader, model, criterion, epoch,
  262. training=False, optimizer=None)
  263. if __name__ == '__main__':
  264. main()