Sklearn Tuner
- 원본 링크 : https://keras.io/api/keras_tuner/tuners/sklearn/
- 최종 확인 : 2024-11-25
SklearnTuner class
keras_tuner.SklearnTuner(
oracle, hypermodel, scoring=None, metrics=None, cv=None, **kwargs
)Tuner for Scikit-learn Models.
Performs cross-validated hyperparameter search for Scikit-learn models.
Examples
import keras_tuner
from sklearn import ensemble
from sklearn import datasets
from sklearn import linear_model
from sklearn import metrics
from sklearn import model_selection
def build_model(hp):
model_type = hp.Choice('model_type', ['random_forest', 'ridge'])
if model_type == 'random_forest':
model = ensemble.RandomForestClassifier(
n_estimators=hp.Int('n_estimators', 10, 50, step=10),
max_depth=hp.Int('max_depth', 3, 10))
else:
model = linear_model.RidgeClassifier(
alpha=hp.Float('alpha', 1e-3, 1, sampling='log'))
return model
tuner = keras_tuner.tuners.SklearnTuner(
oracle=keras_tuner.oracles.BayesianOptimizationOracle(
objective=keras_tuner.Objective('score', 'max'),
max_trials=10),
hypermodel=build_model,
scoring=metrics.make_scorer(metrics.accuracy_score),
cv=model_selection.StratifiedKFold(5),
directory='.',
project_name='my_project')
X, y = datasets.load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = model_selection.train_test_split(
X, y, test_size=0.2)
tuner.search(X_train, y_train)
best_model = tuner.get_best_models(num_models=1)[0]Arguments
- oracle: A
keras_tuner.Oracleinstance. Note that for thisTuner, theobjectivefor theOracleshould always be set toObjective('score', direction='max'). Also,Oracles that exploit Neural-Network-specific training (e.g.Hyperband) should not be used with thisTuner. - hypermodel: A
HyperModelinstance (or callable that takes hyperparameters and returns a Model instance). - scoring: An sklearn
scoringfunction. For more information, seesklearn.metrics.make_scorer. If not provided, the Model’s default scoring will be used viamodel.score. Note that if you are searching across different Model families, the default scoring for these Models will often be different. In this case you should supplyscoringhere in order to make sure your Models are being scored on the same metric. - metrics: Additional
sklearn.metricsfunctions to monitor during search. Note that these metrics do not affect the search process. - cv: An
sklearn.model_selectionSplitter class. Used to determine how samples are split up into groups for cross-validation. - **kwargs: Keyword arguments relevant to all
Tunersubclasses. Please see the docstring forTuner.