Skip to content

Reduce mish error by an alternative without softplus op#2618

Open
ChinChangYang wants to merge 1 commit intoapple:mainfrom
ChinChangYang:reduce-mish-error
Open

Reduce mish error by an alternative without softplus op#2618
ChinChangYang wants to merge 1 commit intoapple:mainfrom
ChinChangYang:reduce-mish-error

Conversation

@ChinChangYang
Copy link
Copy Markdown
Contributor

Fix the high numerical error in mish activation #2359.

Algorithm:

e = exp(x)
mish = x / (1 + 2 / (e * (e + 2)))

Evaluation:

In the following experiments, the mean absolute errors are evaluated by the method in #2359 (comment).

Before this change, NE generates high numerical error:

Mean Absolute Errors Across Samples:
  var_17:
    NE:  2.955052
    GPU: 0.000998

With the new algorithm, NE generates low numerical error:

Mean Absolute Errors Across Samples:
  var_17:
    NE:  0.001744
    GPU: 0.001516

A tester reported that the new mish function generates NaN only when x is -Inf in the float16 format.

Performance:

This change has been adopted in KataGo Core ML backend ChinChangYang/KataGo#7. The performance of the KataGo model with the new mish activation (7.15 ms) is similar to the original mish implementation (7.03 ms).

Conclusion:

Overall, the change enhances the accuracy and reliability of the mish activation in Core ML models.

inputs = _get_inputs(context, node, expected=1)
x = inputs[0]

softplus = mb.softplus(x=x)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the PyTorch documentation, it seems the existing implementation is correct:
https://docs.pytorch.org/docs/stable/generated/torch.nn.Mish.html

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the existing (software) implementation is correct, it must be a hardware precision issue in the Neural Engine. This PR provides a (software) workaround to circumvent the precision issue. I anticipate that Apple’s low-level (hardware) developers will investigate this issue.

Copy link
Copy Markdown

@JiwaniZakir JiwaniZakir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The algebraic derivation is correct — x * tanh(ln(1+eˣ)) simplifies to x·eˣ·(eˣ+2) / (e²ˣ+2eˣ+2), which is equivalent to the new formulation. However, there is a numerical stability concern for large negative x values: as x → -∞, e = exp(x) → 0, causing emep2 = e*(e+2) → 0 and thus tdemep2 = 2/emep2 overflowing to infinity. The final real_div(x, inf) does produce the correct limit of 0, but this intermediate overflow may behave inconsistently across backends or hardware, which ironically trades one source of numerical error for another.

The original three-op path (softplus → tanh → mul) avoids this by computing softplus(x) = ln(1+eˣ) ≈ 0 directly for large negative x, never producing an overflow. It would strengthen this PR to include explicit test cases covering the large-negative-x regime (e.g., x = -30, -100) and to document which backends/targets exhibited the original softplus error, so reviewers can assess whether this tradeoff is worthwhile. The intermediate variable names (emep2, tdemep2, optdemep2) in ops.py are also difficult to parse; expanding the comment to label each step with the full subexpression (e.g., # 1 + 2/(e*(e+2))) would make the code far more maintainable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants