Model Health Monitoring

This example assumes you’ve read advanced.py, and covers:

  • Exploding & vanishing gradients monitoring
  • Spotting dead weights
[1]:
import deeptrain
deeptrain.util.misc.append_examples_dir_to_sys_path()  # for `from utils import`

from utils import CL_CONFIGS as C
from utils import init_session, make_classifier
from utils import Adam
from see_rnn import rnn_histogram, rnn_heatmap

Case 1: Large weights

We train with a large learning rate to force large weights

[2]:
# We build a model prone to large but not exploding/vanishing gradients
C['model']['optimizer'] = Adam(6)
C['traingen']['epochs'] = 1
tg = init_session(C, make_classifier)
Discovered 48 files with matching format
Discovered dataset with matching format
48 set nums inferred; if more are expected, ensure file names contain a common substring w/ a number (e.g. 'train1.npy', 'train2.npy', etc)
DataGenerator initiated

Discovered 36 files with matching format
Discovered dataset with matching format
36 set nums inferred; if more are expected, ensure file names contain a common substring w/ a number (e.g. 'train1.npy', 'train2.npy', etc)
DataGenerator initiated

Preloading superbatch ... Discovered 48 files with matching format
................................................ finished, w/ 6144 total samples
Train initial data prepared
Preloading superbatch ... Discovered 36 files with matching format
.................................... finished, w/ 4608 total samples
Val initial data prepared
Logging ON; directory (new): C:\deeptrain\examples\dir\logs\M2__model-Adam__min999.000
[3]:
tg.train()

Fitting set 1...   (Loss, Acc) = (2.301934, 0.164062)
Fitting set 2...   (Loss, Acc) = (61413017.150967, 0.105469)
Fitting set 3...   (Loss, Acc) = (50551152.767311, 0.101562)
Fitting set 4...   (Loss, Acc) = (39892719.075484, 0.113281)
Fitting set 5...   (Loss, Acc) = (31946056.266637, 0.115625)
Fitting set 6...   (Loss, Acc) = (26621893.911712, 0.117188)
Fitting set 7...   (Loss, Acc) = (22818785.382816, 0.119420)
Fitting set 8...   (Loss, Acc) = (19966440.713710, 0.116211)
Fitting set 9...   (Loss, Acc) = (17747949.665580, 0.114583)
Fitting set 10...  (Loss, Acc) = (15973156.238238, 0.112500)
Fitting set 11...  (Loss, Acc) = (14521056.922195, 0.113636)
Fitting set 12...  (Loss, Acc) = (13310969.448344, 0.112630)
Fitting set 13...  (Loss, Acc) = (12287049.257492, 0.109976)
Fitting set 14...  (Loss, Acc) = (11409403.477279, 0.105469)
Fitting set 15...  (Loss, Acc) = (10648777.200421, 0.105729)
Fitting set 16...  (Loss, Acc) = (9983229.190327, 0.105469)
Fitting set 17...  (Loss, Acc) = (9395981.031790, 0.105699)
Fitting set 18...  (Loss, Acc) = (8873982.566471, 0.105469)
Fitting set 19...  (Loss, Acc) = (8406931.181455, 0.105674)
Fitting set 20...  (Loss, Acc) = (7986584.963393, 0.105078)
Fitting set 21...  (Loss, Acc) = (7606271.703814, 0.104539)
Fitting set 22...  (Loss, Acc) = (7260532.407605, 0.104048)
Fitting set 23...  (Loss, Acc) = (6944857.373701, 0.101902)
Fitting set 24...  (Loss, Acc) = (6655488.531696, 0.101888)
Fitting set 25...  (Loss, Acc) = (6389269.171751, 0.101250)
Fitting set 26...  (Loss, Acc) = (6143528.244499, 0.100361)
Fitting set 27...  (Loss, Acc) = (5915990.377923, 0.100694)
Fitting set 28...  (Loss, Acc) = (5704705.207857, 0.101283)
Fitting set 29...  (Loss, Acc) = (5507991.401438, 0.101293)
Fitting set 30...  (Loss, Acc) = (5324391.880967, 0.101042)
Fitting set 31...  (Loss, Acc) = (5152637.451208, 0.100554)
Fitting set 32...  (Loss, Acc) = (4991617.706756, 0.100342)
Fitting set 33...  (Loss, Acc) = (4840356.694946, 0.099195)
Fitting set 34...  (Loss, Acc) = (4697993.408865, 0.097886)
Fitting set 35...  (Loss, Acc) = (4563765.141642, 0.098437)
Fitting set 36...  (Loss, Acc) = (4436993.987101, 0.098741)
Fitting set 37...  (Loss, Acc) = (4317075.328605, 0.097973)
Fitting set 38...  (Loss, Acc) = (4203468.178645, 0.098479)
Fitting set 39...  (Loss, Acc) = (4095687.036212, 0.098558)
Fitting set 40...  (Loss, Acc) = (3993294.942857, 0.099023)
Fitting set 41...  (Loss, Acc) = (3895897.572042, 0.099276)
Fitting set 42...  (Loss, Acc) = (3803138.170269, 0.099516)
Fitting set 43...  (Loss, Acc) = (3714693.161694, 0.100291)
Fitting set 44...  (Loss, Acc) = (3630268.393200, 0.100142)
Fitting set 45...  (Loss, Acc) = (3549595.830608, 0.100000)
Fitting set 46...  (Loss, Acc) = (3472430.768177, 0.101053)
Fitting set 47...  (Loss, Acc) = (3398549.320164, 0.100399)
Fitting set 48...  (Loss, Acc) = (3327746.265449, 0.099609)
Data set_nums shuffled


