Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
287 views
in Technique[技术] by (71.8m points)

python - tensorflow: saving and restoring session

I am trying to implement a suggestion from answers: Tensorflow: how to save/restore a model?

I have an object which wraps a tensorflow model in a sklearn style.

import tensorflow as tf
class tflasso():
    saver = tf.train.Saver()
    def __init__(self,
                 learning_rate = 2e-2,
                 training_epochs = 5000,
                    display_step = 50,
                    BATCH_SIZE = 100,
                    ALPHA = 1e-5,
                    checkpoint_dir = "./",
             ):
        ...

    def _create_network(self):
       ...


    def _load_(self, sess, checkpoint_dir = None):
        if checkpoint_dir:
            self.checkpoint_dir = checkpoint_dir

        print("loading a session")
        ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            self.saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            raise Exception("no checkpoint found")
        return

    def fit(self, train_X, train_Y , load = True):
        self.X = train_X
        self.xlen = train_X.shape[1]
        # n_samples = y.shape[0]

        self._create_network()
        tot_loss = self._create_loss()
        optimizer = tf.train.AdagradOptimizer( self.learning_rate).minimize(tot_loss)

        # Initializing the variables
        init = tf.initialize_all_variables()
        " training per se"
        getb = batchgen( self.BATCH_SIZE)

        yvar = train_Y.var()
        print(yvar)
        # Launch the graph
        NUM_CORES = 3  # Choose how many cores to use.
        sess_config = tf.ConfigProto(inter_op_parallelism_threads=NUM_CORES,
                                                           intra_op_parallelism_threads=NUM_CORES)
        with tf.Session(config= sess_config) as sess:
            sess.run(init)
            if load:
                self._load_(sess)
            # Fit all training data
            for epoch in range( self.training_epochs):
                for (_x_, _y_) in getb(train_X, train_Y):
                    _y_ = np.reshape(_y_, [-1, 1])
                    sess.run(optimizer, feed_dict={ self.vars.xx: _x_, self.vars.yy: _y_})
                # Display logs per epoch step
                if (1+epoch) % self.display_step == 0:
                    cost = sess.run(tot_loss,
                            feed_dict={ self.vars.xx: train_X,
                                    self.vars.yy: np.reshape(train_Y, [-1, 1])})
                    rsq =  1 - cost / yvar
                    logstr = "Epoch: {:4d}cost = {:.4f}R^2 = {:.4f}".format((epoch+1), cost, rsq)
                    print(logstr )
                    self.saver.save(sess, self.checkpoint_dir + 'model.ckpt',
                       global_step= 1+ epoch)

            print("Optimization Finished!")
        return self

When I run:

tfl = tflasso()
tfl.fit( train_X, train_Y , load = False)

I get output:

Epoch:   50 cost = 38.4705  R^2 = -1.2036
    b1: 0.118122
Epoch:  100 cost = 26.4506  R^2 = -0.5151
    b1: 0.133597
Epoch:  150 cost = 22.4330  R^2 = -0.2850
    b1: 0.142261
Epoch:  200 cost = 20.0361  R^2 = -0.1477
    b1: 0.147998

However, when I try to recover the parameters (even without killing the object): tfl.fit( train_X, train_Y , load = True)

I get strange results. First of all, the loaded value does not correspond the saved one.

loading a session
loaded b1: 0.1          <------- Loaded another value than saved
Epoch:   50 cost = 30.8483  R^2 = -0.7670
    b1: 0.137484  

What is the right way to load, and probably first inspect the saved variables?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

TL;DR: You should try to rework this class so that self.create_network() is called (i) only once, and (ii) before the tf.train.Saver() is constructed.

There are two subtle issues here, which are due to the code structure, and the default behavior of the tf.train.Saver constructor. When you construct a saver with no arguments (as in your code), it collects the current set of variables in your program, and adds ops to the graph for saving and restoring them. In your code, when you call tflasso(), it will construct a saver, and there will be no variables (because create_network() has not yet been called). As a result, the checkpoint should be empty.

The second issue is that—by default—the format of a saved checkpoint is a map from the name property of a variable to its current value. If you create two variables with the same name, they will be automatically "uniquified" by TensorFlow:

v = tf.Variable(..., name="weights")
assert v.name == "weights"
w = tf.Variable(..., name="weights")
assert v.name == "weights_1"  # The "_1" is added by TensorFlow.

The consequence of this is that, when you call self.create_network() in the second call to tfl.fit(), the variables will all have different names from the names that are stored in the checkpoint—or would have been if the saver had been constructed after the network. (You can avoid this behavior by passing a name-Variable dictionary to the saver constructor, but this is usually quite awkward.)

There are two main workarounds:

  1. In each call to tflasso.fit(), create the whole model afresh, by defining a new tf.Graph, then in that graph building the network and creating a tf.train.Saver.

  2. RECOMMENDED Create the network, then the tf.train.Saver in the tflasso constructor, and reuse this graph on each call to tflasso.fit(). Note that you might need to do some more work to reorganize things (in particular, I'm not sure what you do with self.X and self.xlen) but it should be possible to achieve this with placeholders and feeding.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...