-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmain.cpp
More file actions
186 lines (158 loc) · 6.15 KB
/
main.cpp
File metadata and controls
186 lines (158 loc) · 6.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
// Copyright (c) 2026 Robotics and AI Institute LLC dba RAI Institute. All rights reserved.
/**
* @file main.cpp
* @brief Loopback controller example for an IsaacLab-exported ONNX policy.
*
* This example demonstrates how to wire the exploy OnnxRLController to concrete
* implementations of the three required interfaces and run a closed-loop simulation
* for a configurable number of cycles.
*
* Loopback semantics
* ------------------
* The LoopbackRobotStateInterface feeds commanded joint targets (position, velocity,
* effort) back as the measured joint state in the following cycle. This models a
* perfect, zero-delay actuator and is useful for verifying that an exported policy
* produces sensible action sequences without a full physics simulation.
*
* Data collection
* ---------------
* This example uses a no-op data collection interface.
*
* Usage
* -----
* loopback_controller_example <onnx_model_path>
* [--cycles N] (default 100)
* [--vx M_S] (default 0.5)
* [--vy M_S] (default 0.0)
* [--omega RAD_S] (default 0.0)
*/
#include <chrono>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <regex>
#include <stdexcept>
#include <string>
#include <thread>
#include <vector>
#include "exploy/components.hpp"
#include "exploy/controller.hpp"
#include "exploy/logging_interface.hpp"
#include "exploy/matcher.hpp"
#include "fixed_command_interface.hpp"
#include "loopback_state_interface.hpp"
struct Args {
std::string onnx_path;
int num_cycles{100};
double vx{0.5};
double vy{0.0};
double omega{0.0};
};
class NoOpDataCollection : public exploy::control::DataCollectionInterface {
public:
bool registerDataSource(const std::string&, std::span<const double>) override { return true; }
bool registerDataSource(const std::string&, std::span<const float>) override { return true; }
bool registerDataSource(const std::string&, const double&) override { return true; }
bool collectData(uint64_t /*time_us*/) override { return true; }
};
class CustomBodyPositionMatcher : public exploy::control::Matcher {
public:
CustomBodyPositionMatcher() : Matcher("CustomBodyPositionMatcher") {}
bool matches(const exploy::control::Match& maybe_match) override {
std::smatch match;
const std::regex pattern(R"(custom.obj\.([a-zA-Z0-9_]+)\.([a-zA-Z0-9_]+)\.pos_b_rt_w_in_w)");
if (std::regex_match(maybe_match.name, match, pattern) && match.size() > 2) {
found_matches_[match[2].str()] = maybe_match;
return true;
}
return false;
}
std::vector<std::unique_ptr<exploy::control::Input>> createInputs() const override {
std::vector<std::unique_ptr<exploy::control::Input>> inputs;
for (const auto& [body_name, found_match] : found_matches_) {
inputs.push_back(
std::make_unique<exploy::control::BodyPositionInput>(found_match.name, body_name));
}
return inputs;
}
};
[[nodiscard]] std::optional<Args> parseArgs(int argc, char** argv) {
if (argc < 2) {
std::cerr << "Usage: " << argv[0]
<< " <onnx_model_path>"
" [--cycles N] [--vx M_S] [--vy M_S] [--omega RAD_S]\n";
return std::nullopt;
}
Args args;
args.onnx_path = argv[1];
for (int i = 2; i < argc; ++i) {
const std::string flag = argv[i];
const bool has_next = (i + 1) < argc;
if (flag == "--cycles" && has_next) {
args.num_cycles = std::stoi(argv[++i]);
} else if (flag == "--vx" && has_next) {
args.vx = std::stod(argv[++i]);
} else if (flag == "--vy" && has_next) {
args.vy = std::stod(argv[++i]);
} else if (flag == "--omega" && has_next) {
args.omega = std::stod(argv[++i]);
} else {
std::cerr << "Unknown argument: " << flag << "\n";
return std::nullopt;
}
}
return args;
}
int main(int argc, char** argv) {
auto maybe_args = parseArgs(argc, argv);
if (!maybe_args.has_value()) return 1;
const Args& args = *maybe_args;
// Optional: set a custom logger. By default, the controller will log to stdout with a simple
// logger that prefixes log messages with their level (e.g. "[ERROR]", "[WARNING]", etc.).
exploy::control::StdoutLogger logger;
exploy::control::setLogger(&logger);
// Create a RobotStateInterface.
exploy::control::examples::LoopbackRobotStateInterface state;
// Create a CommandInterface.
exploy::control::examples::FixedCommandInterface command(
exploy::control::examples::FixedCommandConfig{.se2_velocity{args.vx, args.vy, args.omega}});
// Create a DataCollectionInterface.
NoOpDataCollection data_collection;
// Create the controller.
exploy::control::OnnxRLController controller(state, command, data_collection);
// Register custom matchers.
controller.context().registerMatcher(std::make_unique<CustomBodyPositionMatcher>());
// Load the ONNX model.
if (!controller.create(args.onnx_path)) {
std::cerr << "[main] Failed to load ONNX model: " << args.onnx_path << "\n";
return 1;
}
const int model_rate = controller.context().updateRate();
const double update_rate_hz = static_cast<double>(model_rate);
// Initialize the controller (calls init() on all components).
if (!controller.init(/*enable_data_collection=*/false)) {
std::cerr << "[main] Controller initialisation failed.\n";
return 1;
}
const std::chrono::duration<double> dt(update_rate_hz > 0.0 ? 1.0 / update_rate_hz : 0.0);
const uint64_t dt_us = static_cast<uint64_t>(dt.count() * 1e6);
uint64_t time_us = 0;
int failures = 0;
for (int cycle = 0; cycle < args.num_cycles; ++cycle) {
auto cycle_start = std::chrono::steady_clock::now();
// Run one controller step (read state → infer → write commands).
if (!controller.update(time_us)) {
std::cerr << "[main] Cycle " << cycle << " FAILED\n";
++failures;
}
time_us += dt_us;
// Rate-limit the loop if an update rate is specified.
if (update_rate_hz > 0.0) {
auto elapsed = std::chrono::steady_clock::now() - cycle_start;
if (elapsed < dt) {
std::this_thread::sleep_for(dt - elapsed);
}
}
}
return failures > 0 ? 1 : 0;
}