source: LMDZ6/branches/Optimisation_LMDZ/libf/misc/dict/dict_mod.f90.h @ 3705

Last change on this file since 3705 was 3705, checked in by adurocher, 4 years ago

Added new hashtable module

File size: 7.6 KB
Line 
1module DICT_MODULENAME
2  ! Dictionary data structure
3
4  implicit none
5
6  private
7  public :: dict, get_val, insert_or_assign, exists, remove, get_keys_vals, get_size, get_kth_key
8
9  type dict
10    type(node), pointer :: root => null()
11    integer :: randstate = 1231767121
12    contains
13    final :: destruct_dict
14  end type dict
15
16  type node
17    type(node), pointer :: left => null(), right => null()
18    keytype1 :: key
19    valtype :: val
20    integer :: pri  ! min-heap
21    integer :: cnt = 1
22  end type node
23
24  contains
25
26  pure function xorshift32(i)
27    implicit none
28    integer(4), intent(in) :: i
29    integer(4) :: xorshift32
30    if (i == 0) then
31      xorshift32 = 1231767121
32    else
33      xorshift32 = i
34    end if
35    xorshift32 = ieor(xorshift32, ishft(xorshift32, 13))
36    xorshift32 = ieor(xorshift32, ishft(xorshift32, -17))
37    xorshift32 = ieor(xorshift32, ishft(xorshift32, 15))
38  end function xorshift32
39
40  function get_val(t, key)
41    implicit none
42    type(dict), intent(in) :: t
43    keytype2, intent(in) :: key
44    type(node), pointer :: nd
45    valtype :: get_val
46    nd => find_node(t%root, key)
47    if (.not. associated(nd)) then
48      stop 105
49    end if
50    get_val = nd%val
51  end function get_val
52
53  function exists(t, key)
54    implicit none
55    type(dict), intent(in) :: t
56    keytype2, intent(in) :: key
57    type(node), pointer :: nd
58    logical :: exists
59    nd => find_node(t%root, key)
60    exists = (associated(nd))
61  end function exists
62
63  subroutine insert_or_assign(t, key, val)
64    implicit none
65    type(dict), intent(inout) :: t
66    keytype2, intent(in) :: key
67    valtype, intent(in) :: val
68    type(node), pointer :: nd
69    nd => find_node(t%root, key)
70    if (associated(nd)) then
71      nd%val = val
72    else  ! This implementation is not optimal
73      t%root => insert(t%root, key, val, t%randstate)
74      t%randstate = xorshift32(t%randstate)
75    end if
76  end subroutine insert_or_assign
77
78  subroutine remove(t, key)
79    implicit none
80    type(dict), intent(inout) :: t
81    keytype2, intent(in) :: key
82    t%root => erase(t%root, key)
83  end subroutine remove
84
85  function get_kth_key(t, k)
86    implicit none
87    type(dict), intent(in) :: t
88    integer, intent(in) :: k
89    type(node), pointer :: res
90    keytype1 :: get_kth_key
91    if (k < 1 .or. k > my_count(t%root)) then
92      print *, "get_kth_key failed"
93      stop 2
94    else
95      res => kth_node(t%root, k)
96      get_kth_key = res%key
97    end if
98  end function get_kth_key
99
100  subroutine get_keys_vals(t, keys, vals, n)
101    implicit none
102    type(dict), intent(in) :: t
103    integer, intent(in) :: n
104    keytype2, intent(out) :: keys(n)
105    valtype, intent(out) :: vals(n)
106    integer :: counter
107    if (my_count(t%root) /= n) stop 5
108    counter = 0
109    call inorder(t%root, keys, vals, counter)
110  end subroutine get_keys_vals
111
112  function get_size(t)
113    implicit none
114    type(dict), intent(in) :: t
115    integer :: get_size
116    get_size = my_count(t%root)
117  end function get_size
118
119  subroutine destruct_dict(t)
120    implicit none
121    type(dict), intent(inout) :: t
122    call delete_all(t%root)
123  end subroutine destruct_dict
124
125  subroutine update(root)
126    implicit none
127    type(node), pointer, intent(in) :: root
128    root%cnt = my_count(root%left) + my_count(root%right) + 1
129  end subroutine update
130
131  function my_count(root)
132    implicit none
133    type(node), pointer, intent(in) :: root
134    integer :: my_count
135    if (associated(root)) then
136      my_count = root%cnt
137    else
138      my_count = 0
139    end if
140  end function my_count
141
142  function rotate_ccw(root)
143    implicit none
144    type(node), pointer, intent(in) :: root
145    type(node), pointer :: tmp, rotate_ccw
146    if (.not. associated(root%right)) stop 1
147    tmp => root%right
148    root%right => tmp%left
149    tmp%left => root
150    rotate_ccw => tmp
151    call update(root)
152    call update(tmp)
153  end function rotate_ccw
154
155  function rotate_cw(root)
156    implicit none
157    type(node), pointer, intent(in) :: root
158    type(node), pointer :: tmp, rotate_cw
159    if (.not. associated(root%left)) stop 1
160    tmp => root%left
161    root%left => tmp%right
162    tmp%right => root
163    rotate_cw => tmp
164    call update(root)
165    call update(tmp)
166  end function rotate_cw
167
168  recursive function insert(root, key, val, pri) result(res)
169    implicit none
170    type(node), pointer, intent(in) :: root
171    integer, intent(in) :: pri
172    keytype2, intent(in) :: key
173    valtype, intent(in) :: val
174    type(node), pointer :: res
175
176    if (.not. associated(root)) then
177      allocate(res)
178      res%key = key
179      res%pri = pri
180      res%val = val
181    else
182      res => root
183      if (key > root%key) then
184        root%right => insert(root%right, key, val, pri)
185        call update(root)
186        if (root%pri > root%right%pri) then
187          res => rotate_ccw(res)
188        end if
189      else
190        root%left => insert(root%left, key, val, pri)
191        call update(root)
192        if (root%pri > root%left%pri) then
193          res => rotate_cw(res)
194        end if
195      end if
196    end if
197  end function insert
198
199  recursive function erase(root, key) result(res)
200    implicit none
201    type(node), pointer, intent(in) :: root
202    keytype2, intent(in) :: key
203    type(node), pointer :: res, tmp
204
205    if (.not. associated(root)) then
206      print *, "Erase failed"
207      stop 1
208    end if
209
210    if (key < root%key) then
211      root%left => erase(root%left, key)
212      res => root
213    else if (key > root%key) then
214      root%right => erase(root%right, key)
215      res => root
216    else
217      if ((.not. associated(root%left)) .or. (.not. associated(root%right))) then
218        tmp => root
219        if (.not. associated(root%left)) then
220          res => root%right
221        else
222          res => root%left
223        end if
224        deallocate(tmp)
225      else
226        if (root%left%pri < root%right%pri) then
227          res => rotate_ccw(root)
228          res%left => erase(res%left, key)
229        else
230          res => rotate_cw(root)
231          res%right => erase(res%right, key)
232        end if
233      end if
234    end if
235    if (associated(res)) call update(res)
236  end function erase
237
238  recursive function find_node(root, key) result(res)
239    implicit none
240    type(node), pointer, intent(in) :: root
241    keytype2, intent(in) :: key
242    type(node), pointer :: res
243    if (.not. associated(root)) then
244      res => null()
245    else if (root%key == key) then
246      res => root
247    else if (key < root%key) then
248      res => find_node(root%left, key)
249    else
250      res => find_node(root%right, key)
251    end if
252  end function find_node
253
254  recursive function kth_node(root, k) result(res)
255    implicit none
256    type(node), pointer, intent(in) :: root
257    integer, intent(in) :: k
258    type(node), pointer :: res
259    if (.not. associated(root)) then
260      res => null()
261    else if (k <= my_count(root%left)) then
262      res => kth_node(root%left, k)
263    else if (k == my_count(root%left) + 1) then
264      res => root
265    else
266      res => kth_node(root%right, k - my_count(root%left) - 1)
267    end if
268  end function kth_node
269
270  recursive subroutine delete_all(root)
271    implicit none
272    type(node), pointer, intent(inout) :: root
273    if (.not. associated(root)) return
274
275    call delete_all(root%left)
276    call delete_all(root%right)
277    deallocate(root)
278    nullify(root)
279  end subroutine delete_all
280
281  recursive subroutine inorder(root, keys, vals, counter)
282    implicit none
283    type(node), pointer, intent(in) :: root
284    keytype2, intent(inout) :: keys(:)
285    valtype, intent(inout) :: vals(:)
286    integer, intent(inout) :: counter
287    if (.not. associated(root)) return
288
289    call inorder(root%left, keys, vals, counter)
290    counter = counter + 1
291    keys(counter) = root%key
292    vals(counter) = root%val
293    call inorder(root%right, keys, vals, counter)
294  end subroutine inorder
295
296end module DICT_MODULENAME
Note: See TracBrowser for help on using the repository browser.