Source code for mjlab.scene.scene

import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable

import mujoco
import mujoco_warp as mjwarp
import numpy as np
import torch

from mjlab.entity import Entity, EntityCfg
from mjlab.sensor import BuiltinSensor, Sensor, SensorCfg
from mjlab.terrains.terrain_importer import TerrainImporter, TerrainImporterCfg

_SCENE_XML = Path(__file__).parent / "scene.xml"


[docs] @dataclass(kw_only=True) class SceneCfg: num_envs: int = 1 env_spacing: float = 2.0 terrain: TerrainImporterCfg | None = None entities: dict[str, EntityCfg] = field(default_factory=dict) sensors: tuple[SensorCfg, ...] = field(default_factory=tuple) extent: float | None = None spec_fn: Callable[[mujoco.MjSpec], None] | None = None
[docs] class Scene:
[docs] def __init__(self, scene_cfg: SceneCfg, device: str) -> None: self._cfg = scene_cfg self._device = device self._entities: dict[str, Entity] = {} self._sensors: dict[str, Sensor] = {} self._terrain: TerrainImporter | None = None self._default_env_origins: torch.Tensor | None = None self._spec = mujoco.MjSpec.from_file(str(_SCENE_XML)) if self._cfg.extent is not None: self._spec.stat.extent = self._cfg.extent self._add_terrain() self._add_entities() self._add_sensors() if self._cfg.spec_fn is not None: self._cfg.spec_fn(self._spec)
[docs] def compile(self) -> mujoco.MjModel: return self._spec.compile()
[docs] def to_zip(self, path: Path) -> None: """Export the scene to a zip file. Warning: The generated zip may require manual adjustment of asset paths to be reloadable. Specifically, you may need to add assetdir="assets" to the compiler directive in the XML. Args: path: Output path for the zip file. TODO: Verify if this is fixed in future MuJoCo releases. """ with path.open("wb") as f: mujoco.MjSpec.to_zip(self._spec, f)
# Attributes. @property def spec(self) -> mujoco.MjSpec: return self._spec @property def env_origins(self) -> torch.Tensor: if self._terrain is not None: assert self._terrain.env_origins is not None return self._terrain.env_origins assert self._default_env_origins is not None return self._default_env_origins @property def env_spacing(self) -> float: return self._cfg.env_spacing @property def entities(self) -> dict[str, Entity]: return self._entities @property def sensors(self) -> dict[str, Sensor]: return self._sensors @property def terrain(self) -> TerrainImporter | None: return self._terrain @property def num_envs(self) -> int: return self._cfg.num_envs @property def device(self) -> str: return self._device def __getitem__(self, key: str) -> Any: if key == "terrain": if self._terrain is None: raise KeyError("No terrain configured in this scene.") return self._terrain if key in self._sensors: return self._sensors[key] if key in self._entities: return self._entities[key] # Not found, raise helpful error. available = list(self._entities.keys()) + list(self._sensors.keys()) if self._terrain is not None: available.append("terrain") raise KeyError(f"Scene element '{key}' not found. Available: {available}") # Methods.
[docs] def initialize( self, mj_model: mujoco.MjModel, model: mjwarp.Model, data: mjwarp.Data, ): self._default_env_origins = torch.zeros( (self._cfg.num_envs, 3), device=self._device, dtype=torch.float32 ) for ent in self._entities.values(): ent.initialize(mj_model, model, data, self._device) for sensor in self._sensors.values(): sensor.initialize(mj_model, model, data, self._device)
[docs] def reset(self, env_ids: torch.Tensor | slice | None = None) -> None: for ent in self._entities.values(): ent.reset(env_ids) for sensor in self._sensors.values(): sensor.reset(env_ids)
[docs] def update(self, dt: float) -> None: for ent in self._entities.values(): ent.update(dt) for sensor in self._sensors.values(): sensor.update(dt)
[docs] def write_data_to_sim(self) -> None: for ent in self._entities.values(): ent.write_data_to_sim()
# Private methods. def _add_entities(self) -> None: # Collect keyframes from each entity to merge into a single scene keyframe. # Order matters: qpos/ctrl are concatenated in entity iteration order. key_qpos: list[np.ndarray] = [] key_ctrl: list[np.ndarray] = [] for ent_name, ent_cfg in self._cfg.entities.items(): ent = ent_cfg.build() self._entities[ent_name] = ent # Extract keyframe before attach (must delete before attach to avoid corruption). if ent.spec.keys: if len(ent.spec.keys) > 1: warnings.warn( f"Entity '{ent_name}' has {len(ent.spec.keys)} keyframes; only the " "first one will be used.", stacklevel=2, ) key_qpos.append(np.array(ent.spec.keys[0].qpos)) key_ctrl.append(np.array(ent.spec.keys[0].ctrl)) ent.spec.delete(ent.spec.keys[0]) frame = self._spec.worldbody.add_frame() self._spec.attach(ent.spec, prefix=f"{ent_name}/", frame=frame) # Add merged keyframe to scene spec. if key_qpos: combined_qpos = np.concatenate(key_qpos) combined_ctrl = np.concatenate(key_ctrl) self._spec.add_key(name="init_state", qpos=combined_qpos, ctrl=combined_ctrl) def _add_terrain(self) -> None: if self._cfg.terrain is None: return self._cfg.terrain.num_envs = self._cfg.num_envs self._cfg.terrain.env_spacing = self._cfg.env_spacing self._terrain = TerrainImporter(self._cfg.terrain, self._device) frame = self._spec.worldbody.add_frame() self._spec.attach(self._terrain.spec, prefix="", frame=frame) def _add_sensors(self) -> None: for sensor_cfg in self._cfg.sensors: sns = sensor_cfg.build() sns.edit_spec(self._spec, self._entities) self._sensors[sensor_cfg.name] = sns for sns in self._spec.sensors: if sns.name not in self._sensors: self._sensors[sns.name] = BuiltinSensor.from_existing(sns.name)