Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
103 views
in Technique[技术] by (71.8m points)

python - Tuple-like (lexographical) max in numpy

I find myself running into the following situation in numpy muliple times over the past couple of months, and I cannot imagine there is not a proper solution for it.

I have a 2d array, let's say

x = np.array([
 [1, 2, 3],
 [2, -5, .333],
 [1, 4, 2],
 [2, -5, 4]])

Now I would like to (sort/get the maximum/do argsort/argmax/ etc) this array in such a way that it first compares the first column. If the first column is equal, it compares the second column, and then the third. So this means for our example:

# max like python: max(x.tolist())
np.tuple_like_amax(x) = np.array([2, -5, 4])
# argmax does't have python equivalent, but something like: [i for i, e in enumerate(x.tolist()) if e == max(x.tolist())][0]
np.tuple_like_argmax = 3

# sorting like python: sorted(x.tolist())
np.tuple_like_sort(x) = np.array([[1.0, 2.0, 3.0], [1.0, 4.0, 2.0], [2.0, -5.0, 0.333], [2.0, -5.0, 4.0]])

# argsort doesn't have python equivalent, but something like: sorted(range(len(x)), key=lambda i: x[i].tolist())
np.tuple_like_argsort(x) = np.array([0, 2, 1, 3])

This is exactly the way how python compares tuples (so actually just calling max(x.tolist()) does the trick here for max. It does feel however like a time-and-memory waste to first convert the array to a python list, and in addition I would like to use things like argmax, sort and all the other great numpy functions.

So just to be clear, I'm not interested in python code that mimics an argmax, but for something that achieves this without converting the lists to python lists.

Found so far:

np.sort seems to work on structured arrays when order= is given. It does feel to me that creating a structured array and then using this method is overkill. Also, argmax doesn't seem to support this, meaning that one would have to use argsort, which has a much higher complexity.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

Here I will focus only on finding the lexicographic argmax (the others: max, argmin, and min can be found trivially from argmax). In addition, unlike np.argmax(), we will return all rows that are at rank 0 (if there are duplicate rows), i.e. all the indices where the row is the lexicographic maximum.

The idea is that, for the "tuple-like order" desired here, the function is really:

  • find all indices where the first column has the maximum;
  • break ties with the places where the second column is max, under condition that the first column is max;
  • etc., as long as there are ties to break (and more columns).
def ixmax(x, k=0, idx=None):
    col = x[idx, k] if idx is not None else x[:, k]
    z = np.where(col == col.max())[0]
    return z if idx is None else idx[z]

def lexargmax(x):
    idx = None
    for k in range(x.shape[1]):
        idx = ixmax(x, k, idx)
        if len(idx) < 2:
            break
    return idx

At first, I was worried that the explicit looping in Python would kill it. But it turns out that it is quite fast. In the case where there is no ties (more likely with independent float values, for instance), that returns immediately after a single np.where(x[:, 0] == x[:, 0].max()). Only in the case of ties do we need to look at the (much smaller) subset of rows that were tied. In unfavorable conditions (many repeated values in all columns), it is still ~100x or more than the partition method, and O(log n) faster than lexsort(), of course.

Test 1: correctness

for i in range(1000):
    x = np.random.randint(0, 10, size=(1000, 8))
    found = lexargmax(x)
    assert lexargmax_by_sort(x) in found and np.unique(x[found], axis=0).shape[0] == 1

(where lexargmark_by_sort is np.lexsort(x[:, ::-1].T)[-1])

Test 2: speed

x = np.random.randint(0, 10, size=(100_000, 100))

a = %timeit -o lexargmax(x)
# 776 μs ± 313 ns per loop

b = %timeit -o lexargmax_by_sort(x)
# 507 ms ± 2.65 ms per loop
# b.average / a.average: 652
 
c = %timeit -o lexargmax_by_partition(x)
# 141 ms ± 2.38 ms
# c.average / a.average: 182

(where lexargmark_by_partition is based on @MadPhysicist very elegant idea:

def lexargmax_by_partition(x):
    view = np.ndarray(x.shape[0], dtype=[('', x.dtype)] * x.shape[1], buffer=x)
    return np.argpartition(view, -1)[-1]

)

After some more testing on various sizes, we get the following time measurements and performance ratios:

enter image description here

In the LHS plot, lexargmax is the group shown with 'o-' and lexargmax_by_partition is the upper group of lines.

In the RHS plot, we just show the speed ratio.

Interestingly, lexargmax_by_partition execution time seems fairly independent of m, the number of columns, whereas our lexargmax depends a little bit on it. I believe that is reflecting the fact that, in this setting (purposeful collisions of max in each column), the more columns we have, the "deeper" we need to go when breaking ties.

Previous (wrong) answer

To find the argmax of the row by lexicographic order, I was thinking you could do:

def lexmax(x):
    r = (2.0 ** np.arange(x.shape[1]))[::-1]
    return np.argmax(((x == x.max(axis=0)) * r).sum(axis=1))

Explanation:

  • x == x.max(axis=0) (as an int) is 1 for each element that is equal to the column's max. In your example, it is (astype(int)):
    [[0 0 0]
      [1 0 0]
      [0 1 0]
      [1 0 1]]
    
  • then we multiply by a column weight that is more than the sum of 1's on the right. Powers of two achieve that. We do it in float to address cases with more than 64 columns.

But this is fatally flawed: The positions of max in the second column should be considered only in the subset where the first column had the max value (to break the tie).

Other approaches including affine transformations of all columns so that we can sum them and find the max don't work either: if the max in column 0 is, say, 1.0, and there is a second place at 0.999, then we would have to know ahead of time that difference of 0.001 and make sure no combination of values from the columns to the right to sum up to overtake that difference. So, that's a dead end.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...