_____________________
 EPOCH 1 -- COMPLETE 



Validating...
Validating set 1...  (Loss, Acc) = (2.804169, 0.117188)
Validating set 2...  (Loss, Acc) = (2.710045, 0.109375)
Validating set 3...  (Loss, Acc) = (2.839425, 0.093750)
Validating set 4...  (Loss, Acc) = (2.752114, 0.093750)
Validating set 5...  (Loss, Acc) = (2.808827, 0.125000)
Validating set 6...  (Loss, Acc) = (2.799275, 0.046875)
Validating set 7...  (Loss, Acc) = (2.512745, 0.078125)
Validating set 8...  (Loss, Acc) = (2.578151, 0.109375)
Validating set 9...  (Loss, Acc) = (2.539990, 0.109375)
Validating set 10... (Loss, Acc) = (2.854088, 0.101562)
Validating set 11... (Loss, Acc) = (2.723956, 0.109375)
Validating set 12... (Loss, Acc) = (2.574893, 0.085938)
Validating set 13... (Loss, Acc) = (2.780102, 0.070312)
Validating set 14... (Loss, Acc) = (2.637559, 0.125000)
Validating set 15... (Loss, Acc) = (2.744786, 0.109375)
Validating set 16... (Loss, Acc) = (2.890051, 0.140625)
Validating set 17... (Loss, Acc) = (2.780263, 0.140625)
Validating set 18... (Loss, Acc) = (2.721423, 0.078125)
Validating set 19... (Loss, Acc) = (2.608651, 0.117188)
Validating set 20... (Loss, Acc) = (2.458480, 0.125000)
Validating set 21... (Loss, Acc) = (2.659738, 0.078125)
Validating set 22... (Loss, Acc) = (2.761351, 0.054688)
Validating set 23... (Loss, Acc) = (2.674613, 0.140625)
Validating set 24... (Loss, Acc) = (2.614646, 0.132812)
Validating set 25... (Loss, Acc) = (2.717063, 0.132812)
Validating set 26... (Loss, Acc) = (2.831150, 0.132812)
Validating set 27... (Loss, Acc) = (2.840220, 0.062500)
Validating set 28... (Loss, Acc) = (2.888196, 0.101562)
Validating set 29... (Loss, Acc) = (2.683285, 0.101562)
Validating set 30... (Loss, Acc) = (2.695441, 0.101562)
Validating set 31... (Loss, Acc) = (2.710979, 0.101562)
Validating set 32... (Loss, Acc) = (2.721793, 0.109375)
Validating set 33... (Loss, Acc) = (2.832050, 0.132812)
Validating set 34... (Loss, Acc) = (2.733249, 0.132812)
Validating set 35... (Loss, Acc) = (2.609143, 0.078125)
Validating set 36... (Loss, Acc) = (2.730237, 0.078125)
TrainGenerator state saved
Model report generated and saved
Best model saved to C:\deeptrain\examples\dir\models\M2__model-Adam__min2.717
TrainGenerator state saved
Model report generated and saved
../../_images/examples_misc_model_health_4_1.png

