PyTorch implementation of influence function methods for understanding how individual training samples affect model predictions. Includes the classic ICML 2017 influence function, TracIn (NeurIPS 2020), and EmpiricalIF (NeurIPS 2022) for fast, single-checkpoint influence estimation without inverse Hessian.
- What is an Influence Function?
- How It Works
- Methods Included
- Installation
- Quick Start
- Usage Examples
- Citation
- Author & Contact
The influence function (Understanding Black-box Predictions via Influence Functions, ICML 2017) answers:
How much does a single training point affect the model's prediction or loss on a specific test point?
Instead of retraining after removing or perturbing a training sample, the influence function estimates the change in the modelβs prediction (or loss) on a test point using gradient and Hessian approximations β no retraining required.
Where
- Positive influence β keeping this training point increases test loss β harmful for this test point
- Negative influence β keeping this point decreases test loss β helpful
The main bottleneck is estimating the inverse Hessian. Conjugate gradient or damping can be expensive and unstable. This repo provides two lightweight alternatives that avoid the full inverse Hessian:
| Paper | Method | Venue |
|---|---|---|
| Estimating Training Data Influence by Tracing Gradient Descent | TracIn | NeurIPS 2020 |
| Debugging and Explaining Metric Learning Approach: An Influence Function Perspective | EmpiricalIF | NeurIPS 2022 |
With
-
Positive β
$z_i$ helped reduce test loss -
Negative β
$z_i$ hurt test performance (possibly harmful)
TracIn is first-order (no Hessian); it needs multiple checkpoints for good estimates.
EmpiricalIF uses the final checkpoint only. It perturbs
-
Positive β
$z_i$ and test$z_{\text{test}}$ co-evolve β helpful -
Negative β
$z_i$ conflicts with test β harmful
EmpiricalIF is a single-checkpoint variant of TracIn. In practice, using the steepest descent and ascent directions for the test loss is enough to compute it.
-
Install PyTorch (match your CUDA version):
https://pytorch.org/get-started/previous-versions/ -
Install dependencies:
pip install -r requirements.txt
Inputs:
dl_train:torch.utils.data.DataLoaderfor training datamodel:nn.Moduleparam_filter_fn: which parameters to use (e.g. last layer only)criterion: loss withreduction="none"
Output:
IF.query_influence(test_input, test_target)returns a list of influence scores of length|dl_train|, one per training sample.
from src.IF import EmpiricalIF
IF = EmpiricalIF(dl_train=trainloader,
model=resnet18,
param_filter_fn=lambda name, param: 'fc' in name,
criterion=nn.CrossEntropyLoss(reduction="none"))
for test_sample in testloader:
test_input, test_target = test_sample
IF_scores = IF.query_influence(test_input, test_target)
print(IF_scores) # shape: (|dl_train|,)Reverse check (perturb top/bottom influence samples and compare):
most_inf, least_inf = IF.reverse_check(
query_input=test_input,
query_target=test_target,
influence_values=IF_scores,
check_ratio=0.01 # top and bottom 1%
)
for idx, orig_if, rev_if in most_inf:
print(f"Top IF sample {idx}: IF={orig_if:.4f}, Reverse IF={rev_if:.4f}")
for idx, orig_if, rev_if in least_inf:
print(f"Bottom IF sample {idx}: IF={orig_if:.4f}, Reverse IF={rev_if:.4f}")from src.IF import BaseInfluenceFunction
IF = BaseInfluenceFunction(dl_train=trainloader,
model=resnet18,
param_filter_fn=lambda name, param: 'fc' in name,
criterion=nn.CrossEntropyLoss(reduction="none"))
for test_sample in testloader:
test_input, test_target = test_sample
IF_scores = IF.query_influence(test_input, test_target)
print(IF_scores)from src.IF import TracIn
IF = TracIn(dl_train=trainloader,
model=resnet18,
param_filter_fn=lambda name, param: 'fc' in name,
criterion=nn.CrossEntropyLoss(reduction="none"))
for test_sample in testloader:
test_input, test_target = test_sample
IF_scores = IF.query_influence(test_input, test_target)
print(IF_scores)If you use this repository, please cite the EmpiricalIF paper:
@article{liu2022debugging,
title={Debugging and Explaining Metric Learning Approaches: An Influence Function Based Perspective},
author={Liu, Ruofan and Lin, Yun and Yang, Xianglin and Dong, Jin Song},
journal={Advances in Neural Information Processing Systems},
volume={35},
pages={7824--7837},
year={2022}
}KuchikiRenji
| GitHub | github.com/KuchikiRenji |
| [email protected] | |
| Discord | kuchiki_renji |
For questions, collaborations, or feedback about this implementation, reach out via the links above.
