To prepare RLtools for mixed-precision training we introduce a type policy numeric_types::Policy. Until now it was easy to switch the floating point type in RLtools because everything depended on the T parameter. For modern deep leraning this is not sufficient because we would like to configure different types for different parts of the models / algorithms (e.g. bf16 for parameters, fp32 for gradient/optimizer state). To facilitate this it is not sufficient to pass around a single type parameter.
Hence we created numeric_types::Policy to enable flexible type configuration:
using namespace rlt::numeric_types;
using PARAMETER_TYPE_RULE = UseCase<categories::Parameter, float>;
using GRADIENT_TYPE_RULE = UseCase<categories::Gradient, float>;
using TYPE_POLICY = Policy<double, PARAMETER_TYPE_RULE, GRADIENT_TYPE_RULE>;The TYPE_POLICY is then passed instead of T, e.g.:
using MODEL_CONFIG = rlt::nn_models::mlp::Configuration<TYPE_POLICY, TI, OUTPUT_DIM, NUM_LAYERS, HIDDEN_DIM, ACTIVATION_FUNCTION, ACTIVATION_FUNCTION_OUTPUT>;
In the codebase the TYPE_POLICY is then queried as follows:
using PARAMETER_TYPE = TYPE_POLICY::template GET<categories::Parameter>;
using GRADIENT_TYPE = TYPE_POLICY::template GET<categories::Gradient>;
using GRADIENT_TYPE = TYPE_POLICY::template GET<categories::Optimizer>;
This allows for a very flexible configuration. If a tag is not set (like categories::Optimizer in this case), it will fall back to TYPE_POLICY::DEFAULT which is double in this case (the first argument). TYPE_POLICY::DEFAULT is also the type that should be used for configuration variables and other variables that do not clearly fall under the categories. You can also define custom category tags yourself, easily. More about that will be covered in a section of the documentation at https://docs.rl.tools in the future.
This is a small API change but it appears in many places, so we implemented it ASAP (without mixed precision training itself being implemented, yet) such that there will be less confusion in the future where we expect these kinds of API to be more stable.
Currently, the advice is to just create a TYPE_POLICY = rlt::numeric_types::Policy<{float,double}>;' and pass it everywhere. You might encounter errors when trying to access e.g. some SPEC::Twhich you should be able to replace withSPEC::TYPE_POLICY::DEFAULTfor identical behavior. In general the behavior should be exactly identical as long as you configure the same float type you used forT` before.