少女祈祷中...

Under ICLR 2024 double-blind review

使用一个自动编码器,来提取训练模型参数中的隐藏表征,然后扩散模型根据这些隐藏参数表征,合成一些随机噪声,输出一些新的表征给自动编码器的解码器部分,输出就是神经网络的参数。

神经网络扩散

初步了解扩散模型

扩散模型分为两个过程,前向过程和反向过程:

  • 前向过程是对原始图像不断添加高斯噪声(由$\beta$约束),经过$T$步后,得到一个随机高斯噪声($T\to \infty$​时,最后得到的一定是噪声)

    image-20240324121348853

    • $q(.)$:前向过程
    • $N(.)$:高斯噪声
    • $\beta$:约束
    • $I$:单位矩阵
  • 反向过程是前向过程反过来,期望通过选连一个去噪网络(denoising network),移除掉$x_T$上的噪声,直到恢复出原始图像来。

    image-20240324121522824

    • $p_\theta (.)$:反向过程,$\theta$是可学习的参数
    • $\mu _\theta (.)$:通过$\theta$估计的高斯噪声的均值
    • $\sum _\theta (.)$:通过$\theta$​估计的高斯噪声的方差
  • 去噪网络的优化:

    image-20240324121845235

    $D_{KL}(.\vert \vert .)$是通过KL散度来计算两个分布之间的差距。

扩散模型的可行之处在于:能够通过反向过程找到一个去噪网络,将原始的高斯分布转化成最终期望得到的分布。

整体架构

image-20240324124652683

参数自动编码器

首先收集k个训练性能良好的模型,其参数可以表示为:$S=[s_1,…,s_K]$,将这些参数展开平铺成向量:$V=[v_1,…v_K]$,然后通过编码器来提取参数潜在的特征:
$$
Z=[z_1,…,z_K]=f_{encoder}(V,\sigma)
$$
然后将提取出的潜在参数特征$Z$输入到解码器中生成重构后的参数:
$$
V^{‘}=[v_1^{‘},…,v_K^{‘}]=f_{decoder}(Z,\rho)
$$
其中$\sigma,\rho$是参数。

优化路径是最小化MSE:
$$
L_{MSE}=\frac{1}{K}\sum _1^K\Vert v _k-v_k^{‘}\Vert^2
$$

参数生成

若是直接采取将参数$V$输入到编码器,然后解码器输出重构后的参数$V_{‘}$,这样会导致过大的存储开销,尤其是当$V$的维度比较高的时候。

因此,作者采用DDPM中的优化过程来优化去噪网络;
image-20240324133345603

  • $\epsilon$:高斯噪声
  • $\theta$:去噪网络的参数
  • $\epsilon _\theta$:去噪网络生成的噪声
  • $t$:每一轮
  • $\bar \alpha _t$:每一轮的噪声强度

实验

设置

  1. 数据集

    MNIST (LeCun et al., 1998), CIFAR-10/100 (Krizhevsky et al., 2009), ImageNet-1K. (Deng et al., 2009), STL-10 (Coates et al., 2011), Flowers (Nilsback & Zisserman, 2008), Pets (Parkhi et al., 2012), F-101 (Bossard et al., 2014)

  2. 架构

    最开始是在比较小的模型上实验的,这些模型由卷积层、池化层、全连接层组成:

    image-20240328153848280

    使用的卷积层是2D卷积,参考的是DDPM(采用的U-net,生成高质量图片,用的2D-conv),但是效果并不好,可能的原因是图片像素和参数不能一概处理,因此换成了1D-conv,对比结果如下:

    image-20240328162902109

    在更换卷积层的时候,作者也考虑了下直接将卷积层更换为FC,二者效果差不多,但是1D-conv的存储开销低于FC,因此还是选取了1D-conv:

    image-20240328163035736

    此外,还做了消融实验,找到了一个参数$K=200$使得模型的性能最优。

    image-20240328165743211

    作者是在扩大模型架构的时候发现了存储开销特别大的问题,灵感来于stable diffusion,作者采用了一个自动编码器来提取潜在特征,以此来对模型的参数进行降维。

  3. 准备训练数据

    准备了200个独立的高性能参数来训练DiffNet,对于架构简单,参数少的模型,直接从头开始训练;对于架构复杂的,则是在预训练模型的基础上来进行的。

  4. 训练细节

    首先把自动编码器训练2000轮,然后将潜在特征和解码器的参数都保存起来。

    然后训练扩散模型来生成表征,扩散模型的结构式基于1D-conv的U-Net,

  5. 推断阶段

    将100个噪声输入到扩散模型中去,生成了100个模型,选取其中在训练数据集上性能最好的网络。整个的性能图如下:

    image-20240328181545590

代码阅读

准备训练数据

作者通过训练一个ResNet18来得到编码器的训练数据。,也就是下图中的参数输入部分:

image-20240402122657089

代码部分在tsak_training.py中,核心部分是这个:

1
2
# override the abstract method in base_task.py, you obtain the model data for generation
def train_for_data(self):

