Training Context

A PERSIA training context manages training environments on NN workers.

PERSIA Training Context Complete Example

Here is a complete example for the usage of PERSIA training context.

import torch

from persia.ctx import TrainCtx
from persia.embedding.optim import Adagrad
from persia.env import get_rank, get_world_size
from persia.data import Dataloder, PersiaDataset, StreamingDataset

from model import DNN


if __name__ == "__main__":
    model = DNN()
    rank, world_size = get_rank(), get_world_size()

    device_id = 0
    torch.cuda.set_device(device_id)
    model.cuda(device_id)

    dense_optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    embedding_optimizer = Adagrad(lr=1e-2)
    loss_fn = torch.nn.BCELoss(reduction="mean")

    with TrainCtx(
        model=model,
        embedding_optimizer=embedding_optimizer,
        dense_optimizer=dense_optimizer,
        device_id=device_id,
    ) as ctx:
        train_dataloader = Dataloder(StreamingDataset(10))

        for (batch_idx, data) in enumerate(train_dataloader):
            (output, labels) = ctx.forward(data)
            loss = loss_fn(output, label[0])
            scaled_loss = ctx.backward(loss)

In the following section, we will introduce several configuration options when creating a PERSIA training context.

EmbeddingConfig

EmbeddingConfig defines embedding hyperparameters.

  • emb_initialization: The default initialization of PERSIA embedding is Uniform distribution. Value is a tuple of the lower and upper bound of embedding uniform initialization.
  • admit_probability: The probability (0<=, <=1) of admitting a new embedding.
  • weight_bound: Restrict each element value of an embedding in [-weight_bound, weight_bound].
from persia.embedding import EmbeddingConfig
from persia.ctx import TrainCtx

embedding_config = EmbeddingConfig(
    emb_initialization=(-1, 1),
    admit_probability=0.8,
    weight_bound=1
)

TrainCtx(
    embedding_config=embedding_config,
    ...
)

Mixed Precision Training

The mixed_precision feature in PERSIA training is only supported on gpu NN workers because it depends on pytorch amp.

Distributed Option

Distributed Option defines the implementation of data parallelism among PERSIA NN workers.

  • DDP (by default): Native pytorch distributed training data parallelism implementation.
  • Bagua: A deep learning training acceleration framework for PyTorch.

Configuring DDPOption

In DDPOption, you can configure the backend and init_method.

from persia.distributed import DDPOption

backend = "nccl"
# backend = "gloo"

init_method = "tcp"
# init_method = "file"

DDPOption(backend="nccl", init_method=init_method)

Configuring BaguaDistributedOption

There are several data parallelism implementations in Bagua, see Bagua Documentation to learn more about Bagua.

from persia.ctx import TrainCtx
from persia.distributed import BaguaDistributedOption

algorithm = "gradient_allreduce"
# algorithm = low_precision_decentralized
# algorithm = bytegrad
# algorithm = async

bagua_args = {}
bagua_option = BaguaDistributedOption(
    algorithm,
    **bagua_args
)

TrainCtx(
    distributed_option=bagua_option
)