Source code for mjlab.sensor.builtin_sensor

"""Sensors that wrap MuJoCo builtin sensors."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Literal

import mujoco
import mujoco_warp as mjwarp
import torch

from mjlab.entity import Entity
from mjlab.sensor.sensor import Sensor, SensorCfg

SensorType = Literal[
  # Site sensors.
  "accelerometer",
  "velocimeter",
  "gyro",
  "force",
  "torque",
  "magnetometer",
  "rangefinder",
  # Joint sensors.
  "jointpos",
  "jointvel",
  "jointlimitpos",
  "jointlimitvel",
  "jointlimitfrc",
  "jointactuatorfrc",
  # Tendon sensors.
  "tendonpos",
  "tendonvel",
  "tendonactuatorfrc",
  # Actuator sensors.
  "actuatorpos",
  "actuatorvel",
  "actuatorfrc",
  # Frame sensors.
  "framepos",
  "framequat",
  "framexaxis",
  "frameyaxis",
  "framezaxis",
  "framelinvel",
  "frameangvel",
  "framelinacc",
  "frameangacc",
  # Subtree sensors.
  "subtreecom",
  "subtreelinvel",
  "subtreeangmom",
  # Misc.
  "e_potential",
  "e_kinetic",
  "clock",
]

_SENSOR_TYPE_MAP = {
  # Site sensors.
  "accelerometer": mujoco.mjtSensor.mjSENS_ACCELEROMETER,
  "velocimeter": mujoco.mjtSensor.mjSENS_VELOCIMETER,
  "gyro": mujoco.mjtSensor.mjSENS_GYRO,
  "force": mujoco.mjtSensor.mjSENS_FORCE,
  "torque": mujoco.mjtSensor.mjSENS_TORQUE,
  "magnetometer": mujoco.mjtSensor.mjSENS_MAGNETOMETER,
  "rangefinder": mujoco.mjtSensor.mjSENS_RANGEFINDER,
  # Joint sensors.
  "jointpos": mujoco.mjtSensor.mjSENS_JOINTPOS,
  "jointvel": mujoco.mjtSensor.mjSENS_JOINTVEL,
  "jointlimitpos": mujoco.mjtSensor.mjSENS_JOINTLIMITPOS,
  "jointlimitvel": mujoco.mjtSensor.mjSENS_JOINTLIMITVEL,
  "jointlimitfrc": mujoco.mjtSensor.mjSENS_JOINTLIMITFRC,
  "jointactuatorfrc": mujoco.mjtSensor.mjSENS_JOINTACTFRC,
  # Tendon sensors.
  "tendonpos": mujoco.mjtSensor.mjSENS_TENDONPOS,
  "tendonvel": mujoco.mjtSensor.mjSENS_TENDONVEL,
  "tendonactuatorfrc": mujoco.mjtSensor.mjSENS_TENDONACTFRC,
  # Actuator sensors.
  "actuatorpos": mujoco.mjtSensor.mjSENS_ACTUATORPOS,
  "actuatorvel": mujoco.mjtSensor.mjSENS_ACTUATORVEL,
  "actuatorfrc": mujoco.mjtSensor.mjSENS_ACTUATORFRC,
  # Frame sensors.
  "framepos": mujoco.mjtSensor.mjSENS_FRAMEPOS,
  "framequat": mujoco.mjtSensor.mjSENS_FRAMEQUAT,
  "framexaxis": mujoco.mjtSensor.mjSENS_FRAMEXAXIS,
  "frameyaxis": mujoco.mjtSensor.mjSENS_FRAMEYAXIS,
  "framezaxis": mujoco.mjtSensor.mjSENS_FRAMEZAXIS,
  "framelinvel": mujoco.mjtSensor.mjSENS_FRAMELINVEL,
  "frameangvel": mujoco.mjtSensor.mjSENS_FRAMEANGVEL,
  "framelinacc": mujoco.mjtSensor.mjSENS_FRAMELINACC,
  "frameangacc": mujoco.mjtSensor.mjSENS_FRAMEANGACC,
  # Subtree sensors.
  "subtreecom": mujoco.mjtSensor.mjSENS_SUBTREECOM,
  "subtreelinvel": mujoco.mjtSensor.mjSENS_SUBTREELINVEL,
  "subtreeangmom": mujoco.mjtSensor.mjSENS_SUBTREEANGMOM,
  # Misc.
  "clock": mujoco.mjtSensor.mjSENS_CLOCK,
  "e_potential": mujoco.mjtSensor.mjSENS_E_POTENTIAL,
  "e_kinetic": mujoco.mjtSensor.mjSENS_E_KINETIC,
}

_OBJECT_TYPE_MAP = {
  "body": mujoco.mjtObj.mjOBJ_BODY,
  "xbody": mujoco.mjtObj.mjOBJ_XBODY,
  "joint": mujoco.mjtObj.mjOBJ_JOINT,
  "geom": mujoco.mjtObj.mjOBJ_GEOM,
  "site": mujoco.mjtObj.mjOBJ_SITE,
  "actuator": mujoco.mjtObj.mjOBJ_ACTUATOR,
  "tendon": mujoco.mjtObj.mjOBJ_TENDON,
  "camera": mujoco.mjtObj.mjOBJ_CAMERA,
}

_SENSORS_REQUIRING_SITE = {
  "accelerometer",
  "velocimeter",
  "gyro",
  "force",
  "torque",
  "magnetometer",
  "rangefinder",
}

_SENSORS_REQUIRING_SPATIAL_FRAME = {
  "framepos",
  "framequat",
  "framexaxis",
  "frameyaxis",
  "framezaxis",
  "framelinvel",
  "frameangvel",
  "framelinacc",
  "frameangacc",
}

_SENSORS_REQUIRING_BODY = {
  "subtreecom",
  "subtreelinvel",
  "subtreeangmom",
}

_SENSOR_OBJECT_REQUIREMENTS = {
  "jointpos": "joint",
  "jointvel": "joint",
  "jointlimitpos": "joint",
  "jointlimitvel": "joint",
  "jointlimitfrc": "joint",
  "jointactuatorfrc": "joint",
  "tendonpos": "tendon",
  "tendonvel": "tendon",
  "tendonactuatorfrc": "tendon",
  "actuatorpos": "actuator",
  "actuatorvel": "actuator",
  "actuatorfrc": "actuator",
}

_SPATIAL_FRAME_TYPES = {"body", "xbody", "geom", "site", "camera"}
_SENSORS_ALLOWING_REF = {
  "framepos",
  "framequat",
  "framexaxis",
  "frameyaxis",
  "framezaxis",
  "framelinvel",
  "frameangvel",
  "framelinacc",
  "frameangacc",
}


[docs] @dataclass class ObjRef: """Reference to a MuJoCo object (body, joint, site, etc.). Used to specify which object a sensor is attached to and its frame of reference. The entity field allows scoping objects to specific entity namespaces. """ type: Literal[ "body", "xbody", "joint", "geom", "site", "actuator", "tendon", "camera" ] """Type of the object.""" name: str """Name of the object.""" entity: str | None = None """Optional entity prefix for the object name."""
[docs] def prefixed_name(self) -> str: """Get the full name with entity prefix if applicable.""" if self.entity: return f"{self.entity}/{self.name}" return self.name
[docs] @dataclass class BuiltinSensorCfg(SensorCfg): sensor_type: SensorType """Which builtin sensor to use.""" obj: ObjRef | None = None """The type and name of the object the sensor is attached to.""" ref: ObjRef | None = None """The type and name of object to which the frame-of-reference is attached to.""" cutoff: float = 0.0 """When this value is positive, it limits the absolute value of the sensor output.""" def __post_init__(self) -> None: # Auto-prefix sensor name if it references an entity. if self.obj is not None and self.obj.entity is not None: self.name = f"{self.obj.entity}/{self.name}" if self.sensor_type in _SENSORS_REQUIRING_SITE: if self.obj is None: raise ValueError( f"Sensor type '{self.sensor_type}' requires obj with type='site'" ) if self.obj.type != "site": raise ValueError( f"Sensor type '{self.sensor_type}' requires obj.type='site', got " f"'{self.obj.type}'" ) elif self.sensor_type in _SENSORS_REQUIRING_SPATIAL_FRAME: if self.obj is None: raise ValueError( f"Sensor type '{self.sensor_type}' requires obj with spatial frame" ) if self.obj.type not in _SPATIAL_FRAME_TYPES: raise ValueError( f"Sensor type '{self.sensor_type}' requires obj.type in " f"{_SPATIAL_FRAME_TYPES}, got '{self.obj.type}'" ) elif self.sensor_type in _SENSORS_REQUIRING_BODY: if self.obj is None: raise ValueError( f"Sensor type '{self.sensor_type}' requires obj with type='body'" ) if self.obj.type != "body": raise ValueError( f"Sensor type '{self.sensor_type}' requires obj.type='body', " f"got '{self.obj.type}'" ) elif self.sensor_type in _SENSOR_OBJECT_REQUIREMENTS: required_type = _SENSOR_OBJECT_REQUIREMENTS[self.sensor_type] if self.obj is None: raise ValueError( f"Sensor type '{self.sensor_type}' requires obj with type='{required_type}'" ) if self.obj.type != required_type: raise ValueError( f"Sensor type '{self.sensor_type}' requires obj.type='{required_type}', " f"got '{self.obj.type}'" ) if self.ref is not None and self.sensor_type not in _SENSORS_ALLOWING_REF: raise ValueError( f"Sensor type '{self.sensor_type}' does not support ref specification" )
[docs] def build(self) -> BuiltinSensor: return BuiltinSensor(self)
[docs] class BuiltinSensor(Sensor[torch.Tensor]): """Wrapper over MuJoCo builtin sensors. Can add a new sensor to the spec, or wrap an existing sensor from entity XML. Returns raw MuJoCo sensordata as torch.Tensor with shape depending on sensor type (e.g., accelerometer: (N, 3), framequat: (N, 4)). Note: Caching provides minimal benefit here since data access is just a tensor slice view into MuJoCo's sensordata buffer. """
[docs] def __init__( self, cfg: BuiltinSensorCfg | None = None, name: str | None = None ) -> None: super().__init__() if cfg is not None: self._name = cfg.name self.cfg: BuiltinSensorCfg | None = cfg else: assert name is not None, "Must provide either cfg or name" self._name = name self.cfg = None self._data: mjwarp.Data | None = None self._data_view: torch.Tensor | None = None
[docs] @classmethod def from_existing(cls, name: str) -> BuiltinSensor: """Wrap an existing sensor already defined in entity XML.""" return cls(cfg=None, name=name)
[docs] def edit_spec(self, scene_spec: mujoco.MjSpec, entities: dict[str, Entity]) -> None: del entities if self.cfg is None: return # Check for duplicate sensors. for sensor in scene_spec.sensors: if sensor.name == self.cfg.name: is_entity_scoped = self.cfg.obj is not None and self.cfg.obj.entity is not None if is_entity_scoped: raise ValueError( f"Sensor '{self.cfg.name}' is defined in both entity XML and scene config. " f"Remove the sensor definition from the entity XML file, or remove the " f"BuiltinSensorCfg from scene.sensors." ) else: raise ValueError( f"Sensor '{self.cfg.name}' already exists in the scene. " f"Rename this sensor to avoid conflicts." ) # Add sensor to spec. kwargs = { "name": self.cfg.name, "type": _SENSOR_TYPE_MAP[self.cfg.sensor_type], } if self.cfg.obj is not None: kwargs["objtype"] = _OBJECT_TYPE_MAP[self.cfg.obj.type] kwargs["objname"] = self.cfg.obj.prefixed_name() if self.cfg.ref is not None: kwargs["reftype"] = _OBJECT_TYPE_MAP[self.cfg.ref.type] kwargs["refname"] = self.cfg.ref.prefixed_name() if self.cfg.cutoff > 0: kwargs["cutoff"] = self.cfg.cutoff scene_spec.add_sensor(**kwargs)
[docs] def initialize( self, mj_model: mujoco.MjModel, model: mjwarp.Model, data: mjwarp.Data, device: str ) -> None: del model, device self._data = data sensor = mj_model.sensor(self._name) start = sensor.adr[0] dim = sensor.dim[0] self._data_view = self._data.sensordata[:, start : start + dim]
def _compute_data(self) -> torch.Tensor: assert self._data_view is not None return self._data_view