torchvision上的模型都是基于ImageNet的,想将其用于MNIST也很简单,只需要一点点修改(以resnet18为例):

1
2
model = tv.models.resnet18(num_classes = 10)
model.conv1 = Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) #因为mnist是单通道的,而ImageNet是3通道

end