You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

193 lines
4.9 KiB

  1. import torch
  2. import numpy as np
  3. import cv2, os, sys
  4. import pandas as pd
  5. from torch.utils.data import Dataset
  6. from matplotlib import pyplot as plt
  7. from torch.utils.data import ConcatDataset, DataLoader, Subset
  8. import torch.nn as nn
  9. import torchvision.transforms as transforms
  10. from torchvision.datasets import DatasetFolder
  11. from PIL import Image
  12. import torchvision.models as models
  13. batch_size = 32
  14. num_epoch = 1
  15. torch.cuda.set_device(1)
  16. train_tfm = transforms.Compose([
  17. transforms.Grayscale(),
  18. transforms.RandomHorizontalFlip(),
  19. transforms.RandomResizedCrop((68,68)),
  20. transforms.ToTensor(),
  21. #transforms.RandomResizedCrop((40,30)),
  22. #transforms.TenCrop((40,30)),
  23. #transforms.Normalize(0.5,0.5),
  24. ])
  25. test_tfm = transforms.Compose([
  26. transforms.Grayscale(),
  27. transforms.ToTensor()
  28. ])
  29. '''
  30. class Classifier(nn.Module):
  31. def __init__(self):
  32. super(Classifier, self).__init__()
  33. self.cnn_layers = nn.Sequential(
  34. #input_size(1,30,40)
  35. nn.Conv2d(1, 16, 3, 1), #output_size(16,28,38)
  36. nn.BatchNorm2d(16),
  37. nn.ReLU(),
  38. nn.Dropout(0.2),
  39. nn.MaxPool2d(kernel_size = 2), #output_size(16,14,19)
  40. nn.Conv2d(16, 24, 3, 1), #output_size(24,12,17)
  41. nn.BatchNorm2d(24),
  42. nn.ReLU(),
  43. nn.Dropout(0.2),
  44. nn.MaxPool2d(kernel_size = 2), #output_size(24,6,8)
  45. nn.Conv2d(24, 32, 3, 1), #output_size(32,4,6)
  46. nn.BatchNorm2d(32),
  47. nn.ReLU(),
  48. nn.Dropout(0.2),
  49. nn.MaxPool2d(kernel_size = 2) #ouput_size(32,2,3)
  50. )
  51. self.fc_layers = nn.Sequential(
  52. nn.Linear(32 * 2 * 3, 32),
  53. nn.ReLU(),
  54. nn.Dropout(0.2),
  55. nn.Linear(32,8)
  56. )
  57. def forward(self, x):
  58. x = self.cnn_layers(x)
  59. x = x.flatten(1)
  60. x = self.fc_layers(x)
  61. return x
  62. '''
  63. def main():
  64. train_set = DatasetFolder("pose_data2/train", loader=lambda x: Image.open(x), extensions="bmp", transform=train_tfm)
  65. test_set = DatasetFolder("pose_data2/test", loader=lambda x: Image.open(x), extensions="bmp", transform=test_tfm)
  66. valid_set = DatasetFolder("pose_data2/val", loader=lambda x: Image.open(x), extensions="bmp", transform=test_tfm)
  67. train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
  68. test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
  69. valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True)
  70. model_path = "model.ckpt"
  71. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  72. model = models.resnet50()
  73. model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,
  74. bias=False)
  75. model.fc = nn.Linear(2048, 8)
  76. model = model.to(device)
  77. print(model)
  78. optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  79. criterion = nn.CrossEntropyLoss()
  80. best_acc = -1
  81. for epoch in range(num_epoch):
  82. ##Training
  83. running_loss = 0.0
  84. total = 0
  85. correct = 0
  86. for i, data in enumerate(train_loader):
  87. inputs, labels = data
  88. inputs = inputs.to(device)
  89. labels = labels.to(device)
  90. optimizer.zero_grad()
  91. outputs = model(inputs)
  92. loss = criterion(outputs, labels)
  93. loss.backward()
  94. optimizer.step()
  95. running_loss += loss.item()
  96. total += labels.size(0)
  97. _,predicted = torch.max(outputs.data,1)
  98. #print(predicted)
  99. #print("label",labels)
  100. correct += (predicted == labels).sum().item()
  101. train_acc = correct / total
  102. print(f"[ Train | {epoch + 1:03d}/{num_epoch:03d} ] loss = {running_loss:.5f}, acc = {train_acc:.5f}")
  103. ##Validation
  104. model.eval()
  105. valid_loss = 0.0
  106. total = 0
  107. correct = 0
  108. for i, data in enumerate(valid_loader):
  109. inputs, labels = data
  110. inputs = inputs.to(device)
  111. labels = labels.to(device)
  112. with torch.no_grad():
  113. outputs = model(inputs)
  114. loss = criterion(outputs, labels)
  115. running_loss += loss.item()
  116. total += labels.size(0)
  117. _,predicted = torch.max(outputs.data,1)
  118. correct += (predicted == labels).sum().item()
  119. valid_acc = correct / total
  120. print(f"[ Valid | {epoch + 1:03d}/{num_epoch:03d} ] loss = {running_loss:.5f}, acc = {valid_acc:.5f}")
  121. if valid_acc > best_acc:
  122. best_acc = valid_acc
  123. torch.save(model.state_dict(), model_path)
  124. print('saving model with acc {:.3f}'.format(valid_acc))
  125. ##Testing
  126. model = models.resnet50()
  127. model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,
  128. bias=False)
  129. model.fc = nn.Linear(2048, 8)
  130. model = model.to(device)
  131. model.load_state_dict(torch.load(model_path))
  132. model.eval()
  133. with torch.no_grad():
  134. correct = 0
  135. total = 0
  136. for i, data in enumerate(test_loader):
  137. inputs, labels = data
  138. inputs = inputs.to(device)
  139. labels = labels.to(device)
  140. outputs = model(inputs)
  141. _,predicted = torch.max(outputs.data,1)
  142. total += labels.size(0)
  143. correct += (predicted == labels).sum().item()
  144. # for k in range(batch_size):
  145. # if predicted[k] != labels[k]:
  146. # print(inputs[k])
  147. #print(predicted)
  148. #print("labels:",labels)
  149. print('Test Accuracy:{} %'.format((correct / total) * 100))
  150. if __name__ == '__main__':
  151. main()