Examples

ImageNet training example

This example uses rpcdataloader for the training of a ResNet50 on ImageNet. It supports for distributed training and mixed-precision. Modifications to make use of rpcdataloader are highlighted and evaluation routines are ommitted for readability.

Prior to running this script, you should spawn workers.

  1import argparse
  2import os
  3import time
  4
  5import torch
  6from torch import nn
  7from torch.optim.lr_scheduler import StepLR
  8from torch.utils.data import DistributedSampler, RandomSampler
  9from torchvision import transforms
 10from torchvision.datasets import ImageFolder
 11from torchvision.models import get_model
 12
 13from rpcdataloader import RPCDataloader, RPCDataset
 14
 15
 16def main():
 17    argparser = argparse.ArgumentParser()
 18    argparser.add_argument("--data-path")
 19    argparser.add_argument("--model", default="resnet50")
 20    argparser.add_argument("--workers", type=str, nargs="+")
 21    argparser.add_argument("--batch-size", default=2, type=int)
 22    argparser.add_argument("--lr", default=0.1, type=float)
 23    argparser.add_argument("--momentum", default=0.9, type=float)
 24    argparser.add_argument("--weight-decay", default=1e-4, type=float)
 25    argparser.add_argument("--epochs", default=100, type=int)
 26    argparser.add_argument("--amp", action="store_true")
 27    args = argparser.parse_args()
 28
 29    # Distributed
 30    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:  # torchrun launch
 31        rank = int(os.environ["RANK"])
 32        local_rank = int(os.environ["LOCAL_RANK"])
 33        world_size = int(os.environ["WORLD_SIZE"])
 34    elif int(os.environ.get("SLURM_NPROCS", 1)) > 1:  # srun launch
 35        rank = int(os.environ["SLURM_PROCID"])
 36        local_rank = int(os.environ["SLURM_LOCALID"])
 37        world_size = int(os.environ["SLURM_NPROCS"])
 38    else:  # single gpu & process launch
 39        rank = 0
 40        local_rank = 0
 41        world_size = 0
 42
 43    if world_size > 0:
 44        torch.distributed.init_process_group(
 45            backend="nccl", world_size=world_size, rank=rank
 46        )
 47
 48        # split workers between GPUs (optional but recommended)
 49        if len(args.workers) > 0:
 50            args.workers = args.workers[rank::world_size]
 51
 52    print(args)
 53
 54    # Device
 55    device = torch.device("cuda", index=local_rank)
 56
 57    # Preprocessing
 58    normalize = transforms.Normalize(
 59        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
 60    )
 61    train_transform = transforms.Compose(
 62        [
 63            transforms.RandomResizedCrop(224),
 64            transforms.RandAugment(),
 65            transforms.ToTensor(),
 66            normalize,
 67        ]
 68    )
 69
 70    # Datasets
 71    train_dataset = RPCDataset(
 72        args.workers,
 73        ImageFolder,
 74        root=args.data_path + "/train",
 75        transform=train_transform,
 76    )
 77
 78    # Data loading
 79    if torch.distributed.is_initialized():
 80        train_sampler = DistributedSampler(train_dataset)
 81    else:
 82        train_sampler = RandomSampler(train_dataset)
 83
 84    train_loader = RPCDataloader(
 85        train_dataset,
 86        batch_size=args.batch_size,
 87        sampler=train_sampler,
 88        pin_memory=True,
 89    )
 90
 91    # Model
 92    model = get_model(args.model, num_classes=1000)
 93    model.to(device)
 94    if torch.distributed.is_initialized():
 95        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
 96        model = torch.nn.parallel.DistributedDataParallel(
 97            model, device_ids=[local_rank]
 98        )
 99
100    # Optimization
101    optimizer = torch.optim.SGD(
102        model.parameters(),
103        args.lr,
104        momentum=args.momentum,
105        weight_decay=args.weight_decay,
106    )
107    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
108    scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
109    loss_fn = nn.CrossEntropyLoss().to(device)
110
111    # Training
112    for epoch in range(args.epochs):
113        if isinstance(train_sampler, DistributedSampler):
114            train_sampler.set_epoch(epoch)
115
116        for it, (images, targets) in enumerate(train_loader):
117            t0 = time.monotonic()
118
119            optimizer.zero_grad(set_to_none=True)
120
121            images = images.to(device, non_blocking=True)
122            targets = targets.to(device, non_blocking=True)
123
124            with torch.cuda.amp.autocast(enabled=args.amp):
125                predictions = model(images)
126                loss = loss_fn(predictions, targets)
127
128            scaler.scale(loss).backward()
129            scaler.step(optimizer)
130            scaler.update()
131
132            if (it + 1) % 20 == 0 and rank == 0:
133                t1 = time.monotonic()
134                print(
135                    f"[epoch {epoch:<3d}"
136                    f"  it {it:-5d}/{len(train_loader)}]"
137                    f"  loss: {loss.item():2.3f}"
138                    f"  time: {t1 - t0:.1f}"
139                )
140
141        scheduler.step()
142
143
144if __name__ == "__main__":
145    main()

Slurm integration example

To use rpcdataloader on a Slurm cluster, the heterogeneous jobs functionality will let you reserve two groups of resources: one for workers and one for training scripts. The sample script below demonstrates how to do this.

Note that you might need to adjust port numbers to avoid collisions between jobs. You might also need to adjust resource specifications depending on the slurm configuration.

 1#!/usr/bin/env sh
 2#SBATCH --time=0-01:00:00
 3
 4# Resource specfification for training scripts
 5#SBATCH --partition=prismgpup
 6#SBATCH --nodes=2
 7#SBATCH --ntasks-per-node=2
 8#SBATCH --cpus-per-task=2
 9#SBATCH --mem=96G
10#SBATCH --gres=gpu:2
11
12#SBATCH hetjob
13
14# Resource specfification for workers
15#SBATCH --partition=cpu
16#SBATCH --nodes=2
17#SBATCH --ntasks-per-node=10
18#SBATCH --cpus-per-task=1
19#SBATCH --mem=40G
20
21source ~/miniconda3/etc/profile.d/conda.sh
22conda activate rpcdataloader
23export LD_LIBRARY_PATH=/home/users/ngranger/sysroot/lib64
24
25export rpc_port_start=16000
26
27# identify workers
28export tmpfile="${TMPDIR:-/tmp}/rpcdataloader_workers.$SLURM_JOB_ID"
29srun --het-group=1 -I --exclusive --kill-on-bad-exit=1 sh -c '
30    echo $(hostname)-ib:$(( $rpc_port_start + $SLURM_LOCALID ))
31    ' > "${tmpfile}"
32readarray -t workers < "${tmpfile}"
33rm $tmpfile
34
35# start workers in background
36srun --het-group=1 -I --exclusive --kill-on-bad-exit=1 sh -c '
37    python -u -m rpcdataloader.launch \
38        --host=0.0.0.0 \
39        --port=$(( $rpc_port_start + $SLURM_LOCALID ))
40    ' &
41worker_task_pid=$!
42
43# run training script
44export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST_HET_GROUP_0 | head -n 1)
45export MASTER_PORT=17000
46srun --het-group=0 -I --exclusive --kill-on-bad-exit=1 \
47    python -u example_rpc.py \
48        --workers ${workers[@]} \
49        --data-path=/media/ILSVRC \
50        --model=swin_v2_s \
51        --batch-size=256 \
52        --lr=0.001 \
53        --amp
54
55# stop workers
56kill $worker_task_pid