I want to use tensorflow to rewrite the pytorch's torch.nn.functional.unfold
function:
#input x:[16, 1, 50, 36]
x = torch.nn.functional.unfold(x, kernel_size=(5, 36), stride=3)
#output x:[16, 180, 16]
I tried to use the function tf.extract_image_patches()
:
x = tf.extract_image_patches(x,ksizes=[1, 1,5, 98],strides=[1, 1, 3, 1], rates=[1, 1, 1, 1],padding='VALID')
The input x.shape
:[16,1,64,98]
I get the output x.shape
:[16,1,20,490]
Then I reshape the X
to [16,490,20]
, that was I expect.
But I get the error when I feed the data:
UnimplementedError (see above for traceback): Only support ksizes across space.
[[Node:hcn/ExtractImagePatches = ExtractImagePatches[T=DT_FLOAT, ksizes=[1, 1, 5, 98], padding="VALID", rates=[1, 1, 1, 1], strides=[1, 1, 3, 1], _device="/job:localhost/replica:0/task:0/device:GPU:0"](hcn/Reshape)]]
How could I use tensorflow to rewrite pytorch torch.nn.functional.unfold
function to change the X
?
See Question&Answers more detail:
os