Training Context
A PERSIA training context manages training environments on NN workers.
- PERSIA Training Context Complete Example
- EmbeddingConfig
- Mixed Precision Training
- Distributed Option
PERSIA Training Context Complete Example
Here is a complete example for the usage of PERSIA training context.
import torch
from persia.ctx import TrainCtx
from persia.embedding.optim import Adagrad
from persia.env import get_rank, get_world_size
from persia.data import Dataloder, PersiaDataset, StreamingDataset
from model import DNN
if __name__ == "__main__":
model = DNN()
rank, world_size = get_rank(), get_world_size()
device_id = 0
torch.cuda.set_device(device_id)
model.cuda(device_id)
dense_optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
embedding_optimizer = Adagrad(lr=1e-2)
loss_fn = torch.nn.BCELoss(reduction="mean")
with TrainCtx(
model=model,
embedding_optimizer=embedding_optimizer,
dense_optimizer=dense_optimizer,
device_id=device_id,
) as ctx:
train_dataloader = Dataloder(StreamingDataset(10))
for (batch_idx, data) in enumerate(train_dataloader):
(output, labels) = ctx.forward(data)
loss = loss_fn(output, label[0])
scaled_loss = ctx.backward(loss)
In the following section, we will introduce several configuration options when creating a PERSIA training context.
EmbeddingConfig
EmbeddingConfig defines embedding hyperparameters.
emb_initialization
: The default initialization of PERSIA embedding isUniform
distribution. Value is a tuple of the lower and upper bound of embedding uniform initialization.admit_probability
: The probability (0<=, <=1) of admitting a new embedding.weight_bound
: Restrict each element value of an embedding in[-weight_bound, weight_bound]
.
from persia.embedding import EmbeddingConfig
from persia.ctx import TrainCtx
embedding_config = EmbeddingConfig(
emb_initialization=(-1, 1),
admit_probability=0.8,
weight_bound=1
)
TrainCtx(
embedding_config=embedding_config,
...
)
Mixed Precision Training
The mixed_precision
feature in PERSIA training is only supported on gpu NN workers because it depends on pytorch amp.
Distributed Option
Distributed Option defines the implementation of data parallelism among PERSIA NN workers.
- DDP (by default): Native pytorch distributed training data parallelism implementation.
- Bagua: A deep learning training acceleration framework for PyTorch.
Configuring DDPOption
In DDPOption
, you can configure the backend
and init_method
.
from persia.distributed import DDPOption
backend = "nccl"
# backend = "gloo"
init_method = "tcp"
# init_method = "file"
DDPOption(backend="nccl", init_method=init_method)
Configuring BaguaDistributedOption
There are several data parallelism implementations in Bagua, see Bagua Documentation to learn more about Bagua.
from persia.ctx import TrainCtx
from persia.distributed import BaguaDistributedOption
algorithm = "gradient_allreduce"
# algorithm = low_precision_decentralized
# algorithm = bytegrad
# algorithm = async
bagua_args = {}
bagua_option = BaguaDistributedOption(
algorithm,
**bagua_args
)
TrainCtx(
distributed_option=bagua_option
)