The base Tuner class
- Original Link : https://keras.io/api/keras_tuner/tuners/base_tuner/
- Last Checked at : 2024-11-25
Tuner
class
keras_tuner.Tuner(
oracle,
hypermodel=None,
max_model_size=None,
optimizer=None,
loss=None,
metrics=None,
distribution_strategy=None,
directory=None,
project_name=None,
logger=None,
tuner_id=None,
overwrite=False,
executions_per_trial=1,
**kwargs
)
Tuner class for Keras models.
This is the base Tuner
class for all tuners for Keras models. It manages
the building, training, evaluation and saving of the Keras models. New
tuners can be created by subclassing the class.
All Keras related logics are in Tuner.run_trial()
and its subroutines.
When subclassing Tuner
, if not calling super().run_trial()
, it can tune
anything.
Arguments
- oracle: Instance of
Oracle
class. - hypermodel: Instance of
HyperModel
class (or callable that takes hyperparameters and returns aModel
instance). It is optional whenTuner.run_trial()
is overriden and does not useself.hypermodel
. - max_model_size: Integer, maximum number of scalars in the parameters of a model. Models larger than this are rejected.
- optimizer: Optional optimizer. It is used to override the
optimizer
argument in thecompile
step for the models. If the hypermodel does not compile the models it generates, then this argument must be specified. - loss: Optional loss. May be used to override the
loss
argument in thecompile
step for the models. If the hypermodel does not compile the models it generates, then this argument must be specified. - metrics: Optional metrics. May be used to override the
metrics
argument in thecompile
step for the models. If the hypermodel does not compile the models it generates, then this argument must be specified. - distribution_strategy: Optional instance of
tf.distribute.Strategy
. If specified, each trial will run under this scope. For example,tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
will run each trial on two GPUs. Currently only single-worker strategies are supported. - directory: A string, the relative path to the working directory.
- project_name: A string, the name to use as prefix for files saved by
this
Tuner
. - tuner_id: Optional string, used as the ID of this
Tuner
. - overwrite: Boolean, defaults to
False
. IfFalse
, reloads an existing project of the same name if one is found. Otherwise, overwrites the project. - executions_per_trial: Integer, the number of executions (training a model from scratch, starting from a new initialization) to run per trial (model configuration). Model metrics may vary greatly depending on random initialization, hence it is often a good idea to run several executions per trial in order to evaluate the performance of a given set of hyperparameter values.
- **kwargs: Arguments for
BaseTuner
.
Attributes
- remaining_trials: Number of trials remaining,
None
ifmax_trials
is not set. This is useful when resuming a previously stopped search.
get_best_hyperparameters
method
Tuner.get_best_hyperparameters(num_trials=1)
Returns the best hyperparameters, as determined by the objective.
This method can be used to reinstantiate the (untrained) best model found during the search process.
Example
best_hp = tuner.get_best_hyperparameters()[0]
model = tuner.hypermodel.build(best_hp)
Arguments
- num_trials: Optional number of
HyperParameters
objects to return.
Returns
List of HyperParameter
objects sorted from the best to the worst.
get_best_models
method
Tuner.get_best_models(num_models=1)
Returns the best model(s), as determined by the tuner’s objective.
The models are loaded with the weights corresponding to their best checkpoint (at the end of the best epoch of best trial).
This method is for querying the models trained during the search.
For best performance, it is recommended to retrain your Model on the
full dataset using the best hyperparameters found during search
,
which can be obtained using tuner.get_best_hyperparameters()
.
Arguments
- num_models: Optional number of best models to return. Defaults to 1.
Returns
List of trained model instances sorted from the best to the worst.
get_state
method
Tuner.get_state()
Returns the current state of this object.
This method is called during save
.
Returns
A dictionary of serializable objects as the state.
load_model
method
Tuner.load_model(trial)
Loads a Model from a given trial.
For models that report intermediate results to the Oracle
, generally
load_model
should load the best reported step
by relying of
trial.best_step
.
Arguments
- trial: A
Trial
instance, theTrial
corresponding to the model to load.
on_epoch_begin
method
Tuner.on_epoch_begin(trial, model, epoch, logs=None)
Called at the beginning of an epoch.
Arguments
- trial: A
Trial
instance. - model: A Keras
Model
. - epoch: The current epoch number.
- logs: Additional metrics.
on_batch_begin
method
Tuner.on_batch_begin(trial, model, batch, logs)
Called at the beginning of a batch.
Arguments
- trial: A
Trial
instance. - model: A Keras
Model
. - batch: The current batch number within the current epoch.
- logs: Additional metrics.
on_batch_end
method
Tuner.on_batch_end(trial, model, batch, logs=None)
Called at the end of a batch.
Arguments
- trial: A
Trial
instance. - model: A Keras
Model
. - batch: The current batch number within the current epoch.
- logs: Additional metrics.
on_epoch_end
method
Tuner.on_epoch_end(trial, model, epoch, logs=None)
Called at the end of an epoch.
Arguments
- trial: A
Trial
instance. - model: A Keras
Model
. - epoch: The current epoch number.
- logs: Dict. Metrics for this epoch. This should include the value of the objective for this epoch.
run_trial
method
Tuner.run_trial(trial, )
Evaluates a set of hyperparameter values.
This method is called multiple times during search
to build and
evaluate the models with different hyperparameters and return the
objective value.
Example
You can use it with self.hypermodel
to build and fit the model.
def run_trial(self, trial, *args, **kwargs):
hp = trial.hyperparameters
model = self.hypermodel.build(hp)
return self.hypermodel.fit(hp, model, *args, **kwargs)
You can also use it as a black-box optimizer for anything.
def run_trial(self, trial, *args, **kwargs):
hp = trial.hyperparameters
x = hp.Float("x", -2.0, 2.0)
y = x * x + 2 * x + 1
return y
Arguments
- trial: A
Trial
instance that contains the information needed to run this trial. Hyperparameters can be accessed viatrial.hyperparameters
. - *args: Positional arguments passed by
search
. - **kwargs: Keyword arguments passed by
search
.
Returns
A History
object, which is the return value of model.fit()
, a
dictionary, a float, or a list of one of these types.
If return a dictionary, it should be a dictionary of the metrics to
track. The keys are the metric names, which contains the
objective
name. The values should be the metric values.
If return a float, it should be the objective
value.
If evaluating the model for multiple times, you may return a list of results of any of the types above. The final objective value is the average of the results in the list.
results_summary
method
Tuner.results_summary(num_trials=10)
Display tuning results summary.
The method prints a summary of the search results including the hyperparameter values and evaluation results for each trial.
Arguments
- num_trials: Optional number of trials to display. Defaults to 10.
save_model
method
Tuner.save_model(trial_id, model, step=0)
Saves a Model for a given trial.
Arguments
- trial_id: The ID of the
Trial
corresponding to this Model. - model: The trained model.
- step: Integer, for models that report intermediate results to the
Oracle
, the step the saved file correspond to. For example, for Keras models this is the number of epochs trained.
search
method
Tuner.search(*fit_args, **fit_kwargs)
Performs a search for best hyperparameter configuations.
Arguments
- *fit_args: Positional arguments that should be passed to
run_trial
, for example the training and validation data. - **fit_kwargs: Keyword arguments that should be passed to
run_trial
, for example the training and validation data.
search_space_summary
method
Tuner.search_space_summary(extended=False)
Print search space summary.
The methods prints a summary of the hyperparameters in the search
space, which can be called before calling the search
method.
Arguments
- extended: Optional boolean, whether to display an extended summary. Defaults to False.
set_state
method
Tuner.set_state(state)
Sets the current state of this object.
This method is called during reload
.
Arguments
- state: A dictionary of serialized objects as the state to restore.