Conversation
|
@brianhou0208 thanks for the work, and looks like a good job getting it in shape. I took a closer look using your code but I have some doubts about this model
For speed comparisons I disabled F.sdpa in existing vit to be fair. Simpler vits with higher acccuracy (imagenet-1k pretrain also to be fair) are often 30-40% faster. So not convinced this is worth the add. Was there a particular reason you had interest in the model? |
def single_attn(self, x: torch.Tensor) -> torch.Tensor:
k, q, v = torch.split(self.kqv(x), self.emb, dim=-1)
if not torch.jit.is_scripting():
with torch.autocast(device_type=v.device.type, enabled=False):
y = self._attn_impl(k, q, v)
else:
y = self._attn_impl(k, q, v)
# skip connection
y = v + self.dp(self.proj(y)) # same as token_transformer in T2T layer, use v as skip connection
return y
def _attn_impl(self, k, q, v):
kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, T, m), (B, T, m)
D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2) # (B, T, m) * (B, m) -> (B, T, 1)
kptv = torch.einsum('bin,bim->bnm', v.float(), kp) # (B, emb, m)
y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon) # (B, T, emb)/Diag
return y |
|
Hi @rwightman, I agree with your observation. The T2T-ViT model does not have advantages over other models. The only advantage might be that it does not use any Another issue occurs when using pre-trained weights and testing whether the structure of first_conv is adaptive to the number of input (C, H, W). If pytorch-image-models/tests/test_models.py Lines 371 to 376 in d81da93 In test_model_load_pretrained , if first_convd is like T2T-ViT without Conv, passing this parameter to nn.Linear instead of nn.Conv2d will also report an error.pytorch-image-models/timm/models/_builder.py Lines 225 to 239 in d81da93 Since this involves modifying |
|
@brianhou0208 I don't know if not having the input conv is a 'feature', my very first vit impl here, before the official JAX code was released that used the Conv2D trick was this: pytorch-image-models/timm/models/vision_transformer.py Lines 139 to 169 in 7613094 The conv approach was faster since it was an optimized kernel and not a chain of API calls, I suppose torch.compile would rectify most of that but still don't see the downside to the conv. Also the packed vit I started working on (have yet to pick it back up) has to push patchification further into the data pipeline, https://github.com/huggingface/pytorch-image-models/blob/379780bb6ca3304d63bf8ca789d5bbce5949d0b5/timm/models/vision_transformer_packed.py |
Hi @rwightman this PR resolved #2364 , please check.
Result
test T2T-ViT model and weight on ImageNet val dataset
test code
output log
calculate FLOPs/MACs/Params tool
report from calflops
Reference
paper: https://arxiv.org/pdf/2101.11986
code: https://github.com/yitu-opensource/T2T-ViT