sotabench.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import os
  2. import numpy as np
  3. import PIL
  4. import torch
  5. from torch.utils.data import DataLoader
  6. import torchvision.transforms as transforms
  7. from torchvision.datasets import ImageNet
  8. from model.efficientnet_pytorch import EfficientNet
  9. from sotabencheval.image_classification import ImageNetEvaluator
  10. from sotabencheval.utils import is_server
  11. if is_server():
  12. DATA_ROOT = DATA_ROOT = os.environ.get('IMAGENET_DIR', './imagenet') # './.data/vision/imagenet'
  13. else: # local settings
  14. DATA_ROOT = os.environ['IMAGENET_DIR']
  15. assert bool(DATA_ROOT), 'please set IMAGENET_DIR environment variable'
  16. print('Local data root: ', DATA_ROOT)
  17. model_name = 'EfficientNet-B5'
  18. model = EfficientNet.from_pretrained(model_name.lower())
  19. image_size = EfficientNet.get_image_size(model_name.lower())
  20. input_transform = transforms.Compose([
  21. transforms.Resize(image_size, PIL.Image.BICUBIC),
  22. transforms.CenterCrop(image_size),
  23. transforms.ToTensor(),
  24. transforms.Normalize(
  25. mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  26. ])
  27. test_dataset = ImageNet(
  28. DATA_ROOT,
  29. split="val",
  30. transform=input_transform,
  31. target_transform=None,
  32. )
  33. test_loader = DataLoader(
  34. test_dataset,
  35. batch_size=128,
  36. shuffle=False,
  37. num_workers=4,
  38. pin_memory=True,
  39. )
  40. model = model.cuda()
  41. model.eval()
  42. evaluator = ImageNetEvaluator(model_name=model_name,
  43. paper_arxiv_id='1905.11946')
  44. def get_img_id(image_name):
  45. return image_name.split('/')[-1].replace('.JPEG', '')
  46. with torch.no_grad():
  47. for i, (input, target) in enumerate(test_loader):
  48. input = input.to(device='cuda', non_blocking=True)
  49. target = target.to(device='cuda', non_blocking=True)
  50. output = model(input)
  51. image_ids = [get_img_id(img[0]) for img in test_loader.dataset.imgs[i*test_loader.batch_size:(i+1)*test_loader.batch_size]]
  52. evaluator.add(dict(zip(image_ids, list(output.cpu().numpy()))))
  53. if evaluator.cache_exists:
  54. break
  55. if not is_server():
  56. print("Results:")
  57. print(evaluator.get_results())
  58. evaluator.save()