import torch
|
|
import pdb
|
|
import torch.nn as nn
|
|
import math
|
|
from torch.autograd import Variable
|
|
from torch.autograd import Function
|
|
|
|
import numpy as np
|
|
|
|
|
|
def Binarize(tensor,quant_mode='det'):
|
|
if quant_mode=='det':
|
|
return tensor.sign()
|
|
else:
|
|
return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)
|
|
|
|
|
|
|
|
|
|
class HingeLoss(nn.Module):
|
|
def __init__(self):
|
|
super(HingeLoss,self).__init__()
|
|
self.margin=1.0
|
|
|
|
def hinge_loss(self,input,target):
|
|
#import pdb; pdb.set_trace()
|
|
output=self.margin-input.mul(target)
|
|
output[output.le(0)]=0
|
|
return output.mean()
|
|
|
|
def forward(self, input, target):
|
|
return self.hinge_loss(input,target)
|
|
|
|
class SqrtHingeLossFunction(Function):
|
|
def __init__(self):
|
|
super(SqrtHingeLossFunction,self).__init__()
|
|
self.margin=1.0
|
|
|
|
def forward(self, input, target):
|
|
output=self.margin-input.mul(target)
|
|
output[output.le(0)]=0
|
|
self.save_for_backward(input, target)
|
|
loss=output.mul(output).sum(0).sum(1).div(target.numel())
|
|
return loss
|
|
|
|
def backward(self,grad_output):
|
|
input, target = self.saved_tensors
|
|
output=self.margin-input.mul(target)
|
|
output[output.le(0)]=0
|
|
import pdb; pdb.set_trace()
|
|
grad_output.resize_as_(input).copy_(target).mul_(-2).mul_(output)
|
|
grad_output.mul_(output.ne(0).float())
|
|
grad_output.div_(input.numel())
|
|
return grad_output,grad_output
|
|
|
|
def Quantize(tensor,quant_mode='det', params=None, numBits=8):
|
|
tensor.clamp_(-2**(numBits-1),2**(numBits-1))
|
|
if quant_mode=='det':
|
|
tensor=tensor.mul(2**(numBits-1)).round().div(2**(numBits-1))
|
|
else:
|
|
tensor=tensor.mul(2**(numBits-1)).round().add(torch.rand(tensor.size()).add(-0.5)).div(2**(numBits-1))
|
|
quant_fixed(tensor, params)
|
|
return tensor
|
|
|
|
#import torch.nn._functions as tnnf
|
|
|
|
|
|
class BinarizeLinear(nn.Linear):
|
|
|
|
def __init__(self, *kargs, **kwargs):
|
|
super(BinarizeLinear, self).__init__(*kargs, **kwargs)
|
|
|
|
def forward(self, input):
|
|
|
|
# if input.size(1) != 784:
|
|
# input.data=Binarize(input.data)
|
|
if not hasattr(self.weight,'org'):
|
|
self.weight.org=self.weight.data.clone()
|
|
self.weight.data=Binarize(self.weight.org)
|
|
out = nn.functional.linear(input, self.weight)
|
|
if not self.bias is None:
|
|
self.bias.org=self.bias.data.clone()
|
|
out += self.bias.view(1, -1).expand_as(out)
|
|
|
|
return out
|
|
|
|
class BinarizeConv2d(nn.Conv2d):
|
|
|
|
def __init__(self, *kargs, **kwargs):
|
|
super(BinarizeConv2d, self).__init__(*kargs, **kwargs)
|
|
|
|
|
|
def forward(self, input):
|
|
# if input.size(1) != 3:
|
|
# input.data = Binarize(input.data)
|
|
if not hasattr(self.weight,'org'):
|
|
self.weight.org=self.weight.data.clone()
|
|
self.weight.data=Binarize(self.weight.org)
|
|
|
|
out = nn.functional.conv2d(input, self.weight, None, self.stride,
|
|
self.padding, self.dilation, self.groups)
|
|
|
|
if not self.bias is None:
|
|
self.bias.org=self.bias.data.clone()
|
|
out += self.bias.view(1, -1, 1, 1).expand_as(out)
|
|
|
|
return out
|
|
# x = torch.tensor([[255.0, 200.0, 201.0], [210.0, 222.0, 223.0]])
|
|
# print(Quantize(x,quant_mode='det', params=None, numBits=8))
|