I've recently been reexploring the world of Rust & Machine Learning and what the future of it is going to look like. Whilst doing so, I built a simple ONNX inference framework in Rust. ONNX is an open standard for serializing machine learning models, with seemingly strong industry backing. I declared my framework finished when it could perform a forward pass of an EfficientNet Lite.
You can see the demo here:
To achieve the above, I had to implement all of the necessary operations such as Convolutions, Transposes, Pooling etc. The most frequent cause of bugs while writing the framework: mismatched shapes.
It would have been great to know exactly what shape I needed at each edge in my network graph. This is where a model visualizer comes in. The most popular visualizer is currently Lutz Roeder's Netron. However, when I opened the EfficientNet model in Netron - I got the following:
...no shapes after the input!
This post isn't to dunk on Netron - it's an amazing tool that has incredibly wide format support and is very robust. However, seeing as I had already written a parser for ONNX files, I decided to scratch my own itch and write a visualizer - with a few additional features along the way - culminating in Steelix. Steelix is ONNX only (greatly reducing the scope of this work unlike Netron) and supports a very limited set of operators. However, despite these limitations it can do some very useful things!
Plotting
Steelix can produce a DOT file for any ONNX model, allowing you to explore the structure. You can see a demo below on our EfficientNet:
Steelix also has the added advantage of being composable. You can embed a call to Steelix in your Jupyter Notebook and automatically generate updated model graphs for your documentation. If Steelix doesn't support an operator, you can still visualize your model, just without any shapes as it is unable to do the mock inference pass.
Summarization
Often when playing with models, it would be good to know some summary statistics about them. Steelix provides a handy summary table with some interesting metrics - check out the demo below on AlexNet:
The main metrics of interest to me are the FLOP count, parameter count and hardware performance. Still to be added are peak activiations and other useful metrics (you can learn more in these great slides). These metrics are particularly interesting if you want to run models on microcontrollers - expect to see me exploring that space soon. Thanks to Marius Hobbhahn for some useful insights on FLOP counting.
Takeaways
Writing Steelix was a good learning experience, but there's a huge number of improvements to be made. Currently, Steelix creates heaps of unnecessary zeroed tensors as some operators require runtime information (i.e the outputs are computed from the data of a previous operation, not solely from metadata). It should be possible to identify which subgraphs require runtime information and which don't, and only load tensors for those subgraphs. Perhaps I will take this further in the future.