Question about the output of the decision transformer

From the code in here: transformers/src/transformers/models/decision_transformer/ at v4.35.2 路 huggingface/transformers 路 GitHub

        # reshape x so that the second dimension corresponds to the original
        # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
        x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)

        # get predictions
        return_preds = self.predict_return(x[:, 2])  # predict next return given state and action
        state_preds = self.predict_state(x[:, 2])  # predict next state given state and action
        action_preds = self.predict_action(x[:, 1])  # predict next action given state

I鈥檓 not sure I understand why self.predict_return(x[:, 2]) or self.predict_state(x[:, 2]) is predicting the return/next state given the state and action. From the comment on the top, x[:, 2] is only the action? Am I missing something?

And if this code is correct, what is the use of x[:, 0]?