|
@ -16,6 +16,25 @@ def Binarize(tensor,quant_mode='det'): |
|
|
return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1) |
|
|
return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def Ninarize(tensor, quant_number, quant_mode='det'): |
|
|
|
|
|
#return tensor.add(1).mul(quant_number+1).div(2).floor().clamp(0, quant_number).mul(2).add(-quant_number) |
|
|
|
|
|
return tensor.add(quant_number).mul(quant_number+1).div(2*quant_number).floor().clamp(0, quant_number).mul(2).add(-quant_number) |
|
|
|
|
|
|
|
|
|
|
|
LUT = torch.Tensor([-63, -62, -61, -60, |
|
|
|
|
|
-59, -58, -57, -56, -55, -54, -53, -52, -51, -50, |
|
|
|
|
|
-49, -48, -47, -46, -45, -44, -43, -42, -41, -40, |
|
|
|
|
|
-39, -38, -37, -36, -35, -35, -35, -35, -33, -33, |
|
|
|
|
|
-31, -31, -29, -29, -29, -27, -27, -25, -25, -25, |
|
|
|
|
|
-25, -23, -21, -21, -19, -19, -17, -17, -17, -13, |
|
|
|
|
|
-13, -11, -11, -9, -9, -7, -6, -5, -4, -2, |
|
|
|
|
|
-1, 1, 2, 4, 4, 6, 8, 8, 10, |
|
|
|
|
|
10, 12, 12, 16, 16, 16, 16, 18, 20, 20, |
|
|
|
|
|
24, 24, 24, 26, 26, 28, 28, 28, 30, 30, |
|
|
|
|
|
32, 32, 34, 34, 34, 34, 36, 37, 38, 39, |
|
|
|
|
|
40, 41, 42, 43, 44, 45, 46, 47, 48, 49, |
|
|
|
|
|
50, 51, 52, 53, 54, 55, 56, 57, 58, 59, |
|
|
|
|
|
60, 61, 62, 63]).long() |
|
|
|
|
|
LUT_OFFSET = 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HingeLoss(nn.Module): |
|
|
class HingeLoss(nn.Module): |
|
@ -195,19 +214,23 @@ class CimSimConv2d(nn.Conv2d): |
|
|
super(CimSimConv2d, self).__init__(*kargs, **kwargs) |
|
|
super(CimSimConv2d, self).__init__(*kargs, **kwargs) |
|
|
|
|
|
|
|
|
self.device = device |
|
|
self.device = device |
|
|
|
|
|
nn.init.uniform_(self.weight.data, a = -1., b = 1.) |
|
|
|
|
|
|
|
|
def forward(self, input): |
|
|
def forward(self, input): |
|
|
if not hasattr(self.weight,'org'): |
|
|
if not hasattr(self.weight,'org'): |
|
|
self.weight.org=self.weight.data.clone() |
|
|
self.weight.org=self.weight.data.clone() |
|
|
self.weight.data=Binarize(self.weight.org) |
|
|
|
|
|
|
|
|
#print('In:', torch.max(self.weight.org), torch.min(self.weight.org)) |
|
|
|
|
|
#self.weight.data=Binarize(self.weight.org) |
|
|
|
|
|
self.weight.data=Ninarize(self.weight.org, 1) |
|
|
|
|
|
#print('out:', torch.max(self.weight.data), torch.min(self.weight.data)) |
|
|
|
|
|
|
|
|
#scale = max(torch.max(input), -torch.min(input)) / 63 |
|
|
#scale = max(torch.max(input), -torch.min(input)) / 63 |
|
|
#if scale != 0: |
|
|
#if scale != 0: |
|
|
# input = torch.round(input / scale) |
|
|
# input = torch.round(input / scale) |
|
|
#''' random error |
|
|
#''' random error |
|
|
#out = nn.functional.conv2d(input, self.weight, None, self.stride, |
|
|
|
|
|
# self.padding, self.dilation, self.groups) |
|
|
|
|
|
#out = torch.round(out / 64 * 36 / 64) |
|
|
|
|
|
|
|
|
out = nn.functional.conv2d(input, self.weight, None, self.stride, |
|
|
|
|
|
self.padding, self.dilation, self.groups) |
|
|
|
|
|
out = torch.round(out / 64) |
|
|
#randrange = (self.weight.size()[1] // 4) |
|
|
#randrange = (self.weight.size()[1] // 4) |
|
|
#for _ in range(randrange): |
|
|
#for _ in range(randrange): |
|
|
# out += torch.randint(-1, 1, out.size(), device=device) |
|
|
# out += torch.randint(-1, 1, out.size(), device=device) |
|
@ -224,11 +247,20 @@ class CimSimConv2d(nn.Conv2d): |
|
|
if torch.max(out2) < 32: |
|
|
if torch.max(out2) < 32: |
|
|
out2 = out2 * 2 |
|
|
out2 = out2 * 2 |
|
|
''' |
|
|
''' |
|
|
|
|
|
#print ('in, weight, out') |
|
|
|
|
|
''' |
|
|
|
|
|
print ('round') |
|
|
|
|
|
#print (torch.max(input), torch.min(input)) |
|
|
|
|
|
#print (torch.sum(input), torch.sum(input)) |
|
|
|
|
|
#print (torch.max(self.weight), torch.min(self.weight)) |
|
|
|
|
|
#print (torch.sum(self.weight), torch.sum(self.weight)) |
|
|
|
|
|
print (torch.max(out), torch.min(out)) |
|
|
|
|
|
print (torch.max(out2), torch.min(out2)) |
|
|
|
|
|
#''' |
|
|
out2 = out2 * 4 |
|
|
out2 = out2 * 4 |
|
|
out2[out2 > 63] = 63 |
|
|
out2[out2 > 63] = 63 |
|
|
out2[out2 < -63] = -63 |
|
|
out2[out2 < -63] = -63 |
|
|
#print (self.weight.data.size()) |
|
|
#print (self.weight.data.size()) |
|
|
#print (torch.max(out2), torch.min(out2)) |
|
|
|
|
|
#print (torch.max(out-out2), torch.min(out-out2)) |
|
|
#print (torch.max(out-out2), torch.min(out-out2)) |
|
|
#out = nn.functional.conv2d(input, self.weight, None, self.stride, |
|
|
#out = nn.functional.conv2d(input, self.weight, None, self.stride, |
|
|
# self.padding, self.dilation, self.groups) |
|
|
# self.padding, self.dilation, self.groups) |
|
@ -246,9 +278,11 @@ class CimSimConv2d(nn.Conv2d): |
|
|
out_channel = weight.size()[0] |
|
|
out_channel = weight.size()[0] |
|
|
out_width = input_a.size()[2] - 2 * (weight.size()[2] // 2) |
|
|
out_width = input_a.size()[2] - 2 * (weight.size()[2] // 2) |
|
|
out_height = input_a.size()[3] - 2 * (weight.size()[3] // 2) |
|
|
out_height = input_a.size()[3] - 2 * (weight.size()[3] // 2) |
|
|
simout = torch.zeros(batch_size, out_channel, out_width, out_height, dtype = input_a.dtype).to(device) |
|
|
|
|
|
|
|
|
simout = torch.zeros(batch_size, out_channel, out_width, out_height, dtype = input_a.dtype).to(input_a.device) |
|
|
first = True |
|
|
first = True |
|
|
#''' Mapping Table |
|
|
#''' Mapping Table |
|
|
|
|
|
global LUT |
|
|
|
|
|
LUT = LUT.to(input_a.device) |
|
|
if weight.size()[2] == 7: |
|
|
if weight.size()[2] == 7: |
|
|
kernel_group = 1 |
|
|
kernel_group = 1 |
|
|
else: |
|
|
else: |
|
@ -257,9 +291,9 @@ class CimSimConv2d(nn.Conv2d): |
|
|
binary_weight_split = torch.split(weight, kernel_group, dim=1) |
|
|
binary_weight_split = torch.split(weight, kernel_group, dim=1) |
|
|
for i in range(len(Digital_input_split)): |
|
|
for i in range(len(Digital_input_split)): |
|
|
temp_output = nn.functional.conv2d(Digital_input_split[i], binary_weight_split[i], None, self.stride, self.padding, self.dilation, self.groups) |
|
|
temp_output = nn.functional.conv2d(Digital_input_split[i], binary_weight_split[i], None, self.stride, self.padding, self.dilation, self.groups) |
|
|
#temp_output = torch.round(temp_output / 64 * 36 / 64) |
|
|
|
|
|
temp_output = torch.round(temp_output / 64) |
|
|
temp_output = torch.round(temp_output / 64) |
|
|
temp_output = Mapping.apply(temp_output) |
|
|
|
|
|
|
|
|
temp_output += LUT_OFFSET |
|
|
|
|
|
temp_output = LUT[temp_output.long()] |
|
|
simout += temp_output + 2 |
|
|
simout += temp_output + 2 |
|
|
#print (torch.max(simout), torch.min(simout)) |
|
|
#print (torch.max(simout), torch.min(simout)) |
|
|
#''' |
|
|
#''' |
|
@ -340,84 +374,3 @@ class CimSimConv2d(nn.Conv2d): |
|
|
raw_sum += (row + adjust + 2) |
|
|
raw_sum += (row + adjust + 2) |
|
|
#print (raw_sum) |
|
|
#print (raw_sum) |
|
|
return raw_sum |
|
|
return raw_sum |
|
|
|
|
|
|
|
|
class Mapping(torch.autograd.Function): |
|
|
|
|
|
@staticmethod |
|
|
|
|
|
def forward(ctx, input): |
|
|
|
|
|
output = input.clone() |
|
|
|
|
|
|
|
|
|
|
|
output[input==-1] = -4 |
|
|
|
|
|
output[input==-2] = -5 |
|
|
|
|
|
output[input==-3] = -6 |
|
|
|
|
|
output[input==-4] = -7 |
|
|
|
|
|
output[input==-5] = -9 |
|
|
|
|
|
output[input==-6] = -9 |
|
|
|
|
|
output[input==-7] = -11 |
|
|
|
|
|
output[input==-8] = -11 |
|
|
|
|
|
output[input==-9] = -13 |
|
|
|
|
|
output[input==-10] = -13 |
|
|
|
|
|
output[input==-11] = -17 |
|
|
|
|
|
output[input==-12] = -17 |
|
|
|
|
|
output[input==-13] = -17 |
|
|
|
|
|
output[input==-14] = -19 |
|
|
|
|
|
output[input==-15] = -19 |
|
|
|
|
|
output[input==-16] = -21 |
|
|
|
|
|
output[input==-17] = -21 |
|
|
|
|
|
output[input==-18] = -23 |
|
|
|
|
|
output[input==-19] = -25 |
|
|
|
|
|
output[input==-20] = -25 |
|
|
|
|
|
output[input==-21] = -25 |
|
|
|
|
|
output[input==-22] = -25 |
|
|
|
|
|
output[input==-23] = -27 |
|
|
|
|
|
output[input==-24] = -27 |
|
|
|
|
|
output[input==-25] = -29 |
|
|
|
|
|
output[input==-26] = -29 |
|
|
|
|
|
output[input==-27] = -29 |
|
|
|
|
|
output[input==-28] = -31 |
|
|
|
|
|
output[input==-29] = -31 |
|
|
|
|
|
output[input==-30] = -33 |
|
|
|
|
|
output[input==-31] = -33 |
|
|
|
|
|
output[input==-32] = -35 |
|
|
|
|
|
output[input==-33] = -35 |
|
|
|
|
|
output[input==-34] = -35 |
|
|
|
|
|
#output[input==-35] = -35 |
|
|
|
|
|
|
|
|
|
|
|
output[input==0] = -2 |
|
|
|
|
|
output[input==1] = -1 |
|
|
|
|
|
output[input==2] = 1 |
|
|
|
|
|
output[input==3] = 2 |
|
|
|
|
|
#output[input==4] = 4 |
|
|
|
|
|
output[input==5] = 4 |
|
|
|
|
|
#output[input==6] = 6 |
|
|
|
|
|
output[input==7] = 8 |
|
|
|
|
|
#output[input==8] = 8 |
|
|
|
|
|
output[input==9] = 10 |
|
|
|
|
|
#output[input==10] = 10 |
|
|
|
|
|
output[input==11] = 12 |
|
|
|
|
|
#output[input==12] = 12 |
|
|
|
|
|
output[input==13] = 16 |
|
|
|
|
|
output[input==14] = 16 |
|
|
|
|
|
output[input==15] = 16 |
|
|
|
|
|
#output[input==16] = 16 |
|
|
|
|
|
output[input==17] = 18 |
|
|
|
|
|
output[input==18] = 20 |
|
|
|
|
|
output[input==19] = 20 |
|
|
|
|
|
output[input==20] = 24 |
|
|
|
|
|
output[input==21] = 24 |
|
|
|
|
|
output[input==22] = 24 |
|
|
|
|
|
output[input==23] = 26 |
|
|
|
|
|
output[input==24] = 26 |
|
|
|
|
|
output[input==25] = 28 |
|
|
|
|
|
output[input==26] = 28 |
|
|
|
|
|
output[input==27] = 28 |
|
|
|
|
|
output[input==28] = 30 |
|
|
|
|
|
output[input==29] = 30 |
|
|
|
|
|
output[input==30] = 32 |
|
|
|
|
|
output[input==31] = 32 |
|
|
|
|
|
output[input==32] = 34 |
|
|
|
|
|
output[input==33] = 34 |
|
|
|
|
|
output[input==34] = 34 |
|
|
|
|
|
output[input==35] = 34 |
|
|
|
|
|
return output |
|
|
|
|
|
def backward(ctx, grad_output): |
|
|
|
|
|
return grad_output |
|
|
|