最近在研究扩散模型,从加噪到去噪的过程,对时间戳t以及模型学习信息这件事有了比较新的理解,于是回过头去把CNN和transformer又想了一边,结合别人的工作有了更好的体系构建
模型真的学到信息
神经网络从本质上来说就是由线性变换+非线性激活函数构成的符合函数,通过梯度下降+反向传播来让模型调整参数,更加靠近loss预期的目标;因此如何判断模型学到的信息呢?从结果来说,loss下降,更好的复原图像,更高的准确率,但探究怎么可能只考虑端点呢,于是我们需要更深入的手段
t-SNE空间降维
先上代码
python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
# 1. 加载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
# 2. 定义神经网络(含中间层输出)
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x, return_features=False):
x = x.view(x.size(0), -1)
x = self.fc1(x)
features = self.relu(x)
out = self.fc2(features)
if return_features:
return out, features
return out
model = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 3. 训练模型(保存部分激活)
model.train()
all_feats = []
all_labels = []
for epoch in range(5):
for inputs, labels in trainloader:
optimizer.zero_grad()
outputs, feats = model(inputs, return_features=True)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if epoch == 4: # 最后一轮保存特征
all_feats.append(feats.detach())
all_labels.append(labels)
# 拼接激活和标签
features = torch.cat(all_feats, dim=0).numpy()
labels = torch.cat(all_labels, dim=0).numpy()
# 4. 使用 t-SNE 降维并可视化
tsne = TSNE(n_components=2, random_state=0, init='pca')
features_2d = tsne.fit_transform(features[:2000]) # 取前2000个样本可视化
plt.figure(figsize=(8, 6))
for i in range(10):
idx = labels[:2000] == i
plt.scatter(features_2d[idx, 0], features_2d[idx, 1], label=str(i), alpha=0.6)
plt.legend()
plt.title("t-SNE Visualization of Hidden Layer Features")
plt.show()
强烈建议你自己跑一下,最终我们发现mnist里面每个数字都达到了“抱团”的效果,发现他的work的,于是我们更往起点走一走
CNN中间特征激活
python
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# 1. 定义CNN模型
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) # 28x28 -> 28x28
self.pool = nn.MaxPool2d(2, 2) # 28x28 -> 14x14
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 14x14 -> 14x14
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 2. 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# 3. 训练模型并提取中间层特征
def get_conv1_features(model, input_image):
conv1 = model.conv1
x = input_image.unsqueeze(0) # 增加批量维度
x = torch.relu(conv1(x)) # 激活函数
return x.squeeze(0) # 返回特征图
model = SimpleCNN()
model.eval() # 设置为评估模式
# 4. 可视化卷积层特征图
def visualize_conv1_features(model, image):
features = get_conv1_features(model, image)
num_filters = features.size(0)
fig, axes = plt.subplots(4, 8, figsize=(15, 8))
axes = axes.ravel()
for i in range(num_filters):
ax = axes[i]
ax.imshow(features[i].detach().numpy(), cmap='gray')
ax.axis('off')
plt.show()
# 随机获取一张图片并可视化卷积层特征图
image, label = trainset[0] # 获取第一张图像
visualize_conv1_features(model, image)
此处的情况十分简便,庆幸于中间层之间非常高的关联性,我们有一套完整的方法看到中间所有的特征,因此,可视化的角度,明确了模型work的完整过程,但是理论支撑呢?
我还是不理解为什么可以学到+存储信息
首先中间层由权重构成,权重就可以当作压缩的特征表示,从信息的角度来说,这是直观成立的,那为什么可以学到信息呢?
这就要回到老生常谈的梯度下降和反向传播了
写到这里我就意识到了,脱离的数学严谨证明,太多东西只是我自己yy了,所以我失去了继续往下思考的欲望,而是想看看更多可解释性的东西
权重的动态调整:
- 神经网络中的每个权重值表示一个学习到的映射关系,即它们控制着输入特征如何影响最终的输出。在训练过程中,网络通过优化权重来逼近目标函数(例如分类任务中的类别标签)。
- 权重通过反向传播不断更新,它们在不同层之间传递学习到的信息。每个权重在学习过程中扮演了存储和传递信息的角色。
权重的可调整性:
- 在训练过程中,网络会根据每个权重对损失的影响来进行更新。因此,每个权重值的变化对应着网络“理解”输入数据的方式发生了改变。通过这种方式,权重不仅存储信息,而且能够自适应地调整其值,从而反映输入数据中的不同模式。
尾声
所以这真是毫无意义的一篇blog。。。