diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..2be9c5d --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +**/*.pt filter=lfs diff=lfs merge=lfs -text +**/*.onnx filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index 55ced25..8a30d25 100644 --- a/.gitignore +++ b/.gitignore @@ -1,10 +1,398 @@ -__pycache__ -AU_Recognition/__pycache__ -AU_Recognition/models/__pycache__ -AU_Recognition/data -AU_Detection/__pycache__ -AU_Detection/models/__pycache__ -AU_Detection/data -Facial_Expression_Recognition/__pycache__ -Facial_Expression_Recognition/models/__pycache__ -Facial_Expression_Recognition/data +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. +## +## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore + +# User-specific files +*.rsuser +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Mono auto generated files +mono_crash.* + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +[Ww][Ii][Nn]32/ +[Aa][Rr][Mm]/ +[Aa][Rr][Mm]64/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ +[Ll]ogs/ + +# Visual Studio 2015/2017 cache/options directory +.vs/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# Visual Studio 2017 auto generated files +Generated\ Files/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUnit +*.VisualState.xml +TestResult.xml +nunit-*.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# Benchmark Results +BenchmarkDotNet.Artifacts/ + +# .NET Core +project.lock.json +project.fragment.lock.json +artifacts/ + +# ASP.NET Scaffolding +ScaffoldingReadMe.txt + +# StyleCop +StyleCopReport.xml + +# Files built by Visual Studio +*_i.c +*_p.c +*_h.h +*.ilk +*.meta +*.obj +*.iobj +*.pch +*.pdb +*.ipdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*_wpftmp.csproj +*.log +*.tlog +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# Visual Studio Trace Files +*.e2e + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# AxoCover is a Code Coverage Tool +.axoCover/* +!.axoCover/settings.json + +# Coverlet is a free, cross platform Code Coverage Tool +coverage*.json +coverage*.xml +coverage*.info + +# Visual Studio code coverage results +*.coverage +*.coveragexml + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# Note: Comment the next line if you want to checkin your web deploy settings, +# but database connection strings (with potential passwords) will be unencrypted +*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# NuGet Symbol Packages +*.snupkg +# The packages folder can be ignored because of Package Restore +**/[Pp]ackages/* +# except build/, which is used as an MSBuild target. +!**/[Pp]ackages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/[Pp]ackages/repositories.config +# NuGet v3's project.json files produces more ignorable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt +*.appx +*.appxbundle +*.appxupload + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!?*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +orleans.codegen.cs + +# Including strong name files can present a security risk +# (https://github.com/github/gitignore/pull/2483#issue-259490424) +#*.snk + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm +ServiceFabricBackup/ +*.rptproj.bak + +# SQL Server files +*.mdf +*.ldf +*.ndf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings +*.rptproj.rsuser +*- [Bb]ackup.rdl +*- [Bb]ackup ([0-9]).rdl +*- [Bb]ackup ([0-9][0-9]).rdl + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat +node_modules/ + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) +*.vbw + +# Visual Studio 6 auto-generated project file (contains which files were open etc.) +*.vbp + +# Visual Studio 6 workspace and project file (working project files containing files to include in project) +*.dsw +*.dsp + +# Visual Studio 6 technical files +*.ncb +*.aps + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# CodeRush personal settings +.cr/personal + +# Python Tools for Visual Studio (PTVS) +__pycache__/ +*.pyc + +# Cake - Uncomment if you are using it +# tools/** +# !tools/packages.config + +# Tabs Studio +*.tss + +# Telerik's JustMock configuration file +*.jmconfig + +# BizTalk build output +*.btp.cs +*.btm.cs +*.odx.cs +*.xsd.cs + +# OpenCover UI analysis results +OpenCover/ + +# Azure Stream Analytics local run output +ASALocalRun/ + +# MSBuild Binary and Structured Log +*.binlog + +# NVidia Nsight GPU debugger configuration file +*.nvuser + +# MFractors (Xamarin productivity tool) working folder +.mfractor/ + +# Local History for Visual Studio +.localhistory/ + +# Visual Studio History (VSHistory) files +.vshistory/ + +# BeatPulse healthcheck temp database +healthchecksdb + +# Backup folder for Package Reference Convert tool in Visual Studio 2017 +MigrationBackup/ + +# Ionide (cross platform F# VS Code tools) working folder +.ionide/ + +# Fody - auto-generated XML schema +FodyWeavers.xsd + +# VS Code files for those working on multiple tools +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace + +# Local History for Visual Studio Code +.history/ + +# Windows Installer files from build outputs +*.cab +*.msi +*.msix +*.msm +*.msp + +# JetBrains Rider +*.sln.iml diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..10daee9 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": true + } + ] +} \ No newline at end of file diff --git a/AU_Recognition/LibreFace_emotionnet_mae_20230627125708.onnx b/AU_Recognition/LibreFace_emotionnet_mae_20230627125708.onnx new file mode 100644 index 0000000..a0845c2 --- /dev/null +++ b/AU_Recognition/LibreFace_emotionnet_mae_20230627125708.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:17dd5cf088a4ed0ef90c8cbcbcd9e7c7947792979f391a0d7bab29e7ca3de569 +size 423605715 diff --git a/AU_Recognition/LibreFace_resnet_20230627125656.onnx b/AU_Recognition/LibreFace_resnet_20230627125656.onnx new file mode 100644 index 0000000..bcb696e --- /dev/null +++ b/AU_Recognition/LibreFace_resnet_20230627125656.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c996c9ae83b67c293f05766b5fe78741e0891e25f5c2d0bd0ca6d4137790018 +size 45197696 diff --git a/AU_Recognition/LibreFace_swin_20230627125646.onnx b/AU_Recognition/LibreFace_swin_20230627125646.onnx new file mode 100644 index 0000000..30f9fd9 --- /dev/null +++ b/AU_Recognition/LibreFace_swin_20230627125646.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01bca9ff088d617d79f8d0480f925d1bbcc49fb0622ede07dabe136ae4b44e9c +size 194925271 diff --git a/AU_Recognition/export_to_onnx.ipynb b/AU_Recognition/export_to_onnx.ipynb new file mode 100644 index 0000000..7c00a54 --- /dev/null +++ b/AU_Recognition/export_to_onnx.ipynb @@ -0,0 +1,232 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from datetime import datetime\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "\n", + "from data import MyDataset\n", + "from models.resnet18 import ResNet18\n", + "from models.swin import SwinTransformer\n", + "from models.mae import MaskedAutoEncoder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_labels = 12\n", + "aus = [1,2,4,5,6,9,12,15,17,20,25,26]\n", + "batch_size = 512\n", + "num_workers = 1\n", + "train = False\n", + "device = \"cpu\"\n", + "data_root = \"../../LibreFace_TestData\"\n", + "data = \"DISFA\"\n", + "onnx_name = \"LibreFace\"\n", + "test_csv = os.path.join(data_root, data, \"labels_intensity_5\", \"all\", \"test.csv\")\n", + "dropout = 0.1\n", + "fm_distillation = False\n", + "hidden_dim = 128\n", + "\n", + "model_name = \"emotionnet_mae\"\n", + "\n", + "class SwinConfig:\n", + "\tdef __init__(self):\n", + "\t\tself.device = device\n", + "\t\tself.dropout = 0.1\n", + "\t\tself.num_labels = num_labels\n", + "\n", + "class AU2HeatmapConfig:\n", + "\tdef __init__(self):\n", + "\t\tself.data = data\n", + "\t\tself.sigma = 10.0\n", + "\t\tself.num_labels = num_labels\n", + "\n", + "class DatasetConfig(AU2HeatmapConfig):\n", + "\tdef __init__(self):\n", + "\t\tsuper().__init__()\n", + "\t\tself.data = data\n", + "\t\tself.data_root = data_root\n", + "\t\tself.image_size = 256\n", + "\t\tself.crop_size = 224\n", + "\n", + "class ResNet18Config:\n", + "\tdef __init__(self):\n", + "\t\tself.fm_distillation = fm_distillation\n", + "\t\tself.dropout = dropout\n", + "\t\tself.num_labels = num_labels\n", + "\t\n", + "class MaskedAutoEncoderConfig:\n", + "\tdef __init__(self):\n", + "\t\tself.fm_distillation = fm_distillation\n", + "\t\tself.dropout = dropout\n", + "\t\tself.num_labels = num_labels\n", + "\t\tself.hidden_dim = hidden_dim\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_config = DatasetConfig()\n", + "dataset = MyDataset(test_csv, train, dataset_config)\n", + "loader = DataLoader(\n", + "\tdataset=dataset,\n", + "\tbatch_size=batch_size,\n", + "\tnum_workers=num_workers,\n", + "\tshuffle=train,\n", + "\tcollate_fn=dataset.collate_fn,\n", + "\tdrop_last=train\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if model_name == \"resnet\":\n", + " model_config = ResNet18Config()\n", + " model = ResNet18(model_config)\n", + " ckpt_name = os.path.join(\"resnet_disfa_all\", data, \"all\", \"resnet.pt\")\n", + "elif model_name == \"swin\":\n", + " model_config = SwinConfig()\n", + " model = SwinTransformer(model_config)\n", + " ckpt_name = os.path.join(\"swin_checkpoint\", data, \"0\", \"swin.pt\")\n", + "elif model_name == \"emotionnet_mae\":\n", + " model_config = MaskedAutoEncoderConfig()\n", + " model = MaskedAutoEncoder(model_config)\n", + " ckpt_name = os.path.join(\"mae_checkpoint\", data, \"0\", \"emotionnet_mae.pt\")\n", + "else:\n", + " assert False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "checkpoints = torch.load(ckpt_name, map_location=torch.device(device))[\"model\"]\n", + "model.load_state_dict(checkpoints, strict=True)\n", + "torch.no_grad()\n", + "model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\"\"\"\n", + "for images, labels in loader:\n", + "\timages = images.to(device)\n", + "\tlabels = labels.to(device)\n", + "\tlabels_pred = model(images)\n", + "\tlabels_pred = torch.clamp(labels_pred, min=0.0, max=5.0)\n", + "\"\"\"\n", + "\n", + "dummy_input = torch.rand((1, 3, 224, 224), device=device)\n", + "input_names = [ \"image\" ]\n", + "output_names = [ \"AUs\" ]\n", + "onnx_name = \"{0}_{1}_{2}.onnx\".format(onnx_name, model_name, datetime.now().strftime(\"%Y%m%d%H%M%S\"))\n", + "\n", + "torch.onnx.export(\n", + "\tmodel, \n", + "\tdummy_input, \n", + "\tonnx_name, \n", + "\tverbose=True, \n", + "\tinput_names=input_names,\n", + "\toutput_names=output_names\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import onnx\n", + "model = onnx.load(onnx_name)\n", + "onnx.checker.check_model(model)\n", + "print(onnx.helper.printable_graph(model.graph))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import onnxruntime as ort\n", + "import numpy as np\n", + "\n", + "ort_session = ort.InferenceSession(onnx_name)\n", + "\n", + "image, label = next(iter(dataset))\n", + "image = image.unsqueeze(dim=0)\n", + "image = image.numpy()\n", + "\n", + "label_pred = ort_session.run(\n", + " None,\n", + " {\"image\": image},\n", + ")[0]\n", + "label_pred = np.squeeze(label_pred, axis=0)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(num_labels):\n", + " gt = label[i]\n", + " pred = label_pred[i]\n", + " print(f\"AU{aus[i]}:\\tdiff={abs(pred - gt):.4f}\\tpred={pred:.4f}\\tgt={gt:.4f}\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "4f76a724a0076d5c39752e12bec55adbbf3b081a4d622e794e00deba6a6ff878" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/AU_Recognition/mae_checkpoint/DISFA/0/emotionnet_mae.pt b/AU_Recognition/mae_checkpoint/DISFA/0/emotionnet_mae.pt new file mode 100644 index 0000000..c23867a --- /dev/null +++ b/AU_Recognition/mae_checkpoint/DISFA/0/emotionnet_mae.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:980a9ef30a0d46e1fac84e95f81adb6a9889656112caf106c17c072e2ce5b702 +size 450504439 diff --git a/AU_Recognition/models/mae.py b/AU_Recognition/models/mae.py index d3cecf3..4de3896 100644 --- a/AU_Recognition/models/mae.py +++ b/AU_Recognition/models/mae.py @@ -37,7 +37,7 @@ def forward(self, images): masks = [] for _ in range(B): masks.append(torch.Tensor(self.masked_position_generator()).repeat(F, 1)) - masks = torch.cat(masks, dim=0).to(torch.bool).cuda() + masks = torch.cat(masks, dim=0).to(torch.bool) features = self.encoder(images, masks) features = features.reshape(B, 196*768) diff --git a/AU_Recognition/models/modeling_pretrain.py b/AU_Recognition/models/modeling_pretrain.py index 6035819..11149b5 100644 --- a/AU_Recognition/models/modeling_pretrain.py +++ b/AU_Recognition/models/modeling_pretrain.py @@ -197,6 +197,8 @@ def __init__(self, use_learnable_pos_emb=False, num_classes=0, # avoid the error from create_fn in timm in_chans=0, # avoid the error from create_fn in timm + pretrained_cfg=None, # avoid the error from create_fn in timm + pretrained_cfg_overlay=None, # avoid the error from create_fn in timm ): super().__init__() self.encoder = PretrainVisionTransformerEncoder( diff --git a/AU_Recognition/resnet_disfa_all/DISFA/all/resnet.pt b/AU_Recognition/resnet_disfa_all/DISFA/all/resnet.pt index fd9be02..e9dc46f 100644 Binary files a/AU_Recognition/resnet_disfa_all/DISFA/all/resnet.pt and b/AU_Recognition/resnet_disfa_all/DISFA/all/resnet.pt differ diff --git a/AU_Recognition/swin_checkpoint/DISFA/0/swin.pt b/AU_Recognition/swin_checkpoint/DISFA/0/swin.pt new file mode 100644 index 0000000..5902bc0 --- /dev/null +++ b/AU_Recognition/swin_checkpoint/DISFA/0/swin.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:744457443743945644d4187f75120411a4bcf94792a627e2fc40831df2acc546 +size 191448499 diff --git a/Facial_Expression_Recognition/export_to_onnx_exportonly.ipynb b/Facial_Expression_Recognition/export_to_onnx_exportonly.ipynb new file mode 100644 index 0000000..d178de8 --- /dev/null +++ b/Facial_Expression_Recognition/export_to_onnx_exportonly.ipynb @@ -0,0 +1,358 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from datetime import datetime\n", + "import torch\n", + "from models.resnet18 import ResNet" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "num_labels = 8\n", + "#train = False\n", + "device = \"cpu\"\n", + "onnx_name = \"FacialExpression\"\n", + "dropout = 0.1\n", + "fm_distillation = True\n", + "\n", + "class ResNetConfig:\n", + "\tdef __init__(self):\n", + "\t\tself.fm_distillation = fm_distillation\n", + "\t\tself.dropout = dropout\n", + "\t\tself.num_labels = num_labels\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# ResNet18\n", + "model_config = ResNetConfig()\n", + "model = ResNet(model_config)\n", + "ckpt_name = os.path.join(\"checkpoints_fm_resnet\", \"AffectNet\", \"resnet.pt\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ResNet(\n", + " (encoder): Sequential(\n", + " (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", + " (4): Sequential(\n", + " (0): BasicBlock(\n", + " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (1): BasicBlock(\n", + " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (5): Sequential(\n", + " (0): BasicBlock(\n", + " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (downsample): Sequential(\n", + " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", + " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (1): BasicBlock(\n", + " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (6): Sequential(\n", + " (0): BasicBlock(\n", + " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (downsample): Sequential(\n", + " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", + " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (1): BasicBlock(\n", + " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (7): Sequential(\n", + " (0): BasicBlock(\n", + " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (downsample): Sequential(\n", + " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", + " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (1): BasicBlock(\n", + " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (8): AdaptiveAvgPool2d(output_size=(1, 1))\n", + " )\n", + " (classifier): Sequential(\n", + " (0): Linear(in_features=512, out_features=128, bias=True)\n", + " (1): ReLU()\n", + " (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Dropout(p=0.1, inplace=False)\n", + " (4): Linear(in_features=128, out_features=8, bias=True)\n", + " (5): Sigmoid()\n", + " )\n", + ")" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "checkpoints = torch.load(ckpt_name, map_location=torch.device(device))[\"model\"]\n", + "model.load_state_dict(checkpoints, strict=True)\n", + "torch.no_grad()\n", + "model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "================ Diagnostic Run torch.onnx.export version 2.0.1 ================\n", + "verbose: False, log level: Level.ERROR\n", + "======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n", + "\n" + ] + } + ], + "source": [ + "\n", + "\"\"\"\n", + "for images, labels in loader:\n", + "\timages = images.to(device)\n", + "\tlabels = labels.to(device)\n", + "\tlabels_pred = model(images)\n", + "\tlabels_pred = torch.clamp(labels_pred, min=0.0, max=5.0)\n", + "\"\"\"\n", + "\n", + "dummy_input = torch.rand((1, 3, 224, 224), device=device)\n", + "input_names = [ \"image\" ]\n", + "output_names = [ \"FEs\" ]\n", + "onnx_name = onnx_name + datetime.now().strftime(\"%Y%m%d%H%M%S\") + \".onnx\"\n", + "\n", + "torch.onnx.export(\n", + "\tmodel, \n", + "\tdummy_input, \n", + "\tonnx_name, \n", + "\tverbose=True, \n", + "\tinput_names=input_names,\n", + "\toutput_names=output_names\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "graph torch_jit (\n", + " %image[FLOAT, 1x3x224x224]\n", + ") initializers (\n", + " %classifier.0.weight[FLOAT, 128x512]\n", + " %classifier.0.bias[FLOAT, 128]\n", + " %classifier.2.weight[FLOAT, 128]\n", + " %classifier.2.bias[FLOAT, 128]\n", + " %classifier.2.running_mean[FLOAT, 128]\n", + " %classifier.2.running_var[FLOAT, 128]\n", + " %classifier.4.weight[FLOAT, 8x128]\n", + " %classifier.4.bias[FLOAT, 8]\n", + " %onnx::Conv_211[FLOAT, 64x3x7x7]\n", + " %onnx::Conv_212[FLOAT, 64]\n", + " %onnx::Conv_214[FLOAT, 64x64x3x3]\n", + " %onnx::Conv_215[FLOAT, 64]\n", + " %onnx::Conv_217[FLOAT, 64x64x3x3]\n", + " %onnx::Conv_218[FLOAT, 64]\n", + " %onnx::Conv_220[FLOAT, 64x64x3x3]\n", + " %onnx::Conv_221[FLOAT, 64]\n", + " %onnx::Conv_223[FLOAT, 64x64x3x3]\n", + " %onnx::Conv_224[FLOAT, 64]\n", + " %onnx::Conv_226[FLOAT, 128x64x3x3]\n", + " %onnx::Conv_227[FLOAT, 128]\n", + " %onnx::Conv_229[FLOAT, 128x128x3x3]\n", + " %onnx::Conv_230[FLOAT, 128]\n", + " %onnx::Conv_232[FLOAT, 128x64x1x1]\n", + " %onnx::Conv_233[FLOAT, 128]\n", + " %onnx::Conv_235[FLOAT, 128x128x3x3]\n", + " %onnx::Conv_236[FLOAT, 128]\n", + " %onnx::Conv_238[FLOAT, 128x128x3x3]\n", + " %onnx::Conv_239[FLOAT, 128]\n", + " %onnx::Conv_241[FLOAT, 256x128x3x3]\n", + " %onnx::Conv_242[FLOAT, 256]\n", + " %onnx::Conv_244[FLOAT, 256x256x3x3]\n", + " %onnx::Conv_245[FLOAT, 256]\n", + " %onnx::Conv_247[FLOAT, 256x128x1x1]\n", + " %onnx::Conv_248[FLOAT, 256]\n", + " %onnx::Conv_250[FLOAT, 256x256x3x3]\n", + " %onnx::Conv_251[FLOAT, 256]\n", + " %onnx::Conv_253[FLOAT, 256x256x3x3]\n", + " %onnx::Conv_254[FLOAT, 256]\n", + " %onnx::Conv_256[FLOAT, 512x256x3x3]\n", + " %onnx::Conv_257[FLOAT, 512]\n", + " %onnx::Conv_259[FLOAT, 512x512x3x3]\n", + " %onnx::Conv_260[FLOAT, 512]\n", + " %onnx::Conv_262[FLOAT, 512x256x1x1]\n", + " %onnx::Conv_263[FLOAT, 512]\n", + " %onnx::Conv_265[FLOAT, 512x512x3x3]\n", + " %onnx::Conv_266[FLOAT, 512]\n", + " %onnx::Conv_268[FLOAT, 512x512x3x3]\n", + " %onnx::Conv_269[FLOAT, 512]\n", + ") {\n", + " %/encoder/encoder.0/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [7, 7], pads = [3, 3, 3, 3], strides = [2, 2]](%image, %onnx::Conv_211, %onnx::Conv_212)\n", + " %/encoder/encoder.2/Relu_output_0 = Relu(%/encoder/encoder.0/Conv_output_0)\n", + " %/encoder/encoder.3/MaxPool_output_0 = MaxPool[ceil_mode = 0, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]](%/encoder/encoder.2/Relu_output_0)\n", + " %/encoder/encoder.4/encoder.4.0/conv1/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%/encoder/encoder.3/MaxPool_output_0, %onnx::Conv_214, %onnx::Conv_215)\n", + " %/encoder/encoder.4/encoder.4.0/relu/Relu_output_0 = Relu(%/encoder/encoder.4/encoder.4.0/conv1/Conv_output_0)\n", + " %/encoder/encoder.4/encoder.4.0/conv2/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%/encoder/encoder.4/encoder.4.0/relu/Relu_output_0, %onnx::Conv_217, %onnx::Conv_218)\n", + " %/encoder/encoder.4/encoder.4.0/Add_output_0 = Add(%/encoder/encoder.4/encoder.4.0/conv2/Conv_output_0, %/encoder/encoder.3/MaxPool_output_0)\n", + " %/encoder/encoder.4/encoder.4.0/relu_1/Relu_output_0 = Relu(%/encoder/encoder.4/encoder.4.0/Add_output_0)\n", + " %/encoder/encoder.4/encoder.4.1/conv1/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%/encoder/encoder.4/encoder.4.0/relu_1/Relu_output_0, %onnx::Conv_220, %onnx::Conv_221)\n", + " %/encoder/encoder.4/encoder.4.1/relu/Relu_output_0 = Relu(%/encoder/encoder.4/encoder.4.1/conv1/Conv_output_0)\n", + " %/encoder/encoder.4/encoder.4.1/conv2/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%/encoder/encoder.4/encoder.4.1/relu/Relu_output_0, %onnx::Conv_223, %onnx::Conv_224)\n", + " %/encoder/encoder.4/encoder.4.1/Add_output_0 = Add(%/encoder/encoder.4/encoder.4.1/conv2/Conv_output_0, %/encoder/encoder.4/encoder.4.0/relu_1/Relu_output_0)\n", + " %/encoder/encoder.4/encoder.4.1/relu_1/Relu_output_0 = Relu(%/encoder/encoder.4/encoder.4.1/Add_output_0)\n", + " %/encoder/encoder.5/encoder.5.0/conv1/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]](%/encoder/encoder.4/encoder.4.1/relu_1/Relu_output_0, %onnx::Conv_226, %onnx::Conv_227)\n", + " %/encoder/encoder.5/encoder.5.0/relu/Relu_output_0 = Relu(%/encoder/encoder.5/encoder.5.0/conv1/Conv_output_0)\n", + " %/encoder/encoder.5/encoder.5.0/conv2/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%/encoder/encoder.5/encoder.5.0/relu/Relu_output_0, %onnx::Conv_229, %onnx::Conv_230)\n", + " %/encoder/encoder.5/encoder.5.0/downsample/downsample.0/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [2, 2]](%/encoder/encoder.4/encoder.4.1/relu_1/Relu_output_0, %onnx::Conv_232, %onnx::Conv_233)\n", + " %/encoder/encoder.5/encoder.5.0/Add_output_0 = Add(%/encoder/encoder.5/encoder.5.0/conv2/Conv_output_0, %/encoder/encoder.5/encoder.5.0/downsample/downsample.0/Conv_output_0)\n", + " %/encoder/encoder.5/encoder.5.0/relu_1/Relu_output_0 = Relu(%/encoder/encoder.5/encoder.5.0/Add_output_0)\n", + " %/encoder/encoder.5/encoder.5.1/conv1/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%/encoder/encoder.5/encoder.5.0/relu_1/Relu_output_0, %onnx::Conv_235, %onnx::Conv_236)\n", + " %/encoder/encoder.5/encoder.5.1/relu/Relu_output_0 = Relu(%/encoder/encoder.5/encoder.5.1/conv1/Conv_output_0)\n", + " %/encoder/encoder.5/encoder.5.1/conv2/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%/encoder/encoder.5/encoder.5.1/relu/Relu_output_0, %onnx::Conv_238, %onnx::Conv_239)\n", + " %/encoder/encoder.5/encoder.5.1/Add_output_0 = Add(%/encoder/encoder.5/encoder.5.1/conv2/Conv_output_0, %/encoder/encoder.5/encoder.5.0/relu_1/Relu_output_0)\n", + " %/encoder/encoder.5/encoder.5.1/relu_1/Relu_output_0 = Relu(%/encoder/encoder.5/encoder.5.1/Add_output_0)\n", + " %/encoder/encoder.6/encoder.6.0/conv1/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]](%/encoder/encoder.5/encoder.5.1/relu_1/Relu_output_0, %onnx::Conv_241, %onnx::Conv_242)\n", + " %/encoder/encoder.6/encoder.6.0/relu/Relu_output_0 = Relu(%/encoder/encoder.6/encoder.6.0/conv1/Conv_output_0)\n", + " %/encoder/encoder.6/encoder.6.0/conv2/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%/encoder/encoder.6/encoder.6.0/relu/Relu_output_0, %onnx::Conv_244, %onnx::Conv_245)\n", + " %/encoder/encoder.6/encoder.6.0/downsample/downsample.0/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [2, 2]](%/encoder/encoder.5/encoder.5.1/relu_1/Relu_output_0, %onnx::Conv_247, %onnx::Conv_248)\n", + " %/encoder/encoder.6/encoder.6.0/Add_output_0 = Add(%/encoder/encoder.6/encoder.6.0/conv2/Conv_output_0, %/encoder/encoder.6/encoder.6.0/downsample/downsample.0/Conv_output_0)\n", + " %/encoder/encoder.6/encoder.6.0/relu_1/Relu_output_0 = Relu(%/encoder/encoder.6/encoder.6.0/Add_output_0)\n", + " %/encoder/encoder.6/encoder.6.1/conv1/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%/encoder/encoder.6/encoder.6.0/relu_1/Relu_output_0, %onnx::Conv_250, %onnx::Conv_251)\n", + " %/encoder/encoder.6/encoder.6.1/relu/Relu_output_0 = Relu(%/encoder/encoder.6/encoder.6.1/conv1/Conv_output_0)\n", + " %/encoder/encoder.6/encoder.6.1/conv2/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%/encoder/encoder.6/encoder.6.1/relu/Relu_output_0, %onnx::Conv_253, %onnx::Conv_254)\n", + " %/encoder/encoder.6/encoder.6.1/Add_output_0 = Add(%/encoder/encoder.6/encoder.6.1/conv2/Conv_output_0, %/encoder/encoder.6/encoder.6.0/relu_1/Relu_output_0)\n", + " %/encoder/encoder.6/encoder.6.1/relu_1/Relu_output_0 = Relu(%/encoder/encoder.6/encoder.6.1/Add_output_0)\n", + " %/encoder/encoder.7/encoder.7.0/conv1/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]](%/encoder/encoder.6/encoder.6.1/relu_1/Relu_output_0, %onnx::Conv_256, %onnx::Conv_257)\n", + " %/encoder/encoder.7/encoder.7.0/relu/Relu_output_0 = Relu(%/encoder/encoder.7/encoder.7.0/conv1/Conv_output_0)\n", + " %/encoder/encoder.7/encoder.7.0/conv2/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%/encoder/encoder.7/encoder.7.0/relu/Relu_output_0, %onnx::Conv_259, %onnx::Conv_260)\n", + " %/encoder/encoder.7/encoder.7.0/downsample/downsample.0/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [2, 2]](%/encoder/encoder.6/encoder.6.1/relu_1/Relu_output_0, %onnx::Conv_262, %onnx::Conv_263)\n", + " %/encoder/encoder.7/encoder.7.0/Add_output_0 = Add(%/encoder/encoder.7/encoder.7.0/conv2/Conv_output_0, %/encoder/encoder.7/encoder.7.0/downsample/downsample.0/Conv_output_0)\n", + " %/encoder/encoder.7/encoder.7.0/relu_1/Relu_output_0 = Relu(%/encoder/encoder.7/encoder.7.0/Add_output_0)\n", + " %/encoder/encoder.7/encoder.7.1/conv1/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%/encoder/encoder.7/encoder.7.0/relu_1/Relu_output_0, %onnx::Conv_265, %onnx::Conv_266)\n", + " %/encoder/encoder.7/encoder.7.1/relu/Relu_output_0 = Relu(%/encoder/encoder.7/encoder.7.1/conv1/Conv_output_0)\n", + " %/encoder/encoder.7/encoder.7.1/conv2/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%/encoder/encoder.7/encoder.7.1/relu/Relu_output_0, %onnx::Conv_268, %onnx::Conv_269)\n", + " %/encoder/encoder.7/encoder.7.1/Add_output_0 = Add(%/encoder/encoder.7/encoder.7.1/conv2/Conv_output_0, %/encoder/encoder.7/encoder.7.0/relu_1/Relu_output_0)\n", + " %/encoder/encoder.7/encoder.7.1/relu_1/Relu_output_0 = Relu(%/encoder/encoder.7/encoder.7.1/Add_output_0)\n", + " %/encoder/encoder.8/GlobalAveragePool_output_0 = GlobalAveragePool(%/encoder/encoder.7/encoder.7.1/relu_1/Relu_output_0)\n", + " %/Constant_output_0 = Constant[value = ]()\n", + " %onnx::Gemm_204 = Reshape[allowzero = 0](%/encoder/encoder.8/GlobalAveragePool_output_0, %/Constant_output_0)\n", + " %/classifier/classifier.0/Gemm_output_0 = Gemm[alpha = 1, beta = 1, transB = 1](%onnx::Gemm_204, %classifier.0.weight, %classifier.0.bias)\n", + " %/classifier/classifier.1/Relu_output_0 = Relu(%/classifier/classifier.0/Gemm_output_0)\n", + " %/classifier/classifier.2/BatchNormalization_output_0 = BatchNormalization[epsilon = 9.99999974737875e-06, momentum = 0.899999976158142, training_mode = 0](%/classifier/classifier.1/Relu_output_0, %classifier.2.weight, %classifier.2.bias, %classifier.2.running_mean, %classifier.2.running_var)\n", + " %/classifier/classifier.4/Gemm_output_0 = Gemm[alpha = 1, beta = 1, transB = 1](%/classifier/classifier.2/BatchNormalization_output_0, %classifier.4.weight, %classifier.4.bias)\n", + " %FEs = Sigmoid(%/classifier/classifier.4/Gemm_output_0)\n", + " return %FEs, %onnx::Gemm_204\n", + "}\n" + ] + } + ], + "source": [ + "import onnx\n", + "model = onnx.load(onnx_name)\n", + "onnx.checker.check_model(model)\n", + "print(onnx.helper.printable_graph(model.graph))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "4f76a724a0076d5c39752e12bec55adbbf3b081a4d622e794e00deba6a6ff878" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/Facial_Expression_Recognition/inference.py b/Facial_Expression_Recognition/inference.py index baac0a2..181d147 100644 --- a/Facial_Expression_Recognition/inference.py +++ b/Facial_Expression_Recognition/inference.py @@ -50,7 +50,7 @@ # Fix random seed set_seed(opts.seed) -train_loader, test_loader = get_data_loaders(opts) +train_loader, test_loader = (None, None) # get_data_loaders(opts) # Setup solver diff --git a/Zongjian's Notes.md b/Zongjian's Notes.md new file mode 100644 index 0000000..f0bb251 --- /dev/null +++ b/Zongjian's Notes.md @@ -0,0 +1,7 @@ +# Python Environment + +`conda create --name Face python=3.11` +`conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia` +Download mediapipe from `https://pypi.org/project/mediapipe/#files` and run `pip install .whl` +`conda install scipy tqdm jupyter pandas` +`pip install dlib imutils timm onnx onnxruntime` \ No newline at end of file diff --git a/detect_mediapipe.py b/detect_mediapipe.py index d253a1a..b10df41 100644 --- a/detect_mediapipe.py +++ b/detect_mediapipe.py @@ -109,10 +109,10 @@ def image_align(img, face_landmarks, output_size=256, -image_root = '/home/ICT2000/dchang/DISFA_Data/DISFA/images/' -aligned_image_root = '/home/ICT2000/dchang/DISFA_Data/DISFA/aligned_images_new/' -landmark_root = '/home/ICT2000/dchang/DISFA_Data/DISFA/landmark/' -annotated_image_root = '/home/ICT2000/dchang/DISFA_Data/DISFA/detect_images/' +image_root = 'data/DISFA/images/' +aligned_image_root = 'data/DISFA/aligned_images_new/' +landmark_root = 'data/DISFA/landmark/' +annotated_image_root = 'data/DISFA/detect_images/' for folder in os.listdir(image_root): os.makedirs(os.path.join(annotated_image_root,folder),exist_ok=True) os.makedirs(os.path.join(aligned_image_root,folder),exist_ok=True)