Hi Everyone I am trying to have a REINFORCE implementation with Jax but I am facing some issues. The Policy looks something like this
class MLPPolicy(nn.Module):
action_dims: int
@nn.compact
def __call__(self, input: jnp.ndarray):
x = nn.Dense(4)(input)
x = nn.relu(x)
x = nn.Dense(16)(x)
x = nn.relu(x)
x = nn.Dense(self.action_dims)(x)
# x = nn.softmax(x)
return x
Train method looks something like this
def train(env_id,gamma,episodes,max_termination,seed,num_envs,learning_rate):
run_name = f"{env_id}__{seed}_{int(time.time())}"
writer = SummaryWriter(f"runs/{run_name}")
random.seed(seed)
np.random.seed(seed=seed)
key = jax.random.PRNGKey(seed)
# key,q_key = jax.random.split(key,2)
env = make_env(env_id, seed, 0, True, run_name)()
obs, _ = env.reset(seed=seed)
action_dims = env.action_space.n
policy = MLPPolicy(action_dims=action_dims)
policy_state = TrainState.create(
apply_fn=policy.apply,
params=policy.init(key,obs),
tx=optax.adam(learning_rate=learning_rate)
)
policy.apply = jax.jit(policy.apply)
# @jax.jit
def update(policy_state,observations,actions,rewards):
def loss_fn(params):
logits = policy.apply(params,observations)
log_probs = jax.nn.log_softmax(logits)
# print("Log probs:",log_probs)
return -jnp.mean(jnp.sum(onehot(actions,log_probs.shape[-1]) * log_probs,axis=-1) * rewards)
loss_value,grads = jax.value_and_grad(loss_fn)(policy_state.params)
policy_state = policy_state.apply_gradients(grads=grads)
return loss_value,policy_state
for episode in range(episodes):
rewards = []
actions = []
observations = []
done = False
obs, _ = env.reset(seed=seed)
for _ in range(max_termination):
observations.append(obs)
logits = policy.apply(policy_state.params,jnp.array(obs)[None,...])
action = jax.random.categorical(key,logits=logits[0])
obs, reward, done, _, _ = env.step(int(action))
actions.append(action)
rewards.append(reward)
if done:
break
rewards = jnp.array([gamma**i * r for i, r in enumerate(rewards)])
eps = np.finfo(np.float32).eps.item()
# rewards = (rewards - rewards.mean())/(rewards.std() + eps)
actions = jnp.array(actions)
observations = jnp.array(observations)
loss, policy_state = update(
policy_state,
observations,
actions,
rewards
)
print(f"Episode:{episode} Loss:{loss} Reward:{rewards.sum()}")
writer.add_scalar("loss",jax.device_get(loss),episode)
return policy_state
I am facing a problem with action selection where jax.random.categorical
keeps selecting the same action always. The behaviour is similar to argmax. Am I doing something wrong here?
Any help will be appreciated