Source code for mjlab.sensor.raycast_sensor

"""Raycast sensor for terrain and obstacle detection.

Ray Patterns
------------

This module provides two ray pattern types for different use cases:

**Grid Pattern** - Parallel rays in a 2D grid::

    Camera at any height:
          ↓   ↓   ↓   ↓   ↓      ← All rays point same direction
          ↓   ↓   ↓   ↓   ↓
          ↓   ↓   ↓   ↓   ↓
        ──●───●───●───●───●──    ← Fixed spacing (e.g., 10cm apart)
             Ground

- Rays are parallel (all point in the same direction, e.g., -Z down)
- Spacing is defined in world units (meters)
- Height doesn't affect the hit pattern - same footprint regardless of altitude
- Good for: height maps, terrain scanning with consistent spatial sampling

**Pinhole Camera Pattern** - Diverging rays from a single point::

    Camera LOW:                    Camera HIGH:
             📷                            📷
            /|\\                          /  |  \\
           / | \\                        /   |   \\
          /  |  \\                      /    |    \\
        ─●───●───●─                 ───●─────●─────●───
        (small footprint)           (large footprint)

- Rays diverge from a single point (like light entering a camera)
- FOV is fixed in angular units (degrees)
- Higher altitude → wider ground coverage, more spread between hits
- Lower altitude → tighter ground coverage, denser hits
- Good for: simulating depth cameras, LiDAR with angular resolution

**Pattern Comparison**:

============== ==================== ==========================
Aspect         Grid                 Pinhole
============== ==================== ==========================
Ray direction  All parallel         Diverge from origin
Spacing        Meters               Degrees (FOV)
Height affect  No                   Yes
Real-world     Orthographic proj.   Perspective camera / LiDAR
============== ==================== ==========================

The pinhole behavior matches real depth sensors (RealSense, LiDAR) - when
you're farther from an object, each pixel covers more area.


Frame Attachment
----------------

Rays are attached to a frame in the scene via ``ObjRef``. Supported frame types:

- **body**: Attach to a body's origin. Rays follow body position and orientation.
- **site**: Attach to a site. Useful for precise placement or offset from body.
- **geom**: Attach to a geometry. Useful for sensors mounted on specific parts.

Example::

    from mjlab.sensor import ObjRef, RayCastSensorCfg, GridPatternCfg

    cfg = RayCastSensorCfg(
        name="terrain_scan",
        frame=ObjRef(type="body", name="base", entity="robot"),
        pattern=GridPatternCfg(size=(1.0, 1.0), resolution=0.1),
    )

The ``exclude_parent_body`` option (default: True) prevents rays from hitting
the body they're attached to.


Ray Alignment
-------------

The ``ray_alignment`` setting controls how rays orient relative to the frame::

    Robot tilted 30°:

    "base" (default)          "yaw"                    "world"
    Rays tilt with body       Rays stay level          Rays fixed to world
          ↘ ↓ ↙                    ↓ ↓ ↓                    ↓ ↓ ↓
           \\|/                     |||                      |||
            🤖  ← tilted            🤖  ← tilted             🤖  ← tilted
           /                       /                        /

- **base**: Full position + rotation. Rays rotate with the body. Good for
  body-mounted sensors that should scan relative to the robot's orientation.

- **yaw**: Position + yaw only, ignores pitch/roll. Rays always point straight
  down regardless of body tilt. Good for height maps where you want consistent
  vertical sampling even when the robot is on a slope.

- **world**: Fixed in world frame, only position follows body. Rays always
  point in a fixed world direction. Good for gravity-aligned measurements.


Debug Visualization
-------------------

Enable visualization with ``debug_vis=True`` and customize via ``VizCfg``::

    cfg = RayCastSensorCfg(
        name="scan",
        frame=ObjRef(type="body", name="base", entity="robot"),
        pattern=GridPatternCfg(),
        debug_vis=True,
        viz=RayCastSensorCfg.VizCfg(
            hit_color=(0, 1, 0, 0.8),      # Green for hits
            miss_color=(1, 0, 0, 0.4),     # Red for misses
            show_rays=True,                 # Draw ray arrows
            show_normals=True,              # Draw surface normals
            normal_color=(1, 1, 0, 1),     # Yellow normals
        ),
    )

Visualization options:

- ``hit_color`` / ``miss_color``: RGBA colors for ray arrows
- ``hit_sphere_color`` / ``hit_sphere_radius``: Spheres at hit points
- ``show_rays``: Draw arrows from origin to hit/miss points
- ``show_normals`` / ``normal_color`` / ``normal_length``: Surface normal arrows


Geom Group Filtering
--------------------

MuJoCo geoms can be assigned to groups 0-5. Use ``include_geom_groups`` to
filter which groups the rays can hit::

    cfg = RayCastSensorCfg(
        name="terrain_only",
        frame=ObjRef(type="body", name="base", entity="robot"),
        pattern=GridPatternCfg(),
        include_geom_groups=(0, 1),  # Only hit geoms in groups 0 and 1
    )

This is useful for ignoring certain geometry (e.g., visual-only geoms in
group 3) while still detecting collisions with terrain (group 0).


Output Data
-----------

Access sensor data via the ``data`` property, which returns ``RayCastData``:

- ``distances``: [B, N] Distance to hit, or -1 if no hit / beyond max_distance
- ``hit_pos_w``: [B, N, 3] World-space hit positions
- ``normals_w``: [B, N, 3] Surface normals at hit points (world frame)
- ``pos_w``: [B, 3] Sensor frame position
- ``quat_w``: [B, 4] Sensor frame orientation (w, x, y, z)

Where B = number of environments, N = number of rays.
"""

