Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
515 views
in Technique[技术] by (71.8m points)

tensorflow - Model is not learning

I am trying to train a Tensor-flow js model on images coming in from my web cam. Basically I'm trying to recreate the pac-man tensor-flow game. The model isn't converging and is pretty much useless after training. I have a feeling its how I'm prepping the data.

Grabbing the image from the canvas

function takePhoto(label) {
  let canv = document.getElementById("canv")
  let cont = canv.getContext("2d")
  cont.drawImage(video, 0, 0, width, height)

  let data = tf.browser.fromPixels(canv, 3)
  data.toFloat().div(tf.scalar(127)).sub(tf.scalar(1))
  return data
}

function addExample(label){
      let data = takePhoto()

      addData(train_data => train_data.concat(data))
      addLabel(train_labels => train_labels.concat(labels[label]))
    }

Train function

export async function train_model(image,label){
    let d = tf.stack(image)

    let l = tf.oneHot(tf.tensor1d(label).toInt(),4)

    let data = await model.fit(d,l,{epochs:10,batchSize:label[0].length,callbacks:{
        onBatchEnd: async  (batch, logs) =>{ 
            console.log(logs.loss.toFixed(5))
        }
    }})
    return data
}

Model

export function buildModel(){
    model = tf.sequential({layers:[ 
        tf.layers.conv2d({inputShape:[width,height,3],
                            kernelSize:3,
                            filters:5, 
                            activation :"relu"}),
        tf.layers.flatten(),
        tf.layers.dense({units:128, activation:"relu",useBias:true}),
        tf.layers.dense({units:32, activation:"relu"}),
        tf.layers.dense({units:4, activation:"softmax"})
    ]})
    model.compile({metrics:["accuracy"], loss:"categoricalCrossentropy", optimizer:"adam",learningRate:.00001})
    console.log(model.summary())
}

Predicting

export async function predict(img){

    let pred = await tf.tidy(() => {

        img = img.reshape([1,width,height, 3]);

        const output = model.predict(img);

        let predictions = Array.from(output.dataSync());
        return predictions
    })
    return pred
}

The callback prints the losses but they do no converge to anything and the predictions are way off (random)

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

Is the model used the right one ?

The first question one needs to ask is if the model used is the right one. The model of the question uses a mixture of convolutional and dense layers. But the model does not really follow the structure of CNN whereas convolutional layers are always followed by pooling layers. Is it the reason why the model is not learning ? Not necessary ...

In classification problems, there are different ways of classifying images each one with its pros and cons. FCNN does not achieve good accuracy, CNN does. But training a CNN model can be computation expensive. This is where transfer learning comes to play.

The pacman example uses transfer learning. So if you want to replicate the example, consider following the github code of tfjs example. The model here uses only one convolutional layer. There are good tutorials on the official website of tensorflow as regard how to write CNN networks and transfer-learning models.


How much data did you use to train your model on ?

Deep learning models in general needs a lot of data. So unless the model has seen a lot of images labelled, it won't be surprising if its accuracy is very low. How much data is needed is mostly a question of art and design than science. But a general rule of thumb, more there is data, better is the model in predicting.


Tuning model

Even a good model needs its parameter to be tuned - number of epochs, batchsize, learning rate, optimizer, loss function... Changing those parameters and observe how they account for the accuracy is a step in having good accuracy.

To point out, there is no such a thing as learning rate in the object passed as parameter of model.compile


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

1.4m articles

1.4m replys

5 comments

57.0k users

...