Source code for mjlab.sensor.contact_sensor
"""Contact sensors track collisions between geoms, bodies, or subtrees."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Literal
import mujoco
import mujoco_warp as mjwarp
import torch
from mjlab.entity import Entity
from mjlab.sensor.sensor import Sensor, SensorCfg
_CONTACT_DATA_MAP = {
"found": 0,
"force": 1,
"torque": 2,
"dist": 3,
"pos": 4,
"normal": 5,
"tangent": 6,
}
_CONTACT_DATA_DIMS = {
"found": 1,
"force": 3,
"torque": 3,
"dist": 1,
"pos": 3,
"normal": 3,
"tangent": 3,
}
_CONTACT_REDUCE_MAP = {
"none": 0,
"mindist": 1,
"maxforce": 2,
"netforce": 3,
}
_MODE_TO_OBJTYPE = {
"geom": mujoco.mjtObj.mjOBJ_GEOM,
"body": mujoco.mjtObj.mjOBJ_BODY,
"subtree": mujoco.mjtObj.mjOBJ_XBODY,
}
[docs]
@dataclass
class ContactMatch:
"""Specifies what to match on one side of a contact.
mode: "geom", "body", or "subtree"
pattern: Regex or tuple of regexes (expands within entity if specified)
entity: Entity name to search within (None = treat pattern as literal MuJoCo name)
exclude: Filter out matches using these regex patterns or exact names.
"""
mode: Literal["geom", "body", "subtree"]
pattern: str | tuple[str, ...]
entity: str | None = None
exclude: tuple[str, ...] = ()
[docs]
@dataclass
class ContactSensorCfg(SensorCfg):
"""Tracks contacts between primary and secondary patterns.
Output shape: [B, N * num_slots] or [B, N * num_slots, 3] where N = # of primaries
Fields (choose subset):
- found: 0=no contact, >0=match count before reduction
- force, torque: 3D vectors in contact frame (or global if reduce="netforce")
- dist: penetration depth
- pos, normal, tangent: 3D vectors in global frame (normal: primary→secondary)
Reduction modes (selects top num_slots from all matches):
- "none": fast, non-deterministic
- "mindist", "maxforce": closest/strongest contacts
- "netforce": sum all forces (global frame)
Policies:
- secondary_policy: "first", "any", or "error" when secondary matches multiple
- track_air_time: enables landing/takeoff detection
- global_frame: rotates force/torque to global (requires normal+tangent fields)
"""
primary: ContactMatch
secondary: ContactMatch | None = None
fields: tuple[str, ...] = ("found", "force")
reduce: Literal["none", "mindist", "maxforce", "netforce"] = "maxforce"
num_slots: int = 1
secondary_policy: Literal["first", "any", "error"] = "first"
track_air_time: bool = False
global_frame: bool = False
history_length: int = 0
"""Number of substeps to store in history buffer for force/torque/dist fields.
When 0 (default): No history buffer is allocated. History fields (force_history,
torque_history, dist_history) are None. Use the regular fields (force, torque, dist)
for the current instantaneous values.
When >0: Allocates a history buffer that stores the last N substeps of contact data.
Shape is [B, N, history_length, ...] where index 0 is the most recent substep.
Set to your decimation value to capture all substeps within one policy step.
Note: history_length=1 is redundant with the regular fields but provides a consistent
[B, N, H, ...] shape if your code expects a history dimension.
"""
debug: bool = False
@dataclass
class _ContactSlot:
"""Maps one MuJoCo sensor (one primary, one field) to its sensordata view."""
primary_name: str
field_name: str
sensor_name: str
data_view: torch.Tensor | None = None
@dataclass
class _AirTimeState:
"""Tracks how long contacts have been in air/contact. Shape: [B, N]."""
current_air_time: torch.Tensor
last_air_time: torch.Tensor
current_contact_time: torch.Tensor
last_contact_time: torch.Tensor
last_time: torch.Tensor
[docs]
@dataclass
class ContactData:
"""Contact sensor output (only requested fields are populated)."""
found: torch.Tensor | None = None
"""[B, N] 0=no contact, >0=match count"""
force: torch.Tensor | None = None
"""[B, N, 3] contact frame (global if reduce="netforce" or global_frame=True)"""
torque: torch.Tensor | None = None
"""[B, N, 3] contact frame (global if reduce="netforce" or global_frame=True)"""
dist: torch.Tensor | None = None
"""[B, N] penetration depth"""
pos: torch.Tensor | None = None
"""[B, N, 3] global frame"""
normal: torch.Tensor | None = None
"""[B, N, 3] global frame, primary→secondary"""
tangent: torch.Tensor | None = None
"""[B, N, 3] global frame"""
current_air_time: torch.Tensor | None = None
"""[B, N] time in air (if track_air_time=True)"""
last_air_time: torch.Tensor | None = None
"""[B, N] duration of last air phase (if track_air_time=True)"""
current_contact_time: torch.Tensor | None = None
"""[B, N] time in contact (if track_air_time=True)"""
last_contact_time: torch.Tensor | None = None
"""[B, N] duration of last contact phase (if track_air_time=True)"""
force_history: torch.Tensor | None = None
"""[B, N, H, 3] contact forces over last H substeps (index 0 = most recent)"""
torque_history: torch.Tensor | None = None
"""[B, N, H, 3] contact torques over last H substeps (index 0 = most recent)"""
dist_history: torch.Tensor | None = None
"""[B, N, H] penetration depth over last H substeps (index 0 = most recent)"""
[docs]
class ContactSensor(Sensor[ContactData]):
"""Tracks contacts with automatic pattern expansion to multiple MuJoCo sensors."""
[docs]
def __init__(self, cfg: ContactSensorCfg) -> None:
super().__init__()
self.cfg = cfg
if cfg.global_frame and cfg.reduce != "netforce":
if "normal" not in cfg.fields or "tangent" not in cfg.fields:
raise ValueError(
f"Sensor '{cfg.name}': global_frame=True requires 'normal' and 'tangent' "
"in fields (needed to build rotation matrix)"
)
self._slots: list[_ContactSlot] = []
self._data: mjwarp.Data | None = None
self._device: str | None = None
self._air_time_state: _AirTimeState | None = None
self._history_state: dict[str, torch.Tensor] | None = None
[docs]
def edit_spec(self, scene_spec: mujoco.MjSpec, entities: dict[str, Entity]) -> None:
"""Expand patterns and add MuJoCo sensors (one per primary x field pair)."""
self._slots.clear()
primary_names = self._resolve_primary_names(entities, self.cfg.primary)
if self.cfg.secondary is None or self.cfg.secondary_policy == "any":
secondary_name = None
else:
secondary_name = self._resolve_single_secondary(
entities, self.cfg.secondary, self.cfg.secondary_policy
)
for prim in primary_names:
for field in self.cfg.fields:
sensor_name = f"{self.cfg.name}_{prim}_{field}"
self._add_contact_sensor_to_spec(
scene_spec, sensor_name, prim, secondary_name, field
)
self._slots.append(
_ContactSlot(
primary_name=prim,
field_name=field,
sensor_name=sensor_name,
)
)
[docs]
def initialize(
self, mj_model: mujoco.MjModel, model: mjwarp.Model, data: mjwarp.Data, device: str
) -> None:
"""Map sensors to sensordata buffer and allocate air time state."""
del model
if not self._slots:
raise RuntimeError(
f"There was an error initializing contact sensor '{self.cfg.name}'"
)
for slot in self._slots:
sensor = mj_model.sensor(slot.sensor_name)
start = sensor.adr[0]
dim = sensor.dim[0]
slot.data_view = data.sensordata[:, start : start + dim]
self._data = data
self._device = device
if self.cfg.track_air_time:
n_envs = data.time.shape[0]
n_primary = len(set(slot.primary_name for slot in self._slots))
self._air_time_state = _AirTimeState(
current_air_time=torch.zeros((n_envs, n_primary), device=device),
last_air_time=torch.zeros((n_envs, n_primary), device=device),
current_contact_time=torch.zeros((n_envs, n_primary), device=device),
last_contact_time=torch.zeros((n_envs, n_primary), device=device),
last_time=torch.zeros((n_envs,), device=device),
)
if self.cfg.history_length > 0:
n_envs = data.time.shape[0]
n_primary = len(set(slot.primary_name for slot in self._slots))
n_contacts = n_primary * self.cfg.num_slots
h = self.cfg.history_length
self._history_state = {}
if "force" in self.cfg.fields:
self._history_state["force"] = torch.zeros(
(n_envs, n_contacts, h, 3), device=device
)
if "torque" in self.cfg.fields:
self._history_state["torque"] = torch.zeros(
(n_envs, n_contacts, h, 3), device=device
)
if "dist" in self.cfg.fields:
self._history_state["dist"] = torch.zeros(
(n_envs, n_contacts, h), device=device
)
def _compute_data(self) -> ContactData:
out = self._extract_sensor_data()
if self._air_time_state is not None:
out.current_air_time = self._air_time_state.current_air_time
out.last_air_time = self._air_time_state.last_air_time
out.current_contact_time = self._air_time_state.current_contact_time
out.last_contact_time = self._air_time_state.last_contact_time
if self._history_state is not None:
out.force_history = self._history_state.get("force")
out.torque_history = self._history_state.get("torque")
out.dist_history = self._history_state.get("dist")
return out
[docs]
def reset(self, env_ids: torch.Tensor | slice | None = None) -> None:
super().reset(env_ids)
if env_ids is None:
env_ids = slice(None)
# Reset air time state for specified envs.
if self._air_time_state is not None:
self._air_time_state.current_air_time[env_ids] = 0.0
self._air_time_state.last_air_time[env_ids] = 0.0
self._air_time_state.current_contact_time[env_ids] = 0.0
self._air_time_state.last_contact_time[env_ids] = 0.0
if self._data is not None:
self._air_time_state.last_time[env_ids] = self._data.time[env_ids]
# Reset history state for specified envs.
if self._history_state is not None:
for buf in self._history_state.values():
buf[env_ids] = 0.0
[docs]
def update(self, dt: float) -> None:
super().update(dt)
if self._air_time_state is not None:
self._update_air_time_tracking()
if self._history_state is not None:
self._update_history()
[docs]
def compute_first_contact(self, dt: float, abs_tol: float = 1.0e-8) -> torch.Tensor:
"""Returns [B, N] bool: True for contacts established within last dt seconds."""
if self._air_time_state is None:
raise RuntimeError(
f"Sensor '{self.cfg.name}' must have track_air_time=True "
"to use compute_first_contact"
)
is_in_contact = self._air_time_state.current_contact_time > 0.0
within_dt = self._air_time_state.current_contact_time < (dt + abs_tol)
return is_in_contact & within_dt
[docs]
def compute_first_air(self, dt: float, abs_tol: float = 1.0e-8) -> torch.Tensor:
"""Returns [B, N] bool: True for contacts broken within last dt seconds."""
if self._air_time_state is None:
raise RuntimeError(
f"Sensor '{self.cfg.name}' must have track_air_time=True "
"to use compute_first_air"
)
is_in_air = self._air_time_state.current_air_time > 0.0
within_dt = self._air_time_state.current_air_time < (dt + abs_tol)
return is_in_air & within_dt
def _extract_sensor_data(self) -> ContactData:
if not self._slots:
raise RuntimeError(f"Sensor '{self.cfg.name}' not initialized")
field_chunks: dict[str, list[torch.Tensor]] = {f: [] for f in self.cfg.fields}
for slot in self._slots:
assert slot.data_view is not None
field_dim = _CONTACT_DATA_DIMS[slot.field_name]
raw = slot.data_view.view(slot.data_view.size(0), self.cfg.num_slots, field_dim)
field_chunks[slot.field_name].append(raw)
out = ContactData()
for field, chunks in field_chunks.items():
cat = torch.cat(chunks, dim=1)
if cat.size(-1) == 1:
cat = cat.squeeze(-1)
setattr(out, field, cat)
if self.cfg.global_frame and self.cfg.reduce != "netforce":
out = self._transform_to_global_frame(out)
return out
def _transform_to_global_frame(self, data: ContactData) -> ContactData:
"""Rotate force/torque from contact frame to global frame."""
assert data.normal is not None and data.tangent is not None
normal = data.normal
tangent = data.tangent
tangent2 = torch.cross(normal, tangent, dim=-1)
R = torch.stack([tangent, tangent2, normal], dim=-1)
has_contact = torch.norm(normal, dim=-1, keepdim=True) > 1e-8
if data.force is not None:
force_global = torch.einsum("...ij,...j->...i", R, data.force)
data.force = torch.where(has_contact, force_global, data.force)
if data.torque is not None:
torque_global = torch.einsum("...ij,...j->...i", R, data.torque)
data.torque = torch.where(has_contact, torque_global, data.torque)
return data
def _update_air_time_tracking(self) -> None:
assert self._air_time_state is not None
contact_data = self._extract_sensor_data()
if contact_data.found is None or "found" not in self.cfg.fields:
return
assert self._data is not None
current_time = self._data.time
elapsed_time = current_time - self._air_time_state.last_time
elapsed_time = elapsed_time.unsqueeze(-1)
is_contact = contact_data.found > 0
state = self._air_time_state
is_first_contact = (state.current_air_time > 0) & is_contact
is_first_detached = (state.current_contact_time > 0) & ~is_contact
state.last_air_time[:] = torch.where(
is_first_contact,
state.current_air_time + elapsed_time,
state.last_air_time,
)
state.current_air_time[:] = torch.where(
~is_contact,
state.current_air_time + elapsed_time,
torch.zeros_like(state.current_air_time),
)
state.last_contact_time[:] = torch.where(
is_first_detached,
state.current_contact_time + elapsed_time,
state.last_contact_time,
)
state.current_contact_time[:] = torch.where(
is_contact,
state.current_contact_time + elapsed_time,
torch.zeros_like(state.current_contact_time),
)
state.last_time[:] = current_time
def _update_history(self) -> None:
"""Roll history buffer and insert current contact data at index 0."""
assert self._history_state is not None
contact_data = self._extract_sensor_data()
if "force" in self._history_state and contact_data.force is not None:
self._history_state["force"] = self._history_state["force"].roll(1, dims=2)
self._history_state["force"][:, :, 0, :] = contact_data.force
if "torque" in self._history_state and contact_data.torque is not None:
self._history_state["torque"] = self._history_state["torque"].roll(1, dims=2)
self._history_state["torque"][:, :, 0, :] = contact_data.torque
if "dist" in self._history_state and contact_data.dist is not None:
self._history_state["dist"] = self._history_state["dist"].roll(1, dims=2)
self._history_state["dist"][:, :, 0] = contact_data.dist
def _resolve_primary_names(
self, entities: dict[str, Entity], match: ContactMatch
) -> list[str]:
if match.entity in (None, ""):
result = (
[match.pattern] if isinstance(match.pattern, str) else list(match.pattern)
)
return result
if match.entity not in entities:
raise ValueError(
f"Primary entity '{match.entity}' not found. Available: {list(entities.keys())}"
)
ent = entities[match.entity]
patterns = [match.pattern] if isinstance(match.pattern, str) else match.pattern
if match.mode == "geom":
_, names = ent.find_geoms(patterns)
elif match.mode == "body":
_, names = ent.find_bodies(patterns)
elif match.mode == "subtree":
_, names = ent.find_bodies(patterns)
if not names:
raise ValueError(
f"Primary subtree pattern '{match.pattern}' matched no bodies in "
f"'{match.entity}'"
)
else:
raise ValueError("Primary mode must be one of {'geom','body','subtree'}")
excludes = match.exclude
if excludes:
exclude_patterns = []
exclude_exact = set()
for exc in excludes:
if any(c in exc for c in r".*+?[]{}()\|^$"):
exclude_patterns.append(re.compile(exc))
else:
exclude_exact.add(exc)
if exclude_exact:
names = [n for n in names if n not in exclude_exact]
if exclude_patterns:
names = [n for n in names if not any(rx.search(n) for rx in exclude_patterns)]
if not names:
raise ValueError(
f"Primary pattern '{match.pattern}' (after excludes) matched "
f"no names in '{match.entity}'"
)
return names
def _resolve_single_secondary(
self,
entities: dict[str, Entity],
match: ContactMatch,
policy: Literal["first", "any", "error"],
) -> str | None:
if policy == "any":
return None
if isinstance(match.pattern, tuple):
raise ValueError(
"Secondary must specify a single name (string). "
"Use a single exact name or a regex that resolves to one name, "
"or set secondary_policy='any' if you want no filter."
)
if match.entity in (None, ""):
if match.mode not in {"geom", "body", "subtree"}:
raise ValueError("Secondary mode must be one of {'geom','body','subtree'}")
return match.pattern
if match.entity not in entities:
raise ValueError(
f"Secondary entity '{match.entity}' not found. "
f"Available: {list(entities.keys())}"
)
ent = entities[match.entity]
if match.mode == "subtree":
return match.pattern
if match.mode == "geom":
_, names = ent.find_geoms(match.pattern)
elif match.mode == "body":
_, names = ent.find_bodies(match.pattern)
else:
raise ValueError("Secondary mode must be one of {'geom','body','subtree'}")
if not names:
raise ValueError(
f"Secondary pattern '{match.pattern}' matched nothing in '{match.entity}'"
)
if len(names) == 1 or policy == "first":
return names[0]
raise ValueError(
f"Secondary pattern '{match.pattern}' matched multiple: {names}. "
f"Be explicit or set secondary_policy='first' or 'any'."
)
def _add_contact_sensor_to_spec(
self,
scene_spec: mujoco.MjSpec,
sensor_name: str,
primary_name: str,
secondary_name: str | None,
field: str,
) -> None:
data_bits = 1 << _CONTACT_DATA_MAP[field]
reduce_mode = _CONTACT_REDUCE_MAP[self.cfg.reduce]
intprm = [data_bits, reduce_mode, self.cfg.num_slots]
primary_entity = self.cfg.primary.entity
if primary_entity and primary_entity != "":
prefixed_primary = f"{primary_entity}/{primary_name}"
else:
prefixed_primary = primary_name
kwargs = {
"name": sensor_name,
"type": mujoco.mjtSensor.mjSENS_CONTACT,
"objtype": _MODE_TO_OBJTYPE[self.cfg.primary.mode],
"objname": prefixed_primary,
"intprm": intprm,
}
if secondary_name is not None:
assert self.cfg.secondary is not None
secondary_entity = self.cfg.secondary.entity
if secondary_entity and secondary_entity != "":
prefixed_secondary = f"{secondary_entity}/{secondary_name}"
else:
prefixed_secondary = secondary_name
kwargs["reftype"] = _MODE_TO_OBJTYPE[self.cfg.secondary.mode]
kwargs["refname"] = prefixed_secondary
if self.cfg.debug:
def _ename(v):
return getattr(v, "name", str(v))
objtype_name = _ename(kwargs["objtype"]).removeprefix("mjOBJ_")
reftype_val = kwargs.get("reftype")
refname_val = kwargs.get("refname")
reftype_name = (
_ename(reftype_val).removeprefix("mjOBJ_")
if reftype_val is not None
else "<any>"
)
ref_str = "<any>" if refname_val is None else f"{reftype_name}:{refname_val}"
print(
"Adding contact sensor\n"
f" name : {sensor_name}\n"
f" object : {objtype_name}:{kwargs['objname']}\n"
f" ref : {ref_str}\n"
f" field : {field} bits=0b{intprm[0]:b}\n"
f" reduce : {self.cfg.reduce} num_slots={self.cfg.num_slots}"
)
scene_spec.add_sensor(**kwargs)