diff --git a/rfdiffusion/inference/model_runners.py b/rfdiffusion/inference/model_runners.py index f47d0e96..8bcc3584 100644 --- a/rfdiffusion/inference/model_runners.py +++ b/rfdiffusion/inference/model_runners.py @@ -358,15 +358,15 @@ def sample_init(self, return_forward_trajectory=False): chain_id = available_chains[0] available_chains.remove(chain_id) # Otherwise, use the chain of the fixed (motif) residues + # If fixed residues span multiple input chains (motif stitching), use the first one else: - assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}" - chain_id = list(chain_ids)[0] + chain_id = sorted(chain_ids)[0] self.chain_idx += [chain_id] * (last_res - first_res) # If this is a fixed chain, maintain the chain and residue numbering else: self.idx_pdb += [contig_ref[1] for contig_ref in self.contig_map.ref[first_res: last_res]] - assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}" - self.chain_idx += [list(chain_ids)[0]] * (last_res - first_res) + chain_id = sorted(chain_ids)[0] + self.chain_idx += [chain_id] * (last_res - first_res) first_res = last_res #################################### @@ -939,7 +939,28 @@ def sample_init(self): ### Get hotspots ### #################### self.hotspot_0idx=iu.get_idx0_hotspots(self.mappings, self.ppi_conf, self.binderlen) - + + ####################################### + ### Resolve cyclic peptide indicies ### + ####################################### + if self._conf.inference.cyclic: + if self._conf.inference.cyc_chains is None: + # default to all residues being cyclized + self.cyclic_reses = ~self.mask_str.to(self.device).squeeze() + else: + # use cyc_chains arg to determine cyclic_reses mask + assert type(self._conf.inference.cyc_chains) is str, 'cyc_chains arg must be string' + cyc_chains = self._conf.inference.cyc_chains + cyc_chains = [i.upper() for i in cyc_chains] + hal_idx = self.contig_map.hal # the pdb indices of output, knowledge of different chains + is_cyclized = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() # initially empty + for ch in cyc_chains: + ch_mask = torch.tensor([idx[0] == ch for idx in hal_idx]).bool() + is_cyclized[ch_mask] = True # set this whole chain to be cyclic + self.cyclic_reses = is_cyclized + else: + self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() + ######################### ### Set up potentials ### #########################