diff --git a/src/__tests__/index.test.js b/src/__tests__/index.test.js index 57b1910..b685fcf 100644 --- a/src/__tests__/index.test.js +++ b/src/__tests__/index.test.js @@ -164,4 +164,36 @@ describe('reactTreeWalker', () => { expect(actual).toMatchObject(expected) }) }) + + it('works with instance-as-result component', () => { + // eslint-disable-next-line react/prefer-stateless-function + class Baz extends Component { + render() { + return ( +
+ + +
+ ) + } + } + const Bar = props => new Baz(props) + const tree = ( +
+ +
+ ) + const actual = [] + // eslint-disable-next-line no-unused-vars + const visitor = (element, instance, context) => { + if (instance && typeof instance.getSomething === 'function') { + const something = instance.getSomething() + actual.push(something) + } + } + return reactTreeWalker(tree, visitor).then(() => { + const expected = [1, 2] + expect(actual).toEqual(expected) + }) + }) }) diff --git a/src/index.js b/src/index.js index 1144205..588e197 100644 --- a/src/index.js +++ b/src/index.js @@ -49,6 +49,11 @@ const pMapSeries = (iterable, iterator) => { ).then(() => ret) } +const ensureChild = child => + child && typeof child.render === 'function' + ? ensureChild(child.render()) + : child + export const isPromise = x => x != null && typeof x.then === 'function' // Recurse an React Element tree, running visitor on each element. @@ -68,7 +73,7 @@ export default function reactTreeWalker( resolve() } - const child = getChildren() + const child = ensureChild(getChildren()) const theChildContext = typeof childContext === 'function' ? childContext() : childContext