Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
302 views
in Technique[技术] by (71.8m points)

python - Using pre-trained inception_resnet_v2 with Tensorflow

I have been trying to use the pre-trained inception_resnet_v2 model released by Google. I am using their model definition(https://github.com/tensorflow/models/blob/master/slim/nets/inception_resnet_v2.py) and given checkpoint(http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz) to load the model in tensorflow as below [Download a extract the checkpoint file and download sample images dog.jpg and panda.jpg to test this code]-

import tensorflow as tf
slim = tf.contrib.slim
from PIL import Image
from inception_resnet_v2 import *
import numpy as np

checkpoint_file = 'inception_resnet_v2_2016_08_30.ckpt'
sample_images = ['dog.jpg', 'panda.jpg']
#Load the model
sess = tf.Session()
arg_scope = inception_resnet_v2_arg_scope()
with slim.arg_scope(arg_scope):
  logits, end_points = inception_resnet_v2(input_tensor, is_training=False)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)
for image in sample_images:
  im = Image.open(image).resize((299,299))
  im = np.array(im)
  im = im.reshape(-1,299,299,3)
  predict_values, logit_values = sess.run([end_points['Predictions'], logits], feed_dict={input_tensor: im})
  print (np.max(predict_values), np.max(logit_values))
  print (np.argmax(predict_values), np.argmax(logit_values))

However, the results from this model code does not give the expected results (class no 918 is predicted irrespective of the input image). Can someone help me understand where I am going wrong?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

The Inception networks expect the input image to have color channels scaled from [-1, 1]. As seen here.

You could either use the existing preprocessing, or in your example just scale the images yourself: im = 2*(im/255.0)-1.0 before feeding them to the network.

Without scaling the input [0-255] is much larger than the network expects and the biases all work to very strongly predict category 918 (comic books).


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...