屹和 發表於 2025-6-23 09:11:00

pytorch入门 - 基于AlexNet神经网络实现猫狗大战

<p>&nbsp;</p>
<p>基于之前的博客&nbsp;pytorch入门 - AlexNet神经网络,并借助Kaggle 的&nbsp;<strong>Dogs vs Cats Redux</strong>&nbsp;数据集,实现一个基于 AlexNet 的二分类模型识别猫与狗。</p>
<p>完整流程涵盖数据准备、归一化、模型定义、训练增强、验证并可视化结果。</p>
<h4>&nbsp;一、数据集准备与预处理</h4>
<pre><code>import os
import shutil

def split_data(ROOT_TRAIN):
    cat_dir = os.path.join(ROOT_TRAIN, "cat")
    dog_dir = os.path.join(ROOT_TRAIN, "dog")
    os.makedirs(cat_dir, exist_ok=True)
    os.makedirs(dog_dir, exist_ok=True)
   
    for filename in os.listdir(ROOT_TRAIN):
      if filename.startswith("cat") and filename.endswith(".jpg"):
            shutil.move(os.path.join(ROOT_TRAIN, filename),
                        os.path.join(cat_dir, filename))
      elif filename.startswith("dog") and filename.endswith(".jpg"):
            shutil.move(os.path.join(ROOT_TRAIN, filename),
                        os.path.join(dog_dir, filename))</code></pre>
<p>​<strong>​优化原因​</strong>​:<br>
分类任务需明确标签与数据的对应关系。通过创建<code>cat/dog</code>子目录并移动图片,可直接利用PyTorch的<code>ImageFolder</code>自动生成标签,避免手动标注错误。</p>
<h4>二、数据归一化参数计算</h4>
<pre><code>def compute_normalization_params(dataset_path):
    transform = transforms.Compose([
      transforms.Resize((227, 227)),
      transforms.ToTensor()
    ])
    dataset = ImageFolder(dataset_path, transform=transform)
    loader = DataLoader(dataset, batch_size=32, num_workers=4, shuffle=False)
   
    # 计算各通道均值和标准差
    mean = 0.0
    std = 0.0
    for data, _ in loader:
      batch_samples = data.size(0)
      data = data.view(batch_samples, data.size(1), -1)
      mean += data.mean(2).sum(0)
      std += data.std(2).sum(0)
    return mean / len(dataset), std / len(dataset)</code></pre>
<p>​<strong>​关键点​</strong>​:</p>
<ol>
<li>​<strong>​输入尺寸统一​</strong>​:AlexNet要求固定输入尺寸<code>227×227</code>,需提前调整</li>
<li>​<strong>​通道级归一化​</strong>​:对RGB三通道分别计算均值和标准差,消除光照差异影响,加速模型收敛</li>
<li>​<strong>​离线计算​</strong>​:避免在训练时实时计算,提升数据加载效率</li>
</ol>
<h4>三、AlexNet模型针对性修改</h4>
<pre><code>class AlexNet(nn.Module):
    def __init__(self):
      super().__init__()
      # 修改1:输入通道调整为3 (RGB)
      self.conv1 = nn.Conv2d(3, 96, kernel_size=11, stride=4)
      # ... (中间层省略)
      # 修改2:输出层调整为2分类
      self.fc3 = nn.Linear(4096, 2)
      
      # 修改3:降低Dropout比例
      self.dropout = nn.Dropout(0.2)# 原论文为0.5</code></pre>
<p>​<strong>​优化逻辑​</strong>​:</p>
<ol>
<li>​<strong>​输入通道适配​</strong>​:原始AlexNet针对ImageNet的1000类设计,此处调整为猫狗二分类,需修改输出层维度为2</li>
<li>​<strong>​降低过拟合风险​</strong>​:
<ul>
<li>猫狗数据集(25k张)远小于ImageNet(1400万张)</li>
<li>降低Dropout比例(0.5→0.2)可保留更多特征信息,避免模型欠拟合</li>
</ul>
</li>
<li>​<strong>​权重初始化​</strong>​:采用Kaiming初始化,适配ReLU激活函数特性,缓解梯度消失</li>
</ol>
<h4>四、数据增强策略</h4>
<pre><code>train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(227, scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=,
                         std=)
])</code></pre>
<p>​<strong>​增强目的​</strong>​:</p>
<ol>
<li>​<strong>​提升泛化能力​</strong>​:通过旋转、裁剪、色彩扰动模拟真实场景的多样性,防止模型记忆固定模式</li>
<li>​<strong>​克服数据局限​</strong>​:小数据集易导致过拟合,增强后等效扩大数据规模</li>
<li>​<strong>​对齐测试环境​</strong>​:测试阶段采用相同预处理,保证输入分布一致性</li>
</ol>
<h4>五、训练过程优化</h4>
<pre><code># 1. 学习率调整
optimizer = optim.Adam(model.parameters(), lr=1e-4)# 原常用值0.001

# 2. 训练-验证集拆分
train_data, val_data = random_split(dataset, )

# 3. 早停机制
if val_acc &gt; best_acc:
    best_model_wts = copy.deepcopy(model.state_dict())</code></pre>
<p>​<strong>​关键技术点​</strong>​:</p>
<ol>
<li>​<strong>​低学习率策略​</strong>​:
<ul>
<li>预训练模型特征已较完备,降低学习率(1e-4)避免破坏已有特征</li>
<li>微调阶段需精细调整参数,高学习率易导致震荡</li>
</ul>
</li>
<li>​<strong>​验证集独立划分​</strong>​:
<ul>
<li>20%数据作为验证集,实时监控模型泛化能力</li>
<li>避免测试集参与训练,保证评估客观性</li>
</ul>
</li>
<li>​<strong>​混合精度训练(可选)​</strong>​:<br>
使用<code>torch.cuda.amp</code>自动混合精度,提升训练速度30%+(需GPU支持)</li>
</ol>
<h3>关键优化总结</h3>
<div class="hyc-common-markdown__table-wrapper">
<table style="height: 205px; width: 1302px">
<thead>
<tr><th>优化点</th><th>原始值</th><th>调整值</th><th>作用</th></tr>
</thead>
<tbody>
<tr>
<td>输入通道</td>
<td>1 (灰度)</td>
<td>3 (RGB)</td>
<td>适配彩色图像</td>
</tr>
<tr>
<td>输出维度</td>
<td>1000</td>
<td>2</td>
<td>二分类需求</td>
</tr>
<tr>
<td>Dropout率</td>
<td>0.5</td>
<td>0.2</td>
<td>防欠拟合</td>
</tr>
<tr>
<td>学习率</td>
<td>0.001</td>
<td>0.0001</td>
<td>稳定微调</td>
</tr>
<tr>
<td>数据增强</td>
<td>无</td>
<td>5种变换</td>
<td>提升泛化性</td>
</tr>
</tbody>
</table>
</div>
<p>&nbsp;</p><br><br>
来源:https://www.cnblogs.com/chenyishi/p/18942679
頁: [1]
查看完整版本: pytorch入门 - 基于AlexNet神经网络实现猫狗大战