123456789101112131415161718192021222324 |
- import torch.onnx
- from model.efficientnet_pytorch import EfficientNet
- import model.resnext_pytorch as resnext
- from model.ghostnet import ghostnet
- from model.mobilenetv3 import mobilenetv3
- import os
- #模型训练时,为了跟criterion的nn.CrossEntropyLoss()配合并未添加softmax
- #转模型时,将网络forward的最后一层最后添加softmax
- #model = EfficientNet.from_name('efficientnet-b0', in_channels=3, num_classes=3, image_size=224)
- model = mobilenetv3.MobileNetV3(n_class=9, input_size=256, channels=3, mode='small')
- # model.set_swish(memory_efficient=False)
- device = torch.device('cpu' if torch.cuda.is_available() else 'cpu')
- # 模型不同批次效率不一样,异常值导致,添加还行去除异常值
- torch.set_flush_denormal(True)
- model.to(device=device)
- # model.load_state_dict(torch.load(os.path.join(os.getcwd(), 'checkpoints\mobilenetv3\CP_epoch31.pth'), map_location=device))
- model.load_state_dict(torch.load(r'D:\多模态信息\train_re\20240515-数据增强-2\CP_epoch31.pth' , map_location=device))
- model.eval()
- dummy_input = torch.randn(1, 3, 256, 256)
- dummy_input = dummy_input.to(device)
- torch.onnx.export(model, dummy_input, r'D:\多模态信息\train_re\20240515-数据增强-2\classification.onnx', verbose=True)
|