Source code for mjlab.sensor.sensor

"""Base sensor interface."""

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, TypeVar

import mujoco
import mujoco_warp as mjwarp
import torch

if TYPE_CHECKING:
  from mjlab.entity import Entity
  from mjlab.viewer.debug_visualizer import DebugVisualizer


T = TypeVar("T")


[docs] @dataclass class SensorCfg(ABC): """Base configuration for a sensor.""" name: str
[docs] @abstractmethod def build(self) -> Sensor[Any]: """Build sensor instance from this config.""" raise NotImplementedError
[docs] class Sensor(ABC, Generic[T]): """Base sensor interface with typed data and per-step caching. Type parameter T specifies the type of data returned by the sensor. For example: - Sensor[torch.Tensor] for sensors returning raw tensors - Sensor[ContactData] for sensors returning structured contact data Subclasses should not forget to: - Call `super().__init__()` in their `__init__` method - If overriding `reset()` or `update()`, call `super()` FIRST to invalidate cache """
[docs] def __init__(self) -> None: self._cached_data: T | None = None self._cache_valid: bool = False
[docs] @abstractmethod def edit_spec( self, scene_spec: mujoco.MjSpec, entities: dict[str, Entity], ) -> None: """Edit the scene spec to add this sensor. This is called during scene construction to add sensor elements to the MjSpec. Args: scene_spec: The scene MjSpec to edit. entities: Dictionary of entities in the scene, keyed by name. """ raise NotImplementedError
[docs] @abstractmethod def initialize( self, mj_model: mujoco.MjModel, model: mjwarp.Model, data: mjwarp.Data, device: str, ) -> None: """Initialize the sensor after model compilation. This is called after the MjSpec is compiled into an MjModel and the simulation is ready to run. Use this to cache sensor indices, allocate buffers, etc. Args: mj_model: The compiled MuJoCo model. model: The mjwarp model wrapper. data: The mjwarp data arrays. device: Device for tensor operations (e.g., "cuda", "cpu"). """ raise NotImplementedError
@property def data(self) -> T: """Get the current sensor data, using cached value if available. This property returns the sensor's current data in its specific type. The data type is specified by the type parameter T. The data is cached per-step and recomputed only when the cache is invalidated (after `reset()` or `update()` is called). Returns: The sensor data in the format specified by type parameter T. """ if not self._cache_valid: self._cached_data = self._compute_data() self._cache_valid = True assert self._cached_data is not None return self._cached_data @abstractmethod def _compute_data(self) -> T: """Compute and return the sensor data. Subclasses must implement this method to compute the sensor's data. This is called by the `data` property when the cache is invalid. Returns: The computed sensor data. """ raise NotImplementedError def _invalidate_cache(self) -> None: """Invalidate the cached data, forcing recomputation on next access.""" self._cache_valid = False
[docs] def reset(self, env_ids: torch.Tensor | slice | None = None) -> None: """Reset sensor state for specified environments. Invalidates the data cache. Override in subclasses that maintain internal state, but call `super().reset(env_ids)` FIRST. Args: env_ids: Environment indices to reset. If None, reset all environments. """ del env_ids # Unused. self._invalidate_cache()
[docs] def update(self, dt: float) -> None: """Update sensor state after a simulation step. Invalidates the data cache. Override in subclasses that need per-step updates, but call `super().update(dt)` FIRST. Args: dt: Time step in seconds. """ del dt # Unused. self._invalidate_cache()
[docs] def debug_vis(self, visualizer: DebugVisualizer) -> None: """Visualize sensor data for debugging. Base implementation does nothing. Override in subclasses that support debug visualization. Args: visualizer: The debug visualizer to draw to. """ del visualizer # Unused.