Skip to content
6 changes: 1 addition & 5 deletions src/vit_prisma/models/base_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,6 @@ def move_model_modules_to_device(self):
def from_pretrained(
cls,
model_name: str,
is_timm: bool = True,
is_clip: bool = False,
fold_ln: Optional[bool] = True,
center_writing_weights: Optional[bool] = True,
refactor_factored_attn_matrices: Optional[bool] = False,
Expand Down Expand Up @@ -728,14 +726,12 @@ def from_pretrained(

cfg = convert_pretrained_model_config(
model_name,
is_timm=is_timm,
is_clip=is_clip,
)



state_dict = get_pretrained_state_dict(
model_name, is_timm, is_clip, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs
model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs
)

model = cls(cfg, move_to_device=False)
Expand Down
24 changes: 18 additions & 6 deletions src/vit_prisma/prisma_tools/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from transformers import AutoConfig, ViTForImageClassification, VivitForVideoClassification, CLIPModel

import timm

from vit_prisma.configs.HookedViTConfig import HookedViTConfig

import torch
Expand All @@ -22,7 +23,15 @@
import einops


def check_timm(model_name: str) -> bool:
"Check if the model name is a timm model"
available_models = timm.list_models(pretrained=True)
return any(model_name.lower() in available_model.lower() for available_model in available_models)

def check_clip(model_name: str) -> bool:
"Check if the model name is a clip model"
config = AutoConfig.from_pretrained(model_name)
return config.model_type == "clip"

def convert_clip_weights(
old_state_dict,
Expand Down Expand Up @@ -289,8 +298,6 @@ def convert_hf_vit_for_image_classification_weights( old_state_dict,

def get_pretrained_state_dict(
official_model_name: str,
is_timm: bool,
is_clip: bool,
cfg: HookedViTConfig,
hf_model=None,
dtype: torch.dtype = torch.float32,
Expand Down Expand Up @@ -318,7 +325,10 @@ def get_pretrained_state_dict(
# f"Loading model {official_model_name} state dict requires setting trust_remote_code=True"
# )
# kwargs["trust_remote_code"] = True


is_timm = check_timm(official_model_name)
is_clip = False if is_timm else check_clip(official_model_name)

try:
if is_timm:
hf_model = hf_model if hf_model is not None else timm.create_model(official_model_name, pretrained=True)
Expand Down Expand Up @@ -390,9 +400,11 @@ def fill_missing_keys(model, state_dict):
state_dict[key] = default_state_dict[key]
return state_dict

def convert_pretrained_model_config(model_name: str, is_timm: bool = True, is_clip: bool = False) -> HookedViTConfig:

def convert_pretrained_model_config(model_name: str) -> HookedViTConfig:

is_timm = check_timm(model_name)
is_clip = False if is_timm else check_clip(model_name)


if is_timm:
model = timm.create_model(model_name)
Expand Down Expand Up @@ -475,4 +487,4 @@ def convert_pretrained_model_config(model_name: str, is_timm: bool = True, is_cl

print(pretrained_config)

return HookedViTConfig.from_dict(pretrained_config)
return HookedViTConfig.from_dict(pretrained_config)
2 changes: 1 addition & 1 deletion tests/test_loading_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_loading_clip():
tinyclip.to(device)
tinyclip_final_proj.to(device)

hooked_model = HookedViT.from_pretrained(model_name, is_timm=False, is_clip=True)
hooked_model = HookedViT.from_pretrained(model_name)
hooked_model.to(device)

with torch.random.fork_rng():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_loading_vit_for_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_loading_vit_for_image_classification():
width = 224
device = "cpu"

hooked_model = HookedViT.from_pretrained(model_name=model_name, is_timm=False)
hooked_model = HookedViT.from_pretrained(model_name=model_name)
hooked_model.to(device)
vit_model = ViTForImageClassification.from_pretrained(model_name)
vit_model.to(device)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_loading_vivet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_loading_vivet():
width = 224
device = "cpu"

hooked_model = HookedViT.from_pretrained(model_name, is_timm=False)
hooked_model = HookedViT.from_pretrained(model_name)
hooked_model.to(device)
google_model = VivitForVideoClassification.from_pretrained(model_name)
google_model.to(device)
Expand Down