Skip to content

Commit 7457c37

Browse files
authored
Fix for flattened unit tests with deeply nested request messages (#346)
Expands the unit-test visible method ref types to include _all_ recursive types. Do NOT use this deep nesting in the client submodule. The client ONLY needs to import, for each method, the request type, response type, and the types of any flattened fields.
1 parent fdef285 commit 7457c37

File tree

8 files changed

+157
-23
lines changed

8 files changed

+157
-23
lines changed

packages/gapic-generator/gapic/schema/api.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@
1919

2020
import collections
2121
import dataclasses
22-
import keyword
2322
import os
2423
import sys
25-
from itertools import chain
2624
from typing import Callable, Container, Dict, FrozenSet, Mapping, Optional, Sequence, Set, Tuple
2725

2826
from google.api_core import exceptions # type: ignore
@@ -37,6 +35,7 @@
3735
from gapic.schema import naming as api_naming
3836
from gapic.utils import cached_property
3937
from gapic.utils import to_snake_case
38+
from gapic.utils import RESERVED_NAMES
4039

4140

4241
@dataclasses.dataclass(frozen=True)
@@ -130,13 +129,13 @@ def names(self) -> FrozenSet[str]:
130129
# from distinct packages.
131130
modules: Dict[str, Set[str]] = collections.defaultdict(set)
132131
for m in self.all_messages.values():
133-
for t in m.field_types:
132+
for t in m.recursive_field_types:
134133
modules[t.ident.module].add(t.ident.package)
135134

136135
answer.update(
137136
module_name
138137
for module_name, packages in modules.items()
139-
if len(packages) > 1
138+
if len(packages) > 1 or module_name in RESERVED_NAMES
140139
)
141140

142141
# Return the set of collision names.
@@ -229,7 +228,7 @@ def disambiguate_keyword_fname(
229228
visited_names: Container[str]) -> str:
230229
path, fname = os.path.split(full_path)
231230
name, ext = os.path.splitext(fname)
232-
if name in keyword.kwlist or full_path in visited_names:
231+
if name in RESERVED_NAMES or full_path in visited_names:
233232
name += "_"
234233
full_path = os.path.join(path, name + ext)
235234
if full_path in visited_names:

packages/gapic-generator/gapic/schema/wrappers.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import dataclasses
3232
import re
3333
from itertools import chain
34-
from typing import (cast, Dict, FrozenSet, List, Mapping, Optional,
34+
from typing import (cast, Dict, FrozenSet, Iterable, List, Mapping, Optional,
3535
Sequence, Set, Union)
3636

3737
from google.api import annotations_pb2 # type: ignore
@@ -225,7 +225,6 @@ def __hash__(self):
225225

226226
@utils.cached_property
227227
def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]:
228-
"""Return all composite fields used in this proto's messages."""
229228
answer = tuple(
230229
field.type
231230
for field in self.fields.values()
@@ -234,6 +233,23 @@ def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]:
234233

235234
return answer
236235

236+
@utils.cached_property
237+
def recursive_field_types(self) -> Sequence[
238+
Union['MessageType', 'EnumType']
239+
]:
240+
"""Return all composite fields used in this proto's messages."""
241+
types: List[Union['MessageType', 'EnumType']] = []
242+
stack = [iter(self.fields.values())]
243+
while stack:
244+
fields_iter = stack.pop()
245+
for field in fields_iter:
246+
if field.message and field.type not in types:
247+
stack.append(iter(field.message.fields.values()))
248+
if not field.is_primitive:
249+
types.append(field.type)
250+
251+
return tuple(types)
252+
237253
@property
238254
def map(self) -> bool:
239255
"""Return True if the given message is a map, False otherwise."""
@@ -654,19 +670,30 @@ def paged_result_field(self) -> Optional[Field]:
654670

655671
@utils.cached_property
656672
def ref_types(self) -> Sequence[Union[MessageType, EnumType]]:
673+
return self._ref_types(True)
674+
675+
@utils.cached_property
676+
def flat_ref_types(self) -> Sequence[Union[MessageType, EnumType]]:
677+
return self._ref_types(False)
678+
679+
def _ref_types(self, recursive: bool) -> Sequence[Union[MessageType, EnumType]]:
657680
"""Return types referenced by this method."""
658681
# Begin with the input (request) and output (response) messages.
659-
answer = [self.input]
682+
answer: List[Union[MessageType, EnumType]] = [self.input]
683+
types: Iterable[Union[MessageType, EnumType]] = (
684+
self.input.recursive_field_types if recursive
685+
else (
686+
f.type
687+
for f in self.flattened_fields.values()
688+
if f.message or f.enum
689+
)
690+
)
691+
answer.extend(types)
692+
660693
if not self.void:
661694
answer.append(self.client_output)
662695
answer.extend(self.client_output.field_types)
663696

664-
answer.extend(
665-
field.type
666-
for field in self.flattened_fields.values()
667-
if field.message or field.enum
668-
)
669-
670697
# If this method has LRO, it is possible (albeit unlikely) that
671698
# the LRO messages reside in a different module.
672699
if self.lro:

packages/gapic-generator/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ from google.oauth2 import service_account # type: ignore
1414

1515
{% filter sort_lines -%}
1616
{% for method in service.methods.values() -%}
17-
{% for ref_type in method.ref_types -%}
17+
{% for ref_type in method.flat_ref_types -%}
1818
{{ ref_type.ident.python_import }}
1919
{% endfor -%}
2020
{% endfor -%}

packages/gapic-generator/gapic/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from gapic.utils.filename import to_valid_module_name
2222
from gapic.utils.lines import sort_lines
2323
from gapic.utils.lines import wrap
24+
from gapic.utils.reserved_names import RESERVED_NAMES
2425
from gapic.utils.rst import rst
2526

2627

@@ -29,6 +30,7 @@
2930
'doc',
3031
'empty',
3132
'partition',
33+
'RESERVED_NAMES',
3234
'rst',
3335
'sort_lines',
3436
'to_snake_case',
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import keyword
16+
17+
18+
RESERVED_NAMES = frozenset(keyword.kwlist)

packages/gapic-generator/test_utils/test_utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,15 @@ def get_message(dot_path: str, *,
145145

146146

147147
def make_method(
148-
name: str, input_message: wrappers.MessageType = None,
148+
name: str,
149+
input_message: wrappers.MessageType = None,
149150
output_message: wrappers.MessageType = None,
150151
package: typing.Union[typing.Tuple[str], str] = 'foo.bar.v1',
151152
module: str = 'baz',
152153
http_rule: http_pb2.HttpRule = None,
153154
signatures: typing.Sequence[str] = (),
154-
**kwargs) -> wrappers.Method:
155+
**kwargs
156+
) -> wrappers.Method:
155157
# Use default input and output messages if they are not provided.
156158
input_message = input_message or make_message('MethodInput')
157159
output_message = output_message or make_message('MethodOutput')
@@ -229,11 +231,14 @@ def make_field(
229231
)
230232

231233

232-
def make_message(name: str, package: str = 'foo.bar.v1', module: str = 'baz',
233-
fields: typing.Sequence[wrappers.Field] = (),
234-
meta: metadata.Metadata = None,
235-
options: desc.MethodOptions = None,
236-
) -> wrappers.MessageType:
234+
def make_message(
235+
name: str,
236+
package: str = 'foo.bar.v1',
237+
module: str = 'baz',
238+
fields: typing.Sequence[wrappers.Field] = (),
239+
meta: metadata.Metadata = None,
240+
options: desc.MethodOptions = None,
241+
) -> wrappers.MessageType:
237242
message_pb = desc.DescriptorProto(
238243
name=name,
239244
field=[i.field_pb for i in fields],

packages/gapic-generator/tests/unit/schema/wrappers/test_message.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,15 @@ def test_get_field():
7070

7171
def test_field_types():
7272
# Create the inner message.
73-
inner_msg = make_message('InnerMessage', fields=())
73+
inner_msg = make_message(
74+
'InnerMessage',
75+
fields=(
76+
make_field(
77+
'hidden_message',
78+
message=make_message('HiddenMessage'),
79+
),
80+
)
81+
)
7482
inner_enum = make_enum('InnerEnum')
7583

7684
# Create the outer message, which contains an Inner as a field.
@@ -87,6 +95,33 @@ def test_field_types():
8795
assert inner_enum in outer.field_types
8896

8997

98+
def test_field_types_recursive():
99+
enumeration = make_enum('Enumeration')
100+
innest_msg = make_message(
101+
'InnestMessage',
102+
fields=(
103+
make_field('enumeration', enum=enumeration),
104+
)
105+
)
106+
inner_msg = make_message(
107+
'InnerMessage',
108+
fields=(
109+
make_field('innest_message', message=innest_msg),
110+
)
111+
)
112+
topmost_msg = make_message(
113+
'TopmostMessage',
114+
fields=(
115+
make_field('inner_message', message=inner_msg),
116+
make_field('uninteresting')
117+
)
118+
)
119+
120+
actual = {t.name for t in topmost_msg.recursive_field_types}
121+
expected = {t.name for t in (enumeration, innest_msg, inner_msg)}
122+
assert actual == expected
123+
124+
90125
def test_get_field_recursive():
91126
# Create the inner message.
92127
inner_fields = (make_field('zero'), make_field('one'))

packages/gapic-generator/tests/unit/schema/wrappers/test_method.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from gapic.schema import wrappers
2424

2525
from test_utils.test_utils import (
26+
make_enum,
2627
make_field,
2728
make_message,
2829
make_method,
@@ -151,6 +152,53 @@ def test_method_paged_result_ref_types():
151152
}
152153

153154

155+
def test_flattened_ref_types():
156+
method = make_method(
157+
'IdentifyMollusc',
158+
input_message=make_message(
159+
'IdentifyMolluscRequest',
160+
fields=(
161+
make_field(
162+
'cephalopod',
163+
message=make_message(
164+
'Cephalopod',
165+
fields=(
166+
make_field('mass_kg', type='TYPE_INT32'),
167+
make_field(
168+
'squid',
169+
number=2,
170+
message=make_message('Squid'),
171+
),
172+
make_field(
173+
'clam',
174+
number=3,
175+
message=make_message('Clam'),
176+
),
177+
),
178+
),
179+
),
180+
make_field(
181+
'stratum',
182+
enum=make_enum(
183+
'Stratum',
184+
)
185+
),
186+
),
187+
),
188+
signatures=('cephalopod.squid,stratum',),
189+
output_message=make_message('Mollusc'),
190+
)
191+
192+
expected_flat_ref_type_names = {
193+
'IdentifyMolluscRequest',
194+
'Squid',
195+
'Stratum',
196+
'Mollusc',
197+
}
198+
actual_flat_ref_type_names = {t.name for t in method.flat_ref_types}
199+
assert expected_flat_ref_type_names == actual_flat_ref_type_names
200+
201+
154202
def test_method_field_headers_none():
155203
method = make_method('DoSomething')
156204
assert isinstance(method.field_headers, collections.abc.Sequence)

0 commit comments

Comments
 (0)