-
Notifications
You must be signed in to change notification settings - Fork 130
APIs
This is the main function of this project. It runs MSA experiments on datasets and models specified in the parameters.
Definition:
def MMSA_run(
model_name: str, dataset_name: str, config_file: str = "",
config: dict = None, seeds: list = [], is_tune: bool = False,
tune_times: int = 50, feature_T: str = "", feature_A: str = "",
feature_V: str = "", model_save_dir: str = "", res_save_dir: str = "",
log_dir: str = "", gpu_ids: list = [0], num_workers: int = 4,
verbose_level: int = 1
)Args:
-
model_name(required): Name of MSA model, see Supported Models for details. -
dataset_name(required): Name of MSA dataset, see Supported Datasets for details. -
config_file: Path to config file. Default config files will be used if not specified. See Config Files for details. -
config: Config in the format of Python dict. Used to override arguments inconfig_file. Ignored in tune mode. -
seeds: List of seeds. Default:[1111, 1112, 1113, 1114, 1115] -
is_tune: Tuning mode switch. See Tuning Mode for details. Default:False -
tune_times: # Sets of hyper parameters to tune. Default:50 -
feature_T: Path to text feature file. Provide an empty string to use default BERT features. Default:"" -
feature_A: Path to audio feature file. Provide an empty string to use default features provided by dataset creators. Default:"" -
feature_V: Path to video feature file. Provide an empty string to use default features provided by dataset creators. Default:"" -
model_save_dir: Path to save trained models. Default:~/MMSA/saved_models -
res_save_dir: Path to save csv results. Default:~/MMSA/results -
log_dir: Path to save log files. Default:~/MMSA/logs -
gpu_ids: GPUs to use. Will assign the most memory-free gpu if an empty list is provided. Default:[0]. Currently only supports single gpu. -
num_workers: Number of workers used to load data. Default:4 -
verbose_level: Verbose level of stdout.0for error,1for info,2for debug. Default:1
Example Usage:
from MMSA import MMSA_run
# run lmf on mosi with default params
MMSA_run('lmf', 'mosi')
# tune mult on mosei with default param ranges
MMSA_run('mult', 'mosi', is_tune=True, seeds=[1111])Retrieves config dict of given dataset and model from a config file, so that it can be viewed or altered easily.
Definition:
def get_config_regression(
model_name: str, dataset_name: str, config_file: str = ""
) -> dict:Args:
- model_name*(required)*: Name of model.
- dataset_name*(required)*: Name of dataset.
- config_file: Path to config file, if given an empty string, will use default config file. Default:
""
Returns:
Config of the given dataset and model in the format of Python dict.
Example Usage:
from MMSA import MMSA_run, get_config_regression
# get default config of mult on sims
config = get_config_regression('mult', 'sims')
# alter the default config
config['nlevels'] = 4
config['conv1d_kernel_size_l'] = 3
config['conv1d_kernel_size_a'] = 3
config['conv1d_kernel_size_v'] = 3
# check config
print(config)
# run with altered config
MMSA_run('mult', 'sims', config=config)Get the tuning config of given dataset and model from a config file. If random_choice is True, the returning config will contain only one set of randomly selected parameters.
Definition:
def get_config_tune(
model_name: str, dataset_name: str, config_file: str = "",
random_choice: bool = True
) -> dictArgs:
- model_name*(required)*: Name of model.
- dataset_name*(required)*: Name of dataset.
- config_file: Path to config file, if given an empty string, will use default tuning config file. Default:
"" - random_choice: If True, will randomly select one parameter for each of the tuning parameters. Default:
True
Returns:
Config of the given dataset and model in the format of Python dict.
Example Usage:
from MMSA import get_config_tune, MMSA_run
# get default tuning parameters of misa on mosi
config = get_config_tune('misa', 'mosi', random_choice=False)
print(config)
# random select one set of default tuning parameters and run
config = get_config_tune('misa', 'mosi')
MMSA_run('misa', 'mosi', config=config)Get all default configs. This function is used to export default config file.
Definition:
def get_config_all(config_file: str) -> dictArgs:
- config_file: "regression" or "tune".
Returns:
Default regression or tuning config for all models in the format of Python dict.
Example Usage:
from MMSA import get_config_all
# get default regression config and dump to file
config = get_config_all("regression")
with open("./config.json", "w") as f:
json.dump(config, f)Get paper titles and citations for models and datasets.
Definition:
def get_citations() -> dictReturns:
cites: {
models: {
tfn: {
title: "xxx",
paper_url: "xxx",
citation: "xxx",
description: "xxx"
},
...
},
datasets: {
...
},
}
Example Usage:
from MMSA import get_citations
# get citations of MOSI dataset
citations = get_citations()
print(citations["mosi"])