100.0% Large -- 'conv2d/kernel:0'
100.0% Large -- 'conv2d/bias:0'
100.0% Large -- 'conv2d_1/kernel:0'
100.0% Large -- 'conv2d_1/bias:0'
85.8% Large -- 'dense/kernel:0'
100.0% Large -- 'dense/bias:0'
99.7% Large -- 'dense_1/kernel:0'
100.0% Large -- 'dense_1/bias:0'
L = layer index, W = weight tensor indexTraining has concluded.

Case 2: Exploding/vanishing weights

We build RNNs with ReLU activations to generate extreme activations, thereby gradients and weights

[4]:
from utils import TS_CONFIGS as C
from utils import make_timeseries_classifier

C['model']['activation'] = 'relu'
C['model']['optimizer'] = Adam(.3)
C['traingen']['epochs'] = 1
C['traingen']['eval_fn'] = 'predict'
C['traingen']['val_freq'] = {'epoch': 1}
tg = init_session(C, make_timeseries_classifier)
WARNING:tensorflow:Layer lstm will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU
WARNING:tensorflow:Layer lstm_1 will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU
Discovered dataset with matching format
Discovered dataset with matching format
103 set nums inferred; if more are expected, ensure file names contain a common substring w/ a number (e.g. 'train1.npy', 'train2.npy', etc)
DataGenerator initiated

Discovered dataset with matching format
Discovered dataset with matching format
12 set nums inferred; if more are expected, ensure file names contain a common substring w/ a number (e.g. 'train1.npy', 'train2.npy', etc)
DataGenerator initiated

Preloading superbatch ... Discovered dataset with matching format
....................................................................................................... finished, w/ 13184 total samples
Train initial data prepared
Preloading superbatch ... Discovered dataset with matching format
............ finished, w/ 1536 total samples
Val initial data prepared
Logging ON; directory (new): C:\deeptrain\examples\dir\logs\M3__model-Adam__min999.000
[5]:
# will error due to being unable to plot nan metrics; we don't mind
try: tg.train()
except: pass

