Decision Transformer for Discrete action

I’m studying the decision transformer referring to Train your first Decision Transformer.
In the post, the example is for “halfcheetah” ( action space is continuous) and
the following model code is used.
I’m trying to apply this to the discrete action space.
I added the logit layer for the discrete action and changed the loss function as below.
( red color: removed , blue color: added )
Is this the right approach?

Hey,

I’m also trying out something similar recently. Based on my implementation, your implementation looks almost similar to mine, except for the fact that I did not use an additional linear layer and directly used the outputs as the logit. Also, I had encode my actions into one-hot-encoding, so I had to do some reshaping with the action targets, but otherwise, I think this is pretty much spot on.

I usually just print out the shapes and the intermediate variables at least once for a sanity check to make sure everything looks right and nothing it broadcasted incorrectly.