| 作者 | 修订时间 |
|---|---|
| 2025-10-31 15:44:29 |
除了使用torchvision.models进行预训练以外,还有一个常见的预训练模型库,叫做timm,这个库是由Ross Wightman创建的。里面提供了许多计算机视觉的SOTA模型,可以当作是torchvision的扩充版本,并且里面的模型在准确度上也较高。在本章内容中,我们主要是针对这个库的预训练模型的使用做叙述,其他部分内容(数据扩增,优化器等)如果大家感兴趣,可以参考以下两个链接。
关于timm的安装,我们可以选择以下两种方式进行:
pip install timm
git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .
timm.list_models()方法查看timm提供的预训练模型(注:本章测试代码均是在jupyter notebook上进行)
import timm
avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models)
592
timm.list_models()传入想查询的模型名称(模糊查询),比如我们想查询densenet系列的所有模型。
all_densnet_models = timm.list_models("*densenet*")
all_densnet_models
我们发现以列表的形式返回了所有densenet系列的所有模型。
['densenet121',
'densenet121d',
'densenet161',
'densenet169',
'densenet201',
'densenet264',
'densenet264d_iabn',
'densenetblur121d',
'tv_densenet121']
default_cfg属性来进行查看,具体操作如下
model = timm.create_model('resnet34',num_classes=10,pretrained=True)
model.default_cfg
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth',
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
'crop_pct': 0.875,
'interpolation': 'bilinear',
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'first_conv': 'conv1',
'classifier': 'fc',
'architecture': 'resnet34'}
除此之外,我们可以通过访问这个链接 查看提供的预训练模型的准确度等信息。
在得到我们想要使用的预训练模型后,我们可以通过timm.create_model()的方法来进行模型的创建,我们可以通过传入参数pretrained=True,来使用预训练模型。同样的,我们也可以使用跟torchvision里面的模型一样的方法查看模型的参数,类型/
import timm
import torch
model = timm.create_model('resnet34',pretrained=True)
x = torch.randn(1,3,224,224)
output = model(x)
output.shape
torch.Size([1, 1000])
model = timm.create_model('resnet34',pretrained=True)
list(dict(model.named_children())['conv1'].parameters())
```python [Parameter containing: tensor([[[[-2.9398e-02, -3.6421e-02, -2.8832e-02, …, -1.8349e-02, -6.9210e-03, 1.2127e-02], [-3.6199e-02, -6.0810e-02, -5.3891e-02, …, -4.2744e-02, -7.3169e-03, -1.1834e-02], … [ 8.4563e-03, -1.7099e-02, -1.2176e-03, …, 7.0081e-02, 2.9756e-02, -4.1400e-03]]]], requires_grad=True)]
- 修改模型(将1000类改为10类输出)
```python
model = timm.create_model('resnet34',num_classes=10,pretrained=True)
x = torch.randn(1,3,224,224)
output = model(x)
output.shape
torch.Size([1, 10])
in_chans=1来改变
model = timm.create_model('resnet34',num_classes=10,pretrained=True,in_chans=1)
x = torch.randn(1,1,224,224)
output = model(x)
timm库所创建的模型是torch.model的子类,我们可以直接使用torch库中内置的模型参数保存和加载的方法,具体操作如下方代码所示
torch.save(model.state_dict(),'./checkpoint/timm_model.pth')
model.load_state_dict(torch.load('./checkpoint/timm_model.pth'))