Advice on implementing Dijkstra's algorithm

So, I have managed to use compound terms to store the state of the dijkstra algorithm, as well as using compound terms to store the graph and another for its weights.
I found that it will significantly reduce the runtime of the algorithm and be somewhat close (albeit slower) to the runtime of networkx !
Now, the runtime of the algorithm is dominated by the min heap operations.

I have made complete test script with both the red black tree and compound term approach as well as using networkx with the awesome janus interface:

test script
:- use_module(library(rbtrees)).
:- use_module(library(heaps)).

add(A, B, C) :- C is A+B.

update_distances1(Start, End, NewDistance, TreeIn-HeapIn, TreeOut-HeapOut) :-
   (  rb_lookup(End, Val, TreeIn)
   -> (  Val = unvisited(CurrentDistance, _),
         NewDistance < CurrentDistance
      -> rb_update(TreeIn, End, unvisited(NewDistance, Start), TreeOut),
         add_to_heap(HeapIn, NewDistance, End, HeapOut)
      ;  TreeIn = TreeOut,
         HeapIn = HeapOut
      )
   ;  rb_insert(TreeIn, End, unvisited(NewDistance, Start), TreeOut),
      add_to_heap(HeapIn, NewDistance, End, HeapOut)
   ).

get_unvisited_min_node1(Tree, HeapIn, P, K, HeapOut) :-
   get_from_heap(HeapIn, PTmp, KTmp, HeapTmp),
   (  rb_lookup(KTmp, Val, Tree)
   -> (  Val = unvisited(_, _)
      -> HeapTmp = HeapOut,
         P = PTmp, K = KTmp
      ;  get_unvisited_min_node1(Tree, HeapTmp, P, K, HeapOut)
      )
   ;  HeapTmp = HeapOut,
      P = PTmp, K = KTmp
   ).

build_path1(Tree, From, To, From) :-
   rb_lookup(To, visited(From), Tree).

:- meta_predicate dijkstra1(+, 3, +, +, -).

dijkstra1(Graph, Edge, Start, End, Path) :-
   ord_list_to_rbtree(Graph, GraphTree),
   empty_heap(Heap),
   list_to_rbtree([Start-unvisited(0, no)], TreeIn),
   dijkstra_1(GraphTree, Edge, Start, End, 0, TreeIn, TreeOut, Heap),
   once(foldl(build_path1(TreeOut), ReversePath, End, Start)),
   reverse([End | ReversePath], Path).

dijkstra_1(Graph, Edge, Start, End, CurrentDistance, TreeIn, TreeOut, Heap) :-
   rb_lookup(Start, Neighbours, Graph),
   maplist(call(Edge, Start), Neighbours, NeighboursWeights),
   maplist(add(CurrentDistance), NeighboursWeights, NeighboursDistances),
   foldl(update_distances1(Start),
         Neighbours, NeighboursDistances,
         TreeIn-Heap, Tree2-Heap2),
   rb_update(Tree2, Start, unvisited(_, StartPrev), visited(StartPrev), Tree3),
   get_unvisited_min_node1(Tree3, Heap2, NextDistance, NextNode, Heap3),
   (  NextNode == End
   -> rb_update(Tree3, End, unvisited(_, EndPrev), visited(EndPrev), TreeOut)
   ;  dijkstra_1(Graph, Edge, NextNode, End, NextDistance, Tree3, TreeOut, Heap3)
   ).

update_distances2(State, Start, End, NewDistance, HeapIn, HeapOut) :-
   arg(End, State, Val),
   (  nonvar(Val)
   -> (  Val = unvisited(CurrentDistance, _),
         NewDistance < CurrentDistance
      -> nb_setarg(End, State, unvisited(NewDistance, Start)),
         add_to_heap(HeapIn, NewDistance, End, HeapOut)
      ;  HeapIn = HeapOut
      )
   ;  arg(End, State, unvisited(NewDistance, Start)),
      add_to_heap(HeapIn, NewDistance, End, HeapOut)
   ).

get_unvisited_min_node2(State, HeapIn, P, K, HeapOut) :-
   get_from_heap(HeapIn, PTmp, KTmp, HeapTmp),
   (  arg(KTmp, State, Val)
   -> (  Val = unvisited(_, _)
      -> HeapTmp = HeapOut,
         P = PTmp, K = KTmp
      ;  get_unvisited_min_node2(State, HeapTmp, P, K, HeapOut)
      )
   ;  HeapTmp = HeapOut,
      P = PTmp, K = KTmp
   ).

