Added the ability to train PyTorch models (#706)
* Added in PyTorch and PyTorch Lightning to train a DC model Successfully able to train a ResNet18-based model using PyTorch Lightning. * Removed hard-coded max number of epochs (used for debugging) * Added an inference transform to ResNet18 to convert PIL -> tensor * Unsqueezed input tensor during inference for batch dimension * Reshaped ResNet output from (1, 2) to (2,) * Added the ability to resume training from a checkpoint * Added helper print message when tensorboard logging is enabled * Updated docopt arguments for train.py. Made checkpoint optional * Changed TorchTubDataset from sub-classing Dataset to IterableDataset This was done in response to https://github.com/autorope/donkeycar/pull/706#discussion_r548137252 * Renamed load_image_arr to load_image. Updated load_pil_image load_pil_image will now handle converting the image to greyscale (vs. this being done in load_image). * Updated enviroments for Mac and Ubuntu. Set Python=3.7 * Updated installation documentation. Added script to setup Nano Updated the installation instructions for Ubuntu, Mac, and Windows. Clarified a common issue that occurs when running pip install -e .[pc] with ZSH. Also added a script to setup the Jetson Nano and updated the documentation for the Nano (it previously was installing tensorflow 1.x). * Added torch flag to setup.py to install pytorch * Moved pytorch training into base.py and removed from train.py * Moved Jetson Nano python package installation into requirements.txt * Formatted with PEP8 to clean up pytorch code * Updated docs to provide work-around for ZSH pip install -e .[pc] * Removed duplicate dependencies in conda env files * ResNet18 torch model now returns training loss history * Added test file for PyTorch training Still need to make sure this passes Travis CI. * Added lightning_logs to .gitignore * You can now specify the default AI framework to use in config.py This reduces the number of command line arguments you are required to provide. * get_model_by_type for PyTorch now lazy imports ResNet18 * Added help message to torch_train. Got rid of linear model type * Updated pytorch tests and fixed some syntax errors * ResNet18 example input shape updated to be (B, 3, 224, 224) Also now passing output_shape to load_resnet18 to modify how many output classes are used * No longer pinning requirement versions for Jetson Nano * Fixed formatting in setup.py
Please register or sign in to comment