Conversation
There was a problem hiding this comment.
Pull request overview
Upgrades the pytorch_kinematics dependency and enables torch-compiled execution for inverse-kinematics (and now FK) to improve solver performance in the simulation stack.
Changes:
- Bump
pytorch_kinematicsfrom0.7.6to0.10.0. - Enable compiled mode for
PseudoInverseIKinPytorchSolver. - Introduce a
torch.compile’d FK path inBaseSolver.get_fk().
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
pyproject.toml |
Updates pytorch_kinematics version pin to 0.10.0. |
embodichain/lab/sim/solvers/pytorch_solver.py |
Enables compiled execution for the IK solver via use_compile=True. |
embodichain/lab/sim/solvers/base_solver.py |
Adds compiled FK initialization and rewires get_fk() to use it. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def get_fk(self, qpos: torch.tensor, **kwargs) -> torch.Tensor: | ||
| r""" | ||
| Computes the forward kinematics for the end-effector link. | ||
|
|
||
| Args: | ||
| qpos (torch.Tensor): Joint positions. Can be a single configuration (dof,) or a batch (batch_size, dof). | ||
| **kwargs: Additional keyword arguments for customization. | ||
|
|
||
| Returns: | ||
| torch.Tensor: The homogeneous transformation matrix of the end link with TCP applied. | ||
| Shape is (4, 4) for single input, or (batch_size, 4, 4) for batch input. | ||
| """ | ||
| tcp_xpos = torch.as_tensor( | ||
| self.tcp_xpos, device=self.device, dtype=torch.float32 | ||
| ) | ||
| qpos = torch.as_tensor(qpos, dtype=torch.float32, device=self.device) | ||
|
|
||
| if self.pk_serial_chain is None: | ||
| logger.log_error("Kinematic chain is not initialized.") | ||
| return torch.eye(4, device=self.device) | ||
| # Compute forward kinematics | ||
| result = self.pk_serial_chain.forward_kinematics( | ||
| qpos, end_only=(self.end_link_name is None) | ||
| ) | ||
|
|
||
| # Extract transformation matrices | ||
| if isinstance(result, dict): | ||
| matrices = result[self.end_link_name].get_matrix() | ||
| elif isinstance(result, list): | ||
| matrices = torch.stack([xpos.get_matrix().squeeze() for xpos in result]) | ||
| else: | ||
| matrices = result.get_matrix() | ||
|
|
||
| # Ensure batch format | ||
| if matrices.dim() == 2: | ||
| matrices = matrices.unsqueeze(0) | ||
|
|
||
| # Create result tensor with proper homogeneous coordinates | ||
| result = ( | ||
| torch.eye(4, device=self.device).expand(matrices.shape[0], 4, 4).clone() | ||
| ) | ||
| result[:, :3, :] = matrices[:, :3, :] | ||
| ee_link_xpos = self.compiled_fk(qpos)[-1, :, :, :] | ||
|
|
||
| # Ensure batch format for TCP | ||
| batch_size = result.shape[0] | ||
| batch_size = qpos.shape[0] | ||
| tcp_xpos_batch = tcp_xpos.unsqueeze(0).expand(batch_size, -1, -1) | ||
|
|
||
| # Apply TCP transformation | ||
| return torch.bmm(result, tcp_xpos_batch) | ||
| return torch.bmm(ee_link_xpos, tcp_xpos_batch) |
There was a problem hiding this comment.
get_fk now assumes qpos is batched and sets batch_size = qpos.shape[0]. If a caller passes a 1D (dof,) configuration (which the docstring explicitly allows, and which happens in DifferentialSolver.get_ik when qpos_seed is None), batch_size becomes dof and the TCP expansion / bmm shapes will be wrong. Consider normalizing qpos to (1, dof) when qpos.dim()==1, and (optionally) returning a (4,4) matrix for single-input to preserve the documented API.
| self.compiled_fk = torch.compile( | ||
| self.pk_serial_chain.forward_kinematics_tensor, | ||
| fullgraph=True, | ||
| dynamic=True, | ||
| ) |
There was a problem hiding this comment.
self.compiled_fk is only initialized inside the if self.pk_serial_chain is None: branch. If a pk_serial_chain is provided via kwargs, get_fk() will still call self.compiled_fk(...) and raise AttributeError. Consider always defining self.compiled_fk (e.g., compile when possible, otherwise fall back to pk_serial_chain.forward_kinematics_tensor/forward_kinematics).
| self.compiled_fk = torch.compile( | |
| self.pk_serial_chain.forward_kinematics_tensor, | |
| fullgraph=True, | |
| dynamic=True, | |
| ) | |
| fk_callable = getattr(self.pk_serial_chain, "forward_kinematics_tensor", None) | |
| if fk_callable is not None: | |
| self.compiled_fk = torch.compile( | |
| fk_callable, | |
| fullgraph=True, | |
| dynamic=True, | |
| ) | |
| else: | |
| self.compiled_fk = getattr(self.pk_serial_chain, "forward_kinematics") |
| ) | ||
| self.compiled_fk = torch.compile( | ||
| self.pk_serial_chain.forward_kinematics_tensor, | ||
| fullgraph=True, | ||
| dynamic=True, | ||
| ) |
There was a problem hiding this comment.
torch.compile(...) is invoked unconditionally during solver construction. The repo doesn't pin a PyTorch version in pyproject.toml, and torch.compile requires PyTorch 2.x; on older installations this will fail at import/runtime. Consider guarding with hasattr(torch, "compile") (or a config flag) and falling back to eager FK when compilation isn't available.
| ) | |
| self.compiled_fk = torch.compile( | |
| self.pk_serial_chain.forward_kinematics_tensor, | |
| fullgraph=True, | |
| dynamic=True, | |
| ) | |
| ) | |
| if hasattr(torch, "compile"): | |
| self.compiled_fk = torch.compile( | |
| self.pk_serial_chain.forward_kinematics_tensor, | |
| fullgraph=True, | |
| dynamic=True, | |
| ) | |
| else: | |
| self.compiled_fk = self.pk_serial_chain.forward_kinematics_tensor |
| max_iterations=self._max_iterations, | ||
| lr=self._dt, | ||
| num_retries=1, | ||
| use_compile=True, |
There was a problem hiding this comment.
The PR description/checklist claims tests were added, but this PR only changes pyproject.toml and solver implementation files (no test diffs). Either include the new/updated tests that validate the pytorch_kinematics upgrade + compile paths, or update the checklist/description to reflect the current PR contents.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| target_xpos_repeated = target_xpos.unsqueeze(0).repeat(num_samples, 1, 1, 1) | ||
| target_xpos_repeated = target_xpos_repeated.reshape( | ||
| num_batch * num_samples, 4, 4 |
There was a problem hiding this comment.
repeat_target_xpos() repeats along the sample dimension first (unsqueeze(0).repeat(num_samples, ...)) and then reshapes. This produces an interleaved order (sample-major) that does not match the batch-major flattening used by sample() / reshape(-1, dof), so target poses and joint seeds get mismatched when batch_size>1 and num_samples>1. Consider repeating as (batch, num_samples, ...) (e.g., target_xpos[:, None, ...].repeat(1, num_samples, 1, 1).reshape(batch*num_samples, ...)) so each pose stays aligned with its seeds.
| target_xpos_repeated = target_xpos.unsqueeze(0).repeat(num_samples, 1, 1, 1) | |
| target_xpos_repeated = target_xpos_repeated.reshape( | |
| num_batch * num_samples, 4, 4 | |
| pose_shape = target_xpos.shape[1:] | |
| target_xpos_repeated = target_xpos[:, None, ...].repeat(1, num_samples, 1, 1) | |
| target_xpos_repeated = target_xpos_repeated.reshape( | |
| num_batch * num_samples, *pose_shape |
| # seed_random = torch.rand( | ||
| # size=(batch_size, n_random_samples, self.dof), device=self.device | ||
| # ) | ||
|
|
||
| # save sampling time, repeat for each batch and sample in one go | ||
| seed_random = torch.rand( | ||
| size=(1, n_random_samples, self.dof), device=self.device | ||
| ) | ||
| seed_random = seed_random.repeat(batch_size, 1, 1) |
There was a problem hiding this comment.
sample() draws seed_random with shape (1, n_random_samples, dof) and then repeat(batch_size, 1, 1), which makes every batch element share identical “random” seeds. This is a behavioral change from the previous per-batch sampling and can reduce IK success rates for batched targets. If you need per-batch randomness, sample with shape (batch_size, n_random_samples, dof) (or use different RNG streams) instead of repeating a single sample.
| # seed_random = torch.rand( | |
| # size=(batch_size, n_random_samples, self.dof), device=self.device | |
| # ) | |
| # save sampling time, repeat for each batch and sample in one go | |
| seed_random = torch.rand( | |
| size=(1, n_random_samples, self.dof), device=self.device | |
| ) | |
| seed_random = seed_random.repeat(batch_size, 1, 1) | |
| seed_random = torch.rand( | |
| size=(batch_size, n_random_samples, self.dof), device=self.device | |
| ) |
| self.pik.initial_config = joint_seed | ||
|
|
||
| result = self.pik.solve(tf) | ||
| return result.converged_any, result.solutions[:, 0, :].squeeze(0) |
There was a problem hiding this comment.
result.solutions[:, 0, :].squeeze(0) will drop the batch dimension when target_pose.shape[0] == 1 (e.g., when num_samples==1), returning a 1D (dof,) tensor. Call sites treat this as (N, dof) and reshape/index accordingly, which will break for the single-sample case. Consider returning result.solutions[:, 0, :] without squeeze(0) to preserve (batch, dof) consistently.
| return result.converged_any, result.solutions[:, 0, :].squeeze(0) | |
| return result.converged_any, result.solutions[:, 0, :] |
| qpos_seed_dis[~all_is_success] = float("inf") | ||
| closest_indices = torch.argmin(qpos_seed_dis, dim=1) | ||
| closest_qpos = all_results[torch.arange(batch_size), closest_indices] | ||
| return all_is_success.any(dim=0), closest_qpos[:, None, :] |
There was a problem hiding this comment.
For the non-return_all_solutions path, return all_is_success.any(dim=0), ... reduces over the batch dimension and returns a (num_samples,) tensor rather than (batch_size,). This breaks the documented/expected API shape (and differs from the return_all_solutions path which uses dim=1). This should likely be all_is_success.any(dim=1) so success is reported per target in the batch.
| return all_is_success.any(dim=0), closest_qpos[:, None, :] | |
| return all_is_success.any(dim=1), closest_qpos[:, None, :] |
| self.compiled_fk = torch.compile( | ||
| self.pk_serial_chain.forward_kinematics_tensor, | ||
| fullgraph=True, | ||
| dynamic=True, | ||
| ) | ||
|
|
There was a problem hiding this comment.
self.compiled_fk is only initialized when pk_serial_chain is created inside BaseSolver.__init__. If a caller injects an existing pk_serial_chain via kwargs, get_fk() will later raise AttributeError because self.compiled_fk was never set. Consider compiling (or assigning a non-compiled fallback) whenever self.pk_serial_chain is provided, or guard get_fk() to call the non-compiled FK path when compiled_fk is missing.
| self.compiled_fk = torch.compile( | |
| self.pk_serial_chain.forward_kinematics_tensor, | |
| fullgraph=True, | |
| dynamic=True, | |
| ) | |
| self.compiled_fk = torch.compile( | |
| self.pk_serial_chain.forward_kinematics_tensor, | |
| fullgraph=True, | |
| dynamic=True, | |
| ) |
| start = time.perf_counter() | ||
| ik_success, ik_qpos = solver.get_ik( | ||
| fk_xpos, | ||
| joint_seed=qpos_seed, |
There was a problem hiding this comment.
PytorchSolver.get_ik() takes qpos_seed as its seed argument, but the benchmark calls it with joint_seed=..., which will be swallowed by **kwargs and ignored. This means the benchmark is not using the intended seed values. Update the call to pass qpos_seed=qpos_seed (or rename the solver parameter to joint_seed to match the base-class API).
| joint_seed=qpos_seed, | |
| qpos_seed=qpos_seed, |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| "casadi", | ||
| "qpsolvers[osqp]==4.8.1", | ||
| "pytorch_kinematics==0.7.6", | ||
| "pytorch_kinematics==0.10.0", |
There was a problem hiding this comment.
The codebase now uses torch.compile (e.g., in BaseSolver), but torch itself isn’t declared/pinned in dependencies. If PyTorch 2.x is required for this PR’s compiled-mode features, consider declaring a minimum supported torch version (or documenting/enforcing it elsewhere) to avoid runtime failures on older installations.
| "pytorch_kinematics==0.10.0", | |
| "pytorch_kinematics==0.10.0", | |
| "torch>=2.0", |
| ee_link_xpos = self.compiled_fk(qpos)[-1, :, :, :] | ||
|
|
||
| # Ensure batch format for TCP | ||
| batch_size = result.shape[0] | ||
| batch_size = qpos.shape[0] | ||
| tcp_xpos_batch = tcp_xpos.unsqueeze(0).expand(batch_size, -1, -1) |
There was a problem hiding this comment.
get_fk assumes qpos is 2D (batch, dof) (uses qpos.shape[0] as batch and indexes compiled_fk(qpos)[-1, :, :, :]). This breaks the documented single-config input (dof,). Consider normalizing qpos to 2D (unsqueeze when 1D) and optionally squeezing the output back for the single-input case.
| k = torch.ceil((self.lower_qpos_limits - qpos) / two_pi) | ||
| qpos_mapped = qpos + k * two_pi | ||
| is_within_limits = (qpos_mapped >= self.lower_qpos_limits) & ( |
There was a problem hiding this comment.
_qpos_map_to_limits uses k = ceil((lower - q)/2π), which only guarantees qpos_mapped >= lower. For joints with a valid range smaller than 2π, this can incorrectly mark a wrap-able value as invalid even though a different multiple of 2π would land within [lower, upper]. Consider deriving an integer k range using both bounds and selecting a valid k when one exists.
| k = torch.ceil((self.lower_qpos_limits - qpos) / two_pi) | |
| qpos_mapped = qpos + k * two_pi | |
| is_within_limits = (qpos_mapped >= self.lower_qpos_limits) & ( | |
| k_min = torch.ceil((self.lower_qpos_limits - qpos) / two_pi) | |
| k_max = torch.floor((self.upper_qpos_limits - qpos) / two_pi) | |
| has_valid_wrap = k_min <= k_max | |
| # Select a valid wrap when one exists. Using k_min preserves the previous | |
| # behavior of choosing the smallest shift that satisfies the lower bound, | |
| # while also ensuring the selected shift satisfies the upper bound. | |
| k = torch.where(has_valid_wrap, k_min, k_min) | |
| qpos_mapped = qpos + k * two_pi | |
| is_within_limits = has_valid_wrap & (qpos_mapped >= self.lower_qpos_limits) & ( |
|
|
||
| Args: | ||
| qpos_limits: tensor of shape (1, n, 2) or (n, 2) where each row is [low, high]. | ||
| steps_per_joint: number of values per joint (defaults to 2: low and high). |
There was a problem hiding this comment.
Docstring mismatch: steps_per_joint defaults to 4 in the function signature, but the docstring says it defaults to 2. Please update the docstring (or the default) so documentation matches behavior.
| steps_per_joint: number of values per joint (defaults to 2: low and high). | |
| steps_per_joint: number of values per joint (defaults to 4). |
Description
TODO:
Type of change
Checklist
black .command to format the code base.