Into DeepTrain

DeepTrain requires only (1) a compiled model and (2) data directory to run. This example covers these and a bit more to keep truer to standard use.

[1]:
import os
from tensorflow.keras.layers import Input, Dense, Conv2D, MaxPooling2D
from tensorflow.keras.layers import Flatten, Activation
from tensorflow.keras.models import Model

from deeptrain import TrainGenerator, DataGenerator

Model maker

Begin by defining a model maker function. Input should specify hyperparameters, optimizer, learning rate, etc; this is the “blueprint” which is later saved.

[2]:
def make_model(batch_shape, optimizer, loss, metrics, num_classes,
               filters, kernel_size):
    ipt = Input(batch_shape=batch_shape)

    x = Conv2D(filters, kernel_size, activation='relu', padding='same')(ipt)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Flatten()(x)
    x = Dense(num_classes)(x)

    out = Activation('softmax')(x)

    model = Model(ipt, out)
    model.compile(optimizer, loss, metrics=metrics)
    return model

Model configs

Define configs dictionary to feed as **kwargs to make_model; we’ll also pass it to TrainGenerator, which will save it and show in a “report” for easy reference

[3]:
batch_size = 128
width, height, channels = 28, 28, 1  # MNIST dims (28 x 28 pixels, greyscale)

MODEL_CFG = dict(
    batch_shape=(batch_size, width, height, channels),
    loss='categorical_crossentropy',
    metrics=['accuracy'],
    optimizer='adam',
    num_classes=10,
    filters=16,
    kernel_size=(3, 3),
)

DataGenerator (train) configs

  • data_path: directory where image data is located
  • labels_path: where labels file is located
  • batch_size: number of samples to feed at once to model
  • shuffle: whether to shuffle data at end of each epoch
  • superbatch_set_nums: which files to load into a superbatch, which holds batches persisently in memory (as opposed to batch, which is overwritten after use). Since MNIST is small, we can load it all into RAM.
[4]:
datadir = os.path.join("dir", "data", "image")
DATAGEN_CFG = dict(
    data_path=os.path.join(datadir, 'train'),
    labels_path=os.path.join(datadir, 'train', 'labels.h5'),
    batch_size=batch_size,
    shuffle=True,
    superbatch_set_nums='all',
)

DataGenerator (validation) configs

[5]:
VAL_DATAGEN_CFG = dict(
    data_path=os.path.join(datadir, 'val'),
    labels_path=os.path.join(datadir, 'val', 'labels.h5'),
    batch_size=batch_size,
    shuffle=False,
    superbatch_set_nums='all',
)

TrainGenerator configs

  • epochs: number of epochs to train for
  • logs_dir: where to save TrainGenerator state, model, report, and history
  • best_models_dir: where to save model when it achieves new best validation performance
  • model_configs: model configurations dict to save & write to report
[6]:
TRAINGEN_CFG = dict(
    epochs=3,
    logs_dir=os.path.join('dir', 'logs'),
    best_models_dir=os.path.join('dir', 'models'),
    model_configs=MODEL_CFG,
)

Create training objects

[7]:
model       = make_model(**MODEL_CFG)
datagen     = DataGenerator(**DATAGEN_CFG)
val_datagen = DataGenerator(**VAL_DATAGEN_CFG)
traingen    = TrainGenerator(model, datagen, val_datagen, **TRAINGEN_CFG)
Discovered 48 files with matching format
Discovered dataset with matching format
48 set nums inferred; if more are expected, ensure file names contain a common substring w/ a number (e.g. 'train1.npy', 'train2.npy', etc)
DataGenerator initiated

Discovered 36 files with matching format
Discovered dataset with matching format
36 set nums inferred; if more are expected, ensure file names contain a common substring w/ a number (e.g. 'train1.npy', 'train2.npy', etc)
DataGenerator initiated

NOTE: no existing models detected in dir\logs; starting model_num from '0'
Preloading superbatch ... Discovered 48 files with matching format
................................................ finished, w/ 6144 total samples
Train initial data prepared
Preloading superbatch ... Discovered 36 files with matching format
.................................... finished, w/ 4608 total samples
Val initial data prepared
NOTE: no existing models detected in dir\logs; starting model_num from '0'
Logging ON; directory (new): dir\logs\M0__model-adam__min999.000

