14-1 Order Statistics Trees Implementation
2020-11-22
x
template<typename ost_t_key, typename ost_t_value>
class order_statistic_tree
{
// Order Statistic Tree , based on red-black tree
public:
enum node_color : uint8_t
{
e_color_red,
e_color_black
};
struct node
{
node* parent;
node* left;
node* right;
node_color color;
int32_t size;
ost_t_key key;
ost_t_value value;
};
order_statistic_tree()
: m_nil{ allocate_node(ost_t_key(), ost_t_value(), e_color_black) }
, m_root(m_nil)
{}
~order_statistic_tree()
{
if (m_nil != nullptr && m_root != m_nil)
{
free_node_recursive(m_root);
m_root = nullptr;
}
if (m_nil != nullptr)
{
free_node(m_nil);
m_nil = nullptr;
}
}
order_statistic_tree(const order_statistic_tree<ost_t_key, ost_t_value>& other)
: m_nil{ allocate_node(ost_t_key(), ost_t_value(), e_color_black) }
, m_root(m_nil)
{
if (other.m_root != other.m_nil)
{
m_root = copy_from_other_order_statistic_tree_recursive(m_nil, m_nil, other.m_root, other.m_nil);
}
}
order_statistic_tree(order_statistic_tree<ost_t_key, ost_t_value>&& other)
: m_nil(other.m_nil)
, m_root(other.m_root)
{
other.m_nil = nullptr;
other.m_root = nullptr;
}
order_statistic_tree<ost_t_key, ost_t_value>& operator=(const order_statistic_tree<ost_t_key, ost_t_value>& other)
{
if (this != &other)
{
if (m_nil != nullptr && m_root != m_nil)
{
free_node_recursive(m_root);
}
if (m_nil == nullptr)
{
m_nil = allocate_node(ost_t_key(), ost_t_value(), e_color_black);
}
m_root = copy_from_other_order_statistic_tree_recursive(m_nil, m_nil, other.m_root, other.m_nil);
}
return *this;
}
order_statistic_tree<ost_t_key, ost_t_value>& operator=(order_statistic_tree<ost_t_key, ost_t_value>&& other)
{
if (this != &other)
{
if (m_nil != nullptr && m_root != m_nil)
{
free_node_recursive(m_root);
}
if (m_nil != nullptr)
{
free_node(m_nil)
}
m_nil = other.m_nil;
m_root = other.m_root;
other.m_nil = nullptr;
other.m_root = nullptr;
}
return *this;
}
void insert(const ost_t_key& key, const ost_t_value& value)
{
node* n = allocate_node(key, value, e_color_red, 1);
rb_insert(n);
}
void erase(const ost_t_key& key)
{
node* n = m_root;
while (n != m_nil && key != n->key)
{
if (key < n->key)
n = n->left;
else
n = n->right;
}
if (is_nil(*n) == false)
{
rb_delete(n);
free_node(n);
}
}
const node& find(const ost_t_key& key) const
{
node* x = m_root;
while (x != m_nil && key != x->key)
{
if (key < x->key)
x = x->left;
else
x = x->right;
}
return *x;
}
const node& find_with_rank(int32_t rank) const
{
return *(os_select(m_root, rank));
}
int32_t get_rank(const node& node) const
{
return os_rank(node);
}
const node& find_maximum()
{
node* n = get_maximum(m_root);
return *n;
}
const node& find_minimum()
{
node* n = get_minimum(m_root);
return *(n);
}
bool is_nil(const node& n) const
{
return &n == m_nil;
}
private:
node* m_nil;
node* m_root;
node* os_select(node* x, int32_t rank) const
{
while (x != m_nil)
{
int32_t r = x->left->size + 1;
if (rank == r)
return x;
else if (rank < r)
x = x->left;
else
x = x->right, rank -= r;
}
return m_nil;
}
int32_t os_rank(const node& x) const
{
int32_t r = x.left->size + 1;
const node* y = &x;
while (y != m_root)
{
if (y == y->parent->right)
r = r + y->parent->left->size + 1;
y = y->parent;
}
return r;
}
node* allocate_node(const ost_t_key& key, const ost_t_value& value, node_color color, int32_t size = 0)
{
node* n = new node();
n->parent = nullptr;
n->left = nullptr;
n->right = nullptr;
n->color = color;
n->size = size;
n->key = key;
n->value = value;
return n;
}
void free_node(node* node)
{
delete node;
}
void free_node_recursive(node* node)
{
if (node->left != m_nil)
free_node_recursive(node->left);
if (node->right != m_nil)
free_node_recursive(node->right);
free_node(node);
}
void left_rotate(node* x)
{
node* y = x->right; // set y
x->right = y->left; // turn y's left subtree into x's right subtree
if (y->left != m_nil)
{
y->left->parent = x;
}
y->parent = x->parent; // link x's parent to y
if (x->parent == m_nil)
{
m_root = y;
}
else if (x == x->parent->left)
{
x->parent->left = y;
}
else
{
x->parent->right = y;
}
y->left = x; // put x on y's left
x->parent = y;
y->size = x->size;
x->size = x->left->size + x->right->size + 1;
}
void right_rotate(node* y)
{
node* x = y->left; // set x
y->left = x->right; // turn x's right subtree into y's left subtree
if (x->right != m_nil)
{
x->right->parent = y;
}
x->parent = y->parent; // link y's parent to x
if (y->parent == m_nil)
{
m_root = x;
}
else if (y == y->parent->left)
{
y->parent->left = x;
}
else
{
y->parent->right = x;
}
x->right = y;
y->parent = x;
x->size = y->size;
y->size = y->left->size + y->right->size + 1;
}
void rb_insert(node* z)
{
node* y = m_nil;
node* x = m_root;
while (x != m_nil)
{
y = x;
if (z->key < x->key)
{
x->size += 1;
x = x->left;
}
else
{
x->size += 1;
x = x->right;
}
}
if (y == m_nil)
m_root = z;
else if (z->key < y->key)
y->left = z;
else
y->right = z;
z->parent = y;
z->left = m_nil;
z->right = m_nil;
z->color = e_color_red;
rb_insert_fixup(z);
}
void rb_insert_fixup(node* z)
{
while (z->parent->color == e_color_red)
{
node* zpp = z->parent->parent;
if (z->parent == zpp->left)
{
node* y = zpp->right;
if (y->color == e_color_red)
{
z->parent->color = e_color_black; // case 1
y->color = e_color_black; // case 1
zpp->color = e_color_red; // case 1
z = zpp;
}
else
{
if (z == z->parent->right)
{
z = z->parent; // case 2
left_rotate(z); // case 2
}
z->parent->color = e_color_black; // case 3
zpp->color = e_color_red; // case 3
right_rotate(zpp);
}
}
else
{
node* y = zpp->left;
if (y->color == e_color_red)
{
z->parent->color = e_color_black; // case 1
y->color = e_color_black; // case 1
zpp->color = e_color_red; // case 1
z = zpp;
}
else
{
if (z == z->parent->left)
{
z = z->parent; // case 2
right_rotate(z); // case 2
}
z->parent->color = e_color_black; // case 3
zpp->color = e_color_red; // case 3
left_rotate(zpp); // case 3
}
}
}
m_root->color = e_color_black;
}
void rb_transplant(node* u, node* v)
{
if (u->parent == m_nil)
m_root = v;
else if (u == u->parent->left)
{
u->parent->left = v;
}
else
{
u->parent->right = v;
}
v->parent = u->parent;
}
void rb_delete_fixup(node* x)
{
while (x != m_root && x->color == e_color_black)
{
node* w = nullptr;
if (x == x->parent->left)
{
w = x->parent->right;
if (w->color == e_color_red)
{
w->color = e_color_black; // case 1
x->parent->color = e_color_red; // case 1
left_rotate(x->parent); // case 1
w = x->parent->right;
}
if (w->left->color == e_color_black && w->right->color == e_color_black)
{
w->color = e_color_red; // case 2
x = x->parent; // case 2
}
else
{
if (w->right->color == e_color_black)
{
w->left->color = e_color_black; // case 3
w->color = e_color_red; // case 3
right_rotate(w); // case 3
w = x->parent->right; // case 3
}
w->color = x->parent->color; // case 4
x->parent->color = e_color_black; // case 4
w->right->color = e_color_black; // case 4
left_rotate(x->parent); // case 4
x = m_root;
}
}
else
{
w = x->parent->left;
if (w->color == e_color_red)
{
w->color = e_color_black; // case 1
x->parent->color = e_color_red; // case 1
right_rotate(x->parent); // case 1
w = x->parent->left;
}
if (w->left->color == e_color_black && w->right->color == e_color_black)
{
w->color = e_color_red; // case 2
x = x->parent; // case 2
}
else
{
if (w->left->color == e_color_black)
{
w->right->color = e_color_black; // case 3
w->color = e_color_red; // case 3
left_rotate(w); // case 3
w = x->parent->left; // case 3
}
w->color = x->parent->color; // case 4
x->parent->color = e_color_black; // case 4
w->left->color = e_color_black; // case 4
right_rotate(x->parent); // case 4
x = m_root;
}
}
}
x->color = e_color_black;
}
void rb_delete(node* z)
{
node* x = m_nil;
node* y = z;
node* yp = y->parent;
node_color y_original_color = y->color;
if (z->left == m_nil)
{
x = z->right;
rb_transplant(z, z->right);
}
else if (z->right == m_nil)
{
x = z->left;
rb_transplant(z, z->left);
}
else
{
y = get_minimum(z->right);
y_original_color = y->color;
x = y->right;
if (y->parent == z)
{
x->parent = y;
}
else
{
rb_transplant(y, y->right);
y->right = z->right;
y->right->parent = y;
yp = y->parent;
}
rb_transplant(z, y);
y->left = z->left;
y->left->parent = y;
y->color = z->color;
// after moving y into z position, resize it
y->size = y->left->size + y->right->size + 1;
}
// size update
while (yp != m_nil)
{
yp->size -= 1;
yp = yp->parent;
}
if (y_original_color == e_color_black)
rb_delete_fixup(x);
}
node* get_minimum(node* x)
{
while (x->left != m_nil)
x = x->left;
return x;
}
node* get_maximum(node* x)
{
while (x->right != m_nil)
x = x->right;
return x;
}
node* get_successor(node* x)
{
if (x->right != m_nil)
return get_minimum(x->right);
node* y = x->parent;
while (y != m_nil && x == y->right)
{
x = y;
y = y->parent;
}
return y;
}
node* get_predecessor(node* x)
{
if (x->left != m_nil)
return get_maximum(x->left);
node* y = x->parent;
while (y != m_nil && x == y->left)
{
x = y;
y = y->parent;
}
return y;
}
static node* copy_from_other_order_statistic_tree_recursive(node* my_nil, node* my_parent, node* other_node, node* other_node_nil)
{
node* new_node = new node();
new_node->color = other_node->color;
new_node->size = other_node->size;
new_node->key = other_node->key;
new_node->value = other_node->value;
new_node->parent = my_parent;
new_node->left = my_nil;
new_node->right = my_nil;
if (other_node->left != other_node_nil)
{
new_node->left = copy_from_other_order_statistic_tree_recursive(my_nil, new_node, other_node->left, other_node_nil);
}
if (other_node->right != other_node_nil)
{
new_node->right = copy_from_other_order_statistic_tree_recursive(my_nil, new_node, other_node->right, other_node_nil);
}
return new_node;
}
void traverse_preorder(node* n, void(*print_key_value_func)(const ost_t_key&, const ost_t_value&))
{
print_key_value_func(n->key, n->value);
if (n->left != m_nil)
traverse_preorder(n->left, print_key_value_func);
if (n->right != m_nil)
traverse_preorder(n->right, print_key_value_func);
}
void traverse_inorder(node* n, void(*print_key_value_func)(const ost_t_key&, const ost_t_value&))
{
if (n->left != m_nil)
traverse_inorder(n->left, print_key_value_func);
print_key_value_func(n->key, n->value);
if (n->right != m_nil)
traverse_inorder(n->right, print_key_value_func);
}
void traverse_postorder(node* n, void(*print_key_value_func)(const ost_t_key&, const ost_t_value&))
{
if (n->left != m_nil)
traverse_postorder(n->left, print_key_value_func);
if (n->right != m_nil)
traverse_postorder(n->right, print_key_value_func);
print_key_value_func(n->key, n->value);
}
};
댓글 없음:
댓글 쓰기