Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,15 @@ dmypy.json
Makefile

# debugging files
scorer_demo.py
scorer_demo.py
models/ViT-L-14.pt
models/ImageReward.pt
models/med_config.json
models/HPS_v2_compressed.pt
models/pytorch_model.bin
TEST.py
/mscoco/
/mscoco.parquet
models/CLIP-ViT-L-14.pt
models/HPS_v2.1.pt
models/Real.pt
3 changes: 3 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions .idea/sd-webui-bayesian-merger.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 35 additions & 3 deletions conf/config.tmpl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ work_device: cpu
threads: 1

wildcards_dir: path/to/wildcards/folder
scorer_model_dir: path/to/scorer/models/folder

model_a: path/to/model_a/file
model_b: path/to/model_b/file
Expand All @@ -32,11 +31,44 @@ guided_optimisation: False
batch_size: 1
init_points: 1
n_iters: 1
img_average_type: arithmetic # geometric, arithmetic, quadratic

save_imgs: False

scorer_device: cpu # cuda
scorer_method: chad # chad, laion, manual
# scorer by type:
# Prompt-Image Alignment: blip, clip
# Aesthetic: chad, laion
# Hybrid(PIA + AES): ir, hpsv2, pick
# Anime/Illustration: shadow, cafe, wdaes
# Misc: manual, noai, iqa
#
# !!!! IQA ARE NOT IMPLEMENTED YET !!!!
#
# Notes:
# 1) recomended tested safe setup is [laion, chad, clip, blip, ir] with weights 0.5, 0.5, 1, 1, 1

scorer_method: [clip, blip, laion, chad, ir]
scorer_average_type: arithmetic # geometric, arithmetic, quadratic
scorer_weight:
#blip: 0.5
#chad: 2
# example above, default is 1
scorer_default_device: cpu # cuda
scorer_device:
#blip: cpu
#chad: cuda
# example above, default is scorer default device
scorer_model_dir: path/to/scorer/models/folder
scorer_alt_location:
#blip:
#model_name: scorer.pth
#model_dir: path/to/scorer/scorer.pth
#chad:
#model_name: scorer.pt
#model_dir: path/to/scorer/scorer.pt
# example above, default downloads them in the scorer_model_dir(this option is here if you already have them downloaded somewhere else)
scorer_print_individual: False


save_best: False
best_format: safetensors # ckpt
Expand Down
1 change: 0 additions & 1 deletion install.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

with open(Path(extension_dir, "requirements.txt"), "r", encoding="utf-8") as f:
reqs = f.readlines()
print(reqs)

for req in reqs:
req = req.strip()
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ sd-meh==0.9.5
lightgbm
scikit-learn
openai-clip
tensordict
timm
fairscale
1 change: 1 addition & 0 deletions sd_webui_bayesian_merger/models/BLIP/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .blip_pretrain import *
70 changes: 70 additions & 0 deletions sd_webui_bayesian_merger/models/BLIP/blip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
'''
* Adapted from BLIP (https://github.com/salesforce/BLIP)
'''

import warnings
warnings.filterwarnings("ignore")

import torch
import os
from urllib.parse import urlparse
from timm.models.hub import download_cached_file
from transformers import BertTokenizer
from .vit import VisionTransformer, interpolate_pos_embed


def init_tokenizer():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
return tokenizer


def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):

assert vit in ['base', 'large'], "vit parameter must be base or large"
if vit=='base':
vision_width = 768
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
drop_path_rate=0 or drop_path_rate
)
elif vit=='large':
vision_width = 1024
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
drop_path_rate=0.1 or drop_path_rate
)
return visual_encoder, vision_width


def is_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")

def load_checkpoint(model,url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
checkpoint = torch.load(cached_file, map_location='cpu')
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location='cpu')
else:
raise RuntimeError('checkpoint url or path is invalid')

state_dict = checkpoint['model']

state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
model.visual_encoder_m)
for key in model.state_dict().keys():
if key in state_dict.keys():
if state_dict[key].shape!=model.state_dict()[key].shape:
print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
del state_dict[key]

msg = model.load_state_dict(state_dict,strict=False)
print('load checkpoint from %s'%url_or_filename)
return model,msg

43 changes: 43 additions & 0 deletions sd_webui_bayesian_merger/models/BLIP/blip_pretrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
'''
* Adapted from BLIP (https://github.com/salesforce/BLIP)
'''

import transformers
transformers.logging.set_verbosity_error()

from torch import nn
import os
from .med import BertConfig, BertModel
from .blip import create_vit, init_tokenizer

class BLIP_Pretrain(nn.Module):
def __init__(self,
med_config = "med_config.json",
image_size = 224,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
embed_dim = 256,
queue_size = 57600,
momentum = 0.995,
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()

self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)

self.tokenizer = init_tokenizer()
encoder_config = BertConfig.from_json_file(med_config)
encoder_config.encoder_width = vision_width
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)

text_width = self.text_encoder.config.hidden_size

self.vision_proj = nn.Linear(vision_width, embed_dim)
self.text_proj = nn.Linear(text_width, embed_dim)

Loading