Unverified Commit 94e9d604 authored by Eric Wiener's avatar Eric Wiener Committed by GitHub
Browse files

Added the ability to train PyTorch models (#706)

* Added in PyTorch and PyTorch Lightning to train a DC model

Successfully able to train a ResNet18-based model using PyTorch
Lightning.

* Removed hard-coded max number of epochs (used for debugging)

* Added an inference transform to ResNet18 to convert PIL -> tensor

* Unsqueezed input tensor during inference for batch dimension

* Reshaped ResNet output from (1, 2) to (2,)

* Added the ability to resume training from a checkpoint

* Added helper print message when tensorboard logging is enabled

* Updated docopt arguments for train.py. Made checkpoint optional

* Changed TorchTubDataset from sub-classing Dataset to IterableDataset

This was done in response to https://github.com/autorope/donkeycar/pull/706#discussion_r548137252

* Renamed load_image_arr to load_image. Updated load_pil_image

load_pil_image will now handle converting the image to greyscale
(vs. this being done in load_image).

* Updated enviroments for Mac and Ubuntu. Set Python=3.7

* Updated installation documentation. Added script to setup Nano

Updated the installation instructions for Ubuntu, Mac, and Windows.
Clarified a common issue that occurs when running pip install -e .[pc]
with ZSH.

Also added a script to setup the Jetson Nano and updated the documentation
for the Nano (it previously was installing tensorflow 1.x).

* Added torch flag to setup.py to install pytorch

* Moved pytorch training into base.py and removed from train.py

* Moved Jetson Nano python package installation into requirements.txt

* Formatted with PEP8 to clean up pytorch code

* Updated docs to provide work-around for ZSH pip install -e .[pc]

* Removed duplicate dependencies in conda env files

* ResNet18 torch model now returns training loss history

* Added test file for PyTorch training

Still need to make sure this passes Travis CI.

* Added lightning_logs to .gitignore

* You can now specify the default AI framework to use in config.py

This reduces the number of command line arguments you are required
to provide.

* get_model_by_type for PyTorch now lazy imports ResNet18

* Added help message to torch_train. Got rid of linear model type

* Updated pytorch tests and fixed some syntax errors

* ResNet18 example input shape updated to be (B, 3, 224, 224)

Also now passing output_shape to load_resnet18 to modify how many
output classes are used

* No longer pinning requirement versions for Jetson Nano

* Fixed formatting in setup.py
parent 4ded8750
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment