You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

248 lines
8.3 KiB

  1. import torch.nn as nn
  2. import torchvision.transforms as transforms
  3. import math
  4. from .binarized_modules import BinarizeLinear,BinarizeConv2d
  5. __all__ = ['resnet_binary']
  6. def Binaryconv3x3(in_planes, out_planes, stride=1):
  7. "3x3 convolution with padding"
  8. return BinarizeConv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  9. padding=1, bias=False)
  10. def conv3x3(in_planes, out_planes, stride=1):
  11. "3x3 convolution with padding"
  12. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  13. padding=1, bias=False)
  14. def init_model(model):
  15. for m in model.modules():
  16. if isinstance(m, BinarizeConv2d):
  17. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  18. m.weight.data.normal_(0, math.sqrt(2. / n))
  19. elif isinstance(m, nn.BatchNorm2d):
  20. m.weight.data.fill_(1)
  21. m.bias.data.zero_()
  22. class BasicBlock(nn.Module):
  23. expansion = 1
  24. def __init__(self, inplanes, planes, stride=1, downsample=None,do_bntan=True):
  25. super(BasicBlock, self).__init__()
  26. self.conv1 = Binaryconv3x3(inplanes, planes, stride)
  27. self.bn1 = nn.BatchNorm2d(planes)
  28. self.tanh1 = nn.Hardtanh(inplace=True)
  29. self.conv2 = Binaryconv3x3(planes, planes)
  30. self.tanh2 = nn.Hardtanh(inplace=True)
  31. self.bn2 = nn.BatchNorm2d(planes)
  32. self.downsample = downsample
  33. self.do_bntan=do_bntan;
  34. self.stride = stride
  35. def forward(self, x):
  36. residual = x.clone()
  37. out = self.conv1(x)
  38. out = self.bn1(out)
  39. out = self.tanh1(out)
  40. out = self.conv2(out)
  41. if self.downsample is not None:
  42. if residual.data.max()>1:
  43. import pdb; pdb.set_trace()
  44. residual = self.downsample(residual)
  45. out += residual
  46. if self.do_bntan:
  47. out = self.bn2(out)
  48. out = self.tanh2(out)
  49. return out
  50. class Bottleneck(nn.Module):
  51. expansion = 4
  52. def __init__(self, inplanes, planes, stride=1, downsample=None):
  53. super(Bottleneck, self).__init__()
  54. self.conv1 = BinarizeConv2d(inplanes, planes, kernel_size=1, bias=False)
  55. self.bn1 = nn.BatchNorm2d(planes)
  56. self.conv2 = BinarizeConv2d(planes, planes, kernel_size=3, stride=stride,
  57. padding=1, bias=False)
  58. self.bn2 = nn.BatchNorm2d(planes)
  59. self.conv3 = BinarizeConv2d(planes, planes * 4, kernel_size=1, bias=False)
  60. self.bn3 = nn.BatchNorm2d(planes * 4)
  61. self.tanh = nn.Hardtanh(inplace=True)
  62. self.downsample = downsample
  63. self.stride = stride
  64. def forward(self, x):
  65. residual = x
  66. import pdb; pdb.set_trace()
  67. out = self.conv1(x)
  68. out = self.bn1(out)
  69. out = self.tanh(out)
  70. out = self.conv2(out)
  71. out = self.bn2(out)
  72. out = self.tanh(out)
  73. out = self.conv3(out)
  74. out = self.bn3(out)
  75. if self.downsample is not None:
  76. residual = self.downsample(x)
  77. out += residual
  78. if self.do_bntan:
  79. out = self.bn2(out)
  80. out = self.tanh2(out)
  81. return out
  82. class ResNet(nn.Module):
  83. def __init__(self):
  84. super(ResNet, self).__init__()
  85. def _make_layer(self, block, planes, blocks, stride=1,do_bntan=True):
  86. downsample = None
  87. if stride != 1 or self.inplanes != planes * block.expansion:
  88. downsample = nn.Sequential(
  89. BinarizeConv2d(self.inplanes, planes * block.expansion,
  90. kernel_size=1, stride=stride, bias=False),
  91. nn.BatchNorm2d(planes * block.expansion),
  92. )
  93. layers = []
  94. layers.append(block(self.inplanes, planes, stride, downsample))
  95. self.inplanes = planes * block.expansion
  96. for i in range(1, blocks-1):
  97. layers.append(block(self.inplanes, planes))
  98. layers.append(block(self.inplanes, planes,do_bntan=do_bntan))
  99. return nn.Sequential(*layers)
  100. def forward(self, x):
  101. x = self.conv1(x)
  102. x = self.maxpool(x)
  103. x = self.bn1(x)
  104. x = self.tanh1(x)
  105. x = self.layer1(x)
  106. x = self.layer2(x)
  107. x = self.layer3(x)
  108. x = self.layer4(x)
  109. x = self.avgpool(x)
  110. x = x.view(x.size(0), -1)
  111. x = self.bn2(x)
  112. x = self.tanh2(x)
  113. x = self.fc(x)
  114. x = self.bn3(x)
  115. x = self.logsoftmax(x)
  116. return x
  117. class ResNet_imagenet(ResNet):
  118. def __init__(self, num_classes=1000,
  119. block=Bottleneck, layers=[3, 4, 23, 3]):
  120. super(ResNet_imagenet, self).__init__()
  121. self.inplanes = 64
  122. self.conv1 = BinarizeConv2d(3, 64, kernel_size=7, stride=2, padding=3,
  123. bias=False)
  124. self.bn1 = nn.BatchNorm2d(64)
  125. self.tanh = nn.Hardtanh(inplace=True)
  126. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  127. self.layer1 = self._make_layer(block, 64, layers[0])
  128. self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
  129. self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
  130. self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
  131. self.avgpool = nn.AvgPool2d(7)
  132. self.fc = BinarizeLinear(512 * block.expansion, num_classes)
  133. init_model(self)
  134. self.regime = {
  135. 0: {'optimizer': 'SGD', 'lr': 1e-1,
  136. 'weight_decay': 1e-4, 'momentum': 0.9},
  137. 30: {'lr': 1e-2},
  138. 60: {'lr': 1e-3, 'weight_decay': 0},
  139. 90: {'lr': 1e-4}
  140. }
  141. class ResNet_cifar10(ResNet):
  142. def __init__(self, num_classes=10,
  143. block=BasicBlock, depth=18):
  144. super(ResNet_cifar10, self).__init__()
  145. self.inflate = 5
  146. self.inplanes = 16*self.inflate
  147. n = int((depth - 2) / 6)
  148. self.conv1 = BinarizeConv2d(3, 16*self.inflate, kernel_size=3, stride=1, padding=1,
  149. bias=False)
  150. self.maxpool = lambda x: x
  151. self.bn1 = nn.BatchNorm2d(16*self.inflate)
  152. self.tanh1 = nn.Hardtanh(inplace=True)
  153. self.tanh2 = nn.Hardtanh(inplace=True)
  154. self.layer1 = self._make_layer(block, 16*self.inflate, n)
  155. self.layer2 = self._make_layer(block, 32*self.inflate, n, stride=2)
  156. self.layer3 = self._make_layer(block, 64*self.inflate, n, stride=2,do_bntan=False)
  157. self.layer4 = lambda x: x
  158. self.avgpool = nn.AvgPool2d(8)
  159. self.bn2 = nn.BatchNorm1d(64*self.inflate)
  160. self.bn3 = nn.BatchNorm1d(10)
  161. self.logsoftmax = nn.LogSoftmax()
  162. self.fc = BinarizeLinear(64*self.inflate, num_classes)
  163. init_model(self)
  164. #self.regime = {
  165. # 0: {'optimizer': 'SGD', 'lr': 1e-1,
  166. # 'weight_decay': 1e-4, 'momentum': 0.9},
  167. # 81: {'lr': 1e-4},
  168. # 122: {'lr': 1e-5, 'weight_decay': 0},
  169. # 164: {'lr': 1e-6}
  170. #}
  171. self.regime = {
  172. 0: {'optimizer': 'Adam', 'lr': 5e-3},
  173. 101: {'lr': 1e-3},
  174. 142: {'lr': 5e-4},
  175. 184: {'lr': 1e-4},
  176. 220: {'lr': 1e-5}
  177. }
  178. def resnet_binary(**kwargs):
  179. num_classes, depth, dataset = map(
  180. kwargs.get, ['num_classes', 'depth', 'dataset'])
  181. if dataset == 'imagenet':
  182. num_classes = num_classes or 1000
  183. depth = depth or 50
  184. if depth == 18:
  185. return ResNet_imagenet(num_classes=num_classes,
  186. block=BasicBlock, layers=[2, 2, 2, 2])
  187. if depth == 34:
  188. return ResNet_imagenet(num_classes=num_classes,
  189. block=BasicBlock, layers=[3, 4, 6, 3])
  190. if depth == 50:
  191. return ResNet_imagenet(num_classes=num_classes,
  192. block=Bottleneck, layers=[3, 4, 6, 3])
  193. if depth == 101:
  194. return ResNet_imagenet(num_classes=num_classes,
  195. block=Bottleneck, layers=[3, 4, 23, 3])
  196. if depth == 152:
  197. return ResNet_imagenet(num_classes=num_classes,
  198. block=Bottleneck, layers=[3, 8, 36, 3])
  199. elif dataset == 'cifar10':
  200. num_classes = num_classes or 10
  201. depth = depth or 18
  202. return ResNet_cifar10(num_classes=num_classes,
  203. block=BasicBlock, depth=depth)