Source code for mjlab.managers.manager_base
from __future__ import annotations
import abc
import inspect
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import torch
from mjlab.managers.scene_entity_config import SceneEntityCfg
if TYPE_CHECKING:
from mjlab.envs import ManagerBasedRlEnv
[docs]
@dataclass
class ManagerTermBaseCfg:
"""Base configuration for manager terms.
This is the base config for terms in observation, reward, termination, curriculum,
and event managers. It provides a common interface for specifying a callable
and its parameters.
The ``func`` field accepts either a function or a class:
**Function-based terms** are simpler and suitable for stateless computations:
.. code-block:: python
RewardTermCfg(func=mdp.joint_torques_l2, weight=-0.01)
**Class-based terms** are instantiated with ``(cfg, env)`` and useful when you need
to:
- Cache computed values at initialization (e.g., resolve regex patterns to indices)
- Maintain state across calls
- Perform expensive setup once rather than every call
.. code-block:: python
class posture:
def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv):
# Resolve std dict to tensor once at init
self.std = resolve_std_to_tensor(cfg.params["std"], env)
def __call__(self, env, **kwargs) -> torch.Tensor:
# Use cached self.std
return compute_posture_reward(env, self.std)
RewardTermCfg(func=posture, params={"std": {".*knee.*": 0.3}}, weight=1.0)
Class-based terms can optionally implement ``reset(env_ids)`` for per-episode state.
"""
func: Any
"""The callable that computes this term's value. Can be a function or a class.
Classes are auto-instantiated with ``(cfg=term_cfg, env=env)``."""
params: dict[str, Any] = field(default_factory=lambda: {})
"""Additional keyword arguments passed to func when called."""
[docs]
class ManagerTermBase:
[docs]
def __init__(self, env: ManagerBasedRlEnv):
self._env = env
# Properties.
@property
def num_envs(self) -> int:
return self._env.num_envs
@property
def device(self) -> str:
return self._env.device
@property
def name(self) -> str:
return self.__class__.__name__
# Methods.
[docs]
def reset(self, env_ids: torch.Tensor | slice | None) -> Any:
"""Resets the manager term."""
del env_ids # Unused.
pass
def __call__(self, *args, **kwargs) -> Any:
"""Returns the value of the term required by the manager."""
raise NotImplementedError
[docs]
class ManagerBase(abc.ABC):
"""Base class for all managers."""
[docs]
def __init__(self, env: ManagerBasedRlEnv):
self._env = env
self._prepare_terms()
# Properties.
@property
def num_envs(self) -> int:
return self._env.num_envs
@property
def device(self) -> str:
return self._env.device
@property
@abc.abstractmethod
def active_terms(self) -> list[str] | dict[Any, list[str]]:
raise NotImplementedError
# Methods.
[docs]
def reset(self, env_ids: torch.Tensor) -> dict[str, Any]:
"""Resets the manager and returns logging info for the current step."""
del env_ids # Unused.
return {}
[docs]
def get_active_iterable_terms(
self, env_idx: int
) -> Sequence[tuple[str, Sequence[float]]]:
raise NotImplementedError
@abc.abstractmethod
def _prepare_terms(self):
raise NotImplementedError
def _resolve_common_term_cfg(self, term_name: str, term_cfg: ManagerTermBaseCfg):
del term_name # Unused.
for value in term_cfg.params.values():
if isinstance(value, SceneEntityCfg):
value.resolve(self._env.scene)
if inspect.isclass(term_cfg.func):
term_cfg.func = term_cfg.func(cfg=term_cfg, env=self._env)