Training API + OneHotEncoder
Cleaned up a lot of the boilerplate inside main.py. The training loop is now inside the network itself and can be called with a single fit() call handling batching, epochs, and logging:
network.fit(X_train, y_train, learning_rate=0.1, batch_size=0.1)
I also renamed a few things that were always named imprecisely. run() is now forward(), and the old fit() (which only did a single gradient step) became epoch(). This made the code much easier to read.
The one-hot encoding is now its own class instead of a one-liner scattered in main.py. This was necessary because you need to decode predictions back to the original labels afterwards:
encoder = OneHotEncoder()
y = encoder.encode(digits.target)
encoder.decode(network.forward(X_test))
I also added a train_test_split utility so I don’t have to manually slice arrays anymore.
Comments 0
No comments yet. Be the first!
Sign in to join the conversation.