Skip to content

gkevinyen5418/LoRA-RITE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

LoRA Done RITE: Robust Invariant Transformation Equilibration for LoRA Optimization

This is a pytorch reimplementation of the original LoRA-RITE in Jax.

Usage

Please copy lora_rite.py to your directory or install it as a module.

Then you can do the following to create a normal pytorch optimizer object.

from lora_rite import LoRARite

lora_params = [p for n, p in model.named_parameters() if "lora" in n]
optimizer = LoRARite(lora_params, lr=learning_rate, betas=(0.9,0.999), clip_unmagnified_grad=max_grad_norm)

Here we assume the lora parameters will be in an alternating order lora_a_1, lora_b_1, lora_a_2, lora_b_2, ... as in the huggingface peft LoRA implementation. In the rare case where this assumption is not satisfied, one can manually reorder it so that the assumption is met.

To correctly adopt gradient clipping, please set the clip_unmagnified_grad parameter for LoRARite and disable it elsewhere (e.g., by setting max_grad_norm=0 for the huggingface trainer).

Commonsense Reasoning Evaluation

This setting is significantly different from what is used in the paper due to the potentially high amount of effort needed to align the environments of pytorch and JAX. We adopt the recipe from the LLM-adapter paper, where the datasets are highly overlapped with our original experiments.

Gemma-2B Result

Optimizer BOOLQ PIQA SIQA HellaSwag Winogrande ARC-E ARC-C OBQA Average
LoRARite 62.91 74.86 67.50 69.30 62.04 78.45 62.29 68.80 68.27
Adam 62.20 75.46 65.35 67.38 55.80 76.60 58.70 68.00 66.19

Running the Experiments

cd LLM-Adapters

# evaluate existing lora_rite checkpoint
bash eval_commonsense.sh

# finetune with lora_rite
bash finetune_commonsense.sh

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published