You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

423 lines
14 KiB

  1. import torch
  2. import pdb
  3. import torch.nn as nn
  4. import math
  5. from torch.autograd import Variable
  6. from torch.autograd import Function
  7. from decimal import Decimal, ROUND_HALF_UP
  8. import numpy as np
  9. def Binarize(tensor,quant_mode='det'):
  10. if quant_mode=='det':
  11. return tensor.sign()
  12. else:
  13. return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)
  14. class HingeLoss(nn.Module):
  15. def __init__(self):
  16. super(HingeLoss,self).__init__()
  17. self.margin=1.0
  18. def hinge_loss(self,input,target):
  19. #import pdb; pdb.set_trace()
  20. output=self.margin-input.mul(target)
  21. output[output.le(0)]=0
  22. return output.mean()
  23. def forward(self, input, target):
  24. return self.hinge_loss(input,target)
  25. class SqrtHingeLossFunction(Function):
  26. def __init__(self):
  27. super(SqrtHingeLossFunction,self).__init__()
  28. self.margin=1.0
  29. def forward(self, input, target):
  30. output=self.margin-input.mul(target)
  31. output[output.le(0)]=0
  32. self.save_for_backward(input, target)
  33. loss=output.mul(output).sum(0).sum(1).div(target.numel())
  34. return loss
  35. def backward(self,grad_output):
  36. input, target = self.saved_tensors
  37. output=self.margin-input.mul(target)
  38. output[output.le(0)]=0
  39. import pdb; pdb.set_trace()
  40. grad_output.resize_as_(input).copy_(target).mul_(-2).mul_(output)
  41. grad_output.mul_(output.ne(0).float())
  42. grad_output.div_(input.numel())
  43. return grad_output,grad_output
  44. def Quantize(tensor,quant_mode='det', params=None, numBits=8):
  45. tensor.clamp_(-2**(numBits-1),2**(numBits-1))
  46. if quant_mode=='det':
  47. tensor=tensor.mul(2**(numBits-1)).round().div(2**(numBits-1))
  48. else:
  49. tensor=tensor.mul(2**(numBits-1)).round().add(torch.rand(tensor.size()).add(-0.5)).div(2**(numBits-1))
  50. quant_fixed(tensor, params)
  51. return tensor
  52. #import torch.nn._functions as tnnf
  53. class BinarizeLinear(nn.Linear):
  54. def __init__(self, *kargs, **kwargs):
  55. super(BinarizeLinear, self).__init__(*kargs, **kwargs)
  56. def forward(self, input):
  57. # if input.size(1) != 784:
  58. # input.data=Binarize(input.data)
  59. if not hasattr(self.weight,'org'):
  60. self.weight.org=self.weight.data.clone()
  61. self.weight.data=Binarize(self.weight.org)
  62. out = nn.functional.linear(input, self.weight)
  63. if not self.bias is None:
  64. self.bias.org=self.bias.data.clone()
  65. out += self.bias.view(1, -1).expand_as(out)
  66. return out
  67. class BinarizeConv2d(nn.Conv2d):
  68. def __init__(self, *kargs, **kwargs):
  69. super(BinarizeConv2d, self).__init__(*kargs, **kwargs)
  70. def forward(self, input):
  71. # if input.size(1) != 3:
  72. # input.data = Binarize(input.data)
  73. if not hasattr(self.weight,'org'):
  74. self.weight.org=self.weight.data.clone()
  75. self.weight.data=Binarize(self.weight.org)
  76. #input = torch.round(input)
  77. #input = input*2-1
  78. #scale = max(torch.max(input), -torch.min(input)) / 63
  79. #input = torch.round(input*2 / scale) - 63
  80. #if scale != 0:
  81. # input = torch.round(input / scale)
  82. #print (torch.max(input))
  83. #print(input)
  84. input = torch.round(input)
  85. #print(input)
  86. #print (torch.max(input))
  87. out = nn.functional.conv2d(input, self.weight, None, self.stride,
  88. self.padding, self.dilation, self.groups)
  89. #print (torch.min(out), torch.max(out))
  90. #out = torch.round(out)
  91. #print (torch.min(out), torch.max(out))
  92. #print (torch.min(input), torch.max(input))
  93. #out = torch.round(out / 64 * 36 / 64)
  94. #print (self.weight.size()[1])
  95. #if self.weight.size()[1] >= 16 and self.weight.size()[1] <= 24:
  96. if self.weight.size()[1] >= 4 and self.weight.size()[2] * self.weight.size()[3] == 9:
  97. out = torch.round(out / 64 * 36 / 64)
  98. elif self.weight.size()[1] == 1:
  99. out = torch.round(out * 7 / 64)
  100. else:
  101. out = torch.round(out / 64)
  102. out = out * 4
  103. out[out > 63] = 63
  104. out[out < -63] = -63
  105. #out = out - torch.round(torch.mean(out))
  106. # out = out*4
  107. #out[out > 63] = 63
  108. #out[out < -63] = -63
  109. #else:
  110. # out = torch.round(out * 10 / 64)
  111. #print (torch.min(out), torch.max(out))
  112. # if not self.bias is None:
  113. # self.bias.org=self.bias.data.clone()
  114. # out += self.bias.view(1, -1, 1, 1).expand_as(out)
  115. return out
  116. class IdealCimConv2d(nn.Conv2d):
  117. def __init__(self, *kargs, **kwargs):
  118. super(IdealCimConv2d, self).__init__(*kargs, **kwargs)
  119. def forward(self, input):
  120. # if input.size(1) != 3:
  121. # input.data = Binarize(input.data)
  122. if not hasattr(self.weight,'org'):
  123. self.weight.org=self.weight.data.clone()
  124. self.weight.data=Binarize(self.weight.org)
  125. #input = torch.round(input)
  126. #input = input*2-1
  127. #scale = max(torch.max(input), -torch.min(input)) / 63
  128. #input = torch.round(input*2 / scale) - 63
  129. #if scale != 0:
  130. # input = torch.round(input / scale)
  131. #print (torch.max(input))
  132. #print(input)
  133. input = torch.round(input)
  134. #print(input)
  135. #print (torch.max(input))
  136. out = nn.functional.conv2d(input, self.weight, None, self.stride,
  137. self.padding, self.dilation, self.groups)
  138. out = out / 64
  139. out = out * 4
  140. out[out > 63] = 63
  141. out[out < -63] = -63
  142. return out
  143. device = 'cuda:0'
  144. '''
  145. H = [1024, 512]
  146. sim_model = torch.nn.Sequential(
  147. torch.nn.Linear(36, H[0]),
  148. torch.nn.Dropout(p=0.5),
  149. torch.nn.ReLU(),
  150. torch.nn.Linear(H[0], H[1]),
  151. torch.nn.Dropout(p=0.5),
  152. torch.nn.ReLU(),
  153. torch.nn.Linear(H[-1], 1),
  154. )
  155. sim_model.load_state_dict(torch.load('model_error.ckpt', map_location=torch.device('cuda:0')))
  156. sim_model = sim_model.to(device)
  157. sim_model.eval()
  158. '''
  159. class CimSimConv2d(nn.Conv2d):
  160. def __init__(self, *kargs, **kwargs):
  161. super(CimSimConv2d, self).__init__(*kargs, **kwargs)
  162. self.device = device
  163. def forward(self, input):
  164. if not hasattr(self.weight,'org'):
  165. self.weight.org=self.weight.data.clone()
  166. self.weight.data=Binarize(self.weight.org)
  167. #scale = max(torch.max(input), -torch.min(input)) / 63
  168. #if scale != 0:
  169. # input = torch.round(input / scale)
  170. #''' random error
  171. #out = nn.functional.conv2d(input, self.weight, None, self.stride,
  172. # self.padding, self.dilation, self.groups)
  173. #out = torch.round(out / 64 * 36 / 64)
  174. #randrange = (self.weight.size()[1] // 4)
  175. #for _ in range(randrange):
  176. # out += torch.randint(-1, 1, out.size(), device=device)
  177. #out[out>63] = 63
  178. #out[out<-63] -63
  179. #'''
  180. input = torch.round(input)
  181. out2 = self.simconv(input, self.weight)
  182. '''
  183. if torch.max(out2) < 32:
  184. out2 = out2 * 2
  185. if torch.max(out2) < 32:
  186. out2 = out2 * 2
  187. if torch.max(out2) < 32:
  188. out2 = out2 * 2
  189. '''
  190. out2 = out2 * 4
  191. out2[out2 > 63] = 63
  192. out2[out2 < -63] = -63
  193. #print (self.weight.data.size())
  194. #print (torch.max(out2), torch.min(out2))
  195. #print (torch.max(out-out2), torch.min(out-out2))
  196. #out = nn.functional.conv2d(input, self.weight, None, self.stride,
  197. # self.padding, self.dilation, self.groups)
  198. #print(input.size(), self.weight.size(), out.size())
  199. #if not self.bias is None:
  200. # self.bias.org=self.bias.data.clone()
  201. # out += self.bias.view(1, -1, 1, 1).expand_as(out)
  202. return out2
  203. def simconv(self, input_a, weight):
  204. #print(input_a.size(), weight.size())
  205. batch_size = input_a.size()[0]
  206. out_channel = weight.size()[0]
  207. out_width = input_a.size()[2] - 2 * (weight.size()[2] // 2)
  208. out_height = input_a.size()[3] - 2 * (weight.size()[3] // 2)
  209. simout = torch.zeros(batch_size, out_channel, out_width, out_height, dtype = input_a.dtype).to(device)
  210. first = True
  211. #''' Mapping Table
  212. if weight.size()[2] == 7:
  213. kernel_group = 1
  214. else:
  215. kernel_group = 4
  216. Digital_input_split = torch.split(input_a, kernel_group, dim=1)
  217. binary_weight_split = torch.split(weight, kernel_group, dim=1)
  218. for i in range(len(Digital_input_split)):
  219. temp_output = nn.functional.conv2d(Digital_input_split[i], binary_weight_split[i], None, self.stride, self.padding, self.dilation, self.groups)
  220. #temp_output = torch.round(temp_output / 64 * 36 / 64)
  221. temp_output = torch.round(temp_output / 64)
  222. temp_output = Mapping.apply(temp_output)
  223. simout += temp_output + 2
  224. #print (torch.max(simout), torch.min(simout))
  225. #'''
  226. ''' Error model
  227. for n in range(batch_size):
  228. for c in range(out_channel):
  229. w = torch.reshape(weight[c], (-1,)).to(device)
  230. inputs = []
  231. for i in range(out_width):
  232. for j in range(out_height):
  233. input = torch.reshape(input_a[n, :, i: i + weight.size()[2], j: j + weight.size()[3]], (-1,))
  234. #print (w.size(), input.size())
  235. # simout[n][c][i][j] = sum(w*input)
  236. # TODO
  237. simout[n][c][i][j] = self.cim_conv_tmp(input, w)
  238. #'''
  239. #print (len(input))
  240. #print (simout.size())
  241. # out = nn.functional.conv2d(input_a, weight)
  242. return simout
  243. def cim_conv_tmp(self, input, weight):
  244. assert len(input) == len(weight)
  245. raw_sum = 0
  246. if len(weight) == 3:
  247. for i in range((len(input)-1) // 36 + 1):
  248. data_x = input[i*36:i*36+36] * weight[i*36:i*36+36]
  249. row = int(Decimal(float(sum(data_x)/64.0)).quantize(0, ROUND_HALF_UP))
  250. #''' Error model
  251. if len(data_x) < 36:
  252. data_x = torch.cat((data_x, torch.zeros(36 - len(data_x), dtype=data_x.dtype)))
  253. try:
  254. #ensor_x = torch.Tensor(data_x).to(self.device)
  255. tensor_x = data_x.to(device)
  256. except:
  257. print (data_x, len())
  258. y_pred = sim_model(tensor_x)
  259. if int(y_pred[0]) > 10:
  260. adjust = 10
  261. elif int(y_pred[0]) < -10:
  262. adjust = -10
  263. else:
  264. adjust = int(y_pred[0])
  265. #print (tensor_x, y_pred)
  266. raw_sum += (row + adjust + 2)
  267. #'''
  268. #if row in self.mappingTable:
  269. # row = self.mappingTable[row]
  270. #raw_sum += row
  271. #raw_sum += row
  272. else:
  273. for i in range((len(input)-1) // 49 + 1):
  274. data_x = input[i*49:i*49+49] * weight[i*49:i*49+49]
  275. row = int(Decimal(float(sum(data_x)/64.0)).quantize(0, ROUND_HALF_UP))
  276. #''' Error model
  277. if len(data_x) < 49:
  278. data_x = torch.cat((data_x, torch.zeros(49 - len(data_x), dtype=data_x.dtype)))
  279. try:
  280. #ensor_x = torch.Tensor(data_x).to(self.device)
  281. tensor_x = data_x.to(device)
  282. except:
  283. print (data_x, len())
  284. y_pred = sim_model(tensor_x)
  285. if int(y_pred[0]) > 10:
  286. adjust = 10
  287. elif int(y_pred[0]) < -10:
  288. adjust = -10
  289. else:
  290. adjust = int(y_pred[0])
  291. #print (tensor_x, y_pred)
  292. raw_sum += (row + adjust + 2)
  293. #print (raw_sum)
  294. return raw_sum
  295. class Mapping(torch.autograd.Function):
  296. @staticmethod
  297. def forward(ctx, input):
  298. output = input.clone()
  299. output[input==-1] = -4
  300. output[input==-2] = -5
  301. output[input==-3] = -6
  302. output[input==-4] = -7
  303. output[input==-5] = -9
  304. output[input==-6] = -9
  305. output[input==-7] = -11
  306. output[input==-8] = -11
  307. output[input==-9] = -13
  308. output[input==-10] = -13
  309. output[input==-11] = -17
  310. output[input==-12] = -17
  311. output[input==-13] = -17
  312. output[input==-14] = -19
  313. output[input==-15] = -19
  314. output[input==-16] = -21
  315. output[input==-17] = -21
  316. output[input==-18] = -23
  317. output[input==-19] = -25
  318. output[input==-20] = -25
  319. output[input==-21] = -25
  320. output[input==-22] = -25
  321. output[input==-23] = -27
  322. output[input==-24] = -27
  323. output[input==-25] = -29
  324. output[input==-26] = -29
  325. output[input==-27] = -29
  326. output[input==-28] = -31
  327. output[input==-29] = -31
  328. output[input==-30] = -33
  329. output[input==-31] = -33
  330. output[input==-32] = -35
  331. output[input==-33] = -35
  332. output[input==-34] = -35
  333. #output[input==-35] = -35
  334. output[input==0] = -2
  335. output[input==1] = -1
  336. output[input==2] = 1
  337. output[input==3] = 2
  338. #output[input==4] = 4
  339. output[input==5] = 4
  340. #output[input==6] = 6
  341. output[input==7] = 8
  342. #output[input==8] = 8
  343. output[input==9] = 10
  344. #output[input==10] = 10
  345. output[input==11] = 12
  346. #output[input==12] = 12
  347. output[input==13] = 16
  348. output[input==14] = 16
  349. output[input==15] = 16
  350. #output[input==16] = 16
  351. output[input==17] = 18
  352. output[input==18] = 20
  353. output[input==19] = 20
  354. output[input==20] = 24
  355. output[input==21] = 24
  356. output[input==22] = 24
  357. output[input==23] = 26
  358. output[input==24] = 26
  359. output[input==25] = 28
  360. output[input==26] = 28
  361. output[input==27] = 28
  362. output[input==28] = 30
  363. output[input==29] = 30
  364. output[input==30] = 32
  365. output[input==31] = 32
  366. output[input==32] = 34
  367. output[input==33] = 34
  368. output[input==34] = 34
  369. output[input==35] = 34
  370. return output
  371. def backward(ctx, grad_output):
  372. return grad_output