GoPt is a small library for loading PyTorch image classification models into Golang code.
- Libtorch C++ v1.11.0 library of Pytorch
- Default CUDA version is
11.3if CUDA is available otherwise using CPU version. - Default Pytorch C++ API version is
1.11.0
NOTE: libtorch will be installed at /usr/local/lib
wget https://raw.githubusercontent.com/sugarme/gotch/master/setup-libtorch.sh
chmod +x setup-libtorch.sh
export CUDA_VER=cpu && bash setup-libtorch.shUpdate Environment: in Debian/Ubuntu, add/update the following lines to .bashrc file
export GOTCH_LIBTORCH="/usr/local/lib/libtorch"
export LIBRARY_PATH="$LIBRARY_PATH:$GOTCH_LIBTORCH/lib"
export CPATH="$CPATH:$GOTCH_LIBTORCH/lib:$GOTCH_LIBTORCH/include:$GOTCH_LIBTORCH/include/torch/csrc/api/include"
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$GOTCH_LIBTORCH/lib" wget https://raw.githubusercontent.com/sugarme/gotch/master/setup-gotch.sh
chmod +x setup-gotch.sh
export CUDA_VER=cpu && export GOTCH_VER=v0.7.0 && bash setup-gotch.shpackage main
import (
"flag"
"fmt"
"github.com/jbloxsome/gopt"
)
var (
modelPath string
imageFile string
)
func init() {
flag.StringVar(&modelPath, "modelpath", "./model.pt", "full path to exported pytorch model.")
flag.StringVar(&imageFile, "image", "./image.jpg", "full path to image file.")
}
func main() {
flag.Parse()
labels := []string{
"false",
"true",
}
model, err := gopt.NewGoPt(modelPath, labels)
if err != nil {
log.Fatal(err)
}
pred, err := model.Predict(imageFile)
if err != nil {
fmt.Println(err)
}
fmt.Println(pred)
}