打碎炼丹炉2!!

25 年 6 月 2 日 星期一
1014 字
6 分钟

打碎炼丹炉!!

先让我抱怨两句,同名文件解析为相同的slug都能冲突。。。。。。

这篇纯wandb,可视化,参数调整,版本控制...

一、环境设置与安装

1. 创建 WandB 账户

访问 wandb.ai 注册免费账户

2. 安装 WandB 库

shell
pip install wandb

3. 登录 WandB

shell
wandb login

按照提示输入 API key(在个人设置页面获取)

二、项目代码集成指南

基础集成(5分钟上手)

python
import wandb

# 初始化 WandB 运行
wandb.init(
    project="my-image-classifier",  # 项目名称
    name="resnet-50-experiment",    # 实验名称
    config={                        # 记录超参数
        "learning_rate": 0.001,
        "batch_size": 64,
        "epochs": 20,
        "optimizer": "Adam"
    }
)

# 你的模型代码
model = create_model()

# 训练循环中记录指标
for epoch in range(config.epochs):
    train_loss = train_epoch(model)
    val_loss, val_acc = validate(model)

    # 记录指标到 WandB
    wandb.log({
        "epoch": epoch,
        "train_loss": train_loss,
        "val_loss": val_loss,
        "val_accuracy": val_acc
    })

    # 保存模型检查点
    torch.save(model.state_dict(), "model.pth")
    wandb.save("model.pth")  # 上传到 WandB

# 结束运行
wandb.finish()

三、高级代码集成技巧

1. 自动记录 PyTorch/TensorFlow 模型

python
# PyTorch Lightning 集成
from pytorch_lightning.loggers import WandbLogger

wandb_logger = WandbLogger(project="my-project")

trainer = Trainer(
    logger=wandb_logger,
    max_epochs=10
)
trainer.fit(model)

2. 记录媒体文件

python
# 记录图像
images = []
for i in range(5):
    img = visualize_prediction(model, sample_data[i])
    images.append(wandb.Image(img, caption=f"Sample {i}"))
wandb.log({"predictions": images})

# 记录视频
video = wandb.Video("animation.mp4", caption="Training progression")
wandb.log({"animation": video})

# 记录3D点云
point_cloud = wandb.Object3D(open("pointcloud.obj"))
wandb.log({"point_cloud": point_cloud})

3. 自定义指标可视化

python
# 创建自定义图表
wandb.define_metric("val_loss", summary="min")
wandb.define_metric("val_accuracy", summary="max")

# 添加自定义图表
wandb.log({
    "confusion_matrix": wandb.plot.confusion_matrix(
        y_true=true_labels,
        preds=predictions,
        class_names=class_names
    ),
    "roc_curve": wandb.plot.roc_curve(y_true, y_probs, labels=class_names)
})

4. 超参数扫描 (Sweeps)

创建 sweep.yaml 配置文件:

yaml
program: train.py
method: bayes
metric:
  name: val_accuracy
  goal: maximize
parameters:
  learning_rate:
    min: 1e-5
    max: 1e-2
  batch_size:
    values: [32, 64, 128]
  optimizer:
    values: ['adam', 'sgd', 'rmsprop']

启动扫描:

shell
wandb sweep sweep.yaml
wandb agent your-sweep-id

四、项目管理最佳实践

1. 项目结构建议

text
my-ml-project/
├── data/                 # 数据集
├── models/               # 模型架构
├── train.py              # 训练脚本
├── configs/              # 配置文件
│   ├── base.yaml
│   └── sweep.yaml
├── requirements.txt      # 依赖
└── .gitignore            # 忽略 wandb 缓存

2. 协作工作流

  1. 创建团队项目wandb.init(project="team-name/project-name")
  2. 共享实验结果:添加团队成员到项目
  3. 创建报告:在 WandB 界面组合图表和笔记
  4. 问题跟踪:使用讨论区标记问题

3. 项目生命周期管理

python
# 恢复中断的训练
run = wandb.init(id="previous-run-id", resume="allow")

# 比较多个运行
api = wandb.Api()
runs = api.runs("project-name")
best_run = min(runs, key=lambda r: r.summary.get("val_loss", float('inf')))

# 下载最佳模型
best_model_file = best_run.file("model-best.h5")
best_model_file.download()

五、调试与优化技巧

1. 资源监控

python
# 添加系统监控
wandb.init(settings=wandb.Settings(_disable_stats=False))

# 自定义资源日志
import psutil
wandb.log({
    "cpu_usage": psutil.cpu_percent(),
    "gpu_mem": torch.cuda.memory_allocated() / 1e9
})

2. 错误处理

python
try:
    # 训练代码
except Exception as e:
    wandb.alert(
        title="Training Failed",
        text=f"Error: {str(e)}"
    )
    raise

3. 性能优化

python
# 减少日志频率
wandb.init(settings=wandb.Settings(_log_level="info"))

# 自定义日志间隔
if epoch % 5 == 0:
    wandb.log(...)

# 选择性保存文件
wandb.save("*.pt")  # 只保存模型文件

六、完整项目示例

python
import wandb
import torch
import torch.nn as nn
from torchvision import datasets, transforms

# 初始化 WandB
wandb.init(project="mnist-classifier", config={
    "lr": 1e-3,
    "batch_size": 128,
    "epochs": 10,
    "arch": "CNN"
})

# 加载数据
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True, transform=transform),
    batch_size=wandb.config.batch_size, shuffle=True)

# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

model = Net()
optimizer = torch.optim.Adam(model.parameters(), lr=wandb.config.lr)

# 训练循环
for epoch in range(wandb.config.epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        # 记录批次指标
        if batch_idx % 100 == 0:
            wandb.log({"batch_loss": loss.item()})

    # 验证
    model.eval()
    val_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            val_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    val_loss /= len(test_loader.dataset)
    val_acc = correct / len(test_loader.dataset)

    # 记录 epoch 指标
    wandb.log({
        "epoch": epoch,
        "val_loss": val_loss,
        "val_accuracy": val_acc
    })

    # 保存最佳模型
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")
        wandb.save("best_model.pth")
        wandb.run.summary["best_val_acc"] = best_acc

# 记录最终结果
wandb.run.summary["final_val_acc"] = val_acc
wandb.finish()

七、常见问题解决方案

问题解决方案
无法连接服务器检查 wandb offline 状态,确认网络连接
日志未更新确保 wandb.log() 在训练循环中调用
文件上传失败使用 .wandbignore 过滤大文件
超参数未显示确保在 wandb.init() 中设置 config
内存占用过高减少日志频率,禁用媒体自动记录

文章标题:打碎炼丹炉2!!

文章作者:io-wy

文章链接:https://io-wy.github.io/posts/%E6%89%93%E7%A2%8E%E7%82%BC%E4%B8%B9%E7%82%892[复制]

最后修改时间:


商业转载请联系站长获得授权,非商业转载请注明本文出处及文章链接,您可以自由地在任何媒体以任何形式复制和分发作品,也可以修改和创作,但是分发衍生作品时必须采用相同的许可协议。
本文采用CC BY-NC-SA 4.0进行许可。