"""Observation manager for computing observations."""
from copy import deepcopy
from dataclasses import dataclass
from typing import Literal, Sequence
import numpy as np
import torch
from prettytable import PrettyTable
from mjlab.managers.manager_base import ManagerBase, ManagerTermBaseCfg
from mjlab.utils.buffers import CircularBuffer, DelayBuffer
from mjlab.utils.noise import noise_cfg, noise_model
from mjlab.utils.noise.noise_cfg import NoiseCfg, NoiseModelCfg
[docs]
@dataclass
class ObservationTermCfg(ManagerTermBaseCfg):
"""Configuration for an observation term.
Processing pipeline: compute → noise → clip → scale → delay → history.
Delay models sensor latency. History provides temporal context. Both are optional
and can be combined.
"""
noise: NoiseCfg | NoiseModelCfg | None = None
"""Noise model to apply to the observation."""
clip: tuple[float, float] | None = None
"""Range (min, max) to clip the observation values."""
scale: tuple[float, ...] | float | torch.Tensor | None = None
"""Scaling factor(s) to multiply the observation by."""
delay_min_lag: int = 0
"""Minimum lag (in steps) for delayed observations. Lag sampled uniformly from
[min_lag, max_lag]. Convert to ms: lag * (1000 / control_hz)."""
delay_max_lag: int = 0
"""Maximum lag (in steps) for delayed observations. Use min=max for constant delay."""
delay_per_env: bool = True
"""If True, each environment samples its own lag. If False, all environments share
the same lag at each step."""
delay_hold_prob: float = 0.0
"""Probability of reusing the previous lag instead of resampling. Useful for
temporally correlated latency patterns."""
delay_update_period: int = 0
"""Resample lag every N steps (models multi-rate sensors). If 0, update every step."""
delay_per_env_phase: bool = True
"""If True and update_period > 0, stagger update timing across envs to avoid
synchronized resampling."""
history_length: int = 0
"""Number of past observations to keep in history. 0 = no history."""
flatten_history_dim: bool = True
"""Whether to flatten the history dimension into observation.
When True and concatenate_terms=True, uses term-major ordering:
[A_t0, A_t1, ..., A_tH-1, B_t0, B_t1, ..., B_tH-1, ...]
See docs/api/observation_history_delay.md for details on ordering."""
[docs]
@dataclass
class ObservationGroupCfg:
"""Configuration for an observation group.
An observation group bundles multiple observation terms together. Groups are
typically used to separate observations for different purposes (e.g., "policy"
for the actor, "critic" for the value function).
"""
terms: dict[str, ObservationTermCfg]
"""Dictionary mapping term names to their configurations."""
concatenate_terms: bool = True
"""Whether to concatenate all terms into a single tensor. If False, returns
a dict mapping term names to their individual tensors."""
concatenate_dim: int = -1
"""Dimension along which to concatenate terms. Default -1 (last dimension)."""
enable_corruption: bool = False
"""Whether to apply noise corruption to observations. Set to True during
training for domain randomization, False during evaluation."""
history_length: int | None = None
"""Group-level history length override. If set, applies to all terms in
this group. If None, each term uses its own ``history_length`` setting."""
flatten_history_dim: bool = True
"""Whether to flatten history into the observation dimension. If True,
observations have shape ``(num_envs, obs_dim * history_length)``. If False,
shape is ``(num_envs, history_length, obs_dim)``."""
nan_policy: Literal["disabled", "warn", "sanitize", "error"] = "disabled"
"""NaN/Inf handling policy for observations in this group.
- 'disabled': No checks (default, fastest)
- 'warn': Log warning with term name and env IDs, then sanitize (debugging)
- 'sanitize': Silent sanitization to 0.0 like reward manager (safe for production)
- 'error': Raise ValueError on NaN/Inf (strict development mode)
"""
nan_check_per_term: bool = True
"""If True, check each observation term individually to identify NaN source.
If False, check only the final concatenated output (faster but less informative).
Only applies when nan_policy != 'disabled'."""
[docs]
class ObservationManager(ManagerBase):
"""Manages observation computation for the environment.
The observation manager computes observations from multiple terms organized
into groups. Each term can have noise, clipping, scaling, delay, and history
applied. Groups can optionally concatenate their terms into a single tensor.
"""
[docs]
def __init__(self, cfg: dict[str, ObservationGroupCfg], env):
self.cfg = deepcopy(cfg)
super().__init__(env=env)
self._group_obs_dim: dict[str, tuple[int, ...] | list[tuple[int, ...]]] = dict()
for group_name, group_term_dims in self._group_obs_term_dim.items():
if self._group_obs_concatenate[group_name]:
term_dims = torch.stack(
[torch.tensor(dims, device="cpu") for dims in group_term_dims], dim=0
)
if len(term_dims.shape) > 1:
if self._group_obs_concatenate_dim[group_name] >= 0:
dim = self._group_obs_concatenate_dim[group_name] - 1
else:
dim = self._group_obs_concatenate_dim[group_name]
dim_sum = torch.sum(term_dims[:, dim], dim=0)
term_dims[0, dim] = dim_sum
term_dims = term_dims[0]
else:
term_dims = torch.sum(term_dims, dim=0)
self._group_obs_dim[group_name] = tuple(term_dims.tolist())
else:
self._group_obs_dim[group_name] = group_term_dims
self._obs_buffer: dict[str, torch.Tensor | dict[str, torch.Tensor]] | None = None
def __str__(self) -> str:
msg = f"<ObservationManager> contains {len(self._group_obs_term_names)} groups.\n"
for group_name, group_dim in self._group_obs_dim.items():
table = PrettyTable()
table.title = f"Active Observation Terms in Group: '{group_name}'"
if self._group_obs_concatenate[group_name]:
table.title += f" (shape: {group_dim})" # type: ignore
table.field_names = ["Index", "Name", "Shape"]
table.align["Name"] = "l"
obs_terms = zip(
self._group_obs_term_names[group_name],
self._group_obs_term_dim[group_name],
self._group_obs_term_cfgs[group_name],
strict=False,
)
for index, (name, dims, term_cfg) in enumerate(obs_terms):
if term_cfg.history_length > 0 and term_cfg.flatten_history_dim:
# Flattened history: show (9,) ← 3×(3,)
original_size = int(np.prod(dims)) // term_cfg.history_length
original_shape = (original_size,) if len(dims) == 1 else dims[1:]
shape_str = f"{dims} ← {term_cfg.history_length}×{original_shape}"
else:
shape_str = str(tuple(dims))
table.add_row([index, name, shape_str])
msg += table.get_string()
msg += "\n"
return msg
[docs]
def get_active_iterable_terms(
self, env_idx: int
) -> Sequence[tuple[str, Sequence[float]]]:
terms = []
if self._obs_buffer is None:
self.compute()
assert self._obs_buffer is not None
obs_buffer: dict[str, torch.Tensor | dict[str, torch.Tensor]] = self._obs_buffer
for group_name, _ in self.group_obs_dim.items():
if not self.group_obs_concatenate[group_name]:
buffers = obs_buffer[group_name]
assert isinstance(buffers, dict)
for name, term in buffers.items():
terms.append((group_name + "-" + name, term[env_idx].cpu().tolist())) # type: ignore[unsupported-operator]
continue
idx = 0
data = obs_buffer[group_name]
assert isinstance(data, torch.Tensor)
for name, shape in zip(
self._group_obs_term_names[group_name],
self._group_obs_term_dim[group_name],
strict=False,
):
data_length = np.prod(shape)
term = data[env_idx, idx : idx + data_length]
terms.append((group_name + "-" + name, term.cpu().tolist()))
idx += data_length
return terms
# Properties.
@property
def active_terms(self) -> dict[str, list[str]]:
return self._group_obs_term_names
@property
def group_obs_dim(self) -> dict[str, tuple[int, ...] | list[tuple[int, ...]]]:
return self._group_obs_dim
@property
def group_obs_term_dim(self) -> dict[str, list[tuple[int, ...]]]:
return self._group_obs_term_dim
@property
def group_obs_concatenate(self) -> dict[str, bool]:
return self._group_obs_concatenate
# Methods.
[docs]
def get_term_cfg(self, group_name: str, term_name: str) -> ObservationTermCfg:
if group_name not in self._group_obs_term_names:
raise ValueError(f"Group '{group_name}' not found in active groups.")
if term_name not in self._group_obs_term_names[group_name]:
raise ValueError(f"Term '{term_name}' not found in group '{group_name}'.")
index = self._group_obs_term_names[group_name].index(term_name)
return self._group_obs_term_cfgs[group_name][index]
[docs]
def reset(self, env_ids: torch.Tensor | slice | None = None) -> dict[str, float]:
# Invalidate cache since reset envs will have different observations.
self._obs_buffer = None
for group_name, group_cfg in self._group_obs_class_term_cfgs.items():
for term_cfg in group_cfg:
term_cfg.func.reset(env_ids=env_ids)
for term_name in self._group_obs_term_names[group_name]:
batch_ids = None if isinstance(env_ids, slice) else env_ids
if term_name in self._group_obs_term_delay_buffer[group_name]:
self._group_obs_term_delay_buffer[group_name][term_name].reset(
batch_ids=batch_ids
)
if term_name in self._group_obs_term_history_buffer[group_name]:
self._group_obs_term_history_buffer[group_name][term_name].reset(
batch_ids=batch_ids
)
for mod in self._group_obs_class_instances.values():
mod.reset(env_ids=env_ids)
return {}
def _check_and_handle_nans(
self, tensor: torch.Tensor, context: str, policy: str
) -> torch.Tensor:
"""Check for NaN/Inf and handle according to policy.
Args:
tensor: Observation tensor to check.
context: Context string for error/warning messages (e.g., "policy/base_lin_vel").
policy: NaN handling policy ("disabled", "warn", "sanitize", "error").
Returns:
The tensor, potentially sanitized depending on policy.
Raises:
ValueError: If policy is "error" and NaN/Inf detected.
"""
if policy == "disabled":
return tensor
has_nan = torch.isnan(tensor).any()
has_inf = torch.isinf(tensor).any()
if not (has_nan or has_inf):
return tensor
if policy == "error":
nan_mask = torch.isnan(tensor).any(dim=-1) | torch.isinf(tensor).any(dim=-1)
nan_env_ids = torch.where(nan_mask)[0].cpu().tolist()
raise ValueError(
f"NaN/Inf detected in observation '{context}' "
f"for environments: {nan_env_ids[:10]}"
)
if policy == "warn":
nan_mask = torch.isnan(tensor).any(dim=-1) | torch.isinf(tensor).any(dim=-1)
nan_env_ids = torch.where(nan_mask)[0].cpu().tolist()
print(
f"[ObservationManager] NaN/Inf in '{context}' "
f"(envs: {nan_env_ids[:5]}). Sanitizing to 0."
)
# Sanitize (applies to both "warn" and "sanitize" policies).
return torch.nan_to_num(tensor, nan=0.0, posinf=0.0, neginf=0.0)
[docs]
def compute(
self, update_history: bool = False
) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
# Return cached observations if not updating and cache exists.
# This prevents double-pushing to delay buffers when compute() is called
# multiple times per control step (e.g., in get_observations() after step()).
if not update_history and self._obs_buffer is not None:
return self._obs_buffer
obs_buffer: dict[str, torch.Tensor | dict[str, torch.Tensor]] = dict()
for group_name in self._group_obs_term_names:
obs_buffer[group_name] = self.compute_group(group_name, update_history)
self._obs_buffer = obs_buffer
return obs_buffer
[docs]
def compute_group(
self, group_name: str, update_history: bool = False
) -> torch.Tensor | dict[str, torch.Tensor]:
group_cfg = self.cfg[group_name]
group_term_names = self._group_obs_term_names[group_name]
group_obs: dict[str, torch.Tensor] = {}
obs_terms = zip(
group_term_names, self._group_obs_term_cfgs[group_name], strict=False
)
for term_name, term_cfg in obs_terms:
obs: torch.Tensor = term_cfg.func(self._env, **term_cfg.params).clone()
if isinstance(term_cfg.noise, noise_cfg.NoiseCfg):
obs = term_cfg.noise.apply(obs)
elif isinstance(term_cfg.noise, noise_cfg.NoiseModelCfg):
obs = self._group_obs_class_instances[term_name](obs)
if term_cfg.clip:
obs = obs.clip_(min=term_cfg.clip[0], max=term_cfg.clip[1])
if term_cfg.scale is not None:
scale = term_cfg.scale
assert isinstance(scale, torch.Tensor)
obs = obs.mul_(scale)
# Check for NaN/Inf before delay/history buffers (per-term checking).
if group_cfg.nan_check_per_term and group_cfg.nan_policy != "disabled":
obs = self._check_and_handle_nans(
obs, context=f"{group_name}/{term_name}", policy=group_cfg.nan_policy
)
if term_cfg.delay_max_lag > 0:
delay_buffer = self._group_obs_term_delay_buffer[group_name][term_name]
delay_buffer.append(obs)
obs = delay_buffer.compute()
if term_cfg.history_length > 0:
circular_buffer = self._group_obs_term_history_buffer[group_name][term_name]
if update_history or not circular_buffer.is_initialized:
circular_buffer.append(obs)
if term_cfg.flatten_history_dim:
group_obs[term_name] = circular_buffer.buffer.reshape(self._env.num_envs, -1)
else:
group_obs[term_name] = circular_buffer.buffer
else:
group_obs[term_name] = obs
# Final NaN check for non-per-term checking.
if not group_cfg.nan_check_per_term and group_cfg.nan_policy != "disabled":
if self._group_obs_concatenate[group_name]:
# Will check after concatenation below.
pass
else:
for term_name in group_obs:
group_obs[term_name] = self._check_and_handle_nans(
group_obs[term_name],
context=f"{group_name}/{term_name}",
policy=group_cfg.nan_policy,
)
if self._group_obs_concatenate[group_name]:
result = torch.cat(
list(group_obs.values()), dim=self._group_obs_concatenate_dim[group_name]
)
# Final check for concatenated result (non-per-term checking).
if not group_cfg.nan_check_per_term and group_cfg.nan_policy != "disabled":
result = self._check_and_handle_nans(
result, context=group_name, policy=group_cfg.nan_policy
)
return result
return group_obs
def _prepare_terms(self) -> None:
self._group_obs_term_names: dict[str, list[str]] = dict()
self._group_obs_term_dim: dict[str, list[tuple[int, ...]]] = dict()
self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_concatenate: dict[str, bool] = dict()
self._group_obs_concatenate_dim: dict[str, int] = dict()
self._group_obs_class_instances: dict[str, noise_model.NoiseModel] = {}
self._group_obs_term_delay_buffer: dict[str, dict[str, DelayBuffer]] = dict()
self._group_obs_term_history_buffer: dict[str, dict[str, CircularBuffer]] = dict()
for group_name, group_cfg in self.cfg.items():
group_cfg: ObservationGroupCfg | None
if group_cfg is None:
print(f"group: {group_name} set to None, skipping...")
continue
self._group_obs_term_names[group_name] = list()
self._group_obs_term_dim[group_name] = list()
self._group_obs_term_cfgs[group_name] = list()
self._group_obs_class_term_cfgs[group_name] = list()
group_entry_delay_buffer: dict[str, DelayBuffer] = dict()
group_entry_history_buffer: dict[str, CircularBuffer] = dict()
self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms
self._group_obs_concatenate_dim[group_name] = (
group_cfg.concatenate_dim + 1
if group_cfg.concatenate_dim >= 0
else group_cfg.concatenate_dim
)
for term_name, term_cfg in group_cfg.terms.items():
term_cfg: ObservationTermCfg | None
if term_cfg is None:
print(f"term: {term_name} set to None, skipping...")
continue
# NOTE: This deepcopy is important to avoid cross-group contamination of term
# configs.
term_cfg = deepcopy(term_cfg)
self._resolve_common_term_cfg(term_name, term_cfg)
if not group_cfg.enable_corruption:
term_cfg.noise = None
if group_cfg.history_length is not None:
term_cfg.history_length = group_cfg.history_length
term_cfg.flatten_history_dim = group_cfg.flatten_history_dim
self._group_obs_term_names[group_name].append(term_name)
self._group_obs_term_cfgs[group_name].append(term_cfg)
if hasattr(term_cfg.func, "reset") and callable(term_cfg.func.reset):
self._group_obs_class_term_cfgs[group_name].append(term_cfg)
obs_dims = tuple(term_cfg.func(self._env, **term_cfg.params).shape)
if term_cfg.scale is not None:
term_cfg.scale = torch.tensor(
term_cfg.scale, dtype=torch.float, device=self._env.device
)
if term_cfg.noise is not None and isinstance(
term_cfg.noise, noise_cfg.NoiseModelCfg
):
noise_model_cls = term_cfg.noise.class_type
assert issubclass(noise_model_cls, noise_model.NoiseModel), (
f"Class type for observation term '{term_name}' NoiseModelCfg"
f" is not a subclass of 'NoiseModel'. Received: '{type(noise_model_cls)}'."
)
self._group_obs_class_instances[term_name] = noise_model_cls(
term_cfg.noise, num_envs=self._env.num_envs, device=self._env.device
)
if term_cfg.delay_max_lag > 0:
group_entry_delay_buffer[term_name] = DelayBuffer(
min_lag=term_cfg.delay_min_lag,
max_lag=term_cfg.delay_max_lag,
batch_size=self._env.num_envs,
device=self._env.device,
per_env=term_cfg.delay_per_env,
hold_prob=term_cfg.delay_hold_prob,
update_period=term_cfg.delay_update_period,
per_env_phase=term_cfg.delay_per_env_phase,
)
if term_cfg.history_length > 0:
group_entry_history_buffer[term_name] = CircularBuffer(
max_len=term_cfg.history_length,
batch_size=self._env.num_envs,
device=self._env.device,
)
old_dims = list(obs_dims)
old_dims.insert(1, term_cfg.history_length)
obs_dims = tuple(old_dims)
if term_cfg.flatten_history_dim:
obs_dims = (obs_dims[0], int(np.prod(obs_dims[1:])))
self._group_obs_term_dim[group_name].append(obs_dims[1:])
self._group_obs_term_delay_buffer[group_name] = group_entry_delay_buffer
self._group_obs_term_history_buffer[group_name] = group_entry_history_buffer