-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexport_openclip.py
More file actions
442 lines (370 loc) · 14.7 KB
/
export_openclip.py
File metadata and controls
442 lines (370 loc) · 14.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
"""Export OpenCLIP ViT-L/14 (DFN-2B) visual and text encoders to ExecuTorch .pte format."""
import argparse
import json
import shutil
from datetime import date
from pathlib import Path
import torch
import torch.nn as nn
VARIANTS = {
"fp16_all": {
"precision": "fp16",
"compute_unit": "ALL",
"description": "fp16, CPU+GPU+ANE (CoreML + XNNPACK fallback)",
"backend": "coreml",
},
"fp32_cpu": {
"precision": "fp32",
"compute_unit": "CPU_ONLY",
"description": "fp32, CPU (XNNPACK only — CoreML fp32 segfaults with torch 2.11/coremltools 9)",
"backend": "xnnpack",
},
}
OUTPUT_DIR = Path("output/openclip")
MODEL_NAME = "ViT-L-14"
PRETRAINED = "dfn2b"
def load_model():
"""Load OpenCLIP ViT-L/14 with DFN-2B pretrained weights."""
import open_clip
model, _, preprocess_val = open_clip.create_model_and_transforms(
MODEL_NAME, pretrained=PRETRAINED, force_quick_gelu=True
)
model.eval()
return model, preprocess_val
def print_preprocessing_info(preprocess_val):
"""Extract and print preprocessing constants from the transform pipeline."""
print("\n=== Preprocessing Constants ===")
for t in preprocess_val.transforms:
print(f" {t}")
print()
def inspect_model(model, preprocess_val):
"""Print model architecture for identifying the correct submodules."""
print("=== Top-level modules ===")
for name, _ in model.named_children():
print(f" {name}")
print("\n=== model.visual submodules ===")
for name, mod in model.visual.named_children():
param_count = sum(p.numel() for p in mod.parameters())
print(f" {name}: {mod.__class__.__name__} ({param_count:,} params)")
print("\n=== Text encoder attributes ===")
for attr in [
"token_embedding",
"positional_embedding",
"transformer",
"ln_final",
"text_projection",
"attn_mask",
]:
obj = getattr(model, attr, None)
if obj is None:
print(f" {attr}: NOT FOUND")
elif isinstance(obj, nn.Module):
param_count = sum(p.numel() for p in obj.parameters())
print(f" {attr}: {obj.__class__.__name__} ({param_count:,} params)")
elif isinstance(obj, (nn.Parameter, torch.Tensor)):
print(f" {attr}: {obj.shape} {obj.dtype}")
else:
print(f" {attr}: {type(obj).__name__}")
print("\n=== All top-level attributes (non-module) ===")
for attr in dir(model):
if attr.startswith("_"):
continue
obj = getattr(model, attr, None)
if isinstance(obj, (nn.Parameter, torch.Tensor)):
print(f" {attr}: {obj.shape} {obj.dtype}")
elif isinstance(obj, nn.Module) and attr not in dict(model.named_children()):
print(f" {attr}: {obj.__class__.__name__}")
print_preprocessing_info(preprocess_val)
# Test forward pass
print("=== Forward pass test ===")
with torch.no_grad():
dummy_image = torch.randn(1, 3, 224, 224)
img_features = model.encode_image(dummy_image)
print(f" encode_image output shape: {img_features.shape}")
print(f" encode_image L2 norm: {torch.norm(img_features, dim=-1).item():.6f}")
vis_out = model.visual(dummy_image)
print(f" model.visual output shape: {vis_out.shape}")
print(f" model.visual L2 norm: {torch.norm(vis_out, dim=-1).item():.6f}")
import open_clip
dummy_text = open_clip.tokenize(["a photo of a cat"])
txt_features = model.encode_text(dummy_text)
print(f" encode_text output shape: {txt_features.shape}")
print(f" encode_text L2 norm: {torch.norm(txt_features, dim=-1).item():.6f}")
# Check activation function (QuickGELU vs GELU)
print("\n=== Activation function check ===")
first_block = model.visual.transformer.resblocks[0]
act_layer = first_block.mlp[1] if hasattr(first_block.mlp, "__getitem__") else None
if act_layer:
print(f" Visual MLP activation: {act_layer.__class__.__name__}")
else:
# Try Sequential access
for name, mod in first_block.mlp.named_children():
if (
"gelu" in name.lower()
or "act" in name.lower()
or "quick" in mod.__class__.__name__.lower()
):
print(f" Visual MLP activation: {mod.__class__.__name__}")
break
else:
# Just list all MLP submodules
print(" Visual MLP structure:")
for name, mod in first_block.mlp.named_children():
print(f" {name}: {mod.__class__.__name__}")
class CLIPVisualEncoder(nn.Module):
"""Wraps the visual encoder as a standalone module (no L2 normalization)."""
def __init__(self, clip_model):
super().__init__()
self.visual = clip_model.visual
def forward(self, image: torch.Tensor) -> torch.Tensor:
return self.visual(image)
class CLIPTextEncoder(nn.Module):
"""Wraps the text encoder as a standalone module.
Assembles the text forward pass from scattered model attributes.
The causal attention mask is registered as a buffer so torch.export
treats it as a constant.
"""
def __init__(self, clip_model):
super().__init__()
self.token_embedding = clip_model.token_embedding
self.positional_embedding = clip_model.positional_embedding
self.transformer = clip_model.transformer
self.ln_final = clip_model.ln_final
self.text_projection = clip_model.text_projection
self.register_buffer("attn_mask", clip_model.attn_mask)
def forward(self, text: torch.Tensor) -> torch.Tensor:
x = self.token_embedding(text) # (1, 77, 768)
x = x + self.positional_embedding # (1, 77, 768)
x = self.transformer(x, attn_mask=self.attn_mask)
x = self.ln_final(x) # (1, 77, 768)
# Extract features from EOT token position
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] # (1, 768)
x = x @ self.text_projection # (1, 768)
return x
def extract_visual_encoder(model):
"""Extract the visual encoder and validate against encode_image."""
wrapper = CLIPVisualEncoder(model)
with torch.no_grad():
dummy = torch.randn(1, 3, 224, 224)
ref = model.encode_image(dummy)
ours = wrapper(dummy)
cos_sim = nn.functional.cosine_similarity(ref, ours, dim=-1).item()
max_diff = (ref - ours).abs().max().item()
print(f" Visual wrapper vs encode_image cosine sim: {cos_sim:.6f}")
print(f" Max absolute diff: {max_diff:.8f}")
if cos_sim < 0.999:
print(" WARNING: Visual wrapper output diverges from encode_image.")
wrapper.eval()
return wrapper
def extract_text_encoder(model):
"""Extract the text encoder and validate against encode_text."""
import open_clip
wrapper = CLIPTextEncoder(model)
with torch.no_grad():
dummy = open_clip.tokenize(["a photo of a cat"])
ref = model.encode_text(dummy)
ours = wrapper(dummy)
cos_sim = nn.functional.cosine_similarity(ref, ours, dim=-1).item()
max_diff = (ref - ours).abs().max().item()
print(f" Text wrapper vs encode_text cosine sim: {cos_sim:.6f}")
print(f" Max absolute diff: {max_diff:.8f}")
if cos_sim < 0.999:
print(" WARNING: Text wrapper output diverges from encode_text.")
wrapper.eval()
return wrapper
def export_variant(encoder, variant_name, module_name):
"""Export one variant to .pte file.
Args:
encoder: The wrapped encoder module.
variant_name: Key into VARIANTS dict (e.g., "fp16_all").
module_name: "visual" or "text".
"""
import coremltools as ct
from executorch.backends.apple.coreml.compiler import CoreMLBackend
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackPartitioner,
)
from executorch.exir import to_edge_transform_and_lower
variant = VARIANTS[variant_name]
print(
f"\n=== Exporting {module_name} {variant_name} ({variant['description']}) ==="
)
# Step 1: torch.export
print(" Exporting with torch.export...")
if module_name == "visual":
example_input = (torch.randn(1, 3, 224, 224),)
else:
example_input = (torch.randint(0, 49408, (1, 77), dtype=torch.long),)
try:
exported = torch.export.export(encoder, example_input)
except Exception as e:
print(f" Strict export failed ({e}), trying strict=False...")
exported = torch.export.export(encoder, example_input, strict=False)
# Step 2: Lower with backend-specific partitioners
backend = variant["backend"]
if backend == "coreml":
precision = (
ct.precision.FLOAT16
if variant["precision"] == "fp16"
else ct.precision.FLOAT32
)
compute_unit = (
ct.ComputeUnit.ALL
if variant["compute_unit"] == "ALL"
else ct.ComputeUnit.CPU_AND_GPU
)
compile_specs = CoreMLBackend.generate_compile_specs(
compute_precision=precision,
compute_unit=compute_unit,
minimum_deployment_target=ct.target.iOS18,
)
partitioners = [
CoreMLPartitioner(compile_specs=compile_specs),
XnnpackPartitioner(), # fallback for unsupported ops
]
else:
partitioners = [XnnpackPartitioner()]
print(f" Lowering with {backend} backend...")
lowered = to_edge_transform_and_lower(
exported,
partitioner=partitioners,
)
# Step 3: Serialize
print(" Serializing to .pte...")
pte = lowered.to_executorch()
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
output_path = OUTPUT_DIR / f"clip_vit_l14_{module_name}_{variant_name}.pte"
with open(output_path, "wb") as f:
f.write(pte.buffer)
size_mb = output_path.stat().st_size / (1024 * 1024)
print(f" Saved: {output_path} ({size_mb:.1f} MB)")
return output_path
def extract_tokenizer(output_dir):
"""Download vocab.json and merges.txt from HuggingFace."""
from huggingface_hub import hf_hub_download
repo_id = "openai/clip-vit-large-patch14"
output_dir.mkdir(exist_ok=True)
for filename in ["vocab.json", "merges.txt"]:
print(f" Downloading {filename}...")
downloaded = hf_hub_download(repo_id=repo_id, filename=filename)
dest = output_dir / filename
shutil.copy(downloaded, dest)
print(f" Saved: {dest}")
def write_config(visual_variants, text_variants, preprocess_val):
"""Write config.json with model metadata for both encoders.
Scans the output directory for all existing .pte files to include
variants from prior runs, not just the current invocation.
"""
def make_variant_meta(module_name):
meta = []
for variant_name, v in VARIANTS.items():
pte_path = OUTPUT_DIR / f"clip_vit_l14_{module_name}_{variant_name}.pte"
if pte_path.exists():
meta.append(
{
"filename": pte_path.name,
"precision": v["precision"],
"compute_unit": v["compute_unit"],
"backends": (
["coreml", "xnnpack"]
if v["backend"] == "coreml"
else ["xnnpack"]
),
}
)
return meta
config = {
"source_model": MODEL_NAME,
"pretrained": PRETRAINED,
"executorch_version": "1.1.0",
"export_date": str(date.today()),
"visual_encoder": {
"variants": make_variant_meta("visual"),
"embedding_dim": 768,
"input_shape": [1, 3, 224, 224],
"l2_normalized": False,
},
"text_encoder": {
"variants": make_variant_meta("text"),
"embedding_dim": 768,
"input_shape": [1, 77],
"input_dtype": "int64",
"l2_normalized": False,
},
"tokenizer": {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
"vocab_size": 49408,
"context_length": 77,
"sot_token_id": 49406,
"eot_token_id": 49407,
},
"preprocessing": {
"input_size": 224,
"crop": "center",
"mean": [0.48145466, 0.4578275, 0.40821073],
"std": [0.26862954, 0.26130258, 0.27577711],
},
}
config_path = OUTPUT_DIR / "config.json"
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
print(f"\nWrote {config_path}")
def main():
parser = argparse.ArgumentParser(
description="Export OpenCLIP ViT-L/14 (DFN-2B) to ExecuTorch .pte"
)
parser.add_argument(
"--inspect",
action="store_true",
help="Print model architecture and exit",
)
parser.add_argument(
"--variant",
choices=["fp16_all", "fp32_cpu", "all"],
default="all",
help="Which variant(s) to export (default: all)",
)
parser.add_argument(
"--module",
choices=["visual", "text", "all"],
default="all",
help="Which encoder(s) to export (default: all)",
)
parser.add_argument(
"--skip-tokenizer",
action="store_true",
help="Skip tokenizer file extraction",
)
args = parser.parse_args()
print(f"Loading OpenCLIP {MODEL_NAME} ({PRETRAINED})...")
model, preprocess_val = load_model()
if args.inspect:
inspect_model(model, preprocess_val)
return
modules_to_export = ["visual", "text"] if args.module == "all" else [args.module]
variants_to_export = (
list(VARIANTS.keys()) if args.variant == "all" else [args.variant]
)
visual_exported = {}
text_exported = {}
if "visual" in modules_to_export:
print("\nExtracting visual encoder...")
visual_encoder = extract_visual_encoder(model)
for variant_name in variants_to_export:
path = export_variant(visual_encoder, variant_name, "visual")
visual_exported[variant_name] = path
if "text" in modules_to_export:
print("\nExtracting text encoder...")
text_encoder = extract_text_encoder(model)
for variant_name in variants_to_export:
path = export_variant(text_encoder, variant_name, "text")
text_exported[variant_name] = path
if not args.skip_tokenizer:
print("\nExtracting tokenizer files...")
extract_tokenizer(OUTPUT_DIR)
write_config(visual_exported, text_exported, preprocess_val)
print("\nDone!")
if __name__ == "__main__":
main()