T5 Finetuning Tips

Starting this for results, sharing + tips and tricks, and results. This is my first attempt at this kind of thread so it may completely fail.

Some things I’ve found

  • Apparently if you copy AdaFactor from fairseq, as recommended by t5 authors, you can fit batch size = 2 for t5-large lm finetuning
  • fp16 rarely works.
  • for most tasks, you need to manually add </s> to the end of your sequence.

Thing’s I’ve read

  • task specific prefix doesn’t matter much.

cc @mrm8488 @valhalla @patrickvonplaten who have all tried different experiments.

25 Likes

Things I’ve found

  • task prefixes matter when
    1. When doing multi-task training
    2. When your task similar or related to one of the supervised tasks used in T5 pre-training mixture.
  • Needs slightly higher LR than the default one set in Trainer, in my experiments 1e-4 and 3e-4 worked for almost all problems (classification, QA, que-gen, summ)
  • no need to pass decoder_input_ids to T5 yourself, just pass labels and the T5Model will prepare them for you. labels should end with eos_token. (important! This is where most of the mistakes are happening).
  • T5 uses pad_token as the decoder_start_token_id so when doing generation without the generate function make sure you start it with pad token.
  • trimming batches when training on TPU leads to very slower training.
  • apparently, because of sentencepiece and some possible leakage of other languages in C4 data, T5 gives somewhat sensible results for french lang. fine-tuned it on FQuAD (french version of SQuAD) for que gen and BLUE-4 against dev set was 15.

Not sure if it’s an issue or not but in some cases using label_smoothing in T5 resulted in nan loss

11 Likes

This even “works” in FP16 – but don’t get me started on Native AMP quite yet…

But in summary – I would strongly recommend using AdaFactor and not ADAM for T5 training and finetuning.

  • this is what the T5 authors use themselves
  • AdaFactor was developed specifically with Transformers/T5 in mind (say so in the paper)
  • ADAM is a massive waste of memory in general; it’s not surprising that something more efficient would work as well unless you have custom additions to your model
10 Likes

those observations seem quite consistent with out experience as well – did not try on TPU yet

1 Like

Hi guys!
I just finish training T5-large on ELI5 on 270,000 exampels using TPU V2-8 on colab modified from @valhalla notebook! This is not really finetuning tips, but some tips to make T5-large trainable on TPU V2-8 .

T5-large is challenging to train on TPU V2-8 with Pytorch (for me)

  • I faced a lot of memory problem (even on Colab High-RAM instance), this notebook of Davide Libenzi - one of XLA authors suggested to declare large model outside _mp_fn (see his mx variable )
  • with T5-base , there is around 7 minutes overhead before training can start, for T5-large, this takes 1 hour overhead to me
  • with max_length = 128 (both input and target), I am able to set per_device_train_batch_size = 4 (so, global_batch_size = 4*8 = 32)
  • there is an issue that xm.save() causes memory error with large models like XLM-Roberta , it happen to T5-large too, so I have to ignore the default save_steps of Trainer by setting it to 1000000

Combine all these, took me around 1 day before I can make a trainable notebook, so hopefully these tricks can be useful to some of you guys too!

I would like to find time to make a TF2 version which should be more stable on TPU :slight_smile:

More note

  • As @valhalla mentioned in his notebook, High-RAM instance is a must. Lately Kaggle notebook increased RAM to 16GB for TPUV3-8, but I could not the training to success (sadly since V3-8 should be 2x faster than V2-8)
4 Likes

Hi @moscow25, thanks for sharing the AdaFactor info. I was wondering if constant LR of 1e-3 is working for small batch sizes, because in the paper they mentioned that the BS for fine-tuning was 128, it’s not possible to use 128 BS with single V100 for model >t5-base .

2 Likes

This trick of loading the model outside of _map_fn is awesome! It should save some memory. In pytorch-xla the model and the datset is loaded in all processes (8 in case 8 TPU cores) so it ends up taking lot of memory. Lazy loading dataset should also reduce RAM usage.

On V3-8, I was able to use bs of 8 per device with max_source_length 512 and max_target_length 64

2 Likes

Sure thing @valhalla. I did not try too many settings… but LR 0.001 seems to work just fine for smaller finetuning batches. I’m running global batch of 2*8 [2 per GPU] – though with a bit of gradient accumulation (4x I believe) but tbh it’s not really that sensitive as far as I can tell. The only gotcha is to turn off those extra scaling parameters that FAIR-seq threw in there and set True by default for no good reason. (scale_parameter=False, relative_step=False)

To get bigger batches, I’m pretty sure we need to add some gradient checkpointing to the model. Trying that out next…

2 Likes

T5 questions I think I know the answer to that multiple people have asked. Correct me if I’m wrong! Quotes are from the paper.

Q: What masking objective did they use for pretraining?
Span Corruption.

Specifically, we use a mean span length of 3 and corrupt 15% of the original sequence. We found that this objective produced marginally better performance (Table 7) while being slightly more computationally efficient due to shorter target sequence lengths.

Q: Are the hf checkpoints trained with multi-tasking?
A: yes

