Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
185 views
in Technique[技术] by (71.8m points)

python - Pytorch: How to unflatten/get back the network from flattened network?

I am using the following function to flatten the network:

#############################################################################
# Flattening the NET
#############################################################################
def flattenNetwork(net):
    flatNet = []
    shapes = []
    for param in net.parameters():
        #if its WEIGHTS
        curr_shape = param.cpu().data.numpy().shape
        shapes.append(curr_shape)
        if len(curr_shape) == 2:
            param = param.cpu().data.numpy().reshape(curr_shape[0]*curr_shape[1])
            flatNet.append(param)
        elif len(curr_shape) == 4:
            param = param.cpu().data.numpy().reshape(curr_shape[0]*curr_shape[1]*curr_shape[2]*curr_shape[3])
            flatNet.append(param)
        else:
            param = param.cpu().data.numpy().reshape(curr_shape[0])
            flatNet.append(param)
    finalNet = []
    for obj in flatNet:
        for x in obj:
            finalNet.append(x)
    finalNet = np.array(finalNet)
    return finalNet,shapes

The above function returns all the weights as a numpy column vector finalNet and shapes (list) of the network. I want to see the effect of weight modifications on the prediction accuracy. So, I change the weights. How can I copy this modified weight vector back to the original network? Please help. Thank you.

question from:https://stackoverflow.com/questions/65941834/pytorch-how-to-unflatten-get-back-the-network-from-flattened-network

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

There is a difference between model definition (its forward function), and the parameter configuration (what's called model state, and is easily accessible as a dictionary using state_dict).

You can get a model's state, as you did with your implementation flattenNetwork. However reverting this operation (i.e. if you only have the weights and layer shapes), for pretty much all models, is not possible.

Now, assuming you do - still - have access to net. My advice is that work with net.state_dict() directly, modify it, then load the dictionary of weights back with load_state_dict. This way, you will avoid having to deal with serializing the model's parameters yourself.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...