import torch import pdb import torch.nn as nn import math from torch.autograd import Variable from torch.autograd import Function from decimal import Decimal, ROUND_HALF_UP import numpy as np def Binarize(tensor,quant_mode='det'): if quant_mode=='det': return tensor.sign() else: return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1) class HingeLoss(nn.Module): def __init__(self): super(HingeLoss,self).__init__() self.margin=1.0 def hinge_loss(self,input,target): #import pdb; pdb.set_trace() output=self.margin-input.mul(target) output[output.le(0)]=0 return output.mean() def forward(self, input, target): return self.hinge_loss(input,target) class SqrtHingeLossFunction(Function): def __init__(self): super(SqrtHingeLossFunction,self).__init__() self.margin=1.0 def forward(self, input, target): output=self.margin-input.mul(target) output[output.le(0)]=0 self.save_for_backward(input, target) loss=output.mul(output).sum(0).sum(1).div(target.numel()) return loss def backward(self,grad_output): input, target = self.saved_tensors output=self.margin-input.mul(target) output[output.le(0)]=0 import pdb; pdb.set_trace() grad_output.resize_as_(input).copy_(target).mul_(-2).mul_(output) grad_output.mul_(output.ne(0).float()) grad_output.div_(input.numel()) return grad_output,grad_output def Quantize(tensor,quant_mode='det', params=None, numBits=8): tensor.clamp_(-2**(numBits-1),2**(numBits-1)) if quant_mode=='det': tensor=tensor.mul(2**(numBits-1)).round().div(2**(numBits-1)) else: tensor=tensor.mul(2**(numBits-1)).round().add(torch.rand(tensor.size()).add(-0.5)).div(2**(numBits-1)) quant_fixed(tensor, params) return tensor #import torch.nn._functions as tnnf class BinarizeLinear(nn.Linear): def __init__(self, *kargs, **kwargs): super(BinarizeLinear, self).__init__(*kargs, **kwargs) def forward(self, input): # if input.size(1) != 784: # input.data=Binarize(input.data) if not hasattr(self.weight,'org'): self.weight.org=self.weight.data.clone() self.weight.data=Binarize(self.weight.org) out = nn.functional.linear(input, self.weight) if not self.bias is None: self.bias.org=self.bias.data.clone() out += self.bias.view(1, -1).expand_as(out) return out class BinarizeConv2d(nn.Conv2d): def __init__(self, *kargs, **kwargs): super(BinarizeConv2d, self).__init__(*kargs, **kwargs) def forward(self, input): # if input.size(1) != 3: # input.data = Binarize(input.data) if not hasattr(self.weight,'org'): self.weight.org=self.weight.data.clone() self.weight.data=Binarize(self.weight.org) #input = torch.round(input) #input = input*2-1 #scale = max(torch.max(input), -torch.min(input)) / 63 #input = torch.round(input*2 / scale) - 63 #if scale != 0: # input = torch.round(input / scale) #print (torch.max(input)) #print(input) input = torch.round(input) #print(input) #print (torch.max(input)) out = nn.functional.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups) #print (torch.min(out), torch.max(out)) #out = torch.round(out) #print (torch.min(out), torch.max(out)) #print (torch.min(input), torch.max(input)) #out = torch.round(out / 64 * 36 / 64) #print (self.weight.size()[1]) #if self.weight.size()[1] >= 16 and self.weight.size()[1] <= 24: if self.weight.size()[1] >= 4 and self.weight.size()[2] * self.weight.size()[3] == 9: out = torch.round(out / 64 * 36 / 64) elif self.weight.size()[1] == 1: out = torch.round(out * 7 / 64) else: out = torch.round(out / 64) out = out * 4 out[out > 63] = 63 out[out < -63] = -63 #out = out - torch.round(torch.mean(out)) # out = out*4 #out[out > 63] = 63 #out[out < -63] = -63 #else: # out = torch.round(out * 10 / 64) #print (torch.min(out), torch.max(out)) # if not self.bias is None: # self.bias.org=self.bias.data.clone() # out += self.bias.view(1, -1, 1, 1).expand_as(out) return out class IdealCimConv2d(nn.Conv2d): def __init__(self, *kargs, **kwargs): super(IdealCimConv2d, self).__init__(*kargs, **kwargs) def forward(self, input): # if input.size(1) != 3: # input.data = Binarize(input.data) if not hasattr(self.weight,'org'): self.weight.org=self.weight.data.clone() self.weight.data=Binarize(self.weight.org) #input = torch.round(input) #input = input*2-1 #scale = max(torch.max(input), -torch.min(input)) / 63 #input = torch.round(input*2 / scale) - 63 #if scale != 0: # input = torch.round(input / scale) #print (torch.max(input)) #print(input) input = torch.round(input) #print(input) #print (torch.max(input)) out = nn.functional.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups) out = out / 64 out = out * 4 out[out > 63] = 63 out[out < -63] = -63 return out device = 'cuda:0' ''' H = [1024, 512] sim_model = torch.nn.Sequential( torch.nn.Linear(36, H[0]), torch.nn.Dropout(p=0.5), torch.nn.ReLU(), torch.nn.Linear(H[0], H[1]), torch.nn.Dropout(p=0.5), torch.nn.ReLU(), torch.nn.Linear(H[-1], 1), ) sim_model.load_state_dict(torch.load('model_error.ckpt', map_location=torch.device('cuda:0'))) sim_model = sim_model.to(device) sim_model.eval() ''' class CimSimConv2d(nn.Conv2d): def __init__(self, *kargs, **kwargs): super(CimSimConv2d, self).__init__(*kargs, **kwargs) self.device = device def forward(self, input): if not hasattr(self.weight,'org'): self.weight.org=self.weight.data.clone() self.weight.data=Binarize(self.weight.org) #scale = max(torch.max(input), -torch.min(input)) / 63 #if scale != 0: # input = torch.round(input / scale) #''' random error #out = nn.functional.conv2d(input, self.weight, None, self.stride, # self.padding, self.dilation, self.groups) #out = torch.round(out / 64 * 36 / 64) #randrange = (self.weight.size()[1] // 4) #for _ in range(randrange): # out += torch.randint(-1, 1, out.size(), device=device) #out[out>63] = 63 #out[out<-63] -63 #''' input = torch.round(input) out2 = self.simconv(input, self.weight) ''' if torch.max(out2) < 32: out2 = out2 * 2 if torch.max(out2) < 32: out2 = out2 * 2 if torch.max(out2) < 32: out2 = out2 * 2 ''' out2 = out2 * 4 out2[out2 > 63] = 63 out2[out2 < -63] = -63 #print (self.weight.data.size()) #print (torch.max(out2), torch.min(out2)) #print (torch.max(out-out2), torch.min(out-out2)) #out = nn.functional.conv2d(input, self.weight, None, self.stride, # self.padding, self.dilation, self.groups) #print(input.size(), self.weight.size(), out.size()) #if not self.bias is None: # self.bias.org=self.bias.data.clone() # out += self.bias.view(1, -1, 1, 1).expand_as(out) return out2 def simconv(self, input_a, weight): #print(input_a.size(), weight.size()) batch_size = input_a.size()[0] out_channel = weight.size()[0] out_width = input_a.size()[2] - 2 * (weight.size()[2] // 2) out_height = input_a.size()[3] - 2 * (weight.size()[3] // 2) simout = torch.zeros(batch_size, out_channel, out_width, out_height, dtype = input_a.dtype).to(device) first = True #''' Mapping Table if weight.size()[2] == 7: kernel_group = 1 else: kernel_group = 4 Digital_input_split = torch.split(input_a, kernel_group, dim=1) binary_weight_split = torch.split(weight, kernel_group, dim=1) for i in range(len(Digital_input_split)): temp_output = nn.functional.conv2d(Digital_input_split[i], binary_weight_split[i], None, self.stride, self.padding, self.dilation, self.groups) #temp_output = torch.round(temp_output / 64 * 36 / 64) temp_output = torch.round(temp_output / 64) temp_output = Mapping.apply(temp_output) simout += temp_output + 2 #print (torch.max(simout), torch.min(simout)) #''' ''' Error model for n in range(batch_size): for c in range(out_channel): w = torch.reshape(weight[c], (-1,)).to(device) inputs = [] for i in range(out_width): for j in range(out_height): input = torch.reshape(input_a[n, :, i: i + weight.size()[2], j: j + weight.size()[3]], (-1,)) #print (w.size(), input.size()) # simout[n][c][i][j] = sum(w*input) # TODO simout[n][c][i][j] = self.cim_conv_tmp(input, w) #''' #print (len(input)) #print (simout.size()) # out = nn.functional.conv2d(input_a, weight) return simout def cim_conv_tmp(self, input, weight): assert len(input) == len(weight) raw_sum = 0 if len(weight) == 3: for i in range((len(input)-1) // 36 + 1): data_x = input[i*36:i*36+36] * weight[i*36:i*36+36] row = int(Decimal(float(sum(data_x)/64.0)).quantize(0, ROUND_HALF_UP)) #''' Error model if len(data_x) < 36: data_x = torch.cat((data_x, torch.zeros(36 - len(data_x), dtype=data_x.dtype))) try: #ensor_x = torch.Tensor(data_x).to(self.device) tensor_x = data_x.to(device) except: print (data_x, len()) y_pred = sim_model(tensor_x) if int(y_pred[0]) > 10: adjust = 10 elif int(y_pred[0]) < -10: adjust = -10 else: adjust = int(y_pred[0]) #print (tensor_x, y_pred) raw_sum += (row + adjust + 2) #''' #if row in self.mappingTable: # row = self.mappingTable[row] #raw_sum += row #raw_sum += row else: for i in range((len(input)-1) // 49 + 1): data_x = input[i*49:i*49+49] * weight[i*49:i*49+49] row = int(Decimal(float(sum(data_x)/64.0)).quantize(0, ROUND_HALF_UP)) #''' Error model if len(data_x) < 49: data_x = torch.cat((data_x, torch.zeros(49 - len(data_x), dtype=data_x.dtype))) try: #ensor_x = torch.Tensor(data_x).to(self.device) tensor_x = data_x.to(device) except: print (data_x, len()) y_pred = sim_model(tensor_x) if int(y_pred[0]) > 10: adjust = 10 elif int(y_pred[0]) < -10: adjust = -10 else: adjust = int(y_pred[0]) #print (tensor_x, y_pred) raw_sum += (row + adjust + 2) #print (raw_sum) return raw_sum class Mapping(torch.autograd.Function): @staticmethod def forward(ctx, input): output = input.clone() output[input==-1] = -4 output[input==-2] = -5 output[input==-3] = -6 output[input==-4] = -7 output[input==-5] = -9 output[input==-6] = -9 output[input==-7] = -11 output[input==-8] = -11 output[input==-9] = -13 output[input==-10] = -13 output[input==-11] = -17 output[input==-12] = -17 output[input==-13] = -17 output[input==-14] = -19 output[input==-15] = -19 output[input==-16] = -21 output[input==-17] = -21 output[input==-18] = -23 output[input==-19] = -25 output[input==-20] = -25 output[input==-21] = -25 output[input==-22] = -25 output[input==-23] = -27 output[input==-24] = -27 output[input==-25] = -29 output[input==-26] = -29 output[input==-27] = -29 output[input==-28] = -31 output[input==-29] = -31 output[input==-30] = -33 output[input==-31] = -33 output[input==-32] = -35 output[input==-33] = -35 output[input==-34] = -35 #output[input==-35] = -35 output[input==0] = -2 output[input==1] = -1 output[input==2] = 1 output[input==3] = 2 #output[input==4] = 4 output[input==5] = 4 #output[input==6] = 6 output[input==7] = 8 #output[input==8] = 8 output[input==9] = 10 #output[input==10] = 10 output[input==11] = 12 #output[input==12] = 12 output[input==13] = 16 output[input==14] = 16 output[input==15] = 16 #output[input==16] = 16 output[input==17] = 18 output[input==18] = 20 output[input==19] = 20 output[input==20] = 24 output[input==21] = 24 output[input==22] = 24 output[input==23] = 26 output[input==24] = 26 output[input==25] = 28 output[input==26] = 28 output[input==27] = 28 output[input==28] = 30 output[input==29] = 30 output[input==30] = 32 output[input==31] = 32 output[input==32] = 34 output[input==33] = 34 output[input==34] = 34 output[input==35] = 34 return output def backward(ctx, grad_output): return grad_output