侧边栏壁纸
博主头像
PPP的日记

行动起来,活在当下

  • 累计撰写 13 篇文章
  • 累计创建 14 个标签
  • 累计收到 23 条评论

目 录CONTENT

文章目录

PyTorch SimCLR:从自监督预训练到线性评估的完整实现

TL;DR: 本文介绍 SimCLR(Simple Framework for Contrastive Learning of Visual Representations)的课程项目复现与扩展。基于 PyTorch 实现 ResNet18 主干网络、InfoNCE 对比损失、两视图数据增强,以及线性评估协议。提供随机初始化特征基线、有监督训练上界的对比参考,验证自监督表示在 CIFAR-10 上的有效性。


一、为什么做这个项目

监督学习需要大量标注数据,而标注成本高昂。自监督学习通过设计代理任务(pretext task)让模型从无标注数据中学习有用表示,其中对比学习是最有效的策略之一。

SimCLR 是其中的代表性工作,其核心思想极为优雅:让同一图像的不同增强视图在特征空间中彼此接近,让不同图像的视图彼此远离。

本项目目标是:

  1. 从零复现 SimCLR 训练流程

  2. 添加 checkpoint 保存功能

  3. 实现线性评估协议验证预训练表示质量

  4. 提供随机基线和有监督上界的完整对比参考


二、系统架构

 ┌──────────────────────────────────────────────────────────────────┐
 │                     run.py (SimCLR Pre-training)                 │
 │                                                                  │
 │  ContrastiveLearningDataset                                      │
 │       ↓ (两视图 augmentation)                                     │
 │  ResNet18 + Projection Head                                      │
 │       ↓ (InfoNCE Loss)                                           │
 │  TensorBoard logging                                             │
 │       ↓ (训练完成后)                                              │
 │  simclr_cifar10_resnet18_ep{50,100}.pth                          │
 └──────────────────────────────────────────────────────────────────┘
                               ↓
 ┌──────────────────────────────────────────────────────────────────┐
 │              linear_eval.py (Linear Evaluation Protocol)         │
 │                                                                  │
 │  Frozen SimCLR Encoder ← 加载 .pth                               │
 │       ↓ (提取特征)                                                │
 │  Linear Classifier (随机初始化)                                   │
 │       ↓ (CrossEntropyLoss, SGD)                                  │
 │  Test Accuracy on CIFAR-10                                       │
 └──────────────────────────────────────────────────────────────────┘
                               ↓
 ┌──────────────────────────────────────────────────────────────────┐
 │         linear_eval_random.py / supervised_resnet.py             │
 │                                                                  │
 │  Random Init Encoder / Full Supervised → Baseline & Upper Bound  │
 └──────────────────────────────────────────────────────────────────┘

三、核心模块实现

3.1 数据增强:对比学习的关键

对比学习的核心在于构造有意义的视图差异。SimCLR 使用一套随机变换组合:

 # data_aug/contrastive_learning_dataset.py
 class ContrastiveLearningDataset:
     def get_dataset(self, name, n_views):
         if name == 'cifar10':
             dataset = CIFAR10Dataset(self.data_dir, n_views)
         ...
         return dataset
 ​
 class CIFAR10Dataset:
     def __call__(self, idx):
         # 随机裁剪 + 水平翻转 + 颜色抖动 + 灰度化 + 高斯模糊
         transforms = transforms.Compose([
             transforms.RandomResizedCrop(32, scale=(0.2, 1.0)),
             transforms.RandomHorizontalFlip(),
             transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
             transforms.RandomGrayscale(p=0.2),
             transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5)),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                std=[0.2023, 0.1994, 0.2010]),
         ])
         # 两个独立随机变换 → 两视图
         return transforms(img), transforms(img)

