API documentation
- class rpcdataloader.RPCDataloader(dataset: RPCDataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, collate_fn=None, pin_memory=False, drop_last=False, timeout=120, *, prefetch_factor: int = 2)[source]
A dataloader using remote rpc-based workers.
- Parameters:
dataset – A remote dataset
batch_size – how many samples per batch to load.
shuffle – set to
True
to have the data reshuffled at every epoch.sampler – defines the strategy to draw samples from the dataset. Can be any
Iterable
with__len__
implemented. If specified,shuffle
must not be specified.batch_sampler – like
sampler
, but returns a batch of indices at a time. Mutually exclusive withbatch_size
,shuffle
,sampler
, anddrop_last
.collate_fn – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
pin_memory – If
True
, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or yourcollate_fn
returns a batch that is a custom type, see the example below.drop_last – set to
True
to drop the last incomplete batch, if the dataset size is not divisible by the batch size. IfFalse
and the size of dataset is not divisible by the batch size, then the last batch will be smaller.prefetch_factor – Number of samples loaded in advance by each worker.
2
means there will be a total of 2 * num_workers samples prefetched across all workers. (default:2
)
Notable differences with pytorch dataloader:
timeout
is the timeout on individual network operations.worker_init_fn
andgenerator
are not supported.Random seeds are not supported because workers may execute requests out of order anyway, thus breaking reproducibility.
- class rpcdataloader.RPCDataset(workers: List[str], dataset: Callable[[Any], Dataset], *args, **kwargs)[source]
Handle to instanciate and manage datasets on remote workers.
- Parameters:
workers – a list of workers with the format host:port
dataset – dataset class or equivalent callable that returns a dataset instance
args – positional arguments for
dataset
kwargs – keyword arguments for
dataset
Note
In a distributed setup, you should probably split the workers between the trainers (ie:
workers = workers[rank::world_size]
).
- rpcdataloader.rpc_async(host: str, func: Callable[[...], _T], args=None, kwargs=None, pin_memory=False, rref: bool = False, timeout=120.0) Future[_T] [source]
Execute function on remote worker and return the result as a future.
- Parameters:
host – rpc worker host
func – function to execute
args – positional arguments
kwargs – keword arguments
pin_memory – wether buffers (ie: tensors) should be allocated in pinned memory.
rref – whether to return the output as a remote reference.
timeout – timeout in seconds on network operations
- Returns:
A future that will contain the function return value.
Note
func
and its arguments must be serializable, which exludes the usage of lambdas or locally defined functions.
- rpcdataloader.run_worker(host: str, port: int, timeout: float = 120, parallel: int = 1)[source]
Start listening and processing remote procedure calls.
- Parameters:
host – interface to bind to (set to ‘0.0.0.0’ for all interfaces)
port – port to bind to
timeout – timeout on network transfers from/to client
parallel – max number procedures executing concurrently
Warning
The workers neither implement authentication nor encryption, any user on the network can send arbitrary commands or may listen to the traffic from/to the worker.
Note
each request is processed in a separate thread
network transfers may overlap regardless of
parallel
argument.
- rpcdataloader.set_random_seeds(base_seed, worker_id)[source]
Set the seed of default random generators from python, torch and numpy.
This should be called once on each worker. Note that workers may run tasks out of order, so this does not ensure reproducibility, only non-redundancy between workers.
Example:
>>> base_seed = torch.randint(0, 2**32-1, [1]).item() >>> for i, (host, port) in enumerate(workers): ... rpc_async(host, port, set_random_seeds, args=[base_seed, i])