Here I will focus only on finding the lexicographic argmax
(the others: max
, argmin
, and min
can be found trivially from argmax
). In addition, unlike np.argmax()
, we will return all rows that are at rank 0 (if there are duplicate rows), i.e. all the indices where the row is the lexicographic maximum.
The idea is that, for the "tuple-like order" desired here, the function is really:
- find all indices where the first column has the maximum;
- break ties with the places where the second column is max, under condition that the first column is max;
- etc., as long as there are ties to break (and more columns).
def ixmax(x, k=0, idx=None):
col = x[idx, k] if idx is not None else x[:, k]
z = np.where(col == col.max())[0]
return z if idx is None else idx[z]
def lexargmax(x):
idx = None
for k in range(x.shape[1]):
idx = ixmax(x, k, idx)
if len(idx) < 2:
break
return idx
At first, I was worried that the explicit looping in Python would kill it. But it turns out that it is quite fast. In the case where there is no ties (more likely with independent float
values, for instance), that returns immediately after a single np.where(x[:, 0] == x[:, 0].max())
. Only in the case of ties do we need to look at the (much smaller) subset of rows that were tied. In unfavorable conditions (many repeated values in all columns), it is still ~100x or more than the partition method, and O(log n)
faster than lexsort()
, of course.
Test 1: correctness
for i in range(1000):
x = np.random.randint(0, 10, size=(1000, 8))
found = lexargmax(x)
assert lexargmax_by_sort(x) in found and np.unique(x[found], axis=0).shape[0] == 1
(where lexargmark_by_sort
is np.lexsort(x[:, ::-1].T)[-1]
)
Test 2: speed
x = np.random.randint(0, 10, size=(100_000, 100))
a = %timeit -o lexargmax(x)
# 776 μs ± 313 ns per loop
b = %timeit -o lexargmax_by_sort(x)
# 507 ms ± 2.65 ms per loop
# b.average / a.average: 652
c = %timeit -o lexargmax_by_partition(x)
# 141 ms ± 2.38 ms
# c.average / a.average: 182
(where lexargmark_by_partition
is based on @MadPhysicist very elegant idea:
def lexargmax_by_partition(x):
view = np.ndarray(x.shape[0], dtype=[('', x.dtype)] * x.shape[1], buffer=x)
return np.argpartition(view, -1)[-1]
)
After some more testing on various sizes, we get the following time measurements and performance ratios:
In the LHS plot, lexargmax
is the group shown with 'o-'
and lexargmax_by_partition
is the upper group of lines.
In the RHS plot, we just show the speed ratio.
Interestingly, lexargmax_by_partition
execution time seems fairly independent of m
, the number of columns, whereas our lexargmax
depends a little bit on it. I believe that is reflecting the fact that, in this setting (purposeful collisions of max in each column), the more columns we have, the "deeper" we need to go when breaking ties.
Previous (wrong) answer
To find the argmax
of the row by lexicographic order, I was thinking you could do:
def lexmax(x):
r = (2.0 ** np.arange(x.shape[1]))[::-1]
return np.argmax(((x == x.max(axis=0)) * r).sum(axis=1))
Explanation:
But this is fatally flawed: The positions of max
in the second column should be considered only in the subset where the first column had the max value (to break the tie).
Other approaches including affine transformations of all columns so that we can sum them and find the max don't work either: if the max in column 0 is, say, 1.0
, and there is a second place at 0.999
, then we would have to know ahead of time that difference of 0.001
and make sure no combination of values from the columns to the right to sum up to overtake that difference. So, that's a dead end.