Get and replace elements in 2d array using recursion

replace_matrix(INth0, JNth0, NewVal, Matrix, NewMatrix) :-
    must_be(nonneg, INth0), 
    must_be(nonneg, JNth0),
    % I and J becomes the list of levels
    replace_matrix_(Matrix, [INth0, JNth0], NewVal, NewMatrix).

replace_matrix_([H|T], [Nth0|LevelsT], NewVal, NewMatrix) :-
    % Nth0 is the current position on current level
    (   Nth0 > 0
    ->  Nth0Prev is Nth0 - 1,
        NewMatrix = [H|NewMatrix1],
        Levels1 = [Nth0Prev|LevelsT],
        % Iterate through the list to the desired element
        replace_matrix_(T, Levels1, NewVal, NewMatrix1)
    ;   LevelsT = []
        % Replace this element, because am at the lowest level
    ->  NewMatrix = [NewVal|T]
    ;   NewMatrix = [NewElem|T],
        % Drill down into next level of the matrix
        replace_matrix_(H, LevelsT, NewVal, NewElem)
    ).

Result:

?- time(replace_matrix(0, 1, X, [[1,2,3],[4,5,6],[7,8,9]],Result)).
% 15 inferences, 0.000 CPU in 0.000 seconds (88% CPU, 319373 Lips)
Result = [[1,X,3],[4,5,6],[7,8,9]].

Note the changed argument order.