打碎炼丹炉!!
先让我抱怨两句,同名文件解析为相同的slug都能冲突。。。。。。
这篇纯wandb,可视化,参数调整,版本控制...
一、环境设置与安装
1. 创建 WandB 账户
访问 wandb.ai 注册免费账户
2. 安装 WandB 库
shell
pip install wandb
3. 登录 WandB
shell
wandb login
按照提示输入 API key(在个人设置页面获取)
二、项目代码集成指南
基础集成(5分钟上手)
python
import wandb
# 初始化 WandB 运行
wandb.init(
project="my-image-classifier", # 项目名称
name="resnet-50-experiment", # 实验名称
config={ # 记录超参数
"learning_rate": 0.001,
"batch_size": 64,
"epochs": 20,
"optimizer": "Adam"
}
)
# 你的模型代码
model = create_model()
# 训练循环中记录指标
for epoch in range(config.epochs):
train_loss = train_epoch(model)
val_loss, val_acc = validate(model)
# 记录指标到 WandB
wandb.log({
"epoch": epoch,
"train_loss": train_loss,
"val_loss": val_loss,
"val_accuracy": val_acc
})
# 保存模型检查点
torch.save(model.state_dict(), "model.pth")
wandb.save("model.pth") # 上传到 WandB
# 结束运行
wandb.finish()
三、高级代码集成技巧
1. 自动记录 PyTorch/TensorFlow 模型
python
# PyTorch Lightning 集成
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(project="my-project")
trainer = Trainer(
logger=wandb_logger,
max_epochs=10
)
trainer.fit(model)
2. 记录媒体文件
python
# 记录图像
images = []
for i in range(5):
img = visualize_prediction(model, sample_data[i])
images.append(wandb.Image(img, caption=f"Sample {i}"))
wandb.log({"predictions": images})
# 记录视频
video = wandb.Video("animation.mp4", caption="Training progression")
wandb.log({"animation": video})
# 记录3D点云
point_cloud = wandb.Object3D(open("pointcloud.obj"))
wandb.log({"point_cloud": point_cloud})
3. 自定义指标可视化
python
# 创建自定义图表
wandb.define_metric("val_loss", summary="min")
wandb.define_metric("val_accuracy", summary="max")
# 添加自定义图表
wandb.log({
"confusion_matrix": wandb.plot.confusion_matrix(
y_true=true_labels,
preds=predictions,
class_names=class_names
),
"roc_curve": wandb.plot.roc_curve(y_true, y_probs, labels=class_names)
})
4. 超参数扫描 (Sweeps)
创建 sweep.yaml
配置文件:
yaml
program: train.py
method: bayes
metric:
name: val_accuracy
goal: maximize
parameters:
learning_rate:
min: 1e-5
max: 1e-2
batch_size:
values: [32, 64, 128]
optimizer:
values: ['adam', 'sgd', 'rmsprop']
启动扫描:
shell
wandb sweep sweep.yaml
wandb agent your-sweep-id
四、项目管理最佳实践
1. 项目结构建议
text
my-ml-project/
├── data/ # 数据集
├── models/ # 模型架构
├── train.py # 训练脚本
├── configs/ # 配置文件
│ ├── base.yaml
│ └── sweep.yaml
├── requirements.txt # 依赖
└── .gitignore # 忽略 wandb 缓存
2. 协作工作流
- 创建团队项目:
wandb.init(project="team-name/project-name")
- 共享实验结果:添加团队成员到项目
- 创建报告:在 WandB 界面组合图表和笔记
- 问题跟踪:使用讨论区标记问题
3. 项目生命周期管理
python
# 恢复中断的训练
run = wandb.init(id="previous-run-id", resume="allow")
# 比较多个运行
api = wandb.Api()
runs = api.runs("project-name")
best_run = min(runs, key=lambda r: r.summary.get("val_loss", float('inf')))
# 下载最佳模型
best_model_file = best_run.file("model-best.h5")
best_model_file.download()
五、调试与优化技巧
1. 资源监控
python
# 添加系统监控
wandb.init(settings=wandb.Settings(_disable_stats=False))
# 自定义资源日志
import psutil
wandb.log({
"cpu_usage": psutil.cpu_percent(),
"gpu_mem": torch.cuda.memory_allocated() / 1e9
})
2. 错误处理
python
try:
# 训练代码
except Exception as e:
wandb.alert(
title="Training Failed",
text=f"Error: {str(e)}"
)
raise
3. 性能优化
python
# 减少日志频率
wandb.init(settings=wandb.Settings(_log_level="info"))
# 自定义日志间隔
if epoch % 5 == 0:
wandb.log(...)
# 选择性保存文件
wandb.save("*.pt") # 只保存模型文件
六、完整项目示例
python
import wandb
import torch
import torch.nn as nn
from torchvision import datasets, transforms
# 初始化 WandB
wandb.init(project="mnist-classifier", config={
"lr": 1e-3,
"batch_size": 128,
"epochs": 10,
"arch": "CNN"
})
# 加载数据
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True, transform=transform),
batch_size=wandb.config.batch_size, shuffle=True)
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
model = Net()
optimizer = torch.optim.Adam(model.parameters(), lr=wandb.config.lr)
# 训练循环
for epoch in range(wandb.config.epochs):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
# 记录批次指标
if batch_idx % 100 == 0:
wandb.log({"batch_loss": loss.item()})
# 验证
model.eval()
val_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
val_loss += F.cross_entropy(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
val_loss /= len(test_loader.dataset)
val_acc = correct / len(test_loader.dataset)
# 记录 epoch 指标
wandb.log({
"epoch": epoch,
"val_loss": val_loss,
"val_accuracy": val_acc
})
# 保存最佳模型
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), "best_model.pth")
wandb.save("best_model.pth")
wandb.run.summary["best_val_acc"] = best_acc
# 记录最终结果
wandb.run.summary["final_val_acc"] = val_acc
wandb.finish()
七、常见问题解决方案
问题 | 解决方案 |
---|---|
无法连接服务器 | 检查 wandb offline 状态,确认网络连接 |
日志未更新 | 确保 wandb.log() 在训练循环中调用 |
文件上传失败 | 使用 .wandbignore 过滤大文件 |
超参数未显示 | 确保在 wandb.init() 中设置 config |
内存占用过高 | 减少日志频率,禁用媒体自动记录 |