You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

376 lines
13 KiB

import torch
import pdb
import torch.nn as nn
import math
from torch.autograd import Variable
from torch.autograd import Function
from decimal import Decimal, ROUND_HALF_UP
import numpy as np
def Binarize(tensor,quant_mode='det'):
if quant_mode=='det':
return tensor.sign()
else:
return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)
def Ninarize(tensor, quant_number, quant_mode='det'):
#return tensor.add(1).mul(quant_number+1).div(2).floor().clamp(0, quant_number).mul(2).add(-quant_number)
return tensor.add(quant_number).mul(quant_number+1).div(2*quant_number).floor().clamp(0, quant_number).mul(2).add(-quant_number)
LUT = torch.Tensor([-63, -62, -61, -60,
-59, -58, -57, -56, -55, -54, -53, -52, -51, -50,
-49, -48, -47, -46, -45, -44, -43, -42, -41, -40,
-39, -38, -37, -36, -35, -35, -35, -35, -33, -33,
-31, -31, -29, -29, -29, -27, -27, -25, -25, -25,
-25, -23, -21, -21, -19, -19, -17, -17, -17, -13,
-13, -11, -11, -9, -9, -7, -6, -5, -4, -2,
-1, 1, 2, 4, 4, 6, 8, 8, 10,
10, 12, 12, 16, 16, 16, 16, 18, 20, 20,
24, 24, 24, 26, 26, 28, 28, 28, 30, 30,
32, 32, 34, 34, 34, 34, 36, 37, 38, 39,
40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
60, 61, 62, 63]).long()
LUT_OFFSET = 63
class HingeLoss(nn.Module):
def __init__(self):
super(HingeLoss,self).__init__()
self.margin=1.0
def hinge_loss(self,input,target):
#import pdb; pdb.set_trace()
output=self.margin-input.mul(target)
output[output.le(0)]=0
return output.mean()
def forward(self, input, target):
return self.hinge_loss(input,target)
class SqrtHingeLossFunction(Function):
def __init__(self):
super(SqrtHingeLossFunction,self).__init__()
self.margin=1.0
def forward(self, input, target):
output=self.margin-input.mul(target)
output[output.le(0)]=0
self.save_for_backward(input, target)
loss=output.mul(output).sum(0).sum(1).div(target.numel())
return loss
def backward(self,grad_output):
input, target = self.saved_tensors
output=self.margin-input.mul(target)
output[output.le(0)]=0
import pdb; pdb.set_trace()
grad_output.resize_as_(input).copy_(target).mul_(-2).mul_(output)
grad_output.mul_(output.ne(0).float())
grad_output.div_(input.numel())
return grad_output,grad_output
def Quantize(tensor,quant_mode='det', params=None, numBits=8):
tensor.clamp_(-2**(numBits-1),2**(numBits-1))
if quant_mode=='det':
tensor=tensor.mul(2**(numBits-1)).round().div(2**(numBits-1))
else:
tensor=tensor.mul(2**(numBits-1)).round().add(torch.rand(tensor.size()).add(-0.5)).div(2**(numBits-1))
quant_fixed(tensor, params)
return tensor
#import torch.nn._functions as tnnf
class BinarizeLinear(nn.Linear):
def __init__(self, *kargs, **kwargs):
super(BinarizeLinear, self).__init__(*kargs, **kwargs)
def forward(self, input):
# if input.size(1) != 784:
# input.data=Binarize(input.data)
if not hasattr(self.weight,'org'):
self.weight.org=self.weight.data.clone()
self.weight.data=Binarize(self.weight.org)
out = nn.functional.linear(input, self.weight)
if not self.bias is None:
self.bias.org=self.bias.data.clone()
out += self.bias.view(1, -1).expand_as(out)
return out
class BinarizeConv2d(nn.Conv2d):
def __init__(self, *kargs, **kwargs):
super(BinarizeConv2d, self).__init__(*kargs, **kwargs)
def forward(self, input):
# if input.size(1) != 3:
# input.data = Binarize(input.data)
if not hasattr(self.weight,'org'):
self.weight.org=self.weight.data.clone()
self.weight.data=Binarize(self.weight.org)
#input = torch.round(input)
#input = input*2-1
#scale = max(torch.max(input), -torch.min(input)) / 63
#input = torch.round(input*2 / scale) - 63
#if scale != 0:
# input = torch.round(input / scale)
#print (torch.max(input))
#print(input)
input = torch.round(input)
#print(input)
#print (torch.max(input))
out = nn.functional.conv2d(input, self.weight, None, self.stride,
self.padding, self.dilation, self.groups)
#print (torch.min(out), torch.max(out))
#out = torch.round(out)
#print (torch.min(out), torch.max(out))
#print (torch.min(input), torch.max(input))
#out = torch.round(out / 64 * 36 / 64)
#print (self.weight.size()[1])
#if self.weight.size()[1] >= 16 and self.weight.size()[1] <= 24:
if self.weight.size()[1] >= 4 and self.weight.size()[2] * self.weight.size()[3] == 9:
out = torch.round(out / 64 * 36 / 64)
elif self.weight.size()[1] == 1:
out = torch.round(out * 7 / 64)
else:
out = torch.round(out / 64)
out = out * 4
out[out > 63] = 63
out[out < -63] = -63
#out = out - torch.round(torch.mean(out))
# out = out*4
#out[out > 63] = 63
#out[out < -63] = -63
#else:
# out = torch.round(out * 10 / 64)
#print (torch.min(out), torch.max(out))
# if not self.bias is None:
# self.bias.org=self.bias.data.clone()
# out += self.bias.view(1, -1, 1, 1).expand_as(out)
return out
class IdealCimConv2d(nn.Conv2d):
def __init__(self, *kargs, **kwargs):
super(IdealCimConv2d, self).__init__(*kargs, **kwargs)
def forward(self, input):
# if input.size(1) != 3:
# input.data = Binarize(input.data)
if not hasattr(self.weight,'org'):
self.weight.org=self.weight.data.clone()
self.weight.data=Binarize(self.weight.org)
#input = torch.round(input)
#input = input*2-1
#scale = max(torch.max(input), -torch.min(input)) / 63
#input = torch.round(input*2 / scale) - 63
#if scale != 0:
# input = torch.round(input / scale)
#print (torch.max(input))
#print(input)
input = torch.round(input)
#print(input)
#print (torch.max(input))
out = nn.functional.conv2d(input, self.weight, None, self.stride,
self.padding, self.dilation, self.groups)
out = out / 64
out = out * 4
out[out > 63] = 63
out[out < -63] = -63
return out
device = 'cuda:0'
'''
H = [1024, 512]
sim_model = torch.nn.Sequential(
torch.nn.Linear(36, H[0]),
torch.nn.Dropout(p=0.5),
torch.nn.ReLU(),
torch.nn.Linear(H[0], H[1]),
torch.nn.Dropout(p=0.5),
torch.nn.ReLU(),
torch.nn.Linear(H[-1], 1),
)
sim_model.load_state_dict(torch.load('model_error.ckpt', map_location=torch.device('cuda:0')))
sim_model = sim_model.to(device)
sim_model.eval()
'''
class CimSimConv2d(nn.Conv2d):
def __init__(self, *kargs, **kwargs):
super(CimSimConv2d, self).__init__(*kargs, **kwargs)
self.device = device
nn.init.uniform_(self.weight.data, a = -1., b = 1.)
def forward(self, input):
if not hasattr(self.weight,'org'):
self.weight.org=self.weight.data.clone()
#print('In:', torch.max(self.weight.org), torch.min(self.weight.org))
#self.weight.data=Binarize(self.weight.org)
self.weight.data=Ninarize(self.weight.org, 1)
#print('out:', torch.max(self.weight.data), torch.min(self.weight.data))
#scale = max(torch.max(input), -torch.min(input)) / 63
#if scale != 0:
# input = torch.round(input / scale)
#''' random error
out = nn.functional.conv2d(input, self.weight, None, self.stride,
self.padding, self.dilation, self.groups)
out = torch.round(out / 64)
#randrange = (self.weight.size()[1] // 4)
#for _ in range(randrange):
# out += torch.randint(-1, 1, out.size(), device=device)
#out[out>63] = 63
#out[out<-63] -63
#'''
input = torch.round(input)
out2 = self.simconv(input, self.weight)
'''
if torch.max(out2) < 32:
out2 = out2 * 2
if torch.max(out2) < 32:
out2 = out2 * 2
if torch.max(out2) < 32:
out2 = out2 * 2
'''
#print ('in, weight, out')
'''
print ('round')
#print (torch.max(input), torch.min(input))
#print (torch.sum(input), torch.sum(input))
#print (torch.max(self.weight), torch.min(self.weight))
#print (torch.sum(self.weight), torch.sum(self.weight))
print (torch.max(out), torch.min(out))
print (torch.max(out2), torch.min(out2))
#'''
out2 = out2 * 4
out2[out2 > 63] = 63
out2[out2 < -63] = -63
#print (self.weight.data.size())
#print (torch.max(out-out2), torch.min(out-out2))
#out = nn.functional.conv2d(input, self.weight, None, self.stride,
# self.padding, self.dilation, self.groups)
#print(input.size(), self.weight.size(), out.size())
#if not self.bias is None:
# self.bias.org=self.bias.data.clone()
# out += self.bias.view(1, -1, 1, 1).expand_as(out)
return out2
def simconv(self, input_a, weight):
#print(input_a.size(), weight.size())
batch_size = input_a.size()[0]
out_channel = weight.size()[0]
out_width = input_a.size()[2] - 2 * (weight.size()[2] // 2)
out_height = input_a.size()[3] - 2 * (weight.size()[3] // 2)
simout = torch.zeros(batch_size, out_channel, out_width, out_height, dtype = input_a.dtype).to(input_a.device)
first = True
#''' Mapping Table
global LUT
LUT = LUT.to(input_a.device)
if weight.size()[2] == 7:
kernel_group = 1
else:
kernel_group = 4
Digital_input_split = torch.split(input_a, kernel_group, dim=1)
binary_weight_split = torch.split(weight, kernel_group, dim=1)
for i in range(len(Digital_input_split)):
temp_output = nn.functional.conv2d(Digital_input_split[i], binary_weight_split[i], None, self.stride, self.padding, self.dilation, self.groups)
temp_output = torch.round(temp_output / 64)
temp_output += LUT_OFFSET
temp_output = LUT[temp_output.long()]
simout += temp_output + 2
#print (torch.max(simout), torch.min(simout))
#'''
''' Error model
for n in range(batch_size):
for c in range(out_channel):
w = torch.reshape(weight[c], (-1,)).to(device)
inputs = []
for i in range(out_width):
for j in range(out_height):
input = torch.reshape(input_a[n, :, i: i + weight.size()[2], j: j + weight.size()[3]], (-1,))
#print (w.size(), input.size())
# simout[n][c][i][j] = sum(w*input)
# TODO
simout[n][c][i][j] = self.cim_conv_tmp(input, w)
#'''
#print (len(input))
#print (simout.size())
# out = nn.functional.conv2d(input_a, weight)
return simout
def cim_conv_tmp(self, input, weight):
assert len(input) == len(weight)
raw_sum = 0
if len(weight) == 3:
for i in range((len(input)-1) // 36 + 1):
data_x = input[i*36:i*36+36] * weight[i*36:i*36+36]
row = int(Decimal(float(sum(data_x)/64.0)).quantize(0, ROUND_HALF_UP))
#''' Error model
if len(data_x) < 36:
data_x = torch.cat((data_x, torch.zeros(36 - len(data_x), dtype=data_x.dtype)))
try:
#ensor_x = torch.Tensor(data_x).to(self.device)
tensor_x = data_x.to(device)
except:
print (data_x, len())
y_pred = sim_model(tensor_x)
if int(y_pred[0]) > 10:
adjust = 10
elif int(y_pred[0]) < -10:
adjust = -10
else:
adjust = int(y_pred[0])
#print (tensor_x, y_pred)
raw_sum += (row + adjust + 2)
#'''
#if row in self.mappingTable:
# row = self.mappingTable[row]
#raw_sum += row
#raw_sum += row
else:
for i in range((len(input)-1) // 49 + 1):
data_x = input[i*49:i*49+49] * weight[i*49:i*49+49]
row = int(Decimal(float(sum(data_x)/64.0)).quantize(0, ROUND_HALF_UP))
#''' Error model
if len(data_x) < 49:
data_x = torch.cat((data_x, torch.zeros(49 - len(data_x), dtype=data_x.dtype)))
try:
#ensor_x = torch.Tensor(data_x).to(self.device)
tensor_x = data_x.to(device)
except:
print (data_x, len())
y_pred = sim_model(tensor_x)
if int(y_pred[0]) > 10:
adjust = 10
elif int(y_pred[0]) < -10:
adjust = -10
else:
adjust = int(y_pred[0])
#print (tensor_x, y_pred)
raw_sum += (row + adjust + 2)
#print (raw_sum)
return raw_sum