Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
188 views
in Technique[技术] by (71.8m points)

python - Why tf.GradientTape() has less GPU memory usage when watch model variables manually?

So when I use tf.GradientTape() to automatically monitor the trainable variables in a resnet model, the computer threw an out of memory error. Below is the code:

x_mini = preprocess_input(x_train)  
with tf.GradientTape() as tape:    
    outputs = model(x_mini, training=True)

However if I disable the auto-monitor and manually watch the trainable variables, I can feed in even larger data without any memory problem. The code is below:

x_mini = preprocess_input(x_train)
with tf.GradientTape(watch_accessed_variables=False) as tape:
    tape.watch(model.trainable_variables)
    outputs = model(x_mini, training=True)

I am wondering if the tape missed some variables when I do it manually.

Below is runable code(out of memory error will show if you comment option 1): I use Tesla T4 15G GPU and tensorflow 2.3.

import tensorflow as tf
import numpy as np
from keras.models import Model
import keras.layers as ly
x_train = tf.convert_to_tensor(np.random.randint(0, 255, (900,224,224,3)), dtype=tf.dtypes.float32)
y_train = tf.convert_to_tensor([0,1,0], dtype=tf.dtypes.float32)
print(x_train.shape)

tf.keras.backend.clear_session()
resnet_model = tf.keras.applications.resnet.ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
resnet_model.trainable = False
inputs = tf.keras.Input(shape=(224, 224, 3))
x = resnet_model(inputs, training=False)
x = ly.GlobalAveragePooling2D()(x)
x = ly.Dropout(0.2)(x)
outputs = ly.Dense(3, activation='softmax')(x)
model = Model(inputs, outputs)
mcross = tf.keras.losses.categorical_crossentropy
macc = tf.keras.metrics.categorical_accuracy
base_learning_rate = 0.0001
optimizer = tf.keras.optimizers.Adam(base_learning_rate)

def cross_entropy(y_true, y_pred):
    y_pred = y_pred / tf.reduce_sum(y_pred, 1, True)
    y_pred = tf.clip_by_value(y_pred, 1e-3, 1-1e-3)
    return -tf.reduce_sum(y_true*tf.math.log(y_pred), 1)

# option 1
# manually tapping variables 
with tf.GradientTape(watch_accessed_variables=False) as tape:
    tape.watch(model.trainable_variables)
    y_pred = model(x_train, training=True)
    loss = cross_entropy(y_train, tf.reduce_mean(y_pred, 0, keepdims=True))
gradients = tape.gradient(loss, model.trainable_variables)

#option 2
# automatically tapping variable
with tf.GradientTape() as tape:
    y_pred = model(x_train, training=True)
    loss = cross_entropy(y_train, tf.reduce_mean(y_pred, 0, keepdims=True))
gradients = tape.gradient(loss, model.trainable_variables)

Also the error message:

--------------------------------------------------------------------------- ResourceExhaustedError                    Traceback (most recent call last) <ipython-input-4-42e45caeae41> in <module>
     31 # automatically tapping variable
     32 with tf.GradientTape() as tape:
---> 33     y_pred = model(x_train, training=True)
     34     loss = cross_entropy(y_train, tf.reduce_mean(y_pred, 0, keepdims=True))
     35 gradients = tape.gradient(loss, model.trainable_variables)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
    983 
    984         with ops.enable_auto_cast_variables(self._compute_dtype_object):
--> 985           outputs = call_fn(inputs, *args, **kwargs)
    986 
    987         if self._activity_regularizer:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in call(self, inputs, training, mask)
    384     """
    385     return self._run_internal_graph(
--> 386         inputs, training=training, mask=mask)
    387 
    388   def compute_output_shape(self, input_shape):

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in _run_internal_graph(self, inputs, training, mask)
    506 
    507         args, kwargs = node.map_arguments(tensor_dict)
--> 508         outputs = node.layer(*args, **kwargs)
    509 
    510         # Update tensor_dict.

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
    983 
    984         with ops.enable_auto_cast_variables(self._compute_dtype_object):
--> 985           outputs = call_fn(inputs, *args, **kwargs)
    986 
    987         if self._activity_regularizer:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in call(self, inputs, training, mask)
    384     """
    385     return self._run_internal_graph(
--> 386         inputs, training=training, mask=mask)
    387 
    388   def compute_output_shape(self, input_shape):

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in _run_internal_graph(self, inputs, training, mask)
    506 
    507         args, kwargs = node.map_arguments(tensor_dict)
--> 508         outputs = node.layer(*args, **kwargs)
    509 
    510         # Update tensor_dict.

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
    983 
    984         with ops.enable_auto_cast_variables(self._compute_dtype_object):
--> 985           outputs = call_fn(inputs, *args, **kwargs)
    986 
    987         if self._activity_regularizer:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/layers/convolutional.py in call(self, inputs)
    245       inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs))
    246 
--> 247     outputs = self._convolution_op(inputs, self.kernel)
    248 
    249     if self.use_bias:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
    199     """Call target, and fall back on dispatchers if there is a TypeError."""
    200     try:
--> 201       return target(*args, **kwargs)
    202     except (TypeError, ValueError):
    203       # Note: convert_to_eager_tensor currently raises a ValueError, not a

/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/nn_ops.py in convolution_v2(input, filters, strides, padding, data_format, dilations, name)    1016       data_format=data_format,    1017       dilations=dilations,
-> 1018       name=name)    1019     1020 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/nn_ops.py in convolution_internal(input, filters, strides, padding, data_format, dilations, name, call_from_convolution, num_spatial_dims)    1146      data_format=data_format,    1147           dilations=dilations,
-> 1148           name=name)    1149     else:    1150       if channel_index == 1:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/nn_ops.py in _conv2d_expanded_batch(input, filters, strides, padding, data_format, dilations, name)    2590         data_format=data_format, 2591         dilations=dilations,
-> 2592         name=name)    2593   return squeeze_batch_dims(    2594       input,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/gen_nn_ops.py in conv2d(input, filter, strides, padding, use_cudnn_on_gpu, explicit_paddings, data_format, dilations, name)
    936       return _result
    937     except _core._NotOkStatusException as e:
--> 938       _ops.raise_from_not_ok_status(e, name)
    939     except _core._FallbackException:
    940       pass

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)    6841   message = e.message + (" name: " + name if name is not None else "")    6842   # pylint: disable=protected-access
-> 6843   six.raise_from(core._status_to_exception(e.code, message), None)    6844   # pylint: enable=protected-access    6845

/opt/conda/lib/python3.7/site-packages/six.py in raise_from(value, from_value)

ResourceExhaustedError: OOM when allocating tensor with shape[900,56,56,256] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:Conv2D]

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)
等待大神答复

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

1.4m articles

1.4m replys

5 comments

57.0k users

...