-
Notifications
You must be signed in to change notification settings - Fork 3
Description
Hi, this is an excellent implementation of path patching!
I carefully read and follow the code and the given demo, things go well for most of the time. However, when I want to deal with some long-text tasks with LLMs, such as reading comprehension with Llama2-7b, often get a "out of memory" error, which means the model cann't process too many data points simultaneously.
It's upset that I can only put two or three data points to do path patching. So I wonder if there are some techniques to see the effect of path patching on all of the data points? For example, can I average the results of path patching as shown below?
for i in range(len(dataset)/batch_size):
# Get the batch data
data = dataset[i*batch_size:(i+1)*batch_size]
# do once path patching
results = path_patch(
model,
orig_input=clean_tokens,
new_input=flipped_tokens,
sender_nodes=IterNode('z'), # This means iterate over all heads in all layers
receiver_nodes=Node('resid_post', 31), # This is resid_post at layer 11
patching_metric=ioi_metric_noising,
verbose=True
)
# accumulate the results
total_results += results
##get the average results on all of the data
ave_results = total_results / (len(dataset)/batch_size)
I am confused about whether doing this produces the same effect as performing path patching on the entire dataset. If so, I think it can save much cuda memory.
Looking forward to your reply!