train

QRNN.train(training_data, validation_data=None, batch_size=256, sigma_noise=None, adversarial_training=False, delta_at=0.01, initial_learning_rate=0.01, momentum=0.0, convergence_epochs=5, learning_rate_decay=2.0, learning_rate_minimum=1e-06, maximum_epochs=200, training_split=0.9, gpu=False)[source]

Train model on given training data.

The training is performed on the provided training data and an optionally-provided validation set. Training can use the following augmentation methods:

  • Gaussian noise added to input

  • Adversarial training

The learning rate is decreased gradually when the validation or training loss did not decrease for a given number of epochs.

Parameters:
  • training_data – Tuple of numpy arrays of a dataset object to use to train the model.

  • validation_data – Optional validation data in the same format as the training data.

  • batch_size – If training data is provided as arrays, this batch size will be used to for the training.

  • sigma_noise – If training data is provided as arrays, training data will be augmented by adding noise with the given standard deviations to each input vector before it is presented to the model.

  • adversarial_training (bool) – Whether or not to perform adversarial training using the fast gradient sign method.

  • delta_at – The scaling factor to apply for adversarial training.

  • initial_learning_rate (float) – The learning rate with which the training is started.

  • momentum (float) – The momentum to use for training.

  • convergence_epochs (int) – The number of epochs with non-decreasing loss before the learning rate is decreased

  • learning_rate_decay (float) – The factor by which the learning rate is decreased.

  • learning_rate_minimum (float) – The learning rate at which the training is aborted.

  • maximum_epochs (int) – For how many epochs to keep training.

  • training_split (float) – If no validation data is provided, this is the fraction of training data that is used for validation.

  • gpu (bool) – Whether or not to try to run the training on the GPU.