You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

154 lines
3.6 KiB

  1. import torch
  2. import numpy as np
  3. import cv2, os, sys
  4. import pandas as pd
  5. from torch.utils.data import Dataset
  6. from matplotlib import pyplot as plt
  7. from torch.utils.data import ConcatDataset, DataLoader, Subset
  8. import torch.nn as nn
  9. import torchvision.transforms as transforms
  10. from torchvision.datasets import DatasetFolder
  11. from PIL import Image
  12. import torchvision.models as models
  13. batch_size = 32
  14. num_epoch = 10
  15. train_tfm = transforms.Compose([
  16. transforms.Grayscale(),
  17. transforms.RandomResizedCrop((40,30)),
  18. transforms.Resize((40, 30)),
  19. transforms.ToTensor(),
  20. #transforms.TenCrop((40,30)),
  21. #transforms.Normalize(0.5,0.5),
  22. ])
  23. test_tfm = transforms.Compose([
  24. transforms.Grayscale(),
  25. transforms.Resize((40, 30)),
  26. transforms.ToTensor()
  27. ])
  28. '''
  29. class Classifier(nn.Module):
  30. def __init__(self):
  31. super(Classifier, self).__init__()
  32. self.cnn_layers = nn.Sequential(
  33. #input_size(1,30,40)
  34. nn.Conv2d(1, 16, 3, 1), #output_size(16,28,38)
  35. nn.BatchNorm2d(16),
  36. nn.ReLU(),
  37. nn.Dropout(0.2),
  38. nn.MaxPool2d(kernel_size = 2), #output_size(16,14,19)
  39. nn.Conv2d(16, 24, 3, 1), #output_size(24,12,17)
  40. nn.BatchNorm2d(24),
  41. nn.ReLU(),
  42. nn.Dropout(0.2),
  43. nn.MaxPool2d(kernel_size = 2), #output_size(24,6,8)
  44. nn.Conv2d(24, 32, 3, 1), #output_size(32,4,6)
  45. nn.BatchNorm2d(32),
  46. nn.ReLU(),
  47. nn.Dropout(0.2),
  48. nn.MaxPool2d(kernel_size = 2) #ouput_size(32,2,3)
  49. )
  50. self.fc_layers = nn.Sequential(
  51. nn.Linear(32 * 2 * 3, 32),
  52. nn.ReLU(),
  53. nn.Dropout(0.2),
  54. nn.Linear(32,8)
  55. )
  56. def forward(self, x):
  57. x = self.cnn_layers(x)
  58. x = x.flatten(1)
  59. x = self.fc_layers(x)
  60. return x
  61. '''
  62. def main():
  63. train_set = DatasetFolder("./dataset/data_0705/lepton/train", loader=lambda x: Image.open(x), extensions="bmp", transform=train_tfm)
  64. test_set = DatasetFolder("./dataset/data_0705/lepton/test", loader=lambda x: Image.open(x), extensions="bmp", transform=test_tfm)
  65. train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
  66. test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
  67. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  68. model = models.resnet18()
  69. model.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=3,
  70. bias=False)
  71. model.fc = nn.Linear(512, 3)
  72. model = model.to(device)
  73. print(model)
  74. optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  75. criterion = nn.CrossEntropyLoss()
  76. for epoch in range(num_epoch):
  77. ##Training
  78. running_loss = 0.0
  79. total = 0
  80. correct = 0
  81. for i, data in enumerate(train_loader):
  82. inputs, labels = data
  83. inputs = inputs.to(device)
  84. labels = labels.to(device)
  85. optimizer.zero_grad()
  86. outputs = model(inputs)
  87. loss = criterion(outputs, labels)
  88. loss.backward()
  89. optimizer.step()
  90. running_loss += loss.item()
  91. total += labels.size(0)
  92. _,predicted = torch.max(outputs.data,1)
  93. #print(predicted)
  94. #print("label",labels)
  95. correct += (predicted == labels).sum().item()
  96. train_acc = correct / total
  97. print(f"[ Train | {epoch + 1:03d}/{num_epoch:03d} ] loss = {running_loss:.5f}, acc = {train_acc:.5f}")
  98. ##Testing
  99. model.eval()
  100. with torch.no_grad():
  101. correct = 0
  102. total = 0
  103. for i, data in enumerate(test_loader):
  104. inputs, labels = data
  105. inputs = inputs.to(device)
  106. labels = labels.to(device)
  107. outputs = model(inputs)
  108. _,predicted = torch.max(outputs.data,1)
  109. total += labels.size(0)
  110. correct += (predicted == labels).sum().item()
  111. #print(predicted)
  112. #print("labels:",labels)
  113. print('Test Accuracy:{} %'.format((correct / total) * 100))
  114. if __name__ == '__main__':
  115. main()