In this section, we will show you how to implement a relatively simple CNN architecture. We will also look at how to train it to classify the CIFAR-10 dataset.
Start by importing all the necessary libraries:
import fire import numpy as np import os import tensorflow as tf from tf.keras.datasets import cifar10
We will define a Python class that will implement the training process. The class name is Train
, and it implements two methods: build_graph
and train
. The train
function is fired when the main program is executed:
class Train: __x_ = [] __y_ = [] __logits = [] __loss = [] __train_step = [] __merged_summary_op = [] __saver = [] __session = [] __writer = [] __is_training = [] __loss_val = [] __train_summary = [] __val_summary = [] def __init__(self): pass def build_graph(self): [...] def train(self, save_dir='./save', batch_size=500): [...] if __name__...