Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
574 views
in Technique[技术] by (71.8m points)

numpy - Python scikit-learn: Cannot clone object... as the constructor does not seem to set parameter

I modified the BernoulliRBM class of scikit-learn to use groups of softmax visible units. In the process, I added an extra Numpy array visible_config as a class attribute which is initialized in the constructor as follows using:

self.visible_config = np.cumsum(np.concatenate((np.asarray([0]),
                                visible_config), axis=0))

where visible_config is a Numpy array passed as an input to the constructor. The code runs without errors when I directly use the fit() function to train the model. However, when I use the GridSearchCV structure, I get the following error

Cannot clone object SoftmaxRBM(batch_size=100, learning_rate=0.01, n_components=100, n_iter=100,
  random_state=0, verbose=True, visible_config=[ 0 21 42 63]), as the constructor does not seem to set parameter visible_config

This seems to be a problem in the equality check between the instance of the class and its copy created by sklearn.base.clone because visible_config does not get copied correctly. I'm not sure how to fix this. It says in the documentation that sklearn.base.clone uses a deepcopy(), so shouldn't visible_config also get copied? Can someone please explain what I can try here? Thanks!

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

Without seeing your code, it's hard to tell exactly what goes wrong, but you are violating a scikit-learn API convention here. The constructor in an estimator should only set attributes to the values the user passes as arguments. All computation should occur in fit, and if fit needs to store the result of a computation, it should do so in an attribute with a trailing underscore (_). This convention is what makes clone and meta-estimators such as GridSearchCV work.

(*) If you ever see an estimator in the main codebase that violates this rule: that would be a bug, and patches are welcome.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...