diff --git a/.claude/skills/add-atomic-action/SKILL.md b/.claude/skills/add-atomic-action/SKILL.md new file mode 100644 index 00000000..9ae574a5 --- /dev/null +++ b/.claude/skills/add-atomic-action/SKILL.md @@ -0,0 +1,197 @@ +--- +name: add-atomic-action +description: Use when adding a new observation, event, reward, action, dataset, or randomization functor to an EmbodiChain environment +--- + +# Add Atomic Action + +Scaffold a new atomic action following EmbodiChain's `ActionCfg` / `AtomicAction` pattern. + +## When to Use + +- User asks to add a new motion primitive (push, wipe, insert, hand-over, …) +- User says "add a new atomic action", "create a custom action", "implement a push action" +- User wants to extend `AtomicActionEngine` with a behaviour not covered by the built-ins + +## Key Files + +| Purpose | Path | +|---------|------| +| Base classes (`ActionCfg`, `AtomicAction`, `ObjectSemantics`) | `embodichain/lab/sim/atomic_actions/core.py` | +| Built-in actions (reference implementations) | `embodichain/lab/sim/atomic_actions/actions.py` | +| Engine + global registry (`register_action`) | `embodichain/lab/sim/atomic_actions/engine.py` | +| Public API exports | `embodichain/lab/sim/atomic_actions/__init__.py` | +| Reference docs | `docs/source/overview/sim/atomic_actions.md` | + +## Steps + +### 1. Define the config + +Add a `@configclass`-decorated class that extends `ActionCfg` (or `MoveActionCfg` / +`GraspActionCfg` if the new action reuses arm/gripper fields). + +Place it in `embodichain/lab/sim/atomic_actions/actions.py` alongside the existing configs, +or in a new file if the action is large. + +```python +from embodichain.utils import configclass +from embodichain.lab.sim.atomic_actions.core import ActionCfg # or MoveActionCfg + +@configclass +class PushActionCfg(ActionCfg): + name: str = "push" # must match the registry key + push_distance: float = 0.05 # metres to push forward + push_speed: int = 30 # waypoints for the push phase + control_part: str = "arm" # robot segment to control +``` + +**Rules:** +- `name` must be unique and match the string passed to `register_action`. +- Inherit from `GraspActionCfg` when the action needs hand open/close fields. +- All fields must have defaults — configs are instantiated without arguments in tests. + +### 2. Implement the action class + +Subclass `AtomicAction` and implement the two abstract methods. + +```python +import torch +from typing import Optional, Union +from embodichain.lab.sim.atomic_actions.core import AtomicAction, ObjectSemantics + +class PushAction(AtomicAction): + """Push an object forward by a fixed distance.""" + + def __init__(self, motion_generator, cfg: PushActionCfg | None = None): + super().__init__(motion_generator, cfg=cfg or PushActionCfg()) + self.arm_joint_ids = self.robot.get_joint_ids(name=self.cfg.control_part) + + # ------------------------------------------------------------------ + def execute( + self, + target: Union[torch.Tensor, ObjectSemantics], + start_qpos: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[bool, torch.Tensor, list]: + """Plan the push motion and return a joint trajectory. + + Args: + target: EEF pose tensor (n_envs, 4, 4) or ObjectSemantics. + start_qpos: Starting joint positions (n_envs, dof). Uses current + robot state when None. + + Returns: + Tuple of (is_success, trajectory, joint_ids) where + trajectory has shape (n_envs, n_waypoints, len(joint_ids)). + """ + # 1. Resolve target pose + # 2. Plan trajectory with self.motion_generator + # 3. Return result + return is_success, trajectory, self.arm_joint_ids + + # ------------------------------------------------------------------ + def validate( + self, + target: Union[torch.Tensor, ObjectSemantics], + start_qpos: Optional[torch.Tensor] = None, + **kwargs, + ) -> bool: + """Fast feasibility check — no trajectory generated. + + Returns: + True if the action can be attempted. + """ + return True # add IK reachability check here if needed +``` + +**Rules:** +- `execute()` must always return `(is_success: bool, trajectory: Tensor, joint_ids: list)`. +- `trajectory` shape: `(n_envs, n_waypoints, len(joint_ids))`. +- `joint_ids` tells the engine which DOF columns the trajectory covers. +- `validate()` must be cheap — no motion planning allowed. +- Call `super().__init__()` — it sets `self.robot`, `self.motion_generator`, and `self.cfg`. + +### 3. Register the action + +Register the new class so `AtomicActionEngine` can discover it by name. + +**Option A — register at module load (built-ins style)** + +In `embodichain/lab/sim/atomic_actions/engine.py`, add to the `_builtin_action_map` dict: + +```python +_builtin_action_map: dict[str, type[AtomicAction]] = { + "move": MoveAction, + "pickup": PickUpAction, + "place": PlaceAction, + "push": PushAction, # ← add here +} +``` + +**Option B — register at runtime (custom / plugin style)** + +```python +from embodichain.lab.sim.atomic_actions import register_action +register_action("push", PushAction) +``` + +### 4. Export from the public API + +Add config and action class to `embodichain/lab/sim/atomic_actions/__init__.py`: + +```python +from .actions import PushAction, PushActionCfg + +__all__ = [ + ..., + "PushAction", + "PushActionCfg", +] +``` + +### 5. Update the supported actions table + +Add a row to the table in `docs/source/overview/sim/atomic_actions.md` under +"Supported Actions": + +```markdown +| `PushAction` | `PushActionCfg` | `Tensor (4,4)` — contact pose | Approach → push forward | +``` + +### 6. Write a test + +Add a test in `tests/sim/atomic_actions/` (append to an existing file or create a new one): + +```python +def test_push_action_cfg_defaults(): + cfg = PushActionCfg() + assert cfg.name == "push" + assert cfg.push_distance == 0.05 + +def test_push_action_validate(mock_motion_generator): + action = PushAction(mock_motion_generator) + assert action.validate(target=torch.eye(4)) is True +``` + +## Common Mistakes + +| Mistake | Fix | +|---------|-----| +| `name` in config doesn't match registry key | Keep `cfg.name` identical to the string in `register_action("push", ...)` | +| Returning `trajectory` without `joint_ids` | Always return the 3-tuple `(bool, Tensor, list)` | +| `trajectory` shape `(n_envs, dof, n_waypoints)` | Correct shape is `(n_envs, n_waypoints, dof)` | +| Doing motion planning inside `validate()` | `validate()` must be fast — IK check only | +| Not calling `super().__init__()` | Required to set `self.robot`, `self.motion_generator`, `self.cfg` | +| Inheriting `MoveActionCfg` instead of `ActionCfg` | Use `MoveActionCfg` only when the action reuses arm-control fields; otherwise use `ActionCfg` | +| Forgetting to export from `__init__.py` | Users import from the public API — missing exports cause `ImportError` | + +## Quick Reference + +| Step | Action | +|------|--------| +| 1 | Define `@configclass` config extending `ActionCfg` with `name` field | +| 2 | Subclass `AtomicAction`, implement `execute()` and `validate()` | +| 3 | Register: add to `_builtin_action_map` or call `register_action()` | +| 4 | Export from `__init__.py` | +| 5 | Add row to supported-actions table in overview docs | +| 6 | Write tests for config defaults and `validate()` | diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.atomic_actions.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.atomic_actions.rst new file mode 100644 index 00000000..181086c3 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.atomic_actions.rst @@ -0,0 +1,89 @@ +embodichain.lab.sim.atomic_actions +================================== + +.. automodule:: embodichain.lab.sim.atomic_actions + + .. rubric:: Classes + + .. autosummary:: + + Affordance + InteractionPoints + ObjectSemantics + ActionCfg + AtomicAction + MoveActionCfg + MoveAction + PickUpActionCfg + PickUpAction + PlaceActionCfg + PlaceAction + AtomicActionEngine + +.. currentmodule:: embodichain.lab.sim.atomic_actions + +Core +---- + +.. autoclass:: Affordance + :members: + :show-inheritance: + +.. autoclass:: InteractionPoints + :members: + :show-inheritance: + +.. autoclass:: ObjectSemantics + :members: + :show-inheritance: + +.. autoclass:: ActionCfg + :members: + :exclude-members: __init__, copy, replace, to_dict, validate + +.. autoclass:: AtomicAction + :members: + :show-inheritance: + +Actions +------- + +.. autoclass:: MoveActionCfg + :members: + :exclude-members: __init__, copy, replace, to_dict, validate + :show-inheritance: + +.. autoclass:: MoveAction + :members: + :show-inheritance: + +.. autoclass:: PickUpActionCfg + :members: + :exclude-members: __init__, copy, replace, to_dict, validate + :show-inheritance: + +.. autoclass:: PickUpAction + :members: + :show-inheritance: + +.. autoclass:: PlaceActionCfg + :members: + :exclude-members: __init__, copy, replace, to_dict, validate + :show-inheritance: + +.. autoclass:: PlaceAction + :members: + :show-inheritance: + +Engine & Registry +----------------- + +.. autoclass:: AtomicActionEngine + :members: + :show-inheritance: + +.. autofunction:: register_action + +.. autofunction:: unregister_action + +.. autofunction:: get_registered_actions diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.rst index 9977639d..412f570d 100644 --- a/docs/source/api_reference/embodichain/embodichain.lab.sim.rst +++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.rst @@ -22,8 +22,9 @@ management, materials, sensors, planning/IK utilities, and action helpers. objects robots sensors - planners solvers + planners + atomic_actions types utility @@ -125,6 +126,13 @@ Planners embodichain.lab.sim.planners +Atomic Actions +-------------- + +.. toctree:: + :maxdepth: 1 + + embodichain.lab.sim.atomic_actions Shared Types ------------ diff --git a/docs/source/introduction.rst b/docs/source/introduction.rst index eae35d96..d437b4fb 100644 --- a/docs/source/introduction.rst +++ b/docs/source/introduction.rst @@ -44,11 +44,11 @@ Getting Started To get started with EmbodiChain, follow these steps: - `Installation - Guide `__ + Guide `__ - `Quick Start - Tutorial `__ + Tutorial `__ - `API - Reference `__ + Reference `__ Contribution Guide ------------------ diff --git a/docs/source/overview/sim/atomic_actions.md b/docs/source/overview/sim/atomic_actions.md new file mode 100644 index 00000000..979df571 --- /dev/null +++ b/docs/source/overview/sim/atomic_actions.md @@ -0,0 +1,241 @@ +# Atomic Actions + +```{currentmodule} embodichain.lab.sim.atomic_actions +``` + +Atomic actions are the building blocks for automated robot motion generation. Each action encapsulates a complete, self-contained motion primitive — such as picking up an object or moving to a pose — that can be chained together to form complex manipulation workflows. + +## Design Overview + +The module is organized into three layers: + +``` +AtomicActionEngine ← orchestrates a sequence of actions + │ + ├── AtomicAction(s) ← each action plans one motion primitive + │ │ + │ └── MotionGenerator ← low-level trajectory planner (IK + trajectory optimization) + │ + └── SemanticAnalyzer ← resolves object labels → ObjectSemantics +``` + +Each action receives a target (object semantics or a pose tensor), runs its planning pipeline, +and returns a joint trajectory. The engine threads the end state of each action as the start +state of the next, then concatenates all trajectories into one contiguous sequence: + +``` +ObjectSemantics ──► AffordanceEstimation ──► AtomicAction.execute() +(label + geometry │ + + affordance ├─ IK solve + + entity) ├─ Motion plan + └─ Gripper interpolation + │ +AtomicActionEngine ◄─────────────── PlanResult ───────┘ +(sequences actions, accumulates + full-robot trajectory) +``` + +### Core Concepts + +**`ObjectSemantics`** describes an interaction target. It bundles: +- `geometry` — mesh data (vertices, triangles) used for grasp annotation +- `affordance` — *how* to interact with the object (e.g. antipodal grasp poses) +- `entity` — a live reference to the simulation object, so actions can read its current pose + +**`Affordance`** is a data class that encodes a specific interaction capability. The built-in affordance types are: + +| Class | Use case | +|---|---| +| `AntipodalAffordance` | Parallel-jaw grasping via antipodal point pairs | +| `InteractionPoints` | Contact-based interactions (push, poke, touch) | + +**`AtomicAction`** is the abstract base class for all motion primitives. Every action must implement: +- `execute(target, start_qpos)` — plan and return a joint trajectory +- `validate(target, start_qpos)` — fast feasibility check without full planning + +**`AtomicActionEngine`** manages a named registry of actions and runs them in sequence via `execute_static()`, threading the end state of each action as the start state of the next. + +--- + +## Built-in Actions + +(supported_atomic_actions)= + +The following actions are available out of the box: + +| Action | Config class | Target type | Motion phases | +|---|---|---|---| +| `MoveAction` | `MoveActionCfg` | `Tensor (4,4)` — EEF pose | Move arm to pose | +| `PickUpAction` | `PickUpActionCfg` | `ObjectSemantics` or `Tensor (4,4)` | Approach → close gripper → lift | +| `PlaceAction` | `PlaceActionCfg` | `Tensor (4,4)` — EEF release pose | Lower → open gripper → retract | + +### `MoveAction` + +Moves the end-effector to a target pose in free space. + +| Config field | Default | Description | +|---|---|---| +| `control_part` | `"arm"` | Robot control part to move | +| `sample_interval` | `50` | Number of waypoints in the trajectory | + +**Target:** `torch.Tensor` of shape `(4, 4)` or `(n_envs, 4, 4)` — a homogeneous EEF pose. + +--- + +### `PickUpAction` + +Three-phase grasp motion: *approach → close gripper → lift*. + +| Config field | Default | Description | +|---|---|---| +| `approach_direction` | `[0, 0, -1]` | Gripper approach direction in object frame | +| `pre_grasp_distance` | `0.15` | Hover distance before descending (m) | +| `lift_height` | `0.10` | Lift height after grasping (m) | +| `hand_open_qpos` | `None` | **Required.** Gripper open joint positions | +| `hand_close_qpos` | `None` | **Required.** Gripper closed joint positions | +| `hand_control_part` | `"hand"` | Robot control part for the gripper | +| `hand_interp_steps` | `5` | Waypoints for the gripper close phase | +| `sample_interval` | `80` | Total waypoints across all three phases | + +**Target:** `ObjectSemantics` (grasp pose computed automatically) **or** a `torch.Tensor` EEF pose. + +--- + +### `PlaceAction` + +Three-phase release motion: *lower → open gripper → retract*. Mirrors `PickUpAction`. + +Inherits all gripper config fields from `GraspActionCfg`. The `approach_direction` field is not used — the arm moves straight down to the target pose. + +**Target:** `torch.Tensor` of shape `(4, 4)` or `(n_envs, 4, 4)` — the EEF pose at release. + +--- + +## Typical Workflow + +```python +from embodichain.lab.sim.atomic_actions import ( + AtomicActionEngine, + ObjectSemantics, + AntipodalAffordance, + PickUpActionCfg, + PlaceActionCfg, + MoveActionCfg, +) + +# 1. Configure each action +pickup_cfg = PickUpActionCfg( + control_part="arm", + hand_control_part="hand", + hand_open_qpos=torch.tensor([0.0, 0.0]), + hand_close_qpos=torch.tensor([0.025, 0.025]), +) +place_cfg = PlaceActionCfg(...) +move_cfg = MoveActionCfg(control_part="arm") + +# 2. Build the engine — action order matches target_list order +engine = AtomicActionEngine( + motion_generator=motion_gen, + actions_cfg_list=[pickup_cfg, place_cfg, move_cfg], +) + +# 3. Describe the object to pick +semantics = ObjectSemantics( + label="mug", + geometry={"mesh_vertices": ..., "mesh_triangles": ...}, + affordance=AntipodalAffordance(object_label="mug", ...), + entity=mug, +) + +# 4. Plan the full sequence and replay +is_success, traj = engine.execute_static( + target_list=[semantics, place_pose, rest_pose] +) +# traj: (n_envs, n_waypoints, dof) +``` + +--- + +## How to Extend: Adding a Custom Action + +You can add any motion primitive by subclassing `AtomicAction` and registering it with the engine. + +### Step 1 — Define the config + +```python +from embodichain.utils import configclass +from embodichain.lab.sim.atomic_actions import ActionCfg + +@configclass +class PushActionCfg(ActionCfg): + name: str = "push" + push_distance: float = 0.05 # metres to push forward + push_speed: int = 30 # waypoints for the push phase +``` + +### Step 2 — Implement the action + +```python +import torch +from typing import Optional, Union +from embodichain.lab.sim.atomic_actions import AtomicAction, ObjectSemantics +from embodichain.lab.sim.planners import PlanState, MoveType + +class PushAction(AtomicAction): + def __init__(self, motion_generator, cfg: PushActionCfg | None = None): + super().__init__(motion_generator, cfg=cfg or PushActionCfg()) + self.arm_joint_ids = self.robot.get_joint_ids(name=self.cfg.control_part) + + def execute( + self, + target: Union[torch.Tensor, ObjectSemantics], + start_qpos: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[bool, torch.Tensor, list]: + # Resolve target to a batched [n_envs, 4, 4] EEF pose + # ... your planning logic here ... + return is_success, trajectory, self.arm_joint_ids + + def validate(self, target, start_qpos=None, **kwargs) -> bool: + return True # add IK check here if needed +``` + +### Step 3 — Register and use + +```python +from embodichain.lab.sim.atomic_actions import register_action + +register_action("push", PushAction, PushActionCfg) + +engine = AtomicActionEngine( + motion_generator=motion_gen, + actions_cfg_list=[PushActionCfg(push_distance=0.08)], +) +is_success, traj = engine.execute_static(target_list=[target_pose]) +``` + +> **Tip:** The `execute()` return signature is always `(is_success, trajectory, joint_ids)`. +> `trajectory` has shape `(n_envs, n_waypoints, len(joint_ids))`. +> `joint_ids` tells the engine which columns of the full robot DOF vector the trajectory covers. + +--- + +## Target Resolution + +`AtomicActionEngine` accepts several target formats in `target_list`, giving you flexibility without boilerplate: + +| Input type | Resolved to | +|---|---| +| `torch.Tensor (4,4)` or `(n_envs,4,4)` | EEF pose, broadcast across envs | +| `ObjectSemantics` | Passed directly to the action | +| `str` (object label) | Looked up in `SemanticAnalyzer` cache | +| `dict` with `"pose"` key | Unwrapped to tensor | +| `dict` with `"label"` key | Analyzed via `SemanticAnalyzer` | + +--- + +## Further Reading + +- {doc}`planners/motion_generator` — the trajectory planner used by every action +- {doc}`sim_robot` — how control parts and IK solvers are configured +- Tutorial: `scripts/tutorials/sim/atomic_actions.py` diff --git a/docs/source/overview/sim/index.rst b/docs/source/overview/sim/index.rst index 56f98ef2..60cdfd56 100644 --- a/docs/source/overview/sim/index.rst +++ b/docs/source/overview/sim/index.rst @@ -22,3 +22,4 @@ Overview of the Simulation Framework: sim_sensor.md solvers/index planners/index + atomic_actions.md diff --git a/docs/source/tutorial/atomic_actions.rst b/docs/source/tutorial/atomic_actions.rst new file mode 100644 index 00000000..10b8e97c --- /dev/null +++ b/docs/source/tutorial/atomic_actions.rst @@ -0,0 +1,170 @@ +.. _tutorial_atomic_actions: + +Atomic Actions +============== + +EmbodiChain's **atomic action** layer provides a high-level, composable interface for common +manipulation primitives such as *move*, *pick up*, and *place*. Each action encapsulates the +full planning pipeline — grasp-pose estimation, IK, trajectory generation, and gripper +interpolation — behind a single ``execute()`` call, making it straightforward to chain +multiple actions together into complex robot behaviours. + +Key Features +------------ + +- **Semantic-aware execution** — actions accept either a raw pose tensor or an + ``ObjectSemantics`` descriptor that bundles affordance data (grasp poses, interaction + points) with the simulation entity. +- **Three built-in primitives** — ``MoveAction``, ``PickUpAction``, and ``PlaceAction`` + cover the most common tabletop manipulation workflows out of the box. + See the :ref:`supported_atomic_actions` table for configs and target types. +- **Extensible registry** — custom actions can be registered globally with + ``register_action`` and discovered by the engine at runtime. +- **Engine orchestration** — ``AtomicActionEngine`` sequences multiple actions, + threads ``start_qpos`` from one action to the next, and returns a single concatenated + trajectory ready to replay in the simulator. + +For the full design overview, architecture diagram, and extension guide see +:doc:`/overview/sim/atomic_actions`. + +The Code +-------- + +The tutorial corresponds to the ``atomic_actions.py`` script in the ``scripts/tutorials/sim`` +directory. + +.. dropdown:: Code for atomic_actions.py + :icon: code + + .. literalinclude:: ../../../scripts/tutorials/sim/atomic_actions.py + :language: python + :linenos: + +Typical Usage +------------- + +Setting up the engine +~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + import torch + from embodichain.lab.sim.planners import MotionGenerator, MotionGenCfg + from embodichain.lab.sim.atomic_actions import ( + AtomicActionEngine, + PickUpActionCfg, + PlaceActionCfg, + MoveActionCfg, + ) + + motion_gen = MotionGenerator(cfg=MotionGenCfg(...)) + + hand_open = torch.tensor([0.00, 0.00], dtype=torch.float32, device=device) + hand_close = torch.tensor([0.025, 0.025], dtype=torch.float32, device=device) + + pickup_cfg = PickUpActionCfg( + hand_open_qpos=hand_open, + hand_close_qpos=hand_close, + control_part="arm", + hand_control_part="hand", + approach_direction=torch.tensor([0.0, 0.0, -1.0], dtype=torch.float32, device=device), + pre_grasp_distance=0.15, + lift_height=0.15, + ) + place_cfg = PlaceActionCfg( + hand_open_qpos=hand_open, + hand_close_qpos=hand_close, + control_part="arm", + hand_control_part="hand", + lift_height=0.15, + ) + move_cfg = MoveActionCfg(control_part="arm") + + engine = AtomicActionEngine( + motion_generator=motion_gen, + actions_cfg_list=[pickup_cfg, place_cfg, move_cfg], + ) + +Defining object semantics +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + from embodichain.lab.sim.atomic_actions import ( + ObjectSemantics, + AntipodalAffordance, + ) + from embodichain.toolkits.graspkit.pg_grasp import GraspGeneratorCfg, AntipodalSamplerCfg + from embodichain.toolkits.graspkit.pg_grasp.gripper_collision_checker import GripperCollisionCfg + + affordance = AntipodalAffordance( + object_label="mug", + force_reannotate=False, + custom_config={ + "gripper_collision_cfg": GripperCollisionCfg( + max_open_length=0.088, finger_length=0.078, point_sample_dense=0.012 + ), + "generator_cfg": GraspGeneratorCfg( + antipodal_sampler_cfg=AntipodalSamplerCfg( + n_sample=20000, max_length=0.088, min_length=0.003 + ) + ), + }, + ) + + semantics = ObjectSemantics( + label="mug", + geometry={ + "mesh_vertices": mug.get_vertices(env_ids=[0], scale=True)[0], + "mesh_triangles": mug.get_triangles(env_ids=[0])[0], + }, + affordance=affordance, + entity=mug, # required so the action can query the live object pose + ) + +Executing a pick-place-move sequence +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + place_xpos = ... # torch.Tensor [4, 4] — target placement pose + rest_xpos = ... # torch.Tensor [4, 4] — resting pose after placing + + is_success, trajectory = engine.execute_static( + target_list=[semantics, place_xpos, rest_xpos] + ) + # trajectory: [n_envs, n_waypoints, robot_dof] + + for i in range(trajectory.shape[1]): + robot.set_qpos(trajectory[:, i]) + sim.update(step=4) + +Registering custom actions +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + from embodichain.lab.sim.atomic_actions import AtomicAction, ActionCfg, register_action + + class PushAction(AtomicAction): + def execute(self, target, start_qpos=None, **kwargs): + # ... your planning logic ... + return is_success, trajectory, joint_ids + + def validate(self, target, start_qpos=None, **kwargs): + return True # quick feasibility check + + register_action("push", PushAction) + +Notes & Best Practices +---------------------- + +- ``PickUpAction`` expects an ``AntipodalAffordance`` with valid mesh data + (``mesh_vertices`` / ``mesh_triangles``) so the grasp generator can annotate the object. + Set ``force_reannotate=False`` (the default) to reuse cached annotations across episodes. +- ``ObjectSemantics.entity`` must be set when using semantic targets so the action can read + the object's current world pose at planning time. +- For static (non-physics) playback, iterate over ``trajectory[:, i]`` and call + ``robot.set_qpos`` directly; for physics-enabled playback, feed waypoints through your + controller or gym wrapper instead. +- To add a new action type, see :doc:`/overview/sim/atomic_actions`. diff --git a/docs/source/tutorial/index.rst b/docs/source/tutorial/index.rst index ef6efe79..c73b3d04 100644 --- a/docs/source/tutorial/index.rst +++ b/docs/source/tutorial/index.rst @@ -14,6 +14,7 @@ Tutorials solver sensor motion_gen + atomic_actions gizmo basic_env modular_env diff --git a/embodichain/lab/sim/atomic_actions/__init__.py b/embodichain/lab/sim/atomic_actions/__init__.py new file mode 100644 index 00000000..cf1e60ce --- /dev/null +++ b/embodichain/lab/sim/atomic_actions/__init__.py @@ -0,0 +1,67 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Atomic action abstraction layer for embodied AI motion generation. + +This module provides a unified interface for atomic actions like reach, grasp, +move, etc., with support for semantic object understanding and extensible +custom action registration. +""" + +from .core import ( + Affordance, + AntipodalAffordance, + InteractionPoints, + ObjectSemantics, + ActionCfg, + AtomicAction, +) +from .actions import ( + MoveAction, + PickUpAction, + PlaceAction, + MoveActionCfg, + PickUpActionCfg, + PlaceActionCfg, +) +from .engine import ( + AtomicActionEngine, + register_action, + unregister_action, + get_registered_actions, +) + +__all__ = [ + # Core classes + "Affordance", + "GraspPose", + "InteractionPoints", + "ObjectSemantics", + "ActionCfg", + "AtomicAction", + # Action implementations + "MoveAction", + "PickUpAction", + "PlaceAction", + "MoveActionCfg", + "PickUpActionCfg", + "PlaceActionCfg", + # Engine + "AtomicActionEngine", + "register_action", + "unregister_action", + "get_registered_actions", +] diff --git a/embodichain/lab/sim/atomic_actions/actions.py b/embodichain/lab/sim/atomic_actions/actions.py new file mode 100644 index 00000000..4f2698de --- /dev/null +++ b/embodichain/lab/sim/atomic_actions/actions.py @@ -0,0 +1,634 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +from typing import Optional, Union, TYPE_CHECKING, Any + +from embodichain.lab.sim.planners import PlanResult, PlanState, MoveType +from embodichain.lab.sim.planners.motion_generator import MotionGenOptions +from embodichain.lab.sim.planners.toppra_planner import ToppraPlanOptions +from .core import AtomicAction, ObjectSemantics, AntipodalAffordance, ActionCfg +from embodichain.utils import logger +from embodichain.utils import configclass +from embodichain.lab.sim.utility.action_utils import interpolate_with_distance +import numpy as np + +if TYPE_CHECKING: + from embodichain.lab.sim.planners import MotionGenerator + from embodichain.lab.sim.objects import Robot + + +@configclass +class MoveActionCfg(ActionCfg): + name: str = "move" + """Name of the action, used for identification and logging.""" + + sample_interval: int = 50 + """Number of waypoints to sample for the motion trajectory. Should be large enough to ensure smooth motion, but not too large to cause unnecessary computation overhead.""" + + +@configclass +class GraspActionCfg(MoveActionCfg): + """Shared configuration for actions that involve gripper open/close motions.""" + + hand_open_qpos: torch.Tensor | None = None + """[hand_dof,] of float. Joint positions for open hand state.""" + + hand_close_qpos: torch.Tensor | None = None + """[hand_dof,] of float. Joint positions for closed hand state.""" + + hand_control_part: str = "hand" + """Name of the robot part that controls the hand joints.""" + + lift_height: float = 0.1 + """Height (m) to lift the end-effector after the gripper phase.""" + + sample_interval: int = 80 + """Number of waypoints for the full trajectory (approach + hand + lift/back).""" + + hand_interp_steps: int = 5 + """Number of waypoints for the gripper open/close interpolation phase.""" + + +class MoveAction(AtomicAction): + def __init__( + self, + motion_generator: MotionGenerator, + cfg: MoveActionCfg | None = None, + ): + """ + Initialize the atomic action. + Args: + motion_generator: The motion generator instance to use for planning. + cfg: Configuration for the action. + """ + super().__init__( + motion_generator, cfg=cfg if cfg is not None else MoveActionCfg() + ) + + self.n_envs = self.robot.get_qpos().shape[0] + self.arm_joint_ids = self.robot.get_joint_ids(name=self.cfg.control_part) + self.dof = len(self.arm_joint_ids) + + def _resolve_pose_target( + self, + target: Union[ObjectSemantics, torch.Tensor], + *, + action_name: str, + ) -> tuple[bool, torch.Tensor]: + """Resolve a pose target into a batched homogeneous transform tensor.""" + if isinstance(target, ObjectSemantics): + logger.log_error( + f"{action_name} currently does not support ObjectSemantics target. " + f"Please provide target pose as torch.Tensor of shape (4, 4) or " + f"(n_envs, 4, 4)", + NotImplementedError, + ) + if not isinstance(target, torch.Tensor): + logger.log_error( + "Target must be either ObjectSemantics or torch.Tensor of shape " + f"(4, 4) or ({self.n_envs}, 4, 4)", + TypeError, + ) + + if target.shape == (4, 4): + target = target.unsqueeze(0).repeat(self.n_envs, 1, 1) + if target.shape != (self.n_envs, 4, 4): + logger.log_error( + f"Target tensor must have shape (4, 4) or ({self.n_envs}, 4, 4), but got {target.shape}", + ValueError, + ) + return True, target + + def _resolve_start_qpos( + self, + start_qpos: Optional[torch.Tensor], + arm_dof: Optional[int] = None, + ) -> torch.Tensor: + """Resolve planning start joint positions into batched arm joint positions.""" + arm_dof = self.dof if arm_dof is None else arm_dof + if start_qpos is None: + start_qpos = self.robot.get_qpos(name=self.cfg.control_part) + if start_qpos.shape == (arm_dof,): + start_qpos = start_qpos.unsqueeze(0).repeat(self.n_envs, 1) + if start_qpos.shape != (self.n_envs, arm_dof): + logger.log_error( + f"start_qpos must have shape ({self.n_envs}, {arm_dof}), but got {start_qpos.shape}", + ValueError, + ) + return start_qpos + + def _compute_three_phase_waypoints( + self, + hand_interp_steps: int, + *, + first_phase_name: str, + third_phase_name: str, + first_phase_ratio: float = 0.6, + ) -> tuple[int, int, int]: + """Split total sample interval into motion, hand interpolation, and motion phases.""" + first_phase_waypoint = int( + np.round(self.cfg.sample_interval - hand_interp_steps) * first_phase_ratio + ) + if first_phase_waypoint < 2: + logger.log_error( + f"Not enough waypoints for {first_phase_name} trajectory. " + "Please increase sample_interval or decrease hand_interp_steps.", + ValueError, + ) + second_phase_waypoint = hand_interp_steps + third_phase_waypoint = ( + self.cfg.sample_interval - first_phase_waypoint - second_phase_waypoint + ) + if third_phase_waypoint < 2: + logger.log_error( + f"Not enough waypoints for {third_phase_name} trajectory. " + "Please increase sample_interval or decrease hand_interp_steps.", + ValueError, + ) + return first_phase_waypoint, second_phase_waypoint, third_phase_waypoint + + def _build_motion_gen_options( + self, + start_qpos: torch.Tensor, + sample_interval: int, + ) -> MotionGenOptions: + """Build default motion generation options for an atomic action.""" + return MotionGenOptions( + start_qpos=start_qpos[0], + control_part=self.cfg.control_part, + is_interpolate=True, + is_linear=False, + interpolate_position_step=0.001, + plan_opts=ToppraPlanOptions( + sample_interval=sample_interval, + ), + ) + + def _plan_arm_trajectory( + self, + target_states_list: list[list[PlanState]], + start_qpos: torch.Tensor, + n_waypoints: int, + arm_dof: Optional[int] = None, + ) -> tuple[bool, torch.Tensor]: + """Plan batched arm trajectories for all environments.""" + arm_dof = self.dof if arm_dof is None else arm_dof + + n_state = len(target_states_list[0]) + xpos_traj = torch.zeros( + size=(self.n_envs, n_state, 4, 4), dtype=torch.float32, device=self.device + ) + for i, target_states in enumerate(target_states_list): + for j, target_state in enumerate(target_states): + # [env_i, state_j, 4, 4] + xpos_traj[i, j] = target_state.xpos + + trajectory = torch.zeros( + size=(self.n_envs, n_state, arm_dof), + dtype=torch.float32, + device=self.device, + ) + qpos_seed = start_qpos + for j in range(n_state): + is_success, qpos = self.robot.compute_ik( + pose=xpos_traj[:, j], name=self.cfg.control_part, joint_seed=qpos_seed + ) + if not is_success: + logger.log_warning( + f"Failed to compute IK for target state {j} in some environments. " + "The resulting trajectory may be invalid." + ) + return False, trajectory + else: + trajectory[:, j] = qpos + qpos_seed = qpos + trajectory = torch.concatenate([start_qpos.unsqueeze(1), trajectory], dim=1) + interp_traj = interpolate_with_distance( + trajectory=trajectory, interp_num=n_waypoints, device=self.device + ) + return True, interp_traj + + def _interpolate_hand_qpos( + self, + start_hand_qpos: torch.Tensor, + end_hand_qpos: torch.Tensor, + n_waypoints: int, + ) -> torch.Tensor: + """Interpolate hand joint positions between two gripper states.""" + weights = torch.linspace(0, 1, steps=n_waypoints, device=self.device) + hand_qpos_list = [ + torch.lerp(start_hand_qpos, end_hand_qpos, weight) for weight in weights + ] + return torch.stack(hand_qpos_list, dim=0) + + def execute( + self, + target: Union[ObjectSemantics, torch.Tensor], + start_qpos: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[bool, torch.Tensor, list[float]]: + """execute pick up action + + Args: + target (ObjectSemantics): object semantics containing grasp affordance and entity information + start_qpos (Optional[torch.Tensor], optional): Planning start qpos. Defaults to None. + + Returns: + tuple[bool, torch.Tensor, list[float]]: + is_success, + trajectory of shape (n_envs, n_waypoints, dof), + joint_ids corresponding to trajectory + """ + is_success, move_xpos = self._resolve_pose_target( + target, action_name=self.__class__.__name__ + ) + start_qpos = self._resolve_start_qpos(start_qpos) + + # TODO: warning and fallback if no valid grasp pose found + if not is_success: + logger.log_warning( + "Failed to resolve grasp pose, using default approach pose" + ) + return False, torch.empty(0), self.arm_joint_ids + + target_states_list = [ + [ + PlanState(xpos=move_xpos[i], move_type=MoveType.EEF_MOVE), + ] + for i in range(self.n_envs) + ] + is_plan_success, trajectory = self._plan_arm_trajectory( + target_states_list, start_qpos, self.cfg.sample_interval + ) + return is_plan_success, trajectory, self.arm_joint_ids + + def validate(self, target, start_qpos=None, **kwargs): + # TODO: implement proper validation logic for pick up action + return True + + +@configclass +class PickUpActionCfg(GraspActionCfg): + name: str = "pick_up" + """Name of the action, used for identification and logging.""" + + pre_grasp_distance: float = 0.15 + """Distance to offset back from the grasp pose along the approach direction to get + the pre-grasp pose. Should be large enough to avoid collision during approach.""" + + approach_direction: torch.Tensor = torch.tensor([0, 0, -1], dtype=torch.float32) + """Direction from which the gripper approaches the object for grasping, expressed + in the object local frame. Default [0, 0, -1] means approaching from above.""" + + +class PickUpAction(MoveAction): + def __init__( + self, + motion_generator: MotionGenerator, + cfg: PickUpActionCfg | None = None, + ): + """ + Initialize the atomic action. + Args: + motion_generator: The motion generator instance to use for planning. + cfg: Configuration for the action. + """ + super().__init__( + motion_generator, cfg=cfg if cfg is not None else PickUpActionCfg() + ) + self.cfg = cfg + self.approach_direction = self.cfg.approach_direction.to(self.device) + if self.cfg.hand_open_qpos is None: + logger.log_error("hand_open_qpos must be specified in PickUpActionCfg") + if self.cfg.hand_close_qpos is None: + logger.log_error("hand_close_qpos must be specified in PickUpActionCfg") + self.hand_open_qpos = self.cfg.hand_open_qpos.to(self.device) + self.hand_close_qpos = self.cfg.hand_close_qpos.to(self.device) + + self.hand_joint_ids = self.robot.get_joint_ids(name=self.cfg.hand_control_part) + self.joint_ids = self.arm_joint_ids + self.hand_joint_ids + self.arm_dof = len(self.arm_joint_ids) + self.dof = len(self.joint_ids) + + def execute( + self, + target: Union[ObjectSemantics, torch.Tensor], + start_qpos: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[bool, torch.Tensor, list[float]]: + """execute pick up action + + Args: + target (Union[ObjectSemantics, torch.Tensor]): target object semantics or target pose for grasping + start_qpos (Optional[torch.Tensor], optional): Planning start qpos. Defaults to None. + + Returns: + tuple[bool, torch.Tensor, list[float]]: + is_success, + trajectory of shape (n_envs, n_waypoints, dof), + joint_ids corresponding to trajectory + """ + + # Resolve grasp pose + if isinstance(target, ObjectSemantics): + is_success, grasp_xpos, open_length = self._resolve_grasp_pose(target) + else: + is_success, grasp_xpos = self._resolve_pose_target( + target, action_name=self.__class__.__name__ + ) + + # TODO: warning and fallback if no valid grasp pose found + if not is_success: + logger.log_warning( + "Failed to resolve grasp pose, using default approach pose" + ) + return False, torch.empty(0), self.joint_ids + + # Compute pre-grasp pose + # TODO: only for parallel gripper, approach in negative grasp z direction + grasp_z = grasp_xpos[:, :3, 2] + pre_grasp_xpos = self._apply_offset( + pose=grasp_xpos, + offset=-grasp_z * self.cfg.pre_grasp_distance, + ) + # Compute lift pose + start_qpos = self._resolve_start_qpos(start_qpos, self.arm_dof) + + # compute waypoint number for each phase + n_approach_waypoint, n_close_waypoint, n_lift_waypoint = ( + self._compute_three_phase_waypoints( + self.cfg.hand_interp_steps, + first_phase_name="approach", + third_phase_name="lift", + ) + ) + + # get pick trajectory + target_states_list = [ + [ + PlanState(xpos=pre_grasp_xpos[i], move_type=MoveType.EEF_MOVE), + PlanState(xpos=grasp_xpos[i], move_type=MoveType.EEF_MOVE), + ] + for i in range(self.n_envs) + ] + pick_trajectory = torch.zeros( + size=(self.n_envs, n_approach_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + is_success, plan_traj = self._plan_arm_trajectory( + target_states_list, + start_qpos, + n_approach_waypoint, + self.arm_dof, + ) + if not is_success: + logger.log_warning("Failed to plan approach trajectory.") + return False, pick_trajectory, self.joint_ids + pick_trajectory[:, :, : self.arm_dof] = plan_traj + # Padding hand open qpos to pick trajectory + pick_trajectory[:, :, self.arm_dof :] = self.hand_open_qpos + + # get hand closing trajectory + grasp_qpos = pick_trajectory[ + :, -1, : self.arm_dof + ] # Assuming the last point of pick trajectory is the grasp pose + hand_close_path = self._interpolate_hand_qpos( + self.hand_open_qpos, + self.hand_close_qpos, + n_close_waypoint, + ) + hand_close_trajectory = torch.zeros( + size=(self.n_envs, n_close_waypoint, self.dof), + device=self.device, + ) + hand_close_trajectory[:, :, : self.arm_dof] = grasp_qpos + hand_close_trajectory[:, :, self.arm_dof :] = hand_close_path + + # get lift trajectory + lift_trajectory = torch.zeros( + size=(self.n_envs, n_lift_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + # lift_xpos = self._compute_lift_xpos(grasp_xpos) + lift_xpos = self._apply_offset( + pose=grasp_xpos, + offset=torch.tensor([0, 0, 1], device=self.device) * self.cfg.lift_height, + ) + target_states_list = [ + [ + PlanState(xpos=lift_xpos[i], move_type=MoveType.EEF_MOVE), + ] + for i in range(self.n_envs) + ] + is_success, plan_traj = self._plan_arm_trajectory( + target_states_list, + grasp_qpos, + n_lift_waypoint, + self.arm_dof, + ) + if not is_success: + logger.log_warning("Failed to plan lift trajectory.") + return False, lift_trajectory, self.joint_ids + lift_trajectory[:, :, : self.arm_dof] = plan_traj + # padding hand close qpos to lift trajectory + lift_trajectory[:, :, self.arm_dof :] = self.hand_close_qpos + + # concatenate trajectories + trajectory = torch.cat( + [pick_trajectory, hand_close_trajectory, lift_trajectory], dim=1 + ) + return True, trajectory, self.joint_ids + + def _resolve_grasp_pose( + self, semantics: ObjectSemantics + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if not isinstance(semantics.affordance, AntipodalAffordance): + logger.log_error( + "Grasp pose affordance must be of type AntipodalAffordance" + ) + if semantics.entity is None: + logger.log_error( + "ObjectSemantics must be associated with an entity to get object pose" + ) + obj_poses = semantics.entity.get_local_pose(to_matrix=True) + + is_success, grasp_xpos, open_length = semantics.affordance.get_best_grasp_poses( + obj_poses=obj_poses, approach_direction=self.approach_direction + ) + return is_success, grasp_xpos, open_length + + def validate(self, target, start_qpos=None, **kwargs): + # TODO: implement proper validation logic for pick up action + return True + + +@configclass +class PlaceActionCfg(GraspActionCfg): + name: str = "place" + """Name of the action, used for identification and logging.""" + + +class PlaceAction(MoveAction): + def __init__( + self, + motion_generator: MotionGenerator, + cfg: PlaceActionCfg | None = None, + ): + """ + Initialize the atomic action. + Args: + motion_generator: The motion generator instance to use for planning. + cfg: Configuration for the action. + """ + super().__init__( + motion_generator, cfg=cfg if cfg is not None else PlaceActionCfg() + ) + self.cfg = cfg + if self.cfg.hand_open_qpos is None: + logger.log_error("hand_open_qpos must be specified in PlaceActionCfg") + if self.cfg.hand_close_qpos is None: + logger.log_error("hand_close_qpos must be specified in PlaceActionCfg") + self.hand_open_qpos = self.cfg.hand_open_qpos.to(self.device) + self.hand_close_qpos = self.cfg.hand_close_qpos.to(self.device) + + self.hand_joint_ids = self.robot.get_joint_ids(name=self.cfg.hand_control_part) + self.joint_ids = self.arm_joint_ids + self.hand_joint_ids + self.arm_dof = len(self.arm_joint_ids) + self.dof = len(self.joint_ids) + + def execute( + self, + target: Union[ObjectSemantics, torch.Tensor], + start_qpos: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[bool, torch.Tensor, list[float]]: + """execute pick up action + + Args: + target (ObjectSemantics): object semantics containing grasp affordance and entity information + start_qpos (Optional[torch.Tensor], optional): Planning start qpos. Defaults to None. + + Returns: + tuple[bool, torch.Tensor, list[float]]: + is_success, + trajectory of shape (n_envs, n_waypoints, dof), + joint_ids corresponding to trajectory + """ + is_success, place_xpos = self._resolve_pose_target( + target, action_name=self.__class__.__name__ + ) + start_qpos = self._resolve_start_qpos(start_qpos, self.arm_dof) + + # TODO: warning and fallback if no valid grasp pose found + if not is_success: + logger.log_warning( + "Failed to resolve grasp pose, using default approach pose" + ) + return False, torch.empty(0), self.joint_ids + + # compute waypoint number for each phase + n_down_waypoint, n_open_waypoint, n_lift_waypoint = ( + self._compute_three_phase_waypoints( + self.cfg.hand_interp_steps, + first_phase_name="approach", + third_phase_name="lift", + ) + ) + + down_trajectory = torch.zeros( + size=(self.n_envs, n_down_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + lift_xpos = self._apply_offset( + pose=place_xpos, + offset=torch.tensor([0, 0, 1], device=self.device) * self.cfg.lift_height, + ) + target_states_list = [ + [ + PlanState(xpos=lift_xpos[i], move_type=MoveType.EEF_MOVE), + PlanState(xpos=place_xpos[i], move_type=MoveType.EEF_MOVE), + ] + for i in range(self.n_envs) + ] + is_success, plan_traj = self._plan_arm_trajectory( + target_states_list, + start_qpos, + n_down_waypoint, + self.arm_dof, + ) + if not is_success: + logger.log_warning("Failed to plan down trajectory.") + return False, down_trajectory, self.joint_ids + down_trajectory[:, :, : self.arm_dof] = plan_traj + # Padding hand open qpos to pick trajectory + down_trajectory[:, :, self.arm_dof :] = self.hand_close_qpos + + # get hand closing trajectory + reach_qpos = down_trajectory[ + :, -1, : self.arm_dof + ] # Assuming the last point of pick trajectory is the grasp pose + hand_open_path = self._interpolate_hand_qpos( + self.hand_close_qpos, + self.hand_open_qpos, + n_open_waypoint, + ) + hand_open_trajectory = torch.zeros( + size=(self.n_envs, n_open_waypoint, self.dof), + device=self.device, + ) + hand_open_trajectory[:, :, : self.arm_dof] = reach_qpos + hand_open_trajectory[:, :, self.arm_dof :] = hand_open_path + + # get lift trajectory + back_trajectory = torch.zeros( + size=(self.n_envs, n_lift_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + target_states_list = [ + [ + PlanState(xpos=lift_xpos[i], move_type=MoveType.EEF_MOVE), + ] + for i in range(self.n_envs) + ] + is_success, plan_traj = self._plan_arm_trajectory( + target_states_list, + reach_qpos, + n_lift_waypoint, + self.arm_dof, + ) + if not is_success: + logger.log_warning("Failed to plan back trajectory.") + return False, back_trajectory, self.joint_ids + back_trajectory[:, :, : self.arm_dof] = plan_traj + # padding hand open qpos to back trajectory + back_trajectory[:, :, self.arm_dof :] = self.hand_open_qpos + + # concatenate trajectories + trajectory = torch.cat( + [down_trajectory, hand_open_trajectory, back_trajectory], dim=1 + ) + return True, trajectory, self.joint_ids + + def validate(self, target, start_qpos=None, **kwargs): + # TODO: implement proper validation logic for pick up action + return True diff --git a/embodichain/lab/sim/atomic_actions/core.py b/embodichain/lab/sim/atomic_actions/core.py new file mode 100644 index 00000000..08a22fc5 --- /dev/null +++ b/embodichain/lab/sim/atomic_actions/core.py @@ -0,0 +1,468 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING + +from embodichain.lab.sim.planners import PlanResult, PlanState, MoveType +from embodichain.utils import configclass + +from embodichain.toolkits.graspkit.pg_grasp import ( + GraspGenerator, + GraspGeneratorCfg, +) +from embodichain.toolkits.graspkit.pg_grasp.gripper_collision_checker import ( + GripperCollisionCfg, +) +from embodichain.lab.sim.common import BatchEntity +from embodichain.utils import logger + +if TYPE_CHECKING: + from embodichain.lab.sim.planners import MotionGenerator, MotionGenOptions + from embodichain.lab.sim.objects import Robot + + +# ============================================================================= +# Affordance Classes +# ============================================================================= + + +@dataclass +class Affordance: + """Base class for affordance data. + + Affordance represents interaction possibilities for an object. + This is the base class for specific affordance types. + """ + + object_label: str = "" + """Label of the object this affordance belongs to.""" + + geometry: Dict[str, Any] = field(default_factory=dict) + """Geometry dictionary shared with ObjectSemantics. + + The mesh payload is expected to be stored in: + - ``mesh_vertices``: torch.Tensor with shape [N, 3] + - ``mesh_triangles``: torch.Tensor with shape [M, 3] + """ + + custom_config: Dict[str, Any] = field(default_factory=dict) + """User-defined configuration payload for affordance creation and usage.""" + + @property + def mesh_vertices(self) -> torch.Tensor | None: + """Get mesh vertices from geometry. + + Returns: + Mesh vertices tensor [N, 3], or None if unavailable. + + Raises: + TypeError: If ``mesh_vertices`` exists but is not a torch tensor. + """ + vertices = self.geometry.get("mesh_vertices") + if vertices is None: + return None + if not isinstance(vertices, torch.Tensor): + raise TypeError("geometry['mesh_vertices'] must be a torch.Tensor") + return vertices + + @property + def mesh_triangles(self) -> torch.Tensor | None: + """Get mesh triangles from geometry. + + Returns: + Mesh triangle index tensor [M, 3], or None if unavailable. + + Raises: + TypeError: If ``mesh_triangles`` exists but is not a torch tensor. + """ + triangles = self.geometry.get("mesh_triangles") + if triangles is None: + return None + if not isinstance(triangles, torch.Tensor): + raise TypeError("geometry['mesh_triangles'] must be a torch.Tensor") + return triangles + + def set_custom_config(self, key: str, value: Any) -> None: + """Set a custom affordance configuration value.""" + self.custom_config[key] = value + + def get_custom_config(self, key: str, default: Any = None) -> Any: + """Get a custom affordance configuration value.""" + return self.custom_config.get(key, default) + + def get_batch_size(self) -> int: + """Return the batch size of this affordance data.""" + return 1 + + +@dataclass +class AntipodalAffordance(Affordance): + generator: GraspGenerator | None = None + """Grasp generator instance, initialized lazily when needed.""" + + force_reannotate: bool = False + """Whether to force re-annotation of grasp generator on each access.""" + + is_draw_grasp_xpos: bool = False + """Whether to visualize grasp poses in the simulator.""" + + def _init_generator(self): + if ( + self.geometry.get("mesh_vertices", None) is None + or self.geometry.get("mesh_triangles", None) is None + ): + logger.log_error( + "Mesh vertices and triangles must be provided in geometry to initialize AntipodalAffordance." + ) + self.generator = GraspGenerator( + vertices=self.geometry.get("mesh_vertices"), + triangles=self.geometry.get("mesh_triangles"), + cfg=self.custom_config.get("generator_cfg", None), + gripper_collision_cfg=self.custom_config.get("gripper_collision_cfg", None), + ) + if self.force_reannotate: + self.generator.annotate() + else: + if self.generator._hit_point_pairs is None: + self.generator.annotate() + + def get_best_grasp_poses( + self, + obj_poses: torch.Tensor, + approach_direction: torch.Tensor = torch.tensor( + [0, 0, -1], dtype=torch.float32 + ), + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if self.generator is None: + self._init_generator() + + grasp_xpos_list = [] + is_success_list = [] + open_length_list = [] + for i, obj_pose in enumerate(obj_poses): + is_success, grasp_xpos, open_length = self.generator.get_grasp_poses( + obj_pose, approach_direction + ) + if is_success: + grasp_xpos_list.append(grasp_xpos.unsqueeze(0)) + else: + logger.log_warning(f"No valid grasp pose found for {i}-th object.") + grasp_xpos_list.append( + torch.eye( + 4, dtype=torch.float32, device=self.generator.device + ).unsqueeze(0) + ) # Default to identity pose if no grasp found + is_success_list.append(is_success) + open_length_list.append(open_length) + is_success = torch.tensor( + is_success_list, dtype=torch.bool, device=self.generator.device + ) + grasp_xpos = torch.concatenate(grasp_xpos_list, dim=0) # [B, 4, 4] + open_length = torch.tensor( + open_length_list, dtype=torch.float32, device=self.generator.device + ) + if self.is_draw_grasp_xpos: + self._draw_grasp_xpos(grasp_xpos, open_length) + return is_success, grasp_xpos, open_length + + def _draw_grasp_xpos(self, grasp_xpos: torch.Tensor, open_length: torch.Tensor): + sim = SimulationManager.get_instance() + axis_xpos = [] + for i in range(grasp_xpos.shape[0]): + axis_xpos.append(grasp_xpos[i].to("cpu").numpy()) + sim.draw_marker( + cfg=MarkerCfg( + name="grasp_xpos", + axis_xpos=axis_xpos, + axis_len=0.05, + ) + ) + + +@dataclass +class InteractionPoints(Affordance): + """Interaction points affordance containing a batch of 3D positions. + + Interaction points define specific locations on an object surface + that can be used for contact-based interactions (pushing, poking, + touching) rather than full grasping. + """ + + points: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 3)) + """Batch of 3D interaction points with shape [B, 3]. + + Each point is a 3D coordinate in the object's local coordinate frame. + """ + + normals: torch.Tensor | None = None + """Optional surface normals at each interaction point with shape [B, 3]. + + Normals indicate the surface orientation at each point, + useful for determining approach directions. + """ + + point_types: List[str] = field(default_factory=list) + """Optional labels for each point's interaction type. + + Examples: "push", "poke", "touch", "pinch" + """ + + def get_points_by_type(self, point_type: str) -> torch.Tensor | None: + """Get points by their interaction type. + + Args: + point_type: Type of interaction (e.g., "push", "poke") + + Returns: + Tensor of points if found, None otherwise + """ + if point_type in self.point_types: + indices = [i for i, t in enumerate(self.point_types) if t == point_type] + return self.points[indices] + return None + + def get_batch_size(self) -> int: + """Return the number of interaction points in this affordance.""" + return self.points.shape[0] + + def get_approach_direction(self, point_idx: int) -> torch.Tensor: + """Get recommended approach direction for a given point. + + Args: + point_idx: Index of the point + + Returns: + 3D approach direction vector (normalized) + """ + if self.normals is not None: + # Approach from the opposite direction of the surface normal + return -self.normals[point_idx] + # Default: approach from positive z + return torch.tensor( + [0, 0, 1], dtype=self.points.dtype, device=self.points.device + ) + + +# ============================================================================= +# ObjectSemantics +# ============================================================================= + + +@dataclass +class ObjectSemantics: + """Semantic information about interaction target. + + This class encapsulates all semantic and geometric information about + an object needed for intelligent interaction planning. + """ + + affordance: Affordance + """Affordance data (GraspPose, InteractionPoints, etc.).""" + + geometry: Dict[str, Any] + """Geometric information including bounding box, mesh data.""" + + properties: Dict[str, Any] = field(default_factory=dict) + """Physical properties: mass, friction, etc.""" + + label: str = "none" + """Object category label (e.g., 'apple', 'bottle').""" + + entity: BatchEntity | None = None + """Optional reference to the underlying simulation entity representing this object.""" + + def __post_init__(self) -> None: + """Bind affordance metadata to this semantic object. + + The affordance shares the same geometry dict instance as + ``ObjectSemantics.geometry`` so mesh tensors are authored in one place. + """ + self.affordance.object_label = self.label + self.affordance.geometry = self.geometry + + +# ============================================================================= +# ActionCfg and AtomicAction +# ============================================================================= + + +@configclass +class ActionCfg: + """Configuration for atomic actions.""" + + name: str = "default" + """Name of the action, used for identification and logging.""" + + control_part: str = "arm" + """Control part name for the action.""" + + interpolation_type: str = "linear" + """Interpolation type: 'linear', 'cubic'.""" + + velocity_limit: Optional[float] = None + """Optional velocity limit for the motion.""" + + acceleration_limit: Optional[float] = None + """Optional acceleration limit for the motion.""" + + +class AtomicAction(ABC): + """Abstract base class for atomic actions. + + All atomic actions use PlanResult from embodichain.lab.sim.planners + as the return type for execute() method, ensuring consistency with + the existing motion planning infrastructure. + """ + + def __init__( + self, + motion_generator: MotionGenerator, + cfg: ActionCfg = ActionCfg(), + ): + """ + Initialize the atomic action. + Args: + motion_generator: The motion generator instance to use for planning. + cfg: Configuration for the action. + """ + self.motion_generator = motion_generator + self.cfg = cfg + self.robot = motion_generator.robot + self.control_part = cfg.control_part + self.device = self.robot.device + + @abstractmethod + def execute( + self, + target: Union[torch.Tensor, ObjectSemantics], + start_qpos: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[bool, torch.Tensor, list[float]]: + """execute pick up action + + Args: + target (ObjectSemantics): object semantics containing grasp affordance and entity information + start_qpos (Optional[torch.Tensor], optional): Planning start qpos. Defaults to None. + + Returns: + tuple[bool, torch.Tensor, list[float]]: + is_success, + trajectory of shape (n_envs, n_waypoints, dof), + joint_ids corresponding to trajectory + """ + + @abstractmethod + def validate( + self, + target: Union[torch.Tensor, ObjectSemantics], + start_qpos: Optional[torch.Tensor] = None, + **kwargs, + ) -> bool: + """Validate if the action is feasible without executing. + + This method performs a quick feasibility check (e.g., IK solvability) + without generating a full trajectory. + + Returns: + True if action appears feasible, False otherwise + """ + pass + + def _ik_solve( + self, target_pose: torch.Tensor, qpos_seed: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Solve IK for target pose. + + Args: + target_pose: Target pose [4, 4] + qpos_seed: Seed configuration [DOF] + + Returns: + Joint configuration [DOF] + + Raises: + RuntimeError: If IK fails to find a solution + """ + if qpos_seed is None: + qpos_seed = self.robot.get_qpos() + + success, qpos = self.robot.compute_ik( + pose=target_pose.unsqueeze(0), + qpos_seed=qpos_seed.unsqueeze(0), + name=self.control_part, + ) + + if not success.all(): + raise RuntimeError(f"IK failed for target pose: {target_pose}") + + return qpos.squeeze(0) + + def _fk_compute(self, qpos: torch.Tensor) -> torch.Tensor: + """Compute forward kinematics. + + Args: + qpos: Joint configuration [DOF] or [B, DOF] + + Returns: + End-effector pose [4, 4] or [B, 4, 4] + """ + if qpos.dim() == 1: + qpos = qpos.unsqueeze(0) + + xpos = self.robot.compute_fk( + qpos=qpos, + name=self.control_part, + to_matrix=True, + ) + + return xpos.squeeze(0) if xpos.shape[0] == 1 else xpos + + def _apply_offset(self, pose: torch.Tensor, offset: torch.Tensor) -> torch.Tensor: + """Apply offset to pose in local frame. + + Args: + pose: Base pose [N, 4, 4] + offset: Offset in local frame [N, 3] or [3] + + Returns: + Pose with offset applied [N, 4, 4] + """ + if not len(pose.shape) == 3 or pose.shape[1:] != (4, 4): + logger.log_error("pose must have shape [N, 4, 4]") + if len(offset.shape) == 1: + offset = offset.unsqueeze(0) + if not len(offset.shape) == 2 or offset.shape[1] != 3: + logger.log_error("offset must have shape [N, 3] or [3]") + result = pose.clone() + result[:, :3, 3] += offset + return result + + def plan_trajectory( + self, + target_states: List[PlanState], + options: Optional["MotionGenOptions"] = None, + ) -> "PlanResult": + """Plan trajectory using motion generator.""" + from embodichain.lab.sim.planners import MotionGenOptions + + if options is None: + options = MotionGenOptions(control_part=self.control_part) + return self.motion_generator.generate(target_states, options) diff --git a/embodichain/lab/sim/atomic_actions/engine.py b/embodichain/lab/sim/atomic_actions/engine.py new file mode 100644 index 00000000..15b868a8 --- /dev/null +++ b/embodichain/lab/sim/atomic_actions/engine.py @@ -0,0 +1,340 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING + +from embodichain.lab.sim.planners import PlanResult +from embodichain.utils import logger +from .core import AtomicAction, ObjectSemantics, ActionCfg + +if TYPE_CHECKING: + from embodichain.lab.sim.planners import MotionGenerator + from embodichain.lab.sim.objects import Robot + + +# ============================================================================= +# Global Action Registry +# ============================================================================= + +_global_action_registry: Dict[str, Type[AtomicAction]] = {} +_global_action_configs: Dict[str, Type[ActionCfg]] = {} + + +def register_action( + name: str, + action_class: Type[AtomicAction], + config_class: Optional[Type[ActionCfg]] = None, +) -> None: + """Register a custom atomic action class globally. + + This function allows registration of custom action types that can then + be instantiated by the AtomicActionEngine. + + Args: + name: Unique identifier for the action type + action_class: The AtomicAction subclass to register + config_class: Optional configuration class for the action + + Example: + >>> class MyCustomAction(AtomicAction): + ... def execute(self, target, **kwargs): + ... # Implementation + ... pass + ... def validate(self, target, **kwargs): + ... return True + >>> register_action("my_custom", MyCustomAction) + """ + _global_action_registry[name] = action_class + if config_class is not None: + _global_action_configs[name] = config_class + + +def unregister_action(name: str) -> None: + """Unregister an action type. + + Args: + name: The action type identifier to remove + """ + _global_action_registry.pop(name, None) + _global_action_configs.pop(name, None) + + +def get_registered_actions() -> Dict[str, Type[AtomicAction]]: + """Get all registered action types. + + Returns: + Dictionary mapping action names to their classes + """ + return _global_action_registry.copy() + + +# ============================================================================= +# Semantic Analyzer +# ============================================================================= + + +class SemanticAnalyzer: + """Analyzes objects and provides ObjectSemantics for atomic actions.""" + + def __init__(self): + self._object_cache: Dict[str, ObjectSemantics] = {} + + def analyze( + self, + label: str, + geometry: Optional[Dict[str, Any]] = None, + custom_config: Optional[Dict[str, Any]] = None, + use_cache: bool = True, + ) -> ObjectSemantics: + """Analyze object by label and return ObjectSemantics. + + This is a placeholder implementation that should be extended + with actual object detection and affordance computation. + + Args: + label: Object category label (e.g., "apple", "bottle") + geometry: Optional geometry payload. Can include mesh tensors: + ``mesh_vertices`` [N, 3] and ``mesh_triangles`` [M, 3]. + custom_config: Optional user-defined affordance configuration. + use_cache: Whether to use cached semantics when available. + + Returns: + ObjectSemantics containing affordance data + """ + # Only use cache for default analyze path + if ( + use_cache + and geometry is None + and custom_config is None + and label in self._object_cache + ): + return self._object_cache[label] + + # Create default semantics (placeholder implementation) + from .core import AntipodalAffordance + + # Generate default grasp poses based on object type + default_poses = torch.eye(4).unsqueeze(0) + default_poses[0, 2, 3] = 0.1 # Default offset + + default_geometry: Dict[str, Any] = {"bounding_box": [0.1, 0.1, 0.1]} + if geometry is not None: + default_geometry.update(geometry) + + grasp_affordance = AntipodalAffordance( + object_label=label, + custom_config=custom_config or {}, + ) + + semantics = ObjectSemantics( + label=label, + affordance=grasp_affordance, + geometry=default_geometry, + properties={"mass": 1.0, "friction": 0.5}, + ) + + # Cache only default path + if use_cache and geometry is None and custom_config is None: + self._object_cache[label] = semantics + return semantics + + def clear_cache(self) -> None: + """Clear the object semantics cache.""" + self._object_cache.clear() + + +# ============================================================================= +# Atomic Action Engine +# ============================================================================= + + +class AtomicActionEngine: + """Central engine for managing and executing atomic actions.""" + + def __init__( + self, + motion_generator: "MotionGenerator", + actions_cfg_list: Optional[List[ActionCfg]] = None, + ): + self.motion_generator = motion_generator + self.robot = self.motion_generator.robot + self.device = self.motion_generator.device + + # Semantic analyzer for object understanding + self._semantic_analyzer = SemanticAnalyzer() + + # Initialize default actions + self._actions: Dict[str, AtomicAction] = self._init_actions(actions_cfg_list) + + def _init_actions( + self, actions_cfg_list: Optional[List[ActionCfg]] = None + ) -> Dict[str, "AtomicAction"]: + actions: Dict[str, AtomicAction] = {} + from .actions import MoveAction, PickUpAction, PlaceAction + + builtin_action_map: Dict[str, Type[AtomicAction]] = { + "move": MoveAction, + "pick_up": PickUpAction, + "place": PlaceAction, + } + if actions_cfg_list is not None: + for cfg in actions_cfg_list: + action_class = builtin_action_map.get( + cfg.name + ) or _global_action_registry.get(cfg.name) + if action_class is None: + logger.log_error(f"Unknown action name in config: {cfg.name}") + continue + instance = action_class(motion_generator=self.motion_generator, cfg=cfg) + actions[cfg.name] = instance + return actions + + def execute_static( + self, + target_list: List[Union[torch.Tensor, str, ObjectSemantics, Dict[str, Any]]], + ) -> tuple[bool, torch.Tensor]: + """Execute a sequence of actions to target poses. + + Each element in ``target_list`` corresponds to an action in the order they + were registered via ``actions_cfg_list``. + """ + action_names = list(self._actions.keys()) + if len(target_list) != len(action_names): + logger.log_error( + f"Length of target_list ({len(target_list)}) must match number of actions ({len(action_names)})." + ) + start_qpos = self.motion_generator.robot.get_qpos() + n_envs = start_qpos.shape[0] + all_dof = self.motion_generator.robot.dof + all_trajectory = torch.empty( + size=(n_envs, 0, all_dof), dtype=torch.float32, device=self.device + ) + + for action_name, target in zip(action_names, target_list): + atom_action = self._actions[action_name] + target = self._resolve_target(target) + control_part = atom_action.control_part + arm_joint_ids = self.motion_generator.robot.get_joint_ids(name=control_part) + start_qpos_part = start_qpos[:, arm_joint_ids] + is_success, traj, joint_ids = atom_action.execute( + target=target, start_qpos=start_qpos_part + ) + if not is_success: + return False, all_trajectory + n_waypoints = traj.shape[1] + + traj_full = torch.zeros( + size=(n_envs, n_waypoints, all_dof), + dtype=torch.float32, + device=self.device, + ) + traj_full[:, :] = start_qpos + traj_full[:, :, joint_ids] = traj + all_trajectory = torch.cat((all_trajectory, traj_full), dim=1) + # update start qpos for the next action + start_qpos[:, joint_ids] = traj[:, -1, :] + return True, all_trajectory + + def validate( + self, + action_name: str, + target: Union[torch.Tensor, str, ObjectSemantics, Dict[str, Any]], + **kwargs, + ) -> bool: + """Validate if a named action is feasible without executing.""" + if action_name not in self._actions: + logger.log_warning(f"Action '{action_name}' is not registered.") + return False + + action = self._actions[action_name] + target = self._resolve_target(target) + return action.validate(target, **kwargs) + + def _resolve_target( + self, + target: Union[torch.Tensor, str, ObjectSemantics, Dict[str, Any]], + ) -> Union[torch.Tensor, ObjectSemantics]: + """Resolve user target input into tensor pose or ObjectSemantics. + + Supports the convenience dict format in ``execute`` and ``validate``. + """ + if isinstance(target, torch.Tensor): + return target + + if isinstance(target, ObjectSemantics): + return target + + if isinstance(target, str): + return self._semantic_analyzer.analyze(target) + + if isinstance(target, dict): + if "pose" in target: + pose = target["pose"] + if not isinstance(pose, torch.Tensor): + raise TypeError("target['pose'] must be a torch.Tensor") + return pose + + if "semantics" in target: + semantics = target["semantics"] + if not isinstance(semantics, ObjectSemantics): + raise TypeError( + "target['semantics'] must be an ObjectSemantics instance" + ) + return semantics + + label = target.get("label") + if label is None: + raise ValueError( + "Dict target must provide 'label', or use 'pose'/'semantics'." + ) + if not isinstance(label, str): + raise TypeError("target['label'] must be a string") + + geometry = target.get("geometry") + custom_config = target.get("custom_config") + use_cache = target.get("use_cache", True) + + semantics = self._semantic_analyzer.analyze( + label=label, + geometry=geometry, + custom_config=custom_config, + use_cache=use_cache, + ) + + properties = target.get("properties") + if properties is not None: + semantics.properties.update(properties) + + uid = target.get("uid") + if uid is not None: + semantics.uid = uid + + return semantics + + raise TypeError( + "target must be torch.Tensor, str, ObjectSemantics, or Dict[str, Any]" + ) + + def get_semantic_analyzer(self) -> SemanticAnalyzer: + """Get the semantic analyzer for object understanding.""" + return self._semantic_analyzer + + def set_semantic_analyzer(self, analyzer: SemanticAnalyzer) -> None: + """Set a custom semantic analyzer.""" + self._semantic_analyzer = analyzer diff --git a/embodichain/lab/sim/planners/motion_generator.py b/embodichain/lab/sim/planners/motion_generator.py index f5f12bac..220deeca 100644 --- a/embodichain/lab/sim/planners/motion_generator.py +++ b/embodichain/lab/sim/planners/motion_generator.py @@ -507,7 +507,11 @@ def interpolate_trajectory( qpos_seed = options.start_qpos if qpos_seed is None and qpos_list is not None: + # first waypoint as seed qpos_seed = qpos_list[0] + if qpos_seed is None: + # fallback to current robot state as seed + qpos_seed = self.robot.get_qpos(name=control_part)[0] # Generate trajectory interpolate_qpos_list = [] @@ -550,9 +554,14 @@ def interpolate_trajectory( # compute_batch_ik expects (n_envs, n_batch, 7) or (n_envs, n_batch, 4, 4) # Here we assume n_envs = 1 or we want to apply this to all envs if available. # Since MotionGenerator usually works with self.robot.device, we use its batching capabilities. + qpos_seed_repeat = ( + qpos_seed.unsqueeze(0) + .repeat(total_interpolated_poses.shape[0], 1) + .unsqueeze(0) + ) success_batch, qpos_batch = self.robot.compute_batch_ik( pose=total_interpolated_poses.unsqueeze(0), - joint_seed=None, # Or use qpos_seed if properly shaped + joint_seed=qpos_seed_repeat, # Or use qpos_seed if properly shaped name=control_part, ) diff --git a/embodichain/lab/sim/planners/toppra_planner.py b/embodichain/lab/sim/planners/toppra_planner.py index 0c20ccf9..218d17ed 100644 --- a/embodichain/lab/sim/planners/toppra_planner.py +++ b/embodichain/lab/sim/planners/toppra_planner.py @@ -191,11 +191,9 @@ def plan( ) # Build waypoints - waypoints = [] - for target in target_states: - waypoints.append(np.array(target.qpos)) - - waypoints = np.array(waypoints) + waypoints = np.array( + [target.qpos.to("cpu").numpy() for target in target_states] + ) # Create spline interpolation # NOTE: Suitable for dense waypoints ss = np.linspace(0, 1, len(waypoints)) diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py index 658f4f88..9ec009bc 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py @@ -73,7 +73,7 @@ class GraspGeneratorCfg: number of sampled surface points, ray perturbation angle, and gripper jaw distance limits. See :class:`AntipodalSamplerCfg` for details.""" - max_deviation_angle: float = np.pi / 12 + max_deviation_angle: float = np.pi / 6 """Maximum allowed angle (in radians) between the specified approach direction and the axis connecting an antipodal point pair. Pairs that deviate more than this threshold from perpendicular to the approach are diff --git a/scripts/tutorials/sim/atomic_actions.py b/scripts/tutorials/sim/atomic_actions.py new file mode 100644 index 00000000..1f4de8d5 --- /dev/null +++ b/scripts/tutorials/sim/atomic_actions.py @@ -0,0 +1,348 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +Tutorial: Atomic Actions for Robot Motion Generation +===================================================== + +This script shows how to use the atomic action system to plan and execute +a pick-and-place task with a robot arm. + +Key concepts covered: + 1. Setting up a MotionGenerator and AtomicActionEngine + 2. Describing what to pick using ObjectSemantics and AntipodalAffordance + 3. Running a pick → place → move sequence with execute_static() + +Run with: + python atomic_actions.py [--num_envs N] [--enable_rt] +""" + +import argparse +import numpy as np +import time +import torch + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot, RigidObject +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.lab.sim.solvers import PytorchSolverCfg +from embodichain.data import get_data_path +from embodichain.lab.sim.cfg import ( + JointDrivePropertiesCfg, + RobotCfg, + RigidObjectCfg, + RigidBodyAttributesCfg, + LightCfg, + URDFCfg, +) +from embodichain.lab.sim.planners import MotionGenerator, MotionGenCfg, ToppraPlannerCfg +from embodichain.toolkits.graspkit.pg_grasp.gripper_collision_checker import ( + GripperCollisionCfg, +) +from embodichain.toolkits.graspkit.pg_grasp.antipodal_generator import ( + GraspGenerator, + GraspGeneratorCfg, + AntipodalSamplerCfg, +) + +# Import everything from the public atomic_actions API +from embodichain.lab.sim.atomic_actions import ( + AtomicActionEngine, + ObjectSemantics, + AntipodalAffordance, + PickUpActionCfg, + PlaceActionCfg, + MoveActionCfg, +) + + +def parse_arguments(): + """ + Parse command-line arguments to configure the simulation. + + Returns: + argparse.Namespace: Parsed arguments including number of environments, device, and rendering options. + """ + parser = argparse.ArgumentParser( + description="Create and simulate a robot in SimulationManager" + ) + parser.add_argument( + "--enable_rt", action="store_true", help="Enable ray tracing rendering" + ) + parser.add_argument( + "--num_envs", type=int, default=1, help="Number of parallel environments" + ) + return parser.parse_args() + + +def initialize_simulation(args): + """ + Initialize the simulation environment based on the provided arguments. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + + Returns: + SimulationManager: Configured simulation manager instance. + """ + config = SimulationManagerCfg( + headless=True, + sim_device="cuda", + enable_rt=args.enable_rt, + physics_dt=1.0 / 100.0, + num_envs=args.num_envs, + ) + sim = SimulationManager(config) + + light = sim.add_light( + cfg=LightCfg(uid="main_light", intensity=50.0, init_pos=(0, 0, 2.0)) + ) + + return sim + + +def create_robot(sim: SimulationManager, position=[0.0, 0.0, 0.0]): + """ + Create and configure a robot with an arm and a dexterous hand in the simulation. + + Args: + sim (SimulationManager): The simulation manager instance. + + Returns: + Robot: The configured robot instance added to the simulation. + """ + # Retrieve URDF paths for the robot arm and hand + ur10_urdf_path = get_data_path("UniversalRobots/UR10/UR10.urdf") + gripper_urdf_path = get_data_path("DH_PGC_140_50_M/DH_PGC_140_50_M.urdf") + # Configure the robot with its components and control properties + cfg = RobotCfg( + uid="UR10", + urdf_cfg=URDFCfg( + components=[ + {"component_type": "arm", "urdf_path": ur10_urdf_path}, + {"component_type": "hand", "urdf_path": gripper_urdf_path}, + ] + ), + drive_pros=JointDrivePropertiesCfg( + stiffness={"JOINT[0-9]": 1e4, "FINGER[1-2]": 1e2}, + damping={"JOINT[0-9]": 1e3, "FINGER[1-2]": 1e1}, + max_effort={"JOINT[0-9]": 1e5, "FINGER[1-2]": 1e3}, + drive_type="force", + ), + control_parts={ + "arm": ["JOINT[0-9]"], + "hand": ["FINGER[1-2]"], + }, + solver_cfg={ + "arm": PytorchSolverCfg( + end_link_name="ee_link", + root_link_name="base_link", + tcp=[ + [0.0, 1.0, 0.0, 0.0], + [-1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.12], + [0.0, 0.0, 0.0, 1.0], + ], + ) + }, + init_qpos=[0.0, -np.pi / 2, -np.pi / 2, np.pi / 2, -np.pi / 2, 0.0, 0.0, 0.0], + init_pos=position, + ) + return sim.add_robot(cfg=cfg) + + +def create_mug(sim: SimulationManager) -> RigidObject: + mug_cfg = RigidObjectCfg( + uid="mug", + shape=MeshCfg( + fpath=get_data_path("CoffeeCup/cup.ply"), + ), + attrs=RigidBodyAttributesCfg( + mass=0.01, + dynamic_friction=0.97, + static_friction=0.99, + ), + max_convex_hull_num=16, + init_pos=[0.55, 0.0, 0.01], + init_rot=[0.0, 0.0, -90], + body_scale=(4, 4, 4), + ) + mug = sim.add_rigid_object(cfg=mug_cfg) + return mug + + +def main(): + """Pick up a mug and place it at a new location using atomic actions.""" + args = parse_arguments() + + # ------------------------------------------------------------------ # + # Step 1: Set up simulation, robot, and object # + # ------------------------------------------------------------------ # + sim: SimulationManager = initialize_simulation(args) + robot = create_robot(sim) + mug = create_mug(sim) + + # ------------------------------------------------------------------ # + # Step 2: Create a MotionGenerator for the robot # + # MotionGenerator handles trajectory planning (IK + TOPPRA smoothing) # + # ------------------------------------------------------------------ # + motion_gen = MotionGenerator( + cfg=MotionGenCfg(planner_cfg=ToppraPlannerCfg(robot_uid=robot.uid)) + ) + + # ------------------------------------------------------------------ # + # Step 3: Configure the three atomic actions # + # # + # PickUpAction — approach → close gripper → lift # + # PlaceAction — lower → open gripper → retract # + # MoveAction — free-space move to a target EEF pose # + # ------------------------------------------------------------------ # + # Gripper joint values for this robot (DH_PGC_140): + # open = [0.00, 0.00] (fully open) + # close = [0.025, 0.025] (grasping width) + hand_open = torch.tensor([0.00, 0.00], dtype=torch.float32, device=sim.device) + hand_close = torch.tensor([0.025, 0.025], dtype=torch.float32, device=sim.device) + + pickup_cfg = PickUpActionCfg( + control_part="arm", + hand_control_part="hand", + hand_open_qpos=hand_open, + hand_close_qpos=hand_close, + # Approach the object from directly above (negative world-Z) + approach_direction=torch.tensor( + [0.0, 0.0, -1.0], dtype=torch.float32, device=sim.device + ), + pre_grasp_distance=0.15, # hover 15 cm above before descending + lift_height=0.15, # lift 15 cm after grasping + ) + + place_cfg = PlaceActionCfg( + control_part="arm", + hand_control_part="hand", + hand_open_qpos=hand_open, + hand_close_qpos=hand_close, + lift_height=0.15, + ) + + move_cfg = MoveActionCfg( + control_part="arm", + ) + + # ------------------------------------------------------------------ # + # Step 4: Build the AtomicActionEngine # + # # + # actions_cfg_list defines the ORDER of actions that execute_static() # + # will run. Each entry is matched positionally to target_list. # + # ------------------------------------------------------------------ # + atomic_engine = AtomicActionEngine( + motion_generator=motion_gen, + actions_cfg_list=[pickup_cfg, place_cfg, move_cfg], + ) + + sim.init_gpu_physics() + sim.open_window() + + # ------------------------------------------------------------------ # + # Step 5: Describe the mug with ObjectSemantics # + # # + # ObjectSemantics bundles together: # + # - geometry (mesh vertices/triangles for grasp annotation) # + # - affordance (how to grasp the object — here antipodal grasps) # + # - entity reference (so the action can read the live object pose) # + # ------------------------------------------------------------------ # + mug_grasp_affordance = AntipodalAffordance( + object_label="mug", + force_reannotate=False, + custom_config={ + "gripper_collision_cfg": GripperCollisionCfg( + max_open_length=0.088, finger_length=0.078, point_sample_dense=0.012 + ), + "generator_cfg": GraspGeneratorCfg( + viser_port=11801, + antipodal_sampler_cfg=AntipodalSamplerCfg( + n_sample=20000, max_length=0.088, min_length=0.003 + ), + ), + }, + ) + mug_semantics = ObjectSemantics( + label="mug", + geometry={ + "mesh_vertices": mug.get_vertices(env_ids=[0], scale=True)[0], + "mesh_triangles": mug.get_triangles(env_ids=[0])[0], + }, + affordance=mug_grasp_affordance, + entity=mug, # needed so PickUpAction can read the mug's live pose + ) + + # ------------------------------------------------------------------ # + # Step 6: Define target poses for place and final rest # + # # + # Poses are 4×4 homogeneous transforms (rotation | translation). # + # For PickUpAction the target is mug_semantics — the action computes # + # the grasp pose automatically from the affordance. # + # ------------------------------------------------------------------ # + # Place the mug 20 cm to the left and 40 cm forward from its pickup pose + place_xpos = torch.tensor( + [ + [-0.0539, -0.9985, -0.0022, 0.2489], + [-0.9977, 0.0540, -0.0401, 0.3970], + [0.0401, 0.0000, -0.9992, 0.2400], + [0.0000, 0.0000, 0.0000, 1.0000], + ], + dtype=torch.float32, + device=sim.device, + ) + + # Move the arm to a safe resting pose after placing + rest_xpos = torch.tensor( + [ + [-0.0539, -0.9985, -0.0022, 0.5000], + [-0.9977, 0.0540, -0.0401, 0.0000], + [0.0401, 0.0000, -0.9992, 0.5000], + [0.0000, 0.0000, 0.0000, 1.0000], + ], + dtype=torch.float32, + device=sim.device, + ) + + # ------------------------------------------------------------------ # + # Step 7: Plan and execute the full sequence # + # # + # execute_static() plans all three actions in order and returns a # + # single concatenated joint trajectory (n_envs, n_waypoints, dof). # + # We then replay it frame-by-frame in the simulator. # + # ------------------------------------------------------------------ # + print("Planning pick → place → move trajectory...") + is_success, traj = atomic_engine.execute_static( + target_list=[mug_semantics, place_xpos, rest_xpos] + ) + + if not is_success: + print("Planning failed. Check that the target poses are reachable.") + return + + print(f"Success! Replaying {traj.shape[1]} waypoints...") + for i in range(traj.shape[1]): + robot.set_qpos(traj[:, i]) + sim.update(step=4) + time.sleep(1e-2) + + input("Press Enter to exit...") + + +if __name__ == "__main__": + main() diff --git a/tests/sim/atomic_actions/__init__.py b/tests/sim/atomic_actions/__init__.py new file mode 100644 index 00000000..0671165d --- /dev/null +++ b/tests/sim/atomic_actions/__init__.py @@ -0,0 +1,17 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Tests for atomic actions module.""" diff --git a/tests/sim/atomic_actions/test_actions.py b/tests/sim/atomic_actions/test_actions.py new file mode 100644 index 00000000..ba7324cc --- /dev/null +++ b/tests/sim/atomic_actions/test_actions.py @@ -0,0 +1,304 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Tests for atomic action implementations (MoveAction, PickUpAction, PlaceAction).""" + +from __future__ import annotations + +import pytest +import torch +from unittest.mock import MagicMock, Mock + +from embodichain.lab.sim.atomic_actions.core import ( + ActionCfg, + Affordance, + ObjectSemantics, +) +from embodichain.lab.sim.atomic_actions.actions import ( + MoveAction, + MoveActionCfg, + PickUpAction, + PickUpActionCfg, + PlaceAction, + PlaceActionCfg, +) + +# --------------------------------------------------------------------------- +# Mock Helpers +# --------------------------------------------------------------------------- + +NUM_ENVS = 2 # number of parallel environments used in tests +ARM_DOF = 6 # typical arm joint count +HAND_DOF = 2 # typical hand joint count +TOTAL_DOF = ARM_DOF + HAND_DOF + + +def _make_mock_robot( + num_envs: int = NUM_ENVS, + arm_dof: int = ARM_DOF, + hand_dof: int = HAND_DOF, +) -> Mock: + """Create a mock Robot with arm and hand control parts.""" + robot = Mock() + robot.device = torch.device("cpu") + robot.dof = arm_dof + hand_dof + + def get_qpos(name=None): + if name == "arm": + return torch.zeros(num_envs, arm_dof) + if name == "hand": + return torch.zeros(num_envs, hand_dof) + # Full qpos + return torch.zeros(num_envs, arm_dof + hand_dof) + + robot.get_qpos = get_qpos + + def get_joint_ids(name=None): + if name == "arm": + return list(range(arm_dof)) + if name == "hand": + return list(range(arm_dof, arm_dof + hand_dof)) + return list(range(arm_dof + hand_dof)) + + robot.get_joint_ids = get_joint_ids + + # compute_ik: return success and identity-like qpos + def compute_ik(pose=None, qpos_seed=None, name=None, joint_seed=None): + seed = joint_seed if joint_seed is not None else qpos_seed + if seed is None: + seed = torch.zeros(num_envs, arm_dof) + success = torch.ones(num_envs, dtype=torch.bool) + return success, seed.clone() + + robot.compute_ik = compute_ik + + # compute_fk: return identity-like poses + def compute_fk(qpos=None, name=None, to_matrix=True): + n = qpos.shape[0] if qpos is not None else num_envs + poses = torch.eye(4).unsqueeze(0).repeat(n, 1, 1) + return poses + + robot.compute_fk = compute_fk + + return robot + + +def _make_mock_motion_generator(robot: Mock | None = None) -> Mock: + """Create a mock MotionGenerator.""" + mg = Mock() + mg.robot = robot or _make_mock_robot() + mg.device = mg.robot.device + return mg + + +# --------------------------------------------------------------------------- +# MoveAction +# --------------------------------------------------------------------------- + + +class TestMoveActionHelpers: + """Tests for MoveAction helper methods that don't need simulation.""" + + def setup_method(self): + self.robot = _make_mock_robot() + self.mg = _make_mock_motion_generator(self.robot) + self.cfg = MoveActionCfg(sample_interval=50) + self.action = MoveAction(self.mg, cfg=self.cfg) + + def test_init_sets_attributes(self): + assert self.action.n_envs == NUM_ENVS + assert self.action.dof == ARM_DOF + assert self.action.device == torch.device("cpu") + + def test_resolve_pose_target_from_4x4(self): + target = torch.eye(4) + is_success, result = self.action._resolve_pose_target( + target, action_name="TestAction" + ) + assert is_success is True + assert result.shape == (NUM_ENVS, 4, 4) + # Single pose should be repeated for all envs + for i in range(NUM_ENVS): + assert torch.equal(result[i], torch.eye(4)) + + def test_resolve_pose_target_from_batched(self): + target = torch.eye(4).unsqueeze(0).repeat(NUM_ENVS, 1, 1) + target[:, 2, 3] = 0.5 # offset z for each env + is_success, result = self.action._resolve_pose_target( + target, action_name="TestAction" + ) + assert is_success is True + assert result.shape == (NUM_ENVS, 4, 4) + for i in range(NUM_ENVS): + assert result[i, 2, 3].item() == pytest.approx(0.5) + + def test_resolve_start_qpos_defaults_to_current(self): + result = self.action._resolve_start_qpos(None) + assert result.shape == (NUM_ENVS, ARM_DOF) + + def test_resolve_start_qpos_broadcasts_single(self): + single = torch.ones(ARM_DOF) + result = self.action._resolve_start_qpos(single) + assert result.shape == (NUM_ENVS, ARM_DOF) + for i in range(NUM_ENVS): + assert torch.equal(result[i], single) + + def test_compute_three_phase_waypoints_sums_to_sample_interval(self): + hand_interp_steps = 5 + first, second, third = self.action._compute_three_phase_waypoints( + hand_interp_steps, + first_phase_name="approach", + third_phase_name="lift", + ) + assert first + second + third == self.cfg.sample_interval + assert first >= 2 + assert third >= 2 + + def test_interpolate_hand_qpos_shape(self): + n_waypoints = 10 + start = torch.zeros(HAND_DOF) + end = torch.ones(HAND_DOF) + result = self.action._interpolate_hand_qpos(start, end, n_waypoints) + assert result.shape == (n_waypoints, HAND_DOF) + # First and last should match endpoints + assert torch.allclose(result[0], start) + assert torch.allclose(result[-1], end) + + def test_interpolate_hand_qpos_linear(self): + """Verify linear interpolation between two hand configs.""" + n_waypoints = 3 + start = torch.tensor([0.0, 0.0]) + end = torch.tensor([1.0, 1.0]) + result = self.action._interpolate_hand_qpos(start, end, n_waypoints) + expected_mid = torch.tensor([0.5, 0.5]) + assert torch.allclose(result[1], expected_mid, atol=1e-6) + + +# --------------------------------------------------------------------------- +# PickUpAction +# --------------------------------------------------------------------------- + + +class TestPickUpActionInit: + """Tests for PickUpAction initialization and config validation.""" + + def setup_method(self): + self.robot = _make_mock_robot() + self.mg = _make_mock_motion_generator(self.robot) + + def _make_cfg(self, **overrides): + defaults = dict( + hand_open_qpos=torch.tensor([0.0, 0.0]), + hand_close_qpos=torch.tensor([0.025, 0.025]), + control_part="arm", + hand_control_part="hand", + pre_grasp_distance=0.15, + lift_height=0.15, + approach_direction=torch.tensor([0.0, 0.0, -1.0]), + ) + defaults.update(overrides) + return PickUpActionCfg(**defaults) + + def test_init_sets_hand_joint_ids(self): + cfg = self._make_cfg() + action = PickUpAction(self.mg, cfg=cfg) + assert action.hand_joint_ids == list(range(ARM_DOF, ARM_DOF + HAND_DOF)) + assert action.joint_ids == list(range(ARM_DOF)) + list( + range(ARM_DOF, ARM_DOF + HAND_DOF) + ) + assert action.dof == TOTAL_DOF + + +# --------------------------------------------------------------------------- +# PlaceAction +# --------------------------------------------------------------------------- + + +class TestPlaceActionInit: + """Tests for PlaceAction initialization.""" + + def setup_method(self): + self.robot = _make_mock_robot() + self.mg = _make_mock_motion_generator(self.robot) + + def _make_cfg(self, **overrides): + defaults = dict( + hand_open_qpos=torch.tensor([0.0, 0.0]), + hand_close_qpos=torch.tensor([0.025, 0.025]), + control_part="arm", + hand_control_part="hand", + lift_height=0.15, + ) + defaults.update(overrides) + return PlaceActionCfg(**defaults) + + def test_init_sets_hand_joint_ids(self): + cfg = self._make_cfg() + action = PlaceAction(self.mg, cfg=cfg) + assert action.hand_joint_ids == list(range(ARM_DOF, ARM_DOF + HAND_DOF)) + assert action.dof == TOTAL_DOF + + +# --------------------------------------------------------------------------- +# AtomicAction._apply_offset +# --------------------------------------------------------------------------- + + +class TestAtomicActionApplyOffset: + """Tests for the shared _apply_offset method inherited from AtomicAction.""" + + def setup_method(self): + self.robot = _make_mock_robot() + self.mg = _make_mock_motion_generator(self.robot) + self.cfg = MoveActionCfg() + self.action = MoveAction(self.mg, cfg=self.cfg) + + def test_apply_offset_batched(self): + # [N, 4, 4] poses, [N, 3] offsets + poses = torch.eye(4).unsqueeze(0).repeat(3, 1, 1) + offsets = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + result = self.action._apply_offset(poses, offsets) + assert result.shape == (3, 4, 4) + assert result[0, :3, 3].tolist() == pytest.approx([1.0, 0.0, 0.0]) + assert result[1, :3, 3].tolist() == pytest.approx([0.0, 1.0, 0.0]) + assert result[2, :3, 3].tolist() == pytest.approx([0.0, 0.0, 1.0]) + + def test_apply_offset_broadcasts_single_offset(self): + # [N, 4, 4] poses, [3] single offset broadcast to all + poses = torch.eye(4).unsqueeze(0).repeat(2, 1, 1) + offset = torch.tensor([0.1, 0.2, 0.3]) + result = self.action._apply_offset(poses, offset) + assert result.shape == (2, 4, 4) + for i in range(2): + assert result[i, :3, 3].tolist() == pytest.approx([0.1, 0.2, 0.3]) + + def test_apply_offset_preserves_rotation(self): + """Offset only affects translation; rotation part stays unchanged.""" + poses = torch.eye(4).unsqueeze(0).repeat(1, 1, 1) + # Set a non-trivial rotation + poses[0, 0, 1] = -1.0 + poses[0, 1, 0] = 1.0 + offset = torch.tensor([1.0, 2.0, 3.0]) + result = self.action._apply_offset(poses, offset) + # Rotation block should be unchanged + assert torch.equal(result[0, :3, :3], poses[0, :3, :3]) + + +if __name__ == "__main__": + # For visual debugging + test = TestMoveActionHelpers() + test.setup_method() + test.test_compute_three_phase_waypoints_sums_to_sample_interval() diff --git a/tests/sim/atomic_actions/test_core.py b/tests/sim/atomic_actions/test_core.py new file mode 100644 index 00000000..7cebaa7b --- /dev/null +++ b/tests/sim/atomic_actions/test_core.py @@ -0,0 +1,171 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Tests for atomic action core module (Affordance, InteractionPoints, ObjectSemantics, ActionCfg).""" + +from __future__ import annotations + +import pytest +import torch + +from embodichain.lab.sim.atomic_actions.core import ( + ActionCfg, + Affordance, + InteractionPoints, + ObjectSemantics, +) + +# --------------------------------------------------------------------------- +# Affordance +# --------------------------------------------------------------------------- + + +class TestAffordance: + """Tests for the Affordance base dataclass.""" + + def test_default_values(self): + aff = Affordance() + assert aff.object_label == "" + assert aff.geometry == {} + assert aff.custom_config == {} + + def test_mesh_vertices_returns_tensor(self): + vertices = torch.randn(10, 3) + aff = Affordance(geometry={"mesh_vertices": vertices}) + assert torch.equal(aff.mesh_vertices, vertices) + + def test_mesh_vertices_returns_none_when_missing(self): + aff = Affordance() + assert aff.mesh_vertices is None + + def test_mesh_vertices_raises_on_wrong_type(self): + aff = Affordance(geometry={"mesh_vertices": [1, 2, 3]}) + with pytest.raises(TypeError, match="must be a torch.Tensor"): + _ = aff.mesh_vertices + + def test_mesh_triangles_returns_tensor(self): + triangles = torch.randint(0, 10, (5, 3)) + aff = Affordance(geometry={"mesh_triangles": triangles}) + assert torch.equal(aff.mesh_triangles, triangles) + + def test_mesh_triangles_returns_none_when_missing(self): + aff = Affordance() + assert aff.mesh_triangles is None + + def test_mesh_triangles_raises_on_wrong_type(self): + aff = Affordance(geometry={"mesh_triangles": "bad"}) + with pytest.raises(TypeError, match="must be a torch.Tensor"): + _ = aff.mesh_triangles + + def test_custom_config_get_set(self): + aff = Affordance() + aff.set_custom_config("key_a", 42) + assert aff.get_custom_config("key_a") == 42 + assert aff.get_custom_config("missing") is None + assert aff.get_custom_config("missing", "default") == "default" + + def test_get_batch_size_returns_one(self): + # Base Affordance always returns 1 + assert Affordance().get_batch_size() == 1 + + +# --------------------------------------------------------------------------- +# InteractionPoints +# --------------------------------------------------------------------------- + + +class TestInteractionPoints: + """Tests for InteractionPoints affordance.""" + + def test_default_points_shape(self): + ip = InteractionPoints() + assert ip.points.shape == (1, 3) + + def test_get_batch_size_matches_points(self): + points = torch.randn(5, 3) + ip = InteractionPoints(points=points) + assert ip.get_batch_size() == 5 + + def test_get_points_by_type_returns_matching_subset(self): + points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + ip = InteractionPoints(points=points, point_types=["push", "poke", "push"]) + result = ip.get_points_by_type("push") + assert result is not None + assert result.shape == (2, 3) + assert torch.equal(result[0], points[0]) + assert torch.equal(result[1], points[2]) + + def test_get_points_by_type_returns_none_for_missing_type(self): + ip = InteractionPoints(points=torch.zeros(2, 3), point_types=["push", "push"]) + assert ip.get_points_by_type("poke") is None + + def test_get_approach_direction_from_normals(self): + normals = torch.tensor([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) + ip = InteractionPoints(points=torch.zeros(2, 3), normals=normals) + # Approach is opposite of normal + assert torch.equal(ip.get_approach_direction(0), torch.tensor([0.0, 0.0, -1.0])) + assert torch.equal(ip.get_approach_direction(1), torch.tensor([-1.0, 0.0, 0.0])) + + def test_get_approach_direction_default_without_normals(self): + ip = InteractionPoints(points=torch.zeros(1, 3)) + direction = ip.get_approach_direction(0) + assert torch.equal(direction, torch.tensor([0.0, 0.0, 1.0])) + + +# --------------------------------------------------------------------------- +# ObjectSemantics +# --------------------------------------------------------------------------- + + +class TestObjectSemantics: + """Tests for ObjectSemantics dataclass.""" + + def test_post_init_binds_label_and_geometry(self): + geometry = {"bounding_box": [0.1, 0.2, 0.3]} + aff = Affordance() + sem = ObjectSemantics( + affordance=aff, + geometry=geometry, + label="mug", + ) + assert sem.affordance.object_label == "mug" + assert sem.affordance.geometry is geometry + + def test_default_optional_fields(self): + sem = ObjectSemantics( + affordance=Affordance(), + geometry={}, + ) + assert sem.label == "none" + assert sem.properties == {} + assert sem.entity is None + + +# --------------------------------------------------------------------------- +# ActionCfg +# --------------------------------------------------------------------------- + + +class TestActionCfg: + """Tests for ActionCfg defaults.""" + + def test_default_values(self): + cfg = ActionCfg() + assert cfg.name == "default" + assert cfg.control_part == "arm" + assert cfg.interpolation_type == "linear" + assert cfg.velocity_limit is None + assert cfg.acceleration_limit is None diff --git a/tests/sim/atomic_actions/test_engine.py b/tests/sim/atomic_actions/test_engine.py new file mode 100644 index 00000000..52dc034d --- /dev/null +++ b/tests/sim/atomic_actions/test_engine.py @@ -0,0 +1,191 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Tests for atomic action engine (registry, SemanticAnalyzer, AtomicActionEngine).""" + +from __future__ import annotations + +import pytest +import torch +from unittest.mock import MagicMock, Mock + +from embodichain.lab.sim.atomic_actions.core import ( + ActionCfg, + Affordance, + ObjectSemantics, +) +from embodichain.lab.sim.atomic_actions.engine import ( + AtomicActionEngine, + SemanticAnalyzer, + get_registered_actions, + register_action, + unregister_action, +) + +# --------------------------------------------------------------------------- +# Global Action Registry +# --------------------------------------------------------------------------- + + +class TestGlobalRegistry: + """Tests for register_action / unregister_action / get_registered_actions.""" + + def teardown_method(self): + # Clean up any test registrations + unregister_action("_test_dummy") + + def test_register_and_retrieve(self): + mock_cls = Mock() + register_action("_test_dummy", mock_cls) + registry = get_registered_actions() + assert "_test_dummy" in registry + assert registry["_test_dummy"] is mock_cls + + def test_unregister_removes_entry(self): + register_action("_test_dummy", Mock()) + unregister_action("_test_dummy") + assert "_test_dummy" not in get_registered_actions() + + def test_unregister_nonexistent_is_noop(self): + # Should not raise + unregister_action("_nonexistent_action") + + def test_get_registered_actions_returns_copy(self): + """Mutating the returned dict should not affect the global registry.""" + result = get_registered_actions() + result["_should_not_persist"] = Mock() + assert "_should_not_persist" not in get_registered_actions() + + +# --------------------------------------------------------------------------- +# SemanticAnalyzer +# --------------------------------------------------------------------------- + + +class TestSemanticAnalyzer: + """Tests for SemanticAnalyzer.""" + + def setup_method(self): + self.analyzer = SemanticAnalyzer() + + def test_analyze_returns_object_semantics(self): + sem = self.analyzer.analyze("mug") + assert isinstance(sem, ObjectSemantics) + assert sem.label == "mug" + assert isinstance(sem.affordance, Affordance) + + def test_analyze_caches_by_default(self): + sem1 = self.analyzer.analyze("bottle") + sem2 = self.analyzer.analyze("bottle") + assert sem1 is sem2 + + def test_analyze_bypasses_cache_with_geometry(self): + sem1 = self.analyzer.analyze("bottle") + sem2 = self.analyzer.analyze( + "bottle", geometry={"bounding_box": [0.2, 0.2, 0.2]} + ) + assert sem1 is not sem2 + + def test_analyze_no_cache(self): + sem1 = self.analyzer.analyze("cup", use_cache=False) + sem2 = self.analyzer.analyze("cup", use_cache=False) + assert sem1 is not sem2 + + def test_clear_cache(self): + self.analyzer.analyze("can") + self.analyzer.clear_cache() + # After clearing, a new object should be created + sem1 = self.analyzer.analyze("can") + sem2 = self.analyzer.analyze("can") + assert sem1 is sem2 # re-cached after clear + + +# --------------------------------------------------------------------------- +# AtomicActionEngine._resolve_target +# --------------------------------------------------------------------------- + + +class TestResolveTarget: + """Tests for AtomicActionEngine._resolve_target with various input types.""" + + def setup_method(self): + self.robot = Mock() + self.robot.device = torch.device("cpu") + self.robot.dof = 6 + self.robot.get_qpos.return_value = torch.zeros(1, 6) + self.robot.get_joint_ids.return_value = list(range(6)) + + self.mg = Mock() + self.mg.robot = self.robot + self.mg.device = torch.device("cpu") + + self.engine = AtomicActionEngine(self.mg, actions_cfg_list=[]) + + def test_tensor_passthrough(self): + tensor = torch.eye(4) + result = self.engine._resolve_target(tensor) + assert result is tensor + + def test_object_semantics_passthrough(self): + sem = ObjectSemantics(affordance=Affordance(), geometry={}) + result = self.engine._resolve_target(sem) + assert result is sem + + def test_string_resolved_via_semantic_analyzer(self): + result = self.engine._resolve_target("mug") + assert isinstance(result, ObjectSemantics) + assert result.label == "mug" + + def test_dict_with_pose_key(self): + pose = torch.eye(4) + result = self.engine._resolve_target({"pose": pose}) + assert result is pose + + def test_dict_with_pose_raises_on_non_tensor(self): + with pytest.raises(TypeError, match="must be a torch.Tensor"): + self.engine._resolve_target({"pose": "not_a_tensor"}) + + def test_dict_with_semantics_key(self): + sem = ObjectSemantics(affordance=Affordance(), geometry={}, label="bottle") + result = self.engine._resolve_target({"semantics": sem}) + assert result is sem + + def test_dict_with_semantics_raises_on_wrong_type(self): + with pytest.raises(TypeError, match="must be an ObjectSemantics"): + self.engine._resolve_target({"semantics": "wrong"}) + + def test_dict_with_label_uses_analyzer(self): + result = self.engine._resolve_target({"label": "apple"}) + assert isinstance(result, ObjectSemantics) + assert result.label == "apple" + + def test_dict_without_label_raises(self): + with pytest.raises(ValueError, match="must provide 'label'"): + self.engine._resolve_target({"geometry": {}}) + + def test_dict_with_non_string_label_raises(self): + with pytest.raises(TypeError, match="must be a string"): + self.engine._resolve_target({"label": 123}) + + def test_unsupported_type_raises(self): + with pytest.raises(TypeError, match="target must be"): + self.engine._resolve_target(42) + + +if __name__ == "__main__": + test = TestSemanticAnalyzer() + test.setup_method() + test.test_analyze_returns_object_semantics()