diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ed8ebf5 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ \ No newline at end of file diff --git a/LICENCE b/LICENCE new file mode 100644 index 0000000..92bbfd6 --- /dev/null +++ b/LICENCE @@ -0,0 +1,19 @@ +Copyright (c) 2025, Hybrid Robotics. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 4cbbe28..e7a2be8 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,161 @@ -### HOW TO USE THIS TEMPLATE +# BeyondMimic Motion Tracking Inference -> **DO NOT FORK** this is meant to be used from **[Use this template](https://github.com/qiayuanl/legged_template_controller/generate)** feature. +[[Website]](https://beyondmimic.github.io/) +[[Arxiv]](https://arxiv.org/abs/2508.08241) +[[Video]](https://youtu.be/RS_MtKVIAzY) -1. Click on **[Use this template](https://github.com/qiayuanl/legged_template_controller/generate)** -3. Give a name to your repo, this name will be your ROS package name - (e.g. `rss25_controllers`, `koushil_controller`, **all lowercase and underscores separation for the name should be used!**) -3. Wait until the first run of CI finishes - (Github Actions will process the template and commit to your new repo) -4. Then clone your new project to the colcon workspace (e.g. `colcon_ws/src`) and happy coding! +This repository provides the inference pipeline for motion tracking policies in BeyondMimic. The pipeline is implemented +in C++ using the ONNX CPU inference engine. Model parameters (joint order, impedance, etc.) +are stored in ONNX metadata, and the reference motion is returned via the `forward()` function. +See [this script](https://github.com/HybridRobotics/whole_body_tracking/blob/main/source/whole_body_tracking/whole_body_tracking/utils/exporter.py) +for details on exporting models. + +This repo also serves as an example of how to implement a custom controller using the +[legged_control2](https://qiayuanl.github.io/legged_control2_doc/) framework. + +## Install and Build (docker) +The Dockerfile and helper scripts are provided to simplify installation and build. +Make sure [Docker](https://docs.docker.com/engine/install/ubuntu/) and the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) are installed first. +```bash +# Host +git clone https://github.com/HybridRobotics/motion_tracking_controller.git +cd motion_tracking_controller/docker +docker compose up -d --build +docker exec -it wbt_ws bash +``` + +```bash +# container +cd /wbt_ws/src/motion_tracking_controller +./scripts/colcon-config.sh Release +``` +This script automatically clones and builds the required dependencies. +See [Basic Usage](#basic-usage) for usage instructions. + +## Installation + +### Dependencies + +This software is built on +the [ROS 2 Jazzy](https://docs.ros.org/en/jazzy/Installation/Ubuntu-Install-Debs.html#ubuntu-deb-packages), which +needs to be installed first. Additionally, this code base depends on `legged_control2`. + +### Install `legged_control2` + +Pre-built binaries for `legged_control2` are available on ROS 2 Jazzy. We recommend first reading +the [full documentation](https://qiayuanl.github.io/legged_control2_doc/overview.html). + +Specifically, For this repo, follow +the [Debian Source installation](https://qiayuanl.github.io/legged_control2_doc/installation.html#debian-source-recommended). +Additionally, install Unitree-specific packages: + +```bash +# Add debian source +echo "deb [trusted=yes] https://github.com/qiayuanl/unitree_buildfarm/raw/noble-jazzy-amd64/ ./" | sudo tee /etc/apt/sources.list.d/qiayuanl_unitree_buildfarm.list +echo "yaml https://github.com/qiayuanl/unitree_buildfarm/raw/noble-jazzy-amd64/local.yaml jazzy" | sudo tee /etc/ros/rosdep/sources.list.d/1-qiayuanl_unitree_buildfarm.list +sudo apt-get update +``` + +```bash +# Install packages +sudo apt-get install ros-jazzy-unitree-description +sudo apt-get install ros-jazzy-unitree-systems +``` + +### Build Package + +After installing `legged_control2`, you can build this package. You’ll also need the +`unitree_bringup` repo, which contains utilities not included in the pre-built binaries. + +Create a ROS 2 workspace if you don't have one. Below we use `~/colcon_ws` as an example. + +```bash +mkdir -p ~/colcon_ws/src +``` + +Clone two repo into the `src` of workspace. + +```bash +cd ~/colcon_ws/src +git clone https://github.com/qiayuanl/unitree_bringup.git +git clone https://github.com/HybridRobotics/motion_tracking_controller.git +cd ../ +``` + +Install dependencies automatically: + +```bash +rosdep install --from-paths src --ignore-src -r -y +``` + +Build the packages: + +```bash +colcon build --symlink-install --cmake-args -DCMAKE_BUILD_TYPE=RelwithDebInfo --packages-up-to unitree_bringup +colcon build --symlink-install --cmake-args -DCMAKE_BUILD_TYPE=RelwithDebInfo --packages-up-to motion_tracking_controller +source install/setup.bash +``` + +## Basic Usage + +### Sim-to-Sim + +We provide a launch file for running the policy in MuJoCo simulation. + +```bash +# Load policy from WandB +ros2 launch motion_tracking_controller mujoco.launch.py wandb_path:= +``` + +```bash +# OR load policy from local ONNX file (should be absolute or start with `~`) +ros2 launch motion_tracking_controller mujoco.launch.py policy_path:= +``` + +### Real Experiments + +> ⚠️ **Disclaimer** +> Running these models on real robots is **dangerous** and entirely at your own risk. +> They are provided **for research only**, and we accept **no responsibility** for any harm, damage, or malfunction. + +1. Connect to the robot via ethernet cable. +2. Set the ethernet adapter to static IP: `192.168.123.11`. +3. Use `ifconfig` to find the ``, (e.g.,`eth0` or `enp3s0`). + +```bash +# Load policy from WandB +ros2 launch motion_tracking_controller real.launch.py network_interface:= wandb_path:= +``` + +```bash +# OR load policy from local ONNX file (should be absolute or start with `~`) +ros2 launch motion_tracking_controller real.launch.py network_interface:= policy_path:=.onnx +``` + +The robot should enter standby controller in the beginning. +Use the Unitree remote (joystick) to start and stop the policy: + +- Standby controller (joint position control): `L1 + A` +- Motion tracking controller (the policy): `R1 + A` +- E-stop (damping): `B` + +## Code Structure + +This section will be especially helpful if you decide to write your own legged_control2 controller. +For a minimal starting point, check +the [legged_template_controller](https://github.com/qiayuanl/legged_template_controller). + +Below is an overview of the code structure for this repository: + +- **`include`** or **`src`** + - **`MotionTrackingController`** Manages observations (like an RL environment) and passes them to the policy. + + - **`MotionOnnxPolicy`** Wraps the neural network, runs inference, and extracts reference motion from the ONNX file. + + - **`MotionCommand`** Defines observation terms aligned with the training code. + + +- **`launch`** + - Includes launch files like `mujoco.launch.py` and `real.launch.py` for simulation and real robot execution. +- **`config`** + - Stores configuration files for standby controller and state estimation params. diff --git a/config/g1/controllers.yaml b/config/g1/controllers.yaml index f31e1c8..926ffc4 100644 --- a/config/g1/controllers.yaml +++ b/config/g1/controllers.yaml @@ -8,8 +8,6 @@ controller_manager: type: legged_controllers/StandbyController walking_controller: type: motion_tracking_controller/MotionTrackingController - walking_controller1: - type: legged_rl_controllers/OnnxController state_estimator: ros__parameters: @@ -17,11 +15,8 @@ state_estimator: base_name: "pelvis" six_dof_contact_names: [ "LL_FOOT", "LR_FOOT" ] estimation: - contact: - height_sensor_noise: 1e10 position: noise: 1e-2 - topic: "/mid360" frame_id: "mid360_link" standby_controller: @@ -36,11 +31,11 @@ standby_controller: left_elbow_joint, left_wrist_roll_joint, left_wrist_pitch_joint, left_wrist_yaw_joint, right_shoulder_pitch_joint, right_shoulder_roll_joint, right_shoulder_yaw_joint, right_elbow_joint, right_wrist_roll_joint, right_wrist_pitch_joint, right_wrist_yaw_joint ] - default_position: [ -0.312, 0.0, 0.0, 0.669, -0.363, 0.0, - -0.312, 0.0, 0.0, 0.669, -0.363, 0.0, + default_position: [ -0.312, 0.0, 0.0, 0.669, -0.33, 0.0, + -0.312, 0.0, 0.0, 0.669, -0.33, 0.0, 0.0, 0.0, 0.0, - 0.2, 1.57, 0.0, 1.57, 0.0, 0.0, 0.0, - 0.2, -1.57, 0.0, 1.57, 0.0, 0.0, 0.0 ] + 0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, + 0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0 ] kp: [ 350.0, 200.0, 200.0, 300.0, 300.0, 150.0, 350.0, 200.0, 200.0, 300.0, 300.0, 150.0, 200.0, 200.0, 200.0, @@ -54,13 +49,4 @@ standby_controller: walking_controller: ros__parameters: - policy: - path: "config/g1/policy.onnx" - motion: - reference_body: "torso_link" - body_names: [ 'pelvis', - 'left_hip_roll_link', 'left_knee_link', 'left_ankle_roll_link', - 'right_hip_roll_link', 'right_knee_link', 'right_ankle_roll_link', - 'torso_link', - 'left_shoulder_roll_link', 'left_elbow_link', 'left_wrist_yaw_link', - 'right_shoulder_roll_link', 'right_elbow_link', 'right_wrist_yaw_link' ] + update_rate: 50 # Remove this line if using Humble!!! diff --git a/config/g1/on_board.yaml b/config/g1/on_board.yaml deleted file mode 100644 index a7523f6..0000000 --- a/config/g1/on_board.yaml +++ /dev/null @@ -1,63 +0,0 @@ -controller_manager: - ros__parameters: - update_rate: 500 # Hz - - state_estimator: - type: legged_controllers/StateEstimator - standby_controller: - type: legged_controllers/StandbyController - walking_controller: - type: motion_tracking_controller/MotionTrackingController - -state_estimator: - ros__parameters: - model: - base_name: "pelvis" - six_dof_contact_names: [ "LL_FOOT", "LR_FOOT" ] - estimation: - contact: - height_sensor_noise: 1e10 - position: - noise: 1e-2 - topic: "/glim/odom" - -standby_controller: - ros__parameters: - joint_names: - [ left_hip_pitch_joint, left_hip_roll_joint, left_hip_yaw_joint, - left_knee_joint, left_ankle_pitch_joint, left_ankle_roll_joint, - right_hip_pitch_joint, right_hip_roll_joint, right_hip_yaw_joint, - right_knee_joint, right_ankle_pitch_joint, right_ankle_roll_joint, - waist_yaw_joint, waist_roll_joint, waist_pitch_joint, - left_shoulder_pitch_joint, left_shoulder_roll_joint, left_shoulder_yaw_joint, - left_elbow_joint, left_wrist_roll_joint, left_wrist_pitch_joint, left_wrist_yaw_joint, - right_shoulder_pitch_joint, right_shoulder_roll_joint, right_shoulder_yaw_joint, - right_elbow_joint, right_wrist_roll_joint, right_wrist_pitch_joint, right_wrist_yaw_joint ] - default_position: [ -0.312, 0.0, 0.0, 0.669, -0.363, 0.0, - -0.312, 0.0, 0.0, 0.669, -0.363, 0.0, - 0.0, 0.0, 0.0, - 0.2, 1.57, 0.0, 1.57, 0.0, 0.0, 0.0, - 0.2, -1.57, 0.0, 1.57, 0.0, 0.0, 0.0 ] - kp: [ 350.0, 200.0, 200.0, 300.0, 300.0, 150.0, - 350.0, 200.0, 200.0, 300.0, 300.0, 150.0, - 200.0, 200.0, 200.0, - 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, - 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0 ] - kd: [ 5.0, 5.0, 5.0, 10.0, 5.0, 5.0, - 5.0, 5.0, 5.0, 10.0, 5.0, 5.0, - 5.0, 5.0, 5.0, - 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, - 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0 ] - -walking_controller: - ros__parameters: - policy: - path: "config/g1/policy.onnx" - motion: - reference_body: "torso_link" - body_names: [ 'pelvis', - 'left_hip_roll_link', 'left_knee_link', 'left_ankle_roll_link', - 'right_hip_roll_link', 'right_knee_link', 'right_ankle_roll_link', - 'torso_link', - 'left_shoulder_roll_link', 'left_elbow_link', 'left_wrist_yaw_link', - 'right_shoulder_roll_link', 'right_elbow_link', 'right_wrist_yaw_link' ] diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..d37f16d --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,78 @@ +# Ubuntu 24.04 image based on ROS Jazzy (supports both x86_64 and arm64) +FROM ros:jazzy-perception-noble + +ENV ROS_DISTRO=jazzy + +# Disable interactive prompts during apt install +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && \ + # APT / certificates / repository management / GPG keys + # (apt-transport-https is usually unnecessary on Ubuntu 22.04) + apt-get install -y --no-install-recommends \ + ca-certificates gnupg lsb-release software-properties-common && \ + \ + # Networking and troubleshooting tools + apt-get install -y --no-install-recommends \ + iputils-ping iproute2 wget curl net-tools && \ + \ + # Build / debug / code management tools + apt-get install -y --no-install-recommends \ + build-essential cmake gdb git git-lfs vim python3-vcstool && \ + \ + # Common C++ math and utility libraries + apt-get install -y --no-install-recommends \ + libboost-all-dev libeigen3-dev libyaml-cpp-dev && \ + \ + # JSON processing / code formatting / miscellaneous utilities + apt-get install -y --no-install-recommends \ + jq clang-format unzip ncdu && \ + \ + # GUI support (kept since GUI is required; gnome-terminal is large but convenient) + apt-get install -y --no-install-recommends \ + dbus-x11 gnome-terminal && \ + \ + # Hardware / USB / joystick support + apt-get install -y --no-install-recommends \ + udev usbutils joystick && \ + \ + # Python toolchain (pip) + apt-get install -y --no-install-recommends \ + python3-pip && \ + \ + git lfs install + + + # ROS dependencies +RUN apt-get install -y --no-install-recommends \ + ros-jazzy-plotjuggler-ros ros-jazzy-pinocchio + + # Install additional dependencies + # unitree_sdk2 +RUN apt-get install -y --no-install-recommends \ + libspdlog-dev libfmt-dev && \ + \ + # unitree_ros2 + apt-get install -y --no-install-recommends \ + ros-jazzy-rmw-cyclonedds-cpp ros-jazzy-rosidl-generator-dds-idl && \ + \ + # unitree_mujoco + apt-get install -y --no-install-recommends \ + libglfw3-dev + + # Add apt source +RUN echo "deb [trusted=yes] https://github.com/qiayuanl/legged_buildfarm/raw/noble-jazzy-amd64/ ./" | sudo tee /etc/apt/sources.list.d/qiayuanl_legged_buildfarm.list && \ + echo "yaml https://github.com/qiayuanl/legged_buildfarm/raw/noble-jazzy-amd64/local.yaml jazzy" | sudo tee /etc/ros/rosdep/sources.list.d/1-qiayuanl_legged_buildfarm.list && \ + echo "deb [trusted=yes] https://github.com/qiayuanl/unitree_buildfarm/raw/noble-jazzy-amd64/ ./" | sudo tee /etc/apt/sources.list.d/qiayuanl_unitree_buildfarm.list && \ + echo "yaml https://github.com/qiayuanl/unitree_buildfarm/raw/noble-jazzy-amd64/local.yaml jazzy" | sudo tee /etc/ros/rosdep/sources.list.d/1-qiayuanl_unitree_buildfarm.list && \ + echo "deb [trusted=yes] https://github.com/qiayuanl/simulation_buildfarm/raw/noble-jazzy-amd64/ ./" | sudo tee /etc/apt/sources.list.d/qiayuanl_simulation_buildfarm.list && \ + echo "yaml https://github.com/qiayuanl/simulation_buildfarm/raw/noble-jazzy-amd64/local.yaml jazzy" | sudo tee /etc/ros/rosdep/sources.list.d/1-qiayuanl_simulation_buildfarm.list && \ + sudo apt-get update && \ + sudo apt-get install -y --no-install-recommends \ + ros-jazzy-legged-control-base ros-jazzy-mujoco-ros2-control ros-jazzy-unitree-description ros-jazzy-unitree-systems + +# Copy workspace setup script into the container +COPY workspace-config.sh /workspace-config.sh + +# Use workspace-config.sh as the container entrypoint +ENTRYPOINT ["/workspace-config.sh"] diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml new file mode 100644 index 0000000..ee13be5 --- /dev/null +++ b/docker/docker-compose.yaml @@ -0,0 +1,47 @@ +name: wbt_ws + +services: + wbt_ws: + build: + context: . # Build context (current directory) + dockerfile: Dockerfile # Defaults to Dockerfile, can be omitted + image: wbt_ws:latest # Image name + container_name: wbt_ws + + privileged: true + security_opt: + - seccomp=unconfined + + # GPU support (modern Compose) + gpus: all + # runtime: nvidia # (for jetson) + + # Environment variables + environment: + - NVIDIA_VISIBLE_DEVICES=all + - NVIDIA_DRIVER_CAPABILITIES=all + - DISPLAY=${DISPLAY:-:0} + # Use a fixed path inside container + - XAUTHORITY=/root/.Xauthority + # - XDG_RUNTIME_DIR=${XDG_RUNTIME_DIR} + - QT_X11_NO_MITSHM=1 + - WORKSPACE=/wbt_ws + + # Volume mounts + volumes: + - /tmp/.X11-unix:/tmp/.X11-unix:rw + # Mount host Xauthority file into container + - ${XAUTHORITY:-$HOME/.Xauthority}:/root/.Xauthority:rw + - /etc/localtime:/etc/localtime:ro + - /dev/input:/dev/input:rwm + - /run/udev:/run/udev:ro + - ../:/wbt_ws/src/motion_tracking_controller + + + # Network configuration + network_mode: host + ipc: host + + # Interactive mode (equivalent to -it) + stdin_open: true + tty: true \ No newline at end of file diff --git a/docker/workspace-config.sh b/docker/workspace-config.sh new file mode 100755 index 0000000..84e56c6 --- /dev/null +++ b/docker/workspace-config.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +############################################ +# Add /usr/local/lib to the default library search path +echo "/usr/local/lib" | sudo tee /etc/ld.so.conf.d/local-lib.conf +ldconfig + +############################################ +# Set up / install GDB Eigen pretty printers +cp $WORKSPACE/scripts/.gdbinit /root +cp $WORKSPACE/scripts/.gdb_eigen /root + +cd $WORKSPACE + +exec bash diff --git a/include/motion_tracking_controller/MotionCommand.h b/include/motion_tracking_controller/MotionCommand.h index 089381e..8884810 100644 --- a/include/motion_tracking_controller/MotionCommand.h +++ b/include/motion_tracking_controller/MotionCommand.h @@ -18,14 +18,14 @@ class MotionCommandTerm : public CommandTerm { using SharedPtr = std::shared_ptr; MotionCommandTerm(MotionCommandCfg cfg, MotionOnnxPolicy::SharedPtr motionPolicy) - : cfg_(std::move(cfg)), motionPolicy_(std::move(motionPolicy)), referenceRobotIndex_(0), referenceMotionIndex_(0) {} + : cfg_(std::move(cfg)), motionPolicy_(std::move(motionPolicy)), anchorRobotIndex_(0), anchorMotionIndex_(0) {} vector_t getValue() override; void reset() override; MotionCommandCfg getCfg() const { return cfg_; } - vector3_t getReferencePositionLocal() const; - vector_t getReferenceOrientationLocal() const; + vector3_t getAnchorPositionLocal() const; + vector_t getAnchorOrientationLocal() const; vector_t getRobotBodyPositionLocal() const; vector_t getRobotBodyOrientationLocal() const; @@ -35,7 +35,7 @@ class MotionCommandTerm : public CommandTerm { MotionCommandCfg cfg_; MotionOnnxPolicy::SharedPtr motionPolicy_; - size_t referenceRobotIndex_, referenceMotionIndex_; + size_t anchorRobotIndex_, anchorMotionIndex_; std::vector bodyIndices_{}; pinocchio::SE3 worldToInit_; }; diff --git a/include/motion_tracking_controller/MotionObservation.h b/include/motion_tracking_controller/MotionObservation.h index 5c0bdb1..6cdc366 100644 --- a/include/motion_tracking_controller/MotionObservation.h +++ b/include/motion_tracking_controller/MotionObservation.h @@ -19,22 +19,22 @@ class MotionObservation : public ObservationTerm { MotionCommandTerm::SharedPtr commandTerm_; }; -class MotionReferencePosition final : public MotionObservation { +class MotionAnchorPosition final : public MotionObservation { public: using MotionObservation::MotionObservation; size_t getSize() const override { return 3; } protected: - vector_t evaluate() override { return commandTerm_->getReferencePositionLocal(); } + vector_t evaluate() override { return commandTerm_->getAnchorPositionLocal(); } }; -class MotionReferenceOrientation final : public MotionObservation { +class MotionAnchorOrientation final : public MotionObservation { public: using MotionObservation::MotionObservation; size_t getSize() const override { return 6; } protected: - vector_t evaluate() override { return commandTerm_->getReferenceOrientationLocal(); } + vector_t evaluate() override { return commandTerm_->getAnchorOrientationLocal(); } }; class RobotBodyPosition final : public MotionObservation { diff --git a/include/motion_tracking_controller/MotionOnnxPolicy.h b/include/motion_tracking_controller/MotionOnnxPolicy.h index d7601cd..6bac80b 100644 --- a/include/motion_tracking_controller/MotionOnnxPolicy.h +++ b/include/motion_tracking_controller/MotionOnnxPolicy.h @@ -11,22 +11,29 @@ namespace legged { class MotionOnnxPolicy : public OnnxPolicy { public: using SharedPtr = std::shared_ptr; - using OnnxPolicy::OnnxPolicy; + MotionOnnxPolicy(const std::string& modelPath, size_t startStep) : OnnxPolicy(modelPath), startStep_(startStep) {} void reset() override; vector_t forward(const vector_t& observations) override; + std::string getAnchorBodyName() const { return anchorBodyName_; } + std::vector getBodyNames() const { return bodyNames_; } + vector_t getJointPosition() const { return jointPosition_; } vector_t getJointVelocity() const { return jointVelocity_; } std::vector getBodyPositions() const { return bodyPositions_; } std::vector getBodyOrientations() const { return bodyOrientations_; } + void parseMetadata() override; + protected: - size_t time_step_ = 0; + size_t timeStep_ = 0, startStep_ = 0; vector_t jointPosition_; vector_t jointVelocity_; std::vector bodyPositions_; std::vector bodyOrientations_; + std::string anchorBodyName_; + std::vector bodyNames_; }; } // namespace legged diff --git a/include/motion_tracking_controller/MotionTrackingController.h b/include/motion_tracking_controller/MotionTrackingController.h index 5ce52e0..86940f0 100644 --- a/include/motion_tracking_controller/MotionTrackingController.h +++ b/include/motion_tracking_controller/MotionTrackingController.h @@ -1,15 +1,14 @@ #pragma once -#include -// #include +#include -#include "motion_tracking_controller/common.h" #include "motion_tracking_controller/MotionCommand.h" +#include "motion_tracking_controller/common.h" namespace legged { -class MotionTrackingController : public OnnxController { +class MotionTrackingController : public RlController { public: - controller_interface::return_type update(const rclcpp::Time& time, const rclcpp::Duration& period) override; + controller_interface::CallbackReturn on_init() override; controller_interface::CallbackReturn on_configure(const rclcpp_lifecycle::State& previous_state) override; @@ -23,7 +22,6 @@ class MotionTrackingController : public OnnxController { MotionCommandCfg cfg_; MotionCommandTerm::SharedPtr commandTerm_; - // DataLogger::SharedPtr dataLogger_; }; } // namespace legged diff --git a/include/motion_tracking_controller/common.h b/include/motion_tracking_controller/common.h index da7a590..86ecfdc 100644 --- a/include/motion_tracking_controller/common.h +++ b/include/motion_tracking_controller/common.h @@ -11,7 +11,7 @@ namespace legged { struct MotionCommandCfg { - std::string referenceBody; + std::string anchorBody; std::vector bodyNames; }; diff --git a/launch/mujoco.launch.py b/launch/mujoco.launch.py index f9639a9..b28de0f 100644 --- a/launch/mujoco.launch.py +++ b/launch/mujoco.launch.py @@ -9,45 +9,67 @@ SetLaunchConfiguration, IncludeLaunchDescription ) +from launch.conditions import IfCondition from launch.launch_description_sources import PythonLaunchDescriptionSource -from launch.substitutions import Command, FindExecutable, PathJoinSubstitution, LaunchConfiguration, PythonExpression +from launch.substitutions import Command, FindExecutable, PathJoinSubstitution, LaunchConfiguration, PythonExpression, \ + ThisLaunchFileDir from launch_ros.actions import Node from launch_ros.substitutions import FindPackageShare -def fill_policy_path(config_path, package_name): +# -------------------------- +# Internal: minimal generic override +# -------------------------- +def generate_temp_config(config_path, package_name, kv_pairs): + """ + Load /, apply overrides from kv_pairs, + and write to /tmp//temp_controllers.yaml. Returns the path. + kv_pairs: list of (dotted_key, raw_value_str) + """ pkg_dir = get_package_share_directory(package_name) src_path = os.path.join(pkg_dir, config_path) dst_path = os.path.join('/tmp', package_name, 'temp_controllers.yaml') - os.makedirs(os.path.dirname(dst_path), exist_ok=True) with open(src_path, 'r') as f: - config = yaml.safe_load(f) - - for ns in list(config.keys()): - params = config.get(ns, {}).setdefault('ros__parameters', {}) - if 'policy' in params and 'path' in params['policy']: - params['policy']['path'] = os.path.join(pkg_dir, params['policy']['path']) + cfg = yaml.safe_load(f) or {} + + for dotted_key, raw_val in kv_pairs: + parts = [p for p in dotted_key.split('.') if p] + if len(parts) < 2: + raise ValueError( + f"Key '{dotted_key}' is incomplete; expected '.[ros__parameters.]foo.bar'" + ) + # Auto-insert ros__parameters right after namespace if omitted + if parts[1] != 'ros__parameters': + parts.insert(1, 'ros__parameters') + + try: + val = yaml.safe_load(raw_val) + except Exception: + val = raw_val + + cur = cfg + for k in parts[:-1]: + if not isinstance(cur.get(k), dict): + cur[k] = {} + cur = cur[k] + cur[parts[-1]] = val with open(dst_path, 'w') as f: - yaml.dump(config, f) - print(f"Modified controllers.yaml saved to {dst_path}") + yaml.dump(cfg, f, sort_keys=False) + print(f"[launch] Temp controllers.yaml written to {dst_path}") return dst_path +# -------------------------- +# ROS nodes / launch wiring +# -------------------------- def control_spawner(names, inactive=False): - # Start building the arguments list with the controller names args = list(names) - # Add the parameter file from the LaunchConfiguration - args += ['--param-file', LaunchConfiguration('controllers_yaml')] - - # If you want them to start inactive (rather than active), pass `--inactive` if inactive: args.append('--inactive') - - # Return the spawner node return Node( package='controller_manager', executable='spawner', @@ -58,11 +80,25 @@ def control_spawner(names, inactive=False): def setup_controllers(context): robot_type_value = LaunchConfiguration('robot_type').perform(context) - - controllers_config_path = 'config/' + robot_type_value + '/controllers.yaml' - temp_controllers_config_path = fill_policy_path( + policy_path_value = LaunchConfiguration('policy_path').perform(context) + start_step_value = LaunchConfiguration('start_step').perform(context) + ext_pos_corr = LaunchConfiguration('ext_pos_corr').perform(context) + + kv_pairs = [] + if policy_path_value: + abs_path = os.path.abspath(os.path.expanduser(os.path.expandvars(policy_path_value))) + kv_pairs.append(('walking_controller.policy.path', abs_path)) + if start_step_value: + kv_pairs.append(('walking_controller.motion.start_step', start_step_value)) + if ext_pos_corr.lower() in ["true", "1", "yes"]: + kv_pairs.append(('state_estimator.estimation.contact.height_sensor_noise', 1e10)) + kv_pairs.append(('state_estimator.estimation.position.topic', "/mid360")) + + controllers_config_path = f'config/{robot_type_value}/controllers.yaml' + temp_controllers_config_path = generate_temp_config( controllers_config_path, - 'motion_tracking_controller' + 'motion_tracking_controller', + kv_pairs ) set_controllers_yaml = SetLaunchConfiguration( @@ -70,26 +106,17 @@ def setup_controllers(context): value=temp_controllers_config_path ) - active_list = [ - "state_estimator", - "standby_controller", - ] + active_list = ["state_estimator", "walking_controller"] + inactive_list = ["standby_controller"] - inactive_list = [ - "walking_controller", - ] active_spawner = control_spawner(active_list) inactive_spawner = control_spawner(inactive_list, inactive=True) - return [ - set_controllers_yaml, - active_spawner, - inactive_spawner - ] + + return [set_controllers_yaml, active_spawner, inactive_spawner] def generate_launch_description(): robot_type = LaunchConfiguration('robot_type') - urdf_name = PythonExpression(["'g1' if '", robot_type, "' == 'g1' else 'sdk1'"]) robot_description_command = Command([ @@ -101,10 +128,8 @@ def generate_launch_description(): urdf_name, "robot.xacro" ]), - " ", - "robot_type:=", robot_type, - " ", - "simulation:=", "mujoco"]) + " ", "robot_type:=", robot_type, + " ", "simulation:=", "mujoco"]) robot_description = {"robot_description": robot_description_command} node_robot_state_publisher = Node( @@ -112,7 +137,7 @@ def generate_launch_description(): executable='robot_state_publisher', output='screen', parameters=[robot_description, { - 'publish_frequency': 1000.0, + 'publish_frequency': 500.0, 'use_sim_time': True }], ) @@ -124,16 +149,25 @@ def generate_launch_description(): {"model_package": "unitree_description", "model_file": PythonExpression(["'/mjcf/", robot_type, ".xml'"]), "physics_plugins": ["mujoco_ros2_control::MujocoRos2ControlPlugin"], + "use_sim_time": True }, robot_description, LaunchConfiguration('controllers_yaml'), ], output='screen') - controllers_opaque_func = OpaqueFunction( - function=setup_controllers + wandb = IncludeLaunchDescription( + PythonLaunchDescriptionSource([ThisLaunchFileDir(), "/wandb.launch.py"]), + launch_arguments={ + "wandb_path": LaunchConfiguration("wandb_path") + }.items(), + condition=IfCondition( + PythonExpression(["'", LaunchConfiguration('policy_path'), "' == ''"]) + ) ) + controllers_opaque_func = OpaqueFunction(function=setup_controllers) + teleop = PathJoinSubstitution([ FindPackageShare('unitree_bringup'), 'launch', @@ -142,10 +176,24 @@ def generate_launch_description(): return LaunchDescription([ DeclareLaunchArgument('robot_type', default_value='g1'), + DeclareLaunchArgument( + 'policy_path', + default_value='', + description='Absolute or ~-expanded path for walking_controller.policy.path' + ), + DeclareLaunchArgument( + 'start_step', + default_value='0', + description='Integer start step for walking_controller.motion.start_step' + ), + DeclareLaunchArgument( + 'ext_pos_corr', + default_value='false', + description='Enable external position correction' + ), + wandb, controllers_opaque_func, mujoco_simulator, node_robot_state_publisher, - IncludeLaunchDescription( - PythonLaunchDescriptionSource(teleop) - ) + IncludeLaunchDescription(PythonLaunchDescriptionSource(teleop)) ]) diff --git a/launch/real.launch.py b/launch/real.launch.py index 4d82ca4..d0e5ce0 100644 --- a/launch/real.launch.py +++ b/launch/real.launch.py @@ -4,52 +4,74 @@ from ament_index_python.packages import get_package_share_directory from launch import LaunchDescription from launch.actions import ( - RegisterEventHandler, + ExecuteProcess, DeclareLaunchArgument, OpaqueFunction, SetLaunchConfiguration, - IncludeLaunchDescription + IncludeLaunchDescription, ) -from launch.event_handlers import OnProcessStart +from launch.conditions import IfCondition from launch.launch_description_sources import PythonLaunchDescriptionSource -from launch.substitutions import Command, FindExecutable, PathJoinSubstitution, LaunchConfiguration, PythonExpression +from launch.substitutions import Command, FindExecutable, PathJoinSubstitution, LaunchConfiguration, PythonExpression, \ + ThisLaunchFileDir from launch_ros.actions import Node from launch_ros.substitutions import FindPackageShare -def fill_policy_path(config_path, package_name): +# -------------------------- +# Internal: minimal generic override +# -------------------------- +def generate_temp_config(config_path, package_name, kv_pairs): + """ + Load /, apply overrides from kv_pairs, + and write to /tmp//temp_controllers.yaml. Returns the path. + kv_pairs: list of (dotted_key, raw_value_str) + """ pkg_dir = get_package_share_directory(package_name) src_path = os.path.join(pkg_dir, config_path) dst_path = os.path.join('/tmp', package_name, 'temp_controllers.yaml') - os.makedirs(os.path.dirname(dst_path), exist_ok=True) with open(src_path, 'r') as f: - config = yaml.safe_load(f) - - for ns in list(config.keys()): - params = config.get(ns, {}).setdefault('ros__parameters', {}) - if 'policy' in params and 'path' in params['policy']: - params['policy']['path'] = os.path.join(pkg_dir, params['policy']['path']) + cfg = yaml.safe_load(f) or {} + + for dotted_key, raw_val in kv_pairs: + parts = [p for p in dotted_key.split('.') if p] + if len(parts) < 2: + raise ValueError( + f"Key '{dotted_key}' is incomplete; expected '.[ros__parameters.]foo.bar'" + ) + # Auto-insert ros__parameters right after namespace if omitted + if parts[1] != 'ros__parameters': + parts.insert(1, 'ros__parameters') + + try: + val = yaml.safe_load(raw_val) + except Exception: + val = raw_val + + cur = cfg + for k in parts[:-1]: + if not isinstance(cur.get(k), dict): + cur[k] = {} + cur = cur[k] + cur[parts[-1]] = val with open(dst_path, 'w') as f: - yaml.dump(config, f) - print(f"Modified controllers.yaml saved to {dst_path}") + yaml.dump(cfg, f, sort_keys=False) + print(f"[launch] Temp controllers.yaml written to {dst_path}") return dst_path +# -------------------------- +# ROS nodes / launch wiring +# -------------------------- def control_spawner(names, inactive=False): - # Start building the arguments list with the controller names args = list(names) - # Add the parameter file from the LaunchConfiguration args += ['--param-file', LaunchConfiguration('controllers_yaml')] - - # If you want them to start inactive (rather than active), pass `--inactive` if inactive: args.append('--inactive') - - # Return the spawner node return Node( package='controller_manager', executable='spawner', @@ -58,13 +80,27 @@ def control_spawner(names, inactive=False): ) -def setup_controllers(context, control_node): +def setup_controllers(context): robot_type_value = LaunchConfiguration('robot_type').perform(context) - - controllers_config_path = 'config/' + robot_type_value + '/on_board.yaml' - temp_controllers_config_path = fill_policy_path( + policy_path_value = LaunchConfiguration('policy_path').perform(context) + start_step_value = LaunchConfiguration('start_step').perform(context) + ext_pos_corr = LaunchConfiguration('ext_pos_corr').perform(context) + + kv_pairs = [] + if policy_path_value: + abs_path = os.path.abspath(os.path.expanduser(os.path.expandvars(policy_path_value))) + kv_pairs.append(('walking_controller.policy.path', abs_path)) + if start_step_value: + kv_pairs.append(('walking_controller.motion.start_step', start_step_value)) + if ext_pos_corr.lower() in ["true", "1", "yes"]: + kv_pairs.append(('state_estimator.estimation.contact.height_sensor_noise', 1e10)) + kv_pairs.append(('state_estimator.estimation.position.topic', "/glim/odom")) + + controllers_config_path = f'config/{robot_type_value}/controllers.yaml' + temp_controllers_config_path = generate_temp_config( controllers_config_path, - "motion_tracking_controller" + 'motion_tracking_controller', + kv_pairs ) set_controllers_yaml = SetLaunchConfiguration( @@ -72,25 +108,13 @@ def setup_controllers(context, control_node): value=temp_controllers_config_path ) - active_list = [ - "state_estimator", - "standby_controller", - ] - inactive_list = [ - "walking_controller", - ] + active_list = ["state_estimator", "standby_controller"] + inactive_list = ["walking_controller"] + active_spawner = control_spawner(active_list) inactive_spawner = control_spawner(inactive_list, inactive=True) - controller_event_handler = RegisterEventHandler( - event_handler=OnProcessStart( - target_action=control_node, - on_start=[active_spawner, inactive_spawner], - ) - ) - return [ - set_controllers_yaml, - controller_event_handler, - ] + + return [set_controllers_yaml, active_spawner, inactive_spawner] def generate_launch_description(): @@ -107,13 +131,11 @@ def generate_launch_description(): urdf_name, "robot.xacro" ]), - " ", - "robot_type:=", robot_type, - " ", - "simulation:=", "false", - " ", - "network_interface:=", network_interface + " ", "robot_type:=", robot_type, + " ", "simulation:=", "false", + " ", "network_interface:=", network_interface ]) + robot_description = {"robot_description": robot_description_command} node_robot_state_publisher = Node( @@ -121,7 +143,7 @@ def generate_launch_description(): executable='robot_state_publisher', output='screen', parameters=[robot_description, { - 'publish_frequency': 1000.0, + 'publish_frequency': 500.0, 'use_sim_time': True }], ) @@ -134,8 +156,44 @@ def generate_launch_description(): respawn=True, ) - controllers_opaque_func = OpaqueFunction( - function=setup_controllers, kwargs={'control_node': control_node} + wandb = IncludeLaunchDescription( + PythonLaunchDescriptionSource([ThisLaunchFileDir(), "/wandb.launch.py"]), + launch_arguments={ + "wandb_path": LaunchConfiguration("wandb_path") + }.items(), + condition=IfCondition( + PythonExpression(["'", LaunchConfiguration('policy_path'), "' == ''"]) + ) + ) + + controllers_opaque_func = OpaqueFunction(function=setup_controllers) + + # Exclude all Unitree topics... it should start from the same namespace, fuck Unitree! + exclude_regex = ( + r'(/EstimatorData|/SymState(_back)?|/api/.*' + r'|/arm/action/state|/arm_sdk' + r'|/audio_msg|/audiosender|/config_change_status' + r'|/dex3/(left|right)/(cmd|state)' + r'|/frontvideostream|/gnss' + r'|/gpt_(cmd|state)|/gptflowfeedback' + r'|/lf/(bmsstate|dex3/(left|right)/state|lowstate|mainboardstate|' + r'odommodestate|secondary_imu|sportmodestate)' + r'|/low(cmd|state)|/multiplestate|/odommodestate' + r'|/parameter_events|/public_network_status|/rosout' + r'|/rtc/(state|status)|/secondary_imu|/selftest' + r'|/servicestate(activate)?|/slam_info|/sportmodestate' + r'|/utlidar/range_info|/videohub/inner' + r'|/webrtc(req|res)|/wirelesscontroller)' + r'|/controller_manager/introspection_data/full' + r'|/controller_manager/statistics/full' + ) + + rosbag2 = ExecuteProcess( + cmd=[ + 'ros2', 'bag', 'record', '-s', 'mcap', '-a', # record all topics + '--exclude-regex', exclude_regex, # skip those that match the regex + ], + output='screen', ) teleop = PathJoinSubstitution([ @@ -146,10 +204,27 @@ def generate_launch_description(): return LaunchDescription([ DeclareLaunchArgument('robot_type', default_value='g1'), - DeclareLaunchArgument('network_interface', default_value=''), + DeclareLaunchArgument('network_interface'), + DeclareLaunchArgument( + 'policy_path', + default_value='', + description='Absolute or ~-expanded path for walking_controller.policy.path' + ), + DeclareLaunchArgument( + 'start_step', + default_value='0', + description='Integer start step for walking_controller.motion.start_step' + ), + DeclareLaunchArgument( + 'ext_pos_corr', + default_value='false', + description='Enable external position correction' + ), + wandb, controllers_opaque_func, control_node, node_robot_state_publisher, + rosbag2, IncludeLaunchDescription( PythonLaunchDescriptionSource(teleop) ) diff --git a/launch/wandb.launch.py b/launch/wandb.launch.py new file mode 100644 index 0000000..24155c4 --- /dev/null +++ b/launch/wandb.launch.py @@ -0,0 +1,46 @@ +# launch/wandb.launch.py +from pathlib import Path + +import wandb +from launch import LaunchDescription +from launch.actions import DeclareLaunchArgument, OpaqueFunction, SetLaunchConfiguration +from launch.substitutions import LaunchConfiguration + + +def _download_onnx(run_path: str, dest_dir: str) -> str: + api = wandb.Api() + run = api.run(run_path) + + # Look in run files + for f in run.files(): + if f.name.endswith(".onnx"): + Path(dest_dir).mkdir(parents=True, exist_ok=True) + f.download(root=dest_dir, replace=True) + return str(Path(dest_dir) / f.name) + + # Look in artifacts too (if logged as artifact) + for art in run.logged_artifacts(): + for af in art.files(): + if af.name.endswith(".onnx"): + Path(dest_dir).mkdir(parents=True, exist_ok=True) + af.download(root=dest_dir) + return str(Path(dest_dir) / af.name) + + raise RuntimeError(f"No .onnx file found in run {run_path}") + + +def _pull_and_set(context, *_, **__): + run_path = LaunchConfiguration("wandb_path").perform(context) + dest = LaunchConfiguration("dest_dir").perform(context) + + local_path = _download_onnx(run_path, dest) + print(f"[W&B] Downloaded ONNX to: {local_path}") + return [SetLaunchConfiguration("policy_path", local_path)] + + +def generate_launch_description(): + return LaunchDescription([ + DeclareLaunchArgument("wandb_path"), + DeclareLaunchArgument("dest_dir", default_value="/tmp/wandb_onnx"), + OpaqueFunction(function=_pull_and_set), + ]) diff --git a/package.xml b/package.xml index d6c1a78..fd90500 100644 --- a/package.xml +++ b/package.xml @@ -11,6 +11,7 @@ ament_cmake_auto legged_rl_controllers + rosbag2-storage-mcap ament_cmake diff --git a/scripts/colcon-config.sh b/scripts/colcon-config.sh new file mode 100755 index 0000000..0d6a9d6 --- /dev/null +++ b/scripts/colcon-config.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Default build type is Release +BUILD_TYPE="${1:-Release}" + +echo "Building with CMAKE_BUILD_TYPE=$BUILD_TYPE" + +# Get project root directory +PROJECT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"/../../../ +echo "Project directory: $PROJECT_DIR" +mkdir -p $PROJECT_DIR/lib +mkdir -p $PROJECT_DIR/src + +# Clone repos +cd $PROJECT_DIR/src +git clone https://github.com/qiayuanl/unitree_bringup.git + +# Build +cd $PROJECT_DIR +rosdep install --from-paths src --ignore-src -r -y + +source /opt/ros/jazzy/setup.bash +colcon build --symlink-install \ + --packages-up-to unitree_bringup motion_tracking_controller \ + --cmake-args \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ + -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ No newline at end of file diff --git a/src/MotionCommand.cpp b/src/MotionCommand.cpp index 33ea8dc..518e210 100644 --- a/src/MotionCommand.cpp +++ b/src/MotionCommand.cpp @@ -12,9 +12,9 @@ vector_t MotionCommandTerm::getValue() { void MotionCommandTerm::reset() { const auto& pinModel = model_->getPinModel(); - referenceRobotIndex_ = pinModel.getFrameId(cfg_.referenceBody); - if (referenceRobotIndex_ >= pinModel.nframes) { - throw std::runtime_error("Reference body " + cfg_.referenceBody + " not found."); + anchorRobotIndex_ = pinModel.getFrameId(cfg_.anchorBody); + if (anchorRobotIndex_ >= pinModel.nframes) { + throw std::runtime_error("Anchor body " + cfg_.anchorBody + " not found."); } for (const auto& bodyName : cfg_.bodyNames) { bodyIndices_.push_back(pinModel.getFrameId(bodyName)); @@ -22,35 +22,42 @@ void MotionCommandTerm::reset() { throw std::runtime_error("Frame " + bodyName + " not found."); } } - referenceMotionIndex_ = cfg_.bodyNames.size(); + anchorMotionIndex_ = cfg_.bodyNames.size(); for (size_t i = 0; i < cfg_.bodyNames.size(); ++i) { - if (cfg_.bodyNames[i] == cfg_.referenceBody) { - referenceMotionIndex_ = i; + if (cfg_.bodyNames[i] == cfg_.anchorBody) { + anchorMotionIndex_ = i; break; } } - if (referenceMotionIndex_ == cfg_.bodyNames.size()) { - throw std::runtime_error("Reference body " + cfg_.referenceBody + " not found in body names."); + if (anchorMotionIndex_ == cfg_.bodyNames.size()) { + throw std::runtime_error("Anchor body " + cfg_.anchorBody + " not found in body names."); } - const pinocchio::SE3 initToRef(motionPolicy_->getBodyOrientations()[referenceMotionIndex_], - motionPolicy_->getBodyPositions()[referenceMotionIndex_]); - const pinocchio::SE3 worldToRef = model_->getPinData().oMf[referenceRobotIndex_]; - worldToInit_ = worldToRef * initToRef.inverse(); - worldToInit_.rotation() = yawQuaternion(quaternion_t(worldToInit_.rotation())); + + // Move the whole motion frame s.t. the first frame of the motion is aligned with the current robot in position and yaw orientation. + pinocchio::SE3 initToAnchor(motionPolicy_->getBodyOrientations()[anchorMotionIndex_], motionPolicy_->getBodyPositions()[anchorMotionIndex_]); + pinocchio::SE3 worldToAnchor = model_->getPinData().oMf[anchorRobotIndex_]; + initToAnchor.rotation() = yawQuaternion(quaternion_t(initToAnchor.rotation())); + worldToAnchor.rotation() = yawQuaternion(quaternion_t(worldToAnchor.rotation())); + + worldToInit_ = worldToAnchor * initToAnchor.inverse(); + + std::cerr << initToAnchor << std::endl; + std::cerr << worldToAnchor << std::endl; + std::cerr << worldToInit_ << std::endl; } -vector3_t MotionCommandTerm::getReferencePositionLocal() const { +vector3_t MotionCommandTerm::getAnchorPositionLocal() const { const auto& data = model_->getPinData(); - const auto& refPoseReal = data.oMf[referenceRobotIndex_]; + const auto& anchorPoseReal = data.oMf[anchorRobotIndex_]; - const auto& refPos = motionPolicy_->getBodyPositions()[referenceMotionIndex_]; - return refPoseReal.actInv(worldToInit_.act(refPos)); + const auto& anchorPos = motionPolicy_->getBodyPositions()[anchorMotionIndex_]; + return anchorPoseReal.actInv(worldToInit_.act(anchorPos)); } -vector_t MotionCommandTerm::getReferenceOrientationLocal() const { - const auto& refPoseReal = model_->getPinData().oMf[referenceRobotIndex_]; - const pinocchio::SE3 refOri(motionPolicy_->getBodyOrientations()[referenceMotionIndex_], vector3_t::Zero()); - const auto rot = refPoseReal.actInv(worldToInit_.act(refOri)).rotation(); +vector_t MotionCommandTerm::getAnchorOrientationLocal() const { + const auto& anchorPoseReal = model_->getPinData().oMf[anchorRobotIndex_]; + const pinocchio::SE3 anchorOri(motionPolicy_->getBodyOrientations()[anchorMotionIndex_], vector3_t::Zero()); + const auto rot = anchorPoseReal.actInv(worldToInit_.act(anchorOri)).rotation(); vector_t rot6(6); rot6 << rot(0, 0), rot(0, 1), rot(1, 0), rot(1, 1), rot(2, 0), rot(2, 1); return rot6; @@ -58,10 +65,10 @@ vector_t MotionCommandTerm::getReferenceOrientationLocal() const { vector_t MotionCommandTerm::getRobotBodyPositionLocal() const { const auto& data = model_->getPinData(); - const auto& refPoseReal = data.oMf[referenceRobotIndex_]; + const auto& anchorPoseReal = data.oMf[anchorRobotIndex_]; vector_t value(3 * cfg_.bodyNames.size()); for (size_t i = 0; i < cfg_.bodyNames.size(); ++i) { - const auto& bodyPoseLocal = refPoseReal.actInv(data.oMf[bodyIndices_[i]]); + const auto& bodyPoseLocal = anchorPoseReal.actInv(data.oMf[bodyIndices_[i]]); value.segment(3 * i, 3) = bodyPoseLocal.translation(); } return value; @@ -69,10 +76,10 @@ vector_t MotionCommandTerm::getRobotBodyPositionLocal() const { vector_t MotionCommandTerm::getRobotBodyOrientationLocal() const { const auto& data = model_->getPinData(); - const auto& refPoseReal = data.oMf[referenceRobotIndex_]; + const auto& anchorPoseReal = data.oMf[anchorRobotIndex_]; vector_t value(6 * cfg_.bodyNames.size()); for (size_t i = 0; i < cfg_.bodyNames.size(); ++i) { - const auto& rot = refPoseReal.actInv(data.oMf[bodyIndices_[i]]).rotation(); + const auto& rot = anchorPoseReal.actInv(data.oMf[bodyIndices_[i]]).rotation(); vector_t rot6(6); rot6 << rot(0, 0), rot(0, 1), rot(1, 0), rot(1, 1), rot(2, 0), rot(2, 1); value.segment(i * 6, 6) = rot6; diff --git a/src/MotionOnnxPolicy.cpp b/src/MotionOnnxPolicy.cpp index 9198497..5111702 100644 --- a/src/MotionOnnxPolicy.cpp +++ b/src/MotionOnnxPolicy.cpp @@ -10,13 +10,13 @@ namespace legged { void MotionOnnxPolicy::reset() { OnnxPolicy::reset(); - time_step_ = 0; + timeStep_ = startStep_; forward(vector_t::Zero(getObservationSize())); } vector_t MotionOnnxPolicy::forward(const vector_t& observations) { tensor2d_t timeStep(1, 1); - timeStep(0, 0) = static_cast(time_step_++); + timeStep(0, 0) = static_cast(timeStep_++); inputTensors_[name2Index_.at("time_step")] = timeStep; OnnxPolicy::forward(observations); @@ -40,4 +40,11 @@ vector_t MotionOnnxPolicy::forward(const vector_t& observations) { return getLastAction(); } +void MotionOnnxPolicy::parseMetadata() { + OnnxPolicy::parseMetadata(); + anchorBodyName_ = getMetadataStr("anchor_body_name"); + std::cout << '\t' << "anchor_body_name: " << anchorBodyName_ << '\n'; + bodyNames_ = parseCsv(getMetadataStr("body_names")); + std::cout << '\t' << "body_names: " << bodyNames_ << '\n'; +} } // namespace legged diff --git a/src/MotionTrackingController.cpp b/src/MotionTrackingController.cpp index 0cd94b1..4408b8f 100644 --- a/src/MotionTrackingController.cpp +++ b/src/MotionTrackingController.cpp @@ -4,51 +4,54 @@ #include "motion_tracking_controller/MotionObservation.h" namespace legged { -controller_interface::return_type MotionTrackingController::update(const rclcpp::Time& time, const rclcpp::Duration& period) { - if (OnnxController::update(time, period) != controller_interface::return_type::OK) { - return controller_interface::return_type::ERROR; +controller_interface::CallbackReturn MotionTrackingController::on_init() { + if (RlController::on_init() != controller_interface::CallbackReturn::SUCCESS) { + return controller_interface::CallbackReturn::ERROR; } - // dataLogger_->update(time); + try { + auto_declare("motion.start_step", 0); + } catch (const std::exception& e) { + RCLCPP_ERROR(get_node()->get_logger(), "Exception during init: %s", e.what()); + return CallbackReturn::ERROR; + } - return controller_interface::return_type::OK; + return controller_interface::CallbackReturn::SUCCESS; } controller_interface::CallbackReturn MotionTrackingController::on_configure(const rclcpp_lifecycle::State& previous_state) { - get_node()->get_parameter("motion.reference_body", cfg_.referenceBody); - get_node()->get_parameter("motion.body_names", cfg_.bodyNames); + const auto policyPath = get_node()->get_parameter("policy.path").as_string(); + const auto startStep = static_cast(get_node()->get_parameter("motion.start_step").as_int()); - std::string policyPath{}; - get_node()->get_parameter("policy.path", policyPath); - policy_ = std::make_shared(policyPath); + policy_ = std::make_shared(policyPath, startStep); policy_->init(); - RCLCPP_INFO_STREAM(rclcpp::get_logger("MotionTrackingController"), "Load Onnx model from" << policyPath << " successfully !"); + + auto policy = std::dynamic_pointer_cast(policy_); + cfg_.anchorBody = policy->getAnchorBodyName(); + cfg_.bodyNames = policy->getBodyNames(); + RCLCPP_INFO_STREAM(rclcpp::get_logger("MotionTrackingController"), "Load Onnx model from " << policyPath << " successfully !"); return RlController::on_configure(previous_state); } controller_interface::CallbackReturn MotionTrackingController::on_activate(const rclcpp_lifecycle::State& previous_state) { - if (OnnxController::on_activate(previous_state) != controller_interface::CallbackReturn::SUCCESS) { + if (RlController::on_activate(previous_state) != controller_interface::CallbackReturn::SUCCESS) { return controller_interface::CallbackReturn::ERROR; } - // dataLogger_ = std::make_shared(leggedModel(), policy_); - return controller_interface::CallbackReturn::SUCCESS; } controller_interface::CallbackReturn MotionTrackingController::on_deactivate(const rclcpp_lifecycle::State& previous_state) { - if (OnnxController::on_deactivate(previous_state) != controller_interface::CallbackReturn::SUCCESS) { + if (RlController::on_deactivate(previous_state) != controller_interface::CallbackReturn::SUCCESS) { return controller_interface::CallbackReturn::ERROR; } - // dataLogger_->writeAndClear(); - return controller_interface::CallbackReturn::SUCCESS; } bool MotionTrackingController::parserCommand(const std::string& name) { - if (OnnxController::parserCommand(name)) { + if (RlController::parserCommand(name)) { return true; } if (name == "motion") { @@ -60,13 +63,13 @@ bool MotionTrackingController::parserCommand(const std::string& name) { } bool MotionTrackingController::parserObservation(const std::string& name) { - if (OnnxController::parserObservation(name)) { + if (RlController::parserObservation(name)) { return true; } - if (name == "motion_ref_pos_b") { - observationManager_->addTerm(std::make_shared(commandTerm_)); - } else if (name == "motion_ref_ori_b") { - observationManager_->addTerm(std::make_shared(commandTerm_)); + if (name == "motion_ref_pos_b" || name == "motion_anchor_pos_b") { + observationManager_->addTerm(std::make_shared(commandTerm_)); + } else if (name == "motion_ref_ori_b" || name == "motion_anchor_ori_b") { + observationManager_->addTerm(std::make_shared(commandTerm_)); } else if (name == "robot_body_pos") { observationManager_->addTerm(std::make_shared(commandTerm_)); } else if (name == "robot_body_ori") {