"""Base actuator interface."""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Generic, TypeVar
import mujoco
import mujoco_warp as mjwarp
import torch
if TYPE_CHECKING:
from mjlab.entity import Entity
from mjlab.entity.data import EntityData
ActuatorCfgT = TypeVar("ActuatorCfgT", bound="ActuatorCfg")
class TransmissionType(str, Enum):
"""Transmission types for actuators."""
JOINT = "joint"
TENDON = "tendon"
SITE = "site"
[docs]
@dataclass(kw_only=True)
class ActuatorCfg(ABC):
target_names_expr: tuple[str, ...]
"""Targets that are part of this actuator group.
Can be a tuple of names or tuple of regex expressions.
Interpreted based on transmission_type (joint/tendon/site).
"""
transmission_type: TransmissionType = TransmissionType.JOINT
"""Transmission type. Defaults to JOINT."""
armature: float = 0.0
"""Reflected rotor inertia."""
frictionloss: float = 0.0
"""Friction loss force limit.
Applies a constant friction force opposing motion, independent of load or velocity.
Also known as dry friction or load-independent friction.
"""
def __post_init__(self) -> None:
assert self.armature >= 0.0, "armature must be non-negative."
assert self.frictionloss >= 0.0, "frictionloss must be non-negative."
if self.transmission_type == TransmissionType.SITE:
if self.armature > 0.0 or self.frictionloss > 0.0:
raise ValueError(
f"{self.__class__.__name__}: armature and frictionloss are not supported for "
"SITE transmission type."
)
[docs]
@abstractmethod
def build(
self, entity: Entity, target_ids: list[int], target_names: list[str]
) -> Actuator:
"""Build actuator instance.
Args:
entity: Entity this actuator belongs to.
target_ids: Local target indices (for indexing entity arrays).
target_names: Target names corresponding to target_ids.
Returns:
Actuator instance.
"""
raise NotImplementedError
[docs]
@dataclass
class ActuatorCmd:
"""High-level actuator command with targets and current state.
Passed to actuator's `compute()` method to generate low-level control signals.
All tensors have shape (num_envs, num_targets).
"""
position_target: torch.Tensor
"""Desired positions (joint positions, tendon lengths, or site positions)."""
velocity_target: torch.Tensor
"""Desired velocities (joint velocities, tendon velocities, or site velocities)."""
effort_target: torch.Tensor
"""Feedforward effort (torques or forces)."""
pos: torch.Tensor
"""Current positions (joint positions, tendon lengths, or site positions)."""
vel: torch.Tensor
"""Current velocities (joint velocities, tendon velocities, or site velocities)."""
[docs]
class Actuator(ABC, Generic[ActuatorCfgT]):
"""Base actuator interface."""
[docs]
def __init__(
self,
cfg: ActuatorCfgT,
entity: Entity,
target_ids: list[int],
target_names: list[str],
) -> None:
self.cfg = cfg
self.entity = entity
self._target_ids_list = target_ids
self._target_names = target_names
self._target_ids: torch.Tensor | None = None
self._ctrl_ids: torch.Tensor | None = None
self._mjs_actuators: list[mujoco.MjsActuator] = []
self._site_zeros: torch.Tensor | None = None
@property
def target_ids(self) -> torch.Tensor:
"""Local indices of targets controlled by this actuator."""
assert self._target_ids is not None
return self._target_ids
@property
def target_names(self) -> list[str]:
"""Names of targets controlled by this actuator."""
return self._target_names
@property
def transmission_type(self) -> TransmissionType:
"""Transmission type of this actuator."""
return self.cfg.transmission_type
@property
def ctrl_ids(self) -> torch.Tensor:
"""Global indices of control inputs for this actuator."""
assert self._ctrl_ids is not None
return self._ctrl_ids
[docs]
@abstractmethod
def edit_spec(self, spec: mujoco.MjSpec, target_names: list[str]) -> None:
"""Edit the MjSpec to add actuators.
This is called during entity construction, before the model is compiled.
Args:
spec: The entity's MjSpec to edit.
target_names: Names of targets (joints, tendons, or sites) controlled by
this actuator.
"""
raise NotImplementedError
[docs]
def initialize(
self,
mj_model: mujoco.MjModel,
model: mjwarp.Model,
data: mjwarp.Data,
device: str,
) -> None:
"""Initialize the actuator after model compilation.
This is called after the MjSpec is compiled into an MjModel.
Args:
mj_model: The compiled MuJoCo model.
model: The compiled mjwarp model.
data: The mjwarp data arrays.
device: Device for tensor operations (e.g., "cuda", "cpu").
"""
del mj_model, model # Unused.
self._target_ids = torch.tensor(
self._target_ids_list, dtype=torch.long, device=device
)
ctrl_ids_list = [act.id for act in self._mjs_actuators]
self._ctrl_ids = torch.tensor(ctrl_ids_list, dtype=torch.long, device=device)
# Pre-allocate zeros for SITE transmission type to avoid repeated allocations.
if self.transmission_type == TransmissionType.SITE:
nenvs = data.nworld
ntargets = len(self._target_ids_list)
self._site_zeros = torch.zeros((nenvs, ntargets), device=device)
[docs]
def get_command(self, data: EntityData) -> ActuatorCmd:
"""Extract command data for this actuator from entity data.
Args:
data: The entity data containing all state and target information.
Returns:
ActuatorCmd with appropriate data based on transmission type.
"""
if self.transmission_type == TransmissionType.JOINT:
return ActuatorCmd(
position_target=data.joint_pos_target[:, self.target_ids],
velocity_target=data.joint_vel_target[:, self.target_ids],
effort_target=data.joint_effort_target[:, self.target_ids],
pos=data.joint_pos[:, self.target_ids],
vel=data.joint_vel[:, self.target_ids],
)
elif self.transmission_type == TransmissionType.TENDON:
return ActuatorCmd(
position_target=data.tendon_len_target[:, self.target_ids],
velocity_target=data.tendon_vel_target[:, self.target_ids],
effort_target=data.tendon_effort_target[:, self.target_ids],
pos=data.tendon_len[:, self.target_ids],
vel=data.tendon_vel[:, self.target_ids],
)
elif self.transmission_type == TransmissionType.SITE:
assert self._site_zeros is not None
return ActuatorCmd(
position_target=self._site_zeros,
velocity_target=self._site_zeros,
effort_target=data.site_effort_target[:, self.target_ids],
pos=self._site_zeros,
vel=self._site_zeros,
)
else:
raise ValueError(f"Unknown transmission type: {self.transmission_type}")
[docs]
@abstractmethod
def compute(self, cmd: ActuatorCmd) -> torch.Tensor:
"""Compute low-level actuator control signal from high-level commands.
Args:
cmd: High-level actuator command.
Returns:
Control signal tensor of shape (num_envs, num_actuators).
"""
raise NotImplementedError
# Optional methods.
[docs]
def reset(self, env_ids: torch.Tensor | slice | None = None) -> None:
"""Reset actuator state for specified environments.
Base implementation does nothing. Override in subclasses that maintain
internal state.
Args:
env_ids: Environment indices to reset. If None, reset all environments.
"""
del env_ids # Unused.
[docs]
def update(self, dt: float) -> None:
"""Update actuator state after a simulation step.
Base implementation does nothing. Override in subclasses that need
per-step updates.
Args:
dt: Time step in seconds.
"""
del dt # Unused.