[RFC][Module] Replace init_weights with StateInitializer pattern#2540
[RFC][Module] Replace init_weights with StateInitializer pattern#2540fegin wants to merge 2 commits intogh/fegin/92/basefrom
Conversation
Key changes: 1. New StateInitializer(Configurable) base class. 2. Module.Config gains state_initializer field, it is not optional. 3. Remove *kwargs from init_states/init_weights. All changes must go through StateInitializer.Config before constructing a module. So the parent module is responsible to pass through the state initialization config before constructing children. Item 3 is the key change that needs to discuss. Whether people agree this direction or people simply want to use kwargs to pass the config when calling init_states. ghstack-source-id: e373bb8 Pull-Request: #2540
Key changes: 1. New StateInitializer(Configurable) base class. 2. Module.Config gains state_initializer field, it is not optional. 3. Remove *kwargs from init_states/init_weights. All changes must go through StateInitializer.Config before constructing a module. So the parent module is responsible to pass through the state initialization config before constructing children. Item 3 is the key change that needs to discuss. Whether people agree this direction or people simply want to use kwargs to pass the config when calling init_states. ghstack-source-id: b726dfb Pull-Request: #2540
|
@claude review this pr |
|
Claude finished @acisseJZhong's task in 2m 25s —— View job PR Review: [RFC][Module] Replace init_weights with StateInitializer pattern
SummaryThis RFC introduces a Design AssessmentThe core idea is sound. Making weight initialization declarative and config-driven rather than imperative/kwargs-driven has clear benefits:
However, there are several concerns and issues worth discussing. Issues and Concerns1. Proliferation of StateInitializer classesEvery module now needs its own
Many of these (e.g., 2.
|
| nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) | ||
| nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std) | ||
| nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std) | ||
| nn.init.trunc_normal_(self.w2, mean=0.0, std=self._init_std) |
There was a problem hiding this comment.
can we also use self._init_mean?
| class MoEStateInitializer(StateInitializer): | ||
| @dataclass(kw_only=True, slots=True) | ||
| class Config(StateInitializer.Config): | ||
| init_std: float = 0.02 |
There was a problem hiding this comment.
same to last comment, why not also introducing mean here?
| in_features=dim, out_features=num_experts | ||
| self.gate = ( | ||
| Linear.Config(bias=gate_bias) | ||
| .replace_state_init_field(init_std=init_std) |
There was a problem hiding this comment.
I found a little bit unintuitive to have replace_state_init_field here, can we modify LinearStateInitializer directly?
Stack from ghstack (oldest at bottom):
Key changes:
go through StateInitializer.Config before constructing a module.
So the parent module is responsible to pass through the state
initialization config before constructing children.
Item 3 is the key change that needs to discuss. Whether people agree this direction or people simply want to use kwargs to pass the config when calling init_states.