|
|
- import torch
- import pdb
- import torch.nn as nn
- import math
- from torch.autograd import Variable
- from torch.autograd import Function
- from decimal import Decimal, ROUND_HALF_UP
-
- import numpy as np
-
-
- def Binarize(tensor,quant_mode='det'):
- if quant_mode=='det':
- return tensor.sign()
- else:
- return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)
-
-
-
-
- class HingeLoss(nn.Module):
- def __init__(self):
- super(HingeLoss,self).__init__()
- self.margin=1.0
-
- def hinge_loss(self,input,target):
- #import pdb; pdb.set_trace()
- output=self.margin-input.mul(target)
- output[output.le(0)]=0
- return output.mean()
-
- def forward(self, input, target):
- return self.hinge_loss(input,target)
-
- class SqrtHingeLossFunction(Function):
- def __init__(self):
- super(SqrtHingeLossFunction,self).__init__()
- self.margin=1.0
-
- def forward(self, input, target):
- output=self.margin-input.mul(target)
- output[output.le(0)]=0
- self.save_for_backward(input, target)
- loss=output.mul(output).sum(0).sum(1).div(target.numel())
- return loss
-
- def backward(self,grad_output):
- input, target = self.saved_tensors
- output=self.margin-input.mul(target)
- output[output.le(0)]=0
- import pdb; pdb.set_trace()
- grad_output.resize_as_(input).copy_(target).mul_(-2).mul_(output)
- grad_output.mul_(output.ne(0).float())
- grad_output.div_(input.numel())
- return grad_output,grad_output
-
- def Quantize(tensor,quant_mode='det', params=None, numBits=8):
- tensor.clamp_(-2**(numBits-1),2**(numBits-1))
- if quant_mode=='det':
- tensor=tensor.mul(2**(numBits-1)).round().div(2**(numBits-1))
- else:
- tensor=tensor.mul(2**(numBits-1)).round().add(torch.rand(tensor.size()).add(-0.5)).div(2**(numBits-1))
- quant_fixed(tensor, params)
- return tensor
-
- #import torch.nn._functions as tnnf
-
-
- class BinarizeLinear(nn.Linear):
-
- def __init__(self, *kargs, **kwargs):
- super(BinarizeLinear, self).__init__(*kargs, **kwargs)
-
- def forward(self, input):
-
- # if input.size(1) != 784:
- # input.data=Binarize(input.data)
- if not hasattr(self.weight,'org'):
- self.weight.org=self.weight.data.clone()
- self.weight.data=Binarize(self.weight.org)
- out = nn.functional.linear(input, self.weight)
- if not self.bias is None:
- self.bias.org=self.bias.data.clone()
- out += self.bias.view(1, -1).expand_as(out)
-
- return out
-
- class BinarizeConv2d(nn.Conv2d):
-
- def __init__(self, *kargs, **kwargs):
- super(BinarizeConv2d, self).__init__(*kargs, **kwargs)
-
-
- def forward(self, input):
- # if input.size(1) != 3:
- # input.data = Binarize(input.data)
- if not hasattr(self.weight,'org'):
- self.weight.org=self.weight.data.clone()
- self.weight.data=Binarize(self.weight.org)
- #input = torch.round(input)
- #input = input*2-1
- #scale = max(torch.max(input), -torch.min(input)) / 63
- #input = torch.round(input*2 / scale) - 63
- #if scale != 0:
- # input = torch.round(input / scale)
- #print (torch.max(input))
- #print(input)
- input = torch.round(input)
- #print(input)
- #print (torch.max(input))
- out = nn.functional.conv2d(input, self.weight, None, self.stride,
- self.padding, self.dilation, self.groups)
-
- #print (torch.min(out), torch.max(out))
- #out = torch.round(out)
- #print (torch.min(out), torch.max(out))
- #print (torch.min(input), torch.max(input))
- #out = torch.round(out / 64 * 36 / 64)
- #print (self.weight.size()[1])
- #if self.weight.size()[1] >= 16 and self.weight.size()[1] <= 24:
- if self.weight.size()[1] >= 4 and self.weight.size()[2] * self.weight.size()[3] == 9:
- out = torch.round(out / 64 * 36 / 64)
- elif self.weight.size()[1] == 1:
- out = torch.round(out * 7 / 64)
- else:
- out = torch.round(out / 64)
- out = out * 4
- out[out > 63] = 63
- out[out < -63] = -63
- #out = out - torch.round(torch.mean(out))
- # out = out*4
- #out[out > 63] = 63
- #out[out < -63] = -63
- #else:
- # out = torch.round(out * 10 / 64)
- #print (torch.min(out), torch.max(out))
-
- # if not self.bias is None:
- # self.bias.org=self.bias.data.clone()
- # out += self.bias.view(1, -1, 1, 1).expand_as(out)
-
- return out
-
- class IdealCimConv2d(nn.Conv2d):
-
- def __init__(self, *kargs, **kwargs):
- super(IdealCimConv2d, self).__init__(*kargs, **kwargs)
-
-
- def forward(self, input):
- # if input.size(1) != 3:
- # input.data = Binarize(input.data)
- if not hasattr(self.weight,'org'):
- self.weight.org=self.weight.data.clone()
- self.weight.data=Binarize(self.weight.org)
- #input = torch.round(input)
- #input = input*2-1
- #scale = max(torch.max(input), -torch.min(input)) / 63
- #input = torch.round(input*2 / scale) - 63
- #if scale != 0:
- # input = torch.round(input / scale)
- #print (torch.max(input))
- #print(input)
- input = torch.round(input)
- #print(input)
- #print (torch.max(input))
- out = nn.functional.conv2d(input, self.weight, None, self.stride,
- self.padding, self.dilation, self.groups)
- out = out / 64
- out = out * 4
- out[out > 63] = 63
- out[out < -63] = -63
- return out
-
-
- device = 'cuda:0'
- '''
- H = [1024, 512]
- sim_model = torch.nn.Sequential(
- torch.nn.Linear(36, H[0]),
- torch.nn.Dropout(p=0.5),
- torch.nn.ReLU(),
- torch.nn.Linear(H[0], H[1]),
- torch.nn.Dropout(p=0.5),
- torch.nn.ReLU(),
- torch.nn.Linear(H[-1], 1),
- )
- sim_model.load_state_dict(torch.load('model_error.ckpt', map_location=torch.device('cuda:0')))
- sim_model = sim_model.to(device)
- sim_model.eval()
- '''
-
- class CimSimConv2d(nn.Conv2d):
- def __init__(self, *kargs, **kwargs):
- super(CimSimConv2d, self).__init__(*kargs, **kwargs)
-
- self.device = device
-
- def forward(self, input):
- if not hasattr(self.weight,'org'):
- self.weight.org=self.weight.data.clone()
- self.weight.data=Binarize(self.weight.org)
-
- #scale = max(torch.max(input), -torch.min(input)) / 63
- #if scale != 0:
- # input = torch.round(input / scale)
- #''' random error
- #out = nn.functional.conv2d(input, self.weight, None, self.stride,
- # self.padding, self.dilation, self.groups)
- #out = torch.round(out / 64 * 36 / 64)
- #randrange = (self.weight.size()[1] // 4)
- #for _ in range(randrange):
- # out += torch.randint(-1, 1, out.size(), device=device)
- #out[out>63] = 63
- #out[out<-63] -63
- #'''
- input = torch.round(input)
- out2 = self.simconv(input, self.weight)
- '''
- if torch.max(out2) < 32:
- out2 = out2 * 2
- if torch.max(out2) < 32:
- out2 = out2 * 2
- if torch.max(out2) < 32:
- out2 = out2 * 2
- '''
- out2 = out2 * 4
- out2[out2 > 63] = 63
- out2[out2 < -63] = -63
- #print (self.weight.data.size())
- #print (torch.max(out2), torch.min(out2))
- #print (torch.max(out-out2), torch.min(out-out2))
- #out = nn.functional.conv2d(input, self.weight, None, self.stride,
- # self.padding, self.dilation, self.groups)
- #print(input.size(), self.weight.size(), out.size())
-
- #if not self.bias is None:
- # self.bias.org=self.bias.data.clone()
- # out += self.bias.view(1, -1, 1, 1).expand_as(out)
-
- return out2
-
- def simconv(self, input_a, weight):
- #print(input_a.size(), weight.size())
- batch_size = input_a.size()[0]
- out_channel = weight.size()[0]
- out_width = input_a.size()[2] - 2 * (weight.size()[2] // 2)
- out_height = input_a.size()[3] - 2 * (weight.size()[3] // 2)
- simout = torch.zeros(batch_size, out_channel, out_width, out_height, dtype = input_a.dtype).to(device)
- first = True
- #''' Mapping Table
- if weight.size()[2] == 7:
- kernel_group = 1
- else:
- kernel_group = 4
- Digital_input_split = torch.split(input_a, kernel_group, dim=1)
- binary_weight_split = torch.split(weight, kernel_group, dim=1)
- for i in range(len(Digital_input_split)):
- temp_output = nn.functional.conv2d(Digital_input_split[i], binary_weight_split[i], None, self.stride, self.padding, self.dilation, self.groups)
- #temp_output = torch.round(temp_output / 64 * 36 / 64)
- temp_output = torch.round(temp_output / 64)
- temp_output = Mapping.apply(temp_output)
- simout += temp_output + 2
- #print (torch.max(simout), torch.min(simout))
- #'''
- ''' Error model
- for n in range(batch_size):
- for c in range(out_channel):
- w = torch.reshape(weight[c], (-1,)).to(device)
- inputs = []
- for i in range(out_width):
- for j in range(out_height):
- input = torch.reshape(input_a[n, :, i: i + weight.size()[2], j: j + weight.size()[3]], (-1,))
- #print (w.size(), input.size())
- # simout[n][c][i][j] = sum(w*input)
- # TODO
- simout[n][c][i][j] = self.cim_conv_tmp(input, w)
- #'''
- #print (len(input))
- #print (simout.size())
- # out = nn.functional.conv2d(input_a, weight)
- return simout
-
- def cim_conv_tmp(self, input, weight):
- assert len(input) == len(weight)
-
- raw_sum = 0
-
- if len(weight) == 3:
-
- for i in range((len(input)-1) // 36 + 1):
- data_x = input[i*36:i*36+36] * weight[i*36:i*36+36]
-
-
- row = int(Decimal(float(sum(data_x)/64.0)).quantize(0, ROUND_HALF_UP))
- #''' Error model
- if len(data_x) < 36:
- data_x = torch.cat((data_x, torch.zeros(36 - len(data_x), dtype=data_x.dtype)))
- try:
- #ensor_x = torch.Tensor(data_x).to(self.device)
- tensor_x = data_x.to(device)
- except:
- print (data_x, len())
- y_pred = sim_model(tensor_x)
- if int(y_pred[0]) > 10:
- adjust = 10
- elif int(y_pred[0]) < -10:
- adjust = -10
- else:
- adjust = int(y_pred[0])
- #print (tensor_x, y_pred)
- raw_sum += (row + adjust + 2)
- #'''
- #if row in self.mappingTable:
- # row = self.mappingTable[row]
- #raw_sum += row
- #raw_sum += row
- else:
- for i in range((len(input)-1) // 49 + 1):
- data_x = input[i*49:i*49+49] * weight[i*49:i*49+49]
-
-
- row = int(Decimal(float(sum(data_x)/64.0)).quantize(0, ROUND_HALF_UP))
- #''' Error model
- if len(data_x) < 49:
- data_x = torch.cat((data_x, torch.zeros(49 - len(data_x), dtype=data_x.dtype)))
- try:
- #ensor_x = torch.Tensor(data_x).to(self.device)
- tensor_x = data_x.to(device)
- except:
- print (data_x, len())
- y_pred = sim_model(tensor_x)
- if int(y_pred[0]) > 10:
- adjust = 10
- elif int(y_pred[0]) < -10:
- adjust = -10
- else:
- adjust = int(y_pred[0])
- #print (tensor_x, y_pred)
- raw_sum += (row + adjust + 2)
- #print (raw_sum)
- return raw_sum
-
- class Mapping(torch.autograd.Function):
- @staticmethod
- def forward(ctx, input):
- output = input.clone()
-
- output[input==-1] = -4
- output[input==-2] = -5
- output[input==-3] = -6
- output[input==-4] = -7
- output[input==-5] = -9
- output[input==-6] = -9
- output[input==-7] = -11
- output[input==-8] = -11
- output[input==-9] = -13
- output[input==-10] = -13
- output[input==-11] = -17
- output[input==-12] = -17
- output[input==-13] = -17
- output[input==-14] = -19
- output[input==-15] = -19
- output[input==-16] = -21
- output[input==-17] = -21
- output[input==-18] = -23
- output[input==-19] = -25
- output[input==-20] = -25
- output[input==-21] = -25
- output[input==-22] = -25
- output[input==-23] = -27
- output[input==-24] = -27
- output[input==-25] = -29
- output[input==-26] = -29
- output[input==-27] = -29
- output[input==-28] = -31
- output[input==-29] = -31
- output[input==-30] = -33
- output[input==-31] = -33
- output[input==-32] = -35
- output[input==-33] = -35
- output[input==-34] = -35
- #output[input==-35] = -35
-
- output[input==0] = -2
- output[input==1] = -1
- output[input==2] = 1
- output[input==3] = 2
- #output[input==4] = 4
- output[input==5] = 4
- #output[input==6] = 6
- output[input==7] = 8
- #output[input==8] = 8
- output[input==9] = 10
- #output[input==10] = 10
- output[input==11] = 12
- #output[input==12] = 12
- output[input==13] = 16
- output[input==14] = 16
- output[input==15] = 16
- #output[input==16] = 16
- output[input==17] = 18
- output[input==18] = 20
- output[input==19] = 20
- output[input==20] = 24
- output[input==21] = 24
- output[input==22] = 24
- output[input==23] = 26
- output[input==24] = 26
- output[input==25] = 28
- output[input==26] = 28
- output[input==27] = 28
- output[input==28] = 30
- output[input==29] = 30
- output[input==30] = 32
- output[input==31] = 32
- output[input==32] = 34
- output[input==33] = 34
- output[input==34] = 34
- output[input==35] = 34
- return output
- def backward(ctx, grad_output):
- return grad_output
|