How to save memory and double the size of PyTorch models with the new Sharded method – InformTFB

How to save memory and double the size of PyTorch models with the new Sharded method

How to save memory and double the size of PyTorch models with the new Sharded method

Deep learning models improve as the amount of data and parameters increases. Even with Open AI’s latest GPT-3 model, which uses 175 billion parameters, we have yet to see a plateau in parameter growth.

For some areas, such as NLP, the workhorse was Transformer, which requires huge amounts of GPU memory. Realistic models simply don’t fit in memory. The last method is called Sharded [lit. ‘segmented’] was introduced in Zero paper Microsoft, where they developed a method that brings humanity closer to 1 trillion parameters.

Especially for the launch of the new Machine Learning course stream, I’m sharing with you an article about Sharded that shows you how to use It with PyTorch today to train models with twice as much memory and in just a few minutes. This feature is now available in PyTorch thanks to a collaboration between the FairScale Facebook AI Research and PyTorch Lightning teams .

Who is this article for?

This article is intended for anyone who uses PyTorch to train models. Sharded works on any model, regardless of which model to train: NLP (transformer), visual (SIMCL, swav, Resnet), or even speech models. Here is a snapshot of the performance gains that you can see with Sharded across all model types.

SwAV is a state-of-the-art method of data-driven learning in the field of computer vision.
DeepSpeech2 is a modern method for speech models.
Image GPT is an advanced method for visual models.
Transformer is an advanced method of natural language processing.

How to use Shared with Pwtorch

For those who don’t have much time to read an intuitive explanation of how Sharded works, I’ll immediately explain how to use Sharded with your PyTorch code. But I encourage you to read the end of the article to understand how Sharded works.

Sharded is designed to be used with multiple GPUs to take advantage of all the available advantages. But learning on multiple GPUs can be daunting and very painful to set up.

The easiest way to charge your code with Sharded is to convert your model to PyTorch Lightning (this is just a refactoring). Here’s a 4-minute video that shows you how to convert your PyTorch code to Lightning.
Once you’ve done that, enabling Sharded on 8 GPUs is as easy as changing a single flag: no changes to your code are required.

If your model is taken from another deep learning library, it will still work with Lightning (NVIDIA Nemo,, Hugging Face). All you have to do is import the model into LightningModule and start training.

from argparse import ArgumentParser

import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy

from transformers import BertModel

class LitBertClassifier(pl.LightningModule):
    def __init__(self, n_classes, pretrained_model_name='bert-base-uncased'):

        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.drop = nn.Dropout(p=0.3)
        self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(
        pooled_output = outputs[1]
        output = self.drop(pooled_output)
        return self.out(output)

    def training_step(self, batch, batch_idx):
        loss, acc = self._shared_step(batch, batch_idx)
        self.log("acc", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        _, acc = self._shared_step(batch, batch_idx)
        self.log("val_acc", acc)

    def _shared_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        targets = batch["targets"]

        outputs = self.forward(

        _, preds = torch.max(outputs, dim=1)
        loss = self.loss_fn(outputs, targets)

        acc = accuracy(preds, targets)
        return loss, acc

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=2e-5)

if __name__ == '__main__':
    # TODO: add your own dataset
    train_dataloader = ...
    val_dataloader = ...

    bert = LitBertClassifier()

    trainer = pl.Trainer(gpus=8, plugins='ddp_sharded'), train_dataloader)

Intuitive explanation of how Sharded works

Several approaches are used for effective training on a large number of GPUs. In one approach (DP), each packet is shared between GPUs. Here is a DP illustration, where each part of the packet is sent to a different GPU, and the model is repeatedly copied to each of them.

DP training

However, this approach is bad because the model weights are passed through the device. In addition, the first graphics processor supports all the state of the optimizer. For example, Adam stores an additional full copy of your model’s weights.

In another method (parallel data distribution, DDP), each GPU is trained on a subset of the data and the gradients are synchronized between the GPUs. This method also works on many machines (nodes). In this figure, each GPU receives a subset of the data and initializes the same model weights for all GPUs. Then, after the reverse pass, all gradients are synchronized and updated.

Parallel data distribution

However, there is still a problem with this method, which is that each GPU must maintain a copy of all the optimizer States (about 2-3 times the model parameters), as well as all forward and reverse activations.

Sharded eliminates this redundancy. It works just like DDP, except that all the overhead (gradients, optimizer state, etc.) is calculated only for a part of the full parameters, and thus we eliminate the redundancy of storing the same gradient and optimizer States on all GPUs. In other words, each GPU stores only a subset of the activation parameters of the optimizer and the computation of the gradient.

The use of any distributed mode

In PyTorch Lightning, switching distribution modes is trivial.

As you can see, using any of these optimization approaches, you can get many ways to maximize the effectiveness of distributed learning.

The good news is that all of these modes are available in PyTorch Lightning without having to change the code. You can try any of them and adjust it for your specific model if necessary.

One method that doesn’t exist is the parallel model. However, you should warn about this method, as it has proven to be much less effective than segmented learning, and it should be used with caution. It may work in some cases, but in General it is best to use segmentation.

The advantage of using Lightning is that you will never be left behind by the latest advances in artificial intelligence research! The Lightning open source software team and community are eager to share the latest developments with you.

Valery Radokhleb
Valery Radokhleb
Web developer, designer

Leave a Reply

Your email address will not be published. Required fields are marked *