diff --git a/tests/test_trees.rs b/tests/test_trees.rs index 73d94061..304e42a7 100644 --- a/tests/test_trees.rs +++ b/tests/test_trees.rs @@ -15,6 +15,34 @@ use tskit::TreeFlags; use tskit::TreeSequence; use tskit::TreeSequenceFlags; +#[cfg(feature = "bindings")] +fn compare_preorder_to_c_api(tree: &tskit::Tree, node: NodeId, expected: &[NodeId]) { + let mut nodes: Vec = vec![ + NodeId::NULL; + unsafe { tskit::bindings::tsk_tree_get_size_bound(tree.as_ll_ref()) } + as usize + ]; + let mut num_nodes: tskit::bindings::tsk_size_t = 0; + let ptr = std::ptr::addr_of_mut!(num_nodes); + unsafe { + tskit::bindings::tsk_tree_preorder_from( + tree.as_ll_ref(), + if node == tree.virtual_root() { + -1 + } else { + node.into() + }, + nodes.as_mut_ptr() as *mut tskit::bindings::tsk_id_t, + ptr, + ); + } + assert_eq!(num_nodes as usize, expected.len()); + assert_eq!(expected, &nodes[0..num_nodes as usize]); +} + +#[cfg(not(feature = "bindings"))] +fn compare_preorder_to_c_api(_tree: &tskit::Tree, _node: NodeId, _expected: &[NodeId]) {} + pub fn make_small_table_collection() -> TableCollection { let mut tables = TableCollection::new(1000.).unwrap(); tables @@ -299,11 +327,11 @@ fn test_iterate_samples_two_trees() { } assert_eq!(expected_number_of_roots[current_tree], num_roots); assert_eq!(tree.roots().count(), eroot_ids.len()); - let mut preoder_nodes = vec![]; + let mut preorder_nodes = vec![]; let mut postoder_nodes = vec![]; for n in tree.traverse_nodes(NodeTraversalOrder::Preorder) { let mut nsamples = 0; - preoder_nodes.push(n); + preorder_nodes.push(n); if let Ok(iter) = tree.samples(n) { for _ in iter { nsamples += 1; @@ -323,7 +351,7 @@ fn test_iterate_samples_two_trees() { assert!(nsamples > 0); assert_eq!(nsamples, tree.num_tracked_samples(n).unwrap()); } - assert_eq!(preoder_nodes.len(), postoder_nodes.len()); + assert_eq!(preorder_nodes.len(), postoder_nodes.len()); let mut postorder_from_roots = vec![]; for root in tree.roots() { @@ -336,50 +364,28 @@ fn test_iterate_samples_two_trees() { } assert_eq!(postorder_from_roots, postoder_nodes); - // Test our preorder against the tskit functions in 0.99.15 - #[cfg(feature = "bindings")] - { - let mut nodes: Vec = vec![ - NodeId::NULL; - unsafe { tskit::bindings::tsk_tree_get_size_bound(tree.as_ll_ref()) } - as usize - ]; - let mut num_nodes: tskit::bindings::tsk_size_t = 0; - let ptr = std::ptr::addr_of_mut!(num_nodes); - unsafe { - tskit::bindings::tsk_tree_preorder( - tree.as_ll_ref(), - nodes.as_mut_ptr() as *mut tskit::bindings::tsk_id_t, - ptr, - ); - } - assert_eq!(num_nodes as usize, preoder_nodes.len()); - for i in 0..num_nodes as usize { - assert_eq!(preoder_nodes[i], nodes[i]); - } + compare_preorder_to_c_api(tree, tree.virtual_root(), &preorder_nodes); - // For each root, traverse its subtree with a preorder - // traversal, collecting outputs as we go. - // Then, compare to what the C API preorder fn outputs - let mut nodes_from_roots = vec![]; - for root in tree.roots() { - for node in tree - .traverse_nodes_from_root(root, NodeTraversalOrder::Preorder) - .unwrap() - { - nodes_from_roots.push(node); - } - } - for &node in &nodes_from_roots { - assert!(nodes.contains(&node)); + // For each root, traverse its subtree with a preorder + // traversal, collecting outputs as we go. + // Then, compare to what the C API preorder fn outputs + let mut nodes_from_roots = vec![]; + for root in tree.roots() { + for node in tree + .traverse_nodes_from_root(root, NodeTraversalOrder::Preorder) + .unwrap() + { + nodes_from_roots.push(node); } - // This assert checks that we get the same order as the - // tskit-c preorder fn. We need to take a slice of the - // vec where we store the tskit output b/c its allocation - // may be larger than the number of nodes in the tree. - assert_eq!(nodes_from_roots, nodes[0..nodes_from_roots.len()]); } - + for &node in &nodes_from_roots { + assert!(preorder_nodes.contains(&node)); + } + // This assert checks that we get the same order as the + // tskit-c preorder fn. We need to take a slice of the + // vec where we store the tskit output b/c its allocation + // may be larger than the number of nodes in the tree. + assert_eq!(nodes_from_roots, preorder_nodes); current_tree += 1; } } @@ -939,26 +945,6 @@ fn test_site_iterator_double_ended() { #[test] fn test_subtrees_from_non_root_nodes() { - #[cfg(feature = "bindings")] - fn compare_preoder_to_c_api(tree: &tskit::Tree, node: NodeId, expected: &[NodeId]) { - let mut nodes: Vec = vec![ - NodeId::NULL; - unsafe { tskit::bindings::tsk_tree_get_size_bound(tree.as_ll_ref()) } - as usize - ]; - let mut num_nodes: tskit::bindings::tsk_size_t = 0; - let ptr = std::ptr::addr_of_mut!(num_nodes); - unsafe { - tskit::bindings::tsk_tree_preorder_from( - tree.as_ll_ref(), - node.into(), - nodes.as_mut_ptr() as *mut tskit::bindings::tsk_id_t, - ptr, - ); - } - assert_eq!(num_nodes as usize, expected.len()); - assert_eq!(expected, &nodes[0..num_nodes as usize]); - } let treeseq = treeseq_from_small_table_collection_two_trees(); let mut iter = treeseq.tree_iterator(tskit::TreeFlags::default()).unwrap(); let tree = iter.next().unwrap(); @@ -968,8 +954,7 @@ fn test_subtrees_from_non_root_nodes() { .collect::>(); assert_eq!(pre, &[1, 4, 5]); - #[cfg(feature = "bindings")] - compare_preoder_to_c_api(tree, 1.into(), &pre); + compare_preorder_to_c_api(tree, 1.into(), &pre); let tree = iter.next().unwrap(); let pre = tree @@ -980,16 +965,14 @@ fn test_subtrees_from_non_root_nodes() { expected.extend(tree.children(1).map(i32::from)); assert_eq!(pre, expected); - #[cfg(feature = "bindings")] - compare_preoder_to_c_api(tree, 1.into(), &pre); + compare_preorder_to_c_api(tree, 1.into(), &pre); let pre = tree .traverse_nodes_from_root(3.into(), tskit::NodeTraversalOrder::Preorder) .unwrap() .collect::>(); assert_eq!(pre, &[3]); - #[cfg(feature = "bindings")] - compare_preoder_to_c_api(tree, 3.into(), &pre); + compare_preorder_to_c_api(tree, 3.into(), &pre); } // The following tests are lifted