You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

314 lines
9.8 KiB

  1. import torch
  2. import numpy as np
  3. import matplotlib
  4. #matplotlib.use('TkAgg')
  5. from matplotlib import pyplot as plt
  6. import cv2, os, sys
  7. from torch.utils.data import Dataset
  8. from torch.utils.data import ConcatDataset, DataLoader, Subset
  9. import torch.nn as nn
  10. import torchvision.transforms as transforms
  11. from torchvision.datasets import DatasetFolder
  12. from PIL import Image
  13. from SimBinaryNetpytorch.models.binarized_modules import CimSimConv2d
  14. from SimBinaryNetpytorch.models.binarized_modules import BinarizeConv2d, IdealCimConv2d
  15. from BinaryNetpytorch.models.binarized_modules import BinarizeConv2d as BConv2d
  16. from BinaryNetpytorch.models.binarized_modules import Binarize,HingeLoss
  17. import seaborn as sns
  18. import random
  19. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  20. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  21. batch_size = 8
  22. num_epoch = 150
  23. LEARN_RATE = 0.001
  24. seed = 333
  25. torch.manual_seed(seed)
  26. torch.cuda.manual_seed(seed)
  27. torch.cuda.manual_seed_all(seed)
  28. np.random.seed(seed)
  29. random.seed(seed)
  30. torch.backends.cudnn.benchmark = False
  31. torch.backends.cudnn.deterministic = True
  32. H = [32, 64, 128]
  33. H = [16, 32, 64]
  34. RANDOM_WEIGHT_PER_EPOCH = 10
  35. RANDOM_RATE = 1
  36. class Classifier(nn.Module):
  37. def __init__(self):
  38. super(Classifier, self).__init__()
  39. conv = CimSimConv2d
  40. conv2 = BConv2d
  41. conv3 = IdealCimConv2d
  42. conv3 = BinarizeConv2d
  43. #conv = nn.Conv2d
  44. self.cnn_layers1 = nn.Sequential(
  45. #conv(in_channels=1, out_channels=128, kernel_size=7),
  46. #nn.BatchNorm2d(128),
  47. #nn.LeakyReLU(0.5),
  48. #conv(in_channels=128, out_channels=64, kernel_size=3),
  49. #nn.BatchNorm2d(64),
  50. #nn.LeakyReLU(0.5),
  51. #input_size(1,30,40)
  52. conv(1, H[0], 3, 1), #output_size(16,66,66)
  53. #nn.BatchNorm2d(128),
  54. nn.LeakyReLU(0.5),
  55. #nn.Dropout(0.2),
  56. nn.MaxPool2d(kernel_size = 2), #output_size(16,33,33)
  57. )
  58. self.cnn_layers2 = nn.Sequential(
  59. conv(H[0], H[1], 3, 1), #output_size(24,31,31)
  60. #nn.BatchNorm2d(64),
  61. nn.LeakyReLU(0.5),
  62. #nn.Dropout(0.2),
  63. nn.MaxPool2d(kernel_size = 2), #output_size(24,15,15)
  64. )
  65. self.cnn_layers3 = nn.Sequential(
  66. conv(H[1], H[2], 3, 1), #output_size(32,13,13)
  67. #nn.BatchNorm2d(32),
  68. nn.LeakyReLU(0.5),
  69. #nn.Dropout(0.2),
  70. nn.MaxPool2d(kernel_size = 2), #ouput_size(32,6,6)
  71. #nn.LogSoftmax(),
  72. #BinarizeConv2d(H[2], 8, (3,2), 1) #ouput_size(4,2,3) without max :(32,24,34)
  73. conv2(H[2], 8, (3,2), 1) #ouput_size(4,2,3) without max :(32,24,34)
  74. #conv(32, 8, (2,1), 1) #ouput_size(4,2,3) without max :(32,24,34)
  75. )
  76. def forward(self, x):
  77. #print(x)
  78. #print("input",float(torch.min(x)),float(torch.max(x)))
  79. x = self.cnn_layers1(x)
  80. #print("layer1",float(torch.min(x)),float(torch.max(x)))
  81. #print(x)
  82. x = self.cnn_layers2(x)
  83. #print("layer2",float(torch.min(x)),float(torch.max(x)))
  84. x = self.cnn_layers3(x)
  85. #print("layer3",float(torch.min(x)),float(torch.max(x)))
  86. #print(x)
  87. #x = x.flatten(1)
  88. #x = self.fc_layers(x)
  89. #print(x.shape)
  90. x = x.view(x.size(0), -1)
  91. #print(x.shape)
  92. #x = nn.LogSoftmax(x)
  93. #print(x)
  94. return x
  95. class Cls_Dataset(Dataset):
  96. def __init__(self, input, target):
  97. self.input = input
  98. self.target = target
  99. def __getitem__(self, index):
  100. return self.input[index], self.target[index]
  101. def __len__(self):
  102. return len(self.input)
  103. def Load_data(path, cls):
  104. data = []
  105. label = []
  106. f = 0
  107. for c in range(cls):
  108. t = 0
  109. img = []
  110. data_path = path + "/0"+str(c)+"/"
  111. for filename in os.listdir(data_path):
  112. # with open(data_path + "/" + filename, "rb") as f_in:
  113. tmp_input = cv2.imread(data_path + filename,cv2.IMREAD_UNCHANGED)
  114. #print(tmp_input)
  115. tmp_input = cv2.resize(tmp_input, (30,40), interpolation = cv2.INTER_AREA)
  116. tmp_input = tmp_input.astype(int)
  117. tmp_input = tmp_input//2
  118. tmp_input = tmp_input - 63
  119. tmp_input = np.where(tmp_input > 63, 63, tmp_input)
  120. tmp_input = tmp_input[:, :, np.newaxis]
  121. tmp_input = tmp_input.transpose(2,0,1)
  122. if img is not None:
  123. img.append(tmp_input)
  124. else:
  125. img = tmp_input
  126. label_tmp = np.full((len(img),1),c)
  127. if f != 0:
  128. data = np.append(data,img,axis=0)
  129. else:
  130. data = img
  131. if f != 0:
  132. label = np.append(label,label_tmp,axis=0)
  133. else:
  134. label = label_tmp
  135. f = 1
  136. label = np.squeeze(label)
  137. label = np.array(label)
  138. #print(data)
  139. data = np.array(data)
  140. #print(label.shape)
  141. data = data.astype('float32')
  142. return data, label
  143. def main():
  144. # test_set = DatasetFolder("./dataset/8cls_srcnnimg/test", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
  145. # val_set = DatasetFolder("./dataset/8cls_srcnnimg/train", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
  146. global num_epoch
  147. train_data, train_label = Load_data("./dataset/8cls_grideye/train",8)
  148. train = Cls_Dataset(train_data, train_label)
  149. train_loader = DataLoader(
  150. train, batch_size=100, shuffle=True,
  151. num_workers=4, pin_memory=True, drop_last=True
  152. )
  153. test_data, test_label = Load_data("./dataset/8cls_grideye/test",8)
  154. test = Cls_Dataset(test_data, test_label)
  155. test_loader = DataLoader(
  156. test, batch_size=100, shuffle=True,
  157. num_workers=4, pin_memory=True, drop_last=True
  158. )
  159. val_data, val_label = Load_data("./dataset/8cls_grideye/val",8)
  160. val = Cls_Dataset(val_data, val_label)
  161. val_loader = DataLoader(
  162. val, batch_size=100, shuffle=True,
  163. num_workers=4, pin_memory=True, drop_last=True
  164. )
  165. save_path = 'models.ckpt'
  166. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  167. model = Classifier().to(device)
  168. #optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
  169. optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
  170. try:
  171. checkpoint = torch.load('training_state_pre.bin')
  172. model.load_state_dict(checkpoint['state_dict'])
  173. #optimizer.load_state_dict(checkpoint['optimizer'])
  174. #model.load_state_dict(torch.load('models_pre.ckpt'))
  175. num_epoch = 150
  176. except:
  177. print('cannot read from pretrained model.')
  178. criterion = nn.CrossEntropyLoss()
  179. best_accuracy = 0.0
  180. last_save = 0
  181. for epoch in range(num_epoch):
  182. model.train()
  183. running_loss = 0.0
  184. total = 0
  185. correct = 0
  186. for i, data in enumerate(train_loader):
  187. inputs, labels = data
  188. inputs = inputs.to(device)
  189. labels = labels.to(device)
  190. optimizer.zero_grad()
  191. outputs = model(inputs)
  192. #print(outputs.shape)
  193. loss = criterion(outputs, labels)
  194. loss.backward()
  195. #print(model)
  196. #print(model.cnn_layers3[3].weight.grad)
  197. for p in list(model.parameters()):
  198. if hasattr(p,'org'):
  199. p.data.copy_(p.org)
  200. optimizer.step()
  201. for p in list(model.parameters()):
  202. if hasattr(p,'org'):
  203. p.org.copy_(p.data.clamp_(-1,1))
  204. running_loss += loss.item()
  205. total += labels.size(0)
  206. _,predicted = torch.max(outputs.data,1)
  207. #print(predicted)
  208. #print("label",labels)
  209. correct += (predicted == labels).sum().item()
  210. train_acc = correct / total
  211. print(f"[ Train | {epoch + 1:03d}/{num_epoch:03d} ] loss = {running_loss:.5f}, acc = {train_acc:.5f}")
  212. model.eval()
  213. with torch.no_grad():
  214. correct = 0
  215. total = 0
  216. for i, data in enumerate(val_loader):
  217. inputs, labels = data
  218. inputs = inputs.to(device)
  219. labels = labels.to(device)
  220. outputs = model(inputs)
  221. _,predicted = torch.max(outputs.data,1)
  222. total += labels.size(0)
  223. correct += (predicted == labels).sum().item()
  224. val_acc = correct / total
  225. if val_acc > best_accuracy:
  226. last_save = epoch
  227. best_accuracy = val_acc
  228. torch.save(model.state_dict(), save_path)
  229. state = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
  230. torch.save(state, 'training_state.bin')
  231. print("Save Model")
  232. print(f"[ Val | {epoch + 1:03d}/{num_epoch:03d} ] acc = {val_acc:.5f}")
  233. model = Classifier().to(device)
  234. model.load_state_dict(torch.load(save_path))
  235. model.eval()
  236. stat = np.zeros((8,8))
  237. with torch.no_grad():
  238. correct = 0
  239. total = 0
  240. print(model)
  241. for i, data in enumerate(test_loader):
  242. inputs, labels = data
  243. inputs = inputs.to(device)
  244. labels = labels.to(device)
  245. outputs = model(inputs)
  246. #print(outputs.data)
  247. _,predicted = torch.max(outputs.data,1)
  248. #print(predicted)
  249. total += labels.size(0)
  250. correct += (predicted == labels).sum().item()
  251. #print(labels)
  252. #print(outputs.size())
  253. print('Test Accuracy:{} %'.format((correct / total) * 100))
  254. print('Save at', last_save+1)
  255. if __name__ == '__main__':
  256. main()