Source code for mjlab.sim.sim_data

"""Bridge for seamless PyTorch-Warp interoperability with zero-copy memory sharing.

Provides automatic wrapping of Warp arrays as PyTorch-compatible objects while
preserving shared memory and CUDA graph compatibility.
"""

from typing import Any, Dict, Generic, Optional, Tuple, TypeVar

import torch
import warp as wp

T = TypeVar("T")


[docs] class TorchArray: """Warp array that behaves like a torch.Tensor with shared memory."""
[docs] def __init__(self, wp_array: wp.array, nworld: int | None = None) -> None: """Initialize the tensor proxy with a Warp array.""" self._wp_array = wp_array # Workaround for Warp bug: wp.to_torch fails on empty CPU arrays. if wp_array.device.is_cpu and wp_array.size == 0: # type: ignore[union-attr] self._tensor = torch.zeros( wp_array.shape, dtype=wp.dtype_to_torch(wp_array.dtype) ) else: self._tensor = wp.to_torch(wp_array) if ( nworld is not None and nworld > 1 and len(self._tensor.shape) > 0 and self._tensor.stride(0) == 0 and self._tensor.shape[0] == 1 ): new_shape = (nworld,) + self._tensor.shape[1:] self._tensor = self._tensor.expand(new_shape) self._is_cuda = not self._wp_array.device.is_cpu # type: ignore self._torch_stream = self._setup_stream()
def _setup_stream(self) -> Optional[torch.cuda.Stream]: """Setup appropriate stream for the device.""" if not self._is_cuda: return None try: warp_stream = wp.get_stream(self._wp_array.device) return torch.cuda.ExternalStream(warp_stream.cuda_stream) except Exception as e: # Fallback to default stream if external stream creation fails. print(f"Warning: Could not create external stream: {e}") return torch.cuda.current_stream(self._tensor.device) @property def wp_array(self) -> wp.array: return self._wp_array def __repr__(self) -> str: """Return string representation of the underlying tensor.""" return repr(self._tensor) def __getitem__(self, idx: Any) -> Any: """Get item(s) from the tensor using standard indexing.""" return self._tensor[idx] def __setitem__(self, idx: Any, value: Any) -> None: """Set item(s) in the tensor using standard indexing.""" if self._is_cuda and self._torch_stream is not None: with torch.cuda.stream(self._torch_stream): self._tensor[idx] = value else: self._tensor[idx] = value def __getattr__(self, name: str) -> Any: """Delegate attribute access to the underlying tensor.""" return getattr(self._tensor, name) @classmethod def __torch_function__( cls, func: Any, types: Tuple[type, ...], args: Tuple[Any, ...] = (), kwargs: Optional[Dict[str, Any]] = None, ) -> Any: """Intercept torch.* function calls to unwrap TorchArray objects.""" if kwargs is None: kwargs = {} # Only intercept when at least one argument is our proxy. if not any(issubclass(t, cls) for t in types): return NotImplemented def _unwrap(x: Any) -> Any: """Unwrap TorchArray objects to their underlying tensors.""" return x._tensor if isinstance(x, cls) else x # Unwrap all TorchArray objects in args and kwargs. unwrapped_args = tuple(_unwrap(arg) for arg in args) unwrapped_kwargs = {k: _unwrap(v) for k, v in kwargs.items()} return func(*unwrapped_args, **unwrapped_kwargs) # Arithmetic operators. def __add__(self, other: Any) -> Any: return self._tensor + other def __radd__(self, other: Any) -> Any: return other + self._tensor def __sub__(self, other: Any) -> Any: return self._tensor - other def __rsub__(self, other: Any) -> Any: return other - self._tensor def __mul__(self, other: Any) -> Any: return self._tensor * other def __rmul__(self, other: Any) -> Any: return other * self._tensor def __truediv__(self, other: Any) -> Any: return self._tensor / other def __rtruediv__(self, other: Any) -> Any: return other / self._tensor def __pow__(self, other: Any) -> Any: return self._tensor**other def __rpow__(self, other: Any) -> Any: return other**self._tensor def __neg__(self) -> Any: return -self._tensor def __pos__(self) -> Any: return +self._tensor def __abs__(self) -> Any: return abs(self._tensor) # Comparison operators. def __eq__(self, other: Any) -> Any: return self._tensor == other def __ne__(self, other: Any) -> Any: return self._tensor != other def __lt__(self, other: Any) -> Any: return self._tensor < other def __le__(self, other: Any) -> Any: return self._tensor <= other def __gt__(self, other: Any) -> Any: return self._tensor > other def __ge__(self, other: Any) -> Any: return self._tensor >= other
def _contains_warp_arrays(obj: Any) -> bool: """Check if an object or its attributes contain any Warp arrays.""" if isinstance(obj, wp.array): return True # Check if it's a struct-like object with attributes if hasattr(obj, "__dict__"): return any( isinstance(getattr(obj, attr), wp.array) for attr in dir(obj) if not attr.startswith("_") ) return False
[docs] class WarpBridge(Generic[T]): """Wraps mjwarp objects to expose Warp arrays as PyTorch tensors. Automatically converts Warp array attributes to TorchArray objects on access, enabling direct PyTorch operations on simulation data. Recursively wraps nested structures that contain Warp arrays. IMPORTANT: This wrapper is read-only. To modify array data, use in-place operations like `obj.field[:] = value`. Direct assignment like `obj.field = new_array` will raise an AttributeError to prevent accidental memory address changes that break CUDA graphs. """
[docs] def __init__(self, struct: T, nworld: int | None = None) -> None: object.__setattr__(self, "_struct", struct) object.__setattr__(self, "_wrapped_cache", {}) object.__setattr__(self, "_nworld", nworld)
def __getattr__(self, name: str) -> Any: """Get attribute from the wrapped data, wrapping Warp arrays as TorchArray.""" # Check cache first to avoid recreating wrappers. if name in self._wrapped_cache: return self._wrapped_cache[name] val = getattr(self._struct, name) # Wrap Warp arrays. if isinstance(val, wp.array): wrapped = TorchArray(val, nworld=self._nworld) self._wrapped_cache[name] = wrapped return wrapped # Recursively wrap nested structures that contain Warp arrays. if _contains_warp_arrays(val): wrapped = WarpBridge(val, nworld=self._nworld) self._wrapped_cache[name] = wrapped return wrapped return val def __setattr__(self, name: str, value: Any) -> None: """Prevent attribute setting to maintain CUDA graph safety.""" raise AttributeError( f"Cannot set attribute '{name}' on WarpBridge. " f"This wrapper is read-only to preserve memory addresses for CUDA graphs. " f"Use in-place operations instead: obj.{name}[:] = value" ) def __repr__(self) -> str: """Return string representation of the wrapped struct.""" return f"WarpBridge({repr(self._struct)})" @property def struct(self) -> T: """Access the underlying wrapped struct.""" return self._struct
[docs] def clear_cache(self) -> None: """Clear the wrapped cache to force re-wrapping of arrays. This should be called after operations that modify the underlying warp arrays, such as expand_model_fields(), to ensure the cache reflects the updated arrays. """ object.__setattr__(self, "_wrapped_cache", {})