关键设计点:

  • RandomResizedCrop(32, scale=(0.2, 1.0)):从 32×32 图像中随机裁剪到不同大小,模拟物体在不同尺寸下的外观

  • ColorJitter:随机改变亮度、对比度、饱和度、色调,使模型关注语义而非表层颜色

  • GaussianBlur:模拟图像模糊,要求模型学习更鲁棒的语义特征

  • 两个独立随机变换:同一图像产生两个差异化的视图,增加对比难度

3.2 模型架构:ResNet18 + 投影头

 # models/resnet_simclr.py
 class ResNetSimCLR(nn.Module):
     def __init__(self, base_model='resnet18', out_dim=128):
         super().__init__()
         self.resnet = torchvision.models.resnet18(pretrained=False)
         # 移除原始 FC 层,替换为投影头
         self.resnet.fc = nn.Sequential(
             nn.Linear(self.resnet.in_features, 512),
             nn.ReLU(),
             nn.Linear(512, out_dim),
         )
 ​
     def forward(self, x):
         h = self.resnet(x)  # 128-d 特征
         return h

为什么需要投影头?

  • 主干网络输出的特征 h 用于下游任务,直接用于对比学习会损害表示质量

  • 投影头 z = g(h) 将特征映射到另一个空间,在该空间做对比学习

  • 线性评估时只用 h(不用 z),验证特征表示质量

3.3 InfoNCE 损失函数

 # simclr.py
 def info_nce_loss(self, features):
     # features: [2*batch_size, out_dim] (两视图拼接)
     batch_size = self.args.batch_size
     labels = torch.cat([torch.arange(batch_size) for i in range(self.args.n_views)], dim=0)
     labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
     labels = labels.to(self.args.device)  # 对角线为 1(正样本对)
 ​
     features = F.normalize(features, dim=1)  # L2 归一化
     similarity_matrix = torch.matmul(features, features.T)  # 余弦相似度
 ​
     # 掩码移除对角线(自身对比)
     mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device)
     labels = labels[~mask].view(labels.shape[0], -1)
     similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
 ​
     # 正样本对:同一图像的两视图
     positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
     # 负样本对:所有其他图像的视图
     negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
 ​
     logits = torch.cat([positives, negatives], dim=1)  # [N, 1 + 2(N-1)]
     logits = logits / self.args.temperature  # 温度参数控制锐度
 ​
     labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)
     return logits, labels

温度参数 τ=0.07 的作用:

  • 越小,相似度分布越锐利,负样本惩罚越集中

  • 越大,分布越平滑,负样本权重更均匀

3.4 线性评估协议

 # linear_eval.py
 # 加载预训练 encoder(冻结)
 encoder = ResNetSimCLR(base_model=args.arch, out_dim=128)
 state_dict = torch.load(args.pretrained, map_location=device)
 encoder.load_state_dict(state_dict)
 encoder.eval()
 for p in encoder.parameters():
     p.requires_grad = False  # 全部冻结
 ​
 # 线性分类器(随机初始化,可学习)
 num_classes = 10
 classifier = nn.Linear(feat_dim, num_classes).to(device)
 ​
 # 训练时:encoder 始终 eval,只更新 classifier
 for images, labels in train_loader:
     with torch.no_grad():
         features = encoder(images)  # 提取特征
     outputs = classifier(features)  # 线性分类
     loss = criterion(outputs, labels)
     loss.backward()
     optimizer.step(classifier.parameters())

为什么叫"线性评估"?

  • 特征表示由对比学习预训练获得,已包含丰富的语义信息

  • 只需在顶部训练一个线性分类器(无隐藏层)就能达到较好准确率

  • 线性分类器的准确率直接反映特征表示的质量

3.5 混合精度训练

 # simclr.py
 scaler = GradScaler(enabled=self.args.fp16_precision)
 ​
 for images, _ in tqdm(train_loader):
     with autocast(enabled=self.args.fp16_precision):
         features = self.model(images)
         logits, labels = self.info_nce_loss(features)
         loss = self.criterion(logits, labels)
 ​
     scaler.scale(loss).backward()
     scaler.step(self.optimizer)
     scaler.update()

