From the code in here: transformers/src/transformers/models/decision_transformer/modeling_decision_transformer.py at v4.35.2 路 huggingface/transformers 路 GitHub
# reshape x so that the second dimension corresponds to the original
# returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)
# get predictions
return_preds = self.predict_return(x[:, 2]) # predict next return given state and action
state_preds = self.predict_state(x[:, 2]) # predict next state given state and action
action_preds = self.predict_action(x[:, 1]) # predict next action given state
I鈥檓 not sure I understand why self.predict_return(x[:, 2])
or self.predict_state(x[:, 2])
is predicting the return/next state given the state and action. From the comment on the top, x[:, 2]
is only the action? Am I missing something?
And if this code is correct, what is the use of x[:, 0]
?