Faced an interesting problem recently
a : (B, S, T)
b : (B, C) where 0 <= x[i, j] < S
What I want is an array of shape (B, C, T)
a = np.array(
...: [[[0,1,2,3],
...: [4,5,6,7],
...: [8,9,10,11]],
...: [[0,1,2,3],
...: [4,5,6,7],
...: [8,9,10,11]]])
b = np.array(
...: [[0,2,2],
...: [1,0, 2]])
a.shape
Out[79]: (2, 3, 4)
b.shape
Out[80]: (2, 3)
What I expect is this
array([[[ 0, 1, 2, 3],
[ 8, 9, 10, 11],
[ 8, 9, 10, 11]],
[[ 4, 5, 6, 7],
[ 0, 1, 2, 3],
[ 8, 9, 10, 11]]])
Note this is different from the typical scenario
Initially I hit some issues with integer index broadcasting. It seems it is possible to do it.
a[np.array([np.arange(2)]).T, b]
References: