make_train_function
- FullyConnected.make_train_function(force=False)
Creates a function that executes one step of training.
This method can be overridden to support custom training logic. This method is called by Model.fit and Model.train_on_batch.
Typically, this method directly controls tf.function and tf.distribute.Strategy settings, and delegates the actual training logic to Model.train_step.
This function is cached the first time Model.fit or Model.train_on_batch is called. The cache is cleared whenever Model.compile is called. You can skip the cache and generate again the function with force=True.
- Parameters:
force – Whether to regenerate the train function and skip the cached function if available.
- Returns:
Function. The function created by this method should accept a tf.data.Iterator, and return a dict containing values that will be passed to tf.keras.Callbacks.on_train_batch_end, such as {‘loss’: 0.2, ‘accuracy’: 0.7}.