model_to_onnx.py 1.2 KB

123456789101112131415161718192021222324
  1. import torch.onnx
  2. from model.efficientnet_pytorch import EfficientNet
  3. import model.resnext_pytorch as resnext
  4. from model.ghostnet import ghostnet
  5. from model.mobilenetv3 import mobilenetv3
  6. import os
  7. #模型训练时,为了跟criterion的nn.CrossEntropyLoss()配合并未添加softmax
  8. #转模型时,将网络forward的最后一层最后添加softmax
  9. #model = EfficientNet.from_name('efficientnet-b0', in_channels=3, num_classes=3, image_size=224)
  10. model = mobilenetv3.MobileNetV3(n_class=9, input_size=256, channels=3, mode='small')
  11. # model.set_swish(memory_efficient=False)
  12. device = torch.device('cpu' if torch.cuda.is_available() else 'cpu')
  13. # 模型不同批次效率不一样,异常值导致,添加还行去除异常值
  14. torch.set_flush_denormal(True)
  15. model.to(device=device)
  16. # model.load_state_dict(torch.load(os.path.join(os.getcwd(), 'checkpoints\mobilenetv3\CP_epoch31.pth'), map_location=device))
  17. model.load_state_dict(torch.load(r'D:\多模态信息\train_re\20240515-数据增强-2\CP_epoch31.pth' , map_location=device))
  18. model.eval()
  19. dummy_input = torch.randn(1, 3, 256, 256)
  20. dummy_input = dummy_input.to(device)
  21. torch.onnx.export(model, dummy_input, r'D:\多模态信息\train_re\20240515-数据增强-2\classification.onnx', verbose=True)