"""Wrappers for XML-defined actuators.
This module provides wrappers for actuators already defined in robot XML/MJCF files.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, TypeVar
import mujoco
import torch
from mjlab.actuator.actuator import Actuator, ActuatorCfg, ActuatorCmd
if TYPE_CHECKING:
from mjlab.entity import Entity
XmlActuatorCfgT = TypeVar("XmlActuatorCfgT", bound=ActuatorCfg)
class XmlActuator(Actuator[XmlActuatorCfgT], Generic[XmlActuatorCfgT]):
"""Base class for XML-defined actuators."""
def __init__(
self,
cfg: XmlActuatorCfgT,
entity: Entity,
target_ids: list[int],
target_names: list[str],
) -> None:
super().__init__(cfg, entity, target_ids, target_names)
def edit_spec(self, spec: mujoco.MjSpec, target_names: list[str]) -> None:
# Filter to only targets that have corresponding XML actuators.
filtered_target_ids = []
filtered_target_names = []
for i, target_name in enumerate(target_names):
actuator = self._find_actuator_for_target(spec, target_name)
if actuator is not None:
self._mjs_actuators.append(actuator)
filtered_target_ids.append(self._target_ids_list[i])
filtered_target_names.append(target_name)
if len(filtered_target_names) == 0:
raise ValueError(
f"No XML actuators found for any targets matching the patterns. "
f"Searched targets: {target_names}. "
f"XML actuator config expects actuators to already exist in the XML."
)
# Update target IDs and names to only include those with actuators.
self._target_ids_list = filtered_target_ids
self._target_names = filtered_target_names
def _find_actuator_for_target(
self, spec: mujoco.MjSpec, target_name: str
) -> mujoco.MjsActuator | None:
"""Find an actuator that targets the given target (joint, tendon, or site)."""
for actuator in spec.actuators:
if actuator.target == target_name:
return actuator
return None
[docs]
@dataclass(kw_only=True)
class XmlPositionActuatorCfg(ActuatorCfg):
"""Wrap existing XML-defined <position> actuators."""
[docs]
def build(
self, entity: Entity, target_ids: list[int], target_names: list[str]
) -> XmlPositionActuator:
return XmlPositionActuator(self, entity, target_ids, target_names)
[docs]
class XmlPositionActuator(XmlActuator[XmlPositionActuatorCfg]):
"""Wrapper for XML-defined <position> actuators."""
[docs]
def compute(self, cmd: ActuatorCmd) -> torch.Tensor:
return cmd.position_target
[docs]
@dataclass(kw_only=True)
class XmlMotorActuatorCfg(ActuatorCfg):
"""Wrap existing XML-defined <motor> actuators."""
[docs]
def build(
self, entity: Entity, target_ids: list[int], target_names: list[str]
) -> XmlMotorActuator:
return XmlMotorActuator(self, entity, target_ids, target_names)
[docs]
class XmlMotorActuator(XmlActuator[XmlMotorActuatorCfg]):
"""Wrapper for XML-defined <motor> actuators."""
[docs]
def compute(self, cmd: ActuatorCmd) -> torch.Tensor:
return cmd.effort_target
[docs]
@dataclass(kw_only=True)
class XmlVelocityActuatorCfg(ActuatorCfg):
"""Wrap existing XML-defined <velocity> actuators."""
[docs]
def build(
self, entity: Entity, target_ids: list[int], target_names: list[str]
) -> XmlVelocityActuator:
return XmlVelocityActuator(self, entity, target_ids, target_names)
[docs]
class XmlVelocityActuator(XmlActuator[XmlVelocityActuatorCfg]):
"""Wrapper for XML-defined <velocity> actuators."""
[docs]
def compute(self, cmd: ActuatorCmd) -> torch.Tensor:
return cmd.velocity_target
[docs]
@dataclass(kw_only=True)
class XmlMuscleActuatorCfg(ActuatorCfg):
"""Wrap existing XML-defined <muscle> actuators."""
[docs]
def build(
self, entity: Entity, target_ids: list[int], target_names: list[str]
) -> XmlMuscleActuator:
return XmlMuscleActuator(self, entity, target_ids, target_names)
[docs]
class XmlMuscleActuator(XmlActuator[XmlMuscleActuatorCfg]):
"""Wrapper for XML-defined <muscle> actuators."""
[docs]
def compute(self, cmd: ActuatorCmd) -> torch.Tensor:
return cmd.effort_target