Train

[8]:
traingen.train()

Fitting set 1...   (Loss, Acc) = (2.297301, 0.062500)
Fitting set 2...   (Loss, Acc) = (2.292228, 0.078125)
Fitting set 3...   (Loss, Acc) = (2.280833, 0.122396)
Fitting set 4...   (Loss, Acc) = (2.268434, 0.152344)
Fitting set 5...   (Loss, Acc) = (2.251584, 0.187500)
Fitting set 6...   (Loss, Acc) = (2.239864, 0.201823)
Fitting set 7...   (Loss, Acc) = (2.228770, 0.229911)
Fitting set 8...   (Loss, Acc) = (2.214706, 0.265625)
Fitting set 9...   (Loss, Acc) = (2.201900, 0.287326)
Fitting set 10...  (Loss, Acc) = (2.189120, 0.307812)
Fitting set 11...  (Loss, Acc) = (2.170205, 0.334517)
Fitting set 12...  (Loss, Acc) = (2.154839, 0.352214)
Fitting set 13...  (Loss, Acc) = (2.142305, 0.362981)
Fitting set 14...  (Loss, Acc) = (2.126470, 0.380580)
Fitting set 15...  (Loss, Acc) = (2.110874, 0.400521)
Fitting set 16...  (Loss, Acc) = (2.096904, 0.412109)
Fitting set 17...  (Loss, Acc) = (2.080665, 0.428768)
Fitting set 18...  (Loss, Acc) = (2.065265, 0.440972)
Fitting set 19...  (Loss, Acc) = (2.049479, 0.448602)
Fitting set 20...  (Loss, Acc) = (2.032223, 0.460547)
Fitting set 21...  (Loss, Acc) = (2.016439, 0.473586)
Fitting set 22...  (Loss, Acc) = (1.996922, 0.488281)
Fitting set 23...  (Loss, Acc) = (1.980086, 0.499321)
Fitting set 24...  (Loss, Acc) = (1.960719, 0.511393)
Fitting set 25...  (Loss, Acc) = (1.947011, 0.519375)
Fitting set 26...  (Loss, Acc) = (1.929373, 0.530349)
Fitting set 27...  (Loss, Acc) = (1.909454, 0.540799)
Fitting set 28...  (Loss, Acc) = (1.894096, 0.546596)
Fitting set 29...  (Loss, Acc) = (1.874574, 0.556034)
Fitting set 30...  (Loss, Acc) = (1.858988, 0.562500)
Fitting set 31...  (Loss, Acc) = (1.840485, 0.569052)
Fitting set 32...  (Loss, Acc) = (1.821290, 0.575928)
Fitting set 33...  (Loss, Acc) = (1.801637, 0.583333)
Fitting set 34...  (Loss, Acc) = (1.783776, 0.588006)
Fitting set 35...  (Loss, Acc) = (1.766005, 0.593080)
Fitting set 36...  (Loss, Acc) = (1.746479, 0.598958)
Fitting set 37...  (Loss, Acc) = (1.727118, 0.604730)
Fitting set 38...  (Loss, Acc) = (1.709816, 0.609375)
Fitting set 39...  (Loss, Acc) = (1.692258, 0.614784)
Fitting set 40...  (Loss, Acc) = (1.674974, 0.618164)
Fitting set 41...  (Loss, Acc) = (1.657619, 0.621570)
Fitting set 42...  (Loss, Acc) = (1.641050, 0.625744)
Fitting set 43...  (Loss, Acc) = (1.624397, 0.629542)
Fitting set 44...  (Loss, Acc) = (1.608801, 0.633523)
Fitting set 45...  (Loss, Acc) = (1.593297, 0.636458)
Fitting set 46...  (Loss, Acc) = (1.577476, 0.639606)
Fitting set 47...  (Loss, Acc) = (1.560293, 0.643451)
Fitting set 48...  (Loss, Acc) = (1.547326, 0.645996)
Data set_nums shuffled


_____________________
 EPOCH 1 -- COMPLETE 



Validating...
Validating set 1...  (Loss, Acc) = (0.783646, 0.835938)
Validating set 2...  (Loss, Acc) = (0.765299, 0.851562)
Validating set 3...  (Loss, Acc) = (0.769534, 0.851562)
Validating set 4...  (Loss, Acc) = (0.764879, 0.853516)
Validating set 5...  (Loss, Acc) = (0.758425, 0.851562)
Validating set 6...  (Loss, Acc) = (0.764723, 0.843750)
Validating set 7...  (Loss, Acc) = (0.764539, 0.842634)
Validating set 8...  (Loss, Acc) = (0.767474, 0.844727)
Validating set 9...  (Loss, Acc) = (0.767794, 0.845486)
Validating set 10... (Loss, Acc) = (0.763187, 0.843750)
Validating set 11... (Loss, Acc) = (0.755877, 0.843750)
Validating set 12... (Loss, Acc) = (0.755090, 0.841797)
Validating set 13... (Loss, Acc) = (0.744053, 0.846154)
Validating set 14... (Loss, Acc) = (0.741029, 0.847098)
Validating set 15... (Loss, Acc) = (0.746417, 0.844271)
Validating set 16... (Loss, Acc) = (0.747256, 0.845215)
Validating set 17... (Loss, Acc) = (0.746406, 0.846048)
Validating set 18... (Loss, Acc) = (0.746663, 0.845486)
Validating set 19... (Loss, Acc) = (0.744822, 0.846628)
Validating set 20... (Loss, Acc) = (0.744538, 0.844922)
Validating set 21... (Loss, Acc) = (0.745785, 0.841890)
Validating set 22... (Loss, Acc) = (0.746459, 0.841619)
Validating set 23... (Loss, Acc) = (0.749001, 0.838995)
Validating set 24... (Loss, Acc) = (0.748322, 0.837565)
Validating set 25... (Loss, Acc) = (0.748221, 0.837500)
Validating set 26... (Loss, Acc) = (0.748553, 0.837740)
Validating set 27... (Loss, Acc) = (0.750275, 0.837384)
Validating set 28... (Loss, Acc) = (0.749730, 0.837612)
Validating set 29... (Loss, Acc) = (0.750421, 0.836476)
Validating set 30... (Loss, Acc) = (0.751461, 0.835677)
Validating set 31... (Loss, Acc) = (0.752642, 0.834677)
Validating set 32... (Loss, Acc) = (0.752270, 0.834473)
Validating set 33... (Loss, Acc) = (0.751130, 0.835227)
Validating set 34... (Loss, Acc) = (0.751729, 0.834559)
Validating set 35... (Loss, Acc) = (0.756417, 0.832143)
Validating set 36... (Loss, Acc) = (0.757068, 0.832031)
TrainGenerator state saved
Model report generated and saved
Best model saved to dir\models\M0__model-adam__min.757
TrainGenerator state saved
Model report generated and saved
../_images/examples_basic_15_1.png

