You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

376 lines
13 KiB

  1. import torch
  2. import pdb
  3. import torch.nn as nn
  4. import math
  5. from torch.autograd import Variable
  6. from torch.autograd import Function
  7. from decimal import Decimal, ROUND_HALF_UP
  8. import numpy as np
  9. def Binarize(tensor,quant_mode='det'):
  10. if quant_mode=='det':
  11. return tensor.sign()
  12. else:
  13. return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)
  14. def Ninarize(tensor, quant_number, quant_mode='det'):
  15. #return tensor.add(1).mul(quant_number+1).div(2).floor().clamp(0, quant_number).mul(2).add(-quant_number)
  16. return tensor.add(quant_number).mul(quant_number+1).div(2*quant_number).floor().clamp(0, quant_number).mul(2).add(-quant_number)
  17. LUT = torch.Tensor([-63, -62, -61, -60,
  18. -59, -58, -57, -56, -55, -54, -53, -52, -51, -50,
  19. -49, -48, -47, -46, -45, -44, -43, -42, -41, -40,
  20. -39, -38, -37, -36, -35, -35, -35, -35, -33, -33,
  21. -31, -31, -29, -29, -29, -27, -27, -25, -25, -25,
  22. -25, -23, -21, -21, -19, -19, -17, -17, -17, -13,
  23. -13, -11, -11, -9, -9, -7, -6, -5, -4, -2,
  24. -1, 1, 2, 4, 4, 6, 8, 8, 10,
  25. 10, 12, 12, 16, 16, 16, 16, 18, 20, 20,
  26. 24, 24, 24, 26, 26, 28, 28, 28, 30, 30,
  27. 32, 32, 34, 34, 34, 34, 36, 37, 38, 39,
  28. 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
  29. 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
  30. 60, 61, 62, 63]).long()
  31. LUT_OFFSET = 63
  32. class HingeLoss(nn.Module):
  33. def __init__(self):
  34. super(HingeLoss,self).__init__()
  35. self.margin=1.0
  36. def hinge_loss(self,input,target):
  37. #import pdb; pdb.set_trace()
  38. output=self.margin-input.mul(target)
  39. output[output.le(0)]=0
  40. return output.mean()
  41. def forward(self, input, target):
  42. return self.hinge_loss(input,target)
  43. class SqrtHingeLossFunction(Function):
  44. def __init__(self):
  45. super(SqrtHingeLossFunction,self).__init__()
  46. self.margin=1.0
  47. def forward(self, input, target):
  48. output=self.margin-input.mul(target)
  49. output[output.le(0)]=0
  50. self.save_for_backward(input, target)
  51. loss=output.mul(output).sum(0).sum(1).div(target.numel())
  52. return loss
  53. def backward(self,grad_output):
  54. input, target = self.saved_tensors
  55. output=self.margin-input.mul(target)
  56. output[output.le(0)]=0
  57. import pdb; pdb.set_trace()
  58. grad_output.resize_as_(input).copy_(target).mul_(-2).mul_(output)
  59. grad_output.mul_(output.ne(0).float())
  60. grad_output.div_(input.numel())
  61. return grad_output,grad_output
  62. def Quantize(tensor,quant_mode='det', params=None, numBits=8):
  63. tensor.clamp_(-2**(numBits-1),2**(numBits-1))
  64. if quant_mode=='det':
  65. tensor=tensor.mul(2**(numBits-1)).round().div(2**(numBits-1))
  66. else:
  67. tensor=tensor.mul(2**(numBits-1)).round().add(torch.rand(tensor.size()).add(-0.5)).div(2**(numBits-1))
  68. quant_fixed(tensor, params)
  69. return tensor
  70. #import torch.nn._functions as tnnf
  71. class BinarizeLinear(nn.Linear):
  72. def __init__(self, *kargs, **kwargs):
  73. super(BinarizeLinear, self).__init__(*kargs, **kwargs)
  74. def forward(self, input):
  75. # if input.size(1) != 784:
  76. # input.data=Binarize(input.data)
  77. if not hasattr(self.weight,'org'):
  78. self.weight.org=self.weight.data.clone()
  79. self.weight.data=Binarize(self.weight.org)
  80. out = nn.functional.linear(input, self.weight)
  81. if not self.bias is None:
  82. self.bias.org=self.bias.data.clone()
  83. out += self.bias.view(1, -1).expand_as(out)
  84. return out
  85. class BinarizeConv2d(nn.Conv2d):
  86. def __init__(self, *kargs, **kwargs):
  87. super(BinarizeConv2d, self).__init__(*kargs, **kwargs)
  88. def forward(self, input):
  89. # if input.size(1) != 3:
  90. # input.data = Binarize(input.data)
  91. if not hasattr(self.weight,'org'):
  92. self.weight.org=self.weight.data.clone()
  93. self.weight.data=Binarize(self.weight.org)
  94. #input = torch.round(input)
  95. #input = input*2-1
  96. #scale = max(torch.max(input), -torch.min(input)) / 63
  97. #input = torch.round(input*2 / scale) - 63
  98. #if scale != 0:
  99. # input = torch.round(input / scale)
  100. #print (torch.max(input))
  101. #print(input)
  102. input = torch.round(input)
  103. #print(input)
  104. #print (torch.max(input))
  105. out = nn.functional.conv2d(input, self.weight, None, self.stride,
  106. self.padding, self.dilation, self.groups)
  107. #print (torch.min(out), torch.max(out))
  108. #out = torch.round(out)
  109. #print (torch.min(out), torch.max(out))
  110. #print (torch.min(input), torch.max(input))
  111. #out = torch.round(out / 64 * 36 / 64)
  112. #print (self.weight.size()[1])
  113. #if self.weight.size()[1] >= 16 and self.weight.size()[1] <= 24:
  114. if self.weight.size()[1] >= 4 and self.weight.size()[2] * self.weight.size()[3] == 9:
  115. out = torch.round(out / 64 * 36 / 64)
  116. elif self.weight.size()[1] == 1:
  117. out = torch.round(out * 7 / 64)
  118. else:
  119. out = torch.round(out / 64)
  120. out = out * 4
  121. out[out > 63] = 63
  122. out[out < -63] = -63
  123. #out = out - torch.round(torch.mean(out))
  124. # out = out*4
  125. #out[out > 63] = 63
  126. #out[out < -63] = -63
  127. #else:
  128. # out = torch.round(out * 10 / 64)
  129. #print (torch.min(out), torch.max(out))
  130. # if not self.bias is None:
  131. # self.bias.org=self.bias.data.clone()
  132. # out += self.bias.view(1, -1, 1, 1).expand_as(out)
  133. return out
  134. class IdealCimConv2d(nn.Conv2d):
  135. def __init__(self, *kargs, **kwargs):
  136. super(IdealCimConv2d, self).__init__(*kargs, **kwargs)
  137. def forward(self, input):
  138. # if input.size(1) != 3:
  139. # input.data = Binarize(input.data)
  140. if not hasattr(self.weight,'org'):
  141. self.weight.org=self.weight.data.clone()
  142. self.weight.data=Binarize(self.weight.org)
  143. #input = torch.round(input)
  144. #input = input*2-1
  145. #scale = max(torch.max(input), -torch.min(input)) / 63
  146. #input = torch.round(input*2 / scale) - 63
  147. #if scale != 0:
  148. # input = torch.round(input / scale)
  149. #print (torch.max(input))
  150. #print(input)
  151. input = torch.round(input)
  152. #print(input)
  153. #print (torch.max(input))
  154. out = nn.functional.conv2d(input, self.weight, None, self.stride,
  155. self.padding, self.dilation, self.groups)
  156. out = out / 64
  157. out = out * 4
  158. out[out > 63] = 63
  159. out[out < -63] = -63
  160. return out
  161. device = 'cuda:0'
  162. '''
  163. H = [1024, 512]
  164. sim_model = torch.nn.Sequential(
  165. torch.nn.Linear(36, H[0]),
  166. torch.nn.Dropout(p=0.5),
  167. torch.nn.ReLU(),
  168. torch.nn.Linear(H[0], H[1]),
  169. torch.nn.Dropout(p=0.5),
  170. torch.nn.ReLU(),
  171. torch.nn.Linear(H[-1], 1),
  172. )
  173. sim_model.load_state_dict(torch.load('model_error.ckpt', map_location=torch.device('cuda:0')))
  174. sim_model = sim_model.to(device)
  175. sim_model.eval()
  176. '''
  177. class CimSimConv2d(nn.Conv2d):
  178. def __init__(self, *kargs, **kwargs):
  179. super(CimSimConv2d, self).__init__(*kargs, **kwargs)
  180. self.device = device
  181. nn.init.uniform_(self.weight.data, a = -1., b = 1.)
  182. def forward(self, input):
  183. if not hasattr(self.weight,'org'):
  184. self.weight.org=self.weight.data.clone()
  185. #print('In:', torch.max(self.weight.org), torch.min(self.weight.org))
  186. #self.weight.data=Binarize(self.weight.org)
  187. self.weight.data=Ninarize(self.weight.org, 1)
  188. #print('out:', torch.max(self.weight.data), torch.min(self.weight.data))
  189. #scale = max(torch.max(input), -torch.min(input)) / 63
  190. #if scale != 0:
  191. # input = torch.round(input / scale)
  192. #''' random error
  193. out = nn.functional.conv2d(input, self.weight, None, self.stride,
  194. self.padding, self.dilation, self.groups)
  195. out = torch.round(out / 64)
  196. #randrange = (self.weight.size()[1] // 4)
  197. #for _ in range(randrange):
  198. # out += torch.randint(-1, 1, out.size(), device=device)
  199. #out[out>63] = 63
  200. #out[out<-63] -63
  201. #'''
  202. input = torch.round(input)
  203. out2 = self.simconv(input, self.weight)
  204. '''
  205. if torch.max(out2) < 32:
  206. out2 = out2 * 2
  207. if torch.max(out2) < 32:
  208. out2 = out2 * 2
  209. if torch.max(out2) < 32:
  210. out2 = out2 * 2
  211. '''
  212. #print ('in, weight, out')
  213. '''
  214. print ('round')
  215. #print (torch.max(input), torch.min(input))
  216. #print (torch.sum(input), torch.sum(input))
  217. #print (torch.max(self.weight), torch.min(self.weight))
  218. #print (torch.sum(self.weight), torch.sum(self.weight))
  219. print (torch.max(out), torch.min(out))
  220. print (torch.max(out2), torch.min(out2))
  221. #'''
  222. out2 = out2 * 4
  223. out2[out2 > 63] = 63
  224. out2[out2 < -63] = -63
  225. #print (self.weight.data.size())
  226. #print (torch.max(out-out2), torch.min(out-out2))
  227. #out = nn.functional.conv2d(input, self.weight, None, self.stride,
  228. # self.padding, self.dilation, self.groups)
  229. #print(input.size(), self.weight.size(), out.size())
  230. #if not self.bias is None:
  231. # self.bias.org=self.bias.data.clone()
  232. # out += self.bias.view(1, -1, 1, 1).expand_as(out)
  233. return out2
  234. def simconv(self, input_a, weight):
  235. #print(input_a.size(), weight.size())
  236. batch_size = input_a.size()[0]
  237. out_channel = weight.size()[0]
  238. out_width = input_a.size()[2] - 2 * (weight.size()[2] // 2)
  239. out_height = input_a.size()[3] - 2 * (weight.size()[3] // 2)
  240. simout = torch.zeros(batch_size, out_channel, out_width, out_height, dtype = input_a.dtype).to(input_a.device)
  241. first = True
  242. #''' Mapping Table
  243. global LUT
  244. LUT = LUT.to(input_a.device)
  245. if weight.size()[2] == 7:
  246. kernel_group = 1
  247. else:
  248. kernel_group = 4
  249. Digital_input_split = torch.split(input_a, kernel_group, dim=1)
  250. binary_weight_split = torch.split(weight, kernel_group, dim=1)
  251. for i in range(len(Digital_input_split)):
  252. temp_output = nn.functional.conv2d(Digital_input_split[i], binary_weight_split[i], None, self.stride, self.padding, self.dilation, self.groups)
  253. temp_output = torch.round(temp_output / 64)
  254. temp_output += LUT_OFFSET
  255. temp_output = LUT[temp_output.long()]
  256. simout += temp_output + 2
  257. #print (torch.max(simout), torch.min(simout))
  258. #'''
  259. ''' Error model
  260. for n in range(batch_size):
  261. for c in range(out_channel):
  262. w = torch.reshape(weight[c], (-1,)).to(device)
  263. inputs = []
  264. for i in range(out_width):
  265. for j in range(out_height):
  266. input = torch.reshape(input_a[n, :, i: i + weight.size()[2], j: j + weight.size()[3]], (-1,))
  267. #print (w.size(), input.size())
  268. # simout[n][c][i][j] = sum(w*input)
  269. # TODO
  270. simout[n][c][i][j] = self.cim_conv_tmp(input, w)
  271. #'''
  272. #print (len(input))
  273. #print (simout.size())
  274. # out = nn.functional.conv2d(input_a, weight)
  275. return simout
  276. def cim_conv_tmp(self, input, weight):
  277. assert len(input) == len(weight)
  278. raw_sum = 0
  279. if len(weight) == 3:
  280. for i in range((len(input)-1) // 36 + 1):
  281. data_x = input[i*36:i*36+36] * weight[i*36:i*36+36]
  282. row = int(Decimal(float(sum(data_x)/64.0)).quantize(0, ROUND_HALF_UP))
  283. #''' Error model
  284. if len(data_x) < 36:
  285. data_x = torch.cat((data_x, torch.zeros(36 - len(data_x), dtype=data_x.dtype)))
  286. try:
  287. #ensor_x = torch.Tensor(data_x).to(self.device)
  288. tensor_x = data_x.to(device)
  289. except:
  290. print (data_x, len())
  291. y_pred = sim_model(tensor_x)
  292. if int(y_pred[0]) > 10:
  293. adjust = 10
  294. elif int(y_pred[0]) < -10:
  295. adjust = -10
  296. else:
  297. adjust = int(y_pred[0])
  298. #print (tensor_x, y_pred)
  299. raw_sum += (row + adjust + 2)
  300. #'''
  301. #if row in self.mappingTable:
  302. # row = self.mappingTable[row]
  303. #raw_sum += row
  304. #raw_sum += row
  305. else:
  306. for i in range((len(input)-1) // 49 + 1):
  307. data_x = input[i*49:i*49+49] * weight[i*49:i*49+49]
  308. row = int(Decimal(float(sum(data_x)/64.0)).quantize(0, ROUND_HALF_UP))
  309. #''' Error model
  310. if len(data_x) < 49:
  311. data_x = torch.cat((data_x, torch.zeros(49 - len(data_x), dtype=data_x.dtype)))
  312. try:
  313. #ensor_x = torch.Tensor(data_x).to(self.device)
  314. tensor_x = data_x.to(device)
  315. except:
  316. print (data_x, len())
  317. y_pred = sim_model(tensor_x)
  318. if int(y_pred[0]) > 10:
  319. adjust = 10
  320. elif int(y_pred[0]) < -10:
  321. adjust = -10
  322. else:
  323. adjust = int(y_pred[0])
  324. #print (tensor_x, y_pred)
  325. raw_sum += (row + adjust + 2)
  326. #print (raw_sum)
  327. return raw_sum