Faced an interesting problem recently

a : (B, S, T) b : (B, C) where 0 <= x[i, j] < S

What I want is an array of shape (B, C, T)

a = np.array( ...: [[[0,1,2,3], ...: [4,5,6,7], ...: [8,9,10,11]], ...: [[0,1,2,3], ...: [4,5,6,7], ...: [8,9,10,11]]]) b = np.array( ...: [[0,2,2], ...: [1,0, 2]])

a.shape Out[79]: (2, 3, 4) b.shape Out[80]: (2, 3)

What I expect is this

array([[[ 0, 1, 2, 3], [ 8, 9, 10, 11], [ 8, 9, 10, 11]], [[ 4, 5, 6, 7], [ 0, 1, 2, 3], [ 8, 9, 10, 11]]])

Note this is different from the typical scenario

Initially I hit some issues with integer index broadcasting. It seems it is possible to do it.

a[np.array([np.arange(2)]).T, b]

References: