Although you often see the following as an example of converting factorial to tail-call:
int factorial(int n, int acc=1) {
if (n <= 1) return acc;
else return factorial(n-1, n*acc);
}
it's not quite correct, since it requires multiplication to be both associative and commutative. (Multiplication is associative and commutative, but the above doesn't serve as a model for other operations which don't satisfy those constraints.) A better solution might be:
int factorial(int n, int k=1, int acc=1) {
if (n == 0) return acc;
else return factorial(n-1, k+1, acc*k);
}
This also serves as a model for the fibonacci transform:
int fibonacci(int n, int a=1, int b=0) {
if (n == 0) return a;
else return fibonacci(n-1, a+b, a);
}
Note that these compute the sequence starting at the beginning, as opposed to queueing pending continuations in a call stack. So they are structurally more like the iterative solution than the recursive solution. Unlike the iterative program, though, they never modify any variable; all bindings are constant. This is an interesting and useful property; in these simple cases it doesn't make much difference, but writing code without reassignments makes some compiler optimizations easier.
Anyway, the last one does provide a model for your recursive function; like the fibonacci sequence, we need to keep the relevant past values, but we need three of them instead of two:
int mouse(int n, int a=1, int b=1, int c=1) {
if (n <=2 ) return a;
else return mouse(n-1, a*c+1, a, b);
}
Addenda
In comments, two questions were raised. I'll try to answer them (and one more) here.
First, it should be clear (from a consideration of the underlying machine architecture which has no concept of function calling) that any function call can be rephrased as a goto (possibly with non-bounded intermediate storage); furthermore, any goto can be expressed as a tail-call. So it is possible (but not necessarily pretty) to rewrite any recursion as tail-recursion.
The usual mechanism is "continuation-passing style" which is a fancy way of saying that every time you want to call a function, you instead package the rest of the current function as a new function (the "continuation"), and pass that continuation to the called function. Since every function then receives a continuation as an argument, it has to finish any continuation it creates with a call to the continuation it received.
That's probably enough to make your head spin, so I'll put it another way: instead of pushing arguments and a return location onto the stack and calling a function (which will later return), you push arguments and a continuation location onto the stack and goto a function, which will later goto the continuation location. In short, you simply make the stack an explicit parameter, and then you never need to return. This style of programming is common in event-driven code (see Python Twisted), and it's a real pain to write (and read). So I strongly recommend letting compilers do this transformation for you, if you can find one which will do that.
@xxmouse suggested that I pulled the recursion equation out of a hat, and asked how it was derived. It's simply the original recursion, but reformulated as a transformation of a single tuple:
fn = fn-1*fn-3 + 1
=>
Fn = <Fn-11*Fn-13+1, Fn-11, Fn-12>
I don't know if that's any clearer, but it's the best I can do. Look at the fibonacci example for a slightly simpler case.
@j_random_hacker asks what the limits on this transformation are. It works for a recursive sequence where each element can be expressed by some formula of the previous k
elements, where k
is a constant. There are other ways to produce a tail-call recursion. For example:
// For didactic purposes only
bool is_odd(int n) { return n%2 == 1; }
int power(int x, int n, int acc=1) {
if (n == 0) return acc;
else if (is_odd(n)) return power(x, n-1, acc*x);
else return power(x*x, n/2, acc);
}
The above is not the same as the usual non-tail-call recursion, which does a different (but equivalent and equally long) sequence of multiplications.
int squared(n) { return n * n; }
int power(int x, int n) {
if (n == 0) return 1;
else if (is_odd(n)) return x * power(x, n-1));
else return squared(power(x, n/2));
}
Thanks to Alexey Frunze for the following test:
Output (ideone):
mouse(0) = 1
mouse(1) = 1
mouse(2) = 1
mouse(3) = 2
mouse(4) = 3
mouse(5) = 4
mouse(6) = 9
mouse(7) = 28
mouse(8) = 113
mouse(9) = 1018