Performing Back Translation with T5 network

Hi, current trying to reproduce the paper “Unsupervised Translation of Programming Languages” (TransCoder) from Facebook research, but using the T5 network as the seq2seq model. Right now I am stuck on the back translation part from the approach:

def back_translate(self, batch: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
        device = batch[0]['input_ids'].device
        print(device)
        input_ids = torch.stack([example['input_ids'] for example in batch])
        target_ids = torch.stack([example['target_ids'] for example in batch])
        attention_mask = torch.stack([example['attention_mask'] for example in batch])
        batch_langs = [self.tokenizer.decode([ids[0].detach()]) for ids in target_ids]
        lang = random.choice(self.langs)
        self.model.eval()
        cpu_model = self.model.to('cpu')
        outputs = cpu_model.generate(
            input_ids = input_ids, attention_mask = attention_mask,
            decoder_start_token_id = self.tokenizer.encode(f'<{lang}>')[0], max_length = 256
        )
        self.model.train()

        inputs = [self.tokenizer.decode(ids).replace('complete: ', '') for ids in input_ids]
        outputs = [self.tokenizer.decode(ids[1:]) for ids in outputs] # remove lang token
        examples = []
        for inpt, outpt, l in zip(outputs, inputs, batch_langs):
            inpt = f'complete: {inpt}  </s>'
            outpt = f'<{l}>{outpt}'
            input_encodings = self.tokenizer.encode_plus(inpt, pad_to_max_length = True, max_length = 256, truncation = True)
            target_encodings = self.tokenizer.encode_plus(outpt, pad_to_max_length = True, max_length = 256, truncation = True)
            encodings = {
                'input_ids': torch.tensor(input_encodings['input_ids'], dtype=torch.long, device = xm.xla_device()), 
                'attention_mask': torch.tensor(input_encodings['attention_mask'], dtype=torch.long, device = xm.xla_device()),
                'target_ids': torch.tensor(target_encodings['input_ids'], dtype=torch.long, device = xm.xla_device()),
                'target_attention_mask': torch.tensor(target_encodings['attention_mask'], dtype=torch.long, device = xm.xla_device())
            }
            
            examples.append(encodings)
        
        input_ids = torch.stack([example['input_ids'] for example in examples])
        input_ids, _ = self.masked_data_collator.mask_tokens(input_ids)
        lm_labels = torch.stack([example['target_ids'] for example in examples])
        lm_labels[lm_labels[:, :] == 0] = -100
        attention_mask = torch.stack([example['attention_mask'] for example in examples])
        decoder_attention_mask = torch.stack([example['target_attention_mask'] for example in examples])
        print(input_ids.device, xm.xla_device())
        return {
            'input_ids': input_ids, 
            'attention_mask': attention_mask,
            'lm_labels': lm_labels, 
            'decoder_attention_mask': decoder_attention_mask
        }

I am training this in Google colab using TPUs, but even though I am explicitly putting the tensors onto the TPU device, it is giving me an error saying: Input tensor is not an XLA tensor: torch.FloatTensor

Here is a link to the full colab notebook: https://colab.research.google.com/drive/1nRGkCdei7D6v6njKWPVZZWtGgPAedkVQ?usp=sharing

Any help or advice would be greatly appreciated!

Can you post the full error trace? Just to be safe, you can do

        return {
            'input_ids': input_ids.to(xm.xla_device()), 
            'attention_mask': attention_mask.to(xm.xla_device()),
            'lm_labels': lm_labels.to(xm.xla_device()), 
            'decoder_attention_mask': decoder_attention_mask.to(xm.xla_device())
        }

Here is the full trace @BramVanroy:

Input tensor is not an XLA tensor: torch.FloatTensor
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "<ipython-input-8-8b7fe8f75065>", line 286, in _mp_fn
    main()
  File "<ipython-input-8-8b7fe8f75065>", line 257, in main
    model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
  File "/usr/local/lib/python3.6/dist-packages/transformers/trainer.py", line 521, in train
    tr_loss += self.training_step(model, inputs, optimizer)
  File "/usr/local/lib/python3.6/dist-packages/transformers/trainer.py", line 702, in training_step
    outputs = model(**inputs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 726, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_t5.py", line 1141, in forward
    return_tuple=return_tuple,
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 726, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_t5.py", line 703, in forward
    inputs_embeds = self.embed_tokens(input_ids)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 726, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/sparse.py", line 126, in forward
    self.norm_type, self.scale_grad_by_freq, self.sparse)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py", line 1845, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: torch_xla/csrc/aten_xla_bridge.cpp:69 : Check failed: xtensor 
*** Begin stack trace ***
	tensorflow::CurrentStackTrace[abi:cxx11]()
	torch_xla::bridge::GetXlaTensor(at::Tensor const&)
	torch_xla::AtenXlaType::index_select(at::Tensor const&, long, at::Tensor const&)
	c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, long, at::Tensor const&), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, long, at::Tensor const&> >, at::Tensor (at::Tensor const&, long, at::Tensor const&)>::call(c10::OperatorKernel*, at::Tensor const&, long, at::Tensor const&)
	
	at::Tensor::index_select(long, at::Tensor const&) const
	at::native::embedding(at::Tensor const&, at::Tensor const&, long, bool, bool)
	torch_xla::AtenXlaType::embedding(at::Tensor const&, at::Tensor const&, long, bool, bool)
	c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, at::Tensor const&, long, bool, bool), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, long, bool, bool> >, at::Tensor (at::Tensor const&, at::Tensor const&, long, bool, bool)>::call(c10::OperatorKernel*, at::Tensor const&, at::Tensor const&, long, bool, bool)
	
	at::embedding(at::Tensor const&, at::Tensor const&, long, bool, bool)
	
	
	
	at::embedding(at::Tensor const&, at::Tensor const&, long, bool, bool)
	
	
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallDict
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	
	_PyFunction_FastCallDict
	
	
	_PyObject_FastCallKeywords
	
	_PyEval_EvalFrameDefault
	
	_PyFunction_FastCallDict
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	
	_PyFunction_FastCallDict
	
	
	_PyObject_FastCallKeywords
	
	_PyEval_EvalFrameDefault
	
	_PyFunction_FastCallDict
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	
	_PyFunction_FastCallDict
	
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallDict
	
	
	
	_PyObject_FastCallKeywords
	
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	_PyFunction_FastCallDict
	
	PyObject_Call
	_PyEval_EvalFrameDefault
