Skip to content

[RFC][Module] Replace init_weights with StateInitializer pattern#2540

Open
fegin wants to merge 2 commits intogh/fegin/92/basefrom
gh/fegin/92/head
Open

[RFC][Module] Replace init_weights with StateInitializer pattern#2540
fegin wants to merge 2 commits intogh/fegin/92/basefrom
gh/fegin/92/head

Conversation

@fegin
Copy link
Contributor

@fegin fegin commented Mar 10, 2026

Stack from ghstack (oldest at bottom):

Key changes:

  1. New StateInitializer(Configurable) base class.
  2. Module.Config gains state_initializer field, it is not optional.
  3. Remove *kwargs from init_states/init_weights. All changes must
    go through StateInitializer.Config before constructing a module.
    So the parent module is responsible to pass through the state
    initialization config before constructing children.

Item 3 is the key change that needs to discuss. Whether people agree this direction or people simply want to use kwargs to pass the config when calling init_states.

[ghstack-poisoned]
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 10, 2026
fegin added a commit that referenced this pull request Mar 10, 2026
Key changes:
1. New StateInitializer(Configurable) base class.
2. Module.Config gains state_initializer field, it is not optional.
3. Remove *kwargs from init_states/init_weights. All changes must
   go through StateInitializer.Config before constructing a module.
   So the parent module is responsible to pass through the state
   initialization config before constructing children.

Item 3 is the key change that needs to discuss. Whether people agree
this direction or people simply want to use kwargs to pass the config
when calling init_states.


ghstack-source-id: e373bb8
Pull-Request: #2540
[ghstack-poisoned]
fegin added a commit that referenced this pull request Mar 10, 2026
Key changes:
1. New StateInitializer(Configurable) base class.
2. Module.Config gains state_initializer field, it is not optional.
3. Remove *kwargs from init_states/init_weights. All changes must
   go through StateInitializer.Config before constructing a module.
   So the parent module is responsible to pass through the state
   initialization config before constructing children.

Item 3 is the key change that needs to discuss. Whether people agree
this direction or people simply want to use kwargs to pass the config
when calling init_states.

ghstack-source-id: b726dfb
Pull-Request: #2540
@fegin fegin requested a review from acisseJZhong March 10, 2026 20:46
@acisseJZhong
Copy link
Contributor

@claude review this pr

@claude
Copy link

claude bot commented Mar 11, 2026

Claude finished @acisseJZhong's task in 2m 25s —— View job


PR Review: [RFC][Module] Replace init_weights with StateInitializer pattern

  • Read CLAUDE.md for repo-specific guidelines
  • Analyze the core new abstraction (StateInitializer)
  • Review Module protocol changes
  • Review model implementations (llama3, deepseek_v3, qwen3, gpt_oss, llama4)
  • Review common module changes (attention, decoder, embedding, linear, feed_forward, etc.)
  • Review test changes
  • Provide comprehensive review feedback

Summary

This RFC introduces a StateInitializer(Configurable) pattern to replace the previous init_weights(**kwargs) approach. The key architectural change is that initialization parameters are now threaded through the Config hierarchy rather than passed as runtime kwargs to init_states. This is a significant refactor touching 49 files across all model families.


Design Assessment

The core idea is sound. Making weight initialization declarative and config-driven rather than imperative/kwargs-driven has clear benefits:

  1. Serializable configs -- The full initialization strategy is captured in the config tree, enabling reproducibility and config-driven experimentation.
  2. No surprise kwargs -- Parents are responsible for setting up child configs before construction, which is more explicit than passing kwargs at init time.
  3. Separation of concerns -- The StateInitializer is a standalone Configurable (not an nn.Module), keeping initialization logic separate from forward computation.

However, there are several concerns and issues worth discussing.


Issues and Concerns

1. Proliferation of StateInitializer classes

Every module now needs its own StateInitializer subclass, even when the init logic is trivial. The PR introduces:

  • LinearStateInitializer, EmbeddingStateInitializer, RMSNormStateInitializer, RoPEStateInitializer
  • GQAttentionStateInitializer, FeedForwardStateInitializer, MoEStateInitializer
  • DecoderStateInitializer, Llama3TransformerBlockStateInitializer, Qwen3TransformerBlockStateInitializer, Llama4TransformerBlockStateInitializer, DeepSeekV3TransformerBlockStateInitializer, DeepSeekV3AttentionStateInitializer
  • GptOssAttentionStateInitializer, GptOssTransformerBlockStateInitializer
  • Qwen3DecoderStateInitializer

Many of these (e.g., RMSNormStateInitializer, RoPEStateInitializer) have empty configs and trivial init_states implementations. This is a lot of boilerplate for what was previously a simple method override. For downstream users adding new models, this increases the barrier to entry.

2. assert isinstance pattern for config downcasting

Multiple modules assert the concrete type of state_initializer at construction time:

# torchtitan/models/llama3/model.py:56-57
assert isinstance(
    config.state_initializer, Llama3TransformerBlockStateInitializer.Config
)

This appears in Llama3TransformerBlock.__init__ (line 56), Qwen3TransformerBlock.__init__ (line 63), GQAttention.__init__ (line 472), FeedForward.__init__ (line 68), DeepSeekV3 Attention.__init__ (line 124), and GptOssAttention.__init__ (line 84).

This assert-based downcasting to read init_std undermines the polymorphism that StateInitializer is supposed to provide. If someone provides a custom StateInitializer.Config that isn't the expected concrete subclass, the module crashes at construction time rather than gracefully falling back. Consider using getattr(config.state_initializer, 'init_std', 0.02) or defining a common protocol/interface for configs that carry init_std.

