Source code for mjlab.managers.termination_manager
"""Termination manager for computing done signals."""
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Sequence
import torch
from prettytable import PrettyTable
from mjlab.managers.manager_base import ManagerBase, ManagerTermBaseCfg
if TYPE_CHECKING:
from mjlab.envs.manager_based_rl_env import ManagerBasedRlEnv
[docs]
@dataclass
class TerminationTermCfg(ManagerTermBaseCfg):
"""Configuration for a termination term."""
time_out: bool = False
"""Whether the term contributes towards episodic timeouts."""
[docs]
class TerminationManager(ManagerBase):
"""Manages termination conditions for the environment.
The termination manager aggregates multiple termination terms to compute
episode done signals. Terms can be either truncations (time-based) or
terminations (failure conditions).
"""
_env: ManagerBasedRlEnv
[docs]
def __init__(self, cfg: dict[str, TerminationTermCfg], env: ManagerBasedRlEnv):
self._term_names: list[str] = list()
self._term_cfgs: list[TerminationTermCfg] = list()
self._class_term_cfgs: list[TerminationTermCfg] = list()
self.cfg = deepcopy(cfg)
super().__init__(env)
self._term_dones = dict()
for term_name in self._term_names:
self._term_dones[term_name] = torch.zeros(
self.num_envs, device=self.device, dtype=torch.bool
)
self._truncated_buf = torch.zeros(
self.num_envs, device=self.device, dtype=torch.bool
)
self._terminated_buf = torch.zeros_like(self._truncated_buf)
def __str__(self) -> str:
msg = f"<TerminationManager> contains {len(self._term_names)} active terms.\n"
table = PrettyTable()
table.title = "Active Termination Terms"
table.field_names = ["Index", "Name", "Time Out"]
table.align["Name"] = "l"
for index, (name, term_cfg) in enumerate(
zip(self._term_names, self._term_cfgs, strict=False)
):
table.add_row([index, name, term_cfg.time_out])
msg += table.get_string()
msg += "\n"
return msg
# Properties.
@property
def active_terms(self) -> list[str]:
return self._term_names
@property
def dones(self) -> torch.Tensor:
return self._truncated_buf | self._terminated_buf
@property
def time_outs(self) -> torch.Tensor:
return self._truncated_buf
@property
def terminated(self) -> torch.Tensor:
return self._terminated_buf
# Methods.
[docs]
def reset(
self, env_ids: torch.Tensor | slice | None = None
) -> dict[str, torch.Tensor]:
if env_ids is None:
env_ids = slice(None)
extras = {}
for key in self._term_dones.keys():
extras["Episode_Termination/" + key] = torch.count_nonzero(
self._term_dones[key][env_ids]
).item()
for term_cfg in self._class_term_cfgs:
term_cfg.func.reset(env_ids=env_ids)
return extras
[docs]
def compute(self) -> torch.Tensor:
self._truncated_buf[:] = False
self._terminated_buf[:] = False
for name, term_cfg in zip(self._term_names, self._term_cfgs, strict=False):
value = term_cfg.func(self._env, **term_cfg.params)
if term_cfg.time_out:
self._truncated_buf |= value
else:
self._terminated_buf |= value
self._term_dones[name][:] = value
return self._truncated_buf | self._terminated_buf
[docs]
def get_term(self, name: str) -> torch.Tensor:
return self._term_dones[name]
[docs]
def get_term_cfg(self, term_name: str) -> TerminationTermCfg:
if term_name not in self._term_names:
raise ValueError(f"Term '{term_name}' not found in active terms.")
return self._term_cfgs[self._term_names.index(term_name)]
[docs]
def get_active_iterable_terms(
self, env_idx: int
) -> Sequence[tuple[str, Sequence[float]]]:
terms = []
for key in self._term_dones.keys():
terms.append((key, [self._term_dones[key][env_idx].float().cpu().item()]))
return terms
def _prepare_terms(self):
for term_name, term_cfg in self.cfg.items():
term_cfg: TerminationTermCfg | None
if term_cfg is None:
print(f"term: {term_name} set to None, skipping...")
continue
self._resolve_common_term_cfg(term_name, term_cfg)
self._term_names.append(term_name)
self._term_cfgs.append(term_cfg)
if hasattr(term_cfg.func, "reset") and callable(term_cfg.func.reset):
self._class_term_cfgs.append(term_cfg)