Source code for mjlab.actuator.pd_actuator

"""An ideal PD control actuator."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, TypeVar

import mujoco
import mujoco_warp as mjwarp
import torch

from mjlab.actuator.actuator import Actuator, ActuatorCfg, ActuatorCmd
from mjlab.utils.spec import create_motor_actuator

if TYPE_CHECKING:
  from mjlab.entity import Entity

IdealPdCfgT = TypeVar("IdealPdCfgT", bound="IdealPdActuatorCfg")


[docs] @dataclass(kw_only=True) class IdealPdActuatorCfg(ActuatorCfg): """Configuration for ideal PD actuator.""" stiffness: float """PD stiffness (proportional gain).""" damping: float """PD damping (derivative gain).""" effort_limit: float = float("inf") """Maximum force/torque limit."""
[docs] def build( self, entity: Entity, target_ids: list[int], target_names: list[str] ) -> IdealPdActuator: return IdealPdActuator(self, entity, target_ids, target_names)
[docs] class IdealPdActuator(Actuator, Generic[IdealPdCfgT]): """Ideal PD control actuator."""
[docs] def __init__( self, cfg: IdealPdCfgT, entity: Entity, target_ids: list[int], target_names: list[str], ) -> None: super().__init__(cfg, entity, target_ids, target_names) self.stiffness: torch.Tensor | None = None self.damping: torch.Tensor | None = None self.force_limit: torch.Tensor | None = None self.default_stiffness: torch.Tensor | None = None self.default_damping: torch.Tensor | None = None self.default_force_limit: torch.Tensor | None = None
[docs] def edit_spec(self, spec: mujoco.MjSpec, target_names: list[str]) -> None: # Add <motor> actuator to spec, one per target. for target_name in target_names: actuator = create_motor_actuator( spec, target_name, effort_limit=self.cfg.effort_limit, armature=self.cfg.armature, frictionloss=self.cfg.frictionloss, transmission_type=self.cfg.transmission_type, ) self._mjs_actuators.append(actuator)
[docs] def initialize( self, mj_model: mujoco.MjModel, model: mjwarp.Model, data: mjwarp.Data, device: str, ) -> None: super().initialize(mj_model, model, data, device) num_envs = data.nworld num_joints = len(self._target_names) self.stiffness = torch.full( (num_envs, num_joints), self.cfg.stiffness, dtype=torch.float, device=device ) self.damping = torch.full( (num_envs, num_joints), self.cfg.damping, dtype=torch.float, device=device ) self.force_limit = torch.full( (num_envs, num_joints), self.cfg.effort_limit, dtype=torch.float, device=device ) self.default_stiffness = self.stiffness.clone() self.default_damping = self.damping.clone() self.default_force_limit = self.force_limit.clone()
[docs] def compute(self, cmd: ActuatorCmd) -> torch.Tensor: assert self.stiffness is not None assert self.damping is not None pos_error = cmd.position_target - cmd.pos vel_error = cmd.velocity_target - cmd.vel computed_torques = self.stiffness * pos_error computed_torques += self.damping * vel_error computed_torques += cmd.effort_target return self._clip_effort(computed_torques)
def _clip_effort(self, effort: torch.Tensor) -> torch.Tensor: assert self.force_limit is not None return torch.clamp(effort, -self.force_limit, self.force_limit)
[docs] def set_gains( self, env_ids: torch.Tensor | slice, kp: torch.Tensor | None = None, kd: torch.Tensor | None = None, ) -> None: """Set PD gains for specified environments. Args: env_ids: Environment indices to update. kp: New proportional gains. Shape: (num_envs, num_actuators) or (num_envs,). kd: New derivative gains. Shape: (num_envs, num_actuators) or (num_envs,). """ assert self.stiffness is not None assert self.damping is not None if kp is not None: if kp.ndim == 1: kp = kp.unsqueeze(-1) self.stiffness[env_ids] = kp if kd is not None: if kd.ndim == 1: kd = kd.unsqueeze(-1) self.damping[env_ids] = kd
[docs] def set_effort_limit( self, env_ids: torch.Tensor | slice, effort_limit: torch.Tensor ) -> None: """Set effort limits for specified environments. Args: env_ids: Environment indices to update. effort_limit: New effort limits. Shape: (num_envs, num_actuators) or (num_envs,). """ assert self.force_limit is not None if effort_limit.ndim == 1: effort_limit = effort_limit.unsqueeze(-1) self.force_limit[env_ids] = effort_limit