You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

109 lines
3.3 KiB

  1. import torch
  2. import pdb
  3. import torch.nn as nn
  4. import math
  5. from torch.autograd import Variable
  6. from torch.autograd import Function
  7. import numpy as np
  8. def Binarize(tensor,quant_mode='det'):
  9. if quant_mode=='det':
  10. return tensor.sign()
  11. else:
  12. return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)
  13. class HingeLoss(nn.Module):
  14. def __init__(self):
  15. super(HingeLoss,self).__init__()
  16. self.margin=1.0
  17. def hinge_loss(self,input,target):
  18. #import pdb; pdb.set_trace()
  19. output=self.margin-input.mul(target)
  20. output[output.le(0)]=0
  21. return output.mean()
  22. def forward(self, input, target):
  23. return self.hinge_loss(input,target)
  24. class SqrtHingeLossFunction(Function):
  25. def __init__(self):
  26. super(SqrtHingeLossFunction,self).__init__()
  27. self.margin=1.0
  28. def forward(self, input, target):
  29. output=self.margin-input.mul(target)
  30. output[output.le(0)]=0
  31. self.save_for_backward(input, target)
  32. loss=output.mul(output).sum(0).sum(1).div(target.numel())
  33. return loss
  34. def backward(self,grad_output):
  35. input, target = self.saved_tensors
  36. output=self.margin-input.mul(target)
  37. output[output.le(0)]=0
  38. import pdb; pdb.set_trace()
  39. grad_output.resize_as_(input).copy_(target).mul_(-2).mul_(output)
  40. grad_output.mul_(output.ne(0).float())
  41. grad_output.div_(input.numel())
  42. return grad_output,grad_output
  43. def Quantize(tensor,quant_mode='det', params=None, numBits=8):
  44. tensor.clamp_(-2**(numBits-1),2**(numBits-1))
  45. if quant_mode=='det':
  46. tensor=tensor.mul(2**(numBits-1)).round().div(2**(numBits-1))
  47. else:
  48. tensor=tensor.mul(2**(numBits-1)).round().add(torch.rand(tensor.size()).add(-0.5)).div(2**(numBits-1))
  49. quant_fixed(tensor, params)
  50. return tensor
  51. #import torch.nn._functions as tnnf
  52. class BinarizeLinear(nn.Linear):
  53. def __init__(self, *kargs, **kwargs):
  54. super(BinarizeLinear, self).__init__(*kargs, **kwargs)
  55. def forward(self, input):
  56. # if input.size(1) != 784:
  57. # input.data=Binarize(input.data)
  58. if not hasattr(self.weight,'org'):
  59. self.weight.org=self.weight.data.clone()
  60. self.weight.data=Binarize(self.weight.org)
  61. out = nn.functional.linear(input, self.weight)
  62. if not self.bias is None:
  63. self.bias.org=self.bias.data.clone()
  64. out += self.bias.view(1, -1).expand_as(out)
  65. return out
  66. class BinarizeConv2d(nn.Conv2d):
  67. def __init__(self, *kargs, **kwargs):
  68. super(BinarizeConv2d, self).__init__(*kargs, **kwargs)
  69. def forward(self, input):
  70. # if input.size(1) != 3:
  71. # input.data = Binarize(input.data)
  72. if not hasattr(self.weight,'org'):
  73. self.weight.org=self.weight.data.clone()
  74. self.weight.data=Binarize(self.weight.org)
  75. out = nn.functional.conv2d(input, self.weight, None, self.stride,
  76. self.padding, self.dilation, self.groups)
  77. if not self.bias is None:
  78. self.bias.org=self.bias.data.clone()
  79. out += self.bias.view(1, -1, 1, 1).expand_as(out)
  80. return out
  81. # x = torch.tensor([[255.0, 200.0, 201.0], [210.0, 222.0, 223.0]])
  82. # print(Quantize(x,quant_mode='det', params=None, numBits=8))