You are browsing as a guest. Sign up (or log in) to start making projects!

Open comments for this post

1h 10m 12s logged

Training API + OneHotEncoder

Cleaned up a lot of the boilerplate inside main.py. The training loop is now inside the network itself and can be called with a single fit() call handling batching, epochs, and logging:

network.fit(X_train, y_train, learning_rate=0.1, batch_size=0.1)

I also renamed a few things that were always named imprecisely. run() is now forward(), and the old fit() (which only did a single gradient step) became epoch(). This made the code much easier to read.

The one-hot encoding is now its own class instead of a one-liner scattered in main.py. This was necessary because you need to decode predictions back to the original labels afterwards:

encoder = OneHotEncoder()
y = encoder.encode(digits.target)

encoder.decode(network.forward(X_test))

I also added a train_test_split utility so I don’t have to manually slice arrays anymore.

0
1

Comments 0

No comments yet. Be the first!