diff --git a/src/__tests__/index.test.js b/src/__tests__/index.test.js
index 57b1910..b685fcf 100644
--- a/src/__tests__/index.test.js
+++ b/src/__tests__/index.test.js
@@ -164,4 +164,36 @@ describe('reactTreeWalker', () => {
expect(actual).toMatchObject(expected)
})
})
+
+ it('works with instance-as-result component', () => {
+ // eslint-disable-next-line react/prefer-stateless-function
+ class Baz extends Component {
+ render() {
+ return (
+
+
+
+
+ )
+ }
+ }
+ const Bar = props => new Baz(props)
+ const tree = (
+
+
+
+ )
+ const actual = []
+ // eslint-disable-next-line no-unused-vars
+ const visitor = (element, instance, context) => {
+ if (instance && typeof instance.getSomething === 'function') {
+ const something = instance.getSomething()
+ actual.push(something)
+ }
+ }
+ return reactTreeWalker(tree, visitor).then(() => {
+ const expected = [1, 2]
+ expect(actual).toEqual(expected)
+ })
+ })
})
diff --git a/src/index.js b/src/index.js
index 1144205..588e197 100644
--- a/src/index.js
+++ b/src/index.js
@@ -49,6 +49,11 @@ const pMapSeries = (iterable, iterator) => {
).then(() => ret)
}
+const ensureChild = child =>
+ child && typeof child.render === 'function'
+ ? ensureChild(child.render())
+ : child
+
export const isPromise = x => x != null && typeof x.then === 'function'
// Recurse an React Element tree, running visitor on each element.
@@ -68,7 +73,7 @@ export default function reactTreeWalker(
resolve()
}
- const child = getChildren()
+ const child = ensureChild(getChildren())
const theChildContext =
typeof childContext === 'function' ? childContext() : childContext