Source code for mjlab.entity.data

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Sequence

import mujoco_warp as mjwarp
import torch

from mjlab.utils.lab_api.math import (
  quat_apply,
  quat_apply_inverse,
  quat_from_matrix,
  quat_mul,
)

if TYPE_CHECKING:
  from mjlab.entity.entity import EntityIndexing


def compute_velocity_from_cvel(
  pos: torch.Tensor,
  subtree_com: torch.Tensor,
  cvel: torch.Tensor,
) -> torch.Tensor:
  """Convert cvel quantities to world-frame velocities."""
  lin_vel_c = cvel[..., 3:6]
  ang_vel_c = cvel[..., 0:3]
  offset = subtree_com - pos
  lin_vel_w = lin_vel_c - torch.cross(ang_vel_c, offset, dim=-1)
  ang_vel_w = ang_vel_c
  return torch.cat([lin_vel_w, ang_vel_w], dim=-1)


[docs] @dataclass class EntityData: """Data container for an entity. Note: Write methods (write_*) modify state directly. Read properties (e.g., root_link_pose_w) require sim.forward() to be current. If you write then read, call sim.forward() in between. Event order matters when mixing reads and writes. All inputs/outputs use world frame. """ indexing: EntityIndexing data: mjwarp.Data model: mjwarp.Model device: str default_root_state: torch.Tensor default_joint_pos: torch.Tensor default_joint_vel: torch.Tensor default_joint_pos_limits: torch.Tensor joint_pos_limits: torch.Tensor soft_joint_pos_limits: torch.Tensor gravity_vec_w: torch.Tensor forward_vec_b: torch.Tensor is_fixed_base: bool is_articulated: bool is_actuated: bool joint_pos_target: torch.Tensor joint_vel_target: torch.Tensor joint_effort_target: torch.Tensor tendon_len_target: torch.Tensor tendon_vel_target: torch.Tensor tendon_effort_target: torch.Tensor site_effort_target: torch.Tensor encoder_bias: torch.Tensor # State dimensions. POS_DIM = 3 QUAT_DIM = 4 LIN_VEL_DIM = 3 ANG_VEL_DIM = 3 ROOT_POSE_DIM = POS_DIM + QUAT_DIM # 7 ROOT_VEL_DIM = LIN_VEL_DIM + ANG_VEL_DIM # 6 ROOT_STATE_DIM = ROOT_POSE_DIM + ROOT_VEL_DIM # 13
[docs] def write_root_state( self, root_state: torch.Tensor, env_ids: torch.Tensor | slice | None = None ) -> None: if self.is_fixed_base: raise ValueError("Cannot write root state for fixed-base entity.") assert root_state.shape[-1] == self.ROOT_STATE_DIM self.write_root_pose(root_state[:, : self.ROOT_POSE_DIM], env_ids) self.write_root_velocity(root_state[:, self.ROOT_POSE_DIM :], env_ids)
[docs] def write_root_pose( self, pose: torch.Tensor, env_ids: torch.Tensor | slice | None = None ) -> None: if self.is_fixed_base: raise ValueError("Cannot write root pose for fixed-base entity.") assert pose.shape[-1] == self.ROOT_POSE_DIM env_ids = self._resolve_env_ids(env_ids) self.data.qpos[env_ids, self.indexing.free_joint_q_adr] = pose
[docs] def write_root_velocity( self, velocity: torch.Tensor, env_ids: torch.Tensor | slice | None = None ) -> None: if self.is_fixed_base: raise ValueError("Cannot write root velocity for fixed-base entity.") assert velocity.shape[-1] == self.ROOT_VEL_DIM env_ids = self._resolve_env_ids(env_ids) quat_w = self.data.qpos[env_ids, self.indexing.free_joint_q_adr[3:7]] ang_vel_b = quat_apply_inverse(quat_w, velocity[:, 3:]) velocity_qvel = torch.cat([velocity[:, :3], ang_vel_b], dim=-1) self.data.qvel[env_ids, self.indexing.free_joint_v_adr] = velocity_qvel
[docs] def write_root_com_velocity( self, velocity: torch.Tensor, env_ids: torch.Tensor | slice | None = None ) -> None: if self.is_fixed_base: raise ValueError("Cannot write root COM velocity for fixed-base entity.") assert velocity.shape[-1] == self.ROOT_VEL_DIM env_ids = self._resolve_env_ids(env_ids) com_offset_b = self.model.body_ipos[:, self.indexing.root_body_id] quat_w = self.data.qpos[env_ids, self.indexing.free_joint_q_adr[3:7]] com_offset_w = quat_apply(quat_w, com_offset_b[env_ids]) lin_vel_com = velocity[:, :3] ang_vel_w = velocity[:, 3:] lin_vel_link = lin_vel_com - torch.cross(ang_vel_w, com_offset_w, dim=-1) link_velocity = torch.cat([lin_vel_link, ang_vel_w], dim=-1) self.write_root_velocity(link_velocity, env_ids)
[docs] def write_joint_state( self, position: torch.Tensor, velocity: torch.Tensor, joint_ids: torch.Tensor | slice | None = None, env_ids: torch.Tensor | slice | None = None, ) -> None: if not self.is_articulated: raise ValueError("Cannot write joint state for non-articulated entity.") self.write_joint_position(position, joint_ids, env_ids) self.write_joint_velocity(velocity, joint_ids, env_ids)
[docs] def write_joint_position( self, position: torch.Tensor, joint_ids: torch.Tensor | slice | None = None, env_ids: torch.Tensor | slice | None = None, ) -> None: if not self.is_articulated: raise ValueError("Cannot write joint position for non-articulated entity.") env_ids = self._resolve_env_ids(env_ids) joint_ids = joint_ids if joint_ids is not None else slice(None) q_slice = self.indexing.joint_q_adr[joint_ids] self.data.qpos[env_ids, q_slice] = position
[docs] def write_joint_velocity( self, velocity: torch.Tensor, joint_ids: torch.Tensor | slice | None = None, env_ids: torch.Tensor | slice | None = None, ) -> None: if not self.is_articulated: raise ValueError("Cannot write joint velocity for non-articulated entity.") env_ids = self._resolve_env_ids(env_ids) joint_ids = joint_ids if joint_ids is not None else slice(None) v_slice = self.indexing.joint_v_adr[joint_ids] self.data.qvel[env_ids, v_slice] = velocity
[docs] def write_external_wrench( self, force: torch.Tensor | None, torque: torch.Tensor | None, body_ids: Sequence[int] | slice | None = None, env_ids: torch.Tensor | slice | None = None, ) -> None: env_ids = self._resolve_env_ids(env_ids) local_body_ids = body_ids if body_ids is not None else slice(None) global_body_ids = self.indexing.body_ids[local_body_ids] if force is not None: self.data.xfrc_applied[env_ids, global_body_ids, 0:3] = force if torque is not None: self.data.xfrc_applied[env_ids, global_body_ids, 3:6] = torque
[docs] def write_ctrl( self, ctrl: torch.Tensor, ctrl_ids: torch.Tensor | slice | None = None, env_ids: torch.Tensor | slice | None = None, ) -> None: if not self.is_actuated: raise ValueError("Cannot write control for non-actuated entity.") env_ids = self._resolve_env_ids(env_ids) local_ctrl_ids = ctrl_ids if ctrl_ids is not None else slice(None) global_ctrl_ids = self.indexing.ctrl_ids[local_ctrl_ids] self.data.ctrl[env_ids, global_ctrl_ids] = ctrl
[docs] def write_mocap_pose( self, pose: torch.Tensor, env_ids: torch.Tensor | slice | None = None ) -> None: if self.indexing.mocap_id is None: raise ValueError("Cannot write mocap pose for non-mocap entity.") assert pose.shape[-1] == self.ROOT_POSE_DIM env_ids = self._resolve_env_ids(env_ids) self.data.mocap_pos[env_ids, self.indexing.mocap_id] = pose[:, 0:3].unsqueeze(1) self.data.mocap_quat[env_ids, self.indexing.mocap_id] = pose[:, 3:7].unsqueeze(1)
[docs] def clear_state(self, env_ids: torch.Tensor | slice | None = None) -> None: if self.is_actuated: env_ids = self._resolve_env_ids(env_ids) self.joint_pos_target[env_ids] = 0.0 self.joint_vel_target[env_ids] = 0.0 self.joint_effort_target[env_ids] = 0.0 self.tendon_len_target[env_ids] = 0.0 self.tendon_vel_target[env_ids] = 0.0 self.tendon_effort_target[env_ids] = 0.0 self.site_effort_target[env_ids] = 0.0
def _resolve_env_ids( self, env_ids: torch.Tensor | slice | None ) -> torch.Tensor | slice: """Convert env_ids to consistent indexing format.""" if env_ids is None: return slice(None) if isinstance(env_ids, torch.Tensor): return env_ids[:, None] return env_ids # Root properties @property def root_link_pose_w(self) -> torch.Tensor: """Root link pose in world frame. Shape (num_envs, 7).""" pos_w = self.data.xpos[:, self.indexing.root_body_id] # (num_envs, 3) quat_w = self.data.xquat[:, self.indexing.root_body_id] # (num_envs, 4) return torch.cat([pos_w, quat_w], dim=-1) # (num_envs, 7) @property def root_link_vel_w(self) -> torch.Tensor: """Root link velocity in world frame. Shape (num_envs, 6).""" # NOTE: Equivalently, can read this from qvel[:6] but the angular part # will be in body frame and needs to be rotated to world frame. # Note also that an extra forward() call might be required to make # both values equal. pos = self.data.xpos[:, self.indexing.root_body_id] # (num_envs, 3) subtree_com = self.data.subtree_com[:, self.indexing.root_body_id] cvel = self.data.cvel[:, self.indexing.root_body_id] # (num_envs, 6) return compute_velocity_from_cvel(pos, subtree_com, cvel) # (num_envs, 6) @property def root_com_pose_w(self) -> torch.Tensor: """Root center-of-mass pose in world frame. Shape (num_envs, 7).""" pos_w = self.data.xipos[:, self.indexing.root_body_id] quat = self.data.xquat[:, self.indexing.root_body_id] body_iquat = self.model.body_iquat[:, self.indexing.root_body_id] assert body_iquat is not None quat_w = quat_mul(quat, body_iquat.squeeze(1)) return torch.cat([pos_w, quat_w], dim=-1) @property def root_com_vel_w(self) -> torch.Tensor: """Root center-of-mass velocity in world frame. Shape (num_envs, 6).""" # NOTE: Equivalent sensor is framelinvel/frameangvel with objtype="body". pos = self.data.xipos[:, self.indexing.root_body_id] # (num_envs, 3) subtree_com = self.data.subtree_com[:, self.indexing.root_body_id] cvel = self.data.cvel[:, self.indexing.root_body_id] # (num_envs, 6) return compute_velocity_from_cvel(pos, subtree_com, cvel) # (num_envs, 6) # Body properties @property def body_link_pose_w(self) -> torch.Tensor: """Body link pose in world frame. Shape (num_envs, num_bodies, 7).""" pos_w = self.data.xpos[:, self.indexing.body_ids] quat_w = self.data.xquat[:, self.indexing.body_ids] return torch.cat([pos_w, quat_w], dim=-1) @property def body_link_vel_w(self) -> torch.Tensor: """Body link velocity in world frame. Shape (num_envs, num_bodies, 6).""" # NOTE: Equivalent sensor is framelinvel/frameangvel with objtype="xbody". pos = self.data.xpos[:, self.indexing.body_ids] # (num_envs, num_bodies, 3) subtree_com = self.data.subtree_com[:, self.indexing.root_body_id] cvel = self.data.cvel[:, self.indexing.body_ids] return compute_velocity_from_cvel(pos, subtree_com.unsqueeze(1), cvel) @property def body_com_pose_w(self) -> torch.Tensor: """Body center-of-mass pose in world frame. Shape (num_envs, num_bodies, 7).""" pos_w = self.data.xipos[:, self.indexing.body_ids] quat = self.data.xquat[:, self.indexing.body_ids] body_iquat = self.model.body_iquat[:, self.indexing.body_ids] quat_w = quat_mul(quat, body_iquat) return torch.cat([pos_w, quat_w], dim=-1) @property def body_com_vel_w(self) -> torch.Tensor: """Body center-of-mass velocity in world frame. Shape (num_envs, num_bodies, 6).""" # NOTE: Equivalent sensor is framelinvel/frameangvel with objtype="body". pos = self.data.xipos[:, self.indexing.body_ids] subtree_com = self.data.subtree_com[:, self.indexing.root_body_id] cvel = self.data.cvel[:, self.indexing.body_ids] return compute_velocity_from_cvel(pos, subtree_com.unsqueeze(1), cvel) @property def body_external_wrench(self) -> torch.Tensor: """Body external wrench in world frame. Shape (num_envs, num_bodies, 6).""" return self.data.xfrc_applied[:, self.indexing.body_ids] # Geom properties @property def geom_pose_w(self) -> torch.Tensor: """Geom pose in world frame. Shape (num_envs, num_geoms, 7).""" pos_w = self.data.geom_xpos[:, self.indexing.geom_ids] xmat = self.data.geom_xmat[:, self.indexing.geom_ids] quat_w = quat_from_matrix(xmat) return torch.cat([pos_w, quat_w], dim=-1) @property def geom_vel_w(self) -> torch.Tensor: """Geom velocity in world frame. Shape (num_envs, num_geoms, 6).""" pos = self.data.geom_xpos[:, self.indexing.geom_ids] body_ids = self.model.geom_bodyid[self.indexing.geom_ids] # (num_geoms,) subtree_com = self.data.subtree_com[:, self.indexing.root_body_id] cvel = self.data.cvel[:, body_ids] return compute_velocity_from_cvel(pos, subtree_com.unsqueeze(1), cvel) # Site properties @property def site_pose_w(self) -> torch.Tensor: """Site pose in world frame. Shape (num_envs, num_sites, 7).""" pos_w = self.data.site_xpos[:, self.indexing.site_ids] mat_w = self.data.site_xmat[:, self.indexing.site_ids] quat_w = quat_from_matrix(mat_w) return torch.cat([pos_w, quat_w], dim=-1) @property def site_vel_w(self) -> torch.Tensor: """Site velocity in world frame. Shape (num_envs, num_sites, 6).""" pos = self.data.site_xpos[:, self.indexing.site_ids] body_ids = self.model.site_bodyid[self.indexing.site_ids] # (num_sites,) subtree_com = self.data.subtree_com[:, self.indexing.root_body_id] cvel = self.data.cvel[:, body_ids] return compute_velocity_from_cvel(pos, subtree_com.unsqueeze(1), cvel) # Joint properties @property def joint_pos(self) -> torch.Tensor: """Joint positions. Shape (num_envs, num_joints).""" return self.data.qpos[:, self.indexing.joint_q_adr] @property def joint_pos_biased(self) -> torch.Tensor: """Joint positions with encoder bias applied. Shape (num_envs, num_joints).""" return self.joint_pos + self.encoder_bias @property def joint_vel(self) -> torch.Tensor: """Joint velocities. Shape (num_envs, nv).""" return self.data.qvel[:, self.indexing.joint_v_adr] @property def joint_acc(self) -> torch.Tensor: """Joint accelerations. Shape (num_envs, nv).""" return self.data.qacc[:, self.indexing.joint_v_adr] # Tendon properties @property def tendon_len(self) -> torch.Tensor: """Tendon lengths. Shape (num_envs, num_tendons).""" return self.data.ten_length[:, self.indexing.tendon_ids] @property def tendon_vel(self) -> torch.Tensor: """Tendon velocities. Shape (num_envs, num_tendons).""" return self.data.ten_velocity[:, self.indexing.tendon_ids] @property def joint_torques(self) -> torch.Tensor: """Joint torques. Shape (num_envs, nv).""" raise NotImplementedError( "Joint torques are not currently available. " "Consider using 'actuator_force' property for actuation forces, " "or 'generalized_force' property for generalized forces applied to the DoFs." ) @property def actuator_force(self) -> torch.Tensor: """Scalar actuation force in actuation space. Shape (num_envs, nu).""" return self.data.actuator_force[:, self.indexing.ctrl_ids] @property def generalized_force(self) -> torch.Tensor: """Generalized forces applied to the DoFs. Shape (num_envs, nv).""" return self.data.qfrc_applied[:, self.indexing.free_joint_v_adr] # Pose and velocity component accessors. @property def root_link_pos_w(self) -> torch.Tensor: """Root link position in world frame. Shape (num_envs, 3).""" return self.root_link_pose_w[:, 0:3] @property def root_link_quat_w(self) -> torch.Tensor: """Root link quaternion in world frame. Shape (num_envs, 4).""" return self.root_link_pose_w[:, 3:7] @property def root_link_lin_vel_w(self) -> torch.Tensor: """Root link linear velocity in world frame. Shape (num_envs, 3).""" return self.root_link_vel_w[:, 0:3] @property def root_link_ang_vel_w(self) -> torch.Tensor: """Root link angular velocity in world frame. Shape (num_envs, 3).""" return self.root_link_vel_w[:, 3:6] @property def root_com_pos_w(self) -> torch.Tensor: """Root COM position in world frame. Shape (num_envs, 3).""" return self.root_com_pose_w[:, 0:3] @property def root_com_quat_w(self) -> torch.Tensor: """Root COM quaternion in world frame. Shape (num_envs, 4).""" return self.root_com_pose_w[:, 3:7] @property def root_com_lin_vel_w(self) -> torch.Tensor: """Root COM linear velocity in world frame. Shape (num_envs, 3).""" return self.root_com_vel_w[:, 0:3] @property def root_com_ang_vel_w(self) -> torch.Tensor: """Root COM angular velocity in world frame. Shape (num_envs, 3).""" return self.root_com_vel_w[:, 3:6] @property def body_link_pos_w(self) -> torch.Tensor: """Body link positions in world frame. Shape (num_envs, num_bodies, 3).""" return self.body_link_pose_w[..., 0:3] @property def body_link_quat_w(self) -> torch.Tensor: """Body link quaternions in world frame. Shape (num_envs, num_bodies, 4).""" return self.body_link_pose_w[..., 3:7] @property def body_link_lin_vel_w(self) -> torch.Tensor: """Body link linear velocities in world frame. Shape (num_envs, num_bodies, 3).""" return self.body_link_vel_w[..., 0:3] @property def body_link_ang_vel_w(self) -> torch.Tensor: """Body link angular velocities in world frame. Shape (num_envs, num_bodies, 3).""" return self.body_link_vel_w[..., 3:6] @property def body_com_pos_w(self) -> torch.Tensor: """Body COM positions in world frame. Shape (num_envs, num_bodies, 3).""" return self.body_com_pose_w[..., 0:3] @property def body_com_quat_w(self) -> torch.Tensor: """Body COM quaternions in world frame. Shape (num_envs, num_bodies, 4).""" return self.body_com_pose_w[..., 3:7] @property def body_com_lin_vel_w(self) -> torch.Tensor: """Body COM linear velocities in world frame. Shape (num_envs, num_bodies, 3).""" return self.body_com_vel_w[..., 0:3] @property def body_com_ang_vel_w(self) -> torch.Tensor: """Body COM angular velocities in world frame. Shape (num_envs, num_bodies, 3).""" return self.body_com_vel_w[..., 3:6] @property def body_external_force(self) -> torch.Tensor: """Body external forces in world frame. Shape (num_envs, num_bodies, 3).""" return self.body_external_wrench[..., 0:3] @property def body_external_torque(self) -> torch.Tensor: """Body external torques in world frame. Shape (num_envs, num_bodies, 3).""" return self.body_external_wrench[..., 3:6] @property def geom_pos_w(self) -> torch.Tensor: """Geom positions in world frame. Shape (num_envs, num_geoms, 3).""" return self.geom_pose_w[..., 0:3] @property def geom_quat_w(self) -> torch.Tensor: """Geom quaternions in world frame. Shape (num_envs, num_geoms, 4).""" return self.geom_pose_w[..., 3:7] @property def geom_lin_vel_w(self) -> torch.Tensor: """Geom linear velocities in world frame. Shape (num_envs, num_geoms, 3).""" return self.geom_vel_w[..., 0:3] @property def geom_ang_vel_w(self) -> torch.Tensor: """Geom angular velocities in world frame. Shape (num_envs, num_geoms, 3).""" return self.geom_vel_w[..., 3:6] @property def site_pos_w(self) -> torch.Tensor: """Site positions in world frame. Shape (num_envs, num_sites, 3).""" return self.site_pose_w[..., 0:3] @property def site_quat_w(self) -> torch.Tensor: """Site quaternions in world frame. Shape (num_envs, num_sites, 4).""" return self.site_pose_w[..., 3:7] @property def site_lin_vel_w(self) -> torch.Tensor: """Site linear velocities in world frame. Shape (num_envs, num_sites, 3).""" return self.site_vel_w[..., 0:3] @property def site_ang_vel_w(self) -> torch.Tensor: """Site angular velocities in world frame. Shape (num_envs, num_sites, 3).""" return self.site_vel_w[..., 3:6] # Derived properties. @property def projected_gravity_b(self) -> torch.Tensor: """Gravity vector projected into body frame. Shape (num_envs, 3).""" return quat_apply_inverse(self.root_link_quat_w, self.gravity_vec_w) @property def heading_w(self) -> torch.Tensor: """Heading angle in world frame. Shape (num_envs,).""" forward_w = quat_apply(self.root_link_quat_w, self.forward_vec_b) return torch.atan2(forward_w[:, 1], forward_w[:, 0]) @property def root_link_lin_vel_b(self) -> torch.Tensor: """Root link linear velocity in body frame. Shape (num_envs, 3).""" return quat_apply_inverse(self.root_link_quat_w, self.root_link_lin_vel_w) @property def root_link_ang_vel_b(self) -> torch.Tensor: """Root link angular velocity in body frame. Shape (num_envs, 3).""" return quat_apply_inverse(self.root_link_quat_w, self.root_link_ang_vel_w) @property def root_com_lin_vel_b(self) -> torch.Tensor: """Root COM linear velocity in body frame. Shape (num_envs, 3).""" return quat_apply_inverse(self.root_link_quat_w, self.root_com_lin_vel_w) @property def root_com_ang_vel_b(self) -> torch.Tensor: """Root COM angular velocity in body frame. Shape (num_envs, 3).""" return quat_apply_inverse(self.root_link_quat_w, self.root_com_ang_vel_w)