Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
*.pth
__pycache__/
pretrained_weights/
*.egg-info
19 changes: 8 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
# LoRA for SAM (meta's segment-anything)

## Usage
```
from segment_anything import build_sam, SamAutomaticMaskGenerator
from segment_anything import sam_model_registry
from sam_lora import LoRA_Sam
import torch
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
lora_sam = LoRA_Sam(sam,r = 4)
result = lora_sam.sam.image_encoder(torch.rand(size=(1,3,1024,1024)))
print(result.shape)
```
1. Create a conda environment and install pytorch as described [here](https://pytorch.org/get-started/locally/)
2. run `pip install .` or `pip install -e .`
3. Get the model checkpoints [here](https://github.com/facebookresearch/segment-anything/tree/main?tab=readme-ov-file#model-checkpoints) or run
`mkdir pretrained_weights && wget -P pretrained_weights https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth`

### Example
See the [test](./scripts/test_lora.py)

## Train
Coming soon and welcome pull request.

## Thanks
The code for LoRA ViT comes form
https://github.com/JamesQFreeman/LoRA-ViT
https://github.com/JamesQFreeman/LoRA-ViT
Binary file removed __pycache__/sam_lora.cpython-39.pyc
Binary file not shown.
File renamed without changes.
Empty file added sam_lora/__init__.py
Empty file.
23 changes: 5 additions & 18 deletions sam_lora.py → sam_lora/sam_lora.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
# Sheng Wang at Apr 6 2023
# What a time to be alive (first half of 2023)

from segment_anything import build_sam, SamPredictor
from segment_anything import sam_model_registry

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.parameter import Parameter
from segment_anything.modeling import Sam
from safetensors import safe_open
Expand Down Expand Up @@ -48,6 +43,7 @@ def forward(self, x):
qkv[:, :, :, -self.dim :] += new_v
return qkv


class LoRA_Sam(nn.Module):
"""Applies low-rank adaptation to a Sam model's image encoder.

Expand Down Expand Up @@ -129,7 +125,7 @@ def save_lora_parameters(self, filename: str) -> None:
r"""Only safetensors is supported now.

pip install safetensor if you do not have one installed yet.

save both lora and fc parameters.
"""

Expand All @@ -138,11 +134,11 @@ def save_lora_parameters(self, filename: str) -> None:
num_layer = len(self.w_As) # actually, it is half
a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)}
b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)}

_in = self.lora_vit.head.in_features
_out = self.lora_vit.head.out_features
fc_tensors = {f"fc_{_in}in_{_out}out": self.lora_vit.head.weight}

merged_dict = {**a_tensors, **b_tensors, **fc_tensors}
save_file(merged_dict, filename)

Expand All @@ -166,7 +162,7 @@ def load_lora_parameters(self, filename: str) -> None:
saved_key = f"w_b_{i:03d}"
saved_tensor = f.get_tensor(saved_key)
w_B_linear.weight = Parameter(saved_tensor)

_in = self.lora_vit.head.in_features
_out = self.lora_vit.head.out_features
saved_key = f"fc_{_in}in_{_out}out"
Expand All @@ -181,12 +177,3 @@ def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5))
for w_B in self.w_Bs:
nn.init.zeros_(w_B.weight)

# def forward(self, x: Tensor) -> Tensor:
# return self.lora_vit(x)


if __name__ == "__main__":
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
lora_sam = LoRA_Sam(sam,4)
lora_sam.sam.image_encoder(torch.rand(size=(1,3,1024,1024)))
10 changes: 10 additions & 0 deletions scripts/test_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from segment_anything import sam_model_registry
from sam_lora.sam_lora import LoRA_Sam
import torch

if __name__ == "__main__":
sam = sam_model_registry["vit_b"](
checkpoint="pretrained_weights/sam_vit_b_01ec64.pth"
)
lora_sam = LoRA_Sam(sam, 4)
lora_sam.sam.image_encoder(torch.rand(size=(1, 3, 1024, 1024)))
15 changes: 0 additions & 15 deletions segment_anything/__init__.py

This file was deleted.

Binary file removed segment_anything/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file removed segment_anything/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading