Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
571 views
in Technique[技术] by (71.8m points)

python - Load weights from checkpoint not working in keras model

I am going insane over this.

I define a sequential model using tensorflow keras:

model = tf.keras.Sequential([tf.keras.layer.Dense(128,input_shape(784,),activation="relu"),
                             tf.keras.layer.Dense(10,activation="softmax"])
model.compile(optimizer="adam",loss="mse")
keras.experimental.export_saved_model(model,"keras_model")

I train said model in a C program using c_api.h

C program saves weights in a checkpoint file.

When trying to restore weights in python from checkpoint file with:

keras.experimental.load_from_saved_model("keras_model/")
#OR
model = tf.keras.Sequential([tf.keras.layer.Dense(128,input_shape(784,),activation="relu"),
                             tf.keras.layer.Dense(10,activation="softmax"])
model.load_weights("keras_model/variables/variables")
#OR
checkpoint = tf.train.Checkpoint(model=model)
status = checkpoint.restore("keras_model/variables/variables")

I end up getting an error and no weights are restored.

I am able to restore weights and continue training in my C program

keras.experimental.load_from_saved_model("keras_model/")
WARNING: Logging before flag parsing goes to stderr.
W0918 15:18:04.350199 140418474760000 deprecation.py:323] From <ipython-input-2-06ea110fdc8e>:1: load_from_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
The experimental save and load functions have been  deprecated. Please switch to `tf.keras.models.load_model`.
2019-09-18 15:18:04.390271: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 1696040000 Hz
2019-09-18 15:18:04.390913: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x4bf4790 executing computations on platform Host. Devices:
2019-09-18 15:18:04.390961: I tensorflow/compiler/xla/service/service.cc:175]   StreamExecutor device (0): Host, Default Version
W0918 15:18:04.436281 140418474760000 deprecation.py:323] From /home/jregalado/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/util.py:1249: NameBasedSaverStatus.__init__ (from tensorflow.python.training.tracking.util) is deprecated and will be removed in a future version.
Instructions for updating:
Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-2-06ea110fdc8e> in <module>
----> 1 keras.experimental.load_from_saved_model("keras_model/")

~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py in new_func(*args, **kwargs)
322               'in a future version' if date is None else ('after %s' % date),
323               instructions)
--> 324       return func(*args, **kwargs)
325     return tf_decorator.make_decorator(
326         func, new_func, 'deprecated',

~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saved_model_experimental.py in load_from_saved_model(saved_model_path, custom_objects)
425       compat.as_text(constants.VARIABLES_DIRECTORY),
426       compat.as_text(constants.VARIABLES_FILENAME))
--> 427   model.load_weights(checkpoint_prefix)
428   return model

~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in load_weights(self, filepath, by_name)
179         raise ValueError('Load weights is not yet supported with TPUStrategy '
180                          'with steps_per_run greater than 1.')
--> 181     return super(Model, self).load_weights(filepath, by_name)
182
183   @trackable.no_automatic_dependency_tracking

~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in load_weights(self, filepath, by_name)
1372         # streaming restore for any variables created in the future.
1373         trackable_utils.streaming_restore(status=status, session=session)
-> 1374       status.assert_nontrivial_match()
1375       return status
1376     if h5py is None:

~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/util.py in assert_nontrivial_match(self)
964     # assert_nontrivial_match and assert_consumed (and both are less
965     # useful since we don't touch Python objects or Python state).
--> 966     return self.assert_consumed()
967
968   def _gather_saveable_objects(self):

~/Projects/tensorflow/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/util.py in assert_consumed(self)
941       raise AssertionError(
942           "Some objects had attributes which were not restored:{}".format(
--> 943               "".join(unused_attribute_strings)))
944     for trackable in self._graph_view.list_objects():
945       # pylint: disable=protected-access

AssertionError: Some objects had attributes which were not restored:
<tf.Variable 'a/kernel:0' shape=(784, 128) dtype=float32, numpy=
array([[-0.03716458, -0.04911711, -0.01023878, ...,  0.0636776 ,
0.02892563, -0.05542086],
[-0.02324755, -0.07362694, -0.0399951 , ...,  0.0680329 ,
0.05201877, -0.05149256],
[ 0.00954343,  0.05673491,  0.05108347, ...,  0.01994208,
-0.01107961,  0.06192174],
...,
[ 0.07091486, -0.07734856, -0.04417738, ...,  0.01921409,
-0.01908814, -0.05070668],
[ 0.01353646, -0.05189713, -0.01391671, ..., -0.05795977,
0.04801518,  0.00801209],
[-0.05304915,  0.01870193,  0.05657425, ..., -0.06819408,
-0.00760372, -0.0106293 ]], dtype=float32)>: ['a/kernel']
<tf.Variable 'a/bias:0' shape=(128,) dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>: ['a/bias']
<tf.Variable 'b/kernel:0' shape=(128, 10) dtype=float32, numpy=
array([[-0.1759212 , -0.09282549, -0.11045764, ..., -0.13727605,
-0.02849793,  0.14510198],
[ 0.06857841, -0.01459177,  0.08369003, ...,  0.05089156,
-0.05319159, -0.08594933],
[-0.180914  , -0.18932283,  0.20551099, ..., -0.17210156,
-0.10069884,  0.06433241],
...,
[ 0.09097584, -0.03930017, -0.15125516, ...,  0.02359283,
-0.16158347, -0.13176063],
[-0.04145582, -0.03205152,  0.20097663, ..., -0.15124482,
0.16874255, -0.15434337],
[-0.13188484,  0.04145408,  0.05036192, ..., -0.10489662,
0.12316228,  0.08794598]], dtype=float32)>: ['b/kernel']
<tf.Variable 'b/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>: ['b/bias']
See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)
Waitting for answers

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...