Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
226 views
in Technique[技术] by (71.8m points)

python - Add Samples after Partial Training in PyTorch

I have trained a model in PyTorch - an RCNN for text classification. The model has very high precision and recall, but I may eventually receive new documents with text unlike what I used to train, validate, or test the model.

I would like to add new text samples to the model without retraining the model from the beginning. This is desirable because I may lose access to some of the text used for initial training.

If it is not possible to add samples (documents), is it possible to train a new model on only the new samples and then somehow combine the original model and the new model? How?

Here is what my model looks like.

RCNN(
  (embeddings): Embedding(10661, 300)
  (lstm): LSTM(300, 64, bidirectional=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (W): Linear(in_features=428, out_features=64, bias=True)
  (tanh): Tanh()
  (fc): Linear(in_features=64, out_features=3, bias=True)
  (softmax): Softmax(dim=1)
  (loss_op): NLLLoss()
)

I am aware of techniques for saving the model and the corresponding load techniques.

  • State Dictionary:torch.save(model.state_dict(), PATH)
  • Model:torch.save(model, PATH)
  • Checkpoint:torch.save({'epoch': EPOCH, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': LOSS,}, PATH)

I can continue training based on the original samples, but I do not know how to add samples.

If this is something TensorFlow can do by PyTorch cannot, I might switch to TensorFlow.

question from:https://stackoverflow.com/questions/65852250/add-samples-after-partial-training-in-pytorch

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

Assuming you have your model's state saved in some file PATH, you can load it back in memory with torch.load. Either on CPU or CUDA device, by default it will be loaded on the device it was on when torch.save was called).

state_dict = torch.load(PATH)
model.load_state_dict(state_dict)

Assuming model is an instance of the same nn.Module class that was used to save the state on PATH. Now model will have an identical state (same parameter weights/biases) as when it was saved on PATH with torch.save. From there you can call model and finetune on new data.

Note: You can load it directly on the desired device by passing a torch.device to torch.load's map_location argument.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...