diff --git a/src/vit_prisma/models/base_vit.py b/src/vit_prisma/models/base_vit.py index 31f9de03..9bc185ea 100644 --- a/src/vit_prisma/models/base_vit.py +++ b/src/vit_prisma/models/base_vit.py @@ -684,8 +684,6 @@ def move_model_modules_to_device(self): def from_pretrained( cls, model_name: str, - is_timm: bool = True, - is_clip: bool = False, fold_ln: Optional[bool] = True, center_writing_weights: Optional[bool] = True, refactor_factored_attn_matrices: Optional[bool] = False, @@ -728,14 +726,12 @@ def from_pretrained( cfg = convert_pretrained_model_config( model_name, - is_timm=is_timm, - is_clip=is_clip, ) state_dict = get_pretrained_state_dict( - model_name, is_timm, is_clip, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs + model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs ) model = cls(cfg, move_to_device=False) diff --git a/src/vit_prisma/prisma_tools/loading_from_pretrained.py b/src/vit_prisma/prisma_tools/loading_from_pretrained.py index 185f5a17..b6cba0df 100644 --- a/src/vit_prisma/prisma_tools/loading_from_pretrained.py +++ b/src/vit_prisma/prisma_tools/loading_from_pretrained.py @@ -13,6 +13,7 @@ from transformers import AutoConfig, ViTForImageClassification, VivitForVideoClassification, CLIPModel import timm + from vit_prisma.configs.HookedViTConfig import HookedViTConfig import torch @@ -22,7 +23,15 @@ import einops +def check_timm(model_name: str) -> bool: + "Check if the model name is a timm model" + available_models = timm.list_models(pretrained=True) + return any(model_name.lower() in available_model.lower() for available_model in available_models) +def check_clip(model_name: str) -> bool: + "Check if the model name is a clip model" + config = AutoConfig.from_pretrained(model_name) + return config.model_type == "clip" def convert_clip_weights( old_state_dict, @@ -289,8 +298,6 @@ def convert_hf_vit_for_image_classification_weights( old_state_dict, def get_pretrained_state_dict( official_model_name: str, - is_timm: bool, - is_clip: bool, cfg: HookedViTConfig, hf_model=None, dtype: torch.dtype = torch.float32, @@ -318,7 +325,10 @@ def get_pretrained_state_dict( # f"Loading model {official_model_name} state dict requires setting trust_remote_code=True" # ) # kwargs["trust_remote_code"] = True - + + is_timm = check_timm(official_model_name) + is_clip = False if is_timm else check_clip(official_model_name) + try: if is_timm: hf_model = hf_model if hf_model is not None else timm.create_model(official_model_name, pretrained=True) @@ -390,9 +400,11 @@ def fill_missing_keys(model, state_dict): state_dict[key] = default_state_dict[key] return state_dict -def convert_pretrained_model_config(model_name: str, is_timm: bool = True, is_clip: bool = False) -> HookedViTConfig: - +def convert_pretrained_model_config(model_name: str) -> HookedViTConfig: + is_timm = check_timm(model_name) + is_clip = False if is_timm else check_clip(model_name) + if is_timm: model = timm.create_model(model_name) @@ -475,4 +487,4 @@ def convert_pretrained_model_config(model_name: str, is_timm: bool = True, is_cl print(pretrained_config) - return HookedViTConfig.from_dict(pretrained_config) + return HookedViTConfig.from_dict(pretrained_config) \ No newline at end of file diff --git a/tests/test_loading_clip.py b/tests/test_loading_clip.py index bb2a0359..c533203a 100644 --- a/tests/test_loading_clip.py +++ b/tests/test_loading_clip.py @@ -25,7 +25,7 @@ def test_loading_clip(): tinyclip.to(device) tinyclip_final_proj.to(device) - hooked_model = HookedViT.from_pretrained(model_name, is_timm=False, is_clip=True) + hooked_model = HookedViT.from_pretrained(model_name) hooked_model.to(device) with torch.random.fork_rng(): diff --git a/tests/test_loading_vit_for_image_classification.py b/tests/test_loading_vit_for_image_classification.py index 51514e79..f2ae53a3 100644 --- a/tests/test_loading_vit_for_image_classification.py +++ b/tests/test_loading_vit_for_image_classification.py @@ -17,7 +17,7 @@ def test_loading_vit_for_image_classification(): width = 224 device = "cpu" - hooked_model = HookedViT.from_pretrained(model_name=model_name, is_timm=False) + hooked_model = HookedViT.from_pretrained(model_name=model_name) hooked_model.to(device) vit_model = ViTForImageClassification.from_pretrained(model_name) vit_model.to(device) diff --git a/tests/test_loading_vivet.py b/tests/test_loading_vivet.py index 165898bb..f4ad2aad 100644 --- a/tests/test_loading_vivet.py +++ b/tests/test_loading_vivet.py @@ -19,7 +19,7 @@ def test_loading_vivet(): width = 224 device = "cpu" - hooked_model = HookedViT.from_pretrained(model_name, is_timm=False) + hooked_model = HookedViT.from_pretrained(model_name) hooked_model.to(device) google_model = VivitForVideoClassification.from_pretrained(model_name) google_model.to(device)