3. replace_state_init_field couples parent modules tightly to child config internals

The replace_state_init_field helper (torchtitan/protocols/module.py:36-41) uses dataclasses.replace to patch the state_initializer field on child configs:

attn_cfg = config.attention.replace_state_init_field(init_std=weight_init_std)

This assumes every child's StateInitializer.Config has an init_std field, but that's not enforced by the type system. If a child uses a different StateInitializer.Config that doesn't have init_std, this will crash at runtime with a TypeError from dataclasses.replace. This is a fragile pattern that could break silently when new modules are added.

4. Qwen3Model manually rebuilds _state_initializer in __init__

In torchtitan/models/qwen3/model.py:204-208:

# Rebuild state_initializer to pick up enable_weight_tying
si_config = Qwen3DecoderStateInitializer.Config(
    enable_weight_tying=config.enable_weight_tying
)
self._state_initializer = Qwen3DecoderStateInitializer(si_config)

This manually reconstructs the state initializer after super().__init__() to sync the enable_weight_tying flag. This breaks the declarative config pattern -- the config's state_initializer field and the actual _state_initializer instance on the module are now out of sync. A cleaner approach would be to ensure Qwen3Model.Config.state_initializer defaults to a factory that reads enable_weight_tying from the parent config, or to handle this in update_from_config.

5. FeedForward.__init__ applies replace_state_init_field(init_std=...) to w3

In torchtitan/models/common/feed_forward.py:76-77:

w3_cfg = config.w3.replace_state_init_field(init_std=init_std)
self.w3 = w3_cfg.build(in_features=config.dim, out_features=config.hidden_dim)

w3 is the gate projection in SwiGLU (F.silu(w1(x)) * w3(x)). In the original Llama3 paper and most implementations, only w2 (the output projection) gets the depth-scaled init_std; w1 and w3 use the default 0.02. Applying the depth-scaled init_std to w3 here might be intentional but it's a behavioral change worth documenting or verifying. Compare with w1 which uses config.w1.build(...) without replace_state_init_field.

6. StateInitializer.__init__ does nothing

In torchtitan/protocols/state_initializer.py:27-28:

def __init__(self, config: Config):
    pass

The base class __init__ discards the config. Subclasses like LinearStateInitializer manually extract fields in their own __init__:

def __init__(self, config: Config):
    self.init_mean = config.init_mean
    self.init_std = config.init_std
    self.cutoff_factor = config.cutoff_factor

This is error-prone -- if a field is added to the config but forgotten in __init__, it's silently lost. Consider storing self.config = config in the base class, similar to how Module subclasses do it.

7. GroupedExperts and TokenChoiceTopKRouter don't follow the pattern

GroupedExperts (torchtitan/models/common/moe/moe.py:83-137) and TokenChoiceTopKRouter (lines 140-318) are plain nn.Module subclasses that receive init_std as a constructor argument and have their own init_states() method that doesn't use StateInitializer. This creates inconsistency -- some modules use the pattern, others don't. If the goal is to make all initialization config-driven, these should eventually follow the same pattern.

8. Module.init_states defensive check may mask bugs

In torchtitan/protocols/module.py:49-54:

def init_states(self, *, buffer_device=None) -> None:
    if hasattr(self, "_state_initializer") and self._state_initializer is not None:
        self._state_initializer.init_states(self, buffer_device=buffer_device)
    else:
        raise NotImplementedError(...)

The hasattr check is defensive against cases where __init__ wasn't called (e.g., inject_module_protocol patches). This is fine, but the else branch says "must provide a state_initializer in Config" which is misleading -- it could also mean Module.__init__ wasn't called due to MRO issues.

9. Config verbosity in model registries

The model configs in __init__.py files are now quite verbose. Every Linear.Config() for internal projections (wq, wk, wv, wo, w1, w2, w3) must be explicitly listed. For models with many variants (Qwen3 has 10 configs), this is a lot of repetition. Consider whether a helper function or config inheritance could reduce this.


Regarding the RFC Discussion Point

The PR description asks:

Whether people agree this direction or people simply want to use kwargs to pass the config when calling init_states.

The config-driven approach has meaningful advantages for serialization and reproducibility. However, the current implementation creates substantial boilerplate. A middle ground might be:

  1. Keep StateInitializer for modules where initialization is genuinely complex (e.g., depth-scaled init, weight tying).
  2. For leaf modules with straightforward init (Linear, Embedding, RMSNorm), consider a lighter-weight approach -- perhaps a callable or a simple config dict rather than a full Configurable subclass.
  3. The replace_state_init_field pattern for threading init_std down the tree works but is fragile. Consider whether init_std should just be a field on Module.Config directly (inherited by all modules) rather than nested inside StateInitializer.Config.

Minor Notes

  • The test coverage in test_module.py and test_linear.py is thorough and covers the diamond inheritance pattern, injection, and verification well.
  • The NoOpStateInitializer default is a good design choice -- modules that don't need custom init (like Flux components) work without any changes.
  • The capture_module_attrs / inject_module_protocol / verify_module_protocol workflow for post-quantization compatibility is well-designed.

nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)
nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std)
nn.init.trunc_normal_(self.w2, mean=0.0, std=self._init_std)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also use self._init_mean?

class MoEStateInitializer(StateInitializer):
@dataclass(kw_only=True, slots=True)
class Config(StateInitializer.Config):
init_std: float = 0.02
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same to last comment, why not also introducing mean here?

in_features=dim, out_features=num_experts
self.gate = (
Linear.Config(bias=gate_bias)
.replace_state_init_field(init_std=init_std)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found a little bit unintuitive to have replace_state_init_field here, can we modify LinearStateInitializer directly?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants