wjlin0

作者 修订时间
wjlin0 2025-10-31 15:44:29

6.3 模型微调 - timm

除了使用torchvision.models进行预训练以外,还有一个常见的预训练模型库,叫做timm,这个库是由Ross Wightman创建的。里面提供了许多计算机视觉的SOTA模型,可以当作是torchvision的扩充版本,并且里面的模型在准确度上也较高。在本章内容中,我们主要是针对这个库的预训练模型的使用做叙述,其他部分内容(数据扩增,优化器等)如果大家感兴趣,可以参考以下两个链接。

6.3.1 timm的安装

关于timm的安装,我们可以选择以下两种方式进行:

  1. 通过pip安装
    pip install timm
    
  2. 通过源码编译安装
    git clone https://github.com/rwightman/pytorch-image-models
    cd pytorch-image-models && pip install -e .
    

6.3.2 如何查看预训练模型种类

  1. 查看timm提供的预训练模型 截止到2022.3.27日为止,timm提供的预训练模型已经达到了592个,我们可以通过timm.list_models()方法查看timm提供的预训练模型(注:本章测试代码均是在jupyter notebook上进行)
    import timm
    avail_pretrained_models = timm.list_models(pretrained=True)
    len(avail_pretrained_models)
    
592
  1. 查看特定模型的所有种类 每一种系列可能对应着不同方案的模型,比如Resnet系列就包括了ResNet18,50,101等模型,我们可以在timm.list_models()传入想查询的模型名称(模糊查询),比如我们想查询densenet系列的所有模型。
    all_densnet_models = timm.list_models("*densenet*")
    all_densnet_models
    

    我们发现以列表的形式返回了所有densenet系列的所有模型。

    ['densenet121',
     'densenet121d',
     'densenet161',
     'densenet169',
     'densenet201',
     'densenet264',
     'densenet264d_iabn',
     'densenetblur121d',
     'tv_densenet121']
    
  2. 查看模型的具体参数 当我们想查看下模型的具体参数的时候,我们可以通过访问模型的default_cfg属性来进行查看,具体操作如下
    model = timm.create_model('resnet34',num_classes=10,pretrained=True)
    model.default_cfg
    
    {'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth',
     'num_classes': 1000,
     'input_size': (3, 224, 224),
     'pool_size': (7, 7),
     'crop_pct': 0.875,
     'interpolation': 'bilinear',
     'mean': (0.485, 0.456, 0.406),
     'std': (0.229, 0.224, 0.225),
     'first_conv': 'conv1',
     'classifier': 'fc',
     'architecture': 'resnet34'}
    

    除此之外,我们可以通过访问这个链接 查看提供的预训练模型的准确度等信息。

6.3.3 使用和修改预训练模型

在得到我们想要使用的预训练模型后,我们可以通过timm.create_model()的方法来进行模型的创建,我们可以通过传入参数pretrained=True,来使用预训练模型。同样的,我们也可以使用跟torchvision里面的模型一样的方法查看模型的参数,类型/

import timm
import torch

model = timm.create_model('resnet34',pretrained=True)
x = torch.randn(1,3,224,224)
output = model(x)
output.shape
torch.Size([1, 1000])
- 修改模型(将1000类改为10类输出)
```python
model = timm.create_model('resnet34',num_classes=10,pretrained=True)
x = torch.randn(1,3,224,224)
output = model(x)
output.shape
torch.Size([1, 10])

参考材料

  1. https://www.aiuai.cn/aifarm1967.html
  2. https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055
  3. https://chowdera.com/2022/03/202203170834122729.html