Hi! Thanks for your cool idea and great work
I was trying coconut on llama3.2 1B and I found the norm of the last hidden state is way larger than the norm of the regular input token embedding.
I was wondering will that cause some unstable training issue? Have you tried to normalize the last hidden state before feeding to the model?