1 2 3 4 5 6 7 8 9 10 11 12
|
if self.is_continuous: mu = self.actor_net(feat) pi = tf.clip_by_value(mu + self.action_noise(), -1, 1) else: logits = self.actor_net(feat) logp_all = tf.nn.log_softmax(logits) gumbel_noise = tf.cast(self.gumbel_dist.sample([batch_size, self.a_counts]), dtype=tf.float32) _pi = tf.nn.softmax((logp_all + gumbel_noise) / self.discrete_tau) _pi_true_one_hot = tf.one_hot(tf.argmax(_pi, axis=-1), self.a_counts) _pi_diff = tf.stop_gradient(_pi_true_one_hot - _pi) pi = _pi_diff + _pi q_actor = self.q_net(feat, pi)
|