You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

309 lines
12 KiB

import argparse
import os
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import models
from torch.autograd import Variable
from data import get_dataset
from preprocess import get_transform
from utils import *
from datetime import datetime
from ast import literal_eval
from torchvision.utils import save_image
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
<<<<<<< HEAD
print(model_names)
=======
>>>>>>> 0d30f7b8e44285531022cdc05b2c11c22db27e3a
parser = argparse.ArgumentParser(description='PyTorch ConvNet Training')
parser.add_argument('--results_dir', metavar='RESULTS_DIR', default='./results',
help='results dir')
parser.add_argument('--save', metavar='SAVE', default='',
help='saved folder')
parser.add_argument('--dataset', metavar='DATASET', default='imagenet',
help='dataset name or folder')
parser.add_argument('--model', '-a', metavar='MODEL', default='alexnet',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: alexnet)')
parser.add_argument('--input_size', type=int, default=None,
help='image input size')
parser.add_argument('--model_config', default='',
help='additional architecture configuration')
parser.add_argument('--type', default='torch.cuda.FloatTensor',
help='type of tensor - e.g torch.cuda.HalfTensor')
parser.add_argument('--gpus', default='0',
help='gpus used for training - e.g 0,1,3')
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
help='number of data loading workers (default: 8)')
parser.add_argument('--epochs', default=2500, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--optimizer', default='SGD', type=str, metavar='OPT',
help='optimizer function used')
parser.add_argument('--lr', '--learning_rate', default=0.1, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', type=str, metavar='FILE',
help='evaluate model FILE on validation set')
def main():
global args, best_prec1
best_prec1 = 0
args = parser.parse_args()
if args.evaluate:
args.results_dir = '/tmp'
if args.save is '':
args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
save_path = os.path.join(args.results_dir, args.save)
if not os.path.exists(save_path):
os.makedirs(save_path)
setup_logging(os.path.join(save_path, 'log.txt'))
results_file = os.path.join(save_path, 'results.%s')
results = ResultsLog(results_file % 'csv', results_file % 'html')
logging.info("saving to %s", save_path)
logging.debug("run arguments: %s", args)
if 'cuda' in args.type:
args.gpus = [int(i) for i in args.gpus.split(',')]
torch.cuda.set_device(args.gpus[0])
cudnn.benchmark = True
else:
args.gpus = None
# create model
logging.info("creating model %s", args.model)
model = models.__dict__[args.model]
model_config = {'input_size': args.input_size, 'dataset': args.dataset}
if args.model_config is not '':
model_config = dict(model_config, **literal_eval(args.model_config))
model = model(**model_config)
logging.info("created model with configuration: %s", model_config)
# optionally resume from a checkpoint
if args.evaluate:
if not os.path.isfile(args.evaluate):
parser.error('invalid checkpoint: {}'.format(args.evaluate))
checkpoint = torch.load(args.evaluate)
model.load_state_dict(checkpoint['state_dict'])
logging.info("loaded checkpoint '%s' (epoch %s)",
args.evaluate, checkpoint['epoch'])
elif args.resume:
checkpoint_file = args.resume
if os.path.isdir(checkpoint_file):
results.load(os.path.join(checkpoint_file, 'results.csv'))
checkpoint_file = os.path.join(
checkpoint_file, 'model_best.pth.tar')
if os.path.isfile(checkpoint_file):
logging.info("loading checkpoint '%s'", args.resume)
checkpoint = torch.load(checkpoint_file)
args.start_epoch = checkpoint['epoch'] - 1
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
logging.info("loaded checkpoint '%s' (epoch %s)",
checkpoint_file, checkpoint['epoch'])
else:
logging.error("no checkpoint found at '%s'", args.resume)
num_parameters = sum([l.nelement() for l in model.parameters()])
logging.info("number of parameters: %d", num_parameters)
# Data loading code
default_transform = {
'train': get_transform(args.dataset,
input_size=args.input_size, augment=True),
'eval': get_transform(args.dataset,
input_size=args.input_size, augment=False)
}
transform = getattr(model, 'input_transform', default_transform)
regime = getattr(model, 'regime', {0: {'optimizer': args.optimizer,
'lr': args.lr,
'momentum': args.momentum,
'weight_decay': args.weight_decay}})
# define loss function (criterion) and optimizer
criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
criterion.type(args.type)
model.type(args.type)
val_data = get_dataset(args.dataset, 'val', transform['eval'])
val_loader = torch.utils.data.DataLoader(
val_data,
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
if args.evaluate:
validate(val_loader, model, criterion, 0)
return
train_data = get_dataset(args.dataset, 'train', transform['train'])
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
logging.info('training regime: %s', regime)
for epoch in range(args.start_epoch, args.epochs):
optimizer = adjust_optimizer(optimizer, epoch, regime)
# train for one epoch
train_loss, train_prec1, train_prec5 = train(
train_loader, model, criterion, epoch, optimizer)
# evaluate on validation set
val_loss, val_prec1, val_prec5 = validate(
val_loader, model, criterion, epoch)
# remember best prec@1 and save checkpoint
is_best = val_prec1 > best_prec1
best_prec1 = max(val_prec1, best_prec1)
save_checkpoint({
'epoch': epoch + 1,
'model': args.model,
'config': args.model_config,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'regime': regime
}, is_best, path=save_path)
logging.info('\n Epoch: {0}\t'
'Training Loss {train_loss:.4f} \t'
'Training Prec@1 {train_prec1:.3f} \t'
'Training Prec@5 {train_prec5:.3f} \t'
'Validation Loss {val_loss:.4f} \t'
'Validation Prec@1 {val_prec1:.3f} \t'
'Validation Prec@5 {val_prec5:.3f} \n'
.format(epoch + 1, train_loss=train_loss, val_loss=val_loss,
train_prec1=train_prec1, val_prec1=val_prec1,
train_prec5=train_prec5, val_prec5=val_prec5))
results.add(epoch=epoch + 1, train_loss=train_loss, val_loss=val_loss,
train_error1=100 - train_prec1, val_error1=100 - val_prec1,
train_error5=100 - train_prec5, val_error5=100 - val_prec5)
#results.plot(x='epoch', y=['train_loss', 'val_loss'],
# title='Loss', ylabel='loss')
#results.plot(x='epoch', y=['train_error1', 'val_error1'],
# title='Error@1', ylabel='error %')
#results.plot(x='epoch', y=['train_error5', 'val_error5'],
# title='Error@5', ylabel='error %')
results.save()
def forward(data_loader, model, criterion, epoch=0, training=True, optimizer=None):
if args.gpus and len(args.gpus) > 1:
model = torch.nn.DataParallel(model, args.gpus)
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
end = time.time()
for i, (inputs, target) in enumerate(data_loader):
# measure data loading time
data_time.update(time.time() - end)
if args.gpus is not None:
target = target.cuda()
if not training:
with torch.no_grad():
input_var = Variable(inputs.type(args.type), volatile=not training)
target_var = Variable(target)
# compute output
output = model(input_var)
else:
input_var = Variable(inputs.type(args.type), volatile=not training)
target_var = Variable(target)
# compute output
output = model(input_var)
loss = criterion(output, target_var)
if type(output) is list:
output = output[0]
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
if training:
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
for p in list(model.parameters()):
if hasattr(p,'org'):
p.data.copy_(p.org)
optimizer.step()
for p in list(model.parameters()):
if hasattr(p,'org'):
p.org.copy_(p.data.clamp_(-1,1))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
logging.info('{phase} - Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(data_loader),
phase='TRAINING' if training else 'EVALUATING',
batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
return losses.avg, top1.avg, top5.avg
def train(data_loader, model, criterion, epoch, optimizer):
# switch to train mode
model.train()
return forward(data_loader, model, criterion, epoch,
training=True, optimizer=optimizer)
def validate(data_loader, model, criterion, epoch):
# switch to evaluate mode
model.eval()
return forward(data_loader, model, criterion, epoch,
training=False, optimizer=None)
if __name__ == '__main__':
main()