Fitting set 21...  (Loss, Acc) = (0.765751, 0.824219)
Fitting set 27...  (Loss, Acc) = (0.727158, 0.841797)
Fitting set 44...  (Loss, Acc) = (0.746188, 0.834635)
Fitting set 3...   (Loss, Acc) = (0.730069, 0.838867)
Fitting set 23...  (Loss, Acc) = (0.728389, 0.839844)
Fitting set 31...  (Loss, Acc) = (0.720288, 0.835286)
Fitting set 45...  (Loss, Acc) = (0.720948, 0.834263)
Fitting set 22...  (Loss, Acc) = (0.704299, 0.837402)
Fitting set 16...  (Loss, Acc) = (0.698179, 0.838108)
Fitting set 38...  (Loss, Acc) = (0.691929, 0.833984)
Fitting set 34...  (Loss, Acc) = (0.692961, 0.832031)
Fitting set 1...   (Loss, Acc) = (0.680281, 0.838216)
Fitting set 4...   (Loss, Acc) = (0.676445, 0.841046)
Fitting set 37...  (Loss, Acc) = (0.663750, 0.845145)
Fitting set 46...  (Loss, Acc) = (0.658366, 0.846094)
Fitting set 12...  (Loss, Acc) = (0.649041, 0.846924)
Fitting set 10...  (Loss, Acc) = (0.645140, 0.846278)
Fitting set 32...  (Loss, Acc) = (0.640361, 0.846137)
Fitting set 8...   (Loss, Acc) = (0.633576, 0.846423)
Fitting set 25...  (Loss, Acc) = (0.634411, 0.843555)
Fitting set 19...  (Loss, Acc) = (0.633591, 0.844308)
Fitting set 36...  (Loss, Acc) = (0.627558, 0.844993)
Fitting set 2...   (Loss, Acc) = (0.619685, 0.847317)
Fitting set 18...  (Loss, Acc) = (0.615301, 0.848796)
Fitting set 42...  (Loss, Acc) = (0.613810, 0.848906)
Fitting set 26...  (Loss, Acc) = (0.606216, 0.850511)
Fitting set 6...   (Loss, Acc) = (0.602131, 0.850550)
Fitting set 48...  (Loss, Acc) = (0.603196, 0.849191)
Fitting set 15...  (Loss, Acc) = (0.598294, 0.850350)
Fitting set 35...  (Loss, Acc) = (0.595915, 0.849349)
Fitting set 39...  (Loss, Acc) = (0.590949, 0.851436)
Fitting set 40...  (Loss, Acc) = (0.587879, 0.851929)
Fitting set 41...  (Loss, Acc) = (0.583732, 0.852391)
Fitting set 29...  (Loss, Acc) = (0.578788, 0.853516)
Fitting set 17...  (Loss, Acc) = (0.576870, 0.853906)
Fitting set 5...   (Loss, Acc) = (0.571825, 0.855360)
Fitting set 11...  (Loss, Acc) = (0.565289, 0.856524)
Fitting set 33...  (Loss, Acc) = (0.563101, 0.856394)
Fitting set 30...  (Loss, Acc) = (0.560148, 0.856671)
Fitting set 9...   (Loss, Acc) = (0.558437, 0.856348)
Fitting set 43...  (Loss, Acc) = (0.555807, 0.856612)
Fitting set 13...  (Loss, Acc) = (0.552969, 0.857422)
Fitting set 20...  (Loss, Acc) = (0.549530, 0.858376)
Fitting set 7...   (Loss, Acc) = (0.547364, 0.858754)
Fitting set 14...  (Loss, Acc) = (0.545547, 0.858767)
Fitting set 47...  (Loss, Acc) = (0.542079, 0.860139)
Fitting set 24...  (Loss, Acc) = (0.539119, 0.861287)
Fitting set 28...  (Loss, Acc) = (0.539579, 0.861247)
Data set_nums shuffled


_____________________
 EPOCH 2 -- COMPLETE 



Validating...
Validating set 1...  (Loss, Acc) = (0.355274, 0.921875)
Validating set 2...  (Loss, Acc) = (0.348840, 0.921875)
Validating set 3...  (Loss, Acc) = (0.387879, 0.914062)
Validating set 4...  (Loss, Acc) = (0.394017, 0.908203)
Validating set 5...  (Loss, Acc) = (0.399065, 0.903125)
Validating set 6...  (Loss, Acc) = (0.397540, 0.903646)
Validating set 7...  (Loss, Acc) = (0.392841, 0.906250)
Validating set 8...  (Loss, Acc) = (0.396953, 0.906250)
Validating set 9...  (Loss, Acc) = (0.396412, 0.907986)
Validating set 10... (Loss, Acc) = (0.399894, 0.908594)
Validating set 11... (Loss, Acc) = (0.389734, 0.909091)
Validating set 12... (Loss, Acc) = (0.388978, 0.908203)
Validating set 13... (Loss, Acc) = (0.378990, 0.910457)
Validating set 14... (Loss, Acc) = (0.379070, 0.908482)
Validating set 15... (Loss, Acc) = (0.388130, 0.907292)
Validating set 16... (Loss, Acc) = (0.391481, 0.905762)
Validating set 17... (Loss, Acc) = (0.385940, 0.906250)
Validating set 18... (Loss, Acc) = (0.386285, 0.904948)
Validating set 19... (Loss, Acc) = (0.382503, 0.905428)
Validating set 20... (Loss, Acc) = (0.380249, 0.906641)
Validating set 21... (Loss, Acc) = (0.385483, 0.903646)
Validating set 22... (Loss, Acc) = (0.383873, 0.904474)
Validating set 23... (Loss, Acc) = (0.386388, 0.903872)
Validating set 24... (Loss, Acc) = (0.383956, 0.902995)
Validating set 25... (Loss, Acc) = (0.384522, 0.903438)
Validating set 26... (Loss, Acc) = (0.385134, 0.903546)
Validating set 27... (Loss, Acc) = (0.388839, 0.902199)
Validating set 28... (Loss, Acc) = (0.388487, 0.902065)
Validating set 29... (Loss, Acc) = (0.389055, 0.903556)
Validating set 30... (Loss, Acc) = (0.389983, 0.902865)
Validating set 31... (Loss, Acc) = (0.391534, 0.902470)
Validating set 32... (Loss, Acc) = (0.391110, 0.902588)
Validating set 33... (Loss, Acc) = (0.391102, 0.902225)
Validating set 34... (Loss, Acc) = (0.388916, 0.902803)
Validating set 35... (Loss, Acc) = (0.395877, 0.900223)
Validating set 36... (Loss, Acc) = (0.395200, 0.900608)
TrainGenerator state saved
Model report generated and saved
Best model saved to dir\models\M0__model-adam__min.395
TrainGenerator state saved
Model report generated and saved
../_images/examples_basic_15_3.png

