Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
159 views
in Technique[技术] by (71.8m points)

PyTorch: How to insert before a certain element

Currently I have a 2D tensor, for each row, I want to insert a new element e before the first index of a specified value v. Additional information: cannot guarantee each row could have a such value. If there isn't, just append the element

Example: Supporse e is 0, v is 10, Given a tensor

[[9, 6, 5, 4, 10],
 [8, 7, 3, 5, 5],
 [4, 9, 10, 10, 10]]

I want to get

[[9, 6, 5, 4, 0, 10],
 [8, 7, 3, 5, 5, 0],
 [4, 9, 0, 10, 10, 10]]

Are there some Torch-style ways to do this? The worst case I can treat this as a trivial Python problem but I think the corresponding solution is a little time-consuming.

question from:https://stackoverflow.com/questions/65932919/pytorch-how-to-insert-before-a-certain-element

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

I haven't yet found a full PyTorch solution. I'll keep looking, but here is somewhere to start:

>>> v, e = 10, 0
>>> v, e = torch.tensor([v]), torch.tensor([e])

>>> x = torch.tensor([[ 9,  6,  5,  4, 10],
                      [ 8,  7,  3,  5,  5],
                      [ 4,  9, 10, 10, 10],
                      [10,  9,  7, 10,  2]])

To deal with the edge case where v is not found in one of the rows you can add a temporary column to x. This will ensure every row has a value v in it. We will use x_ as a helper tensor:

>>> x_ = torch.cat([x, v.repeat(x.size(0))[:, None]], axis=1)
tensor([[ 9,  6,  5,  4, 10, 10],
        [ 8,  7,  3,  5,  5, 10],
        [ 4,  9, 10, 10, 10, 10],
        [10,  9,  7, 10,  2, 10]])

Find the indices of the first value v on each row:

>>> bp = (x_ == v).int().argmax(axis=1)
tensor([4, 5, 2, 0])

Finally, the easiest way to insert values at different positions in each row is with a list comprehension:

>>> torch.stack([torch.cat([xi[:bpi], e, xi[bpi:]]) for xi, bpi in zip(x, bp)])
tensor([[ 9,  6,  5,  4,  0, 10],
        [ 8,  7,  3,  5,  5,  0],
        [ 4,  9,  0, 10, 10, 10],
        [ 0, 10,  9,  7, 10,  2]])

Edit - If v cannot occur in the first position, then no need for x_:

>>> x
tensor([[ 9,  6,  5,  4, 10],
        [ 8,  7,  3,  5,  5],
        [ 4,  9, 10, 10, 10]])

>>> bp = (x == v).int().argmax(axis=1) - 1

>>> torch.stack([torch.cat([xi[:bpi], e, xi[bpi:]]) for xi, bpi in zip(x, bp)])
tensor([[ 9,  6,  5,  0,  4, 10],
        [ 8,  7,  3,  5,  0,  5],
        [ 4,  0,  9, 10, 10, 10]])

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...