W3cubDocs

/TensorFlow Python

tf.keras.utils.Sequence

Class Sequence

Defined in tensorflow/python/keras/_impl/keras/utils/data_utils.py.

Base object for fitting to a sequence of data, such as a dataset.

Every Sequence must implements the __getitem__ and the __len__ methods.

Examples:

from skimage.io import imread
from skimage.transform import resize
import numpy as np

# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.

class CIFAR10Sequence(Sequence):
    def __init__(self, x_set, y_set, batch_size):
        self.X,self.y = x_set,y_set
        self.batch_size = batch_size

    def __len__(self):
        return len(self.X) // self.batch_size

    def __getitem__(self,idx):
        batch_x = self.X[idx*self.batch_size:(idx+1)*self.batch_size]
        batch_y = self.y[idx*self.batch_size:(idx+1)*self.batch_size]

        return np.array([
            resize(imread(file_name), (200,200))
               for file_name in batch_x]), np.array(batch_y)

Methods

__getitem__

__getitem__(index)

Gets batch at position index.

Arguments:

  • index: position of the batch in the Sequence.

Returns:

A batch

__len__

__len__()

Number of batch in the Sequence.

Returns:

The number of batches in the Sequence.

on_epoch_end

on_epoch_end()

Method called at the end of every epoch.

© 2017 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence