train_on_batch
- FullyConnected.train_on_batch(x, y=None, sample_weight=None, class_weight=None, reset_metrics=True, return_dict=False)
Runs a single gradient update on a single batch of data.
- Parameters:
x –
Input data. It could be: - A Numpy array (or array-like), or a list of arrays
(in case the model has multiple inputs).
- A TensorFlow tensor, or a list of tensors
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
y – Target data. Like the input data x, it could be either Numpy array(s) or TensorFlow tensor(s).
sample_weight – Optional array of the same length as x, containing weights to apply to the model’s loss for each sample. In the case of temporal data, you can pass a 2D array with shape (samples, sequence_length), to apply a different weight to every timestep of every sample.
class_weight – Optional dictionary mapping class indices (integers) to a weight (float) to apply to the model’s loss for the samples from this class during training. This can be useful to tell the model to “pay more attention” to samples from an under-represented class. When class_weight is specified and targets have a rank of 2 or greater, either y must be one-hot encoded, or an explicit final dimension of 1 must be included for sparse class labels.
reset_metrics – If True, the metrics returned will be only for this batch. If False, the metrics will be statefully accumulated across batches.
return_dict – If True, loss and metric results are returned as a dict, with each key being the name of the metric. If False, they are returned as a list.
- Returns:
Scalar training loss (if the model has a single output and no metrics) or list of scalars (if the model has multiple outputs and/or metrics). The attribute model.metrics_names will give you the display labels for the scalar outputs.
- Raises:
RuntimeError – If model.train_on_batch is wrapped in a tf.function.