b-tree-plus-alpha
Loading...
Searching...
No Matches
dynamic_wavelet_tree_on_grid.hpp
1#pragma once
2#include "../sequence/dynamic_bit_sequence.hpp"
3#include "../prefix_sum/dynamic_prefix_sum.hpp"
4
5namespace stool
6{
7 namespace bptree
8 {
9
15 {
18
19 std::vector<BIT_SEQUENCE> bits_seq;
20 std::vector<PREFIX_SUM> length_seq;
21
22 public:
24 {
25 public:
26 using iterator_category = std::random_access_iterator_tag; // C++17
27 using value_type = uint64_t;
28 using difference_type = std::ptrdiff_t;
29
30 const DynamicWaveletTreeOnGrid *container = nullptr;
31 uint64_t x_rank = 0;
32 uint64_t y_rank = 0;
33 YRankIterator() : container(nullptr), x_rank(0), y_rank(0) {}
34 YRankIterator(const DynamicWaveletTreeOnGrid *container, uint64_t x_rank, uint64_t y_rank) : container(container), x_rank(x_rank), y_rank(y_rank) {}
35
36 uint64_t operator*() const noexcept { return this->x_rank; }
37
38 // 前後インクリメント
39 YRankIterator &operator++() noexcept
40 {
41 ++this->y_rank;
42 if (this->y_rank < this->container->size())
43 {
44 this->x_rank = this->container->access_x_rank(this->y_rank);
45 }
46 else
47 {
48 this->x_rank = this->container->size();
49 }
50 return *this;
51 }
52 YRankIterator operator++(int) noexcept
53 {
54 YRankIterator t = *this;
55 ++*this;
56 return t;
57 }
58 YRankIterator &operator--() noexcept
59 {
60 if (this->y_rank > 0)
61 {
62 this->y_rank--;
63 this->x_rank = this->container->access_x_rank(this->y_rank);
64 }
65 else
66 {
67 this->x_rank = this->container->size();
68 }
69 return *this;
70 }
71 YRankIterator operator--(int) noexcept
72 {
73 YRankIterator t = *this;
74 --*this;
75 return t;
76 }
77
78 // 加減算
79 YRankIterator &operator+=(difference_type n) noexcept
80 {
81 this->y_rank += n;
82 if (this->y_rank < this->container->size())
83 {
84 this->x_rank = this->container->access_x_rank(this->y_rank);
85 }
86 else
87 {
88 this->x_rank = this->container->size();
89 }
90 return *this;
91 }
92 YRankIterator &operator-=(difference_type n) noexcept
93 {
94 this->y_rank -= n;
95 if (this->y_rank < this->container->size())
96 {
97 this->x_rank = this->container->access_x_rank(this->y_rank);
98 }
99 else
100 {
101 this->x_rank = this->container->size();
102 }
103 return *this;
104 }
105 friend YRankIterator operator+(YRankIterator it, difference_type n) noexcept { return YRankIterator(it.container, it.y_rank + n, it.x_rank); }
106 friend YRankIterator operator+(difference_type n, YRankIterator it) noexcept { return it + n; }
107 friend YRankIterator operator-(YRankIterator it, difference_type n) noexcept { return YRankIterator(it.container, it.y_rank - n, it.x_rank); }
108 friend difference_type operator-(YRankIterator a, YRankIterator b) noexcept { return a.y_rank - b.y_rank; }
109
110 // 比較
111 friend bool operator==(YRankIterator a, YRankIterator b) noexcept { return a.y_rank == b.y_rank; }
112 friend bool operator!=(YRankIterator a, YRankIterator b) noexcept { return !(a == b); }
113 friend bool operator<(YRankIterator a, YRankIterator b) noexcept { return a.y_rank < b.y_rank; }
114 friend bool operator>(YRankIterator a, YRankIterator b) noexcept { return b < a; }
115 friend bool operator<=(YRankIterator a, YRankIterator b) noexcept { return !(b < a); }
116 friend bool operator>=(YRankIterator a, YRankIterator b) noexcept { return !(a < b); }
117 };
118
120 {
121 public:
122 using iterator_category = std::random_access_iterator_tag; // C++17
123 using value_type = uint64_t;
124 using difference_type = std::ptrdiff_t;
125
126 const DynamicWaveletTreeOnGrid *container = nullptr;
127 uint64_t x_rank = 0;
128 uint64_t y_rank = 0;
129 XRankIterator() : container(nullptr), x_rank(0), y_rank(0) {}
130 XRankIterator(const DynamicWaveletTreeOnGrid *container, uint64_t x_rank, uint64_t y_rank) : container(container), x_rank(x_rank), y_rank(y_rank) {}
131
132 uint64_t operator*() const noexcept { return this->y_rank; }
133
134 // 前後インクリメント
135 XRankIterator &operator++() noexcept
136 {
137 ++this->x_rank;
138 if (this->x_rank < this->container->size())
139 {
140 this->y_rank = this->container->access_y_rank(this->x_rank);
141 }
142 else
143 {
144 this->y_rank = this->container->size();
145 }
146 return *this;
147 }
148 XRankIterator operator++(int) noexcept
149 {
150 XRankIterator t = *this;
151 ++*this;
152 return t;
153 }
154 XRankIterator &operator--() noexcept
155 {
156 if (this->x_rank > 0)
157 {
158 this->x_rank--;
159 this->y_rank = this->container->access_y_rank(this->x_rank);
160 }
161 else
162 {
163 this->y_rank = this->container->size();
164 }
165 return *this;
166 }
167 XRankIterator operator--(int) noexcept
168 {
169 XRankIterator t = *this;
170 --*this;
171 return t;
172 }
173
174 // 加減算
175 XRankIterator &operator+=(difference_type n) noexcept
176 {
177 this->x_rank += n;
178 if (this->x_rank < this->container->size())
179 {
180 this->y_rank = this->container->access_y_rank(this->x_rank);
181 }
182 else
183 {
184 this->y_rank = this->container->size();
185 }
186 return *this;
187 }
188 XRankIterator &operator-=(difference_type n) noexcept
189 {
190 this->x_rank -= n;
191 if (this->x_rank < this->container->size())
192 {
193 this->y_rank = this->container->access_y_rank(this->x_rank);
194 }
195 else
196 {
197 this->y_rank = this->container->size();
198 }
199 return *this;
200 }
201 friend XRankIterator operator+(XRankIterator it, difference_type n) noexcept { return XRankIterator(it.container, it.y_rank + n, it.x_rank); }
202 friend XRankIterator operator+(difference_type n, XRankIterator it) noexcept { return it + n; }
203 friend XRankIterator operator-(XRankIterator it, difference_type n) noexcept { return XRankIterator(it.container, it.y_rank - n, it.x_rank); }
204 friend difference_type operator-(XRankIterator a, XRankIterator b) noexcept { return a.y_rank - b.y_rank; }
205
206 // 比較
207 friend bool operator==(XRankIterator a, XRankIterator b) noexcept { return a.y_rank == b.y_rank; }
208 friend bool operator!=(XRankIterator a, XRankIterator b) noexcept { return !(a == b); }
209 friend bool operator<(XRankIterator a, XRankIterator b) noexcept { return a.y_rank < b.y_rank; }
210 friend bool operator>(XRankIterator a, XRankIterator b) noexcept { return b < a; }
211 friend bool operator<=(XRankIterator a, XRankIterator b) noexcept { return !(b < a); }
212 friend bool operator>=(XRankIterator a, XRankIterator b) noexcept { return !(a < b); }
213 };
214
216 {
217 this->clear();
218 }
219
220 void clear()
221 {
222 for (uint64_t i = 0; i < this->bits_seq.size(); i++)
223 {
224 this->bits_seq[i].clear();
225 this->length_seq[i].clear();
226 }
227 this->bits_seq.clear();
228 this->length_seq.clear();
229 //this->leaves.clear();
230
231 }
232
233 uint64_t get_node_x_pos_in_bit_sequence(int64_t h, uint64_t h_node_id) const{
234 if(h_node_id == 0){
235 return 0;
236 }
237 else{
238 return this->length_seq[h].psum(h_node_id-1);
239 }
240 }
241 /*
242 return the number of 0 in S[0..i];
243 */
244 uint64_t rank0_in_bit_sequence_of_node(uint64_t h, [[maybe_unused]] uint64_t h_node_id, uint64_t node_x_pos_in_bit_sequence, uint64_t i) const {
245 assert(i <= this->length_seq[h].at(h_node_id));
246 assert(node_x_pos_in_bit_sequence == this->get_node_x_pos_in_bit_sequence(h, h_node_id));
247 return this->bits_seq[h].rank0(node_x_pos_in_bit_sequence + i + 1) - this->bits_seq[h].rank0(node_x_pos_in_bit_sequence);
248
249 }
250 /*
251 return the number of 1 in S[0..i];
252 */
253 uint64_t rank1_in_bit_sequence_of_node(uint64_t h, [[maybe_unused]] uint64_t h_node_id, uint64_t node_x_pos_in_bit_sequence, uint64_t i) const {
254 assert(i <= this->length_seq[h].at(h_node_id));
255 assert(node_x_pos_in_bit_sequence == this->get_node_x_pos_in_bit_sequence(h, h_node_id));
256 return this->bits_seq[h].rank1(node_x_pos_in_bit_sequence + i + 1) - this->bits_seq[h].rank1(node_x_pos_in_bit_sequence);
257
258 }
259
260
261 void recursive_add(int64_t h, uint64_t h_node_id, uint64_t x_rank, uint64_t y_rank, std::vector<uint64_t> &output_path)
262 {
263 output_path[h] = h_node_id;
264 uint64_t node_size = this->length_seq[h].at(h_node_id);
265 uint64_t node_x_pos_in_bit_sequence = this->get_node_x_pos_in_bit_sequence(h, h_node_id);
266
267 if (h + 1 < this->height())
268 {
269 uint64_t left_node_id = 2 * h_node_id;
270 uint64_t right_node_id = 2 * h_node_id + 1;
271 uint64_t left_tree_size = this->length_seq[h + 1].at(left_node_id);
272
273 if (x_rank <= left_tree_size)
274 {
275 uint64_t new_y_rank = y_rank > 0 ? this->rank0_in_bit_sequence_of_node(h, h_node_id, node_x_pos_in_bit_sequence, y_rank-1) : 0;
276 this->recursive_add(h + 1, left_node_id, x_rank, new_y_rank, output_path);
277 this->bits_seq[h].insert(node_x_pos_in_bit_sequence + y_rank, false);
278 this->length_seq[h].increment(h_node_id, 1);
279 }
280 else
281 {
282 uint64_t new_y_rank = y_rank > 0 ? this->rank1_in_bit_sequence_of_node(h, h_node_id, node_x_pos_in_bit_sequence, y_rank - 1) : 0;
283 uint64_t new_x_rank = x_rank - left_tree_size;
284
285 this->recursive_add(h + 1, right_node_id, new_x_rank, new_y_rank, output_path);
286 this->bits_seq[h].insert(node_x_pos_in_bit_sequence + y_rank, true);
287 this->length_seq[h].increment(h_node_id, 1);
288 }
289
290 //uint64_t upper_size = this->get_upper_size_of_internal_node(h);
291 /*
292 if (this->is_unbalanced_node(h, h_node_id))
293 {
294 if(h + 5 < this->height()){
295 std::cout << "Rebuild internal node: h = " << h << ", h_node_id = " << h_node_id << ", H = " << this->height() << "/len = " << this->length_seq[h].at(h_node_id) << "/ s: " << this->get_upper_size_of_internal_node(h) << std::endl;
296 }
297 std::cout << "Rebuild internal node: h = " << h << ", h_node_id = " << h_node_id << ", H = " << this->height() << "/len = " << this->length_seq[h].at(h_node_id) << "/ s: " << this->get_upper_size_of_internal_node(h) << std::endl;
298 this->print_tree();
299
300 this->rebuild_internal_node(h, h_node_id);
301 }
302 */
303
304 }
305 else
306 {
307 assert(this->get_bit_count_in_node(h, h_node_id) <= 1);
308 if(node_size == 0){
309 this->bits_seq[h].insert(node_x_pos_in_bit_sequence + y_rank, false);
310 this->length_seq[h].increment(h_node_id, 1);
311 }else if(node_size == 1){
312 assert(x_rank <= 1);
313 if(x_rank == 0){
314 this->bits_seq[h].set_bit(node_x_pos_in_bit_sequence, true);
315 this->bits_seq[h].insert(node_x_pos_in_bit_sequence + y_rank, false);
316
317 }else{
318
319 this->bits_seq[h].insert(node_x_pos_in_bit_sequence + y_rank, true);
320 }
321 this->length_seq[h].increment(h_node_id, 1);
322 }else{
323 throw std::runtime_error("node_size > 1");
324 }
325 }
326
327
328 }
329
330
331
332 static uint64_t _get_upper_size_of_root(uint64_t H)
333 {
334 return _get_upper_size_of_internal_node(0, H);
335 }
336
337
338 static uint64_t _get_upper_size_of_internal_node(uint64_t h, uint64_t H)
339 {
340
341 uint64_t u1 = 1;
342 for (uint64_t p = h + 1; p < H; p++)
343 {
344 u1 *= 2;
345 }
346
347 if(u1 > 4){
348 return u1 / 2;
349 }else{
350 return u1;
351 }
352 }
353
354 uint64_t get_upper_size_of_internal_node(uint64_t h) const
355 {
356 return _get_upper_size_of_internal_node(h, this->height());
357 }
358 uint64_t get_lower_size_of_internal_node(uint64_t h) const
359 {
360 uint64_t fsize = _get_upper_size_of_internal_node(h, this->height());
361 return (fsize / 4);
362 }
363
364
365 void build_h_bit_sequence(uint64_t h, const std::vector<uint64_t> &rank_elements, std::vector<uint64_t> &output_next_rank_elements, std::vector<uint64_t> &output_next_length_seq)
366 {
367
368 uint64_t h_node_count = 1ULL << h;
369 uint64_t counter = 0;
370 uint64_t node_x_pos = 0;
371 std::vector<bool> tmp_bit_sequence(rank_elements.size(), false);
372
373 if((int64_t)(h + 1) < this->height()){
374 output_next_rank_elements.resize(rank_elements.size(), UINT64_MAX);
375 output_next_length_seq.resize(h_node_count * 2, UINT64_MAX);
376 }
377
378
379
380 for(uint64_t i = 0; i < h_node_count; i++){
381 uint64_t bit_size = this->get_bit_count_in_node(h, i);
382 uint64_t half_size = bit_size / 2;
383
384
385 // Processing left elements
386 if((int64_t)(h + 1) < this->height())
387 {
388 uint64_t left_counter = 0;
389 for (uint64_t j = 0; j < bit_size; j++)
390 {
391 if (rank_elements[node_x_pos + j] < half_size)
392 {
393
394 output_next_rank_elements[counter++] = rank_elements[node_x_pos + j];
395 left_counter++;
396 }
397 }
398 output_next_length_seq[i * 2] = left_counter;
399 }
400
401 // Processing right elements
402 {
403 uint64_t right_counter = 0;
404
405 if((int64_t)(h + 1) < this->height()){
406 for (uint64_t j = 0; j < bit_size; j++)
407 {
408 if (rank_elements[node_x_pos + j] >= half_size)
409 {
410
411 tmp_bit_sequence[node_x_pos + j] = true;
412 output_next_rank_elements[counter++] = rank_elements[node_x_pos + j] - half_size;
413 right_counter++;
414 }
415 }
416 output_next_length_seq[(i * 2) + 1] = right_counter;
417 }else{
418 if(bit_size > 1){
419 throw std::runtime_error("Error in build_h_bit_sequence, bit_size > 1");
420 }
421 }
422
423 }
424
425 node_x_pos += bit_size;
426 }
427 this->bits_seq[h].clear();
428 this->bits_seq[h].push_many(tmp_bit_sequence);
429 }
430
431 void rebuild_h_bit_sequence(uint64_t h, uint64_t first_node_id, uint64_t local_h_node_count, const std::vector<uint64_t> &rank_elements, std::vector<uint64_t> &output_next_rank_elements, std::vector<uint64_t> &output_next_length_seq)
432 {
433 assert(first_node_id + local_h_node_count - 1 < this->length_seq[h].size());
434
435
436 uint64_t counter = 0;
437 uint64_t node_x_pos = this->get_node_x_pos_in_bit_sequence(h, first_node_id);
438 const uint64_t first_node_x_pos = node_x_pos;
439 std::vector<bool> tmp_bit_sequence(rank_elements.size(), false);
440
441 if((int64_t)(h + 1) < this->height()){
442 output_next_rank_elements.resize(rank_elements.size(), UINT64_MAX);
443 output_next_length_seq.resize(local_h_node_count * 2, UINT64_MAX);
444 }
445
446
447
448 for(uint64_t node_id = first_node_id; node_id <= first_node_id + local_h_node_count - 1; node_id++){
449 uint64_t bit_size = this->get_bit_count_in_node(h, node_id);
450 uint64_t half_size = bit_size / 2;
451
452
453 // Processing left elements
454 if((int64_t)(h + 1) < this->height())
455 {
456 uint64_t left_counter = 0;
457 for (uint64_t j = 0; j < bit_size; j++)
458 {
459 if (rank_elements[(node_x_pos - first_node_x_pos) + j] < half_size)
460 {
461 output_next_rank_elements[counter++] = rank_elements[(node_x_pos - first_node_x_pos) + j];
462 left_counter++;
463 }
464 }
465 output_next_length_seq[(node_id - first_node_id) * 2] = left_counter;
466 }
467
468 // Processing right elements
469 {
470 uint64_t right_counter = 0;
471
472 if((int64_t)(h + 1) < this->height()){
473 for (uint64_t j = 0; j < bit_size; j++)
474 {
475 if (rank_elements[(node_x_pos - first_node_x_pos) + j] >= half_size)
476 {
477
478 tmp_bit_sequence[(node_x_pos - first_node_x_pos) + j] = true;
479 output_next_rank_elements[counter++] = rank_elements[(node_x_pos - first_node_x_pos) + j] - half_size;
480 right_counter++;
481 }
482 }
483 output_next_length_seq[(node_id - first_node_id) * 2 + 1] = right_counter;
484 }else{
485 if(bit_size > 1){
486 throw std::runtime_error("Error in rebuild_h_bit_sequence, bit_size > 1");
487 }
488 }
489
490 }
491
492 node_x_pos += bit_size;
493 }
494 this->bits_seq[h].set_bits(first_node_x_pos, tmp_bit_sequence);
495 }
496
497
498 void build(const std::vector<uint64_t> &rank_elements, int message_paragraph = stool::Message::NO_MESSAGE)
499 {
500
501
502 this->clear();
503
504 uint64_t height = 0;
505 while (true)
506 {
507 uint64_t fsize = _get_upper_size_of_root(height);
508 if (rank_elements.size() < fsize)
509 {
510 break;
511 }
512 else
513 {
514 height++;
515 }
516 }
517
518 if(message_paragraph != stool::Message::NO_MESSAGE){
519 std::cout << stool::Message::get_paragraph_string(message_paragraph) << "Building wavelet tree for range search... " << "(input size = " << rank_elements.size() << ", tree height = " << height << ")" << std::endl;
520 }
521 std::chrono::system_clock::time_point st1, st2;
522 st1 = std::chrono::system_clock::now();
523
524
525 this->bits_seq.resize(height);
526 this->length_seq.resize(height);
527 for(uint64_t h = 0; h < height; h++){
528 this->bits_seq[h].clear();
529 this->length_seq[h].clear();
530 }
531
532 if(height > 0){
533 this->length_seq[0].push_back(rank_elements.size());
534 std::vector<uint64_t> tmp_rank_elements = rank_elements;
535
536 for(uint64_t h = 0; h < height; h++){
537
538 if(message_paragraph != stool::Message::NO_MESSAGE){
539 std::cout << stool::Message::get_paragraph_string(message_paragraph+1) << "Building the " << h << "-th bit sequence in the wavelet tree... " << std::endl;
540 }
541 std::vector<uint64_t> next_rank_elements;
542 std::vector<uint64_t> next_length_seq;
543
544 this->build_h_bit_sequence(h, tmp_rank_elements, next_rank_elements, next_length_seq);
545
546
547 tmp_rank_elements.swap(next_rank_elements);
548 if(h + 1 < height){
549 this->length_seq[h+1].push_many(next_length_seq);
550 }
551 }
552 }
553
554 assert(this->verify());
555
556 st2 = std::chrono::system_clock::now();
557 uint64_t sec_time = std::chrono::duration_cast<std::chrono::seconds>(st2 - st1).count();
558
559 if(message_paragraph != stool::Message::NO_MESSAGE){
560 std::cout << stool::Message::get_paragraph_string(message_paragraph) << "[DONE] Elapsed Time: " << sec_time << " sec" << std::endl;
561 }
562 }
563
564 uint64_t get_bit_count_in_node(uint64_t h, uint64_t h_node_id) const {
565 assert(h < this->length_seq.size());
566 assert(h_node_id < this->length_seq[h].size());
567 return this->length_seq[h].at(h_node_id);
568 }
569
570
571 bool is_unbalanced_node(uint8_t h, uint64_t h_node_id) const
572 {
573 if (h + 1 < this->height())
574 {
575 uint64_t left_node_id = 2 * h_node_id;
576 uint64_t right_node_id = 2 * h_node_id + 1;
577 uint64_t left_tree_size = this->get_bit_count_in_node(h + 1, left_node_id);
578 uint64_t right_tree_size = this->get_bit_count_in_node(h + 1, right_node_id);
579 bool unbalance_flag_LR = left_tree_size > (right_tree_size * 2) || right_tree_size > (left_tree_size * 2);
580 uint64_t child_upper_size = this->get_upper_size_of_internal_node(h + 1);
581 bool full_flag_L = left_tree_size >= child_upper_size;
582 bool full_flag_R = right_tree_size >= child_upper_size;
583 return unbalance_flag_LR || full_flag_L || full_flag_R;
584 }
585 else
586 {
587 return this->length_seq[h].at(h_node_id) >= 2;
588
589 }
590 }
591
592
593
594
595
596
597 void swap(DynamicWaveletTreeOnGrid &item)
598 {
599 this->length_seq.swap(item.length_seq);
600 this->bits_seq.swap(item.bits_seq);
601 }
602
603 int64_t height() const
604 {
605 return this->bits_seq.size();
606 }
607 uint64_t size() const
608 {
609 if (this->height() > 0)
610 {
611 return this->bits_seq[0].size();
612 }
613 else
614 {
615 return 0;
616 }
617 }
618
619 uint64_t access_x_rank(uint64_t y_rank) const
620 {
621 assert(y_rank < this->size());
622 uint64_t x_rank = this->compute_local_x_rank(0, 0, y_rank);
623 return x_rank;
624 }
625 uint64_t find_leaf_index(uint64_t x_rank) const
626 {
627
628 if (x_rank >= this->size())
629 {
630 throw std::runtime_error("ERROR in find_leaf_index: x_rank is out of range");
631 }
632
633 uint64_t current_x_rank = x_rank;
634 uint64_t current_node_id = 0;
635 uint64_t height = this->height();
636
637 for (uint64_t h = 0; h + 1 < height; h++)
638 {
639 uint64_t left_tree_size = this->get_bit_count_in_node(h+1, 2 * current_node_id);
640
641 if (current_x_rank < left_tree_size)
642 {
643 current_node_id = 2 * current_node_id;
644 }
645 else
646 {
647 current_x_rank = current_x_rank - left_tree_size;
648 current_node_id = 2 * current_node_id + 1;
649 }
650 }
651 return current_node_id;
652
653 }
654 uint64_t access_y_rank(uint64_t x_rank) const
655 {
656 uint64_t leaf_index = this->find_leaf_index(x_rank);
657 uint64_t current_y_rank = 0;
658 uint64_t prev_node_id = leaf_index;
659 int64_t height = this->height();
660 for (int64_t h = height - 2; h >= 0; h--)
661 {
662 uint64_t next_node_id = prev_node_id / 2;
663 uint64_t next_x_pos = this->get_node_x_pos_in_bit_sequence(h, next_node_id);
664
665 if (prev_node_id % 2 == 0)
666 {
667 uint64_t count_zero_offset = this->bits_seq[h].rank0(next_x_pos);
668 uint64_t next_y_rank = this->bits_seq[h].select0(current_y_rank + count_zero_offset) - next_x_pos;
669 current_y_rank = next_y_rank;
670 prev_node_id = next_node_id;
671
672 }
673 else
674 {
675
676 uint64_t count_one_offset = this->bits_seq[h].rank1(next_x_pos);
677 int64_t select_result = this->bits_seq[h].select1(current_y_rank + count_one_offset);
678 assert(select_result >= 0);
679 uint64_t next_y_rank = select_result - next_x_pos;
680
681
682
683 current_y_rank = next_y_rank;
684 prev_node_id = next_node_id;
685
686
687 }
688 }
689 return current_y_rank;
690 }
691 std::vector<bool> get_bit_sequence(uint64_t h, uint64_t node_id) const{
692 uint64_t x_pos = this->get_node_x_pos_in_bit_sequence(h, node_id);
693 uint64_t node_size = this->get_bit_count_in_node(h, node_id);
694 std::vector<bool> r;
695 r.resize(node_size, false);
696 for(uint64_t i = 0; i < node_size; i++){
697 r[i] = this->bits_seq[h].at(x_pos + i);
698 }
699 return r;
700 }
701
702 bool verify() const
703 {
704
705 for (uint64_t h = 0; h < this->bits_seq.size(); h++)
706 {
707 uint64_t node_count = 1 << h;
708
709 if(h + 1 < this->bits_seq.size()){
710 for (uint64_t i = 0; i < node_count; i++)
711 {
712 std::vector<bool> bit_sequence = this->get_bit_sequence(h, i);
713 uint64_t countL = 0;
714 uint64_t countR = 0;
715 for(uint64_t j = 0; j < bit_sequence.size(); j++){
716 if(bit_sequence[j]){
717 countR++;
718 }else{
719 countL++;
720 }
721 }
722
723 uint64_t left_tree_size = this->get_bit_count_in_node(h+1, 2 * i);
724 uint64_t right_tree_size = this->get_bit_count_in_node(h+1, 2 * i + 1);
725
726 if(countL != left_tree_size){
727 this->print_tree();
728 throw std::runtime_error("Error: verify, countL != left_tree_size");
729 }
730
731 if(countR != right_tree_size){
732 this->print_tree();
733 throw std::runtime_error("Error: verify, countR != right_tree_size");
734 }
735 }
736
737 }else{
738 for (uint64_t i = 0; i < node_count; i++)
739 {
740 uint64_t bit_size = this->get_bit_count_in_node(h, i);
741 if(bit_size > 1){
742 this->print_tree();
743 throw std::runtime_error("Error: verify function, bit_size > 1");
744 }
745 }
746 }
747
748 }
749 return true;
750
751 }
752
753 std::vector<uint64_t> to_local_rank_elements_in_y_order(uint64_t h, uint64_t node_id) const
754 {
755 uint64_t height = this->height();
756 uint64_t node_size = this->get_bit_count_in_node(h, node_id);
757 std::vector<uint64_t> r;
758 r.resize(node_size, UINT64_MAX);
759 uint64_t x_pos = this->get_node_x_pos_in_bit_sequence(h, node_id);
760
761 if (h + 1 < height)
762 {
763 uint64_t counterL = 0;
764 uint64_t counterR = 0;
765 uint64_t leaf_id_L = 2 * node_id;
766 uint64_t leaf_id_R = 2 * node_id + 1;
767 uint64_t left_tree_size = this->get_bit_count_in_node(h + 1, leaf_id_L);
768
769 std::vector<uint64_t> left_elements = this->to_local_rank_elements_in_y_order(h + 1, leaf_id_L);
770 std::vector<uint64_t> right_elements = this->to_local_rank_elements_in_y_order(h + 1, leaf_id_R);
771
772 assert(left_elements.size() + right_elements.size() == node_size);
773
774
775
776
777 for (uint64_t i = 0; i < node_size; i++)
778 {
779 bool b = this->bits_seq[h].at(x_pos + i);
780 if (b)
781 {
782 assert(counterR < right_elements.size());
783
784 r[i] = right_elements[counterR++] + left_tree_size;
785 }
786 else
787 {
788 assert(counterL < left_elements.size());
789 r[i] = left_elements[counterL++];
790 }
791 }
792 }
793 else
794 {
795 if(node_size == 1){
796 r[0] = 0;
797 }else if(node_size > 1){
798 for (uint64_t i = 0; i < node_size; i++)
799 {
800 r[i] = this->bits_seq[h].at(x_pos + i);
801 }
802 }
803 }
804
805 return r;
806
807 }
808
809 std::vector<uint64_t> to_rank_elements_in_y_order() const
810 {
811
812 if (this->height() > 0)
813 {
814 return this->to_local_rank_elements_in_y_order(0, 0);
815 }
816 else
817 {
818 std::vector<uint64_t> r;
819 return r;
820 }
821 }
822 std::vector<uint64_t> to_rank_elements_in_x_order() const
823 {
824 std::vector<uint64_t> r;
825 r.resize(this->size(), UINT64_MAX);
826 uint64_t size = this->size();
827 for(uint64_t i = 0; i < size; i++){
828 r[i] = this->access_y_rank(i);
829 }
830
831 return r;
832 }
833
834
835 uint64_t compute_local_x_rank(uint64_t node_y, uint64_t node_id, uint64_t local_y_rank) const
836 {
837 assert(local_y_rank < this->length_seq[node_y].at(node_id));
838
839 uint64_t x_rank = 0;
840 uint64_t h_node_id = node_id;
841 int64_t height = this->height();
842 for (int64_t h = node_y; h + 1 < height; h++)
843 {
844 uint64_t node_x_pos = this->get_node_x_pos_in_bit_sequence(h, h_node_id);
845
846 assert(node_x_pos + local_y_rank < this->bits_seq[h].size());
847
848 bool b = this->bits_seq[h].at(node_x_pos + local_y_rank);
849 uint64_t next_node_id = (2 * h_node_id) + (uint64_t)b;
850 if (b)
851 {
852 uint64_t left_tree_size = this->get_bit_count_in_node(h+1, 2 * h_node_id);
853 x_rank += left_tree_size;
854 local_y_rank -= this->rank0_in_bit_sequence_of_node(h, h_node_id, node_x_pos, local_y_rank);
855 }
856 else
857 {
858 local_y_rank -= this->rank1_in_bit_sequence_of_node(h, h_node_id, node_x_pos, local_y_rank);
859 }
860 h_node_id = next_node_id;
861 }
862 //x_rank += this->leaves[h_node_id][h_y_rank];
863 return x_rank;
864 }
865
866 template <typename APPENDABLE_VECTOR>
867 uint64_t local_range_report_on_internal_node(uint64_t h, uint64_t node_id, uint64_t x_rank_gap, uint64_t hy_min, uint64_t hy_max, APPENDABLE_VECTOR &out) const
868 {
869
870 for (uint64_t i = hy_min; i <= hy_max; i++)
871 {
872 uint64_t x = this->compute_local_x_rank(h, node_id, i) + x_rank_gap;
873 out.push_back(x);
874 }
875 return hy_max - hy_min + 1;
876 }
877
878 template <typename APPENDABLE_VECTOR>
879 uint64_t recursive_range_report_on_internal_nodes(uint64_t h, uint64_t node_id, uint64_t node_x_pos, int64_t x_min, int64_t x_max, uint64_t hy_min, uint64_t hy_max, APPENDABLE_VECTOR &out) const
880 {
881
882 uint64_t found_elements_count = 0;
883 int64_t node_size = this->get_bit_count_in_node(h, node_id);
884 if (x_min <= (int64_t)node_x_pos && (int64_t)(node_x_pos + node_size - 1) <= x_max)
885 {
886 uint64_t limitR = std::min((int64_t)hy_max, node_size-1);
887
888 if(hy_min <= limitR){
889 uint64_t _tmp = local_range_report_on_internal_node(h, node_id, node_x_pos, hy_min, limitR, out);
890 found_elements_count += _tmp;
891 }
892
893 }
894 else if((int64_t)(h+1) < this->height())
895 {
896 uint64_t node_x_pos_L = node_x_pos;
897 uint64_t node_x_pos_R = node_x_pos + this->get_bit_count_in_node(h+1, 2 * node_id);
898
899
900
901 /*
902 int64_t hy_max_0 = ((int64_t)this->bits_seq[h][node_id].rank0(hy_max + 1)) - 1;
903 int64_t hy_max_1 = ((int64_t)this->bits_seq[h][node_id].rank1(hy_max + 1)) - 1;
904 */
905 int64_t hy_max_0 = rank0_in_bit_sequence_of_node(h, node_id, node_x_pos_L, hy_max) - 1;
906 int64_t hy_max_1 = rank1_in_bit_sequence_of_node(h, node_id, node_x_pos_L, hy_max) - 1;
907
908
909 int64_t hy_min_0 = hy_min > 0 ? rank0_in_bit_sequence_of_node(h, node_id, node_x_pos_L, hy_min - 1) : 0;
910 int64_t hy_min_1 = hy_min > 0 ? rank1_in_bit_sequence_of_node(h, node_id, node_x_pos_L, hy_min - 1) : 0;
911
912
913 uint64_t next_node_id_L = 2 * node_id;
914 uint64_t next_node_id_R = next_node_id_L + 1;
915
916 if (x_min < (int64_t)node_x_pos_R && hy_min_0 <= hy_max_0)
917 {
918 found_elements_count += this->recursive_range_report_on_internal_nodes(h + 1, next_node_id_L, node_x_pos_L, x_min, x_max, hy_min_0, hy_max_0, out);
919 }
920
921 if (x_max >= (int64_t)node_x_pos_R && hy_min_1 <= hy_max_1)
922 {
923 found_elements_count += this->recursive_range_report_on_internal_nodes(h + 1, next_node_id_R, node_x_pos_R, x_min, x_max, hy_min_1, hy_max_1, out);
924 }
925
926 }
927 return found_elements_count;
928
929 }
930
931 template <typename APPENDABLE_VECTOR>
932 uint64_t range_report(uint64_t x_min, uint64_t x_max, uint64_t y_min, uint64_t y_max, APPENDABLE_VECTOR &out) const
933 {
934 uint64_t found_elements_count = 0;
935 if (this->height() > 0)
936 {
937 found_elements_count = this->recursive_range_report_on_internal_nodes(0, 0, 0, x_min, x_max, y_min, y_max, out);
938 }
939 return found_elements_count;
940 }
941
942 XRankIterator x_rank_begin() const
943 {
944 if (this->size() == 0)
945 {
946 return XRankIterator(this, this->size(), this->size());
947 }
948 else
949 {
950 return XRankIterator(this, 0, this->access_y_rank(0));
951 }
952 }
953 XRankIterator x_rank_end() const
954 {
955 return XRankIterator(this, this->size(), this->size());
956 }
957 YRankIterator y_rank_begin() const
958 {
959 if (this->size() == 0)
960 {
961 return YRankIterator(this, this->size(), this->size());
962 }
963 else
964 {
965 return YRankIterator(this, this->access_x_rank(0), 0);
966 }
967 }
968 YRankIterator y_rank_end() const
969 {
970 return YRankIterator(this, this->size(), this->size());
971 }
972
973 void print_tree() const
974 {
975 std::vector<std::string> r;
976 for (uint64_t h = 0; h < this->bits_seq.size(); h++)
977 {
978 std::vector<uint64_t> tmp_len_veq;
979 uint64_t counter = 0;
980 tmp_len_veq.push_back(0);
981 for(uint64_t i = 0; i < this->length_seq[h].size(); i++){
982 counter += this->length_seq[h].at(i);
983 tmp_len_veq.push_back(counter);
984 }
985
986 std::string s = "";
987 uint64_t tmp_p = 0;
988 for(uint64_t i = 0; i < this->bits_seq[h].size(); i++){
989 while(tmp_len_veq[tmp_p] <= i){
990 s.append("|");
991 tmp_p++;
992 }
993
994 bool b = this->bits_seq[h].at(i);
995
996 s.append(b ? "1" : "0");
997 }
998 s.append("|");
999
1000 r.push_back(s);
1001 }
1002
1003
1004 std::cout << "===== TREE =====" << std::endl;
1005
1006 for (uint64_t i = 0; i < r.size(); i++)
1007 {
1008 std::cout << r[i] << std::endl;
1009 }
1010 std::cout << "===== [END] =====" << std::endl;
1011 }
1012
1013 static void store_to_file(DynamicWaveletTreeOnGrid &item, std::ofstream &os)
1014 {
1015 throw std::runtime_error("Error: DynamicWaveletTreeOnGrid::store_to_file is not implemented");
1016 }
1017 static void store_to_bytes(DynamicWaveletTreeOnGrid &item, std::vector<uint8_t> &output, uint64_t &pos)
1018 {
1019 throw std::runtime_error("Error: DynamicWaveletTreeOnGrid::store_to_bytes is not implemented");
1020 }
1021 uint64_t size_in_bytes(bool only_extra_bytes = false) const
1022 {
1023 uint64_t sum = 0;
1024 sum += sizeof(uint64_t);
1025 for(int64_t h = 0; h < (int64_t)this->height(); h++){
1026 sum += this->bits_seq[h].size_in_bytes(only_extra_bytes);
1027 sum += this->length_seq[h].size_in_bytes(only_extra_bytes);
1028 }
1029 return sum;
1030
1031 }
1032
1033
1034 static DynamicWaveletTreeOnGrid load_from_file(std::ifstream &ifs)
1035 {
1036 throw std::runtime_error("Error: DynamicWaveletTreeOnGrid::load_from_file is not implemented");
1037
1038 }
1039 static DynamicWaveletTreeOnGrid load_from_bytes(const std::vector<uint8_t> &data, uint64_t &pos)
1040 {
1041 throw std::runtime_error("Error: DynamicWaveletTreeOnGrid::load_from_bytes is not implemented");
1042 }
1043 void rebuild_internal_node(uint8_t h, uint64_t h_node_id)
1044 {
1045
1046 std::vector<uint64_t> rank_elements = this->to_local_rank_elements_in_y_order(h, h_node_id);
1047
1048
1049 uint64_t height = this->height();
1050 uint64_t current_node_id = h_node_id;
1051 uint64_t current_node_count = 1;
1052 for(uint64_t q = h; q < height; q++){
1053 std::vector<uint64_t> next_rank_elements;
1054 std::vector<uint64_t> next_length_seq;
1055
1056 this->rebuild_h_bit_sequence(q, current_node_id, current_node_count, rank_elements, next_rank_elements, next_length_seq);
1057
1058 rank_elements.swap(next_rank_elements);
1059
1060 current_node_count *= 2;
1061 current_node_id *= 2;
1062
1063 if(q+1 < height){
1064 this->length_seq[q+1].set_values(current_node_id, next_length_seq);
1065 }
1066
1067 }
1068 }
1069 void add(uint64_t x_rank, uint64_t y_rank)
1070 {
1071
1072
1073 if(this->size() > 0){
1074 //std::cout << "Add: x_rank = " << x_rank << ", y_rank = " << y_rank << std::endl;
1075
1076 std::vector<uint64_t> output_path(this->height(), UINT64_MAX);
1077 this->recursive_add(0, 0, x_rank, y_rank, output_path);
1078 uint64_t upper_size = this->get_upper_size_of_internal_node(0);
1079 if (this->size() >= upper_size)
1080 {
1081 /*
1082 std::cout << "Rebuilding range reporting data structure ...: ";
1083 std::chrono::system_clock::time_point st1, st2;
1084 st1 = std::chrono::system_clock::now();
1085 */
1086
1087 std::vector<uint64_t> rank_elements = this->to_rank_elements_in_y_order();
1088 this->build(rank_elements);
1089
1090 /*
1091
1092 st2 = std::chrono::system_clock::now();
1093 uint64_t sec_time = std::chrono::duration_cast<std::chrono::seconds>(st2 - st1).count();
1094 std::cout << "[DONE] Elapsed Time: " << sec_time << " sec, the number of elements: " << rank_elements.size() << std::endl;
1095 */
1096 }else{
1097 uint64_t height = this->height();
1098 for(uint64_t h = 0; h < height; h++){
1099 uint64_t h_node_id = output_path[h];
1100 if (this->is_unbalanced_node(h, h_node_id))
1101 {
1102 /*
1103 if(h + 5 < this->height()){
1104 std::cout << "Rebuild internal node: h = " << h << ", h_node_id = " << h_node_id << ", H = " << this->height() << "/len = " << this->length_seq[h].at(h_node_id) << "/ s: " << this->get_upper_size_of_internal_node(h) << std::endl;
1105 }
1106 */
1107
1108 this->rebuild_internal_node(h, h_node_id);
1109 break;
1110 }
1111
1112 }
1113 }
1114 assert(this->verify());
1115
1116 }else{
1117 this->clear();
1118 std::vector<uint64_t> rank_elements;
1119 rank_elements.push_back(0);
1120 this->build(rank_elements);
1121
1122 }
1123
1124
1125
1126 }
1127
1128 void remove(uint64_t y_rank)
1129 {
1130 int64_t height = this->height();
1131 if(height == 0){
1132 throw std::runtime_error("Error: DynamicWaveletTreeOnGrid::remove(y_rank)");
1133 }else{
1134 uint64_t h_y_rank = y_rank;
1135 uint64_t h_node_id = 0;
1136
1137 for (int64_t h = 0; h < height; h++)
1138 {
1139 uint64_t node_x_pos = this->get_node_x_pos_in_bit_sequence(h, h_node_id);
1140 bool b = this->bits_seq[h].at(node_x_pos + h_y_rank);
1141 uint64_t next_node_id = (2 * h_node_id) + (uint64_t)b;
1142 if (b)
1143 {
1144 uint64_t rmv_y_rank = this->rank0_in_bit_sequence_of_node(h, h_node_id, node_x_pos, h_y_rank);
1145 this->bits_seq[h].remove(node_x_pos + h_y_rank);
1146 this->length_seq[h].decrement(h_node_id, 1);
1147 h_y_rank -= rmv_y_rank;
1148 }
1149 else
1150 {
1151 uint64_t rmv_y_rank = this->rank1_in_bit_sequence_of_node(h, h_node_id, node_x_pos, h_y_rank);
1152 this->bits_seq[h].remove(node_x_pos + h_y_rank);
1153 this->length_seq[h].decrement(h_node_id, 1);
1154 h_y_rank -= rmv_y_rank;
1155 }
1156 h_node_id = next_node_id;
1157 }
1158
1159 uint64_t upper_size = this->get_upper_size_of_internal_node(0);
1160 uint64_t size = this->size();
1161 if (size < upper_size / 2)
1162 {
1163 auto rank_elements = this->to_rank_elements_in_y_order();
1164 this->build(rank_elements);
1165 }
1166
1167 assert(this->verify());
1168 }
1169
1170
1171 }
1172
1173 std::vector<std::string> get_memory_usage_info(int message_paragraph = stool::Message::SHOW_MESSAGE) const
1174 {
1175
1176 std::vector<std::string> r;
1177 uint64_t size_in_bytes = this->size_in_bytes();
1178 uint64_t element_count = this->size();
1179
1180 double bits_per_element = element_count > 0 ? ((double)size_in_bytes / (double)element_count) : 0;
1181
1182 r.push_back(stool::Message::get_paragraph_string(message_paragraph) + "=DynamicWaveletTreeOnGrid: " + std::to_string(this->size_in_bytes())
1183 + " bytes, " + std::to_string(element_count) + " elements, " + std::to_string(bits_per_element) + " bytes per element =");
1184
1185 for(uint64_t h = 0; h < this->bits_seq.size(); h++){
1186 uint64_t _sub_size = 0;
1187 _sub_size += this->bits_seq[h].size_in_bytes();
1188 _sub_size += this->length_seq[h].size_in_bytes();
1189
1190 uint64_t _bits_per_element = element_count > 0 ? ((double)_sub_size / (double)element_count) : 0;
1191 r.push_back(stool::Message::get_paragraph_string(message_paragraph+1) + "Level " + std::to_string(h) + " in range tree: " + std::to_string(_sub_size) + " bytes" + " (" + std::to_string(_bits_per_element) + " bytes per element)");
1192 }
1193 r.push_back(stool::Message::get_paragraph_string(message_paragraph) + "==");
1194
1195 return r;
1196 }
1197
1198 };
1199 }
1200}
An implementation of B+-tree [Unchecked AI's Comment].
Definition bp_tree.hpp:24
void push_back(VALUE value)
Pushes a single value to the B+ tree.
Definition bp_tree.hpp:1851
uint64_t size() const
Return the number of values stored in this tree.
Definition bp_tree.hpp:316
Definition dynamic_wavelet_tree_on_grid.hpp:120
Definition dynamic_wavelet_tree_on_grid.hpp:24
DynamicWaveletTreeOnGrid. [Unchecked AI's Comment].
Definition dynamic_wavelet_tree_on_grid.hpp:15