Fitting set 0...   Loss = nan RNNs reset
Fitting set 1...   Loss = nan RNNs reset
Fitting set 10...  Loss = nan RNNs reset
Fitting set 100... Loss = nan RNNs reset
Fitting set 101... Loss = nan RNNs reset
Fitting set 102... Loss = nan RNNs reset
Fitting set 11...  Loss = nan RNNs reset
Fitting set 12...  Loss = nan RNNs reset
Fitting set 13...  Loss = nan RNNs reset
Fitting set 14...  Loss = nan RNNs reset
Fitting set 15...  Loss = nan RNNs reset
Fitting set 16...  Loss = nan RNNs reset
Fitting set 17...  Loss = nan RNNs reset
Fitting set 18...  Loss = nan RNNs reset
Fitting set 19...  Loss = nan RNNs reset
Fitting set 2...   Loss = nan RNNs reset
Fitting set 20...  Loss = nan RNNs reset
Fitting set 21...  Loss = nan RNNs reset
Fitting set 22...  Loss = nan RNNs reset
Fitting set 23...  Loss = nan RNNs reset
Fitting set 24...  Loss = nan RNNs reset
Fitting set 25...  Loss = nan RNNs reset
Fitting set 26...  Loss = nan RNNs reset
Fitting set 27...  Loss = nan RNNs reset
Fitting set 28...  Loss = nan RNNs reset
Fitting set 29...  Loss = nan RNNs reset
Fitting set 3...   Loss = nan RNNs reset
Fitting set 30...  Loss = nan RNNs reset
Fitting set 31...  Loss = nan RNNs reset
Fitting set 32...  Loss = nan RNNs reset
Fitting set 33...  Loss = nan RNNs reset
Fitting set 34...  Loss = nan RNNs reset
Fitting set 35...  Loss = nan RNNs reset
Fitting set 36...  Loss = nan RNNs reset
Fitting set 37...  Loss = nan RNNs reset
Fitting set 38...  Loss = nan RNNs reset
Fitting set 39...  Loss = nan RNNs reset
Fitting set 4...   Loss = nan RNNs reset
Fitting set 40...  Loss = nan RNNs reset
Fitting set 41...  Loss = nan RNNs reset
Fitting set 42...  Loss = nan RNNs reset
Fitting set 43...  Loss = nan RNNs reset
Fitting set 44...  Loss = nan RNNs reset
Fitting set 45...  Loss = nan RNNs reset
Fitting set 46...  Loss = nan RNNs reset
Fitting set 47...  Loss = nan RNNs reset
Fitting set 48...  Loss = nan RNNs reset
Fitting set 49...  Loss = nan RNNs reset
Fitting set 5...   Loss = nan RNNs reset
Fitting set 50...  Loss = nan RNNs reset
Fitting set 51...  Loss = nan RNNs reset
Fitting set 52...  Loss = nan RNNs reset
Fitting set 53...  Loss = nan RNNs reset
Fitting set 54...  Loss = nan RNNs reset
Fitting set 55...  Loss = nan RNNs reset
Fitting set 56...  Loss = nan RNNs reset
Fitting set 57...  Loss = nan RNNs reset
Fitting set 58...  Loss = nan RNNs reset
Fitting set 59...  Loss = nan RNNs reset
Fitting set 6...   Loss = nan RNNs reset
Fitting set 60...  Loss = nan RNNs reset
Fitting set 61...  Loss = nan RNNs reset
Fitting set 62...  Loss = nan RNNs reset
Fitting set 63...  Loss = nan RNNs reset
Fitting set 64...  Loss = nan RNNs reset
Fitting set 65...  Loss = nan RNNs reset
Fitting set 66...  Loss = nan RNNs reset
Fitting set 67...  Loss = nan RNNs reset
Fitting set 68...  Loss = nan RNNs reset
Fitting set 69...  Loss = nan RNNs reset
Fitting set 7...   Loss = nan RNNs reset
Fitting set 70...  Loss = nan RNNs reset
Fitting set 71...  Loss = nan RNNs reset
Fitting set 72...  Loss = nan RNNs reset
Fitting set 73...  Loss = nan RNNs reset
Fitting set 74...  Loss = nan RNNs reset
Fitting set 75...  Loss = nan RNNs reset
Fitting set 76...  Loss = nan RNNs reset
Fitting set 77...  Loss = nan RNNs reset
Fitting set 78...  Loss = nan RNNs reset
Fitting set 79...  Loss = nan RNNs reset
Fitting set 8...   Loss = nan RNNs reset
Fitting set 80...  Loss = nan RNNs reset
Fitting set 81...  Loss = nan RNNs reset
Fitting set 82...  Loss = nan RNNs reset
Fitting set 83...  Loss = nan RNNs reset
Fitting set 84...  Loss = nan RNNs reset
Fitting set 85...  Loss = nan RNNs reset
Fitting set 86...  Loss = nan RNNs reset
Fitting set 87...  Loss = nan RNNs reset
Fitting set 88...  Loss = nan RNNs reset
Fitting set 89...  Loss = nan RNNs reset
Fitting set 9...   Loss = nan RNNs reset
Fitting set 90...  Loss = nan RNNs reset
Fitting set 91...  Loss = nan RNNs reset
Fitting set 92...  Loss = nan RNNs reset
Fitting set 93...  Loss = nan RNNs reset
Fitting set 94...  Loss = nan RNNs reset
Fitting set 95...  Loss = nan RNNs reset
Fitting set 96...  Loss = nan RNNs reset
Fitting set 97...  Loss = nan RNNs reset
Fitting set 98...  Loss = nan RNNs reset
Fitting set 99...  Loss = nan RNNs reset
Data set_nums shuffled


_____________________
 EPOCH 1 -- COMPLETE 



Validating...
Validating set 0...  Loss = nan
RNNs reset Validating set 1...  Loss = nan
RNNs reset Validating set 10...
C:\deeptrain\deeptrain\metrics.py:113: RuntimeWarning: invalid value encountered in greater_equal
  neg_abs_logits = np.where(logits >= 0, -logits, logits)