from __future__ import annotations

import math
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Literal

import mujoco
import mujoco_warp as mjwarp
import torch
import warp as wp
from mujoco_warp import rays

from mjlab.entity import Entity
from mjlab.sensor.builtin_sensor import ObjRef
from mjlab.sensor.sensor import Sensor, SensorCfg
from mjlab.utils.lab_api.math import quat_from_matrix

if TYPE_CHECKING:
  from mjlab.viewer.debug_visualizer import DebugVisualizer


# NOTE: Need to define this here because it's not publicly exposed by mujoco_warp.
vec6 = wp.types.vector(length=6, dtype=float)

# Type aliases for configuration choices.
RayAlignment = Literal["base", "yaw", "world"]


[docs] @dataclass class GridPatternCfg: """Grid pattern - parallel rays in a 2D grid.""" size: tuple[float, float] = (1.0, 1.0) """Grid size (length, width) in meters.""" resolution: float = 0.1 """Spacing between rays in meters.""" direction: tuple[float, float, float] = (0.0, 0.0, -1.0) """Ray direction in frame-local coordinates."""
[docs] def generate_rays( self, mj_model: mujoco.MjModel | None, device: str ) -> tuple[torch.Tensor, torch.Tensor]: """Generate ray pattern. Args: mj_model: MuJoCo model (unused for grid pattern). device: Device for tensor operations. Returns: Tuple of (local_offsets [N, 3], local_directions [N, 3]). """ del mj_model # Unused for grid pattern size_x, size_y = self.size res = self.resolution x = torch.arange( -size_x / 2, size_x / 2 + res * 0.5, res, device=device, dtype=torch.float32 ) y = torch.arange( -size_y / 2, size_y / 2 + res * 0.5, res, device=device, dtype=torch.float32 ) grid_x, grid_y = torch.meshgrid(x, y, indexing="xy") num_rays = grid_x.numel() local_offsets = torch.zeros((num_rays, 3), device=device, dtype=torch.float32) local_offsets[:, 0] = grid_x.flatten() local_offsets[:, 1] = grid_y.flatten() # All rays share the same direction for grid pattern. direction = torch.tensor(self.direction, device=device, dtype=torch.float32) direction = direction / direction.norm() local_directions = direction.unsqueeze(0).expand(num_rays, 3).clone() return local_offsets, local_directions
[docs] @dataclass class PinholeCameraPatternCfg: """Pinhole camera pattern - rays diverging from origin like a camera. Can be configured with explicit parameters (width, height, fovy) or created via factory methods like from_mujoco_camera() or from_intrinsic_matrix(). """ width: int = 16 """Image width in pixels.""" height: int = 12 """Image height in pixels.""" fovy: float = 45.0 """Vertical field of view in degrees (matches MuJoCo convention).""" _camera_name: str | None = field(default=None, repr=False) """Internal: MuJoCo camera name for deferred parameter resolution."""
[docs] @classmethod def from_mujoco_camera(cls, camera_name: str) -> PinholeCameraPatternCfg: """Create config that references a MuJoCo camera. Camera parameters (resolution, FOV) are resolved at runtime from the model. Args: camera_name: Name of the MuJoCo camera to reference. Returns: Config that will resolve parameters from the MuJoCo camera. """ # Placeholder values; actual values resolved in generate_rays(). return cls(width=0, height=0, fovy=0.0, _camera_name=camera_name)
[docs] @classmethod def from_intrinsic_matrix( cls, intrinsic_matrix: list[float], width: int, height: int ) -> PinholeCameraPatternCfg: """Create from 3x3 intrinsic matrix [fx, 0, cx, 0, fy, cy, 0, 0, 1]. Args: intrinsic_matrix: Flattened 3x3 intrinsic matrix. width: Image width in pixels. height: Image height in pixels. Returns: Config with fovy computed from the intrinsic matrix. """ fy = intrinsic_matrix[4] # fy is at position [1,1] in the matrix fovy = 2 * math.atan(height / (2 * fy)) * 180 / math.pi return cls(width=width, height=height, fovy=fovy)
[docs] def generate_rays( self, mj_model: mujoco.MjModel | None, device: str ) -> tuple[torch.Tensor, torch.Tensor]: """Generate ray pattern. Args: mj_model: MuJoCo model (required if using from_mujoco_camera). device: Device for tensor operations. Returns: Tuple of (local_offsets [N, 3], local_directions [N, 3]). """ # Resolve camera parameters. if self._camera_name is not None: if mj_model is None: raise ValueError("MuJoCo model required when using from_mujoco_camera()") # Get parameters from MuJoCo camera. cam_id = mj_model.camera(self._camera_name).id width, height = mj_model.cam_resolution[cam_id] # MuJoCo has two camera modes: # 1. fovy mode: sensorsize is zero, use cam_fovy directly # 2. Physical sensor mode: sensorsize > 0, compute from focal/sensorsize sensorsize = mj_model.cam_sensorsize[cam_id] if sensorsize[0] > 0 and sensorsize[1] > 0: # Physical sensor model. intrinsic = mj_model.cam_intrinsic[cam_id] # [fx, fy, cx, cy] focal = intrinsic[:2] # [fx, fy] h_fov_rad = 2 * math.atan(sensorsize[0] / (2 * focal[0])) v_fov_rad = 2 * math.atan(sensorsize[1] / (2 * focal[1])) else: # Read vertical FOV directly from MuJoCo. v_fov_rad = math.radians(mj_model.cam_fovy[cam_id]) aspect = width / height h_fov_rad = 2 * math.atan(math.tan(v_fov_rad / 2) * aspect) else: # Use explicit parameters. width = self.width height = self.height v_fov_rad = math.radians(self.fovy) aspect = width / height h_fov_rad = 2 * math.atan(math.tan(v_fov_rad / 2) * aspect) # Create normalized pixel coordinates [-1, 1]. u = torch.linspace(-1, 1, width, device=device, dtype=torch.float32) v = torch.linspace(-1, 1, height, device=device, dtype=torch.float32) grid_u, grid_v = torch.meshgrid(u, v, indexing="xy") # Convert to ray directions (MuJoCo camera: -Z forward, +X right, +Y down). ray_x = grid_u.flatten() * math.tan(h_fov_rad / 2) ray_y = grid_v.flatten() * math.tan(v_fov_rad / 2) ray_z = -torch.ones_like(ray_x) # Negative Z for MuJoCo camera forward num_rays = width * height local_offsets = torch.zeros((num_rays, 3), device=device) local_directions = torch.stack([ray_x, ray_y, ray_z], dim=1) local_directions = local_directions / local_directions.norm(dim=1, keepdim=True) return local_offsets, local_directions
PatternCfg = GridPatternCfg | PinholeCameraPatternCfg
[docs] @dataclass class RayCastData: """Raycast sensor output data.""" distances: torch.Tensor """[B, N] Distance to hit point. -1 if no hit.""" normals_w: torch.Tensor """[B, N, 3] Surface normal at hit point (world frame). Zero if no hit.""" hit_pos_w: torch.Tensor """[B, N, 3] Hit position in world frame. Zero if no hit.""" pos_w: torch.Tensor """[B, 3] Frame position in world coordinates.""" quat_w: torch.Tensor """[B, 4] Frame orientation quaternion (w, x, y, z) in world coordinates."""
[docs] @dataclass class RayCastSensorCfg(SensorCfg): """Raycast sensor configuration. Supports multiple ray patterns (grid, pinhole camera) and alignment modes. """
[docs] @dataclass class VizCfg: """Visualization settings for debug rendering.""" hit_color: tuple[float, float, float, float] = (0.0, 1.0, 0.0, 0.8) """RGBA color for rays that hit a surface.""" miss_color: tuple[float, float, float, float] = (1.0, 0.0, 0.0, 0.4) """RGBA color for rays that miss.""" hit_sphere_color: tuple[float, float, float, float] = (0.0, 1.0, 1.0, 1.0) """RGBA color for spheres drawn at hit points.""" hit_sphere_radius: float = 0.5 """Radius of spheres drawn at hit points (multiplier of meansize).""" show_rays: bool = False """Whether to draw ray arrows.""" show_normals: bool = False """Whether to draw surface normals at hit points.""" normal_color: tuple[float, float, float, float] = (1.0, 1.0, 0.0, 1.0) """RGBA color for surface normal arrows.""" normal_length: float = 5.0 """Length of surface normal arrows (multiplier of meansize)."""
frame: ObjRef """Body or site to attach rays to.""" pattern: PatternCfg = field(default_factory=GridPatternCfg) """Ray pattern configuration. Defaults to GridPatternCfg.""" ray_alignment: RayAlignment = "base" """How rays align with the frame. - "base": Full position + rotation (default). - "yaw": Position + yaw only, ignores pitch/roll (good for height maps). - "world": Fixed in world frame, position only follows body. """ max_distance: float = 10.0 """Maximum ray distance. Rays beyond this report -1.""" exclude_parent_body: bool = True """Exclude parent body from ray intersection tests.""" include_geom_groups: tuple[int, ...] | None = None """Geom groups (0-5) to include in raycasting. None means all groups.""" debug_vis: bool = False """Enable debug visualization.""" viz: VizCfg = field(default_factory=VizCfg) """Visualization settings."""
[docs] def build(self) -> RayCastSensor: return RayCastSensor(self)
[docs] class RayCastSensor(Sensor[RayCastData]): """Raycast sensor for terrain and obstacle detection."""
[docs] def __init__(self, cfg: RayCastSensorCfg) -> None: super().__init__() self.cfg = cfg self._data: mjwarp.Data | None = None self._model: mjwarp.Model | None = None self._mj_model: mujoco.MjModel | None = None self._device: str | None = None self._wp_device: wp.context.Device | None = None self._frame_body_id: int | None = None self._frame_site_id: int | None = None self._frame_geom_id: int | None = None self._frame_type: Literal["body", "site", "geom"] = "body" self._local_offsets: torch.Tensor | None = None self._local_directions: torch.Tensor | None = None # [N, 3] per-ray directions self._num_rays: int = 0 self._ray_pnt: wp.array | None = None self._ray_vec: wp.array | None = None self._ray_dist: wp.array | None = None self._ray_geomid: wp.array | None = None self._ray_normal: wp.array | None = None self._ray_bodyexclude: wp.array | None = None self._geomgroup = vec6(-1, -1, -1, -1, -1, -1) self._distances: torch.Tensor | None = None self._normals_w: torch.Tensor | None = None self._hit_pos_w: torch.Tensor | None = None self._pos_w: torch.Tensor | None = None self._quat_w: torch.Tensor | None = None self._raycast_graph: wp.Graph | None = None self._use_cuda_graph: bool = False
[docs] def edit_spec( self, scene_spec: mujoco.MjSpec, entities: dict[str, Entity], ) -> None: del scene_spec, entities
[docs] def initialize( self, mj_model: mujoco.MjModel, model: mjwarp.Model, data: mjwarp.Data, device: str, ) -> None: self._data = data self._model = model self._mj_model = mj_model self._device = device self._wp_device = wp.get_device(device) num_envs = data.nworld frame = self.cfg.frame frame_name = frame.prefixed_name() if frame.type == "body": self._frame_body_id = mj_model.body(frame_name).id self._frame_type = "body" elif frame.type == "site": self._frame_site_id = mj_model.site(frame_name).id # Look up parent body for exclusion. self._frame_body_id = int(mj_model.site_bodyid[self._frame_site_id]) self._frame_type = "site" elif frame.type == "geom": self._frame_geom_id = mj_model.geom(frame_name).id # Look up parent body for exclusion. self._frame_body_id = int(mj_model.geom_bodyid[self._frame_geom_id]) self._frame_type = "geom" else: raise ValueError( f"RayCastSensor frame must be 'body', 'site', or 'geom', got '{frame.type}'" ) # Generate ray pattern. pattern = self.cfg.pattern self._local_offsets, self._local_directions = pattern.generate_rays( mj_model, device ) self._num_rays = self._local_offsets.shape[0] self._ray_pnt = wp.zeros((num_envs, self._num_rays), dtype=wp.vec3, device=device) self._ray_vec = wp.zeros((num_envs, self._num_rays), dtype=wp.vec3, device=device) self._ray_dist = wp.zeros((num_envs, self._num_rays), dtype=float, device=device) self._ray_geomid = wp.zeros((num_envs, self._num_rays), dtype=int, device=device) self._ray_normal = wp.zeros( (num_envs, self._num_rays), dtype=wp.vec3, device=device ) body_exclude = ( self._frame_body_id if self.cfg.exclude_parent_body and self._frame_body_id is not None else -1 ) self._ray_bodyexclude = wp.full( (self._num_rays,), body_exclude, dtype=int, # pyright: ignore[reportArgumentType] device=device, ) # Convert include_geom_groups to vec6 format (-1 = include, 0 = exclude). if self.cfg.include_geom_groups is not None: groups = [0, 0, 0, 0, 0, 0] for g in self.cfg.include_geom_groups: if 0 <= g <= 5: groups[g] = -1 self._geomgroup = vec6(*groups) else: self._geomgroup = vec6(-1, -1, -1, -1, -1, -1) # All groups assert self._wp_device is not None self._use_cuda_graph = self._wp_device.is_cuda and wp.is_mempool_enabled( self._wp_device ) if self._use_cuda_graph: self._create_graph()
def _compute_data(self) -> RayCastData: self._perform_raycast() assert self._distances is not None and self._normals_w is not None assert self._hit_pos_w is not None assert self._pos_w is not None and self._quat_w is not None return RayCastData( distances=self._distances, normals_w=self._normals_w, hit_pos_w=self._hit_pos_w, pos_w=self._pos_w, quat_w=self._quat_w, ) @property def num_rays(self) -> int: return self._num_rays
[docs] def debug_vis(self, visualizer: DebugVisualizer) -> None: if not self.cfg.debug_vis: return assert self._data is not None assert self._local_offsets is not None assert self._local_directions is not None env_idx = visualizer.env_idx data = self.data if self._frame_type == "body": frame_pos = self._data.xpos[env_idx, self._frame_body_id].cpu().numpy() frame_mat_tensor = self._data.xmat[env_idx, self._frame_body_id].view(3, 3) elif self._frame_type == "site": frame_pos = self._data.site_xpos[env_idx, self._frame_site_id].cpu().numpy() frame_mat_tensor = self._data.site_xmat[env_idx, self._frame_site_id].view(3, 3) else: # geom frame_pos = self._data.geom_xpos[env_idx, self._frame_geom_id].cpu().numpy() frame_mat_tensor = self._data.geom_xmat[env_idx, self._frame_geom_id].view(3, 3) # Apply ray alignment for visualization. rot_mat_tensor = self._compute_alignment_rotation(frame_mat_tensor.unsqueeze(0))[0] rot_mat = rot_mat_tensor.cpu().numpy() local_offsets_np = self._local_offsets.cpu().numpy() local_dirs_np = self._local_directions.cpu().numpy() hit_positions_np = data.hit_pos_w[env_idx].cpu().numpy() distances_np = data.distances[env_idx].cpu().numpy() normals_np = data.normals_w[env_idx].cpu().numpy() meansize = visualizer.meansize ray_width = 0.1 * meansize sphere_radius = self.cfg.viz.hit_sphere_radius * meansize normal_length = self.cfg.viz.normal_length * meansize normal_width = 0.1 * meansize for i in range(self._num_rays): origin = frame_pos + rot_mat @ local_offsets_np[i] hit = distances_np[i] >= 0 if hit: end = hit_positions_np[i] color = self.cfg.viz.hit_color else: direction = rot_mat @ local_dirs_np[i] end = origin + direction * min(0.5, self.cfg.max_distance * 0.05) color = self.cfg.viz.miss_color if self.cfg.viz.show_rays: visualizer.add_arrow( start=origin, end=end, color=color, width=ray_width, label=f"{self.cfg.name}_ray_{i}", ) if hit: visualizer.add_sphere( center=end, radius=sphere_radius, color=self.cfg.viz.hit_sphere_color, label=f"{self.cfg.name}_hit_{i}", ) if self.cfg.viz.show_normals: normal_end = end + normals_np[i] * normal_length visualizer.add_arrow( start=end, end=normal_end, color=self.cfg.viz.normal_color, width=normal_width, label=f"{self.cfg.name}_normal_{i}", )
# Private methods. def _create_graph(self) -> None: """Capture CUDA graph for raycast operation.""" assert self._wp_device is not None and self._wp_device.is_cuda with wp.ScopedDevice(self._wp_device): with wp.ScopedCapture() as capture: self._raycast_direct() self._raycast_graph = capture.graph def _raycast_direct(self) -> None: """Execute raycast kernel directly.""" rays( m=self._model.struct, # type: ignore[attr-defined] d=self._data.struct, # type: ignore[attr-defined] pnt=self._ray_pnt, vec=self._ray_vec, geomgroup=self._geomgroup, # pyright: ignore[reportArgumentType] flg_static=True, bodyexclude=self._ray_bodyexclude, dist=self._ray_dist, geomid=self._ray_geomid, normal=self._ray_normal, ) def _perform_raycast(self) -> None: assert self._data is not None and self._model is not None assert self._local_offsets is not None and self._local_directions is not None if self._frame_type == "body": frame_pos = self._data.xpos[:, self._frame_body_id] frame_mat = self._data.xmat[:, self._frame_body_id].view(-1, 3, 3) elif self._frame_type == "site": frame_pos = self._data.site_xpos[:, self._frame_site_id] frame_mat = self._data.site_xmat[:, self._frame_site_id].view(-1, 3, 3) else: # geom frame_pos = self._data.geom_xpos[:, self._frame_geom_id] frame_mat = self._data.geom_xmat[:, self._frame_geom_id].view(-1, 3, 3) num_envs = frame_pos.shape[0] # Apply ray alignment. rot_mat = self._compute_alignment_rotation(frame_mat) # Transform ray origins. world_offsets = torch.einsum("bij,nj->bni", rot_mat, self._local_offsets) world_origins = frame_pos.unsqueeze(1) + world_offsets # Transform ray directions (per-ray). world_rays = torch.einsum("bij,nj->bni", rot_mat, self._local_directions) assert self._ray_pnt is not None and self._ray_vec is not None pnt_torch = wp.to_torch(self._ray_pnt).view(num_envs, self._num_rays, 3) vec_torch = wp.to_torch(self._ray_vec).view(num_envs, self._num_rays, 3) pnt_torch.copy_(world_origins) vec_torch.copy_(world_rays) if self._use_cuda_graph and self._raycast_graph is not None: with wp.ScopedDevice(self._wp_device): wp.capture_launch(self._raycast_graph) else: self._raycast_direct() assert self._ray_dist is not None and self._ray_normal is not None self._distances = wp.to_torch(self._ray_dist) self._normals_w = wp.to_torch(self._ray_normal).view(num_envs, self._num_rays, 3) self._distances[self._distances > self.cfg.max_distance] = -1.0 # Compute hit positions: origin + direction * distance. # For misses (distance = -1), hit_pos_w will be invalid (but normals_w are zero). assert self._distances is not None hit_mask = self._distances >= 0 hit_pos_w = world_origins.clone() hit_pos_w[hit_mask] = world_origins[hit_mask] + world_rays[ hit_mask ] * self._distances[hit_mask].unsqueeze(-1) self._hit_pos_w = hit_pos_w self._pos_w = frame_pos.clone() self._quat_w = quat_from_matrix(frame_mat) def _compute_alignment_rotation(self, frame_mat: torch.Tensor) -> torch.Tensor: """Compute rotation matrix based on ray_alignment setting.""" if self.cfg.ray_alignment == "base": # Full rotation. return frame_mat elif self.cfg.ray_alignment == "yaw": # Extract yaw only, zero out pitch/roll. return self._extract_yaw_rotation(frame_mat) elif self.cfg.ray_alignment == "world": # Identity rotation (world-aligned). num_envs = frame_mat.shape[0] return ( torch.eye(3, device=frame_mat.device, dtype=frame_mat.dtype) .unsqueeze(0) .expand(num_envs, -1, -1) ) else: raise ValueError(f"Unknown ray_alignment: {self.cfg.ray_alignment}") def _extract_yaw_rotation(self, rot_mat: torch.Tensor) -> torch.Tensor: """Extract yaw-only rotation matrix (rotation around Z axis). Handles the singularity at ±90° pitch by falling back to the Y-axis when the X-axis projection onto the XY plane is too small. """ batch_size = rot_mat.shape[0] device = rot_mat.device dtype = rot_mat.dtype # Project X-axis onto XY plane. x_axis = rot_mat[:, :, 0] # First column [B, 3] x_proj = x_axis.clone() x_proj[:, 2] = 0 # Zero out Z component x_norm = x_proj.norm(dim=1) # [B] # Check for singularity (X-axis nearly vertical). threshold = 0.1 singular = x_norm < threshold # [B] # For singular cases, use Y-axis instead. if singular.any(): y_axis = rot_mat[:, :, 1] # Second column [B, 3] y_proj = y_axis.clone() y_proj[:, 2] = 0 y_norm = y_proj.norm(dim=1).clamp(min=1e-6) y_proj = y_proj / y_norm.unsqueeze(-1) # Y-axis points left; rotate -90° around Z to get forward direction. # [y_x, y_y] -> [y_y, -y_x] x_from_y = torch.zeros_like(y_proj) x_from_y[:, 0] = y_proj[:, 1] x_from_y[:, 1] = -y_proj[:, 0] x_proj[singular] = x_from_y[singular] x_norm[singular] = 1.0 # Already normalized # Normalize X projection. x_norm = x_norm.clamp(min=1e-6) x_proj = x_proj / x_norm.unsqueeze(-1) # Build yaw-only rotation matrix. yaw_mat = torch.zeros((batch_size, 3, 3), device=device, dtype=dtype) yaw_mat[:, 0, 0] = x_proj[:, 0] yaw_mat[:, 1, 0] = x_proj[:, 1] yaw_mat[:, 0, 1] = -x_proj[:, 1] yaw_mat[:, 1, 1] = x_proj[:, 0] yaw_mat[:, 2, 2] = 1 return yaw_mat