*** End stack trace ***
Input tensor is not an XLA tensor: torch.FloatTensor
---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
<ipython-input-10-e0997d11f847> in <module>()
      2 from numpy import inf
      3 
----> 4 xmp.spawn(_mp_fn, args=(), nprocs=8, start_method='fork')

2 frames
/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py in join(self, timeout)
    111                 raise Exception(
    112                     "process %d terminated with exit code %d" %
--> 113                     (error_index, exitcode)
    114                 )
    115 

Exception: process 1 terminated with exit code 17

(This is repeated multiple times for each TPU core so I omitted the duplicate text).
I also tried using your tip of explicitly returning the tensors using the to(xm.xla_device()) and got the same error message.

Seems like you are using the trainer. Can you do print(trainer.args.device), what does that give you? Normally, the trainer should automatically select the TPU if it is available.

So inputs should be converted to XLA tensors automatically.

Can you also print that final dict that you return?

From print(trainer.args.device) it prints xla:0 so it seems it is definitely on the TPU, the trainer anyways. Here is the printout of the dictionary that I return:

Return Dict {'input_ids': tensor([[743,  10,   1,  ...,   0,   0,   0],
        [743,  10,  96,  ...,   0,   0,   0],
        [743,  10,   1,  ...,   0,   0,   0],
        [743,  10,   1,  ...,   0,   0,   0]], device='xla:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='xla:0'), 'lm_labels': tensor([[32104, 32103,  2490,  ...,  -100,  -100,  -100],
        [32104, 32103,  2490,  ...,  2281,   599,    31],
        [32104, 32103,  2490,  ...,  -100,  -100,  -100],
        [32104, 32104,  2490,  ...,    17, 18392,    76]], device='xla:0'), 'decoder_attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1]], device='xla:0')}

I think it has to do with me places the model on the CPU cpu_model = self.model.to('cpu') to perform inference. However, I have to do this because there is some other bug that doesn’t let me use the .generate function on TPUs. I’m gonna try just using a regular old GPU and see if that helps at all.