hubconf.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. from model.efficientnet_pytorch import EfficientNet as _EfficientNet
  2. dependencies = ['torch']
  3. def _create_model_fn(model_name):
  4. def _model_fn(num_classes=1000, in_channels=3, pretrained='imagenet'):
  5. """Create Efficient Net.
  6. Described in detail here: https://arxiv.org/abs/1905.11946
  7. Args:
  8. num_classes (int, optional): Number of classes, default is 1000.
  9. in_channels (int, optional): Number of input channels, default
  10. is 3.
  11. pretrained (str, optional): One of [None, 'imagenet', 'advprop']
  12. If None, no pretrained model is loaded.
  13. If 'imagenet', models trained on imagenet dataset are loaded.
  14. If 'advprop', models trained using adversarial training called
  15. advprop are loaded. It is important to note that the
  16. preprocessing required for the advprop pretrained models is
  17. slightly different from normal ImageNet preprocessing
  18. """
  19. model_name_ = model_name.replace('_', '-')
  20. if pretrained is not None:
  21. model = _EfficientNet.from_pretrained(
  22. model_name=model_name_,
  23. advprop=(pretrained == 'advprop'),
  24. num_classes=num_classes,
  25. in_channels=in_channels)
  26. else:
  27. model = _EfficientNet.from_name(
  28. model_name=model_name_,
  29. override_params={'num_classes': num_classes},
  30. )
  31. model._change_in_channels(in_channels)
  32. return model
  33. return _model_fn
  34. for model_name in ['efficientnet_b' + str(i) for i in range(9)]:
  35. locals()[model_name] = _create_model_fn(model_name)