import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable
import mujoco
import mujoco_warp as mjwarp
import numpy as np
import torch
from mjlab.entity import Entity, EntityCfg
from mjlab.sensor import BuiltinSensor, Sensor, SensorCfg
from mjlab.terrains.terrain_importer import TerrainImporter, TerrainImporterCfg
_SCENE_XML = Path(__file__).parent / "scene.xml"
[docs]
@dataclass(kw_only=True)
class SceneCfg:
num_envs: int = 1
env_spacing: float = 2.0
terrain: TerrainImporterCfg | None = None
entities: dict[str, EntityCfg] = field(default_factory=dict)
sensors: tuple[SensorCfg, ...] = field(default_factory=tuple)
extent: float | None = None
spec_fn: Callable[[mujoco.MjSpec], None] | None = None
[docs]
class Scene:
[docs]
def __init__(self, scene_cfg: SceneCfg, device: str) -> None:
self._cfg = scene_cfg
self._device = device
self._entities: dict[str, Entity] = {}
self._sensors: dict[str, Sensor] = {}
self._terrain: TerrainImporter | None = None
self._default_env_origins: torch.Tensor | None = None
self._spec = mujoco.MjSpec.from_file(str(_SCENE_XML))
if self._cfg.extent is not None:
self._spec.stat.extent = self._cfg.extent
self._add_terrain()
self._add_entities()
self._add_sensors()
if self._cfg.spec_fn is not None:
self._cfg.spec_fn(self._spec)
[docs]
def compile(self) -> mujoco.MjModel:
return self._spec.compile()
[docs]
def to_zip(self, path: Path) -> None:
"""Export the scene to a zip file.
Warning: The generated zip may require manual adjustment of asset paths
to be reloadable. Specifically, you may need to add assetdir="assets"
to the compiler directive in the XML.
Args:
path: Output path for the zip file.
TODO: Verify if this is fixed in future MuJoCo releases.
"""
with path.open("wb") as f:
mujoco.MjSpec.to_zip(self._spec, f)
# Attributes.
@property
def spec(self) -> mujoco.MjSpec:
return self._spec
@property
def env_origins(self) -> torch.Tensor:
if self._terrain is not None:
assert self._terrain.env_origins is not None
return self._terrain.env_origins
assert self._default_env_origins is not None
return self._default_env_origins
@property
def env_spacing(self) -> float:
return self._cfg.env_spacing
@property
def entities(self) -> dict[str, Entity]:
return self._entities
@property
def sensors(self) -> dict[str, Sensor]:
return self._sensors
@property
def terrain(self) -> TerrainImporter | None:
return self._terrain
@property
def num_envs(self) -> int:
return self._cfg.num_envs
@property
def device(self) -> str:
return self._device
def __getitem__(self, key: str) -> Any:
if key == "terrain":
if self._terrain is None:
raise KeyError("No terrain configured in this scene.")
return self._terrain
if key in self._sensors:
return self._sensors[key]
if key in self._entities:
return self._entities[key]
# Not found, raise helpful error.
available = list(self._entities.keys()) + list(self._sensors.keys())
if self._terrain is not None:
available.append("terrain")
raise KeyError(f"Scene element '{key}' not found. Available: {available}")
# Methods.
[docs]
def initialize(
self,
mj_model: mujoco.MjModel,
model: mjwarp.Model,
data: mjwarp.Data,
):
self._default_env_origins = torch.zeros(
(self._cfg.num_envs, 3), device=self._device, dtype=torch.float32
)
for ent in self._entities.values():
ent.initialize(mj_model, model, data, self._device)
for sensor in self._sensors.values():
sensor.initialize(mj_model, model, data, self._device)
[docs]
def reset(self, env_ids: torch.Tensor | slice | None = None) -> None:
for ent in self._entities.values():
ent.reset(env_ids)
for sensor in self._sensors.values():
sensor.reset(env_ids)
[docs]
def update(self, dt: float) -> None:
for ent in self._entities.values():
ent.update(dt)
for sensor in self._sensors.values():
sensor.update(dt)
[docs]
def write_data_to_sim(self) -> None:
for ent in self._entities.values():
ent.write_data_to_sim()
# Private methods.
def _add_entities(self) -> None:
# Collect keyframes from each entity to merge into a single scene keyframe.
# Order matters: qpos/ctrl are concatenated in entity iteration order.
key_qpos: list[np.ndarray] = []
key_ctrl: list[np.ndarray] = []
for ent_name, ent_cfg in self._cfg.entities.items():
ent = ent_cfg.build()
self._entities[ent_name] = ent
# Extract keyframe before attach (must delete before attach to avoid corruption).
if ent.spec.keys:
if len(ent.spec.keys) > 1:
warnings.warn(
f"Entity '{ent_name}' has {len(ent.spec.keys)} keyframes; only the "
"first one will be used.",
stacklevel=2,
)
key_qpos.append(np.array(ent.spec.keys[0].qpos))
key_ctrl.append(np.array(ent.spec.keys[0].ctrl))
ent.spec.delete(ent.spec.keys[0])
frame = self._spec.worldbody.add_frame()
self._spec.attach(ent.spec, prefix=f"{ent_name}/", frame=frame)
# Add merged keyframe to scene spec.
if key_qpos:
combined_qpos = np.concatenate(key_qpos)
combined_ctrl = np.concatenate(key_ctrl)
self._spec.add_key(name="init_state", qpos=combined_qpos, ctrl=combined_ctrl)
def _add_terrain(self) -> None:
if self._cfg.terrain is None:
return
self._cfg.terrain.num_envs = self._cfg.num_envs
self._cfg.terrain.env_spacing = self._cfg.env_spacing
self._terrain = TerrainImporter(self._cfg.terrain, self._device)
frame = self._spec.worldbody.add_frame()
self._spec.attach(self._terrain.spec, prefix="", frame=frame)
def _add_sensors(self) -> None:
for sensor_cfg in self._cfg.sensors:
sns = sensor_cfg.build()
sns.edit_spec(self._spec, self._entities)
self._sensors[sensor_cfg.name] = sns
for sns in self._spec.sensors:
if sns.name not in self._sensors:
self._sensors[sns.name] = BuiltinSensor.from_existing(sns.name)