使用 PyTorch 内置 AMP,无需 NVIDIA Apex,在 V100 上可节省约 30% 显存。


四、三种评估范式对比

范式

训练方式

作用

SimCLR 预训练

自监督(对比学习)

学习通用视觉表示

Linear Eval (SimCLR)

冻结 encoder + 线性分类器

验证预训练表示质量

Linear Eval (Random)

随机初始化 encoder + 线性分类器

下游任务基线

Supervised ResNet

全监督端到端训练

任务上界参考


五、关键代码细节

5.1 训练启动与 checkpoint 保存

 # run.py
 model = ResNetSimCLR(base_model=args.arch, out_dim=args.out_dim)
 optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)
 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
     optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1
 )
 ​
 with torch.cuda.device(args.gpu_index):
     simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, args=args)
     simclr.train(train_loader)
     # 保存预训练模型
     checkpoint_name = f"simclr_{args.dataset_name}_{args.arch}_ep{args.epochs}.pth"
     torch.save(simclr.model.state_dict(), checkpoint_name)

CosineAnnealing 学习率调度:

  • 从初始 lr 缓慢下降到 eta_min

  • 适合对比学习,因为后期需要更精细的特征优化

5.2 CUDA 与设备管理

 # 兼容 CPU 和 GPU
 if not args.disable_cuda and torch.cuda.is_available():
     args.device = torch.device('cuda')
     cudnn.deterministic = True
     cudnn.benchmark = True  # 固定输入尺寸时启用加速
 else:
     args.device = torch.device('cpu')
     args.gpu_index = -1

六、项目结构

 SimCLR/
 ├── run.py                         # SimCLR 预训练主脚本
 ├── simclr.py                      # SimCLR 训练逻辑 + InfoNCE Loss
 ├── linear_eval.py                 # 预训练特征线性评估
 ├── linear_eval_random.py          # 随机初始化特征基线
 ├── supervised_resnet.py           # 全监督 ResNet 上界
 ├── utils.py                       # 检查点保存、精度计算
 ├── models/
 │   └── resnet_simclr.py          # ResNet18 + 投影头
 ├── data_aug/
 │   └── contrastive_learning_dataset.py  # 两视图数据增强
 ├── datasets/                       # CIFAR-10 数据目录
 ├── env.yml                        # conda 环境配置
 ├── requirements.txt
 └── simclr_cifar10_resnet18_ep50.pth  # 预训练 checkpoint

七、Trade-offs 与局限性

决策

Trade-off

改进方向

ResNet18 主干

计算量适中,但表示能力有限

可替换为 ResNet50 或 Vision Transformer

两视图增强

简单有效,但增强空间有限

可添加更多变换(CutMix、MixUp)

线性评估

快速验证,但不能反映非线性分类器效果

可添加 MLP 评估协议

CIFAR-10 32×32

小分辨率,适合课程项目

可扩展到 ImageNet 128×128 或 224×224

无数据并行

单 GPU 训练

可添加 DistributedDataParallel 扩展到多卡


八、快速开始

 # 1. 创建环境
 conda env create --name simclr --file env.yml
 conda activate simclr
 ​
 # 2. 准备数据
 # 将 CIFAR-10 放入 ./datasets/cifar-10-batches-py/
 ​
 # 3. 预训练 SimCLR (50 epochs)
 python run.py \
   -data ./datasets \
   --dataset-name cifar10 \
   -a resnet18 \
   --epochs 50 \
   -b 256 \
   --lr 0.0003 \
   --temperature 0.07
 ​
 # 4. 线性评估(使用预训练 checkpoint)
 python linear_eval.py \
   --pretrained simclr_cifar10_resnet18_ep50.pth \
   --epochs 30 \
   --lr 0.01
 ​
 # 5. 随机基线对比
 python linear_eval_random.py
 ​
 # 6. 有监督上界
 python supervised_resnet.py \
   -a resnet18 \
   --epochs 50

九、Further Reading

1

评论区