Examples
ImageNet training example
This example uses rpcdataloader for the training of a ResNet50 on ImageNet. It supports for distributed training and mixed-precision. Modifications to make use of rpcdataloader are highlighted and evaluation routines are ommitted for readability.
Prior to running this script, you should spawn workers.
1import argparse
2import os
3import time
4
5import torch
6from torch import nn
7from torch.optim.lr_scheduler import StepLR
8from torch.utils.data import DistributedSampler, RandomSampler
9from torchvision import transforms
10from torchvision.datasets import ImageFolder
11from torchvision.models import get_model
12
13from rpcdataloader import RPCDataloader, RPCDataset
14
15
16def main():
17 argparser = argparse.ArgumentParser()
18 argparser.add_argument("--data-path")
19 argparser.add_argument("--model", default="resnet50")
20 argparser.add_argument("--workers", type=str, nargs="+")
21 argparser.add_argument("--batch-size", default=2, type=int)
22 argparser.add_argument("--lr", default=0.1, type=float)
23 argparser.add_argument("--momentum", default=0.9, type=float)
24 argparser.add_argument("--weight-decay", default=1e-4, type=float)
25 argparser.add_argument("--epochs", default=100, type=int)
26 argparser.add_argument("--amp", action="store_true")
27 args = argparser.parse_args()
28
29 # Distributed
30 if "RANK" in os.environ and "WORLD_SIZE" in os.environ: # torchrun launch
31 rank = int(os.environ["RANK"])
32 local_rank = int(os.environ["LOCAL_RANK"])
33 world_size = int(os.environ["WORLD_SIZE"])
34 elif int(os.environ.get("SLURM_NPROCS", 1)) > 1: # srun launch
35 rank = int(os.environ["SLURM_PROCID"])
36 local_rank = int(os.environ["SLURM_LOCALID"])
37 world_size = int(os.environ["SLURM_NPROCS"])
38 else: # single gpu & process launch
39 rank = 0
40 local_rank = 0
41 world_size = 0
42
43 if world_size > 0:
44 torch.distributed.init_process_group(
45 backend="nccl", world_size=world_size, rank=rank
46 )
47
48 # split workers between GPUs (optional but recommended)
49 if len(args.workers) > 0:
50 args.workers = args.workers[rank::world_size]
51
52 print(args)
53
54 # Device
55 device = torch.device("cuda", index=local_rank)
56
57 # Preprocessing
58 normalize = transforms.Normalize(
59 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
60 )
61 train_transform = transforms.Compose(
62 [
63 transforms.RandomResizedCrop(224),
64 transforms.RandAugment(),
65 transforms.ToTensor(),
66 normalize,
67 ]
68 )
69
70 # Datasets
71 train_dataset = RPCDataset(
72 args.workers,
73 ImageFolder,
74 root=args.data_path + "/train",
75 transform=train_transform,
76 )
77
78 # Data loading
79 if torch.distributed.is_initialized():
80 train_sampler = DistributedSampler(train_dataset)
81 else:
82 train_sampler = RandomSampler(train_dataset)
83
84 train_loader = RPCDataloader(
85 train_dataset,
86 batch_size=args.batch_size,
87 sampler=train_sampler,
88 pin_memory=True,
89 )
90
91 # Model
92 model = get_model(args.model, num_classes=1000)
93 model.to(device)
94 if torch.distributed.is_initialized():
95 model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
96 model = torch.nn.parallel.DistributedDataParallel(
97 model, device_ids=[local_rank]
98 )
99
100 # Optimization
101 optimizer = torch.optim.SGD(
102 model.parameters(),
103 args.lr,
104 momentum=args.momentum,
105 weight_decay=args.weight_decay,
106 )
107 scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
108 scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
109 loss_fn = nn.CrossEntropyLoss().to(device)
110
111 # Training
112 for epoch in range(args.epochs):
113 if isinstance(train_sampler, DistributedSampler):
114 train_sampler.set_epoch(epoch)
115
116 for it, (images, targets) in enumerate(train_loader):
117 t0 = time.monotonic()
118
119 optimizer.zero_grad(set_to_none=True)
120
121 images = images.to(device, non_blocking=True)
122 targets = targets.to(device, non_blocking=True)
123
124 with torch.cuda.amp.autocast(enabled=args.amp):
125 predictions = model(images)
126 loss = loss_fn(predictions, targets)
127
128 scaler.scale(loss).backward()
129 scaler.step(optimizer)
130 scaler.update()
131
132 if (it + 1) % 20 == 0 and rank == 0:
133 t1 = time.monotonic()
134 print(
135 f"[epoch {epoch:<3d}"
136 f" it {it:-5d}/{len(train_loader)}]"
137 f" loss: {loss.item():2.3f}"
138 f" time: {t1 - t0:.1f}"
139 )
140
141 scheduler.step()
142
143
144if __name__ == "__main__":
145 main()
Slurm integration example
To use rpcdataloader on a Slurm cluster, the heterogeneous jobs functionality will let you reserve two groups of resources: one for workers and one for training scripts. The sample script below demonstrates how to do this.
Note that you might need to adjust port numbers to avoid collisions between jobs. You might also need to adjust resource specifications depending on the slurm configuration.
1#!/usr/bin/env sh
2#SBATCH --time=0-01:00:00
3
4# Resource specfification for training scripts
5#SBATCH --partition=prismgpup
6#SBATCH --nodes=2
7#SBATCH --ntasks-per-node=2
8#SBATCH --cpus-per-task=2
9#SBATCH --mem=96G
10#SBATCH --gres=gpu:2
11
12#SBATCH hetjob
13
14# Resource specfification for workers
15#SBATCH --partition=cpu
16#SBATCH --nodes=2
17#SBATCH --ntasks-per-node=10
18#SBATCH --cpus-per-task=1
19#SBATCH --mem=40G
20
21source ~/miniconda3/etc/profile.d/conda.sh
22conda activate rpcdataloader
23export LD_LIBRARY_PATH=/home/users/ngranger/sysroot/lib64
24
25export rpc_port_start=16000
26
27# identify workers
28export tmpfile="${TMPDIR:-/tmp}/rpcdataloader_workers.$SLURM_JOB_ID"
29srun --het-group=1 -I --exclusive --kill-on-bad-exit=1 sh -c '
30 echo $(hostname)-ib:$(( $rpc_port_start + $SLURM_LOCALID ))
31 ' > "${tmpfile}"
32readarray -t workers < "${tmpfile}"
33rm $tmpfile
34
35# start workers in background
36srun --het-group=1 -I --exclusive --kill-on-bad-exit=1 sh -c '
37 python -u -m rpcdataloader.launch \
38 --host=0.0.0.0 \
39 --port=$(( $rpc_port_start + $SLURM_LOCALID ))
40 ' &
41worker_task_pid=$!
42
43# run training script
44export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST_HET_GROUP_0 | head -n 1)
45export MASTER_PORT=17000
46srun --het-group=0 -I --exclusive --kill-on-bad-exit=1 \
47 python -u example_rpc.py \
48 --workers ${workers[@]} \
49 --data-path=/media/ILSVRC \
50 --model=swin_v2_s \
51 --batch-size=256 \
52 --lr=0.001 \
53 --amp
54
55# stop workers
56kill $worker_task_pid