Q:Do we have access to T5 1.1 Checkpoints:
A: No, because they are not obvious wins: Should I use t5v1.1, t5narrow and TalkingHeads? · Issue #266 · google-research/text-to-text-transfer-transformer · GitHub

4 Likes

More on T5 pre-training objective
Each corrupted span is replaced by a unique sentinel token. . The
output sequence then consists of the dropped-out spans, delimited by the sentinel
tokens used to replace them in the input plus a final sentinel token.

T5 uses 100 extra ids as sentinel tokens (<extra_id_0> ... <extra_id_99>)

from HF docs

output sequence is formed as a concatenation of the same sentinel tokens and the real masked tokens

E.g. the sentence “The cute dog walks in the park” with the masks put on “cute dog” and “the” should be processed as follows:

input_text = "The <extra_id_1> walks in <extra_id_2> park"
target_text = "<extra_id_1> cute dog <extra_id_2> the <extra_id_3> </s>"
2 Likes

Has anyone managed to finetune t5 in fp16? Maybe @mrm8488 ?
If so which torch version/how ?

1 Like

No, I didn’t. But I will check it out!!

Hi, what about packed sequences they use in paper?

I’m trying to figure out how to do it in huggingface model (to replicate an experiment)

1 Like

I see that in T5 they do use scale_parameter
https://console.cloud.google.com/storage/browser/_details/t5-data/experiments/scaling/sc-bi_v1-bsx4/operative_config.gin

Parameters for AdafactorOptimizer:

==============================================================================

AdafactorOptimizer.beta1 = 0.0
AdafactorOptimizer.clipping_threshold = 1.0
AdafactorOptimizer.decay_rate = None
AdafactorOptimizer.epsilon1 = 1e-30
AdafactorOptimizer.epsilon2 = 0.001
AdafactorOptimizer.factored = True
AdafactorOptimizer.min_dim_size_to_factor = 128
AdafactorOptimizer.multiply_by_parameter_scale = True

1 Like

Thanks @saareliad for looking that up. I was going to do that, glad you did…

So far for “scale by parameter size = True” – in my experiments

  1. This does works.
  2. The model converges much more slowly (for fine-tuning, on 8 GPU Volta) – and from what I saw to a worse number.

But yes, it’s a parameter that does work. Maybe need to increase the LR or train for much longer with that on.

As for FP16 – I did manage to get T5-Large checkpoints training in FP16… actually it won’t even inference (get NaN) for T5-Large without turning off feed-forward in FP16 or by doing some custom scaling magic…

Even if you do only do FP16 on the attention operations and that’s it. I am able to fine-tune a checkpoint without NaN’s but the model diverges after a while. This happens even with just a small part of the model in FP16, and with a lower LR.

It’s possible that this can eventually work, but it’s not simple.

Moreover – if you are doing your attention operations in FP16 but saving all weights and gradients in FP32 (as well as FP16) – this may save a little bit of compute but does not save GPU memory at training. So this kind of very conservative FP16 is not useful, sad to say:

  1. Diverges the model
  2. Does not use less GPU memory (so can’t use larger batches or bigger model)
  3. Does not train faster
3 Likes

What do you mean by “does work?” you clearly say that it made your results worse.
From what I saw over T5-small “scale parameter” makes the result slightly worse as well.

I’m currently trying T5-3B and the model doesn’t learn
(beta1=0.9 and scale_parameter=False) and default learning rate, so I wonder what can be done to mitigate it.

TF32 Training on Ampere?

There seems to be discussion (various threads and git issues) about whether T5 arch is just inherently unstable and that frequent FP16 NaN isn’t a bug in the transformers implementation or in user training arguments but may be unavoidable in true FP16 mode.

I would like to bring into this discussion that when i run the mesh tensorflow version of T5 from the research repo (https://github.com/google-research/text-to-text-transfer-transformer) on TPU on my data set its rock solid 16 bit training (I assume because of the wider range capability of bf16 support). On the same data set I essentially can never get fp16 working on anything larger than t5-small with HuggingFace (with adafactor, with and without lr warming, native/apex(1/2/3) ect)

For workflow reasons using the research mesh code is not going to be an option and I need to get the 3B model training on GPUs which will require ~16bit compute in order to fit in ~32-48GB gpu.

Ampere gpus supports a similar size TF32 (18bit vs 16 bit for BF16 on TPU). Has anyone tried (or even have access to) an A100 GPU to see if TF32 solves the issue here?

EDIT: it looks like Ampere also natively support BF16. So that looks like a good way to compare T5 mesh on TPU with T5 HF on Ampere both using BF16.

2 Likes

task specific prefix doesn’t matter much.

When finetuning on a task that is quite different from one of the supervised tasks, then:

  1. Can you leave the default “summarization” task that finetune.py uses, or will this mess things up?
  2. Should you create your own named task or just leave the task blank? (Does the model always expect some sort of <task_name>: prefix to every input?)
  1. if the task is not related to “summarization” then it’ll probably mess thing up or slow down convergence, because the model will think it’s doing summarization because of the prefix. better to remove it.
  2. Just iterating what I’ve said above. If your task is completely new and not related to one of the tasks on which T5 was trained then the prefix shouldn’t matter.