Distributed Training with Trainer Class is Really Slow

I’ve been attempting to pretrain a model using the Trainer class, but I’m noticing its extremely slow when training on imagenet-1k dataset.

I’m pretraining on 8 GPUs, but it claims it could take a few hundred hours! So I’m wondering if there is something particular I need to do with my code outside the trainer class?

Admittedly, to pretrain I had to create a custom pytorch nn.Module that wraps the ViT huggingface object and adds a MLP for img-classification. But much of the code is taken from huggingface’s ViTForImageClassification. Could this custom object cause the under-the-hood distributed programming to not work? Could it also by the result of a bottleneck somewhere when loading up and preprocessing my dataset objects?

My script is just a simple starting line:

python -m torch.distributed.launch --nproc_per_node=8 train.py 

See my code below, but its just a combination of huggingface template code.

from transformers import ViTFeatureExtractor, ViTModel, ViTConfig

...

class Model(nn.Module):
    def __init__(self, num_labels=1000):
        super().__init__()
        self.num_labels = num_labels

        self.ViT_Encoder = ViTModel(config=ViTConfig())

        self.classifier = nn.Linear(768, self.num_labels) if self.num_labels > 0 else nn.Identity()

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
    ) -> Union[tuple, ImageClassifierOutput]:

        attention_map = self.ViT_Encoder(pixel_values)

        sequence_output = attention_map[0]

        logits = self.classifier(sequence_output[:, 0, :])

        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return ImageClassifierOutput(
            loss=loss,
            logits=logits
        )

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=1)
    return metric.compute(predictions=predictions, references=labels)

def train_transform(example_batch):
    # https://huggingface.co/blog/fine-tune-vit
    inputs = BatchFeature({'pixel_values': [_train_transforms(image.convert("RGB")) for image in example_batch["image"]]})
    inputs['labels'] = example_batch['label']
    return inputs

def val_transform(example_batch):
    # https://huggingface.co/blog/fine-tune-vit
    inputs = BatchFeature({'pixel_values': [_val_transforms(image.convert("RGB")) for image in example_batch["image"]]})
    inputs['labels'] = example_batch['label']
    return inputs

def collate_fn(batch):
    # https://huggingface.co/blog/fine-tune-vit
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

if __name__ == '__main__':

    num_labels = 1000

    # pull dataset
    dataset = load_dataset('imagenet-1k')
    train_dataset = dataset['train']
    valid_dataset = dataset['validation']

    # shuffle data
    train_dataset.shuffle(seed=42)
    valid_dataset.shuffle(seed=42)

    # feature extraction and data augmentation
    feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
    normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
    _train_transforms = Compose(
        [
            RandomResizedCrop(feature_extractor.size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize
        ],
    )

    _val_transforms = Compose(
        [
            Resize(feature_extractor.size),
            CenterCrop(feature_extractor.size),
            ToTensor(),
            normalize,
        ]
    )

    train_ds = train_dataset.with_transform(train_transform)
    valid_ds = valid_dataset.with_transform(val_transform)

    metric = evaluate.load('accuracy')

    model = Model(num_labels=1000)

    training_args = TrainingArguments(
        output_dir=args.exp_dir,
        evaluation_strategy='epoch',
        save_strategy='epoch',
        num_train_epochs=10,
        remove_unused_columns=False,
        load_best_model_at_end=True,
        save_total_limit=2,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=collate_fn,
        train_dataset=train_ds,
        eval_dataset=valid_ds,
        compute_metrics=compute_metrics,
        tokenizer=feature_extractor,
    )

    train_results = trainer.train()
    trainer.save_model()
    trainer.log_metrics("train", train_results.metrics)
    trainer.save_metrics("train", train_results.metrics)
    trainer.save_state()
    print('Done!')