From f8915244025c07563c135fe41511d0b791e76a3a Mon Sep 17 00:00:00 2001 From: Gertjan Bisschop Date: Fri, 4 Nov 2022 09:32:25 +0000 Subject: [PATCH 1/2] Python algorithm without S avl tree --- algorithms.py | 174 ++++++++++++++++++++------------------------------ 1 file changed, 68 insertions(+), 106 deletions(-) diff --git a/algorithms.py b/algorithms.py index a5885bf4b..4cc0e3220 100644 --- a/algorithms.py +++ b/algorithms.py @@ -125,15 +125,16 @@ def __init__(self, index): self.population = None self.label = 0 self.index = index + self.ancestral_to = -1 def __repr__(self): - return repr((self.left, self.right, self.node)) + return repr((self.left, self.right, self.node, self.ancestral_to)) @staticmethod def show_chain(seg): s = "" while seg is not None: - s += f"[{seg.left}, {seg.right}: {seg.node}], " + s += f"[{seg.left}, {seg.right}: {seg.node}, {seg.ancestral_to}], " seg = seg.next return s[:-2] @@ -475,7 +476,13 @@ def overlaps_at(self, pos): curr_interval = curr_interval.next raise ValueError("Bad overlap count chain") - def increment_interval(self, left, right): + def yield_overlap_check(self): + curr_interval = self.overlaps + while curr_interval is not None: + yield (curr_interval.node, curr_interval.ancestral_to) + curr_interval = curr_interval.next + + def increment_interval(self, left, right, ancestral_to): """ Increment the count that spans the interval [left, right), creating additional intervals in overlaps @@ -486,14 +493,17 @@ def increment_interval(self, left, right): if curr_interval.left == left: if curr_interval.right <= right: curr_interval.node += 1 + curr_interval.ancestral_to += ancestral_to left = curr_interval.right curr_interval = curr_interval.next else: self._split(curr_interval, right) curr_interval.node += 1 + curr_interval.ancestral_to += ancestral_to break else: - if curr_interval.right < left: + # verify this changed < to <= !!!! + if curr_interval.right <= left: curr_interval = curr_interval.next else: self._split(curr_interval, left) @@ -505,7 +515,7 @@ def _split(self, seg, bp): # noqa: A002 from breakpoint to seg.right. Set the original segment's right endpoint to breakpoint """ - right = self._make_segment(bp, seg.right, seg.node) + right = self._make_segment(bp, seg.right, seg.node, seg.ancestral_to) if seg.next is not None: seg.next.prev = right right.next = seg.next @@ -513,11 +523,12 @@ def _split(self, seg, bp): # noqa: A002 seg.next = right seg.right = bp - def _make_segment(self, left, right, count): + def _make_segment(self, left, right, count, ancestral_to=0): seg = Segment(0) seg.left = left seg.right = right seg.node = count + seg.ancestral_to = ancestral_to return seg @@ -591,7 +602,7 @@ def __init__( self.gc_mass_index = [ FenwickTree(self.max_segments) for j in range(num_labels) ] - self.S = bintrees.AVLTree() + for pop in self.P: pop.set_start_size(population_sizes[pop.id]) pop.set_growth_rate(population_growth_rates[pop.id], 0) @@ -640,36 +651,35 @@ def __init__( def initialise(self, ts): root_time = np.max(self.tables.nodes.time) self.t = root_time + self.num_samples = ts.num_samples root_segments_head = [None for _ in range(ts.num_nodes)] root_segments_tail = [None for _ in range(ts.num_nodes)] - last_S = -1 + for tree in ts.trees(): left, right = tree.interval - S = 0 if tree.num_roots == 1 else tree.num_roots - if S != last_S: - self.S[left] = S - last_S = S # If we have 1 root this is a special case and we don't add in # any ancestral segments to the state. if tree.num_roots > 1: for root in tree.roots: population = ts.node(root).population + ancestral_to = tree.num_samples(root) if root_segments_head[root] is None: - seg = self.alloc_segment(left, right, root, population) + seg = self.alloc_segment( + left, right, root, population, ancestral_to + ) root_segments_head[root] = seg root_segments_tail[root] = seg else: tail = root_segments_tail[root] - if tail.right == left: + if tail.right == left and tail.ancestral_to == ancestral_to: tail.right = right else: seg = self.alloc_segment( - left, right, root, population, tail + left, right, root, population, ancestral_to, tail ) tail.next = seg root_segments_tail[root] = seg - self.S[self.L] = -1 # Insert the segment chains into the algorithm state. for node in range(ts.num_nodes): @@ -702,6 +712,7 @@ def alloc_segment( right, node, population, + ancestral_to, prev=None, next=None, # noqa: A002 label=0, @@ -717,6 +728,7 @@ def alloc_segment( s.next = next s.prev = prev s.label = label + s.ancestral_to = ancestral_to return s def copy_segment(self, segment): @@ -725,6 +737,7 @@ def copy_segment(self, segment): right=segment.right, node=segment.node, population=segment.population, + ancestral_to=segment.ancestral_to, next=segment.next, prev=segment.prev, label=segment.label, @@ -1619,8 +1632,8 @@ def dtwf_recombine(self, x): Chooses breakpoints and returns segments sorted by inheritance direction, by iterating through segment chain starting with x """ - u = self.alloc_segment(-1, -1, -1, -1, None, None) - v = self.alloc_segment(-1, -1, -1, -1, None, None) + u = self.alloc_segment(-1, -1, -1, -1, -1, None, None) + v = self.alloc_segment(-1, -1, -1, -1, -1, None, None) seg_tails = [u, v] # TODO Should this be the recombination rate going foward from x.left? @@ -1738,7 +1751,11 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): if len(X) == 1: x = X[0] if len(H) > 0 and H[0][0] < x.right: - alpha = self.alloc_segment(x.left, H[0][0], x.node, x.population) + # what type of event?? + # what should ancestral_to be? + alpha = self.alloc_segment( + x.left, H[0][0], x.node, x.population, x.ancestral_to + ) alpha.label = label x.left = H[0][0] heapq.heappush(H, (x.left, x)) @@ -1752,24 +1769,14 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): if new_node_id == -1: coalescence = True new_node_id = self.store_node(pop_id) - # We must also break if the next left value is less than - # any of the right values in the current overlap set. - if left not in self.S: - j = self.S.floor_key(left) - self.S[left] = self.S[j] - if r_max not in self.S: - j = self.S.floor_key(r_max) - self.S[r_max] = self.S[j] - # Update the number of extant segments. - if self.S[left] == len(X): - self.S[left] = 0 - right = self.S.succ_key(left) - else: - right = left - while right < r_max and self.S[right] != len(X): - self.S[right] -= len(X) - 1 - right = self.S.succ_key(right) - alpha = self.alloc_segment(left, right, new_node_id, pop_id) + + ancestral_to = sum(x.ancestral_to for x in X) + right = r_max + if ancestral_to != self.num_samples: + alpha = self.alloc_segment( + left, right, new_node_id, pop_id, ancestral_to + ) + # Update the heaps and make the record. for x in X: self.store_edge(left, right, new_node_id, x.node) @@ -1804,15 +1811,18 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): self.store_arg_edges(z) if defrag_required: self.defrag_segment_chain(z) - if coalescence: - self.defrag_breakpoints() + return merged_head def defrag_segment_chain(self, z): y = z while y.prev is not None: x = y.prev - if x.right == y.left and x.node == y.node: + if ( + x.right == y.left + and x.node == y.node + and x.ancestral_to == y.ancestral_to + ): x.right = y.right x.next = y.next if y.next is not None: @@ -1821,17 +1831,6 @@ def defrag_segment_chain(self, z): self.free_segment(y) y = x - def defrag_breakpoints(self): - # Defrag the breakpoints set - j = 0 - k = 0 - while k < self.L: - k = self.S.succ_key(j) - if self.S[j] == self.S[k]: - del self.S[k] - else: - j = k - def common_ancestor_event(self, population_index, label): """ Implements a coancestry event. @@ -1883,28 +1882,20 @@ def merge_two_ancestors(self, population_index, label, x, y): # segment left = x.left r_max = min(x.right, y.right) - if left not in self.S: - j = self.S.floor_key(left) - self.S[left] = self.S[j] - if r_max not in self.S: - j = self.S.floor_key(r_max) - self.S[r_max] = self.S[j] - # Update the number of extant segments. - if self.S[left] == 2: - self.S[left] = 0 - right = self.S.succ_key(left) - else: - right = left - while right < r_max and self.S[right] != 2: - self.S[right] -= 1 - right = self.S.succ_key(right) + + ancestral_to = x.ancestral_to + y.ancestral_to + right = r_max + + if ancestral_to != self.num_samples: alpha = self.alloc_segment( left=left, right=right, node=u, population=population_index, + ancestral_to=ancestral_to, label=label, ) + self.store_edge(left, right, u, x.node) self.store_edge(left, right, u, y.node) # Now trim the ends of x and y to the right sizes. @@ -1941,8 +1932,6 @@ def merge_two_ancestors(self, population_index, label, x, y): self.store_arg_edges(z) if defrag_required: self.defrag_segment_chain(z) - if coalescence: - self.defrag_breakpoints() def print_state(self, verify=False): print("State @ time ", self.t) @@ -1970,9 +1959,7 @@ def print_state(self, verify=False): population.print_state() if self.pedigree is not None: self.pedigree.print_state() - print("Overlap counts", len(self.S)) - for k, x in self.S.items(): - print("\t", k, "\t:\t", x) + for label in range(self.num_labels): if self.recomb_mass_index is not None: print( @@ -2018,6 +2005,7 @@ def verify_segments(self): assert u.left >= prev.right assert u.label == head.label assert u.population == head.population + assert 1 <= u.ancestral_to < self.num_samples prev = u u = u.next @@ -2027,43 +2015,17 @@ def verify_overlaps(self): for label in range(self.num_labels): for u in pop.iter_label(label): while u is not None: - overlap_counter.increment_interval(u.left, u.right) + overlap_counter.increment_interval( + u.left, u.right, u.ancestral_to + ) u = u.next - for pos, count in self.S.items(): - if pos != self.L: - assert count == overlap_counter.overlaps_at(pos) - - assert self.S[self.L] == -1 - # Check the ancestry tracking. - A = bintrees.AVLTree() - A[0] = 0 - A[self.L] = -1 - for pop in self.P: - for label in range(self.num_labels): - for u in pop.iter_label(label): - while u is not None: - if u.left not in A: - k = A.floor_key(u.left) - A[u.left] = A[k] - if u.right not in A: - k = A.floor_key(u.right) - A[u.right] = A[k] - k = u.left - while k < u.right: - A[k] += 1 - k = A.succ_key(k) - u = u.next - # Now, defrag A - j = 0 - k = 0 - while k < self.L: - k = A.succ_key(j) - if A[j] == A[k]: - del A[k] - else: - j = k - assert list(A.items()) == list(self.S.items()) + # OverlapCounter tracks info on all positions (0, self.L) even when + # there is no more ancestral material in that section -> count=0 + assert all( + anc_to == self.num_samples if count > 0 else anc_to == 0 + for count, anc_to in overlap_counter.yield_overlap_check() + ) def verify_mass_index(self, label, mass_index, rate_map, compute_left_bound): assert mass_index is not None From b6fefaa47cf345b50c35927aa90b6df4216df705 Mon Sep 17 00:00:00 2001 From: Gertjan Bisschop Date: Wed, 16 Nov 2022 08:32:13 +0000 Subject: [PATCH 2/2] ancestral_to first C outline --- algorithms.py | 3 - lib/msprime.c | 396 ++++++++----------------------------- lib/msprime.h | 4 +- lib/tests/test_ancestry.c | 8 +- lib/tests/test_pedigrees.c | 10 + 5 files changed, 96 insertions(+), 325 deletions(-) diff --git a/algorithms.py b/algorithms.py index 4cc0e3220..621f467a2 100644 --- a/algorithms.py +++ b/algorithms.py @@ -502,7 +502,6 @@ def increment_interval(self, left, right, ancestral_to): curr_interval.ancestral_to += ancestral_to break else: - # verify this changed < to <= !!!! if curr_interval.right <= left: curr_interval = curr_interval.next else: @@ -1751,8 +1750,6 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): if len(X) == 1: x = X[0] if len(H) > 0 and H[0][0] < x.right: - # what type of event?? - # what should ancestral_to be? alpha = self.alloc_segment( x.left, H[0][0], x.node, x.population, x.ancestral_to ) diff --git a/lib/msprime.c b/lib/msprime.c index c491b6029..2329e8188 100644 --- a/lib/msprime.c +++ b/lib/msprime.c @@ -661,7 +661,8 @@ msp_set_avl_node_block_size(msp_t *self, size_t block_size) static segment_t *MSP_WARN_UNUSED msp_alloc_segment(msp_t *self, double left, double right, tsk_id_t value, - population_id_t population, label_id_t label, segment_t *prev, segment_t *next) + population_id_t population, label_id_t label, segment_t *prev, segment_t *next, + size_t ancestral_to) { segment_t *seg = NULL; @@ -700,6 +701,7 @@ msp_alloc_segment(msp_t *self, double left, double right, tsk_id_t value, seg->value = value; seg->population = population; seg->label = label; + seg->ancestral_to = ancestral_to; out: return seg; } @@ -708,7 +710,7 @@ static segment_t *MSP_WARN_UNUSED msp_copy_segment(msp_t *self, const segment_t *seg) { return msp_alloc_segment(self, seg->left, seg->right, seg->value, seg->population, - seg->label, seg->prev, seg->next); + seg->label, seg->prev, seg->next, seg->ancestral_to); } /* Top level allocators and initialisation */ @@ -717,6 +719,7 @@ int msp_alloc(msp_t *self, tsk_table_collection_t *tables, gsl_rng *rng) { int ret = -1; + size_t num_samples = 0; memset(self, 0, sizeof(msp_t)); if (rng == NULL || tables == NULL) { @@ -737,6 +740,10 @@ msp_alloc(msp_t *self, tsk_table_collection_t *tables, gsl_rng *rng) ret = MSP_ERR_BAD_SEQUENCE_LENGTH; goto out; } + for (tsk_size_t j = 0; j < tables->nodes.num_rows; j++) { + num_samples += (tables->nodes.flags[j] & TSK_NODE_IS_SAMPLE); + } + self->num_samples = num_samples; self->num_populations = (uint32_t) self->tables->populations.num_rows; if (self->num_populations == 0) { ret = MSP_ERR_ZERO_POPULATIONS; @@ -769,7 +776,6 @@ msp_alloc(msp_t *self, tsk_table_collection_t *tables, gsl_rng *rng) self->segment_block_size = 1024; /* set up the AVL trees */ avl_init_tree(&self->breakpoints, cmp_node_mapping, NULL); - avl_init_tree(&self->overlap_counts, cmp_node_mapping, NULL); avl_init_tree(&self->non_empty_populations, cmp_pointer, NULL); /* Set up the demographic events */ self->demographic_events_head = NULL; @@ -871,7 +877,6 @@ msp_free(msp_t *self) msp_safe_free(self->sampling_events); msp_safe_free(self->buffered_edges); msp_safe_free(self->root_segments); - msp_safe_free(self->initial_overlaps); msp_safe_free(self->pedigree.individuals); msp_safe_free(self->pedigree.visit_order); /* free the object heaps */ @@ -1136,6 +1141,8 @@ msp_verify_segments(msp_t *self, bool verify_breakpoints) tsk_bug_assert(u->label == (label_id_t) k); tsk_bug_assert(u->left < u->right); tsk_bug_assert(u->right <= self->sequence_length); + tsk_bug_assert( + 1 <= u->ancestral_to && u->ancestral_to < self->num_samples); if (u->prev != NULL) { tsk_bug_assert(u->prev->next == u); } @@ -1154,7 +1161,6 @@ msp_verify_segments(msp_t *self, bool verify_breakpoints) label_segments == object_heap_get_num_allocated(&self->segment_heap[k])); } total_avl_nodes = msp_get_num_ancestors(self) + avl_count(&self->breakpoints) - + avl_count(&self->overlap_counts) + avl_count(&self->non_empty_populations); for (j = 0; j < self->pedigree.num_individuals; j++) { ind = &self->pedigree.individuals[j]; @@ -1210,6 +1216,7 @@ overlap_counter_alloc(overlap_counter_t *self, double seq_length, int initial_co overlaps->value = initial_count; overlaps->population = 0; overlaps->label = 0; + overlaps->ancestral_to = 0; self->seq_length = seq_length; self->overlaps = overlaps; @@ -1232,20 +1239,19 @@ overlap_counter_free(overlap_counter_t *self) } } -/* Find the number of segments that overlap at the given position */ -static uint32_t -overlap_counter_overlaps_at(overlap_counter_t *self, double pos) +/* Verify whether for or every interval in + * overlap_counter->ancestral_to equals num_samples or 0 + */ +static void +overlap_counter_verify_ancestral_to(overlap_counter_t *self, tsk_size_t num_samples) { - tsk_bug_assert(pos >= 0 && pos < self->seq_length); + size_t expected_count; segment_t *curr_overlap = self->overlaps; - while (curr_overlap->next != NULL) { - if (curr_overlap->left <= pos && pos < curr_overlap->right) { - break; - } + while (curr_overlap != NULL) { + expected_count = (curr_overlap->value > 0) ? num_samples : 0; + tsk_bug_assert(curr_overlap->ancestral_to == expected_count); curr_overlap = curr_overlap->next; } - - return (uint32_t) curr_overlap->value; } /* Split the segment at breakpoint and add in another segment @@ -1263,6 +1269,7 @@ overlap_counter_split_segment(segment_t *seg, double breakpoint) right_seg->value = seg->value; right_seg->population = 0; right_seg->label = 0; + right_seg->ancestral_to = seg->ancestral_to; if (seg->next != NULL) { right_seg->next = seg->next; @@ -1277,22 +1284,26 @@ overlap_counter_split_segment(segment_t *seg, double breakpoint) * [left, right), creating additional intervals if necessary. */ static void -overlap_counter_increment_interval(overlap_counter_t *self, double left, double right) +overlap_counter_increment_interval( + overlap_counter_t *self, double left, double right, size_t ancestral_to) { segment_t *curr_interval = self->overlaps; while (left < right) { if (curr_interval->left == left) { if (curr_interval->right <= right) { curr_interval->value++; + curr_interval->ancestral_to += ancestral_to; left = curr_interval->right; curr_interval = curr_interval->next; } else { overlap_counter_split_segment(curr_interval, right); curr_interval->value++; + curr_interval->ancestral_to += ancestral_to; break; } } else { - if (curr_interval->right < left) { + // same issue as algorithms.py + if (curr_interval->right <= left) { curr_interval = curr_interval->next; } else { overlap_counter_split_segment(curr_interval, left); @@ -1306,11 +1317,12 @@ static void msp_verify_overlaps(msp_t *self) { avl_node_t *node; - node_mapping_t *nm; + // node_mapping_t *nm; sampling_event_t se; segment_t *u; size_t j; - uint32_t label, count; + uint32_t label; //, count; + overlap_counter_t counter; int ok = overlap_counter_alloc(&counter, self->sequence_length, 0); @@ -1320,7 +1332,8 @@ msp_verify_overlaps(msp_t *self) for (j = self->next_sampling_event; j < self->num_sampling_events; j++) { se = self->sampling_events[j]; for (u = self->root_segments[se.sample]; u != NULL; u = u->next) { - overlap_counter_increment_interval(&counter, u->left, u->right); + overlap_counter_increment_interval( + &counter, u->left, u->right, u->ancestral_to); } } @@ -1329,16 +1342,15 @@ msp_verify_overlaps(msp_t *self) for (node = (&self->populations[j].ancestors[label])->head; node != NULL; node = node->next) { for (u = (segment_t *) node->item; u != NULL; u = u->next) { - overlap_counter_increment_interval(&counter, u->left, u->right); + overlap_counter_increment_interval( + &counter, u->left, u->right, u->ancestral_to); } } } } - for (node = self->overlap_counts.head; node->next != NULL; node = node->next) { - nm = (node_mapping_t *) node->item; - count = overlap_counter_overlaps_at(&counter, nm->position); - tsk_bug_assert(nm->value == count); - } + + /* For each position in counter.ancestral_to == self.num_samples or 0*/ + overlap_counter_verify_ancestral_to(&counter, self->num_samples); overlap_counter_free(&counter); } @@ -1400,23 +1412,10 @@ msp_verify_migration_destinations(msp_t *self) static void msp_verify_initial_state(msp_t *self) { - overlap_count_t *overlap; - double last_overlap_left = -1; + // overlap_count_t *overlap; tsk_size_t j; segment_t *head, *seg, *prev; - for (overlap = self->initial_overlaps; overlap->left < self->sequence_length; - overlap++) { - tsk_bug_assert(overlap->left > last_overlap_left); - last_overlap_left = overlap->left; - } - /* Last overlap should be a sentinal */ - overlap->left = self->sequence_length; - overlap->count = UINT32_MAX; - - /* First overlap should be 0 */ - tsk_bug_assert(self->initial_overlaps->left == 0); - /* Check the root segments */ for (j = 0; j < self->input_position.nodes; j++) { head = self->root_segments[j]; @@ -1523,21 +1522,6 @@ msp_print_root_segments(msp_t *self, FILE *out) } } -static void -msp_print_initial_overlaps(msp_t *self, FILE *out) -{ - overlap_count_t *overlap; - - fprintf(out, "Initial overlaps\n"); - - for (overlap = self->initial_overlaps; overlap->left < self->sequence_length; - overlap++) { - fprintf(out, "\t%f -> %d\n", overlap->left, (int) overlap->count); - } - tsk_bug_assert(overlap->left == self->sequence_length); - fprintf(out, "\t%f -> %d\n", overlap->left, (int) overlap->count); -} - int msp_print_state(msp_t *self, FILE *out) { @@ -1583,7 +1567,6 @@ msp_print_state(msp_t *self, FILE *out) rate_map_print_state(&self->gc_map, out); msp_pedigree_print_state(self, out); msp_print_root_segments(self, out); - msp_print_initial_overlaps(self, out); fprintf(out, "Sampling events:\n"); for (j = 0; j < self->num_sampling_events; j++) { if (j == self->next_sampling_event) { @@ -1692,11 +1675,7 @@ msp_print_state(msp_t *self, FILE *out) nm = (node_mapping_t *) a->item; fprintf(out, "\t%.14g -> %d\n", nm->position, (int) nm->value); } - fprintf(out, "Overlap count = %d\n", avl_count(&self->overlap_counts)); - for (a = self->overlap_counts.head; a != NULL; a = a->next) { - nm = (node_mapping_t *) a->item; - fprintf(out, "\t%.14g -> %d\n", nm->position, (int) nm->value); - } + fprintf(out, "Tables = \n"); tsk_table_collection_print_state(self->tables, out); @@ -1908,8 +1887,8 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, new_ind = NULL; y = NULL; for (x = ind; x != NULL; x = x->next) { - y = msp_alloc_segment( - self, x->left, x->right, x->value, x->population, dest_label, y, NULL); + y = msp_alloc_segment(self, x->left, x->right, x->value, x->population, + dest_label, y, NULL, x->ancestral_to); if (new_ind == NULL) { new_ind = y; } else { @@ -1972,105 +1951,6 @@ msp_remove_non_empty_population(msp_t *self, tsk_id_t population) return ret; } -/* - * Inserts a new overlap_count at the specified locus left, mapping to the - * specified number of overlapping segments b. - */ -static int MSP_WARN_UNUSED -msp_insert_overlap_count(msp_t *self, double left, uint32_t count) -{ - int ret = 0; - avl_node_t *node = msp_alloc_avl_node(self); - node_mapping_t *m = msp_alloc_node_mapping(self); - - if (node == NULL || m == NULL) { - ret = MSP_ERR_NO_MEMORY; - goto out; - } - m->position = left; - m->value = count; - avl_init_node(node, m); - node = avl_insert_node(&self->overlap_counts, node); - tsk_bug_assert(node != NULL); -out: - return ret; -} - -/* - * Inserts a new overlap_count at the specified locus, and copies its - * node mapping from the containing overlap_count. - */ -static int MSP_WARN_UNUSED -msp_copy_overlap_count(msp_t *self, double k) -{ - int ret; - node_mapping_t search, *nm; - avl_node_t *node; - - search.position = k; - avl_search_closest(&self->overlap_counts, &search, &node); - tsk_bug_assert(node != NULL); - nm = (node_mapping_t *) node->item; - if (nm->position > k) { - node = node->prev; - tsk_bug_assert(node != NULL); - nm = (node_mapping_t *) node->item; - } - ret = msp_insert_overlap_count(self, k, nm->value); - return ret; -} - -static int -msp_compress_overlap_counts(msp_t *self, double l, double r) -{ - int ret = 0; - avl_node_t *node1, *node2; - node_mapping_t search, *nm1, *nm2; - - search.position = l; - node1 = avl_search(&self->overlap_counts, &search); - tsk_bug_assert(node1 != NULL); - if (node1->prev != NULL) { - node1 = node1->prev; - } - node2 = node1->next; - do { - nm1 = (node_mapping_t *) node1->item; - nm2 = (node_mapping_t *) node2->item; - if (nm1->value == nm2->value) { - avl_unlink_node(&self->overlap_counts, node2); - msp_free_avl_node(self, node2); - msp_free_node_mapping(self, nm2); - node2 = node1->next; - } else { - node1 = node2; - node2 = node2->next; - } - } while (node2 != NULL && nm2->position <= r); - return ret; -} - -static int MSP_WARN_UNUSED -msp_conditional_compress_overlap_counts(msp_t *self, double l, double r) -{ - int ret = 0; - double covered_fraction = (r - l) / self->sequence_length; - - /* This is a heuristic to prevent us spending a lot of time pointlessly - * trying to defragment during the early stages of the simulation. - * 5% of the overall length seems like a good value and leads to - * a ~15% time reduction when doing large simulations. - */ - if (covered_fraction < 0.05) { - ret = msp_compress_overlap_counts(self, l, r); - if (ret != 0) { - goto out; - } - } -out: - return ret; -} - /* Defragment the segment chain ending in z by squashing any redundant * segments together */ static int MSP_WARN_UNUSED @@ -2081,7 +1961,8 @@ msp_defrag_segment_chain(msp_t *self, segment_t *z) y = z; while (y->prev != NULL) { x = y->prev; - if (x->right == y->left && x->value == y->value) { + if (x->right == y->left && x->value == y->value + && x->ancestral_to == y->ancestral_to) { x->right = y->right; x->next = y->next; if (y->next != NULL) { @@ -2273,8 +2154,8 @@ msp_dtwf_recombine(msp_t *self, segment_t *x, segment_t **u, segment_t **v) } else { tail = seg_tails[ix]; } - z = msp_alloc_segment( - self, k, x->right, x->value, x->population, x->label, tail, x->next); + z = msp_alloc_segment(self, k, x->right, x->value, x->population, x->label, + tail, x->next, x->ancestral_to); if (z == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -2474,7 +2355,7 @@ msp_recombination_event(msp_t *self, label_id_t label, segment_t **lhs, segment_ if (y->left < breakpoint) { tsk_bug_assert(breakpoint < y->right); alpha = msp_alloc_segment(self, breakpoint, y->right, y->value, y->population, - y->label, NULL, y->next); + y->label, NULL, y->next, y->ancestral_to); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -2774,16 +2655,14 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l bool coalescence = false; bool defrag_required = false; tsk_id_t v; - double l, r, l_min, r_max; - avl_node_t *node; - node_mapping_t *nm, search; + double l, r, r_max; segment_t *x, *y, *z, *alpha, *beta, *merged_head; + size_t ancestral_to; x = a; y = b; merged_head = NULL; /* Keep GCC happy */ - l_min = 0; r_max = 0; /* update recomb mass and get ready for loop */ @@ -2811,7 +2690,7 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l alpha->next = NULL; } else if (x->left != y->left) { alpha = msp_alloc_segment(self, x->left, y->left, x->value, - x->population, x->label, NULL, NULL); + x->population, x->label, NULL, NULL, x->ancestral_to); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -2820,9 +2699,10 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l } else { l = x->left; r_max = GSL_MIN(x->right, y->right); + r = r_max; if (!coalescence) { coalescence = true; - l_min = l; + // l_min = l; if (new_node_id == TSK_NULL) { new_node_id = msp_store_node( self, 0, self->time, population_id, TSK_NULL); @@ -2833,50 +2713,17 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l } } v = new_node_id; - /* Insert overlap counts for bounds, if necessary */ - search.position = l; - node = avl_search(&self->overlap_counts, &search); - if (node == NULL) { - ret = msp_copy_overlap_count(self, l); - if (ret < 0) { - goto out; - } - } - search.position = r_max; - node = avl_search(&self->overlap_counts, &search); - if (node == NULL) { - ret = msp_copy_overlap_count(self, r_max); - if (ret < 0) { - goto out; - } - } - /* Now get overlap count at the left */ - search.position = l; - node = avl_search(&self->overlap_counts, &search); - tsk_bug_assert(node != NULL); - nm = (node_mapping_t *) node->item; - if (nm->value == 2) { - nm->value = 0; - node = node->next; - tsk_bug_assert(node != NULL); - nm = (node_mapping_t *) node->item; - r = nm->position; - } else { - r = l; - while (nm->value != 2 && r < r_max) { - nm->value--; - node = node->next; - tsk_bug_assert(node != NULL); - nm = (node_mapping_t *) node->item; - r = nm->position; - } + ancestral_to = x->ancestral_to + y->ancestral_to; + + if (ancestral_to != self->num_samples) { alpha = msp_alloc_segment( - self, l, r, v, population_id, label, NULL, NULL); + self, l, r, v, population_id, label, NULL, NULL, ancestral_to); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; } } + tsk_bug_assert(v != x->value); ret = msp_store_edge(self, l, r, v, x->value); if (ret != 0) { @@ -2945,12 +2792,6 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l goto out; } } - if (coalescence) { - ret = msp_conditional_compress_overlap_counts(self, l_min, r_max); - if (ret != 0) { - goto out; - } - } if (ret_merged_head != NULL) { *ret_merged_head = merged_head; } @@ -3002,13 +2843,13 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, bool coalescence = false; bool defrag_required = false; uint32_t j, h; - double l, r, r_max, next_l, l_min; + double l, r, r_max, next_l; avl_node_t *node; - node_mapping_t *nm, search; segment_t *x, *z, *alpha; segment_t **H = NULL; segment_t *merged_head = NULL; tsk_id_t individual = TSK_NULL; + size_t ancestral_to; H = malloc(avl_count(Q) * sizeof(segment_t *)); if (H == NULL) { @@ -3016,7 +2857,6 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, goto out; } r_max = 0; /* keep compiler happy */ - l_min = 0; z = NULL; merged_head = NULL; while (avl_count(Q) > 0) { @@ -3042,7 +2882,7 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, x = H[0]; if (node != NULL && next_l < x->right) { alpha = msp_alloc_segment(self, x->left, next_l, x->value, x->population, - x->label, NULL, NULL); + x->label, NULL, NULL, x->ancestral_to); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3062,7 +2902,7 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, } else { coalescence = true; if (new_node_id == TSK_NULL) { - l_min = l; + // l_min = l; new_node_id = msp_store_node(self, 0, self->time, population_id, individual); if (new_node_id < 0) { @@ -3070,51 +2910,20 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, goto out; } } - /* Insert overlap counts for bounds, if necessary */ - search.position = l; - node = avl_search(&self->overlap_counts, &search); - if (node == NULL) { - ret = msp_copy_overlap_count(self, l); - if (ret < 0) { - goto out; - } - } - search.position = r_max; - node = avl_search(&self->overlap_counts, &search); - if (node == NULL) { - ret = msp_copy_overlap_count(self, r_max); - if (ret < 0) { - goto out; - } + ancestral_to = 0; + for (j = 0; j < h; j++) { + ancestral_to += H[j]->ancestral_to; } - /* Update the extant segments and allocate alpha if the interval - * has not coalesced. */ - search.position = l; - node = avl_search(&self->overlap_counts, &search); - tsk_bug_assert(node != NULL); - nm = (node_mapping_t *) node->item; - if (nm->value == h) { - nm->value = 0; - node = node->next; - tsk_bug_assert(node != NULL); - nm = (node_mapping_t *) node->item; - r = nm->position; - } else { - r = l; - while (nm->value != h && r < r_max) { - nm->value -= h - 1; - node = node->next; - tsk_bug_assert(node != NULL); - nm = (node_mapping_t *) node->item; - r = nm->position; - } - alpha = msp_alloc_segment( - self, l, r, new_node_id, population_id, label, NULL, NULL); + r = r_max; + if (ancestral_to != self->num_samples) { + alpha = msp_alloc_segment(self, l, r, new_node_id, population_id, label, + NULL, NULL, ancestral_to); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; } } + /* Store the edges and update the priority queue */ for (j = 0; j < h; j++) { x = H[j]; @@ -3179,12 +2988,6 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, goto out; } } - if (coalescence) { - ret = msp_conditional_compress_overlap_counts(self, l_min, r_max); - if (ret != 0) { - goto out; - } - } if (ret_merged_head != NULL) { *ret_merged_head = merged_head; } @@ -3299,12 +3102,7 @@ msp_reset_memory_state(msp_t *self) msp_free_avl_node(self, node); msp_free_node_mapping(self, nm); } - for (node = self->overlap_counts.head; node != NULL; node = node->next) { - nm = (node_mapping_t *) node->item; - avl_unlink_node(&self->overlap_counts, node); - msp_free_avl_node(self, node); - msp_free_node_mapping(self, nm); - } + return ret; } @@ -3389,6 +3187,7 @@ msp_allocate_root_segments(msp_t *self, tsk_tree_t *tree, double left, double ri population_id_t population; const population_id_t *restrict node_population = self->tables->nodes.population; label_id_t label = 0; /* For now only support label 0 */ + tsk_size_t num_samples; for (root = tsk_tree_get_left_root(tree); root != TSK_NULL; root = tree->right_sib[root]) { @@ -3399,9 +3198,10 @@ msp_allocate_root_segments(msp_t *self, tsk_tree_t *tree, double left, double ri ret = MSP_ERR_POPULATION_OUT_OF_BOUNDS; goto out; } + ret = tsk_tree_get_num_samples(tree, root, &num_samples); if (root_segments_head[root] == NULL) { seg = msp_alloc_segment( - self, left, right, root, population, label, NULL, NULL); + self, left, right, root, population, label, NULL, NULL, num_samples); if (seg == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3410,11 +3210,11 @@ msp_allocate_root_segments(msp_t *self, tsk_tree_t *tree, double left, double ri root_segments_tail[root] = seg; } else { tail = root_segments_tail[root]; - if (tail->right == left) { + if (tail->right == left && tail->ancestral_to == num_samples) { tail->right = right; } else { seg = msp_alloc_segment( - self, left, right, root, population, label, tail, NULL); + self, left, right, root, population, label, tail, NULL, num_samples); if (seg == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3435,10 +3235,8 @@ msp_process_input_trees(msp_t *self) int t_iter; tsk_treeseq_t ts; tsk_tree_t tree; - uint32_t overlap_count, last_overlap_count; - tsk_size_t num_trees, num_roots; + tsk_size_t num_roots; const size_t num_nodes = self->tables->nodes.num_rows; - overlap_count_t *overlap; segment_t **root_segments_tail = NULL; /* Initialise the memory for the tree and tree sequence so we can @@ -3451,15 +3249,12 @@ msp_process_input_trees(msp_t *self) ret = msp_set_tsk_error(ret); goto out; } - num_trees = tsk_treeseq_get_num_trees(&ts); root_segments_tail = calloc(num_nodes + 1, sizeof(*root_segments_tail)); self->root_segments = calloc(num_nodes + 1, sizeof(*self->root_segments)); /* We can't have more than num_trees intervals, and allow for one sentinel */ - self->initial_overlaps = calloc(num_trees + 1, sizeof(*self->initial_overlaps)); - if (self->root_segments == NULL || root_segments_tail == NULL - || self->initial_overlaps == NULL) { + if (self->root_segments == NULL || root_segments_tail == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; } @@ -3470,32 +3265,20 @@ msp_process_input_trees(msp_t *self) goto out; } - overlap = self->initial_overlaps; - last_overlap_count = UINT32_MAX; for (t_iter = tsk_tree_first(&tree); t_iter == 1; t_iter = tsk_tree_next(&tree)) { num_roots = tsk_tree_get_num_roots(&tree); - overlap_count = 0; if (num_roots > 1) { - overlap_count = (uint32_t) num_roots; ret = msp_allocate_root_segments(self, &tree, tree.interval.left, tree.interval.right, self->root_segments, root_segments_tail); if (ret != 0) { goto out; } } - if (overlap_count != last_overlap_count) { - overlap->left = tree.interval.left; - overlap->count = overlap_count; - overlap++; - last_overlap_count = overlap_count; - } } if (t_iter != 0) { ret = msp_set_tsk_error(t_iter); goto out; } - overlap->left = self->sequence_length; - overlap->count = UINT32_MAX; out: tsk_treeseq_free(&ts); @@ -3504,25 +3287,6 @@ msp_process_input_trees(msp_t *self) return ret; } -static int -msp_reset_population_state(msp_t *self) -{ - int ret = 0; - overlap_count_t *overlap = self->initial_overlaps; - while (true) { - ret = msp_insert_overlap_count(self, overlap->left, overlap->count); - if (ret != 0) { - goto out; - } - if (overlap->left == self->sequence_length) { - break; - } - overlap++; - } -out: - return ret; -} - /* Apply all demographic events at the specified time */ static int MSP_WARN_UNUSED @@ -3666,11 +3430,6 @@ msp_reset(msp_t *self) } tsk_bug_assert(self->tables->populations.num_rows == self->num_populations); - ret = msp_reset_population_state(self); - if (ret != 0) { - goto out; - } - self->next_demographic_event = self->demographic_events_head; memcpy( self->migration_matrix, self->initial_migration_matrix, N * N * sizeof(double)); @@ -3772,7 +3531,6 @@ msp_initialise_simulation_state(msp_t *self) cmp_sampling_event); } - /* ret = msp_compress_overlap_counts(self, 0, self->sequence_length); */ out: msp_safe_free(samples); return ret; diff --git a/lib/msprime.h b/lib/msprime.h index 853f3af61..0a365ff9a 100644 --- a/lib/msprime.h +++ b/lib/msprime.h @@ -82,6 +82,7 @@ typedef struct segment_t_t { size_t id; struct segment_t_t *prev; struct segment_t_t *next; + size_t ancestral_to; } segment_t; typedef struct { @@ -207,7 +208,6 @@ typedef struct _msp_t { pedigree_t pedigree; /* Initial state for replication */ segment_t **root_segments; - overlap_count_t *initial_overlaps; simulation_model_t initial_model; double *initial_migration_matrix; population_t *initial_populations; @@ -231,6 +231,7 @@ typedef struct _msp_t { sampling_event_t *sampling_events; size_t num_sampling_events; size_t next_sampling_event; + size_t num_samples; /* Demographic events */ struct demographic_event_t_t *demographic_events_head; struct demographic_event_t_t *demographic_events_tail; @@ -242,7 +243,6 @@ typedef struct _msp_t { population_t *populations; avl_tree_t non_empty_populations; avl_tree_t breakpoints; - avl_tree_t overlap_counts; /* We keep an independent Fenwick tree for each label */ fenwick_t *recomb_mass_index; fenwick_t *gc_mass_index; diff --git a/lib/tests/test_ancestry.c b/lib/tests/test_ancestry.c index dd2e2fc56..881fef924 100644 --- a/lib/tests/test_ancestry.c +++ b/lib/tests/test_ancestry.c @@ -38,6 +38,8 @@ test_single_locus_simulation(void) memset(samples, 0, n * sizeof(sample_t)); ret = build_sim(&msp, &tables, rng, 1, 1, samples, n); CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(msp.num_samples, n); + ret = msp_initialise(&msp); CU_ASSERT_EQUAL_FATAL(ret, 0); @@ -91,6 +93,7 @@ test_single_locus_two_populations(void) CU_ASSERT_EQUAL(ret, 0); ret = msp_add_mass_migration(&msp, t2, 1, 0, 1.0); CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(msp.num_samples, n); ret = msp_initialise(&msp); CU_ASSERT_EQUAL_FATAL(ret, 0); @@ -156,6 +159,7 @@ test_single_locus_many_populations(void) CU_ASSERT_EQUAL(ret, 0); ret = msp_add_mass_migration(&msp, 30.0, 0, num_populations - 1, 1.0); CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(msp.num_samples, n); ret = msp_initialise(&msp); CU_ASSERT_EQUAL(ret, 0); @@ -3022,6 +3026,7 @@ verify_simulate_from(int model, rate_map_t *recomb_map, tsk_tree_t tree; msp_t msp; gsl_rng *rng = safe_rng_alloc(); + tsk_size_t num_samples; ret = tsk_table_collection_copy(from_tables, &tables, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); @@ -3054,7 +3059,8 @@ verify_simulate_from(int model, rate_map_t *recomb_map, CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_treeseq_init(&final, &tables, TSK_TS_INIT_BUILD_INDEXES); CU_ASSERT_EQUAL_FATAL(ret, 0); - + num_samples = tsk_treeseq_get_num_samples(&final); + CU_ASSERT_EQUAL(msp.num_samples, num_samples); ret = tsk_tree_init(&tree, &final, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); for (ret = tsk_tree_first(&tree); ret == 1; ret = tsk_tree_next(&tree)) { diff --git a/lib/tests/test_pedigrees.c b/lib/tests/test_pedigrees.c index d188c77e3..827628951 100644 --- a/lib/tests/test_pedigrees.c +++ b/lib/tests/test_pedigrees.c @@ -117,6 +117,8 @@ verify_pedigree(double recombination_rate, unsigned long seed, tsk_tree_t tree; gsl_rng *rng = safe_rng_alloc(); bool coalescence = false; + size_t num_samples = 0; + size_t j; ret = build_pedigree_sim(&msp, &tables, rng, 100, ploidy, num_individuals, parents, time, is_sample, population); @@ -127,6 +129,14 @@ verify_pedigree(double recombination_rate, unsigned long seed, CU_ASSERT_EQUAL(ret, 0); ret = msp_initialise(&msp); CU_ASSERT_EQUAL_FATAL(ret, 0); + for (j = 0; j < num_individuals; j++) { + if (is_sample == NULL) { + num_samples += (time[j] == 0) * ploidy; + } else { + num_samples += is_sample[j] * ploidy; + } + } + CU_ASSERT_EQUAL(msp.num_samples, num_samples); /* msp_print_state(&msp, stdout); */