You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

332 lines
14 KiB

  1. import argparse
  2. import os
  3. import time
  4. import logging
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.parallel
  8. import torch.backends.cudnn as cudnn
  9. import torch.optim
  10. import torch.utils.data
  11. import models
  12. from torch.autograd import Variable
  13. from data import get_dataset
  14. from preprocess import get_transform
  15. from utils import *
  16. from datetime import datetime
  17. from ast import literal_eval
  18. from torchvision.utils import save_image
  19. from models.binarized_modules import HingeLoss
  20. model_names = sorted(name for name in models.__dict__
  21. if name.islower() and not name.startswith("__")
  22. and callable(models.__dict__[name]))
  23. parser = argparse.ArgumentParser(description='PyTorch ConvNet Training')
  24. parser.add_argument('--results_dir', metavar='RESULTS_DIR', default='/media/hdd/ihubara/BinaryNet.pytorch/results',
  25. help='results dir')
  26. parser.add_argument('--save', metavar='SAVE', default='',
  27. help='saved folder')
  28. parser.add_argument('--dataset', metavar='DATASET', default='imagenet',
  29. help='dataset name or folder')
  30. parser.add_argument('--model', '-a', metavar='MODEL', default='alexnet',
  31. choices=model_names,
  32. help='model architecture: ' +
  33. ' | '.join(model_names) +
  34. ' (default: alexnet)')
  35. parser.add_argument('--input_size', type=int, default=None,
  36. help='image input size')
  37. parser.add_argument('--model_config', default='',
  38. help='additional architecture configuration')
  39. parser.add_argument('--type', default='torch.cuda.FloatTensor',
  40. help='type of tensor - e.g torch.cuda.HalfTensor')
  41. parser.add_argument('--gpus', default='0',
  42. help='gpus used for training - e.g 0,1,3')
  43. parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
  44. help='number of data loading workers (default: 8)')
  45. parser.add_argument('--epochs', default=900, type=int, metavar='N',
  46. help='number of total epochs to run')
  47. parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
  48. help='manual epoch number (useful on restarts)')
  49. parser.add_argument('-b', '--batch-size', default=256, type=int,
  50. metavar='N', help='mini-batch size (default: 256)')
  51. parser.add_argument('--optimizer', default='SGD', type=str, metavar='OPT',
  52. help='optimizer function used')
  53. parser.add_argument('--lr', '--learning_rate', default=0.1, type=float,
  54. metavar='LR', help='initial learning rate')
  55. parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
  56. help='momentum')
  57. parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
  58. metavar='W', help='weight decay (default: 1e-4)')
  59. parser.add_argument('--print-freq', '-p', default=10, type=int,
  60. metavar='N', help='print frequency (default: 10)')
  61. parser.add_argument('--resume', default='', type=str, metavar='PATH',
  62. help='path to latest checkpoint (default: none)')
  63. parser.add_argument('-e', '--evaluate', type=str, metavar='FILE',
  64. help='evaluate model FILE on validation set')
  65. torch.cuda.random.manual_seed_all(10)
  66. output_dim = 0
  67. def main():
  68. global args, best_prec1, output_dim
  69. best_prec1 = 0
  70. args = parser.parse_args()
  71. output_dim = {'cifar10': 10, 'cifar100':100, 'imagenet': 1000}[args.dataset]
  72. #import pdb; pdb.set_trace()
  73. #torch.save(args.batch_size/(len(args.gpus)/2+1),'multi_gpu_batch_size')
  74. if args.evaluate:
  75. args.results_dir = '/tmp'
  76. if args.save is '':
  77. args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
  78. save_path = os.path.join(args.results_dir, args.save)
  79. if not os.path.exists(save_path):
  80. os.makedirs(save_path)
  81. setup_logging(os.path.join(save_path, 'log.txt'))
  82. results_file = os.path.join(save_path, 'results.%s')
  83. results = ResultsLog(results_file % 'csv', results_file % 'html')
  84. logging.info("saving to %s", save_path)
  85. logging.debug("run arguments: %s", args)
  86. if 'cuda' in args.type:
  87. args.gpus = [int(i) for i in args.gpus.split(',')]
  88. torch.cuda.set_device(args.gpus[0])
  89. cudnn.benchmark = True
  90. else:
  91. args.gpus = None
  92. # create model
  93. logging.info("creating model %s", args.model)
  94. model = models.__dict__[args.model]
  95. model_config = {'input_size': args.input_size, 'dataset': args.dataset, 'num_classes': output_dim}
  96. if args.model_config is not '':
  97. model_config = dict(model_config, **literal_eval(args.model_config))
  98. model = model(**model_config)
  99. logging.info("created model with configuration: %s", model_config)
  100. # optionally resume from a checkpoint
  101. if args.evaluate:
  102. if not os.path.isfile(args.evaluate):
  103. parser.error('invalid checkpoint: {}'.format(args.evaluate))
  104. checkpoint = torch.load(args.evaluate)
  105. model.load_state_dict(checkpoint['state_dict'])
  106. logging.info("loaded checkpoint '%s' (epoch %s)",
  107. args.evaluate, checkpoint['epoch'])
  108. elif args.resume:
  109. checkpoint_file = args.resume
  110. if os.path.isdir(checkpoint_file):
  111. results.load(os.path.join(checkpoint_file, 'results.csv'))
  112. checkpoint_file = os.path.join(
  113. checkpoint_file, 'model_best.pth.tar')
  114. if os.path.isfile(checkpoint_file):
  115. logging.info("loading checkpoint '%s'", args.resume)
  116. checkpoint = torch.load(checkpoint_file)
  117. args.start_epoch = checkpoint['epoch'] - 1
  118. best_prec1 = checkpoint['best_prec1']
  119. model.load_state_dict(checkpoint['state_dict'])
  120. logging.info("loaded checkpoint '%s' (epoch %s)",
  121. checkpoint_file, checkpoint['epoch'])
  122. else:
  123. logging.error("no checkpoint found at '%s'", args.resume)
  124. num_parameters = sum([l.nelement() for l in model.parameters()])
  125. logging.info("number of parameters: %d", num_parameters)
  126. # Data loading code
  127. default_transform = {
  128. 'train': get_transform(args.dataset,
  129. input_size=args.input_size, augment=True),
  130. 'eval': get_transform(args.dataset,
  131. input_size=args.input_size, augment=False)
  132. }
  133. transform = getattr(model, 'input_transform', default_transform)
  134. regime = getattr(model, 'regime', {0: {'optimizer': args.optimizer,
  135. 'lr': args.lr,
  136. 'momentum': args.momentum,
  137. 'weight_decay': args.weight_decay}})
  138. # define loss function (criterion) and optimizer
  139. #criterion = getattr(model, 'criterion', nn.NLLLoss)()
  140. criterion = getattr(model, 'criterion', HingeLoss)()
  141. #criterion.type(args.type)
  142. model.type(args.type)
  143. val_data = get_dataset(args.dataset, 'val', transform['eval'])
  144. val_loader = torch.utils.data.DataLoader(
  145. val_data,
  146. batch_size=args.batch_size, shuffle=False,
  147. num_workers=args.workers, pin_memory=True)
  148. if args.evaluate:
  149. validate(val_loader, model, criterion, 0)
  150. return
  151. train_data = get_dataset(args.dataset, 'train', transform['train'])
  152. train_loader = torch.utils.data.DataLoader(
  153. train_data,
  154. batch_size=args.batch_size, shuffle=True,
  155. num_workers=args.workers, pin_memory=True)
  156. optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
  157. logging.info('training regime: %s', regime)
  158. #import pdb; pdb.set_trace()
  159. #search_binarized_modules(model)
  160. for epoch in range(args.start_epoch, args.epochs):
  161. optimizer = adjust_optimizer(optimizer, epoch, regime)
  162. # train for one epoch
  163. train_loss, train_prec1, train_prec5 = train(
  164. train_loader, model, criterion, epoch, optimizer)
  165. # evaluate on validation set
  166. val_loss, val_prec1, val_prec5 = validate(
  167. val_loader, model, criterion, epoch)
  168. # remember best prec@1 and save checkpoint
  169. is_best = val_prec1 > best_prec1
  170. best_prec1 = max(val_prec1, best_prec1)
  171. save_checkpoint({
  172. 'epoch': epoch + 1,
  173. 'model': args.model,
  174. 'config': args.model_config,
  175. 'state_dict': model.state_dict(),
  176. 'best_prec1': best_prec1,
  177. 'regime': regime
  178. }, is_best, path=save_path)
  179. logging.info('\n Epoch: {0}\t'
  180. 'Training Loss {train_loss:.4f} \t'
  181. 'Training Prec@1 {train_prec1:.3f} \t'
  182. 'Training Prec@5 {train_prec5:.3f} \t'
  183. 'Validation Loss {val_loss:.4f} \t'
  184. 'Validation Prec@1 {val_prec1:.3f} \t'
  185. 'Validation Prec@5 {val_prec5:.3f} \n'
  186. .format(epoch + 1, train_loss=train_loss, val_loss=val_loss,
  187. train_prec1=train_prec1, val_prec1=val_prec1,
  188. train_prec5=train_prec5, val_prec5=val_prec5))
  189. results.add(epoch=epoch + 1, train_loss=train_loss, val_loss=val_loss,
  190. train_error1=100 - train_prec1, val_error1=100 - val_prec1,
  191. train_error5=100 - train_prec5, val_error5=100 - val_prec5)
  192. results.plot(x='epoch', y=['train_loss', 'val_loss'],
  193. title='Loss', ylabel='loss')
  194. results.plot(x='epoch', y=['train_error1', 'val_error1'],
  195. title='Error@1', ylabel='error %')
  196. results.plot(x='epoch', y=['train_error5', 'val_error5'],
  197. title='Error@5', ylabel='error %')
  198. results.save()
  199. def forward(data_loader, model, criterion, epoch=0, training=True, optimizer=None):
  200. if args.gpus and len(args.gpus) > 1:
  201. model = torch.nn.DataParallel(model, args.gpus)
  202. batch_time = AverageMeter()
  203. data_time = AverageMeter()
  204. losses = AverageMeter()
  205. top1 = AverageMeter()
  206. top5 = AverageMeter()
  207. end = time.time()
  208. for i, (inputs, target) in enumerate(data_loader):
  209. # measure data loading time
  210. data_time.update(time.time() - end)
  211. if args.gpus is not None:
  212. target = target.cuda()
  213. #import pdb; pdb.set_trace()
  214. if criterion.__class__.__name__=='HingeLoss':
  215. target=target.unsqueeze(1)
  216. target_onehot = torch.cuda.FloatTensor(target.size(0), output_dim)
  217. target_onehot.fill_(-1)
  218. target_onehot.scatter_(1, target, 1)
  219. target=target.squeeze()
  220. if not training:
  221. with torch.no_grad():
  222. input_var = Variable(inputs.type(args.type))
  223. target_var = Variable(target_onehot)
  224. # compute output
  225. output = model(input_var)
  226. else:
  227. input_var = Variable(inputs.type(args.type))
  228. target_var = Variable(target_onehot)
  229. # compute output
  230. output = model(input_var)
  231. #import pdb; pdb.set_trace()
  232. loss = criterion(output, target_onehot)
  233. #import pdb; pdb.set_trace()
  234. if type(output) is list:
  235. output = output[0]
  236. # measure accuracy and record loss
  237. prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
  238. losses.update(loss.item(), inputs.size(0))
  239. top1.update(prec1.item(), inputs.size(0))
  240. top5.update(prec5.item(), inputs.size(0))
  241. #import pdb; pdb.set_trace()
  242. #if not training and top1.avg<15:
  243. # import pdb; pdb.set_trace()
  244. if training:
  245. # compute gradient and do SGD step
  246. optimizer.zero_grad()
  247. #add backwoed hook
  248. loss.backward()
  249. for p in list(model.parameters()):
  250. #import pdb; pdb.set_trace()
  251. if hasattr(p,'org'):
  252. #print('before:', p[0][0])
  253. #gm=max(p.grad.data.max(),-p.grad.data.min())
  254. #p.grad=p.grad.div(gm+1)
  255. p.data.copy_(p.org)
  256. #print('after:', p[0][0])
  257. optimizer.step()
  258. for p in list(model.parameters()):
  259. #import pdb; pdb.set_trace()
  260. if hasattr(p,'org'):
  261. #print('before:', p[0][0])
  262. p.org.copy_(p.data.clamp_(-1,1))
  263. #if epoch>30:
  264. # import pdb; pdb.set_trace()
  265. # measure elapsed time
  266. batch_time.update(time.time() - end)
  267. end = time.time()
  268. if i % args.print_freq == 0:
  269. logging.info('{phase} - Epoch: [{0}][{1}/{2}]\t'
  270. 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  271. 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
  272. 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
  273. 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
  274. 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
  275. epoch, i, len(data_loader),
  276. phase='TRAINING' if training else 'EVALUATING',
  277. batch_time=batch_time,
  278. data_time=data_time, loss=losses, top1=top1, top5=top5))
  279. return losses.avg, top1.avg, top5.avg
  280. def train(data_loader, model, criterion, epoch, optimizer):
  281. # switch to train mode
  282. model.train()
  283. return forward(data_loader, model, criterion, epoch,
  284. training=True, optimizer=optimizer)
  285. def validate(data_loader, model, criterion, epoch):
  286. # switch to evaluate mode
  287. model.eval()
  288. return forward(data_loader, model, criterion, epoch,
  289. training=False, optimizer=None)
  290. if __name__ == '__main__':
  291. main()