Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
1.6k views
in Technique[技术] by (71.8m points)

python - How to use tf.dynamic_partition and tf.dynamic_stitch with multiple dimensions and a change in shape?

Goal

My goal is to perform an expensive operation on a masked subset of elements and represent the remaining elements with zero. I'll start out with an example where I use sum in place of the expensive operation:

# shape: (3,3,2)
in = [ [ [ 1, 2 ], [ 2, 3 ], [ 3, 4 ] ],
       [ [ 4, 5 ], [ 5, 6 ], [ 6, 7 ] ],
       [ [ 7, 8 ], [ 8, 9 ], [ 9, 0 ] ] ]

# shape: (3,3)
mask = [ [ 0, 1, 0 ],
         [ 1, 0, 0 ],
         [ 0, 0, 0 ] ]

# expected sum output:
# shape: (3,3,1)
out = [ [ [ 0 ], [ 5 ], [ 0 ] ],
        [ [ 9 ], [ 0 ], [ 0 ] ],
        [ [ 0 ], [ 0 ], [ 0 ] ] ]

Progress So Far

I was able to get this working outside of a layer. A is my mask, E is my input data, and N is my number of elements in both row and column axes.

The process of partitioning had the side-effect of flattening E so I needed to reshape it back into its original row/col dimension but with a new channel dimension.

from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
import tensorflow as tf

import numpy as np

def test1():
    N = 3
    testA = [[[0., 1., 0.],
              [1., 0., 0.],
              [0., 0., 0.]]]
    testE = [[[[1., 2.],
               [3.1, 4.1],
               [5., 6.], ],

              [[7.1, 8.1],
               [9., 1.],
               [9., 2.], ],

              [[8., 3.],
               [7., 4.],
               [6., 5.], ], ]]
    
    testA = np.asarray(testA).astype('float32')
    testE = np.asarray(testE).astype('float32')
        
    part1 = tf.dynamic_partition( testE, testA, 2 )
    print( len( part1 ) )
    print( part1[0] )
    print( part1[1] )
    """
    2
    tf.Tensor(
    [[1. 2.]
    [5. 6.]
    [9. 1.]
    [9. 2.]
    [8. 3.]
    [7. 4.]
    [6. 5.]], shape=(7, 2), dtype=float32)
    tf.Tensor(
    [[3.1 4.1]
    [7.1 8.1]], shape=(2, 2), dtype=float32)
    """
    
    sum1 = tf.math.reduce_sum( part1[1], axis=-1, keepdims=1 )
    print( sum1 )
    """
    tf.Tensor(
    [[ 7.2     ]
    [15.200001]], shape=(2, 1), dtype=float32)
    """
    
    indices1 = [
        [ 0 ],
        [ 2 ],
        [ 4 ],
        [ 5 ],
        [ 6 ],
        [ 7 ],
        [ 8 ],
        ]

    indices2 = [
        [ 1 ],
        [ 3 ],
        ]

    
    indices = [ indices1, indices2 ]
    
    partitioned_data = [
        np.zeros( shape=(7,1) ),
        sum1
    ]
    
    stitch1_flat = tf.dynamic_stitch( indices, partitioned_data )
    print( stitch1_flat )
    """
    tf.Tensor(
    [ 0.        7.2       0.       15.200001  0.        0.        0.
    0.        0.      ], shape=(9,), dtype=float32)
    """

    stitch1 = tf.reshape( stitch1_flat, (N,N,1) )
    print( stitch1 )
    """
    tf.Tensor(
    [[[ 0.      ]
    [ 7.2     ]
    [ 0.      ]]

    [[15.200001]
    [ 0.      ]
    [ 0.      ]]

    [[ 0.      ]
    [ 0.      ]
    [ 0.      ]]], shape=(3, 3, 1), dtype=float32)
    """

    stitch1_np = stitch1.numpy()
    target = np.array([[[ 0.      ],
                        [ 7.2     ],
                        [ 0.      ]],

                       [[15.200001],
                        [ 0.      ],
                        [ 0.      ]],

                       [[ 0.      ],
                        [ 0.      ],
                       [ 0.      ]]])
    
    np.testing.assert_almost_equal( stitch1_np, target, decimal=3 )

Where I Need Help

I'm having a very hard time trying to generalize this into a keras/tf layer. I was able to get the partition working but I'm having a hard time calculating the correct indices for stitching back up. I also expect to run into trouble calculating the size of the zeros tensor for the other sticking element. Any help you can offer for either of these sticking points would be greatly appreciated!

I'm also pretty bad at python so don't assume I'm doing anything unconventional on purpose. I probably just don't know any better.

Thanks in advance!

from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
import tensorflow as tf

import numpy as np

def test2():

    class TestFlat(tf.keras.layers.Layer):
        def __init__(self):
            super(TestFlat, self).__init__()
            self.N = -1 #size of row, column
            self.S = -1 #size of input channel

        def build(self, input_shape):
            print( "input_shape: ", input_shape )
            # TensorShape([None, 3, 3]), TensorShape([None, 3, 3, 2])]
            assert( len( input_shape ) == 2 )
            assert( len( input_shape[0] ) == 3 )
            assert( len( input_shape[1] ) == 4 )
            assert( input_shape[0][1] == input_shape[1][1] )
            assert( input_shape[0][2] == input_shape[1][2] )
            self.N = input_shape[0][1]
            self.S = input_shape[1][3]
            
        def call(self, inputs):
            print( "inputs: ", inputs )
            #[<tf.Tensor 'A_in:0' shape=(None, 3, 3)    dtype=float32>,
            # <tf.Tensor 'E_in:0' shape=(None, 3, 3, 2) dtype=float32>]
            
            A = inputs[0] # mask
            E = inputs[1] # data

            A_int = tf.cast( A, "int32" )
            part = tf.dynamic_partition( E, A_int, 2 )
            print( len( part ) )
            print( part[0] )
            print( part[1] )
            """
            2
            tf.Tensor(
            [[1. 2.]
            [5. 6.]
            [9. 1.]
            [9. 2.]
            [8. 3.]
            [7. 4.]
            [6. 5.]], shape=(7, 2), dtype=float32)
            tf.Tensor(
            [[3.1 4.1]
            [7.1 8.1]], shape=(2, 2), dtype=float32)
            """

            sum1 = tf.math.reduce_sum( part[1], axis=-1, keepdims=True )

            # Okay so now we're done with the "expensive" calculation
            # and we just need to merge with zeros back into our target shape of (None,N,N,1)

            # Step 1: Calculate indices for stitching
            # none of the rest of this works:
            r = tf.range(self.N*self.N*self.S) #???
            #tf.shape((None,self.N,self.N,1))  #???
            s = tf.shape(E)                    #???
            aa=tf.Variable(s)                  #???
            aa[-1].assign( 1 )                 #??? 
            r = tf.reshape( r, s )             #???
            indices = tf.dynamic_partition( r, A_int, 2 )
            print( indices )
            """
            partitioned_data = [
                np.zeros( shape=(7,1) ),
                sum1
            ]
            """
            
            # Step 2: Create zero tensor

            # Step 3: Stitch sum1 with zero tensor

            return inputs[0] #dummy for now

    N = 3
    S = 2

    A_in = Input(shape=(N, N), name='A_in')
    E_in = Input(shape=(N, N, S), name='E_in')

    out = TestFlat()( [A_in,E_in] )
    
    model = Model(inputs=[A_in,E_in], outputs=out)
    model.compile(optimizer='adam', loss='mean_squared_error')
    model.summary()

    testA = [[[0., 1., 0.],
              [1., 0., 0.],
              [0., 0., 0.]]]
    testE = [[[[1., 2.],
               [3.1, 4.1],
               [5., 6.], ],

              [[7.1, 8.1],
               [9., 1.],
               [9., 2.], ],

              [[8., 3.],
               [7., 4.],
               [6., 5.], ], ]]
    
    testA = np.asarray(testA).astype('float32')
    testE = np.asarray(testE).astype('float32')

    print( model([testA,testE]) )

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

I eventually found an answer after learning that you don't actually need the indices for the zero values. The intermixed zeros are implied and you just need to pad some onto the end for the final batch.

This is messy but it works. Please feel free to offer suggestions on how this could be better.

def test_flat_nbody_layer():

    class TestFlat(tf.keras.layers.Layer):
        def __init__(self):
            super(TestFlat, self).__init__()
            self.N = -1
            self.S = -1

        def build(self, input_shape):
            print( "input_shape: ", input_shape )
            # TensorShape([None, 3, 3]), TensorShape([None, 3, 3, 2])]
            assert( len( input_shape ) == 2 )
            assert( len( input_shape[0] ) == 3 )
            assert( len( input_shape[1] ) == 4 )
            assert( input_shape[0][1] == input_shape[1][1] )
            assert( input_shape[0][2] == input_shape[1][2] )
            self.N = input_shape[0][1]
            self.S = input_shape[1][3]
            pass
            
        def call(self, inputs):
            print( "inputs: ", inputs )
            #[<tf.Tensor 'A_in:0' shape=(None, 3, 3)    dtype=float32>,
            # <tf.Tensor 'E_in:0' shape=(None, 3, 3, 2) dtype=float32>]
            
            A = inputs[0]
            E = inputs[1]
            
            print( A ) #shape=(None, 3, 3)
            print( E ) #shape=(None, 3, 3, 2)
            
            A_int = tf.cast( A, "int32" )
            
            part = tf.dynamic_partition( E, A_int, 2 )
            print( len( part ) )
            print( part[0] ) #shape=(None, 2)
            print( part[1] ) #shape=(None, 2)
            """
            2
            tf.Tensor(
            [[1. 2.]
            [5. 6.]
            [9. 1.]
            [9. 2.]
            [8. 3.]
            [7. 4.]
            [6. 5.]], shape=(7, 2), dtype=float32)
            tf.Tensor(
            [[3.1 4.1]
            [7.1 8.1]], shape=(2, 2), dtype=float32)
            """

            sum1 = tf.math.reduce_sum( part[1], axis=-1, keepdims=True )
            print( sum1.shape )
            
            x=tf.constant(self.N*self.N)
            n=tf.constant(self.N)
            r = tf.range(x*tf.shape(E)[0])
            print( r ) #Tensor("test_flat/range:0", shape=(9,), dtype=int32)

            print( "Batch Size:", tf.shape(E)[0] )
            r2 = tf.reshape( r, shape=[tf.shape(E)[0],n,n] )
            print( r2 ) #Tensor("test_flat/Reshape:0", shape=(1, 3, 3), dtype=int32)
            condition_indices = tf.dynamic_partition( r2, A_int, 2 )
            print( condition_indices )
            #[<tf.Tensor 'test_flat/DynamicPartition_1:0' shape=(None,) dtype=int32>,
            # <tf.Tensor 'test_flat/DynamicPartition_1:1' shape=(None,) dtype=int32>]

            indices = [ condition_indices[ 1 ] ]
            partitioned_data = [ sum1 ]
            stitch_flat = tf.dynamic_stitch( indices, partitioned_data )
            print( "stitch_flat", stitch_flat )
            # Tensor("test_flat/DynamicStitch:0", shape=(None, 1), dtype=float32)

            npad1 = tf.shape(E)[0] * n * n
            print( "npad1", npad1 )
            npad2 = tf.shape(stitch_flat)[0]
            print( "npad2", npad2 )
            nz = npad1 - npad2
            print( "nz", nz )

            zero_padding = tf.zeros(nz, dtype=stitch_flat.dtype)
            print( "zeros", zero_padding )
            zero_padding = tf.reshape( zero_padding, [nz,1] )
            print( "zeros", zero_padding )

            print( "tf.shape(stitch_flat)", tf.shape(stitch_flat) )
            stitch = tf.concat([stitch_flat,zero_padding], -2 )

            stitch = tf.reshape( stitch, [tf.shape(E)[0],n,n,1] )
            
            return stitch #dummy for now

    N = 3
    S = 2

    A_in = Input(shape=(N, N), name='A_in')
    E_in = Input(shape=(N, N, S), name='E_in')

    out = TestFlat()( [A_in,E_in] )
    
    model = Model(inputs=[A_in,E_in], outputs=out)
    model.compile(optimizer='adam', loss='mean_squared_error')
    model.summary()

    testA = [[[0., 1., 0.],
              [1., 0., 0.],
              [0., 0., 0.]]]
    testE = [[[[1., 2.],
               [3.1, 4.1],
               [5., 6.], ],

              [[7.1, 8.1],
               [9., 1.],
               [9., 2.], ],

              [[8., 3.],
               [7., 4.],
               [6., 5.], ], ]]


    target = np.array([[[ 0.      ],
                        [ 7.2     ],
                        [ 0.      ]],

                       [[15.200001],
                        [ 0.      ],
                        [ 0.      ]],

                       [[ 0.      ],
                        [ 0.      ],
                       [ 0.      ]]])
    assert_almost_equal = np.testing.assert_almost_equal
    
    testA = np.asarray(testA).astype('float32')
    testE = np.asarray(testE).astype('float32')

    batch1_pred = model([testA,testE])
    print( "test1", batch1_pred )
    """
    tf.Tensor(
    [[ 0.      ]
    [ 7.2     ]
    [ 0.      ]
    [15.200001]], shape=(4, 1), dtype=float32)
    """
    assert_almost_equal( batch1_pred[0], target, decimal=3 )
    
    testA2 = np.asarray([ testA[0], testA[0] ])
    testE2 = np.asarray([ testE[0], testE[0] ])

    print( "testA2", testA2.shape )
    print( "testE2", testE2.shape )
    """
    testA2 (2, 3, 3)
    testE2 (2, 3, 3, 2)
    """

    batch2_pred = model([testA2,testE2])
    print( "test2", batch2_pred )
    
    for output in batch2_pred:
        assert_almost_equal( output, target, decimal=3 )

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...