You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

217 lines
6.9 KiB

  1. import torch.nn as nn
  2. import torchvision.transforms as transforms
  3. import math
  4. __all__ = ['resnet']
  5. def conv3x3(in_planes, out_planes, stride=1):
  6. "3x3 convolution with padding"
  7. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  8. padding=1, bias=False)
  9. def init_model(model):
  10. for m in model.modules():
  11. if isinstance(m, nn.Conv2d):
  12. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  13. m.weight.data.normal_(0, math.sqrt(2. / n))
  14. elif isinstance(m, nn.BatchNorm2d):
  15. m.weight.data.fill_(1)
  16. m.bias.data.zero_()
  17. class BasicBlock(nn.Module):
  18. expansion = 1
  19. def __init__(self, inplanes, planes, stride=1, downsample=None):
  20. super(BasicBlock, self).__init__()
  21. self.conv1 = conv3x3(inplanes, planes, stride)
  22. self.bn1 = nn.BatchNorm2d(planes)
  23. self.relu = nn.ReLU(inplace=True)
  24. self.conv2 = conv3x3(planes, planes)
  25. self.bn2 = nn.BatchNorm2d(planes)
  26. self.downsample = downsample
  27. self.stride = stride
  28. def forward(self, x):
  29. residual = x
  30. out = self.conv1(x)
  31. out = self.bn1(out)
  32. out = self.relu(out)
  33. out = self.conv2(out)
  34. out = self.bn2(out)
  35. if self.downsample is not None:
  36. residual = self.downsample(x)
  37. out += residual
  38. out = self.relu(out)
  39. return out
  40. class Bottleneck(nn.Module):
  41. expansion = 4
  42. def __init__(self, inplanes, planes, stride=1, downsample=None):
  43. super(Bottleneck, self).__init__()
  44. self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
  45. self.bn1 = nn.BatchNorm2d(planes)
  46. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
  47. padding=1, bias=False)
  48. self.bn2 = nn.BatchNorm2d(planes)
  49. self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
  50. self.bn3 = nn.BatchNorm2d(planes * 4)
  51. self.relu = nn.ReLU(inplace=True)
  52. self.downsample = downsample
  53. self.stride = stride
  54. def forward(self, x):
  55. residual = x
  56. out = self.conv1(x)
  57. out = self.bn1(out)
  58. out = self.relu(out)
  59. out = self.conv2(out)
  60. out = self.bn2(out)
  61. out = self.relu(out)
  62. out = self.conv3(out)
  63. out = self.bn3(out)
  64. if self.downsample is not None:
  65. residual = self.downsample(x)
  66. out += residual
  67. out = self.relu(out)
  68. return out
  69. class ResNet(nn.Module):
  70. def __init__(self):
  71. super(ResNet, self).__init__()
  72. def _make_layer(self, block, planes, blocks, stride=1):
  73. downsample = None
  74. if stride != 1 or self.inplanes != planes * block.expansion:
  75. downsample = nn.Sequential(
  76. nn.Conv2d(self.inplanes, planes * block.expansion,
  77. kernel_size=1, stride=stride, bias=False),
  78. nn.BatchNorm2d(planes * block.expansion),
  79. )
  80. layers = []
  81. layers.append(block(self.inplanes, planes, stride, downsample))
  82. self.inplanes = planes * block.expansion
  83. for i in range(1, blocks):
  84. layers.append(block(self.inplanes, planes))
  85. return nn.Sequential(*layers)
  86. def forward(self, x):
  87. x = self.conv1(x)
  88. x = self.bn1(x)
  89. x = self.relu(x)
  90. x = self.maxpool(x)
  91. x = self.layer1(x)
  92. x = self.layer2(x)
  93. x = self.layer3(x)
  94. x = self.layer4(x)
  95. x = self.avgpool(x)
  96. x = x.view(x.size(0), -1)
  97. x = self.fc(x)
  98. return x
  99. class ResNet_imagenet(ResNet):
  100. def __init__(self, num_classes=1000,
  101. block=Bottleneck, layers=[3, 4, 23, 3]):
  102. super(ResNet_imagenet, self).__init__()
  103. self.inplanes = 64
  104. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
  105. bias=False)
  106. self.bn1 = nn.BatchNorm2d(64)
  107. self.relu = nn.ReLU(inplace=True)
  108. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  109. self.layer1 = self._make_layer(block, 64, layers[0])
  110. self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
  111. self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
  112. self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
  113. self.avgpool = nn.AvgPool2d(7)
  114. self.fc = nn.Linear(512 * block.expansion, num_classes)
  115. init_model(self)
  116. self.regime = {
  117. 0: {'optimizer': 'SGD', 'lr': 1e-1,
  118. 'weight_decay': 1e-4, 'momentum': 0.9},
  119. 30: {'lr': 1e-2},
  120. 60: {'lr': 1e-3, 'weight_decay': 0},
  121. 90: {'lr': 1e-4}
  122. }
  123. class ResNet_cifar10(ResNet):
  124. def __init__(self, num_classes=10,
  125. block=BasicBlock, depth=18):
  126. super(ResNet_cifar10, self).__init__()
  127. self.inplanes = 16
  128. n = int((depth - 2) / 6)
  129. self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1,
  130. bias=False)
  131. self.bn1 = nn.BatchNorm2d(16)
  132. self.relu = nn.ReLU(inplace=True)
  133. self.maxpool = lambda x: x
  134. self.layer1 = self._make_layer(block, 16, n)
  135. self.layer2 = self._make_layer(block, 32, n, stride=2)
  136. self.layer3 = self._make_layer(block, 64, n, stride=2)
  137. self.layer4 = lambda x: x
  138. self.avgpool = nn.AvgPool2d(8)
  139. self.fc = nn.Linear(64, num_classes)
  140. init_model(self)
  141. self.regime = {
  142. 0: {'optimizer': 'SGD', 'lr': 1e-1,
  143. 'weight_decay': 1e-4, 'momentum': 0.9},
  144. 81: {'lr': 1e-2},
  145. 122: {'lr': 1e-3, 'weight_decay': 0},
  146. 164: {'lr': 1e-4}
  147. }
  148. def resnet(**kwargs):
  149. num_classes, depth, dataset = map(
  150. kwargs.get, ['num_classes', 'depth', 'dataset'])
  151. if dataset == 'imagenet':
  152. num_classes = num_classes or 1000
  153. depth = depth or 50
  154. if depth == 18:
  155. return ResNet_imagenet(num_classes=num_classes,
  156. block=BasicBlock, layers=[2, 2, 2, 2])
  157. if depth == 34:
  158. return ResNet_imagenet(num_classes=num_classes,
  159. block=BasicBlock, layers=[3, 4, 6, 3])
  160. if depth == 50:
  161. return ResNet_imagenet(num_classes=num_classes,
  162. block=Bottleneck, layers=[3, 4, 6, 3])
  163. if depth == 101:
  164. return ResNet_imagenet(num_classes=num_classes,
  165. block=Bottleneck, layers=[3, 4, 23, 3])
  166. if depth == 152:
  167. return ResNet_imagenet(num_classes=num_classes,
  168. block=Bottleneck, layers=[3, 8, 36, 3])
  169. elif dataset == 'cifar10':
  170. num_classes = num_classes or 10
  171. depth = depth or 18 #56
  172. return ResNet_cifar10(num_classes=num_classes,
  173. block=BasicBlock, depth=depth)