|
|
- import argparse
- import os
- import time
- import logging
- import torch
- import torch.nn as nn
- import torch.nn.parallel
- import torch.backends.cudnn as cudnn
- import torch.optim
- import torch.utils.data
- import models
- from torch.autograd import Variable
- from data import get_dataset
- from preprocess import get_transform
- from utils import *
- from datetime import datetime
- from ast import literal_eval
- from torchvision.utils import save_image
-
-
- model_names = sorted(name for name in models.__dict__
- if name.islower() and not name.startswith("__")
- and callable(models.__dict__[name]))
- <<<<<<< HEAD
- print(model_names)
- =======
- >>>>>>> 0d30f7b8e44285531022cdc05b2c11c22db27e3a
-
- parser = argparse.ArgumentParser(description='PyTorch ConvNet Training')
-
- parser.add_argument('--results_dir', metavar='RESULTS_DIR', default='./results',
- help='results dir')
- parser.add_argument('--save', metavar='SAVE', default='',
- help='saved folder')
- parser.add_argument('--dataset', metavar='DATASET', default='imagenet',
- help='dataset name or folder')
- parser.add_argument('--model', '-a', metavar='MODEL', default='alexnet',
- choices=model_names,
- help='model architecture: ' +
- ' | '.join(model_names) +
- ' (default: alexnet)')
- parser.add_argument('--input_size', type=int, default=None,
- help='image input size')
- parser.add_argument('--model_config', default='',
- help='additional architecture configuration')
- parser.add_argument('--type', default='torch.cuda.FloatTensor',
- help='type of tensor - e.g torch.cuda.HalfTensor')
- parser.add_argument('--gpus', default='0',
- help='gpus used for training - e.g 0,1,3')
- parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
- help='number of data loading workers (default: 8)')
- parser.add_argument('--epochs', default=2500, type=int, metavar='N',
- help='number of total epochs to run')
- parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
- help='manual epoch number (useful on restarts)')
- parser.add_argument('-b', '--batch-size', default=256, type=int,
- metavar='N', help='mini-batch size (default: 256)')
- parser.add_argument('--optimizer', default='SGD', type=str, metavar='OPT',
- help='optimizer function used')
- parser.add_argument('--lr', '--learning_rate', default=0.1, type=float,
- metavar='LR', help='initial learning rate')
- parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
- help='momentum')
- parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
- metavar='W', help='weight decay (default: 1e-4)')
- parser.add_argument('--print-freq', '-p', default=10, type=int,
- metavar='N', help='print frequency (default: 10)')
- parser.add_argument('--resume', default='', type=str, metavar='PATH',
- help='path to latest checkpoint (default: none)')
- parser.add_argument('-e', '--evaluate', type=str, metavar='FILE',
- help='evaluate model FILE on validation set')
-
-
- def main():
- global args, best_prec1
- best_prec1 = 0
- args = parser.parse_args()
-
- if args.evaluate:
- args.results_dir = '/tmp'
- if args.save is '':
- args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
- save_path = os.path.join(args.results_dir, args.save)
- if not os.path.exists(save_path):
- os.makedirs(save_path)
-
- setup_logging(os.path.join(save_path, 'log.txt'))
- results_file = os.path.join(save_path, 'results.%s')
- results = ResultsLog(results_file % 'csv', results_file % 'html')
-
- logging.info("saving to %s", save_path)
- logging.debug("run arguments: %s", args)
-
- if 'cuda' in args.type:
- args.gpus = [int(i) for i in args.gpus.split(',')]
- torch.cuda.set_device(args.gpus[0])
- cudnn.benchmark = True
- else:
- args.gpus = None
-
- # create model
- logging.info("creating model %s", args.model)
- model = models.__dict__[args.model]
- model_config = {'input_size': args.input_size, 'dataset': args.dataset}
-
- if args.model_config is not '':
- model_config = dict(model_config, **literal_eval(args.model_config))
-
- model = model(**model_config)
- logging.info("created model with configuration: %s", model_config)
-
- # optionally resume from a checkpoint
- if args.evaluate:
- if not os.path.isfile(args.evaluate):
- parser.error('invalid checkpoint: {}'.format(args.evaluate))
- checkpoint = torch.load(args.evaluate)
- model.load_state_dict(checkpoint['state_dict'])
- logging.info("loaded checkpoint '%s' (epoch %s)",
- args.evaluate, checkpoint['epoch'])
- elif args.resume:
- checkpoint_file = args.resume
- if os.path.isdir(checkpoint_file):
- results.load(os.path.join(checkpoint_file, 'results.csv'))
- checkpoint_file = os.path.join(
- checkpoint_file, 'model_best.pth.tar')
- if os.path.isfile(checkpoint_file):
- logging.info("loading checkpoint '%s'", args.resume)
- checkpoint = torch.load(checkpoint_file)
- args.start_epoch = checkpoint['epoch'] - 1
- best_prec1 = checkpoint['best_prec1']
- model.load_state_dict(checkpoint['state_dict'])
- logging.info("loaded checkpoint '%s' (epoch %s)",
- checkpoint_file, checkpoint['epoch'])
- else:
- logging.error("no checkpoint found at '%s'", args.resume)
-
- num_parameters = sum([l.nelement() for l in model.parameters()])
- logging.info("number of parameters: %d", num_parameters)
-
- # Data loading code
- default_transform = {
- 'train': get_transform(args.dataset,
- input_size=args.input_size, augment=True),
- 'eval': get_transform(args.dataset,
- input_size=args.input_size, augment=False)
- }
- transform = getattr(model, 'input_transform', default_transform)
- regime = getattr(model, 'regime', {0: {'optimizer': args.optimizer,
- 'lr': args.lr,
- 'momentum': args.momentum,
- 'weight_decay': args.weight_decay}})
- # define loss function (criterion) and optimizer
- criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
- criterion.type(args.type)
- model.type(args.type)
-
- val_data = get_dataset(args.dataset, 'val', transform['eval'])
- val_loader = torch.utils.data.DataLoader(
- val_data,
- batch_size=args.batch_size, shuffle=False,
- num_workers=args.workers, pin_memory=True)
-
- if args.evaluate:
- validate(val_loader, model, criterion, 0)
- return
-
- train_data = get_dataset(args.dataset, 'train', transform['train'])
- train_loader = torch.utils.data.DataLoader(
- train_data,
- batch_size=args.batch_size, shuffle=True,
- num_workers=args.workers, pin_memory=True)
-
- optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
- logging.info('training regime: %s', regime)
-
-
- for epoch in range(args.start_epoch, args.epochs):
- optimizer = adjust_optimizer(optimizer, epoch, regime)
-
- # train for one epoch
- train_loss, train_prec1, train_prec5 = train(
- train_loader, model, criterion, epoch, optimizer)
-
- # evaluate on validation set
- val_loss, val_prec1, val_prec5 = validate(
- val_loader, model, criterion, epoch)
-
- # remember best prec@1 and save checkpoint
- is_best = val_prec1 > best_prec1
- best_prec1 = max(val_prec1, best_prec1)
-
- save_checkpoint({
- 'epoch': epoch + 1,
- 'model': args.model,
- 'config': args.model_config,
- 'state_dict': model.state_dict(),
- 'best_prec1': best_prec1,
- 'regime': regime
- }, is_best, path=save_path)
- logging.info('\n Epoch: {0}\t'
- 'Training Loss {train_loss:.4f} \t'
- 'Training Prec@1 {train_prec1:.3f} \t'
- 'Training Prec@5 {train_prec5:.3f} \t'
- 'Validation Loss {val_loss:.4f} \t'
- 'Validation Prec@1 {val_prec1:.3f} \t'
- 'Validation Prec@5 {val_prec5:.3f} \n'
- .format(epoch + 1, train_loss=train_loss, val_loss=val_loss,
- train_prec1=train_prec1, val_prec1=val_prec1,
- train_prec5=train_prec5, val_prec5=val_prec5))
-
- results.add(epoch=epoch + 1, train_loss=train_loss, val_loss=val_loss,
- train_error1=100 - train_prec1, val_error1=100 - val_prec1,
- train_error5=100 - train_prec5, val_error5=100 - val_prec5)
- #results.plot(x='epoch', y=['train_loss', 'val_loss'],
- # title='Loss', ylabel='loss')
- #results.plot(x='epoch', y=['train_error1', 'val_error1'],
- # title='Error@1', ylabel='error %')
- #results.plot(x='epoch', y=['train_error5', 'val_error5'],
- # title='Error@5', ylabel='error %')
- results.save()
-
-
- def forward(data_loader, model, criterion, epoch=0, training=True, optimizer=None):
- if args.gpus and len(args.gpus) > 1:
- model = torch.nn.DataParallel(model, args.gpus)
- batch_time = AverageMeter()
- data_time = AverageMeter()
- losses = AverageMeter()
- top1 = AverageMeter()
- top5 = AverageMeter()
-
- end = time.time()
- for i, (inputs, target) in enumerate(data_loader):
- # measure data loading time
- data_time.update(time.time() - end)
- if args.gpus is not None:
- target = target.cuda()
-
- if not training:
- with torch.no_grad():
- input_var = Variable(inputs.type(args.type), volatile=not training)
- target_var = Variable(target)
- # compute output
- output = model(input_var)
- else:
- input_var = Variable(inputs.type(args.type), volatile=not training)
- target_var = Variable(target)
- # compute output
- output = model(input_var)
-
-
- loss = criterion(output, target_var)
- if type(output) is list:
- output = output[0]
-
- # measure accuracy and record loss
- prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
- losses.update(loss.item(), inputs.size(0))
- top1.update(prec1.item(), inputs.size(0))
- top5.update(prec5.item(), inputs.size(0))
-
- if training:
- # compute gradient and do SGD step
- optimizer.zero_grad()
- loss.backward()
- for p in list(model.parameters()):
- if hasattr(p,'org'):
- p.data.copy_(p.org)
- optimizer.step()
- for p in list(model.parameters()):
- if hasattr(p,'org'):
- p.org.copy_(p.data.clamp_(-1,1))
-
-
- # measure elapsed time
- batch_time.update(time.time() - end)
- end = time.time()
-
- if i % args.print_freq == 0:
- logging.info('{phase} - Epoch: [{0}][{1}/{2}]\t'
- 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
- 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
- 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
- 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
- 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
- epoch, i, len(data_loader),
- phase='TRAINING' if training else 'EVALUATING',
- batch_time=batch_time,
- data_time=data_time, loss=losses, top1=top1, top5=top5))
-
- return losses.avg, top1.avg, top5.avg
-
-
- def train(data_loader, model, criterion, epoch, optimizer):
- # switch to train mode
- model.train()
- return forward(data_loader, model, criterion, epoch,
- training=True, optimizer=optimizer)
-
-
- def validate(data_loader, model, criterion, epoch):
- # switch to evaluate mode
- model.eval()
- return forward(data_loader, model, criterion, epoch,
- training=False, optimizer=None)
-
-
- if __name__ == '__main__':
- main()
|