Reduce mish error by an alternative without softplus op#2618
Reduce mish error by an alternative without softplus op#2618ChinChangYang wants to merge 1 commit intoapple:mainfrom
Conversation
| inputs = _get_inputs(context, node, expected=1) | ||
| x = inputs[0] | ||
|
|
||
| softplus = mb.softplus(x=x) |
There was a problem hiding this comment.
Looking at the PyTorch documentation, it seems the existing implementation is correct:
https://docs.pytorch.org/docs/stable/generated/torch.nn.Mish.html
There was a problem hiding this comment.
If the existing (software) implementation is correct, it must be a hardware precision issue in the Neural Engine. This PR provides a (software) workaround to circumvent the precision issue. I anticipate that Apple’s low-level (hardware) developers will investigate this issue.
JiwaniZakir
left a comment
There was a problem hiding this comment.
The algebraic derivation is correct — x * tanh(ln(1+eˣ)) simplifies to x·eˣ·(eˣ+2) / (e²ˣ+2eˣ+2), which is equivalent to the new formulation. However, there is a numerical stability concern for large negative x values: as x → -∞, e = exp(x) → 0, causing emep2 = e*(e+2) → 0 and thus tdemep2 = 2/emep2 overflowing to infinity. The final real_div(x, inf) does produce the correct limit of 0, but this intermediate overflow may behave inconsistently across backends or hardware, which ironically trades one source of numerical error for another.
The original three-op path (softplus → tanh → mul) avoids this by computing softplus(x) = ln(1+eˣ) ≈ 0 directly for large negative x, never producing an overflow. It would strengthen this PR to include explicit test cases covering the large-negative-x regime (e.g., x = -30, -100) and to document which backends/targets exhibited the original softplus error, so reviewers can assess whether this tradeoff is worthwhile. The intermediate variable names (emep2, tdemep2, optdemep2) in ops.py are also difficult to parse; expanding the comment to label each step with the full subexpression (e.g., # 1 + 2/(e*(e+2))) would make the code far more maintainable.
Fix the high numerical error in mish activation #2359.
Algorithm:
Evaluation:
In the following experiments, the mean absolute errors are evaluated by the method in #2359 (comment).
Before this change, NE generates high numerical error:
With the new algorithm, NE generates low numerical error:
A tester reported that the new mish function generates
NaNonly whenxis-Infin the float16 format.Performance:
This change has been adopted in KataGo Core ML backend ChinChangYang/KataGo#7. The performance of the KataGo model with the new mish activation (7.15 ms) is similar to the original mish implementation (7.03 ms).
Conclusion:
Overall, the change enhances the accuracy and reliability of the mish activation in Core ML models.