You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

285 lines
9.5 KiB

  1. import torch
  2. import numpy as np
  3. import cv2, os, sys
  4. import pandas as pd
  5. from torch.utils.data import Dataset
  6. from matplotlib import pyplot as plt
  7. from torch.utils.data import ConcatDataset, DataLoader, Subset
  8. import torch.nn as nn
  9. import torchvision.transforms as transforms
  10. from torchvision.datasets import DatasetFolder
  11. from PIL import Image
  12. import torchvision.models
  13. import BinaryNetpytorch.models as models
  14. from BinaryNetpytorch.models.binarized_modules import BinarizeLinear,BinarizeConv2d
  15. import progressbar
  16. import seaborn as sns
  17. batch_size = 32
  18. num_epoch = 60
  19. torch.cuda.set_device(1)
  20. train_tfm = transforms.Compose([
  21. # transforms.RandomHorizontalFlip(),
  22. # transforms.RandomResizedCrop((40,30)),
  23. transforms.Grayscale(),
  24. transforms.Resize((68, 68)),
  25. transforms.ToTensor(),
  26. #transforms.RandomResizedCrop((40,30)),
  27. #transforms.TenCrop((40,30)),
  28. # transforms.Normalize(0.5,0.5),
  29. ])
  30. test_tfm = transforms.Compose([
  31. transforms.Grayscale(),
  32. transforms.Resize((68, 68)),
  33. transforms.ToTensor()
  34. ])
  35. def Quantize(img):
  36. scaler = torch.div(img, 0.0078125, rounding_mode="floor")
  37. scaler_t1 = scaler * 0.0078125
  38. scaler_t2 = (scaler + 1) * 0.0078125
  39. img = torch.where(abs(img - scaler_t1) < abs(img -scaler_t2), scaler_t1 , scaler_t2)
  40. return img
  41. # bar = progressbar.ProgressBar(maxval=img.size(0)*img.size(2)*img.size(3), \
  42. # widgets=[progressbar.Bar('=', '[', ']'), ' ', progressbar.Percentage()])
  43. # bar.start()
  44. # for p in range(img.size(0)):
  45. # for i in range(img.size(2)):
  46. # for j in range(img.size(3)):
  47. # scaler = int(img[p][0][i][j] / 0.0078125)
  48. # t1 = scaler * 0.0078125
  49. # t2 = (scaler + 1) * 0.0078125
  50. # if(abs(img[p][0][i][j] - t1) < abs(img[p][0][i][j] - t2)):
  51. # img[p][0][i][j] = t1
  52. # else:
  53. # img[p][0][i][j] = t2
  54. # bar.finish()
  55. # return img
  56. def Binaryconv3x3(in_planes, out_planes, stride=1):
  57. "3x3 convolution with padding"
  58. return BinarizeConv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  59. padding=1, bias=False)
  60. def conv3x3(in_planes, out_planes, stride=1):
  61. "3x3 convolution with padding"
  62. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  63. padding=1, bias=False)
  64. class BasicBlock(nn.Module):
  65. expansion = 1
  66. def __init__(self, inplanes, planes, stride=1, downsample=None,do_bntan=True):
  67. super(BasicBlock, self).__init__()
  68. self.conv1 = Binaryconv3x3(inplanes, planes, stride)
  69. self.bn1 = nn.BatchNorm2d(planes)
  70. self.tanh1 = nn.Hardtanh(inplace=True)
  71. self.conv2 = Binaryconv3x3(planes, planes)
  72. self.tanh2 = nn.Hardtanh(inplace=True)
  73. self.bn2 = nn.BatchNorm2d(planes)
  74. self.downsample = downsample
  75. self.do_bntan=do_bntan
  76. self.stride = stride
  77. def forward(self, x):
  78. residual = x.clone()
  79. x = Quantize(x)
  80. out = self.conv1(x)
  81. out = self.bn1(out)
  82. out = self.tanh1(out)
  83. out = Quantize(out)
  84. out = self.conv2(out)
  85. if self.downsample is not None:
  86. if residual.data.max()>1:
  87. import pdb; pdb.set_trace()
  88. residual = self.downsample(residual)
  89. out += residual
  90. if self.do_bntan:
  91. out = self.bn2(out)
  92. out = self.tanh2(out)
  93. return out
  94. class ResNet(nn.Module):
  95. def __init__(self):
  96. super(ResNet, self).__init__()
  97. def _make_layer(self, block, planes, blocks, stride=1,do_bntan=True):
  98. downsample = None
  99. if stride != 1 or self.inplanes != planes * block.expansion:
  100. downsample = nn.Sequential(
  101. BinarizeConv2d(self.inplanes, planes * block.expansion,
  102. kernel_size=1, stride=stride, bias=False),
  103. nn.BatchNorm2d(planes * block.expansion),
  104. )
  105. layers = []
  106. layers.append(block(self.inplanes, planes, stride, downsample))
  107. self.inplanes = planes * block.expansion
  108. for i in range(1, blocks-1):
  109. layers.append(block(self.inplanes, planes))
  110. layers.append(block(self.inplanes, planes,do_bntan=do_bntan))
  111. return nn.Sequential(*layers)
  112. def forward(self, x):
  113. x = Quantize(x)
  114. x = self.conv1(x)
  115. x = self.maxpool(x)
  116. x = self.bn1(x)
  117. x = self.tanh1(x)
  118. x = self.layer1(x)
  119. x = self.layer2(x)
  120. x = self.layer3(x)
  121. x = self.layer4(x)
  122. x = self.avgpool(x)
  123. x = x.view(x.size(0), -1)
  124. x = self.bn2(x)
  125. x = self.tanh2(x)
  126. #print(x.size())
  127. x = x.view(32,1280,1,1)
  128. x = self.fc(x)
  129. x = x.view(x.size(0), -1)
  130. x = self.bn3(x)
  131. x = self.logsoftmax(x)
  132. return x
  133. class ResNet_cifar10(ResNet):
  134. def __init__(self, num_classes=8,
  135. block=BasicBlock, depth=18):
  136. super(ResNet_cifar10, self).__init__()
  137. self.inflate = 5
  138. self.inplanes = 16*self.inflate
  139. n = int((depth - 2) / 6)
  140. self.conv1 = BinarizeConv2d(1, 16*self.inflate, kernel_size=3, stride=1, padding=1,
  141. bias=False)
  142. self.maxpool = lambda x: x
  143. self.bn1 = nn.BatchNorm2d(16*self.inflate)
  144. self.tanh1 = nn.Hardtanh(inplace=True)
  145. self.tanh2 = nn.Hardtanh(inplace=True)
  146. self.layer1 = self._make_layer(block, 16*self.inflate, n)
  147. self.layer2 = self._make_layer(block, 32*self.inflate, n, stride=2)
  148. self.layer3 = self._make_layer(block, 64*self.inflate, n, stride=2,do_bntan=False)
  149. self.layer4 = lambda x: x
  150. self.avgpool = nn.AvgPool2d(8)
  151. self.bn2 = nn.BatchNorm1d(256*self.inflate)
  152. self.bn3 = nn.BatchNorm1d(8)
  153. self.logsoftmax = nn.LogSoftmax()
  154. #self.fc = BinarizeLinear(256*self.inflate, 8)
  155. self.fc = BinarizeConv2d(256*self.inflate, 8, kernel_size=1)
  156. def main():
  157. train_set = DatasetFolder("pose_data2/train", loader=lambda x: Image.open(x), extensions="bmp", transform=train_tfm)
  158. test_set = DatasetFolder("pose_data2/test", loader=lambda x: Image.open(x), extensions="bmp", transform=test_tfm)
  159. train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
  160. test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
  161. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  162. model = ResNet_cifar10(num_classes=8,block=BasicBlock,depth=18)
  163. model = model.to(device)
  164. print(model)
  165. optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  166. criterion = nn.CrossEntropyLoss()
  167. model_path = "model.ckpt"
  168. for epoch in range(num_epoch):
  169. running_loss = 0.0
  170. total = 0
  171. correct = 0
  172. for i, data in enumerate(train_loader):
  173. inputs, labels = data
  174. inputs = inputs.to(device)
  175. labels = labels.to(device)
  176. optimizer.zero_grad()
  177. outputs = model(inputs)
  178. loss = criterion(outputs, labels)
  179. loss.backward()
  180. optimizer.step()
  181. running_loss += loss.item()
  182. total += labels.size(0)
  183. _,predicted = torch.max(outputs.data,1)
  184. #print(predicted)
  185. #print("label",labels)
  186. correct += (predicted == labels).sum().item()
  187. train_acc = correct / total
  188. print(f"[ Train | {epoch + 1:03d}/{num_epoch:03d} ] loss = {running_loss:.5f}, acc = {train_acc:.5f}")
  189. torch.save(model.state_dict(), model_path)
  190. model = ResNet_cifar10(num_classes=8,block=BasicBlock,depth=18)
  191. model = model.to(device)
  192. model.load_state_dict(torch.load(model_path))
  193. model.eval()
  194. with torch.no_grad():
  195. correct = 0
  196. total = 0
  197. correct_2 = 0
  198. stat = np.zeros((8,8))
  199. for i, data in enumerate(test_loader):
  200. inputs, labels = data
  201. inputs = inputs.to(device)
  202. labels = labels.to(device)
  203. outputs = model(inputs)
  204. _,predicted = torch.max(outputs.data,1)
  205. total += labels.size(0)
  206. correct += (predicted == labels).sum().item()
  207. for b in range(batch_size):
  208. if predicted[b] == 0 or predicted[b] == 1 or predicted[b] == 2 or predicted[b] == 3:
  209. if labels[b] == 0 or labels[b] == 1 or labels[b] == 2 or labels[b] == 3:
  210. correct_2 += 1
  211. else:
  212. if labels[b] == 4 or labels[b] == 5 or labels[b] == 6 or labels[b] == 7:
  213. correct_2 += 1
  214. for k in range(batch_size):
  215. if predicted[k] != labels[k]:
  216. img = inputs[k].mul(255).byte()
  217. img = img.cpu().numpy().squeeze(0)
  218. img = np.moveaxis(img, 0, -1)
  219. predict = predicted[k].cpu().numpy()
  220. label = labels[k].cpu().numpy()
  221. path = "test_result/predict:"+str(predict)+"_labels:"+str(label)+".jpg"
  222. stat[int(label)][int(predict)] += 1
  223. cv2.imwrite(path,img)
  224. print(stat)
  225. ax = sns.heatmap(stat, linewidth=0.5)
  226. plt.xlabel('Prediction')
  227. plt.ylabel('Label')
  228. plt.savefig('heatmap.jpg')
  229. #print(predicted)
  230. #print("labels:",labels)
  231. print('Test_2clasee Accuracy:{} %'.format((correct_2 / total) * 100))
  232. print('Test Accuracy:{} %'.format((correct / total) * 100))
  233. if __name__ == '__main__':
  234. main()