You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

207 lines
5.9 KiB

  1. import torch
  2. import numpy as np
  3. import cv2, os, sys
  4. from torch.utils.data import Dataset
  5. from matplotlib import pyplot as plt
  6. from torch.utils.data import ConcatDataset, DataLoader, Subset
  7. import torch.nn as nn
  8. import torchvision.transforms as transforms
  9. from torchvision.datasets import DatasetFolder
  10. from PIL import Image
  11. from BinaryNetpytorch.models.binarized_modules import BinarizeLinear,BinarizeConv2d
  12. from BinaryNetpytorch.models.binarized_modules import Binarize,HingeLoss
  13. import seaborn as sns
  14. import random
  15. batch_size = 8
  16. num_epoch = 10
  17. seed = 777
  18. torch.manual_seed(seed)
  19. torch.cuda.manual_seed(seed)
  20. torch.cuda.manual_seed_all(seed)
  21. np.random.seed(seed)
  22. random.seed(seed)
  23. torch.backends.cudnn.benchmark = False
  24. torch.backends.cudnn.deterministic = True
  25. train_tfm = transforms.Compose([
  26. #transforms.Grayscale(),
  27. #transforms.RandomHorizontalFlip(),
  28. #transforms.RandomResizedCrop((40,30)),
  29. #transforms.RandomCrop((40,30)),
  30. #transforms.RandomHorizontalFlip(),
  31. transforms.ToTensor(),
  32. #transforms.RandomResizedCrop((40,30)),
  33. #transforms.TenCrop((40,30)),
  34. #transforms.Normalize(0.5,0.5),
  35. ])
  36. test_tfm = transforms.Compose([
  37. #transforms.Grayscale(),
  38. transforms.ToTensor()
  39. ])
  40. class Classifier(nn.Module):
  41. def __init__(self):
  42. super(Classifier, self).__init__()
  43. self.cnn_layers = nn.Sequential(
  44. # BinarizeConv2d(in_channels=1, out_channels=128, kernel_size=9, padding=9//2, bias=False),
  45. # nn.BatchNorm2d(128),
  46. # nn.ReLU(),
  47. # BinarizeConv2d(in_channels=128, out_channels=64, kernel_size=1, padding=1//2, bias=False),
  48. # nn.BatchNorm2d(64),
  49. #input_size(1,30,40)
  50. BinarizeConv2d(1, 128, 3, 1), #output_size(16,28,38)
  51. nn.BatchNorm2d(128),
  52. nn.ReLU(),
  53. #nn.Dropout(0.2),
  54. nn.MaxPool2d(kernel_size = 2), #output_size(16,14,19)
  55. BinarizeConv2d(128, 64, 3, 1), #output_size(24,12,17)
  56. nn.BatchNorm2d(64),
  57. nn.ReLU(),
  58. #nn.Dropout(0.2),
  59. nn.MaxPool2d(kernel_size = 2), #output_size(24,6,8)
  60. BinarizeConv2d(64, 32, 3, 1), #output_size(32,4,6)
  61. nn.BatchNorm2d(32),
  62. nn.ReLU(),
  63. #nn.Dropout(0.2),
  64. nn.MaxPool2d(kernel_size = 2), #ouput_size(32,2,3)
  65. #nn.LogSoftmax(),
  66. BinarizeConv2d(32, 3, (3,2), 1) #ouput_size(4,2,3) without max :(32,24,34)
  67. )
  68. def forward(self, x):
  69. x = self.cnn_layers(x)
  70. #x = x.flatten(1)
  71. #x = self.fc_layers(x)
  72. #print(x.shape)
  73. x = x.view(x.size(0), -1)
  74. #print(x.shape)
  75. #x = nn.LogSoftmax(x)
  76. #print(x)
  77. return x
  78. def main():
  79. train_set = DatasetFolder("./dataset/data_0711/grideye/train", loader=lambda x: Image.open(x), extensions="bmp", transform=train_tfm)
  80. test_set = DatasetFolder("./dataset/data_0711/grideye/test", loader=lambda x: Image.open(x), extensions="bmp", transform=test_tfm)
  81. val_set = DatasetFolder("./dataset/data_0711/grideye/train", loader=lambda x: Image.open(x), extensions="bmp", transform=test_tfm)
  82. train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
  83. test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
  84. val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)
  85. save_path = 'models.ckpt'
  86. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  87. model = Classifier().to(device)
  88. optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
  89. criterion = nn.CrossEntropyLoss()
  90. best_accuracy = 0.0
  91. for epoch in range(num_epoch):
  92. running_loss = 0.0
  93. total = 0
  94. correct = 0
  95. for i, data in enumerate(train_loader):
  96. inputs, labels = data
  97. inputs = inputs.to(device)
  98. labels = labels.to(device)
  99. #print(labels)
  100. optimizer.zero_grad()
  101. outputs = model(inputs)
  102. #print(outputs.shape)
  103. loss = criterion(outputs, labels)
  104. loss.backward()
  105. for p in list(model.parameters()):
  106. if hasattr(p,'org'):
  107. p.data.copy_(p.org)
  108. optimizer.step()
  109. for p in list(model.parameters()):
  110. if hasattr(p,'org'):
  111. p.org.copy_(p.data.clamp_(-1,1))
  112. running_loss += loss.item()
  113. total += labels.size(0)
  114. _,predicted = torch.max(outputs.data,1)
  115. #print(predicted)
  116. #print("label",labels)
  117. correct += (predicted == labels).sum().item()
  118. train_acc = correct / total
  119. print(f"[ Train | {epoch + 1:03d}/{num_epoch:03d} ] loss = {running_loss:.5f}, acc = {train_acc:.5f}")
  120. model.eval()
  121. with torch.no_grad():
  122. correct = 0
  123. total = 0
  124. for i, data in enumerate(val_loader):
  125. inputs, labels = data
  126. inputs = inputs.to(device)
  127. labels = labels.to(device)
  128. outputs = model(inputs)
  129. _,predicted = torch.max(outputs.data,1)
  130. total += labels.size(0)
  131. correct += (predicted == labels).sum().item()
  132. val_acc = correct / total
  133. if val_acc > best_accuracy:
  134. best_accuracy = val_acc
  135. torch.save(model.state_dict(), save_path)
  136. print("Save Model")
  137. print(f"[ Val | {epoch + 1:03d}/{num_epoch:03d} ] acc = {val_acc:.5f}")
  138. model = Classifier().to(device)
  139. model.load_state_dict(torch.load(save_path))
  140. model.eval()
  141. stat = np.zeros((3,3))
  142. with torch.no_grad():
  143. correct = 0
  144. total = 0
  145. print(model)
  146. for i, data in enumerate(test_loader):
  147. inputs, labels = data
  148. inputs = inputs.to(device)
  149. labels = labels.to(device)
  150. outputs = model(inputs)
  151. #print(outputs.data)
  152. _,predicted = torch.max(outputs.data,1)
  153. #print(predicted)
  154. total += labels.size(0)
  155. correct += (predicted == labels).sum().item()
  156. for k in range(len(predicted)):
  157. if predicted[k] != labels[k]:
  158. img = inputs[k].mul(255).byte()
  159. img = img.cpu().numpy().squeeze(0)
  160. img = np.moveaxis(img, 0, -1)
  161. predict = predicted[k].cpu().numpy()
  162. label = labels[k].cpu().numpy()
  163. path = "test_result/predict:"+str(predict)+"_labels:"+str(label)+".jpg"
  164. stat[int(label)][int(predict)] += 1
  165. ax = sns.heatmap(stat, linewidth=0.5)
  166. plt.xlabel('Prediction')
  167. plt.ylabel('Label')
  168. plt.savefig('heatmap.jpg')
  169. #print(predicted)
  170. #print("labels:",labels)
  171. print('Test Accuracy:{} %'.format((correct / total) * 100))
  172. if __name__ == '__main__':
  173. main()