yeah but I am using a Trainer class which is implemented from scratch. That’s why I am asking…
Here is an example of the Trainer class for EncoderDecoder models:
class EncoderDecoderTransformerTrainer:
def __init__(self, model,
optimizer,
patience,
scheduler=None,
checkpoint_dir=None,
clip=None,
device='cpu'):
self.model = model.to(device)
self.optimizer = optimizer
self.scheduler = scheduler
self.checkpoint_dir = checkpoint_dir
self.clip = clip
self.device = device
self.patience = patience
def calc_val_loss(self, val_loader):
self.model.eval()
with torch.no_grad():
avg_val_loss = 0
for index, batch in enumerate(tqdm(val_loader)):
inputs = to_device(batch[0], device=self.device)
inputs_att = to_device(batch[1], device=self.device)
padded_targets = to_device(batch[2], device=self.device)
replaced_targets = to_device(batch[3], device=self.device)
targets_att = to_device(batch[4], device=self.device)
outputs = self.model(input_ids=inputs,
attention_mask=inputs_att,
decoder_input_ids=padded_targets,
decoder_attention_mask=targets_att,
labels=replaced_targets)
lm_loss = outputs[0]
pred_scores = outputs[1]
last_hidden = outputs[2]
avg_val_loss += lm_loss.item()
avg_val_loss = avg_val_loss / len(val_loader)
return avg_val_loss
def print_epoch(self, epoch, avg_train_epoch_loss, avg_val_epoch_loss,
cur_patience, strt):
print("Epoch {}:".format(epoch+1))
print("Train loss: {} | Train PPL: {}".format(
avg_train_epoch_loss, math.exp(avg_train_epoch_loss)))
print("Val loss: {} | Val PPL: {}".format(avg_val_epoch_loss,
math.exp(avg_val_epoch_loss)))
print("Patience left: {}".format(self.patience-cur_patience))
print("Time: {} mins".format((time.time() - strt) / 60.0))
print("++++++++++++++++++")
def save_epoch(self, epoch, loss=None):
if not os.path.exists(self.checkpoint_dir):
os.makedirs(self.checkpoint_dir)
# torch.save(self.model.state_dict(), os.path.join(
# self.checkpoint_dir, '{}_{}.pth'.format(epoch, 'model_checkpoint')))
# we use the proposed method for saving EncoderDecoder model
self.model.save_pretrained(os.path.join(self.checkpoint_dir,'model_checkpoint'))
torch.save(self.optimizer.state_dict(), os.path.join(
self.checkpoint_dir,'optimizer_checkpoint'))
def train_step(self, batch):
self.model.train()
self.optimizer.zero_grad()
inputs = to_device(batch[0], device=self.device)
inputs_att = to_device(batch[1], device=self.device)
padded_targets = to_device(batch[2], device=self.device)
replaced_targets = to_device(batch[3], device=self.device)
targets_att = to_device(batch[4], device=self.device)
print(inputs.shape)
print(padded_targets.shape)
# episis den eimai sigouros gia to ti prepei na dwsw san
# decoder_input_ids (ta input ids i ta padded_targets??)
outputs = self.model(input_ids=inputs,
attention_mask=inputs_att,
decoder_input_ids=padded_targets,
decoder_attention_mask=targets_att,
labels=replaced_targets)
lm_loss = outputs[0]
# print(lm_loss)
pred_scores = outputs[1]
last_hidden = outputs[2]
return lm_loss, last_hidden
def train_epochs(self, n_epochs, train_loader, val_loader):
best_val_loss, cur_patience = 10000, 0
print("Training model....")
self.model.train()
for epoch in range(n_epochs):
if cur_patience == self.patience:
break
avg_train_loss = 0
strt = time.time()
for index, sample_batch in enumerate(tqdm(train_loader)):
loss, _ = self.train_step(sample_batch)
avg_train_loss += loss.item()
loss.backward(retain_graph=False)
if self.clip is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(),
self.clip)
self.optimizer.step()
avg_train_loss = avg_train_loss / len(train_loader)
avg_val_loss = self.calc_val_loss(val_loader)
if avg_val_loss < best_val_loss:
self.save_epoch(epoch)
best_val_loss = avg_val_loss
cur_patience = 0
else:
cur_patience += 1
self.print_epoch(epoch, avg_train_loss, avg_val_loss,
cur_patience, strt)
def fit(self, train_loader, val_loader, epochs):
self.train_epochs(epochs, train_loader, val_loader)