You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

211 lines
6.5 KiB

  1. import torch
  2. import numpy as np
  3. import cv2, os, sys
  4. import pandas as pd
  5. from torch.utils.data import Dataset
  6. from matplotlib import pyplot as plt
  7. from torch.utils.data import ConcatDataset, DataLoader, Subset
  8. import torch.nn as nn
  9. import torchvision.transforms as transforms
  10. from torchvision.datasets import DatasetFolder
  11. from PIL import Image
  12. import torchvision.models
  13. import BinaryNetpytorch.models as models
  14. from BinaryNetpytorch.models.binarized_modules import BinarizeLinear,BinarizeConv2d
  15. batch_size = 32
  16. num_epoch = 10
  17. train_tfm = transforms.Compose([
  18. # transforms.RandomHorizontalFlip(),
  19. # transforms.RandomResizedCrop((40,30)),
  20. transforms.Grayscale(),
  21. transforms.Resize((40, 30)),
  22. transforms.ToTensor(),
  23. #transforms.RandomResizedCrop((40,30)),
  24. #transforms.TenCrop((40,30)),
  25. # transforms.Normalize(0.5,0.5),
  26. ])
  27. test_tfm = transforms.Compose([
  28. transforms.Grayscale(),
  29. transforms.Resize((40, 30)),
  30. transforms.ToTensor()
  31. ])
  32. def Binaryconv3x3(in_planes, out_planes, stride=1):
  33. "3x3 convolution with padding"
  34. return BinarizeConv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  35. padding=1, bias=False)
  36. def conv3x3(in_planes, out_planes, stride=1):
  37. "3x3 convolution with padding"
  38. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  39. padding=1, bias=False)
  40. class BasicBlock(nn.Module):
  41. expansion = 1
  42. def __init__(self, inplanes, planes, stride=1, downsample=None,do_bntan=True):
  43. super(BasicBlock, self).__init__()
  44. self.conv1 = Binaryconv3x3(inplanes, planes, stride)
  45. self.bn1 = nn.BatchNorm2d(planes)
  46. self.tanh1 = nn.Hardtanh(inplace=True)
  47. self.conv2 = Binaryconv3x3(planes, planes)
  48. self.tanh2 = nn.Hardtanh(inplace=True)
  49. self.bn2 = nn.BatchNorm2d(planes)
  50. self.downsample = downsample
  51. self.do_bntan=do_bntan
  52. self.stride = stride
  53. def forward(self, x):
  54. residual = x.clone()
  55. out = self.conv1(x)
  56. out = self.bn1(out)
  57. out = self.tanh1(out)
  58. out = self.conv2(out)
  59. if self.downsample is not None:
  60. if residual.data.max()>1:
  61. import pdb; pdb.set_trace()
  62. residual = self.downsample(residual)
  63. out += residual
  64. if self.do_bntan:
  65. out = self.bn2(out)
  66. out = self.tanh2(out)
  67. return out
  68. class ResNet(nn.Module):
  69. def __init__(self):
  70. super(ResNet, self).__init__()
  71. def _make_layer(self, block, planes, blocks, stride=1,do_bntan=True):
  72. downsample = None
  73. if stride != 1 or self.inplanes != planes * block.expansion:
  74. downsample = nn.Sequential(
  75. BinarizeConv2d(self.inplanes, planes * block.expansion,
  76. kernel_size=1, stride=stride, bias=False),
  77. nn.BatchNorm2d(planes * block.expansion),
  78. )
  79. layers = []
  80. layers.append(block(self.inplanes, planes, stride, downsample))
  81. self.inplanes = planes * block.expansion
  82. for i in range(1, blocks-1):
  83. layers.append(block(self.inplanes, planes))
  84. layers.append(block(self.inplanes, planes,do_bntan=do_bntan))
  85. return nn.Sequential(*layers)
  86. def forward(self, x):
  87. x = self.conv1(x)
  88. x = self.maxpool(x)
  89. x = self.bn1(x)
  90. x = self.tanh1(x)
  91. x = self.layer1(x)
  92. x = self.layer2(x)
  93. x = self.layer3(x)
  94. x = self.layer4(x)
  95. x = self.avgpool(x)
  96. x = x.view(x.size(0), -1)
  97. x = self.bn2(x)
  98. x = self.tanh2(x)
  99. x = self.fc(x)
  100. x = self.bn3(x)
  101. x = self.logsoftmax(x)
  102. return x
  103. class ResNet_cifar10(ResNet):
  104. def __init__(self, num_classes=3,
  105. block=BasicBlock, depth=18):
  106. super(ResNet_cifar10, self).__init__()
  107. self.inflate = 5
  108. self.inplanes = 16*self.inflate
  109. n = int((depth - 2) / 6)
  110. self.conv1 = BinarizeConv2d(1, 16*self.inflate, kernel_size=3, stride=1, padding=1,
  111. bias=False)
  112. self.maxpool = lambda x: x
  113. self.bn1 = nn.BatchNorm2d(16*self.inflate)
  114. self.tanh1 = nn.Hardtanh(inplace=True)
  115. self.tanh2 = nn.Hardtanh(inplace=True)
  116. self.layer1 = self._make_layer(block, 16*self.inflate, n)
  117. self.layer2 = self._make_layer(block, 32*self.inflate, n, stride=2)
  118. self.layer3 = self._make_layer(block, 64*self.inflate, n, stride=2,do_bntan=False)
  119. self.layer4 = lambda x: x
  120. self.avgpool = nn.AvgPool2d(8)
  121. self.bn2 = nn.BatchNorm1d(64*self.inflate)
  122. self.bn3 = nn.BatchNorm1d(3)
  123. self.logsoftmax = nn.LogSoftmax()
  124. self.fc = BinarizeLinear(64*self.inflate, 3)
  125. def main():
  126. train_set = DatasetFolder("pose_data/training/labeled", loader=lambda x: Image.open(x), extensions="bmp", transform=train_tfm)
  127. test_set = DatasetFolder("pose_data/testing", loader=lambda x: Image.open(x), extensions="bmp", transform=test_tfm)
  128. train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
  129. test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
  130. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  131. model = ResNet_cifar10(num_classes=3,block=BasicBlock,depth=18)
  132. model = model.to(device)
  133. print(model)
  134. optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  135. criterion = nn.CrossEntropyLoss()
  136. for epoch in range(num_epoch):
  137. running_loss = 0.0
  138. total = 0
  139. correct = 0
  140. for i, data in enumerate(train_loader):
  141. inputs, labels = data
  142. inputs = inputs.to(device)
  143. labels = labels.to(device)
  144. optimizer.zero_grad()
  145. outputs = model(inputs)
  146. loss = criterion(outputs, labels)
  147. loss.backward()
  148. optimizer.step()
  149. running_loss += loss.item()
  150. total += labels.size(0)
  151. _,predicted = torch.max(outputs.data,1)
  152. #print(predicted)
  153. #print("label",labels)
  154. correct += (predicted == labels).sum().item()
  155. train_acc = correct / total
  156. print(f"[ Train | {epoch + 1:03d}/{num_epoch:03d} ] loss = {running_loss:.5f}, acc = {train_acc:.5f}")
  157. model.eval()
  158. with torch.no_grad():
  159. correct = 0
  160. total = 0
  161. for i, data in enumerate(test_loader):
  162. inputs, labels = data
  163. inputs = inputs.to(device)
  164. labels = labels.to(device)
  165. outputs = model(inputs)
  166. _,predicted = torch.max(outputs.data,1)
  167. total += labels.size(0)
  168. correct += (predicted == labels).sum().item()
  169. #print(predicted)
  170. #print("labels:",labels)
  171. print('Test Accuracy:{} %'.format((correct / total) * 100))
  172. if __name__ == '__main__':
  173. main()