import java.util.List;
import java.util.LinkedList;
import java.util.Arrays;
import java.util.PriorityQueue;
/**
* Taxicab number is the one that can be represented as a sum of cubes
* of two numbers in two different ways: a^3 + b^3 = c^3 + d^3 = N, where
* (a, b) and (c, d) pairs differ not only in the ordering.
*
* The basic idea is to enumerate all pairs (i, j) for 1 < i, j < N
* and find the ones that are repeated more than once. An easy way to
* do this is to store every pair in the array, sort the array and
* traverse it looking for the repeated pairs.
*
* But actually we can do a lot better in terms of extra memory used.
* We could use a PQ at first to store the pairs from (1, 1) to (N, N)
* and then traverse it adding new elements that would be traversed later
* on fly.
*
* (1, 1) (1, 2) ... (1, N)
* (2, 1) (2, 2) ... (2, N)
* ...
* (N, 1) (N, 2) ... (N, N)
*
* In each cell (i, j) there is a value i^3 + j^3. The thing is that
* in every row and in every column the values are sorted in the increasing
* order. And the elements from the matrix could be traversed in the
* following order: (1, 1), (2, 1), (2, 2), (3, 1), (3, 2) and so on.
*
* So after retrieving a next element from the PQ we can add a new
* one that would be by one cell lower in the matrix. Thus we will maintain
* the number of elements stored in the PQ at N and also as a result
* we will traverse every single element from the matrix under the main
* diagonal in the increasing order.
*/
class TaxicabNumbers {
/**
* Finds every taxicab number that is less than {@code n} in
* O(N^2lgN) time and using O(N^2) extra space
*/
public static List<TaxicabNumber> taxicabNumbersSortVersion(int n) {
// 0 is not a taxicab number, so we can safely skip it
if (n <= 0) {
"Expected: n > 0. Got: %d.", n));
}
List<TaxicabNumber> taxicabNumbers = new LinkedList<>();
// find an upper bound for n^(1/3)
int maxPairNumber = 1;
while (pow(maxPairNumber, 3) <= n) {
maxPairNumber++;
}
// 1 + 2 + 3 + ... + n = n * (n + 1) / 2
// 'cause we ignore the elements above the main diagonal
IntPair[] pairs
= new IntPair[maxPairNumber * (maxPairNumber + 1) / 2];
for (int i = 0, iArray = 0; i < maxPairNumber; i++) {
// ignore the elements above the main diagonal
for (int j = 0; j <= i; j++) {
pairs[iArray] = new IntPair(i, j);
iArray++;
}
}
// the number of pairs with equal sums in a row
int runningNumber = 1;
IntPair prev = new IntPair(0, 0);
for (int i = 0; i < pairs.length; i++) {
IntPair curr = pairs[i];
if (prev.sum == curr.sum) {
runningNumber++;
if (runningNumber == 2) {
taxicabNumbers.add(new TaxicabNumber(prev, curr));
}
} else {
runningNumber = 1;
}
prev = curr;
}
return taxicabNumbers;
}
public static List<TaxicabNumber> taxicabNumbersHeapVersion(int n) {
if (n <= 0) {
"Expected: n > 0. Got: %d.", n));
}
List<TaxicabNumber> taxicabNumbers = new LinkedList<>();
PriorityQueue<IntPair> pairs = new PriorityQueue<>();
// find an upper bound for n^(1/3)
int maxPairNumber = 1;
while (pow(maxPairNumber, 3) < n) {
maxPairNumber++;
}
for (int i = 0; i <= maxPairNumber; i++) {
pairs.offer(new IntPair(i, i));
}
// the number of pairs with equal sums in a row
int runningNumber = 1;
IntPair prev = new IntPair(0, 0);
while (!pairs.isEmpty()) {
IntPair curr = pairs.poll();
if (prev.sum == curr.sum) {
runningNumber++;
if (runningNumber == 2) {
taxicabNumbers.add(new TaxicabNumber(prev, curr));
}
} else {
runningNumber = 1;
}
if (curr.i < maxPairNumber) {
pairs.offer(new IntPair(curr.i + 1, curr.j));
}
prev = curr;
}
return taxicabNumbers;
}
private static long pow(long x, int n) {
assert n >= 0;
long pow = 1;
for (int i = 0; i < n; i++) {
pow *= x;
}
return pow;
}
private static class IntPair implements Comparable<IntPair> {
private long i;
private long j;
private long sum;
public IntPair(long i, long j) {
assert i > 0 && j > 0;
this.i = i;
this.j = j;
sum = i * i * i + j * j * j;
}
/**
* (a, b) < (c, d) iff a^3 + b^3 ^lt; c^3 + d^3, or
* in case the sums are equal, then the smaller pair is the
* one that has smaller value on the first place (that is
* a < c)
*/
@Override
public int compareTo(IntPair other) {
int sumComparison
= Long.
compare(this.
sum, other.
sum); if (sumComparison == 0) {
return Long.
compare(this.
i, other.
i); } else {
return sumComparison;
}
}
@Override
return String.
format("%d^3 + %d^3", i, j
); }
}
private static class TaxicabNumber {
private IntPair representation1;
private IntPair representation2;
public TaxicabNumber(IntPair representation1,
IntPair representation2) {
assert representation1.sum == representation2.sum;
this.representation1 = representation1;
this.representation2 = representation2;
}
@Override
return String.
format("%s = %s = %d", representation1,
representation2, representation1.sum);
}
}
public static void main
(String[] args
) { System.
out.
println(taxicabNumbersSortVersion
(65_000
)); System.
out.
println(taxicabNumbersHeapVersion
(65_000
)); }
}