Fitting set 40...  (Loss, Acc) = (0.393554, 0.910156)
Fitting set 30...  (Loss, Acc) = (0.389306, 0.912109)
Fitting set 17...  (Loss, Acc) = (0.409685, 0.904948)
Fitting set 10...  (Loss, Acc) = (0.399228, 0.899414)
Fitting set 19...  (Loss, Acc) = (0.411247, 0.894531)
Fitting set 9...   (Loss, Acc) = (0.411084, 0.888672)
Fitting set 24...  (Loss, Acc) = (0.401326, 0.893415)
Fitting set 14...  (Loss, Acc) = (0.403191, 0.895020)
Fitting set 8...   (Loss, Acc) = (0.395733, 0.897135)
Fitting set 35...  (Loss, Acc) = (0.398654, 0.895703)
Fitting set 45...  (Loss, Acc) = (0.400620, 0.894531)
Fitting set 15...  (Loss, Acc) = (0.395943, 0.894857)
Fitting set 36...  (Loss, Acc) = (0.390992, 0.897536)
Fitting set 4...   (Loss, Acc) = (0.390960, 0.898717)
Fitting set 39...  (Loss, Acc) = (0.386528, 0.899740)
Fitting set 43...  (Loss, Acc) = (0.386909, 0.899170)
Fitting set 6...   (Loss, Acc) = (0.385125, 0.898667)
Fitting set 16...  (Loss, Acc) = (0.383201, 0.899523)
Fitting set 13...  (Loss, Acc) = (0.381569, 0.899877)
Fitting set 11...  (Loss, Acc) = (0.375037, 0.901367)
Fitting set 23...  (Loss, Acc) = (0.376780, 0.899740)
Fitting set 31...  (Loss, Acc) = (0.377788, 0.899680)
Fitting set 32...  (Loss, Acc) = (0.379219, 0.898607)
Fitting set 33...  (Loss, Acc) = (0.379287, 0.897949)
Fitting set 18...  (Loss, Acc) = (0.377729, 0.899844)
Fitting set 21...  (Loss, Acc) = (0.376463, 0.899790)
Fitting set 38...  (Loss, Acc) = (0.372582, 0.900608)
Fitting set 26...  (Loss, Acc) = (0.367968, 0.901088)
Fitting set 22...  (Loss, Acc) = (0.365197, 0.902613)
Fitting set 44...  (Loss, Acc) = (0.365697, 0.902734)
Fitting set 3...   (Loss, Acc) = (0.364059, 0.903604)
Fitting set 37...  (Loss, Acc) = (0.360719, 0.904419)
Fitting set 7...   (Loss, Acc) = (0.359887, 0.904711)
Fitting set 1...   (Loss, Acc) = (0.356046, 0.906135)
Fitting set 29...  (Loss, Acc) = (0.354691, 0.906585)
Fitting set 48...  (Loss, Acc) = (0.357704, 0.905707)
Fitting set 41...  (Loss, Acc) = (0.356611, 0.905933)
Fitting set 28...  (Loss, Acc) = (0.359539, 0.904708)
Fitting set 25...  (Loss, Acc) = (0.360531, 0.903145)
Fitting set 5...   (Loss, Acc) = (0.357812, 0.904395)
Fitting set 12...  (Loss, Acc) = (0.354390, 0.905393)
Fitting set 27...  (Loss, Acc) = (0.353634, 0.905599)
Fitting set 34...  (Loss, Acc) = (0.355857, 0.904524)
Fitting set 20...  (Loss, Acc) = (0.355108, 0.904741)
Fitting set 42...  (Loss, Acc) = (0.356337, 0.904427)
Fitting set 47...  (Loss, Acc) = (0.356071, 0.904806)
Fitting set 46...  (Loss, Acc) = (0.355477, 0.905502)
Fitting set 2...   (Loss, Acc) = (0.353324, 0.905843)
Data set_nums shuffled


