Source code for mjlab.sim.sim

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Literal, cast

import mujoco
import mujoco_warp as mjwarp
import torch
import warp as wp

from mjlab.sim.randomization import expand_model_fields
from mjlab.sim.sim_data import TorchArray, WarpBridge
from mjlab.utils.nan_guard import NanGuard, NanGuardCfg

# Type aliases for better IDE support while maintaining runtime compatibility
# At runtime, WarpBridge wraps the actual MJWarp objects.
if TYPE_CHECKING:
  ModelBridge = mjwarp.Model
  DataBridge = mjwarp.Data
else:
  ModelBridge = WarpBridge
  DataBridge = WarpBridge

# Minimum CUDA driver version supported for conditional CUDA graphs.
_GRAPH_CAPTURE_MIN_DRIVER = (12, 4)

_JACOBIAN_MAP = {
  "auto": mujoco.mjtJacobian.mjJAC_AUTO,
  "dense": mujoco.mjtJacobian.mjJAC_DENSE,
  "sparse": mujoco.mjtJacobian.mjJAC_SPARSE,
}
_CONE_MAP = {
  "elliptic": mujoco.mjtCone.mjCONE_ELLIPTIC,
  "pyramidal": mujoco.mjtCone.mjCONE_PYRAMIDAL,
}
_INTEGRATOR_MAP = {
  "euler": mujoco.mjtIntegrator.mjINT_EULER,
  "implicitfast": mujoco.mjtIntegrator.mjINT_IMPLICITFAST,
}
_SOLVER_MAP = {
  "newton": mujoco.mjtSolver.mjSOL_NEWTON,
  "cg": mujoco.mjtSolver.mjSOL_CG,
  "pgs": mujoco.mjtSolver.mjSOL_PGS,
}


[docs] @dataclass class MujocoCfg: """Configuration for MuJoCo simulation parameters.""" # Integrator settings. timestep: float = 0.002 integrator: Literal["euler", "implicitfast"] = "implicitfast" # Friction settings. impratio: float = 1.0 cone: Literal["pyramidal", "elliptic"] = "pyramidal" # Solver settings. jacobian: Literal["auto", "dense", "sparse"] = "auto" solver: Literal["newton", "cg", "pgs"] = "newton" iterations: int = 100 tolerance: float = 1e-8 ls_iterations: int = 50 ls_tolerance: float = 0.01 ccd_iterations: int = 50 # Other. gravity: tuple[float, float, float] = (0, 0, -9.81) multiccd: bool = False
[docs] def apply(self, model: mujoco.MjModel) -> None: """Apply configuration settings to a compiled MjModel.""" model.opt.jacobian = _JACOBIAN_MAP[self.jacobian] model.opt.cone = _CONE_MAP[self.cone] model.opt.integrator = _INTEGRATOR_MAP[self.integrator] model.opt.solver = _SOLVER_MAP[self.solver] model.opt.timestep = self.timestep model.opt.impratio = self.impratio model.opt.gravity[:] = self.gravity model.opt.iterations = self.iterations model.opt.tolerance = self.tolerance model.opt.ls_iterations = self.ls_iterations model.opt.ls_tolerance = self.ls_tolerance model.opt.ccd_iterations = self.ccd_iterations if self.multiccd: model.opt.enableflags |= mujoco.mjtEnableBit.mjENBL_MULTICCD
[docs] @dataclass(kw_only=True) class SimulationCfg: nconmax: int | None = None """Number of contacts to allocate per world. Contacts exist in large heterogenous arrays: one world may have more than nconmax contacts. If None, a heuristic value is used.""" njmax: int | None = None """Number of constraints to allocate per world. Constraint arrays are batched by world: no world may have more than njmax constraints. If None, a heuristic value is used.""" ls_parallel: bool = True # Boosts perf quite noticeably. contact_sensor_maxmatch: int = 64 mujoco: MujocoCfg = field(default_factory=MujocoCfg) nan_guard: NanGuardCfg = field(default_factory=NanGuardCfg)
[docs] class Simulation: """GPU-accelerated MuJoCo simulation powered by MJWarp. CUDA Graph Capture ------------------ On CUDA devices with memory pools enabled, the simulation captures CUDA graphs for ``step()``, ``forward()``, and ``reset()`` operations. Graph capture records a sequence of GPU kernels and their memory addresses, then replays the entire sequence with a single kernel launch, eliminating CPU overhead from repeated kernel dispatches. **Important:** A captured graph holds pointers to the GPU arrays that existed at capture time. If those arrays are later replaced (e.g., via ``expand_model_fields()``), the graph will still read from the old arrays, silently ignoring any new values. The ``expand_model_fields()`` method handles this automatically by calling ``create_graph()`` after replacing arrays. If you write code that replaces model or data arrays after simulation initialization, you **must** call ``create_graph()`` afterward to re-capture the graphs with the new memory addresses. """
[docs] def __init__( self, num_envs: int, cfg: SimulationCfg, model: mujoco.MjModel, device: str ): self.cfg = cfg self.device = device self.wp_device = wp.get_device(self.device) self.num_envs = num_envs self._default_model_fields: dict[str, torch.Tensor] = {} # MuJoCo model and data. self._mj_model = model cfg.mujoco.apply(self._mj_model) self._mj_data = mujoco.MjData(model) mujoco.mj_forward(self._mj_model, self._mj_data) # MJWarp model and data. with wp.ScopedDevice(self.wp_device): self._wp_model = mjwarp.put_model(self._mj_model) self._wp_model.opt.ls_parallel = cfg.ls_parallel self._wp_model.opt.contact_sensor_maxmatch = cfg.contact_sensor_maxmatch self._wp_data = mjwarp.put_data( self._mj_model, self._mj_data, nworld=self.num_envs, nconmax=self.cfg.nconmax, njmax=self.cfg.njmax, ) self._reset_mask_wp = wp.zeros(num_envs, dtype=bool) self._reset_mask = TorchArray(self._reset_mask_wp) self._model_bridge = WarpBridge(self._wp_model, nworld=self.num_envs) self._data_bridge = WarpBridge(self._wp_data) self.use_cuda_graph = self._should_use_cuda_graph() self.create_graph() self.nan_guard = NanGuard(cfg.nan_guard, self.num_envs, self._mj_model)
[docs] def create_graph(self) -> None: """Capture CUDA graphs for step, forward, and reset operations. This method must be called whenever GPU arrays in the model or data are replaced after initialization. The captured graphs hold pointers to the arrays that existed at capture time. If those arrays are replaced, the graphs will silently read from the old arrays, ignoring any new values. Called automatically by: - ``__init__()`` during simulation initialization - ``expand_model_fields()`` after replacing model arrays On CPU devices or when memory pools are disabled, this is a no-op. """ self.step_graph = None self.forward_graph = None self.reset_graph = None if self.use_cuda_graph: with wp.ScopedDevice(self.wp_device): with wp.ScopedCapture() as capture: mjwarp.step(self.wp_model, self.wp_data) self.step_graph = capture.graph with wp.ScopedCapture() as capture: mjwarp.forward(self.wp_model, self.wp_data) self.forward_graph = capture.graph with wp.ScopedCapture() as capture: mjwarp.reset_data(self.wp_model, self.wp_data, reset=self._reset_mask_wp) self.reset_graph = capture.graph
# Properties. @property def mj_model(self) -> mujoco.MjModel: return self._mj_model @property def mj_data(self) -> mujoco.MjData: return self._mj_data @property def wp_model(self) -> mjwarp.Model: return self._wp_model @property def wp_data(self) -> mjwarp.Data: return self._wp_data @property def data(self) -> "DataBridge": return cast("DataBridge", self._data_bridge) @property def model(self) -> "ModelBridge": return cast("ModelBridge", self._model_bridge) @property def default_model_fields(self) -> dict[str, torch.Tensor]: """Default values for expanded model fields, used in domain randomization.""" return self._default_model_fields # Methods.
[docs] def expand_model_fields(self, fields: tuple[str, ...]) -> None: """Expand model fields to support per-environment parameters.""" if not fields: return invalid_fields = [f for f in fields if not hasattr(self._mj_model, f)] if invalid_fields: raise ValueError(f"Fields not found in model: {invalid_fields}") expand_model_fields(self._wp_model, self.num_envs, list(fields)) self._model_bridge.clear_cache() # Field expansion allocates new arrays and replaces them via setattr. The # CUDA graph captured the old memory addresses, so we must recreate it. self.create_graph()
[docs] def get_default_field(self, field: str) -> torch.Tensor: """Get the default value for a model field, caching for reuse. Returns the original values from the C MuJoCo model (mj_model), obtained from the final compiled scene spec before any randomization is applied. Not to be confused with the GPU Warp model (wp_model) which may have randomized values. """ if field not in self._default_model_fields: if not hasattr(self._mj_model, field): raise ValueError(f"Field '{field}' not found in model") model_field = getattr(self.model, field) default_value = getattr(self._mj_model, field) self._default_model_fields[field] = torch.as_tensor( default_value, dtype=model_field.dtype, device=self.device ).clone() return self._default_model_fields[field]
[docs] def forward(self) -> None: with wp.ScopedDevice(self.wp_device): if self.use_cuda_graph and self.forward_graph is not None: wp.capture_launch(self.forward_graph) else: mjwarp.forward(self.wp_model, self.wp_data)
[docs] def step(self) -> None: with wp.ScopedDevice(self.wp_device): with self.nan_guard.watch(self.data): if self.use_cuda_graph and self.step_graph is not None: wp.capture_launch(self.step_graph) else: mjwarp.step(self.wp_model, self.wp_data)
[docs] def reset(self, env_ids: torch.Tensor | None = None) -> None: with wp.ScopedDevice(self.wp_device): if env_ids is None: self._reset_mask.fill_(True) else: self._reset_mask.fill_(False) self._reset_mask[env_ids] = True if self.use_cuda_graph and self.reset_graph is not None: wp.capture_launch(self.reset_graph) else: mjwarp.reset_data(self.wp_model, self.wp_data, reset=self._reset_mask_wp)
# Private methods. def _should_use_cuda_graph(self) -> bool: """Determine if CUDA graphs can be used based on device and driver version.""" if not self.wp_device.is_cuda: return False driver_ver = wp.context.runtime.driver_version has_mempool = wp.is_mempool_enabled(self.wp_device) if driver_ver is None: print("[WARNING] CUDA Graphs disabled: driver version unavailable") return False if has_mempool and driver_ver >= _GRAPH_CAPTURE_MIN_DRIVER: return True reasons = [] if not has_mempool: reasons.append("mempool disabled") if driver_ver < _GRAPH_CAPTURE_MIN_DRIVER: reasons.append(f"driver {driver_ver[0]}.{driver_ver[1]} < 12.4") print(f"[WARNING] CUDA Graphs disabled: {', '.join(reasons)}") return False