Debugging my poor Decision Transformer performance


I am attempting to apply Decision Transformers (DTs) to a research area I’ve been working on for the last ~18 months. I have implemented the DT model using the original code base (decision-transformer/ at master · kzl/decision-transformer · GitHub) which is more or less identical to the Hugging Face code base (transformers/ at main · huggingface/transformers · GitHub) and it can learn something, but its performance plateaus long before I would have expected. I am hoping that perhaps people from this community might be able to give me some pointers on where to look for errors/improvements in DT if I walk you through my preliminary experiments. In the below text, I have put specific questions in bold which I would be enormously grateful if anyone could address.

N.B. To get around the rule that ‘new users can only put one embedded image in a post’, I am posting each figure in separate posts blow.

Metrics Tracked

During training, I am tracking the following metrics:

  1. training/loss: The loss computed for a given batch and used to update DT.
  2. training/state_error: (state\_preds - state\_target)^{2}
  3. training/action_error: (action\_preds - action\_target)^{2}
  4. training/return_error: (return\_preds - return\_target)^{2}
  5. training/imitation_success_frac: The fraction of expert actions which were correctly predicted in a sampled batch of trajectories.
  6. evaluation/return: The return achieved by a DT checkpoint when rolled out and evaluated by setting the initial target to the expert’s initial return and updating the returns-to-go at each step accordingly

Question: Are there any other training and/or evaluation metrics you’d recommend tracking to better monitor DT?


1. Initial sanity check: Overfitting to one expert trajectory

To begin with, I am overfitting to one expert trajectory; that is to say that the training data set is 4 expert environment transitions, and the test-time evaluation is the latest DT checkpoint acting in this exact same environment. The expert gets an overall return (performance; higher is better) of R=-10, so if I am overfitting correctly, then this should also be the attainable performance of DT.

a. Prediction_target=expert_action

Here, I am setting the prediction target to the expert action with a cross entropy loss function between DT’s predicted actions and the true expert actions.

As you can see, DT can overfit to the expert actions and after ~20 epochs achieves the expert’s performance on this overfitting example:

Fig. 1: Overfitting, prediction_target=expert_action

Question: Is it surprising that the training/return_error metric is increasing? I know I am using the expert’s actions as the target in my loss function, but should DT not be learning to accurately predict the return-to-go with its return_preds? Does this suggest there is something going wrong somewhere?

b. Prediction_target=return_to_go

In the original DT paper, they state that they could use either the return-to-go or the expert action as the training target. However, when I run the above overfitting experiment but now setting the target as the return-to-go with a Mean Squared Error loss function, DT finds it much tougher to learn, and indeed never reaches the R=-10 performance of the expert:

Fig. 2: See below post.

Question: Why might DT be finding it so much harder to overfit the return? Is this expected behaviour, or might it indicate something is going wrong somewhere?

2. Generalising to unseen trajectories

Here, I am training DT on ~130k expert trajectories (~500k environment transitions) by setting prediction_target=expert_action with cross entropy loss and getting the agent to generalise to a set of 100 unseen environment problem settings at evaluation time. The mean expert performance on these 100 unseen problems is R=-6.76, the performance of a baseline imitation learning agent (which just imitates the expert) is R=-10.5, and the performance of a simple baseline heuristic is R=-12.8, so I would hope that DT can get to around these ball-park figures. However, as you can see, DT plateaus at around R=-22:

Fig. 3: See below post.

I do not think that this is due to overfitting/the training data set being too small, because when I 10x the data set to ~1.3M expert trajectories, I plateau at R=-22 again.

Question: Do you think that the poor performance will be due to the poor return error (which to me seems fairly important since if this is wrong then how can DT learn to choose good actions based on a given target return?)? Or do you think there might be some other cause of this poor performance? Do you have any ideas for additional experiments I should run/metrics I should track to try to better understand what is going on?

If anyone has any general thoughts, advice, or queries about the above, I would be very grateful!

Fig. 2: Overfitting, prediction_target=return_to_go

Fig. 3: Generalising