C:\deeptrain\deeptrain\metrics.py:114: RuntimeWarning: invalid value encountered in greater_equal
  relu_logits    = np.where(logits >= 0, logits, 0)
C:\deeptrain\deeptrain\util\searching.py:70: RuntimeWarning: invalid value encountered in greater
  new_best = (metric > best_metric if max_is_best else
 Loss = nan
RNNs reset Validating set 11... Loss = nan
RNNs reset Validating set 2...  Loss = nan
RNNs reset Validating set 3...  Loss = nan
RNNs reset Validating set 4...  Loss = nan
RNNs reset Validating set 5...  Loss = nan
RNNs reset Validating set 6...  Loss = nan
RNNs reset Validating set 7...  Loss = nan
RNNs reset Validating set 8...  Loss = nan
RNNs reset Validating set 9...  Loss = nan
RNNs reset
D:\Anaconda\envs\tf2_env\lib\site-packages\matplotlib\axes\_axes.py:6630: RuntimeWarning: All-NaN slice encountered
  xmin = min(xmin, np.nanmin(xi))
D:\Anaconda\envs\tf2_env\lib\site-packages\matplotlib\axes\_axes.py:6630: RuntimeWarning: invalid value encountered in less
  xmin = min(xmin, np.nanmin(xi))
D:\Anaconda\envs\tf2_env\lib\site-packages\matplotlib\axes\_axes.py:6631: RuntimeWarning: All-NaN slice encountered
  xmax = max(xmax, np.nanmax(xi))
D:\Anaconda\envs\tf2_env\lib\site-packages\matplotlib\axes\_axes.py:6631: RuntimeWarning: invalid value encountered in greater
  xmax = max(xmax, np.nanmax(xi))
../../_images/examples_misc_model_health_7_4.png
[6]:
tg.check_health()
3.125% dead -- 'lstm/lstm_cell/bias:0'
1.042% dead -- 'lstm_1/lstm_cell_1/bias:0'
L = layer index, W = weight tensor index

82.3% NaN -- 'lstm/lstm_cell/kernel:0'
100.0% NaN -- 'lstm/lstm_cell/recurrent_kernel:0'
82.3% NaN -- 'lstm/lstm_cell/bias:0'
100.0% NaN -- 'lstm_1/lstm_cell_1/kernel:0'
100.0% NaN -- 'lstm_1/lstm_cell_1/recurrent_kernel:0'
77.1% NaN -- 'lstm_1/lstm_cell_1/bias:0'
100.0% NaN -- 'dense_2/kernel:0'
100.0% NaN -- 'dense_2/bias:0'
L = layer index, W = weight tensor index
C:\deeptrain\deeptrain\introspection.py:405: RuntimeWarning: invalid value encountered in less
  num_dead = np.sum(np.abs(w_value) < dead_threshold)
C:\deeptrain\deeptrain\introspection.py:490: RuntimeWarning: invalid value encountered in greater
  num_large = np.sum(np.abs(w_value) > large_threshold) - num_nan

Visualize

[7]:
_ = rnn_histogram(tg.model, 1)
_ = rnn_heatmap(tg.model, 1)
_ = rnn_histogram(tg.model, 2)
_ = rnn_heatmap(tg.model, 2)
../../_images/examples_misc_model_health_10_0.png

KERNEL:
INPUT: 100.0% NaN
FORGET: 100.0% NaN
CELL: 29.2% NaN
OUTPUT: 100.0% NaN

RECURRENT:
INPUT: 100.0% NaN
FORGET: 100.0% NaN
CELL: 100.0% NaN
OUTPUT: 100.0% NaN
D:\Anaconda\envs\tf2_env\lib\site-packages\matplotlib\colors.py:581: RuntimeWarning: invalid value encountered in less
  xa[xa < 0] = -1
../../_images/examples_misc_model_health_10_3.png
../../_images/examples_misc_model_health_10_4.png

KERNEL:
INPUT: 100.0% NaN
FORGET: 100.0% NaN
CELL: 100.0% NaN
OUTPUT: 100.0% NaN

RECURRENT:
INPUT: 100.0% NaN
FORGET: 100.0% NaN
CELL: 100.0% NaN
OUTPUT: 100.0% NaN
../../_images/examples_misc_model_health_10_6.png