Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@kumar-devesh
Copy link

Initial code from PR #2568

This PR adds g-LoRA to close issue #780

@BenjaminBossan
Copy link
Member

Sorry for the delay, we haven't forgotten, we were just very busy these last weeks. In the meantime, could you please resolve the merge conflicts?

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! I did a first skim and this looks like a good first throw.

I think there's a bit of room for refactoring since BaseTuner and BaseTunerLayer now provide a lot more functionality for free.

To be consistent with other methods, let's rename GLora to Glora.

Let's also add Glora to the test_custom_models test suite to get a good set of early test coverage!

rendered properly in your Markdown viewer.
-->

# GLora
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename this to GLoRA like in the paper.

```

## Notes
- GLora is a superset of LoRA: setting all paths to "LoRA" recovers standard LoRA.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't seem correct. Setting A to LoRA and everything else to none should recover LoRA, no?

from .c3a import C3AConfig, C3AModel
from .cpt import CPTConfig, CPTEmbedding
from .fourierft import FourierFTConfig, FourierFTModel
from .glora import GLoraConfig, GLoraModel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use Glora instead of GLora in the code (all occurrences)

Comment on lines +50 to +73
config_A_B: str = field(
default="LoRA",
metadata={
"help": "Configuration for A and B matrices in GLora."
f"Valid values: {', '.join(_VALID_A_B_CONFIGS)}. "
"For LoRA, it will be post-processed to LoRA_<rank>."
},
)

config_C: str = field(
default="LoRA",
metadata={
"help": "Configuration for C matrix in GLora."
f"Valid values: {', '.join(_VALID_C_CONFIGS)}. "
"For LoRA, it will be post-processed to LoRA_<rank>."
},
)

config_D_E: str = field(
default="constant",
metadata={
"help": f"Configuration for D and E matrices in GLora. Valid values: {', '.join(_VALID_D_E_CONFIGS)}."
},
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm misunderstanding but to me it doesn't make sense to tie A/B and D/E configurations since they are independent in the formulation. To cite the paper:

A is utilized to scale the weight. B has the role to scale the input and shift the weight. C is the layer-wise prompt serving a similar function of VPT-Deep, D and E are used to scale and shift the bias, respectively.

To me it doesn't seem uncommon to set B=scalar while keeping A=LoRA.

import torch.nn.functional as F


class GLoraLayer(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's inherit from BaseTunerLayer instead of nn.Module since the latter is already provided in the concrete class (e.g., GloraLinear).

Let's also define the class attributes adapter_layer_names and other_param_names. You can look at lora/layer.py for inspiration`.

Comment on lines +326 to +335
class GLoraLinear(GLoraLayer, nn.Linear):
def __init__(self, in_features, out_features, bias=True, **kwargs):
nn.Linear.__init__(self, in_features, out_features, bias=bias)
GLoraLayer.__init__(self, in_features=in_features, out_features=out_features)
self.weight.requires_grad = False
if self.bias is not None:
self.bias.requires_grad = False
self._disable_adapters = False
self.active_adapters = []
self.merged_adapters = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's swap the inheritance order for GloraLinear. Also, I would refrain from using nn.Linear and inherit from nn.Module instead since a common usage pattern is to check for nn.Linear to distinguish between PEFT and non-PEFT methods.

Comment on lines +100 to +120
def delete_adapter(self, adapter_name):
for d in [
self.glora_Ad,
self.glora_Au,
self.glora_Bd,
self.glora_Bu,
self.glora_Cd,
self.glora_Cu,
self.glora_D,
self.glora_E,
]:
if adapter_name in d:
del d[adapter_name]
if adapter_name in self.r:
del self.r[adapter_name]
if adapter_name in self.eval_config:
del self.eval_config[adapter_name]
if adapter_name in self.active_adapters:
self.active_adapters.remove(adapter_name)
if adapter_name in self.merged_adapters:
self.merged_adapters.remove(adapter_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{delete,enable,add,...}_adapter you get for free when inheriting from BaseTunerLayer.

Comment on lines +67 to +69
# Add all adapters after peft_config is set
for name, cfg in self.peft_config.items():
self.add_adapter(name, cfg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that add_adapter is needed since BaseTuner provides inject_adapter.

add_adapter, find_and_replace as well as mark_only_glora_as_trainable are, as far as I can see, all covered by BaseTuner.inject_adapter provided that we use BaseTunerLayer correctly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants