"""Reward manager for computing reward signals."""
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import torch
from prettytable import PrettyTable
from mjlab.managers.manager_base import ManagerBase, ManagerTermBaseCfg
if TYPE_CHECKING:
from mjlab.envs.manager_based_rl_env import ManagerBasedRlEnv
[docs]
@dataclass(kw_only=True)
class RewardTermCfg(ManagerTermBaseCfg):
"""Configuration for a reward term."""
func: Any
"""The callable that computes this reward term's value."""
weight: float
"""Weight multiplier for this reward term."""
[docs]
class RewardManager(ManagerBase):
"""Manages reward computation by aggregating weighted reward terms.
Reward Scaling Behavior:
By default, rewards are scaled by the environment step duration (dt). This
normalizes cumulative episodic rewards across different simulation frequencies.
The scaling can be disabled via the ``scale_by_dt`` parameter.
When ``scale_by_dt=True`` (default):
- ``reward_buf`` (returned by ``compute()``) = raw_value * weight * dt
- ``_episode_sums`` (cumulative rewards) are scaled by dt
- ``Episode_Reward/*`` logged metrics are scaled by dt
When ``scale_by_dt=False``:
- ``reward_buf`` = raw_value * weight (no dt scaling)
Regardless of the scaling setting:
- ``_step_reward`` (via ``get_active_iterable_terms()``) always contains
the unscaled reward rate (raw_value * weight)
"""
_env: ManagerBasedRlEnv
[docs]
def __init__(
self,
cfg: dict[str, RewardTermCfg],
env: ManagerBasedRlEnv,
*,
scale_by_dt: bool = True,
):
self._term_names: list[str] = list()
self._term_cfgs: list[RewardTermCfg] = list()
self._class_term_cfgs: list[RewardTermCfg] = list()
self._scale_by_dt = scale_by_dt
self.cfg = deepcopy(cfg)
super().__init__(env=env)
self._episode_sums = dict()
for term_name in self._term_names:
self._episode_sums[term_name] = torch.zeros(
self.num_envs, dtype=torch.float, device=self.device
)
self._reward_buf = torch.zeros(self.num_envs, dtype=torch.float, device=self.device)
self._step_reward = torch.zeros(
(self.num_envs, len(self._term_names)), dtype=torch.float, device=self.device
)
def __str__(self) -> str:
msg = f"<RewardManager> contains {len(self._term_names)} active terms.\n"
table = PrettyTable()
table.title = "Active Reward Terms"
table.field_names = ["Index", "Name", "Weight"]
table.align["Name"] = "l"
table.align["Weight"] = "r"
for index, (name, term_cfg) in enumerate(
zip(self._term_names, self._term_cfgs, strict=False)
):
table.add_row([index, name, term_cfg.weight])
msg += table.get_string()
msg += "\n"
return msg
# Properties.
@property
def active_terms(self) -> list[str]:
return self._term_names
# Methods.
[docs]
def reset(
self, env_ids: torch.Tensor | slice | None = None
) -> dict[str, torch.Tensor]:
if env_ids is None:
env_ids = slice(None)
extras = {}
for key in self._episode_sums.keys():
episodic_sum_avg = torch.mean(self._episode_sums[key][env_ids])
extras["Episode_Reward/" + key] = (
episodic_sum_avg / self._env.max_episode_length_s
)
self._episode_sums[key][env_ids] = 0.0
for term_cfg in self._class_term_cfgs:
term_cfg.func.reset(env_ids=env_ids)
return extras
[docs]
def compute(self, dt: float) -> torch.Tensor:
self._reward_buf[:] = 0.0
scale = dt if self._scale_by_dt else 1.0
for term_idx, (name, term_cfg) in enumerate(
zip(self._term_names, self._term_cfgs, strict=False)
):
if term_cfg.weight == 0.0:
self._step_reward[:, term_idx] = 0.0
continue
value = term_cfg.func(self._env, **term_cfg.params) * term_cfg.weight * scale
# NaN/Inf can occur from corrupted physics state; zero them to avoid policy crash.
value = torch.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0)
self._reward_buf += value
self._episode_sums[name] += value
self._step_reward[:, term_idx] = value / scale
return self._reward_buf
[docs]
def get_active_iterable_terms(self, env_idx):
terms = []
for idx, name in enumerate(self._term_names):
terms.append((name, [self._step_reward[env_idx, idx].cpu().item()]))
return terms
[docs]
def get_term_cfg(self, term_name: str) -> RewardTermCfg:
if term_name not in self._term_names:
raise ValueError(f"Term '{term_name}' not found in active terms.")
return self._term_cfgs[self._term_names.index(term_name)]
def _prepare_terms(self):
for term_name, term_cfg in self.cfg.items():
term_cfg: RewardTermCfg | None
if term_cfg is None:
print(f"term: {term_name} set to None, skipping...")
continue
self._resolve_common_term_cfg(term_name, term_cfg)
self._term_names.append(term_name)
self._term_cfgs.append(term_cfg)
if hasattr(term_cfg.func, "reset") and callable(term_cfg.func.reset):
self._class_term_cfgs.append(term_cfg)