| 1 | module DICT_MODULENAME |
|---|
| 2 | ! Dictionary data structure |
|---|
| 3 | |
|---|
| 4 | implicit none |
|---|
| 5 | |
|---|
| 6 | private |
|---|
| 7 | public :: dict, get_val, insert_or_assign, exists, remove, get_keys_vals, get_size, get_kth_key |
|---|
| 8 | |
|---|
| 9 | type dict |
|---|
| 10 | type(node), pointer :: root => null() |
|---|
| 11 | integer :: randstate = 1231767121 |
|---|
| 12 | contains |
|---|
| 13 | !final :: destruct_dict |
|---|
| 14 | end type dict |
|---|
| 15 | |
|---|
| 16 | type node |
|---|
| 17 | type(node), pointer :: left => null(), right => null() |
|---|
| 18 | keytype1 :: key |
|---|
| 19 | valtype :: val |
|---|
| 20 | integer :: pri ! min-heap |
|---|
| 21 | integer :: cnt = 1 |
|---|
| 22 | end type node |
|---|
| 23 | |
|---|
| 24 | contains |
|---|
| 25 | |
|---|
| 26 | pure function xorshift32(i) |
|---|
| 27 | implicit none |
|---|
| 28 | integer(4), intent(in) :: i |
|---|
| 29 | integer(4) :: xorshift32 |
|---|
| 30 | if (i == 0) then |
|---|
| 31 | xorshift32 = 1231767121 |
|---|
| 32 | else |
|---|
| 33 | xorshift32 = i |
|---|
| 34 | end if |
|---|
| 35 | xorshift32 = ieor(xorshift32, ishft(xorshift32, 13)) |
|---|
| 36 | xorshift32 = ieor(xorshift32, ishft(xorshift32, -17)) |
|---|
| 37 | xorshift32 = ieor(xorshift32, ishft(xorshift32, 15)) |
|---|
| 38 | end function xorshift32 |
|---|
| 39 | |
|---|
| 40 | function get_val(t, key) |
|---|
| 41 | implicit none |
|---|
| 42 | type(dict), intent(in) :: t |
|---|
| 43 | keytype2, intent(in) :: key |
|---|
| 44 | type(node), pointer :: nd |
|---|
| 45 | valtype :: get_val |
|---|
| 46 | nd => find_node(t%root, key) |
|---|
| 47 | if (.not. associated(nd)) then |
|---|
| 48 | stop 105 |
|---|
| 49 | end if |
|---|
| 50 | get_val = nd%val |
|---|
| 51 | end function get_val |
|---|
| 52 | |
|---|
| 53 | function exists(t, key) |
|---|
| 54 | implicit none |
|---|
| 55 | type(dict), intent(in) :: t |
|---|
| 56 | keytype2, intent(in) :: key |
|---|
| 57 | type(node), pointer :: nd |
|---|
| 58 | logical :: exists |
|---|
| 59 | nd => find_node(t%root, key) |
|---|
| 60 | exists = (associated(nd)) |
|---|
| 61 | end function exists |
|---|
| 62 | |
|---|
| 63 | subroutine insert_or_assign(t, key, val) |
|---|
| 64 | implicit none |
|---|
| 65 | type(dict), intent(inout) :: t |
|---|
| 66 | keytype2, intent(in) :: key |
|---|
| 67 | valtype, intent(in) :: val |
|---|
| 68 | type(node), pointer :: nd |
|---|
| 69 | nd => find_node(t%root, key) |
|---|
| 70 | if (associated(nd)) then |
|---|
| 71 | nd%val = val |
|---|
| 72 | else ! This implementation is not optimal |
|---|
| 73 | t%root => insert(t%root, key, val, t%randstate) |
|---|
| 74 | t%randstate = xorshift32(t%randstate) |
|---|
| 75 | end if |
|---|
| 76 | end subroutine insert_or_assign |
|---|
| 77 | |
|---|
| 78 | subroutine remove(t, key) |
|---|
| 79 | implicit none |
|---|
| 80 | type(dict), intent(inout) :: t |
|---|
| 81 | keytype2, intent(in) :: key |
|---|
| 82 | t%root => erase(t%root, key) |
|---|
| 83 | end subroutine remove |
|---|
| 84 | |
|---|
| 85 | function get_kth_key(t, k) |
|---|
| 86 | implicit none |
|---|
| 87 | type(dict), intent(in) :: t |
|---|
| 88 | integer, intent(in) :: k |
|---|
| 89 | type(node), pointer :: res |
|---|
| 90 | keytype1 :: get_kth_key |
|---|
| 91 | if (k < 1 .or. k > my_count(t%root)) then |
|---|
| 92 | print *, "get_kth_key failed" |
|---|
| 93 | stop 2 |
|---|
| 94 | else |
|---|
| 95 | res => kth_node(t%root, k) |
|---|
| 96 | get_kth_key = res%key |
|---|
| 97 | end if |
|---|
| 98 | end function get_kth_key |
|---|
| 99 | |
|---|
| 100 | subroutine get_keys_vals(t, keys, vals, n) |
|---|
| 101 | implicit none |
|---|
| 102 | type(dict), intent(in) :: t |
|---|
| 103 | integer, intent(in) :: n |
|---|
| 104 | keytype2, intent(out) :: keys(n) |
|---|
| 105 | valtype, intent(out) :: vals(n) |
|---|
| 106 | integer :: counter |
|---|
| 107 | if (my_count(t%root) /= n) stop 5 |
|---|
| 108 | counter = 0 |
|---|
| 109 | call inorder(t%root, keys, vals, counter) |
|---|
| 110 | end subroutine get_keys_vals |
|---|
| 111 | |
|---|
| 112 | function get_size(t) |
|---|
| 113 | implicit none |
|---|
| 114 | type(dict), intent(in) :: t |
|---|
| 115 | integer :: get_size |
|---|
| 116 | get_size = my_count(t%root) |
|---|
| 117 | end function get_size |
|---|
| 118 | |
|---|
| 119 | subroutine destruct_dict(t) |
|---|
| 120 | implicit none |
|---|
| 121 | type(dict), intent(inout) :: t |
|---|
| 122 | call delete_all(t%root) |
|---|
| 123 | end subroutine destruct_dict |
|---|
| 124 | |
|---|
| 125 | subroutine update(root) |
|---|
| 126 | implicit none |
|---|
| 127 | type(node), pointer, intent(in) :: root |
|---|
| 128 | root%cnt = my_count(root%left) + my_count(root%right) + 1 |
|---|
| 129 | end subroutine update |
|---|
| 130 | |
|---|
| 131 | function my_count(root) |
|---|
| 132 | implicit none |
|---|
| 133 | type(node), pointer, intent(in) :: root |
|---|
| 134 | integer :: my_count |
|---|
| 135 | if (associated(root)) then |
|---|
| 136 | my_count = root%cnt |
|---|
| 137 | else |
|---|
| 138 | my_count = 0 |
|---|
| 139 | end if |
|---|
| 140 | end function my_count |
|---|
| 141 | |
|---|
| 142 | function rotate_ccw(root) |
|---|
| 143 | implicit none |
|---|
| 144 | type(node), pointer, intent(in) :: root |
|---|
| 145 | type(node), pointer :: tmp, rotate_ccw |
|---|
| 146 | if (.not. associated(root%right)) stop 1 |
|---|
| 147 | tmp => root%right |
|---|
| 148 | root%right => tmp%left |
|---|
| 149 | tmp%left => root |
|---|
| 150 | rotate_ccw => tmp |
|---|
| 151 | call update(root) |
|---|
| 152 | call update(tmp) |
|---|
| 153 | end function rotate_ccw |
|---|
| 154 | |
|---|
| 155 | function rotate_cw(root) |
|---|
| 156 | implicit none |
|---|
| 157 | type(node), pointer, intent(in) :: root |
|---|
| 158 | type(node), pointer :: tmp, rotate_cw |
|---|
| 159 | if (.not. associated(root%left)) stop 1 |
|---|
| 160 | tmp => root%left |
|---|
| 161 | root%left => tmp%right |
|---|
| 162 | tmp%right => root |
|---|
| 163 | rotate_cw => tmp |
|---|
| 164 | call update(root) |
|---|
| 165 | call update(tmp) |
|---|
| 166 | end function rotate_cw |
|---|
| 167 | |
|---|
| 168 | recursive function insert(root, key, val, pri) result(res) |
|---|
| 169 | implicit none |
|---|
| 170 | type(node), pointer, intent(in) :: root |
|---|
| 171 | integer, intent(in) :: pri |
|---|
| 172 | keytype2, intent(in) :: key |
|---|
| 173 | valtype, intent(in) :: val |
|---|
| 174 | type(node), pointer :: res |
|---|
| 175 | |
|---|
| 176 | if (.not. associated(root)) then |
|---|
| 177 | allocate(res) |
|---|
| 178 | res%key = key |
|---|
| 179 | res%pri = pri |
|---|
| 180 | res%val = val |
|---|
| 181 | else |
|---|
| 182 | res => root |
|---|
| 183 | if (key > root%key) then |
|---|
| 184 | root%right => insert(root%right, key, val, pri) |
|---|
| 185 | call update(root) |
|---|
| 186 | if (root%pri > root%right%pri) then |
|---|
| 187 | res => rotate_ccw(res) |
|---|
| 188 | end if |
|---|
| 189 | else |
|---|
| 190 | root%left => insert(root%left, key, val, pri) |
|---|
| 191 | call update(root) |
|---|
| 192 | if (root%pri > root%left%pri) then |
|---|
| 193 | res => rotate_cw(res) |
|---|
| 194 | end if |
|---|
| 195 | end if |
|---|
| 196 | end if |
|---|
| 197 | end function insert |
|---|
| 198 | |
|---|
| 199 | recursive function erase(root, key) result(res) |
|---|
| 200 | implicit none |
|---|
| 201 | type(node), pointer, intent(in) :: root |
|---|
| 202 | keytype2, intent(in) :: key |
|---|
| 203 | type(node), pointer :: res, tmp |
|---|
| 204 | |
|---|
| 205 | if (.not. associated(root)) then |
|---|
| 206 | print *, "Erase failed" |
|---|
| 207 | stop 1 |
|---|
| 208 | end if |
|---|
| 209 | |
|---|
| 210 | if (key < root%key) then |
|---|
| 211 | root%left => erase(root%left, key) |
|---|
| 212 | res => root |
|---|
| 213 | else if (key > root%key) then |
|---|
| 214 | root%right => erase(root%right, key) |
|---|
| 215 | res => root |
|---|
| 216 | else |
|---|
| 217 | if ((.not. associated(root%left)) .or. (.not. associated(root%right))) then |
|---|
| 218 | tmp => root |
|---|
| 219 | if (.not. associated(root%left)) then |
|---|
| 220 | res => root%right |
|---|
| 221 | else |
|---|
| 222 | res => root%left |
|---|
| 223 | end if |
|---|
| 224 | deallocate(tmp) |
|---|
| 225 | else |
|---|
| 226 | if (root%left%pri < root%right%pri) then |
|---|
| 227 | res => rotate_ccw(root) |
|---|
| 228 | res%left => erase(res%left, key) |
|---|
| 229 | else |
|---|
| 230 | res => rotate_cw(root) |
|---|
| 231 | res%right => erase(res%right, key) |
|---|
| 232 | end if |
|---|
| 233 | end if |
|---|
| 234 | end if |
|---|
| 235 | if (associated(res)) call update(res) |
|---|
| 236 | end function erase |
|---|
| 237 | |
|---|
| 238 | recursive function find_node(root, key) result(res) |
|---|
| 239 | implicit none |
|---|
| 240 | type(node), pointer, intent(in) :: root |
|---|
| 241 | keytype2, intent(in) :: key |
|---|
| 242 | type(node), pointer :: res |
|---|
| 243 | if (.not. associated(root)) then |
|---|
| 244 | res => null() |
|---|
| 245 | else if (root%key == key) then |
|---|
| 246 | res => root |
|---|
| 247 | else if (key < root%key) then |
|---|
| 248 | res => find_node(root%left, key) |
|---|
| 249 | else |
|---|
| 250 | res => find_node(root%right, key) |
|---|
| 251 | end if |
|---|
| 252 | end function find_node |
|---|
| 253 | |
|---|
| 254 | recursive function kth_node(root, k) result(res) |
|---|
| 255 | implicit none |
|---|
| 256 | type(node), pointer, intent(in) :: root |
|---|
| 257 | integer, intent(in) :: k |
|---|
| 258 | type(node), pointer :: res |
|---|
| 259 | if (.not. associated(root)) then |
|---|
| 260 | res => null() |
|---|
| 261 | else if (k <= my_count(root%left)) then |
|---|
| 262 | res => kth_node(root%left, k) |
|---|
| 263 | else if (k == my_count(root%left) + 1) then |
|---|
| 264 | res => root |
|---|
| 265 | else |
|---|
| 266 | res => kth_node(root%right, k - my_count(root%left) - 1) |
|---|
| 267 | end if |
|---|
| 268 | end function kth_node |
|---|
| 269 | |
|---|
| 270 | recursive subroutine delete_all(root) |
|---|
| 271 | implicit none |
|---|
| 272 | type(node), pointer, intent(inout) :: root |
|---|
| 273 | if (.not. associated(root)) return |
|---|
| 274 | |
|---|
| 275 | call delete_all(root%left) |
|---|
| 276 | call delete_all(root%right) |
|---|
| 277 | deallocate(root) |
|---|
| 278 | nullify(root) |
|---|
| 279 | end subroutine delete_all |
|---|
| 280 | |
|---|
| 281 | recursive subroutine inorder(root, keys, vals, counter) |
|---|
| 282 | implicit none |
|---|
| 283 | type(node), pointer, intent(in) :: root |
|---|
| 284 | keytype2, intent(inout) :: keys(:) |
|---|
| 285 | valtype, intent(inout) :: vals(:) |
|---|
| 286 | integer, intent(inout) :: counter |
|---|
| 287 | if (.not. associated(root)) return |
|---|
| 288 | |
|---|
| 289 | call inorder(root%left, keys, vals, counter) |
|---|
| 290 | counter = counter + 1 |
|---|
| 291 | keys(counter) = root%key |
|---|
| 292 | vals(counter) = root%val |
|---|
| 293 | call inorder(root%right, keys, vals, counter) |
|---|
| 294 | end subroutine inorder |
|---|
| 295 | |
|---|
| 296 | end module DICT_MODULENAME |
|---|