build_path2(State, From, To, From) :-
   arg(To, State, visited(From)).

add_to_heap_(Default, Key, HeapIn, HeapOut) :-
   add_to_heap(HeapIn, Default, Key, HeapOut).

:- meta_predicate dijkstra2(+, 3, +, +, -).

dijkstra2(Graph, Edge, Start, End, Path) :-
   pairs_values(Graph, Neighbours),
   compound_name_arguments(GraphCompound, graph, Neighbours),
   empty_heap(Heap),
   length(Graph, NumNodes),
   compound_name_arity(State, state, NumNodes),
   arg(Start, State, unvisited(0, no)),
   dijkstra_2(GraphCompound, Edge, Start, End, 0, State, Heap),
   once(foldl(build_path2(State), ReversePath, End, Start)),
   reverse([End | ReversePath], Path).

dijkstra_2(Graph, Edge, Start, End, CurrentDistance, State, Heap) :-
   arg(Start, Graph, Neighbours),
   maplist(call(Edge, Start), Neighbours, NeighboursWeights),
   maplist(add(CurrentDistance), NeighboursWeights, NeighboursDistances),
   foldl(update_distances2(State, Start),
         Neighbours, NeighboursDistances,
         Heap, Heap2),
   arg(Start, State, unvisited(_, StartPrev)),
   nb_setarg(Start, State, visited(StartPrev)),
   get_unvisited_min_node2(State, Heap2, NextDistance, NextNode, Heap3),
   (  NextNode == End
   -> arg(End, State, unvisited(_, EndPrev)),
      nb_setarg(End, State, visited(EndPrev))
   ;  dijkstra_2(Graph, Edge, NextNode, End, NextDistance, State, Heap3)
   ).

random_directed_edge(Low, High, Weights, Start-End, -(Start, End, Weight), (Start-End)-Weight) :-
   random_between(Low, High, Start),
   random_between(Low, High, End),
   random_between(Low, High, Weight),
   arg(Start, Weights, Ends),
   memberchk(End-Weight, Ends).

weight1(Weights, From, To, Distance) :-
   rb_lookup(From-To, Distance, Weights).
weight2(Weights, From, To, Distance) :-
   arg(From, Weights, Ends),
   memberchk(To-Distance, Ends).

random_graph(NumNodes, NumEdges, UGraph, weight1(RbWeights), weight2(Weights), NxGraph) :-
   numlist(1, NumNodes, Nodes),
   length(Edges, NumEdges),
   compound_name_arity(Weights, weights, NumNodes),
   maplist(random_directed_edge(1, NumNodes, Weights), Edges, Tuples, Pairs),
   list_to_rbtree(Pairs, RbWeights),
   vertices_edges_to_ugraph(Nodes, Edges, UGraph),
   py_call(networkx:'DiGraph'(), NxGraph, [py_object(true)]),
   py_call(NxGraph:add_weighted_edges_from(Tuples)).

main(Seed, X) :-
   N is 10**X,
   M is N*2,
   set_random(seed(Seed)),
   random_graph(N, M, UGraph, Edge1, Edge2, NxGraph),
   random_between(1, N, Start),
   random_between(1, N, End),
   format("Prolog dijkstra 1~n"),
   time(dijkstra1(UGraph, Edge1, Start, End, Path)),
   format("Prolog dijkstra 2~n"),
   time(dijkstra2(UGraph, Edge2, Start, End, Path)),
   format("Networkx dijkstra~n"),
   time(py_call(networkx:dijkstra_path(NxGraph, Start, End))).

Here are some notable results:

?- set_prolog_flag(stack_limit, 2_147_483_648).
true.
?- main(1, 6). % run dijkstra on random graph with 1M nodes and 2M edges
Prolog dijkstra 1 % red black tree approach
% 455,909,369 inferences, 44.604 CPU in 44.738 seconds (100% CPU, 10221305 Lips)
Prolog dijkstra 2 % compound term approach
% 88,561,958 inferences, 9.340 CPU in 9.367 seconds (100% CPU, 9481945 Lips)
Networkx dijkstra
% -1 inferences, 6.474 CPU in 6.495 seconds (100% CPU, 0 Lips)
true.
3 Likes