训练一共有400轮:

  1. 首先将ResNet18训练200轮,将其参数保存下来:

    1
    2
    3
    4
    5
    6
    7
    if i == (epoch - 1):
    # 在第199的时候保存模型
    print("saving the model")
    torch.save(net, os.path.join(tmp_path, "whole_model.pth"))
    # 将不需要训练的层进行固定(取消梯度),后续训练只训练需要训练的层train_layer
    fix_partial_model(train_layer, net)
    parameters = []
  2. 在200轮后的训练中,只训练需要需要训练的层,其他的层的参数被冻结了,每过10轮,都会将训练的层的参数保存下来,存储到临时文件夹中。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    if i >= epoch:
    # 在接下来的训练中,每训练一轮,都会将需要保存的层的参数保存下来,存储到一个列表中
    parameters.append(state_part(train_layer, net))
    save_model_accs.append(acc)
    # 当列表的长度等于10时,或者到达训练结束的时候,将参数保存在硬盘上的临时文件夹中。
    if len(parameters) == 10 or i == all_epoch - 1:
    torch.save(parameters, os.path.join(tmp_path, "p_data_{}.pt".format(i)))
    # 初始化列表
    parameters = []
  3. 最后得到了一个最重要的数据data.pt,里面存储了整个模型whole_model.pth,编码器需要的训练数据pdata,可以将data.pt load一下:

    image-20240402124224411

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
       {'pdata': tensor([[0.3967, 0.3701, 0.3879,  ..., 0.0686, 0.0885, 0.0762],
    [0.3967, 0.3701, 0.3879, ..., 0.0686, 0.0885, 0.0762],
    [0.3967, 0.3701, 0.3878, ..., 0.0686, 0.0885, 0.0762],
    ...,
    [0.3975, 0.3710, 0.3888, ..., 0.0686, 0.0884, 0.0762],
    [0.3975, 0.3710, 0.3888, ..., 0.0686, 0.0884, 0.0762],
    [0.3975, 0.3710, 0.3888, ..., 0.0686, 0.0884, 0.0762]]), 'mean': tensor([0.3973, 0.3707, 0.3885, ..., 0.0686, 0.0884, 0.0762]), 'std': tensor([4.2733e-04, 4.3796e-04, 4.7875e-04, ..., 2.7116e-05, 8.1900e-06,
    2.2222e-05]), 'model': ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer1): Sequential(
    (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (shortcut): Sequential()
    )
    (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (shortcut): Sequential()
    )
    )
    (layer2): Sequential(
    (0): BasicBlock(
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (shortcut): Sequential(
    (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    )
    (1): BasicBlock(
    (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (shortcut): Sequential()
    )
    )
    (layer3): Sequential(
    (0): BasicBlock(
    (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (shortcut): Sequential(
    (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    )
    (1): BasicBlock(
    (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (shortcut): Sequential()
    )
    )
    (layer4): Sequential(
    (0): BasicBlock(
    (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (shortcut): Sequential(
    (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    )
    (1): BasicBlock(
    (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (shortcut): Sequential()
    )
    )
    (linear): Linear(in_features=512, out_features=100, bias=True)
    ), 'train_layer': ['layer4.1.bn1.weight', 'layer4.1.bn1.bias', 'layer4.1.bn2.bias', 'layer4.1.bn2.weight'], 'performance': [71.79, 71.78, 71.8, 71.85, 71.84, 71.8, 71.85, 71.83, 71.77, 71.86, 71.85, 71.82, 71.81, 71.85, 71.89, 71.8, 71.82, 71.85, 71.78, 71.86, 71.87, 71.87, 71.81, 71.84, 71.84, 71.87, 71.87, 71.82, 71.87, 71.86, 71.86, 71.87, 71.85, 71.86, 71.85, 71.86, 71.83, 71.83, 71.93, 71.91, 71.84, 71.8, 71.88, 71.84, 71.78, 71.81, 71.82, 71.8, 71.84, 71.83, 71.85, 71.85, 71.89, 71.75, 71.84, 71.78, 71.82, 71.9, 71.86, 71.89, 71.81, 71.8, 71.84, 71.86, 71.81, 71.84, 71.86, 71.82, 71.84, 71.76, 71.83, 71.82, 71.87, 71.86, 71.83, 71.87, 71.84, 71.81, 71.85, 71.84, 71.87, 71.76, 71.85, 71.78, 71.75, 71.86, 71.88, 71.83, 71.85, 71.83, 71.86, 71.86, 71.85, 71.85, 71.9, 71.86, 71.84, 71.87, 71.88, 71.86, 71.82, 71.82, 71.84, 71.84, 71.82, 71.89, 71.79, 71.86, 71.84, 71.8, 71.86, 71.85, 71.83, 71.83, 71.84, 71.89, 71.87, 71.86, 71.8, 71.84, 71.83, 71.79, 71.84, 71.9, 71.85, 71.86, 71.88, 71.84, 71.86, 71.86, 71.84, 71.86, 71.78, 71.83, 71.87, 71.89, 71.81, 71.86, 71.77, 71.84, 71.92, 71.82, 71.81, 71.8, 71.78, 71.85, 71.89, 71.81, 71.75, 71.8, 71.81, 71.84, 71.88, 71.8, 71.85, 71.8, 71.85, 71.8, 71.95, 71.85, 71.87, 71.83, 71.87, 71.84, 71.82, 71.87, 71.8, 71.86, 71.81, 71.89, 71.86, 71.84, 71.87, 71.81, 71.87, 71.83, 71.82, 71.88, 71.85, 71.84, 71.79, 71.84, 71.81, 71.84, 71.86, 71.85, 71.85, 71.85, 71.83, 71.81, 71.87, 71.83, 71.88, 71.84, 71.85, 71.81, 71.81, 71.76, 71.89, 71.86], 'cfg': {'name': 'classification', 'data': {'data_root': 'data/cifar100', 'dataset': 'cifar100', 'batch_size': 2048, 'num_workers': 8}, 'model': {'_target_': 'models.resnet.ResNet18', 'num_classes': 100}, 'optimizer': {'_target_': 'torch.optim.SGD', 'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0.0005}, 'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.MultiStepLR', 'milestones': [60, 120, 160, 200], 'gamma': 0.2}, 'epoch': 200, 'save_num_model': 200, 'train_layer': ['layer4.1.bn1.weight', 'layer4.1.bn1.bias', 'layer4.1.bn2.bias', 'layer4.1.bn2.weight'], 'param': {'data_root': 'param_data/cifar100/data.pt', 'k': 200, 'num_workers': 4}}}

    ### 训练扩散模型

    代码`train_p_diff.py`,有两种模式,可以在`base.yaml`中选择是训练或者是测试扩散模型。

    项目作者将代码封装的比较好,核心代码在`core`文件夹里面。

    ## 实验结果

    再次梳理一下这篇论文对应的实验的思路:数据集设置为CIFAR10,网络为ResNet18

    1. 在CIFAR10上训练ResNet18,将得到的模型保存下来,在test_data上进行测试,得到acc:

    ```shell
    /home/chengyiqiu/miniconda3/envs/pdiff/bin/python /home/chengyiqiu/code/diffusion/Diffuse-Backdoor-Parameters/test.py
    10.0 # 这是随机初始化ResNet的acc,十类瞎猜理论上是1/10
    86.034 # 这是加载了保存的state_dict后的ResNet18的准确率,良好

    Process finished with exit code 0
  4. 将训练好的ResNet18,选取train-layer,只训练这些层,其他的层的参数冻结(require grad = false),然后训练200个epoch,将train-layer的参数收集起来,假设模型的train-layer的长度是5120,那么收集到的数据的shape就是:(200, 5120),通过这些数据训练一个扩散模型。

  5. 用训练好的扩散模型生成参数,输入的噪声的维度是(200, latent_shape),得到200个生成的train-layer的参数,将其加入到第二步中最开始训练好的ResNet18中,替换对应层的参数,并进行测试,下面是测试结果:

    这是不替换对应层参数的结果:

    1
    # model = partial_reverse_tomodel(param, model, train_layer).to(param.device)
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    /home/chengyiqiu/miniconda3/envs/pdiff/bin/python /home/chengyiqiu/code/diffusion/Diffuse-Backdoor-Parameters/outputs/cifar10/ae_ddpm_cifar10_pth/load.py 
    /home/chengyiqiu/miniconda3/envs/pdiff/lib/python3.8/site-packages/torch/nn/modules/instancenorm.py:80: UserWarning: input's size at dim=1 does not match num_features. You can silence this warning by not passing in num_features, which is not used because affine=False
    warnings.warn(f"input's size at dim={feature_dim} does not match num_features. "
    ae param shape: torch.Size([200, 7178])
    Files already downloaded and verified
    0%| | 0/200 [00:00<?, ?it/s]/home/chengyiqiu/miniconda3/envs/pdiff/lib/python3.8/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
    warnings.warn(warning.format(ret))
    100%|██████████| 200/200 [04:04<00:00, 1.22s/it]
    Sorted list of accuracies: [92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67, 92.67]
    Average accuracy: 92.67
    Max accuracy: 92.67
    Min accuracy: 92.67
    Median accuracy: 92.67

    这是替换对应层的结果:

    1
    model = partial_reverse_tomodel(param, model, train_layer).to(param.device)
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    /home/chengyiqiu/miniconda3/envs/pdiff/bin/python /home/chengyiqiu/code/diffusion/Diffuse-Backdoor-Parameters/outputs/cifar10/ae_ddpm_cifar10_pth/load.py 
    /home/chengyiqiu/miniconda3/envs/pdiff/lib/python3.8/site-packages/torch/nn/modules/instancenorm.py:80: UserWarning: input's size at dim=1 does not match num_features. You can silence this warning by not passing in num_features, which is not used because affine=False
    warnings.warn(f"input's size at dim={feature_dim} does not match num_features. "
    ae param shape: torch.Size([200, 7178])
    Files already downloaded and verified
    0%| | 0/200 [00:00<?, ?it/s]/home/chengyiqiu/miniconda3/envs/pdiff/lib/python3.8/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
    warnings.warn(warning.format(ret))
    100%|██████████| 200/200 [04:04<00:00, 1.22s/it]
    Sorted list of accuracies: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.01, 10.01, 10.01, 10.01, 10.02, 10.03, 10.04, 10.04, 10.08, 10.08, 10.09, 10.16, 10.63, 10.71, 11.04, 11.08, 11.13, 11.17, 11.68, 11.89, 11.94, 12.84, 12.85, 13.27, 13.54, 13.8, 14.11, 14.44, 14.52, 14.77, 14.89, 15.2, 15.93, 16.25, 16.44, 16.7, 17.29, 18.43, 18.44, 18.62, 18.83, 18.9, 18.98, 19.32, 19.34, 19.68, 19.7, 19.9, 20.19, 20.3, 20.3, 20.33, 20.46, 20.85, 20.96, 21.56, 21.6, 22.38, 22.41, 22.41, 23.13, 23.63, 23.71, 23.73, 24.17, 25.84, 26.6, 26.64, 26.76, 26.83, 27.0, 27.26, 27.52, 27.54, 27.72, 27.77, 27.86, 28.0, 28.15, 28.2, 28.29, 28.29, 28.46, 28.71, 28.74, 28.87, 29.02, 29.14, 29.25, 29.93, 30.24, 30.71, 31.66, 32.23, 32.45, 33.66, 34.37, 36.0, 36.56, 36.71, 37.02, 37.32, 37.33, 39.09, 39.4, 39.76, 40.4, 42.54, 44.09, 44.67, 45.6, 46.82, 48.23, 48.83, 52.8, 53.44, 54.1, 54.59, 59.25, 60.46, 63.95, 64.05, 70.35, 70.59, 70.92, 74.63, 76.78, 78.43, 79.81, 80.77, 82.4, 82.61, 85.24, 85.46, 86.22, 87.14, 87.83, 88.27, 88.64, 88.71, 88.98, 89.03, 89.5, 89.94, 89.97, 90.46]
    Average accuracy: 28.35
    Max accuracy: 90.46
    Min accuracy: 10.00
    Median accuracy: 19.69

    这次扩散模型输出的参数效果一般,但也有性能比较好的参数。

  6. 将训练好的模型换成相同架构(ResNet18),相同数据集,植入了后门之后的模型,将对应层的参数进行替换:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    /home/chengyiqiu/miniconda3/envs/pdiff/bin/python /home/chengyiqiu/code/diffusion/Diffuse-Backdoor-Parameters/outputs/cifar10/ae_ddpm_cifar10_pth/load.py 
    /home/chengyiqiu/miniconda3/envs/pdiff/lib/python3.8/site-packages/torch/nn/modules/instancenorm.py:80: UserWarning: input's size at dim=1 does not match num_features. You can silence this warning by not passing in num_features, which is not used because affine=False
    warnings.warn(f"input's size at dim={feature_dim} does not match num_features. "
    ae param shape: torch.Size([200, 7178])
    Files already downloaded and verified
    0%| | 0/200 [00:00<?, ?it/s]/home/chengyiqiu/miniconda3/envs/pdiff/lib/python3.8/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
    warnings.warn(warning.format(ret))
    100%|██████████| 200/200 [04:04<00:00, 1.22s/it]
    Sorted list of accuracies: [1.29, 1.63, 1.99, 2.64, 2.93, 2.98, 3.12, 4.07, 4.18, 4.22, 5.14, 5.18, 5.51, 5.56, 5.65, 6.1, 6.3, 6.4, 7.02, 7.07, 7.34, 7.38, 7.91, 8.75, 8.86, 9.13, 9.15, 9.16, 9.5, 9.65, 9.76, 9.76, 9.78, 9.83, 9.84, 9.86, 9.93, 9.95, 9.96, 9.96, 9.98, 9.98, 9.98, 9.99, 9.99, 9.99, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.01, 10.01, 10.01, 10.02, 10.02, 10.03, 10.03, 10.04, 10.04, 10.06, 10.06, 10.06, 10.07, 10.13, 10.15, 10.23, 10.33, 10.43, 10.46, 10.51, 10.69, 10.83, 11.2, 11.37, 11.38, 11.5, 11.69, 11.87, 12.02, 12.32, 12.34, 12.37, 12.41, 13.09, 13.95, 13.96, 14.04, 14.96, 15.39, 15.8, 17.02, 17.15, 17.23, 17.8, 18.74, 19.39, 20.22]
    Average accuracy: 9.94
    Max accuracy: 20.22
    Min accuracy: 1.29
    Median accuracy: 10.00

    可以看到,没有高性能参数。

  7. 最后单独测试一下植入后门的ResNet的性能,以确定是扩散模型生成的参数导致ResNet的性能下降。

    1
    2
    3
    4
    5
    /home/chengyiqiu/miniconda3/envs/pdiff/bin/python /home/chengyiqiu/code/diffusion/Diffuse-Backdoor-Parameters/test.py 
    10.0
    94.474

    Process finished with exit code 0

实验结论

扩散模型的却可以生成高性能参数,但是生成的参数泛化性十分差劲!

简单探究下原因,虽然说论文中说是用了200个高性能模型,但其实上,者200个高性能模型的前面几层都是一样的,假如我这样选取:

1
train_layer = ['layer4.1.bn1.weight', 'layer4.1.bn1.bias', 'layer4.1.bn2.bias', 'layer4.1.bn2.weight', 'linear.weight', 'linear.bias']

选取模型最后的几层,那么这最终的200个模型的前面几层的参数都是一样的,这严重限制了扩散模型的泛化性能!

问题记录

  1. 数据集的transform设置,训练时和测试时不一致,导致训练得到的准确率有较大的drop,若是统一transform之后,训练的正确率得以改善。

    1
    2
    3
    4
    transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    这是gpt的解释:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    在这段代码中,`transforms.Normalize`是PyTorch中用于数据预处理的一个方法,它执行了数据的归一化操作。归一化是机器学习和深度学习中常用的技术,旨在将输入数据的范围和分布调整到特定的区间,以提高模型的训练效率和性能。

    `transforms.Normalize`的参数如下:

    - `mean`: 一个与输入数据的通道数相同长度的列表,表示每个通道的均值。在这个例子中,`[0.5, 0.5, 0.5]`意味着对于RGB图像,每个颜色通道(红、绿、蓝)的均值都被设置为0.5。

    - `std`: 与`mean`类似,这是一个标准差(standard deviation)的列表,用于每个通道。在这个例子中,每个通道的标准差被设置为0.5。

    归一化操作的具体计算方法是将每个通道的数据减去其均值(`mean`),然后除以其标准差(`std`)。对于图像数据,这通常可以表示为以下公式:

    \[ \text{Normalized pixel value} = \frac{\text{Pixel value} - \text{mean}}{\text{std}} \]

    这种操作有助于将输入数据的分布标准化,使其具有零均值和单位方差(当`mean = 0`且`std = 1`时)。这样做的好处包括:

    1. 加速训练过程:归一化后的数据有助于提高梯度下降算法的收敛速度。
    2. 改善模型性能:标准化的数据可以使模型更容易学习特征,因为它减少了不同特征之间的量纲影响。
    3. 防止梯度消失/爆炸:归一化有助于保持梯度在训练过程中的稳定,防止因输入数据范围过大或过小而导致的梯度问题。

    在实际应用中,`transforms.Normalize`通常与`transforms.ToTensor`一起使用,后者首先将图像数据转换为PyTorch张量(`Tensor`),然后`transforms.Normalize`对这些张量进行归一化处理。这种组合是处理图像数据的常见做法,特别是在使用预训练模型或者需要提高模型性能的场合。

    经过我的测试,transform相同和不同相差10个点:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    /home/chengyiqiu/miniconda3/envs/pdiff/bin/python /home/chengyiqiu/code/diffusion/Diffuse-Backdoor-Parameters/test.py 
    9.83
    87.57

    Process finished with exit code 0


    /home/chengyiqiu/miniconda3/envs/pdiff/bin/python /home/chengyiqiu/code/diffusion/Diffuse-Backdoor-Parameters/test.py
    10.0
    76.11

    Process finished with exit code 0

多样化参数以增强泛化性能

尝试1 慢慢减训练的layer

先把ResNet训练100()轮,然后按以下设置,训练这些层,各100轮,最后只拿出全连接层的参数,查看泛化性能是否提升。

1
2
3
train_layer = ['layer4.1.bn1.weight', 'layer4.1.bn1.bias', 'layer4.1.bn2.bias', 'layer4.1.bn2.weight', 'linear.weight', 'linear.bias']
train_layer = ['layer4.1.bn2.bias', 'layer4.1.bn2.weight', 'linear.weight', 'linear.bias']
train_layer = ['linear.weight', 'linear.bias']

测试下:

相同模型:

1
2
3
4
5
6
res_path = '../tmp/whole_model_resnet18_cifar10.pth'
t = torch.load(res_path)
# resnet.load_state_dict(torch.load(res_path)['model'])
state_dict = torch.load(res_path)['state_dict']
# train_layer = ['layer4.1.bn1.weight', 'layer4.1.bn1.bias', 'layer4.1.bn2.bias', 'layer4.1.bn2.weight']
train_layer = ['linear.weight', 'linear.bias']
1
2
3
4
5
6
7
8
9
10
11
12
13
/home/chengyiqiu/miniconda3/envs/pdiff/bin/python /home/chengyiqiu/code/diffusion/Diffuse-Backdoor-Parameters/tools/load_pdiff.py 
/home/chengyiqiu/miniconda3/envs/pdiff/lib/python3.8/site-packages/torch/nn/modules/instancenorm.py:80: UserWarning: input's size at dim=1 does not match num_features. You can silence this warning by not passing in num_features, which is not used because affine=False
warnings.warn(f"input's size at dim={feature_dim} does not match num_features. "
ae param shape: torch.Size([300, 5130])
Files already downloaded and verified
100%|██████████| 300/300 [26:12<00:00, 5.24s/it]
Sorted list of accuracies: [0.558, 0.61, 0.964, 1.198, 1.96, 2.38, 2.44, 2.508, 2.598, 2.772, 3.982, 4.098, 4.724, 5.266, 5.836, 5.926, 6.122, 6.894, 7.068, 7.25, 7.314, 7.446, 7.458, 7.87, 8.048, 8.436, 8.65, 9.362, 10.164, 10.184, 10.69, 10.73, 10.992, 11.008, 11.086, 11.48, 11.614, 11.804, 11.808, 12.058, 12.248, 12.454, 12.494, 12.658, 12.92, 12.92, 13.072, 13.124, 13.128, 13.43, 13.488, 13.524, 13.828, 14.174, 14.244, 14.324, 14.506, 14.54, 14.972, 15.162, 15.212, 15.258, 15.384, 15.41, 15.694, 16.038, 16.062, 16.132, 16.256, 16.258, 17.032, 17.11, 17.22, 17.4, 17.402, 17.842, 18.15, 18.204, 18.288, 18.524, 18.68, 18.798, 18.926, 19.074, 19.44, 19.818, 20.45, 20.49, 20.55, 20.592, 20.61, 20.67, 20.672, 20.74, 20.832, 20.982, 21.242, 21.254, 21.36, 21.748, 21.812, 22.29, 22.8, 22.842, 23.068, 23.356, 23.724, 23.84, 24.076, 24.396, 24.598, 24.726, 25.052, 25.32, 25.72, 25.756, 26.25, 26.83, 26.968, 26.972, 27.21, 27.328, 27.592, 27.996, 28.21, 28.53, 28.726, 29.04, 29.16, 29.194, 29.424, 29.566, 29.598, 29.678, 29.774, 29.808, 30.916, 31.158, 31.24, 31.338, 31.482, 32.002, 32.148, 32.742, 32.804, 33.08, 33.552, 33.76, 33.92, 34.004, 34.774, 35.834, 36.576, 37.622, 37.68, 38.128, 39.032, 39.044, 39.494, 39.708, 39.876, 40.692, 41.044, 41.104, 42.296, 42.618, 42.918, 42.924, 43.336, 43.896, 43.942, 44.39, 44.82, 45.118, 45.322, 45.94, 46.096, 46.416, 48.732, 48.834, 49.362, 49.578, 49.764, 51.2, 51.566, 51.74, 52.22, 52.87, 53.906, 54.406, 54.882, 56.08, 56.298, 57.084, 57.42, 57.71, 58.088, 58.106, 58.928, 60.352, 60.442, 60.736, 62.214, 62.24, 62.672, 63.272, 63.516, 63.776, 63.808, 64.076, 64.376, 64.792, 64.806, 65.076, 65.462, 65.83, 65.866, 65.928, 66.678, 66.714, 67.088, 67.394, 68.068, 68.344, 68.572, 68.728, 69.234, 69.326, 69.516, 69.592, 70.49, 70.924, 71.772, 71.918, 72.286, 72.31, 72.538, 72.654, 72.828, 73.326, 74.204, 74.62, 74.694, 75.168, 75.762, 76.372, 76.492, 77.38, 77.558, 77.566, 78.034, 78.228, 78.54, 78.716, 78.88, 79.034, 80.036, 80.302, 80.548, 80.782, 82.27, 82.28, 82.704, 82.824, 82.956, 83.308, 83.386, 83.644, 83.67, 83.74, 84.098, 84.486, 84.558, 85.982, 86.032, 86.682, 86.91, 87.098, 87.894, 89.008, 89.156, 89.736, 89.752, 90.448, 90.574, 91.7, 91.79, 92.102, 92.462, 92.522, 92.754, 93.51, 94.272, 95.092, 95.588, 95.974, 95.992, 97.704, 98.034, 98.558]
Average accuracy: 42.82
Max accuracy: 98.56
Min accuracy: 0.56
Median accuracy: 34.39

Process finished with exit code 0

不同模型(badnet):

1
2
3
4
5
6
7
8
9
10
11
/home/chengyiqiu/miniconda3/envs/pdiff/bin/python /home/chengyiqiu/code/diffusion/Diffuse-Backdoor-Parameters/tools/load_pdiff.py 
/home/chengyiqiu/miniconda3/envs/pdiff/lib/python3.8/site-packages/torch/nn/modules/instancenorm.py:80: UserWarning: input's size at dim=1 does not match num_features. You can silence this warning by not passing in num_features, which is not used because affine=False
warnings.warn(f"input's size at dim={feature_dim} does not match num_features. "
ae param shape: torch.Size([300, 5130])
Files already downloaded and verified
100%|██████████| 300/300 [26:20<00:00, 5.27s/it]
Sorted list of accuracies: [0.486, 0.532, 0.826, 0.83, 0.852, 0.892, 0.99, 1.024, 1.186, 1.236, 1.296, 1.31, 1.338, 1.436, 1.516, 1.644, 1.662, 1.678, 1.79, 1.806, 1.942, 2.08, 2.304, 2.316, 2.326, 2.354, 2.374, 2.416, 2.658, 2.728, 2.78, 2.818, 2.974, 3.032, 3.034, 3.038, 3.048, 3.05, 3.106, 3.138, 3.268, 3.328, 3.336, 3.6, 3.668, 3.712, 3.72, 3.84, 3.856, 4.014, 4.078, 4.084, 4.092, 4.092, 4.104, 4.234, 4.342, 4.492, 4.538, 4.62, 4.62, 4.638, 4.768, 4.972, 5.126, 5.152, 5.162, 5.18, 5.21, 5.248, 5.608, 5.64, 5.686, 5.718, 5.76, 5.906, 5.998, 6.11, 6.152, 6.332, 6.362, 6.374, 6.426, 6.466, 6.588, 6.716, 6.776, 6.826, 6.944, 6.986, 7.054, 7.148, 7.158, 7.3, 7.308, 7.332, 7.37, 7.426, 7.534, 7.57, 7.606, 7.662, 7.67, 7.864, 7.96, 7.988, 8.002, 8.194, 8.264, 8.31, 8.368, 8.698, 8.966, 8.99, 9.092, 9.144, 9.144, 9.224, 9.244, 9.288, 9.31, 9.406, 9.438, 9.54, 9.622, 9.628, 9.642, 9.668, 9.728, 9.734, 9.81, 9.836, 9.87, 10.0, 10.048, 10.082, 10.206, 10.4, 10.558, 10.59, 10.656, 10.766, 10.796, 10.978, 10.996, 11.038, 11.13, 11.248, 11.288, 11.332, 11.4, 11.404, 11.428, 11.52, 11.604, 11.622, 11.66, 11.804, 11.92, 12.042, 12.072, 12.144, 12.178, 12.2, 12.252, 12.302, 12.402, 12.52, 12.656, 12.722, 12.752, 12.792, 12.796, 12.838, 12.906, 12.974, 13.054, 13.136, 13.146, 13.156, 13.19, 13.304, 13.456, 13.466, 13.536, 13.58, 13.61, 13.696, 13.704, 13.842, 13.852, 13.914, 14.024, 14.04, 14.062, 14.134, 14.184, 14.222, 14.42, 14.47, 14.578, 14.67, 14.792, 14.958, 14.968, 15.02, 15.064, 15.09, 15.102, 15.398, 15.466, 15.524, 15.712, 15.988, 16.108, 16.16, 16.31, 16.432, 16.466, 16.558, 16.562, 16.624, 16.698, 16.7, 16.728, 16.822, 16.88, 16.886, 17.212, 17.248, 17.248, 17.29, 17.384, 17.486, 17.582, 17.75, 17.852, 17.912, 17.948, 17.962, 18.072, 18.146, 18.41, 18.422, 18.662, 18.722, 18.724, 18.942, 18.974, 18.994, 19.004, 19.046, 19.268, 19.306, 19.338, 19.364, 19.41, 19.664, 19.706, 19.71, 19.848, 19.884, 19.982, 20.014, 20.282, 20.304, 20.7, 20.802, 20.912, 21.106, 21.244, 21.636, 21.718, 22.028, 22.098, 22.212, 22.53, 22.838, 22.924, 22.948, 23.044, 23.734, 23.79, 24.564, 24.716, 24.76, 24.798, 25.438, 25.454, 25.488, 25.572, 25.688, 26.852, 26.91, 27.516, 28.05, 29.412, 31.142, 33.508, 35.426]
Average accuracy: 11.75
Max accuracy: 35.43
Min accuracy: 0.49
Median accuracy: 11.37

泛化性能有所提升,但仍然不够好,接着增加层数,试着增加泛化性能。

尝试2 切换顺序

发现train_layer的weight和bias写反了,修改一下:

1
2
3
4
train_layer_1 = ['layer4.1.bn1.weight', 'layer4.1.bn1.bias', 'layer4.1.bn2.weight', 'layer4.1.bn2.bias',
'linear.weight', 'linear.bias']
train_layer_2 = ['layer4.1.bn2.weight', 'layer4.1.bn2.bias', 'linear.weight', 'linear.bias']
train_layer_3 = ['linear.weight', 'linear.bias']

应该是问题不大的,这里仅仅测试一下,理论上,冻结梯度、保存对应层的时候,都是if name in train_layer:,最后替换参数的时候,也是从网络本身的层数来一层一层判断:for name, pa in model.named_parameters():

这里发现一个比较奇怪的点:

1
2
3
4
5
6
7
8
9
10
11
Epoch 2999, global step 3000: 'ae_acc' reached 2.35000 (best 2.35000), saving model to 'outputs/cifar10/ae_ddpm_cifar100/././checkpoints/ae-epoch=2999-ae_acc=2.3500.ckpt' as top 1
Epoch 5999, global step 6000: 'ae_acc' reached 3.69000 (best 3.69000), saving model to 'outputs/cifar10/ae_ddpm_cifar100/././checkpoints/ae-epoch=5999-ae_acc=3.6900.ckpt' as top 1
Epoch 8999, global step 9000: 'ae_acc' reached 4.73000 (best 4.73000), saving model to 'outputs/cifar10/ae_ddpm_cifar100/././checkpoints/ae-epoch=8999-ae_acc=4.7300.ckpt' as top 1
Epoch 11999, global step 12000: 'ae_acc' was not in top 1
Epoch 14999, global step 15000: 'ae_acc' was not in top 1
Epoch 17999, global step 18000: 'ae_acc' reached 5.04000 (best 5.04000), saving model to 'outputs/cifar10/ae_ddpm_cifar100/././checkpoints/ae-epoch=17999-ae_acc=5.0400.ckpt' as top 1
Epoch 20999, global step 21000: 'ae_acc' was not in top 1
Epoch 23999, global step 24000: 'ae_acc' was not in top 1
Epoch 26999, global step 27000: 'ae_acc' was not in top 1
Epoch 29999, global step 30000: 'ae_acc' was not in top 1
Epoch 32999, global step 33000: 'ae_acc' reached 94.30000 (best 94.30000), saving model to 'outputs/cifar10/ae_ddpm_cifar100/././checkpoints/ae-epoch=32999-ae_acc=94.3000.ckpt' as top 1

前3w轮是在训练AE,正确率都很低,但是一旦到了3w轮后,开始训练DM,正确率马上就上来了。。。

结果还是不行

1
2
3
4
5
6
7
8
9
/home/chengyiqiu/miniconda3/envs/pdiff/bin/python /home/chengyiqiu/code/diffusion/Diffuse-Backdoor-Parameters/tools/load_pdiff.py 
ae param shape: torch.Size([300, 5130])
Files already downloaded and verified
100%|██████████| 300/300 [24:07<00:00, 4.82s/it]
Sorted list of accuracies: [0.554, 0.786, 0.882, 1.226, 1.226, 1.25, 1.296, 1.324, 1.396, 1.532, 1.658, 1.66, 1.704, 1.714, 1.802, 1.924, 2.066, 2.104, 2.18, 2.222, 2.344, 2.426, 2.434, 2.484, 2.572, 2.606, 2.634, 2.7, 2.792, 2.884, 2.942, 2.97, 3.028, 3.034, 3.254, 3.274, 3.286, 3.328, 3.354, 3.462, 3.542, 3.674, 3.698, 3.788, 3.87, 3.984, 4.024, 4.068, 4.088, 4.246, 4.25, 4.26, 4.262, 4.498, 4.512, 4.53, 4.632, 4.74, 4.77, 4.776, 4.834, 4.842, 4.96, 5.056, 5.124, 5.224, 5.386, 5.396, 5.412, 5.628, 5.796, 5.93, 6.046, 6.274, 6.278, 6.294, 6.318, 6.388, 6.402, 6.448, 6.558, 6.57, 6.588, 6.616, 6.656, 6.71, 6.722, 6.822, 6.872, 6.878, 6.922, 6.94, 6.986, 6.998, 7.016, 7.038, 7.074, 7.18, 7.244, 7.282, 7.346, 7.504, 7.54, 7.6, 7.612, 7.64, 7.662, 7.702, 7.712, 7.758, 7.802, 7.866, 8.024, 8.152, 8.242, 8.444, 8.494, 8.508, 8.522, 8.578, 8.61, 8.622, 8.704, 8.74, 8.742, 8.748, 8.76, 8.764, 8.772, 8.806, 8.806, 8.82, 8.904, 8.928, 8.974, 9.038, 9.092, 9.144, 9.174, 9.206, 9.224, 9.31, 9.34, 9.39, 9.406, 9.492, 9.496, 9.508, 9.574, 9.612, 9.638, 9.65, 9.756, 9.758, 9.762, 9.81, 9.818, 9.832, 9.834, 9.866, 9.888, 9.904, 9.944, 9.956, 9.958, 9.968, 9.974, 9.986, 10.0, 10.056, 10.058, 10.102, 10.14, 10.166, 10.262, 10.27, 10.3, 10.308, 10.36, 10.714, 10.714, 10.756, 10.814, 10.818, 10.826, 10.834, 10.838, 10.888, 10.906, 10.93, 10.962, 10.966, 10.968, 10.98, 11.06, 11.078, 11.088, 11.164, 11.168, 11.204, 11.242, 11.258, 11.486, 11.534, 11.542, 11.574, 11.574, 11.678, 11.7, 11.706, 11.744, 11.79, 11.794, 11.85, 11.988, 12.084, 12.122, 12.154, 12.32, 12.37, 12.386, 12.484, 12.786, 12.814, 12.842, 12.864, 12.88, 12.98, 13.126, 13.182, 13.188, 13.214, 13.236, 13.278, 13.516, 13.528, 13.546, 13.606, 13.608, 13.62, 13.722, 13.766, 14.102, 14.19, 14.25, 14.48, 14.484, 14.492, 14.668, 14.69, 14.8, 14.806, 14.876, 15.054, 15.136, 15.146, 15.208, 15.21, 15.216, 15.314, 15.338, 15.478, 15.48, 15.522, 15.626, 15.672, 15.72, 15.752, 15.88, 15.92, 16.014, 16.1, 16.166, 16.168, 16.35, 16.366, 16.37, 16.536, 16.754, 17.194, 17.252, 17.42, 17.466, 17.528, 17.584, 17.84, 18.162, 18.692, 18.852, 18.926, 18.948, 19.224, 19.698, 19.936, 20.124, 20.764, 20.84, 22.566, 22.912, 24.076]
Average accuracy: 9.58
Max accuracy: 24.08
Min accuracy: 0.55
Median accuracy: 9.62

尝试3 加入卷积层训练

这里把ResNet18的最后两层给出来:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
layer4.0.conv1.weight
layer4.0.bn1.weight
layer4.0.bn1.bias
layer4.0.conv2.weight
layer4.0.bn2.weight
layer4.0.bn2.bias
layer4.0.shortcut.0.weight
layer4.0.shortcut.1.weight
layer4.0.shortcut.1.bias
layer4.1.conv1.weight
layer4.1.bn1.weight
layer4.1.bn1.bias
layer4.1.conv2.weight
layer4.1.bn2.weight
layer4.1.bn2.bias
linear.weight
linear.bias

先训练这几个:

1
2
3
4
5
6
7
8
layer4.1.conv1.weight
layer4.1.bn1.weight
layer4.1.bn1.bias
layer4.1.conv2.weight
layer4.1.bn2.weight
layer4.1.bn2.bias
linear.weight
linear.bias

效果提燃很差:

1
2
3
4
5
6
7
8
9
/home/chengyiqiu/miniconda3/envs/pdiff/bin/python /home/chengyiqiu/code/diffusion/Diffuse-Backdoor-Parameters/tools/eval_pdiff.py 
ae param shape: torch.Size([300, 5130])
Files already downloaded and verified
100%|██████████| 300/300 [24:57<00:00, 4.99s/it]
Sorted list of accuracies: [0.6, 0.654, 0.69, 0.702, 0.716, 0.868, 0.878, 0.95, 1.018, 1.124, 1.208, 1.31, 1.328, 1.338, 1.458, 1.514, 1.698, 1.704, 1.902, 1.924, 2.026, 2.074, 2.162, 2.176, 2.224, 2.224, 2.228, 2.296, 2.328, 2.388, 2.414, 2.426, 2.43, 2.464, 2.472, 2.496, 2.51, 2.52, 2.52, 2.56, 2.746, 2.8, 2.834, 2.854, 2.91, 2.93, 3.038, 3.062, 3.08, 3.092, 3.136, 3.148, 3.152, 3.198, 3.232, 3.366, 3.46, 3.464, 3.468, 3.526, 3.528, 3.536, 3.57, 3.608, 3.642, 3.674, 3.68, 3.712, 3.718, 3.81, 3.88, 3.986, 4.024, 4.048, 4.116, 4.136, 4.244, 4.332, 4.342, 4.372, 4.372, 4.444, 4.538, 4.56, 4.562, 4.676, 4.732, 4.774, 4.86, 4.886, 4.928, 4.956, 4.974, 4.982, 5.008, 5.062, 5.168, 5.206, 5.238, 5.266, 5.28, 5.298, 5.414, 5.426, 5.538, 5.57, 5.574, 5.596, 5.598, 5.604, 5.728, 5.742, 5.764, 5.772, 5.884, 5.908, 5.992, 6.004, 6.114, 6.14, 6.19, 6.222, 6.222, 6.4, 6.412, 6.446, 6.54, 6.554, 6.672, 6.766, 6.84, 7.046, 7.198, 7.332, 7.49, 7.572, 7.668, 7.674, 7.722, 8.054, 8.108, 8.114, 8.162, 8.182, 8.398, 8.518, 8.546, 8.636, 8.64, 8.73, 8.734, 8.768, 8.794, 8.818, 9.006, 9.184, 9.214, 9.29, 9.304, 9.328, 9.372, 9.428, 9.428, 9.442, 9.476, 9.486, 9.496, 9.506, 9.568, 9.578, 9.77, 9.852, 9.914, 9.962, 9.992, 10.05, 10.082, 10.102, 10.116, 10.12, 10.128, 10.15, 10.184, 10.276, 10.31, 10.362, 10.386, 10.414, 10.426, 10.464, 10.488, 10.564, 10.594, 10.674, 10.712, 10.744, 10.754, 10.772, 10.822, 10.878, 10.922, 10.944, 10.966, 11.026, 11.03, 11.03, 11.032, 11.114, 11.116, 11.118, 11.12, 11.142, 11.188, 11.204, 11.276, 11.394, 11.408, 11.422, 11.492, 11.532, 11.566, 11.596, 11.608, 11.74, 11.772, 11.94, 12.006, 12.006, 12.016, 12.05, 12.14, 12.268, 12.424, 12.448, 12.48, 12.514, 12.688, 12.708, 12.75, 12.766, 12.822, 12.862, 12.924, 12.996, 13.004, 13.116, 13.128, 13.158, 13.188, 13.218, 13.25, 13.416, 13.476, 13.568, 13.616, 13.668, 13.72, 13.95, 14.014, 14.142, 14.22, 14.324, 14.418, 14.46, 14.528, 14.542, 14.686, 14.742, 14.768, 14.988, 15.5, 15.576, 16.004, 16.012, 16.278, 16.628, 16.728, 16.754, 16.802, 16.942, 17.006, 17.288, 17.29, 17.34, 17.384, 17.424, 17.524, 17.994, 18.078, 18.092, 18.118, 18.13, 18.166, 18.288, 18.57, 18.674, 18.928, 19.316, 19.704, 22.478]
Average accuracy: 8.50
Max accuracy: 22.48
Min accuracy: 0.60
Median accuracy: 8.73

尝试4 4个bn

1
2
3
4
5
6
7
8
9
train_layer_1 = ['layer4.0.bn1.weight', 'layer4.0.bn1.bias', 'layer4.0.bn2.weight', 'layer4.0.bn2.bias',
'layer4.1.bn1.weight', 'layer4.1.bn1.bias', 'layer4.1.bn2.weight', 'layer4.1.bn2.bias',
'linear.weight', 'linear.bias']
train_layer_2 = ['layer4.0.bn2.weight', 'layer4.0.bn2.bias', 'layer4.1.bn1.weight', 'layer4.1.bn1.bias',
'layer4.1.bn2.weight', 'layer4.1.bn2.bias', 'linear.weight', 'linear.bias']
train_layer_3 = ['layer4.1.bn1.weight', 'layer4.1.bn1.bias', 'layer4.1.bn2.weight', 'layer4.1.bn2.bias',
'linear.weight', 'linear.bias']
train_layer_4 = ['layer4.1.bn2.weight', 'layer4.1.bn2.bias', 'linear.weight', 'linear.bias']
train_layer_5 = ['linear.weight', 'linear.bias']

效果:

1
2
3
4
5
6
7
8
9
/home/chengyiqiu/miniconda3/envs/pdiff/bin/python /home/chengyiqiu/code/diffusion/Diffuse-Backdoor-Parameters/tools/eval_pdiff.py 
ae param shape: torch.Size([250, 5130])
Files already downloaded and verified
100%|██████████| 250/250 [20:36<00:00, 4.94s/it]
Sorted list of accuracies: [0.634, 0.784, 0.808, 0.916, 1.012, 1.126, 1.286, 1.322, 1.47, 1.592, 1.718, 1.756, 1.886, 2.05, 2.072, 2.098, 2.192, 2.212, 2.246, 2.254, 2.326, 2.342, 2.504, 2.6, 2.766, 2.888, 2.928, 3.008, 3.178, 3.198, 3.3, 3.438, 3.462, 3.492, 3.534, 3.642, 4.056, 4.072, 4.262, 4.516, 4.558, 4.586, 4.632, 4.776, 4.78, 4.878, 4.938, 5.014, 5.088, 5.208, 5.296, 5.548, 5.556, 5.568, 5.586, 5.644, 5.674, 5.684, 5.802, 5.804, 5.848, 5.882, 5.92, 6.0, 6.024, 6.102, 6.506, 6.544, 6.602, 6.624, 6.714, 6.748, 6.76, 6.852, 6.864, 6.932, 7.036, 7.116, 7.128, 7.204, 7.25, 7.43, 7.45, 7.492, 7.514, 7.626, 7.632, 7.69, 7.702, 8.018, 8.068, 8.138, 8.194, 8.194, 8.252, 8.444, 8.444, 8.448, 8.456, 8.466, 8.562, 8.624, 8.65, 8.668, 8.714, 8.728, 8.752, 9.022, 9.028, 9.124, 9.136, 9.166, 9.25, 9.258, 9.276, 9.376, 9.388, 9.438, 9.582, 9.666, 9.7, 9.71, 9.734, 9.774, 9.784, 9.828, 9.832, 9.832, 9.874, 9.886, 9.898, 9.964, 10.052, 10.106, 10.13, 10.13, 10.144, 10.202, 10.236, 10.258, 10.268, 10.296, 10.302, 10.344, 10.346, 10.362, 10.366, 10.398, 10.426, 10.478, 10.514, 10.558, 10.602, 10.638, 10.656, 10.682, 10.812, 10.826, 10.83, 10.844, 10.866, 10.87, 10.942, 10.954, 11.024, 11.058, 11.08, 11.202, 11.23, 11.37, 11.508, 11.566, 11.642, 11.878, 12.166, 12.252, 12.442, 12.442, 12.472, 12.56, 12.688, 12.762, 12.852, 13.114, 13.218, 13.412, 13.532, 13.558, 13.574, 13.704, 13.722, 13.804, 13.908, 13.972, 14.414, 14.448, 14.67, 14.676, 14.774, 14.778, 14.962, 15.248, 15.254, 15.304, 15.356, 15.426, 15.564, 15.754, 15.826, 15.856, 15.88, 15.978, 15.982, 16.142, 16.318, 16.584, 16.68, 17.06, 17.15, 17.286, 17.656, 17.866, 18.176, 18.37, 18.596, 18.846, 18.898, 19.226, 19.384, 19.396, 19.468, 19.528, 19.552, 19.692, 19.806, 19.952, 19.964, 20.044, 20.088, 20.174, 20.388, 22.354, 22.43, 22.706, 23.08, 23.254, 23.296, 24.044, 26.122, 26.376]
Average accuracy: 10.11
Max accuracy: 26.38
Min accuracy: 0.63
Median accuracy: 9.81

不行,泛化性没有提升。

尝试5 训练-重训练

先训练100轮,得到一个不错的模型,然后将模型的第一层的参数重新随机初始化,继续训练n轮,直到模型正确率达到阈值$\tau$​,将这个模型训练n轮,收集FC层的参数,再重新初始化第一层的参数,如此循环下去。

1
2
3
4
5
6
7
8
9
/home/chengyiqiu/miniconda3/envs/pdiff/bin/python /home/chengyiqiu/code/diffusion/Diffuse-Backdoor-Parameters/tools/eval_pdiff.py 
ae param shape: torch.Size([200, 5130])
Files already downloaded and verified
100%|██████████| 200/200 [16:35<00:00, 4.98s/it]
Sorted list of accuracies: [0.494, 0.768, 1.67, 1.95, 2.124, 2.188, 2.38, 2.388, 2.456, 2.47, 2.564, 2.722, 2.81, 2.886, 2.948, 2.996, 3.08, 3.13, 3.166, 3.254, 3.738, 3.914, 4.008, 4.032, 4.428, 4.462, 4.648, 4.738, 4.822, 4.838, 4.874, 4.91, 4.944, 5.124, 5.262, 5.502, 5.636, 5.64, 5.67, 5.848, 5.882, 6.164, 6.188, 6.342, 6.346, 6.526, 6.556, 6.684, 6.786, 6.794, 6.812, 6.83, 6.85, 6.946, 7.262, 7.388, 7.44, 7.504, 7.542, 7.552, 7.792, 7.866, 7.952, 8.086, 8.176, 8.27, 8.434, 8.532, 8.69, 8.732, 8.87, 8.882, 8.94, 8.954, 9.034, 9.12, 9.294, 9.332, 9.354, 9.374, 9.442, 9.454, 9.458, 9.584, 9.596, 9.618, 9.636, 9.686, 9.796, 9.988, 10.186, 10.314, 10.486, 10.57, 10.58, 10.64, 10.65, 10.782, 11.04, 11.066, 11.094, 11.1, 11.126, 11.126, 11.178, 11.184, 11.25, 11.292, 11.314, 11.356, 11.402, 11.508, 11.608, 11.612, 11.62, 11.626, 11.638, 11.702, 11.898, 11.964, 12.03, 12.122, 12.282, 12.344, 12.36, 12.412, 12.532, 12.63, 12.764, 12.78, 12.83, 12.85, 12.966, 13.016, 13.07, 13.174, 13.25, 13.446, 13.702, 13.762, 13.79, 13.816, 13.838, 14.232, 14.236, 14.3, 14.372, 14.396, 14.446, 14.766, 14.844, 14.862, 15.024, 15.412, 15.456, 15.734, 15.846, 15.858, 16.028, 16.142, 16.258, 16.328, 16.546, 16.66, 16.722, 17.142, 17.246, 17.26, 17.314, 17.326, 17.536, 17.712, 17.8, 17.812, 18.058, 18.058, 18.158, 18.226, 18.286, 18.308, 18.438, 18.538, 19.256, 19.87, 20.006, 21.064, 21.558, 22.054, 22.104, 22.666, 22.688, 23.956, 23.974, 25.298, 25.642, 25.644, 25.734, 26.37, 28.938, 31.89]
Average accuracy: 11.27
Max accuracy: 31.89
Min accuracy: 0.49
Median accuracy: 11.08

尝试6 卷积层重训练

参数:

1
init_layer = ['conv1.weight','layer1.0.conv1.weight', 'layer1.0.conv2.weight', 'layer1.1.conv1.weight','layer1.1.conv2.weight','linear.weight', 'linear.bias']

结果:

1
2
3
4
5
6
7
8
9
(pdiff) chengyiqiu@server:~/code/diffusion/Diffuse-Backdoor-Parameters/tools$ python eval_pdiff.py
ae param shape: torch.Size([200, 5130])
Files already downloaded and verified
100%|███████████████████████████████████████████████████████████| 200/200 [16:05<00:00, 4.83s/it]
Sorted list of accuracies: [0.914, 1.544, 1.632, 1.868, 1.888, 2.2, 2.378, 2.44, 2.624, 2.658, 2.806, 2.884, 3.412, 3.672, 3.852, 3.876, 4.038, 4.624, 4.628, 4.94, 4.978, 5.212, 5.344, 5.35, 5.424, 5.644, 5.706, 5.912, 6.08, 6.088, 6.32, 6.934, 7.166, 7.32, 7.386, 7.404, 7.426, 7.482, 7.74, 7.882, 8.1, 8.164, 8.388, 8.404, 8.428, 8.52, 8.678, 8.832, 8.956, 9.032, 9.194, 9.228, 9.422, 9.508, 9.536, 9.556, 9.562, 9.57, 9.628, 9.634, 9.816, 9.836, 9.91, 9.916, 9.94, 9.942, 9.962, 10.0, 10.028, 10.05, 10.084, 10.136, 10.144, 10.144, 10.258, 10.264, 10.334, 10.39, 10.438, 10.468, 10.48, 10.496, 10.516, 10.52, 10.65, 10.674, 10.678, 10.69, 10.748, 10.752, 10.942, 11.234, 11.282, 11.422, 11.488, 11.658, 11.698, 11.784, 11.79, 11.79, 11.792, 11.92, 11.944, 11.95, 11.996, 12.06, 12.078, 12.222, 12.32, 12.402, 12.442, 12.566, 12.574, 12.598, 12.696, 12.756, 12.836, 12.836, 12.922, 12.96, 12.972, 13.01, 13.092, 13.412, 13.48, 13.55, 13.658, 13.77, 13.776, 13.932, 14.254, 14.262, 14.296, 14.326, 14.39, 14.392, 14.456, 14.526, 14.72, 14.79, 14.868, 14.962, 15.016, 15.018, 15.092, 15.116, 15.2, 15.206, 15.3, 15.35, 15.404, 15.468, 15.674, 16.334, 16.372, 16.378, 16.406, 16.458, 16.698, 16.71, 16.798, 17.088, 17.184, 17.252, 17.684, 17.908, 17.958, 18.242, 18.248, 18.314, 18.632, 18.67, 18.71, 18.716, 18.73, 19.206, 19.666, 19.67, 19.724, 19.762, 20.088, 20.242, 20.512, 20.604, 20.66, 20.912, 22.172, 22.488, 22.868, 22.97, 23.476, 25.878, 25.972, 26.558, 27.16, 29.946, 31.72, 31.974, 32.288, 32.566]
Average accuracy: 12.50
Max accuracy: 32.57
Min accuracy: 0.91
Median accuracy: 11.79

尝试7

1
2
3
init_layer = ['conv1.weight', 'bn1.weight', 'bn1.bias', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight',
'layer1.0.bn1.bias', 'layer4.1.conv2.weight', 'layer4.1.bn2.weight', 'layer4.1.bn2.bias',
'linear.weight', 'linear.bias', ]

结果:

1
2
3
4
5
6
7
8
9
10
11
/home/chengyiqiu/miniconda3/envs/pdiff/bin/python /home/chengyiqiu/code/diffusion/Diffuse-Backdoor-Parameters/tools/eval_pdiff.py 
/home/chengyiqiu/miniconda3/envs/pdiff/lib/python3.8/site-packages/torch/nn/modules/instancenorm.py:80: UserWarning: input's size at dim=1 does not match num_features. You can silence this warning by not passing in num_features, which is not used because affine=False
warnings.warn(f"input's size at dim={feature_dim} does not match num_features. "
ae param shape: torch.Size([200, 5130])
Files already downloaded and verified
100%|██████████| 200/200 [17:28<00:00, 5.24s/it]
Sorted list of accuracies: [0.63, 0.888, 1.024, 1.078, 1.146, 1.17, 1.308, 1.386, 1.444, 1.476, 1.626, 1.828, 1.9, 2.0, 2.012, 2.168, 2.268, 2.332, 2.576, 2.638, 2.678, 2.778, 3.026, 3.028, 3.138, 3.296, 3.32, 3.598, 3.622, 3.832, 4.148, 4.162, 4.52, 4.536, 4.992, 5.018, 5.084, 5.1, 5.11, 5.338, 5.572, 5.744, 5.762, 5.772, 5.864, 5.962, 5.986, 6.116, 6.358, 6.368, 6.476, 6.546, 7.022, 7.142, 7.254, 7.3, 7.452, 7.458, 7.478, 7.522, 7.57, 7.644, 7.7, 7.77, 7.798, 7.824, 7.842, 8.414, 8.506, 8.548, 8.612, 8.636, 8.686, 8.734, 8.862, 8.862, 8.984, 8.992, 9.098, 9.136, 9.14, 9.142, 9.206, 9.312, 9.326, 9.354, 9.562, 9.568, 9.674, 9.712, 9.716, 9.776, 9.79, 9.942, 9.968, 10.0, 10.014, 10.02, 10.036, 10.14, 10.2, 10.25, 10.266, 10.296, 10.324, 10.354, 10.372, 10.39, 10.412, 10.43, 10.462, 10.532, 10.596, 10.616, 10.702, 10.72, 10.734, 10.746, 10.746, 10.76, 10.944, 10.976, 10.978, 11.018, 11.03, 11.044, 11.128, 11.152, 11.226, 11.236, 11.238, 11.398, 11.458, 11.58, 11.594, 11.676, 11.808, 11.87, 11.874, 12.148, 12.23, 12.36, 12.426, 12.51, 12.644, 12.754, 12.772, 12.85, 12.862, 12.902, 12.966, 12.988, 13.122, 13.294, 13.378, 13.458, 14.048, 14.11, 14.234, 14.454, 14.502, 14.636, 14.74, 14.804, 14.872, 14.96, 15.04, 15.32, 15.604, 15.892, 16.058, 16.114, 16.35, 16.674, 16.768, 17.258, 17.414, 17.752, 17.778, 18.206, 18.248, 18.256, 18.296, 18.374, 18.438, 18.674, 18.744, 19.124, 20.658, 20.784, 20.854, 21.198, 21.33, 22.388, 22.688, 22.874, 24.144, 25.728, 27.32, 30.836]
Average accuracy: 10.28
Max accuracy: 30.84
Min accuracy: 0.63
Median accuracy: 10.17