Finetuning GPT2 using Multiple GPU and Trainer

@valhalla Here is the full code snippet:

from transformers import GPT2Tokenizer, GPT2LMHeadModel, TrainingArguments, Trainer
import torch
from torch.utils.data import Dataset
import sys
import pandas as pd
#import numpy as np

ZERO = sys.float_info.min
ZERO_PT = torch.tensor(ZERO)

class GPT2FinetunedWithNgrams(GPT2LMHeadModel):
    def __init__(self, config, model_tokenizer=None):
        super().__init__(config)
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', padding_side='right')
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def eval_sentence(self, sent: str):
        vec = self.sentence_vec(
            sent)  # remove punct, lower case, split on space, prepend "<s>", postpend "</s>" start and stop tokens. Returns list of strings.
        last_idx = min(self.max_ngram, len(vec))

        log_prob = 0
        for i in range(2, last_idx + 1):
            #log_prob += np.log(max(ZERO, self.pkatz(vec[0:i])))  # conditional probability with katz backoff
            log_prob += torch.log(max(ZERO_PT, self.pkatz(vec[0:i])))

        for i in range(1, len(vec) - last_idx + 1):
            j = i + last_idx
            #log_prob += np.log(max(ZERO, self.pkatz(vec[i:j])))
            log_prob += torch.log(max(ZERO_PT, self.pkatz(vec[i:j])))
        return log_prob, len(vec)

    def sentence_loss(self, sent: str):
        p, l = self.eval_sentence(sent)
        return -p

    def generate_text_while_finetuning(self,
                                       input_ids=None,
                                       past=None,
                                       attention_mask=None,
                                       token_type_ids=None,
                                       position_ids=None,
                                       head_mask=None,
                                       inputs_embeds=None,
                                       labels=None,
                                       use_cache=None,
                                       output_attentions=None,
                                       output_hidden_states=None, ):
        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        hidden_states = transformer_outputs[0]
        lm_logits = self.lm_head(hidden_states)
        outputs = (lm_logits,) + transformer_outputs[1:]
        return outputs  # (loss), lm_logits, presents, (all hidden_states), (attentions)

    def forward(
            self,
            input_ids=None,
            past=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
            use_cache=True,
    ):

        max_length = input_ids.shape[1] + 50
        full_generated_gpt2_ids = self.generate(input_ids=input_ids,
                                                max_length=max_length,
                                                is_finetuning_current_model=True,
                                                attention_mask=attention_mask,
                                                pad_token_id=50256,
                                                do_sample=True,
                                                top_k=50,
                                                top_p=0.95)

        decoded_gen_samples = self.tokenizer.batch_decode(full_generated_gpt2_ids, skip_special_tokens=True)
        tmp_losses = [self.sentence_loss(decoded_sample) for decoded_sample in decoded_gen_samples]
        losses = torch.stack(tmp_losses)
        loss = losses.mean()
        loss.requires_grad = True
        return (loss,)


##The code below is the run script.
class MyDataset(Dataset):
    def __init__(self, csv_file: str):
            self.df = pd.read_csv(csv_file, encoding='ISO-8859-1')

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        text = self.df.iloc[idx, 1]
        return text

def my_data_collator(dataset_samples_list):
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2', padding_side='right')
    tokenizer.pad_token = tokenizer.eos_token

    encoded_results = tokenizer(dataset_samples_list, padding=True, truncation=True, return_tensors='pt', return_attention_mask=True)

    batch = {}
    batch['input_ids'] = torch.stack([result for result in encoded_results['input_ids']])
    batch['past'] = None
    batch['attention_mask'] = torch.stack([result for result in encoded_results['attention_mask']])
    batch['position_ids'] = None
    batch['head_mask'] = None
    batch['inputs_embeds'] = None
    batch['labels'] = None
    batch['use_cache'] = True
    return batch

dataset_train = MyDataset('/path/to/train_dataset.csv')

training_args = TrainingArguments(
    output_dir='/path/to/out',
    do_train=True,
    per_device_train_batch_size=64,
    logging_dir='/path/to/dir',
    max_steps=300000
)

model = GPT2FinetunedWithNgrams.from_pretrained('gpt2')

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=my_data_collator,
    train_dataset=dataset_train
)
trainer.train()
trainer.save_model('/path/to/model_save_dir')