Skip to content

Commit 7d9f40f

Browse files
authored
Enable LoRA for TELinear layers (#13929)
* Enable LoRA for TELinear layers Signed-off-by: Chen Cui <chcui@nvidia.com> * Apply isort and black reformatting Signed-off-by: cuichenx <cuichenx@users.noreply.github.com> --------- Signed-off-by: Chen Cui <chcui@nvidia.com> Signed-off-by: cuichenx <cuichenx@users.noreply.github.com> Co-authored-by: cuichenx <cuichenx@users.noreply.github.com>
1 parent 6dbef01 commit 7d9f40f

4 files changed

Lines changed: 23 additions & 5 deletions

File tree

nemo/collections/llm/peft/canonical_lora.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,9 @@ def transform(self, m: nn.Module, name=None, prefix=None):
230230
m, dim=self.dim, alpha=self.alpha, dropout=self.dropout, lora_A_init_method=self.lora_A_init_method
231231
)
232232

233-
input_is_parallel, in_features, out_features, disable_sp_comm = get_adapter_attributes_from_linear(m)
233+
input_is_parallel, in_features, out_features, disable_sp_comm, base_linear_is_parallel = (
234+
get_adapter_attributes_from_linear(m)
235+
)
234236

235237
adapter_kwargs = dict(
236238
dim=self.dim,
@@ -247,6 +249,7 @@ def transform(self, m: nn.Module, name=None, prefix=None):
247249
alpha=self.alpha,
248250
is_expert=is_expert_linear(full_name),
249251
disable_sequence_parallel_comm=disable_sp_comm,
252+
base_linear_is_parallel=base_linear_is_parallel,
250253
)
251254
if name in ['linear_proj', 'linear_fc2']:
252255
adapter = ParallelLinearAdapter(in_features, out_features, **adapter_kwargs)

nemo/collections/llm/peft/dora.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ def transform(self, m: nn.Module, name=None, prefix=None):
180180
"""
181181
if (ans := self.match(m, name, prefix)) is not None:
182182
(match, full_name) = ans
183-
input_is_parallel, in_features, out_features, disable_sp_comm = get_adapter_attributes_from_linear(m)
183+
input_is_parallel, in_features, out_features, disable_sp_comm, base_linear_is_parallel = (
184+
get_adapter_attributes_from_linear(m)
185+
)
184186
logging.info(f"Adding DoRA to: {full_name}")
185187
adapter = ParallelLinearDoRAAdapter(
186188
in_features,
@@ -198,6 +200,7 @@ def transform(self, m: nn.Module, name=None, prefix=None):
198200
model_parallel_config=getattr(m, "config", None),
199201
alpha=self.alpha,
200202
disable_sequence_parallel_comm=disable_sp_comm,
203+
base_linear_is_parallel=base_linear_is_parallel,
201204
)
202205
return DoRALinear(m, adapter)
203206
return m

nemo/collections/llm/peft/lora.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,9 @@ def transform(self, m: nn.Module, name=None, prefix=None):
462462
lora_dtype=self.lora_dtype,
463463
)
464464

465-
input_is_parallel, in_features, out_features, disable_sp_comm = get_adapter_attributes_from_linear(m)
465+
input_is_parallel, in_features, out_features, disable_sp_comm, base_linear_is_parallel = (
466+
get_adapter_attributes_from_linear(m)
467+
)
466468
logging.info(f"Adding lora to: {full_name}")
467469
adapter = ParallelLinearAdapter(
468470
in_features,
@@ -483,6 +485,7 @@ def transform(self, m: nn.Module, name=None, prefix=None):
483485
a2a_experimental=self.a2a_experimental,
484486
disable_sequence_parallel_comm=disable_sp_comm,
485487
dropout_recompute=self.dropout_recompute,
488+
base_linear_is_parallel=base_linear_is_parallel,
486489
)
487490
return LoRALinear(m, adapter)
488491
return m

nemo/collections/llm/peft/utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def get_adapter_attributes_from_linear(m: nn.Module):
7373
Return input_is_parallel, in_features, out_feature attributes based on implementation of the base layer.
7474
"""
7575
disable_sequence_parallel_comm = not m.config.sequence_parallel
76-
76+
base_linear_is_parallel = True
7777
if HAVE_TE and any(isinstance(m, te_column_parallel) for te_column_parallel in TECL):
7878
input_is_parallel = False
7979
# m.in_features and m.out_features are divided by tp_size already,
@@ -112,6 +112,7 @@ def get_adapter_attributes_from_linear(m: nn.Module):
112112
input_is_parallel = False
113113
in_features = m.in_features
114114
out_features = m.out_features
115+
base_linear_is_parallel = False
115116
elif isinstance(m, ColumnParallelLinear):
116117
input_is_parallel = False
117118
in_features = m.input_size
@@ -123,7 +124,7 @@ def get_adapter_attributes_from_linear(m: nn.Module):
123124
else:
124125
raise NotImplementedError(f"Layer type is unrecognized for LoRA: {type(m)}")
125126

126-
return input_is_parallel, in_features, out_features, disable_sequence_parallel_comm
127+
return input_is_parallel, in_features, out_features, disable_sequence_parallel_comm, base_linear_is_parallel
127128

128129

129130
def is_expert_linear(fqn):
@@ -262,6 +263,7 @@ def __init__(
262263
is_expert: bool = False,
263264
disable_sequence_parallel_comm: bool = True,
264265
dropout_recompute: bool = False,
266+
base_linear_is_parallel: bool = True,
265267
**kwargs,
266268
):
267269
super().__init__()
@@ -310,6 +312,10 @@ def __init__(
310312
lin_out_gather_output = True if input_is_parallel else False
311313
if self.use_a2a and input_is_parallel and _sequence_parallel:
312314
lin_out_gather_output = False
315+
316+
if not base_linear_is_parallel:
317+
lin_out_gather_output = True
318+
313319
self.linear_out = ColumnParallelLinear(
314320
dim,
315321
out_features,
@@ -344,6 +350,9 @@ def __init__(
344350
if not _sequence_parallel:
345351
self.disable_sequence_parallel_comm = True
346352

353+
if not base_linear_is_parallel:
354+
self.disable_sequence_parallel_comm = True
355+
347356
def _get_init_fn(self, init_method: str):
348357
if init_method == 'xavier':
349358
init_fn = nn.init.xavier_normal_

0 commit comments

Comments
 (0)