Sklearn Tuner
- Original Link : https://keras.io/api/keras_tuner/tuners/sklearn/
- Last Checked at : 2024-11-25
SklearnTuner
class
keras_tuner.SklearnTuner(
oracle, hypermodel, scoring=None, metrics=None, cv=None, **kwargs
)
Tuner for Scikit-learn Models.
Performs cross-validated hyperparameter search for Scikit-learn models.
Examples
import keras_tuner
from sklearn import ensemble
from sklearn import datasets
from sklearn import linear_model
from sklearn import metrics
from sklearn import model_selection
def build_model(hp):
model_type = hp.Choice('model_type', ['random_forest', 'ridge'])
if model_type == 'random_forest':
model = ensemble.RandomForestClassifier(
n_estimators=hp.Int('n_estimators', 10, 50, step=10),
max_depth=hp.Int('max_depth', 3, 10))
else:
model = linear_model.RidgeClassifier(
alpha=hp.Float('alpha', 1e-3, 1, sampling='log'))
return model
tuner = keras_tuner.tuners.SklearnTuner(
oracle=keras_tuner.oracles.BayesianOptimizationOracle(
objective=keras_tuner.Objective('score', 'max'),
max_trials=10),
hypermodel=build_model,
scoring=metrics.make_scorer(metrics.accuracy_score),
cv=model_selection.StratifiedKFold(5),
directory='.',
project_name='my_project')
X, y = datasets.load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = model_selection.train_test_split(
X, y, test_size=0.2)
tuner.search(X_train, y_train)
best_model = tuner.get_best_models(num_models=1)[0]
Arguments
- oracle: A
keras_tuner.Oracle
instance. Note that for thisTuner
, theobjective
for theOracle
should always be set toObjective('score', direction='max')
. Also,Oracle
s that exploit Neural-Network-specific training (e.g.Hyperband
) should not be used with thisTuner
. - hypermodel: A
HyperModel
instance (or callable that takes hyperparameters and returns a Model instance). - scoring: An sklearn
scoring
function. For more information, seesklearn.metrics.make_scorer
. If not provided, the Model’s default scoring will be used viamodel.score
. Note that if you are searching across different Model families, the default scoring for these Models will often be different. In this case you should supplyscoring
here in order to make sure your Models are being scored on the same metric. - metrics: Additional
sklearn.metrics
functions to monitor during search. Note that these metrics do not affect the search process. - cv: An
sklearn.model_selection
Splitter class. Used to determine how samples are split up into groups for cross-validation. - **kwargs: Keyword arguments relevant to all
Tuner
subclasses. Please see the docstring forTuner
.