Train REINFORCE with JAX

Hi Everyone I am trying to have a REINFORCE implementation with Jax but I am facing some issues. The Policy looks something like this

class MLPPolicy(nn.Module):
    action_dims: int

    @nn.compact
    def __call__(self, input: jnp.ndarray):
        x = nn.Dense(4)(input)
        x = nn.relu(x)
        x = nn.Dense(16)(x)
        x = nn.relu(x)
        x = nn.Dense(self.action_dims)(x)
        # x = nn.softmax(x)
        
        return x 

Train method looks something like this

def train(env_id,gamma,episodes,max_termination,seed,num_envs,learning_rate):
    run_name = f"{env_id}__{seed}_{int(time.time())}"
    writer = SummaryWriter(f"runs/{run_name}")
    
    random.seed(seed)
    np.random.seed(seed=seed)
    key = jax.random.PRNGKey(seed)
    # key,q_key = jax.random.split(key,2)

    env = make_env(env_id, seed, 0, True, run_name)()

    obs, _ = env.reset(seed=seed)
    action_dims = env.action_space.n
    policy = MLPPolicy(action_dims=action_dims)

    policy_state = TrainState.create(
        apply_fn=policy.apply,
        params=policy.init(key,obs),
        tx=optax.adam(learning_rate=learning_rate)
    )
    
    policy.apply = jax.jit(policy.apply)
    
    # @jax.jit
    def update(policy_state,observations,actions,rewards):
        def loss_fn(params):
            logits = policy.apply(params,observations)
            log_probs = jax.nn.log_softmax(logits)
            # print("Log probs:",log_probs)
            return -jnp.mean(jnp.sum(onehot(actions,log_probs.shape[-1]) * log_probs,axis=-1) * rewards)

        loss_value,grads = jax.value_and_grad(loss_fn)(policy_state.params)
        policy_state = policy_state.apply_gradients(grads=grads)
        return loss_value,policy_state



    for episode in range(episodes):
        rewards = []
        actions = []
        observations = []

        done = False
        obs, _ = env.reset(seed=seed)
        for _ in range(max_termination):
            observations.append(obs)
            logits = policy.apply(policy_state.params,jnp.array(obs)[None,...])
            action = jax.random.categorical(key,logits=logits[0])
            obs, reward, done, _, _ = env.step(int(action))
            actions.append(action)
            rewards.append(reward)
            if done:
                break


        rewards = jnp.array([gamma**i * r for i, r in enumerate(rewards)])
        eps = np.finfo(np.float32).eps.item()
        # rewards = (rewards - rewards.mean())/(rewards.std() + eps)
        actions = jnp.array(actions)
        observations = jnp.array(observations)

        loss, policy_state = update(
            policy_state,
            observations,
            actions,
            rewards
        )
        print(f"Episode:{episode}   Loss:{loss}  Reward:{rewards.sum()}")
        writer.add_scalar("loss",jax.device_get(loss),episode)

    return policy_state

I am facing a problem with action selection where jax.random.categorical keeps selecting the same action always. The behaviour is similar to argmax. Am I doing something wrong here?
Any help will be appreciated