Source code for mjlab.managers.action_manager
"""Action manager for processing actions sent to the environment."""
from __future__ import annotations
import abc
from dataclasses import dataclass
from typing import TYPE_CHECKING, Sequence
import torch
from prettytable import PrettyTable
from mjlab.managers.manager_base import ManagerBase, ManagerTermBase
if TYPE_CHECKING:
from mjlab.envs import ManagerBasedRlEnv
[docs]
@dataclass(kw_only=True)
class ActionTermCfg(abc.ABC):
"""Configuration for an action term.
Action terms process raw actions from the policy and apply them to entities
in the scene (e.g., setting joint positions, velocities, or efforts).
"""
entity_name: str
"""Name of the entity in the scene that this action term controls."""
clip: dict[str, tuple] | None = None
"""Optional clipping bounds per transmission type. Maps transmission name
(e.g., 'position', 'velocity') to (min, max) tuple."""
[docs]
@abc.abstractmethod
def build(self, env: ManagerBasedRlEnv) -> ActionTerm:
"""Build the action term from this config."""
raise NotImplementedError
[docs]
class ActionTerm(ManagerTermBase):
"""Base class for action terms.
The action term is responsible for processing the raw actions sent to the environment
and applying them to the entity managed by the term.
"""
[docs]
def __init__(self, cfg: ActionTermCfg, env: ManagerBasedRlEnv):
self.cfg = cfg
super().__init__(env)
self._entity = self._env.scene[self.cfg.entity_name]
@property
@abc.abstractmethod
def action_dim(self) -> int:
raise NotImplementedError
[docs]
@abc.abstractmethod
def process_actions(self, actions: torch.Tensor) -> None:
raise NotImplementedError
[docs]
@abc.abstractmethod
def apply_actions(self) -> None:
raise NotImplementedError
@property
@abc.abstractmethod
def raw_action(self) -> torch.Tensor:
raise NotImplementedError
[docs]
class ActionManager(ManagerBase):
"""Manages action processing for the environment.
The action manager aggregates multiple action terms, each controlling a different
entity or aspect of the simulation. It splits the policy's action tensor and
routes each slice to the appropriate action term.
"""
[docs]
def __init__(self, cfg: dict[str, ActionTermCfg], env: ManagerBasedRlEnv):
self.cfg = cfg
super().__init__(env=env)
# Create buffers to store actions.
self._action = torch.zeros(
(self.num_envs, self.total_action_dim), device=self.device
)
self._prev_action = torch.zeros_like(self._action)
self._prev_prev_action = torch.zeros_like(self._action)
def __str__(self) -> str:
msg = f"<ActionManager> contains {len(self._term_names)} active terms.\n"
table = PrettyTable()
table.title = f"Active Action Terms (shape: {self.total_action_dim})"
table.field_names = ["Index", "Name", "Dimension"]
table.align["Name"] = "l"
table.align["Dimension"] = "r"
for index, (name, term) in enumerate(self._terms.items()):
table.add_row([index, name, term.action_dim])
msg += table.get_string()
msg += "\n"
return msg
# Properties.
@property
def total_action_dim(self) -> int:
return sum(self.action_term_dim)
@property
def action_term_dim(self) -> list[int]:
return [term.action_dim for term in self._terms.values()]
@property
def action(self) -> torch.Tensor:
return self._action
@property
def prev_action(self) -> torch.Tensor:
return self._prev_action
@property
def prev_prev_action(self) -> torch.Tensor:
return self._prev_prev_action
@property
def active_terms(self) -> list[str]:
return self._term_names
# Methods.
[docs]
def get_term(self, name: str) -> ActionTerm:
return self._terms[name]
[docs]
def reset(self, env_ids: torch.Tensor | slice | None = None) -> dict[str, float]:
if env_ids is None:
env_ids = slice(None)
# Reset action history.
self._prev_action[env_ids] = 0.0
self._prev_prev_action[env_ids] = 0.0
self._action[env_ids] = 0.0
# Reset action terms.
for term in self._terms.values():
term.reset(env_ids=env_ids)
return {}
[docs]
def process_action(self, action: torch.Tensor) -> None:
if self.total_action_dim != action.shape[1]:
raise ValueError(
f"Invalid action shape, expected: {self.total_action_dim}, received: {action.shape[1]}."
)
self._prev_prev_action[:] = self._prev_action
self._prev_action[:] = self._action
self._action[:] = action.to(self.device)
# Split and apply.
idx = 0
for term in self._terms.values():
term_actions = action[:, idx : idx + term.action_dim]
term.process_actions(term_actions)
idx += term.action_dim
[docs]
def apply_action(self) -> None:
for term in self._terms.values():
term.apply_actions()
[docs]
def get_active_iterable_terms(
self, env_idx: int
) -> Sequence[tuple[str, Sequence[float]]]:
terms = []
idx = 0
for name, term in self._terms.items():
term_actions = self._action[env_idx, idx : idx + term.action_dim].cpu()
terms.append((name, term_actions.tolist()))
idx += term.action_dim
return terms
def _prepare_terms(self):
self._term_names: list[str] = list()
self._terms: dict[str, ActionTerm] = dict()
for term_name, term_cfg in self.cfg.items():
term_cfg: ActionTermCfg | None
if term_cfg is None:
print(f"term: {term_name} set to None, skipping...")
continue
term = term_cfg.build(self._env)
self._term_names.append(term_name)
self._terms[term_name] = term