I haven't yet found a full PyTorch solution. I'll keep looking, but here is somewhere to start:
>>> v, e = 10, 0
>>> v, e = torch.tensor([v]), torch.tensor([e])
>>> x = torch.tensor([[ 9, 6, 5, 4, 10],
[ 8, 7, 3, 5, 5],
[ 4, 9, 10, 10, 10],
[10, 9, 7, 10, 2]])
To deal with the edge case where v
is not found in one of the rows you can add a temporary column to x
. This will ensure every row has a value v
in it. We will use x_
as a helper tensor:
>>> x_ = torch.cat([x, v.repeat(x.size(0))[:, None]], axis=1)
tensor([[ 9, 6, 5, 4, 10, 10],
[ 8, 7, 3, 5, 5, 10],
[ 4, 9, 10, 10, 10, 10],
[10, 9, 7, 10, 2, 10]])
Find the indices of the first value v
on each row:
>>> bp = (x_ == v).int().argmax(axis=1)
tensor([4, 5, 2, 0])
Finally, the easiest way to insert values at different positions in each row is with a list comprehension:
>>> torch.stack([torch.cat([xi[:bpi], e, xi[bpi:]]) for xi, bpi in zip(x, bp)])
tensor([[ 9, 6, 5, 4, 0, 10],
[ 8, 7, 3, 5, 5, 0],
[ 4, 9, 0, 10, 10, 10],
[ 0, 10, 9, 7, 10, 2]])
Edit - If v
cannot occur in the first position, then no need for x_
:
>>> x
tensor([[ 9, 6, 5, 4, 10],
[ 8, 7, 3, 5, 5],
[ 4, 9, 10, 10, 10]])
>>> bp = (x == v).int().argmax(axis=1) - 1
>>> torch.stack([torch.cat([xi[:bpi], e, xi[bpi:]]) for xi, bpi in zip(x, bp)])
tensor([[ 9, 6, 5, 0, 4, 10],
[ 8, 7, 3, 5, 0, 5],
[ 4, 0, 9, 10, 10, 10]])