Goal
My goal is to perform an expensive operation on a masked subset of elements and represent the remaining elements with zero. I'll start out with an example where I use sum in place of the expensive operation:
# shape: (3,3,2)
in = [ [ [ 1, 2 ], [ 2, 3 ], [ 3, 4 ] ],
[ [ 4, 5 ], [ 5, 6 ], [ 6, 7 ] ],
[ [ 7, 8 ], [ 8, 9 ], [ 9, 0 ] ] ]
# shape: (3,3)
mask = [ [ 0, 1, 0 ],
[ 1, 0, 0 ],
[ 0, 0, 0 ] ]
# expected sum output:
# shape: (3,3,1)
out = [ [ [ 0 ], [ 5 ], [ 0 ] ],
[ [ 9 ], [ 0 ], [ 0 ] ],
[ [ 0 ], [ 0 ], [ 0 ] ] ]
Progress So Far
I was able to get this working outside of a layer. A
is my mask, E
is my input data, and N
is my number of elements in both row and column axes.
The process of partitioning had the side-effect of flattening E so I needed to reshape it back into its original row/col dimension but with a new channel dimension.
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
import tensorflow as tf
import numpy as np
def test1():
N = 3
testA = [[[0., 1., 0.],
[1., 0., 0.],
[0., 0., 0.]]]
testE = [[[[1., 2.],
[3.1, 4.1],
[5., 6.], ],
[[7.1, 8.1],
[9., 1.],
[9., 2.], ],
[[8., 3.],
[7., 4.],
[6., 5.], ], ]]
testA = np.asarray(testA).astype('float32')
testE = np.asarray(testE).astype('float32')
part1 = tf.dynamic_partition( testE, testA, 2 )
print( len( part1 ) )
print( part1[0] )
print( part1[1] )
"""
2
tf.Tensor(
[[1. 2.]
[5. 6.]
[9. 1.]
[9. 2.]
[8. 3.]
[7. 4.]
[6. 5.]], shape=(7, 2), dtype=float32)
tf.Tensor(
[[3.1 4.1]
[7.1 8.1]], shape=(2, 2), dtype=float32)
"""
sum1 = tf.math.reduce_sum( part1[1], axis=-1, keepdims=1 )
print( sum1 )
"""
tf.Tensor(
[[ 7.2 ]
[15.200001]], shape=(2, 1), dtype=float32)
"""
indices1 = [
[ 0 ],
[ 2 ],
[ 4 ],
[ 5 ],
[ 6 ],
[ 7 ],
[ 8 ],
]
indices2 = [
[ 1 ],
[ 3 ],
]
indices = [ indices1, indices2 ]
partitioned_data = [
np.zeros( shape=(7,1) ),
sum1
]
stitch1_flat = tf.dynamic_stitch( indices, partitioned_data )
print( stitch1_flat )
"""
tf.Tensor(
[ 0. 7.2 0. 15.200001 0. 0. 0.
0. 0. ], shape=(9,), dtype=float32)
"""
stitch1 = tf.reshape( stitch1_flat, (N,N,1) )
print( stitch1 )
"""
tf.Tensor(
[[[ 0. ]
[ 7.2 ]
[ 0. ]]
[[15.200001]
[ 0. ]
[ 0. ]]
[[ 0. ]
[ 0. ]
[ 0. ]]], shape=(3, 3, 1), dtype=float32)
"""
stitch1_np = stitch1.numpy()
target = np.array([[[ 0. ],
[ 7.2 ],
[ 0. ]],
[[15.200001],
[ 0. ],
[ 0. ]],
[[ 0. ],
[ 0. ],
[ 0. ]]])
np.testing.assert_almost_equal( stitch1_np, target, decimal=3 )
Where I Need Help
I'm having a very hard time trying to generalize this into a keras/tf layer.
I was able to get the partition working but I'm having a hard time calculating the correct indices for stitching back up.
I also expect to run into trouble calculating the size of the zeros tensor for the other sticking element.
Any help you can offer for either of these sticking points would be greatly appreciated!
I'm also pretty bad at python so don't assume I'm doing anything unconventional on purpose.
I probably just don't know any better.
Thanks in advance!
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
import tensorflow as tf
import numpy as np
def test2():
class TestFlat(tf.keras.layers.Layer):
def __init__(self):
super(TestFlat, self).__init__()
self.N = -1 #size of row, column
self.S = -1 #size of input channel
def build(self, input_shape):
print( "input_shape: ", input_shape )
# TensorShape([None, 3, 3]), TensorShape([None, 3, 3, 2])]
assert( len( input_shape ) == 2 )
assert( len( input_shape[0] ) == 3 )
assert( len( input_shape[1] ) == 4 )
assert( input_shape[0][1] == input_shape[1][1] )
assert( input_shape[0][2] == input_shape[1][2] )
self.N = input_shape[0][1]
self.S = input_shape[1][3]
def call(self, inputs):
print( "inputs: ", inputs )
#[<tf.Tensor 'A_in:0' shape=(None, 3, 3) dtype=float32>,
# <tf.Tensor 'E_in:0' shape=(None, 3, 3, 2) dtype=float32>]
A = inputs[0] # mask
E = inputs[1] # data
A_int = tf.cast( A, "int32" )
part = tf.dynamic_partition( E, A_int, 2 )
print( len( part ) )
print( part[0] )
print( part[1] )
"""
2
tf.Tensor(
[[1. 2.]
[5. 6.]
[9. 1.]
[9. 2.]
[8. 3.]
[7. 4.]
[6. 5.]], shape=(7, 2), dtype=float32)
tf.Tensor(
[[3.1 4.1]
[7.1 8.1]], shape=(2, 2), dtype=float32)
"""
sum1 = tf.math.reduce_sum( part[1], axis=-1, keepdims=True )
# Okay so now we're done with the "expensive" calculation
# and we just need to merge with zeros back into our target shape of (None,N,N,1)
# Step 1: Calculate indices for stitching
# none of the rest of this works:
r = tf.range(self.N*self.N*self.S) #???
#tf.shape((None,self.N,self.N,1)) #???
s = tf.shape(E) #???
aa=tf.Variable(s) #???
aa[-1].assign( 1 ) #???
r = tf.reshape( r, s ) #???
indices = tf.dynamic_partition( r, A_int, 2 )
print( indices )
"""
partitioned_data = [
np.zeros( shape=(7,1) ),
sum1
]
"""
# Step 2: Create zero tensor
# Step 3: Stitch sum1 with zero tensor
return inputs[0] #dummy for now
N = 3
S = 2
A_in = Input(shape=(N, N), name='A_in')
E_in = Input(shape=(N, N, S), name='E_in')
out = TestFlat()( [A_in,E_in] )
model = Model(inputs=[A_in,E_in], outputs=out)
model.compile(optimizer='adam', loss='mean_squared_error')
model.summary()
testA = [[[0., 1., 0.],
[1., 0., 0.],
[0., 0., 0.]]]
testE = [[[[1., 2.],
[3.1, 4.1],
[5., 6.], ],
[[7.1, 8.1],
[9., 1.],
[9., 2.], ],
[[8., 3.],
[7., 4.],
[6., 5.], ], ]]
testA = np.asarray(testA).astype('float32')
testE = np.asarray(testE).astype('float32')
print( model([testA,testE]) )