_____________________
 EPOCH 3 -- COMPLETE 



Validating...
Validating set 1...  (Loss, Acc) = (0.248495, 0.945312)
Validating set 2...  (Loss, Acc) = (0.256513, 0.933594)
Validating set 3...  (Loss, Acc) = (0.303261, 0.924479)
Validating set 4...  (Loss, Acc) = (0.315434, 0.919922)
Validating set 5...  (Loss, Acc) = (0.323978, 0.917188)
Validating set 6...  (Loss, Acc) = (0.320035, 0.916667)
Validating set 7...  (Loss, Acc) = (0.311358, 0.917411)
Validating set 8...  (Loss, Acc) = (0.318938, 0.911133)
Validating set 9...  (Loss, Acc) = (0.318156, 0.910590)
Validating set 10... (Loss, Acc) = (0.324514, 0.910937)
Validating set 11... (Loss, Acc) = (0.314028, 0.914773)
Validating set 12... (Loss, Acc) = (0.314245, 0.913411)
Validating set 13... (Loss, Acc) = (0.305751, 0.915264)
Validating set 14... (Loss, Acc) = (0.305372, 0.915179)
Validating set 15... (Loss, Acc) = (0.316371, 0.914062)
Validating set 16... (Loss, Acc) = (0.321232, 0.911133)
Validating set 17... (Loss, Acc) = (0.314669, 0.914522)
Validating set 18... (Loss, Acc) = (0.315064, 0.912760)
Validating set 19... (Loss, Acc) = (0.310267, 0.914062)
Validating set 20... (Loss, Acc) = (0.307732, 0.914844)
Validating set 21... (Loss, Acc) = (0.316089, 0.912202)
Validating set 22... (Loss, Acc) = (0.314979, 0.912642)
Validating set 23... (Loss, Acc) = (0.317234, 0.912024)
Validating set 24... (Loss, Acc) = (0.314426, 0.912109)
Validating set 25... (Loss, Acc) = (0.315502, 0.911875)
Validating set 26... (Loss, Acc) = (0.315770, 0.911959)
Validating set 27... (Loss, Acc) = (0.319364, 0.910880)
Validating set 28... (Loss, Acc) = (0.318615, 0.910435)
Validating set 29... (Loss, Acc) = (0.319337, 0.911369)
Validating set 30... (Loss, Acc) = (0.320033, 0.910677)
Validating set 31... (Loss, Acc) = (0.321572, 0.911038)
Validating set 32... (Loss, Acc) = (0.320362, 0.910889)
Validating set 33... (Loss, Acc) = (0.320599, 0.910748)
Validating set 34... (Loss, Acc) = (0.318106, 0.911535)
Validating set 35... (Loss, Acc) = (0.325595, 0.909598)
Validating set 36... (Loss, Acc) = (0.324875, 0.910373)
TrainGenerator state saved
Model report generated and saved
Best model saved to dir\models\M0__model-adam__min.325
TrainGenerator state saved
Model report generated and saved
../_images/examples_basic_15_5.png
Training has concluded.

Delve deeper

DeepTrain offers much beyond the minimals; it’s suggested to proceed with the advanced example before exploring others.