@@ -128,6 +128,23 @@ def type(self) -> Union['MessageType', 'EnumType', 'PythonType']:
128128 raise TypeError ('Unrecognized protobuf type. This code should '
129129 'not be reachable; please file a bug.' )
130130
131+ def with_context (self , * , collisions : Set [str ]) -> 'Field' :
132+ """Return a derivative of this field with the provided context.
133+
134+ This method is used to address naming collisions. The returned
135+ ``Field`` object aliases module names to avoid naming collisions
136+ in the file being written.
137+ """
138+ return dataclasses .replace (self ,
139+ message = self .message .with_context (
140+ collisions = collisions ,
141+ skip_fields = True ,
142+ ) if self .message else None ,
143+ enum = self .enum .with_context (collisions = collisions )
144+ if self .enum else None ,
145+ meta = self .meta .with_context (collisions = collisions ),
146+ )
147+
131148
132149@dataclasses .dataclass (frozen = True )
133150class MessageType :
@@ -152,7 +169,13 @@ def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]:
152169 answer .append (field .type )
153170 return tuple (answer )
154171
155- def get_field (self , * field_path : Sequence [str ]) -> Field :
172+ @property
173+ def ident (self ) -> metadata .Address :
174+ """Return the identifier data to be used in templates."""
175+ return self .meta .address
176+
177+ def get_field (self , * field_path : Sequence [str ],
178+ collisions : Set [str ] = frozenset ()) -> Field :
156179 """Return a field arbitrarily deep in this message's structure.
157180
158181 This method recursively traverses the message tree to return the
@@ -171,12 +194,21 @@ def get_field(self, *field_path: Sequence[str]) -> Field:
171194 KeyError: If a repeated field is used in the non-terminal position
172195 in the path.
173196 """
197+ # If collisions are not explicitly specified, retrieve them
198+ # from this message's address.
199+ # This ensures that calls to `get_field` will return a field with
200+ # the same context, regardless of the number of levels through the
201+ # chain (in order to avoid infinite recursion on circular references,
202+ # we only shallowly bind message references held by fields; this
203+ # binds deeply in the one spot where that might be a problem).
204+ collisions = collisions or self .meta .address .collisions
205+
174206 # Get the first field in the path.
175207 cursor = self .fields [field_path [0 ]]
176208
177209 # Base case: If this is the last field in the path, return it outright.
178210 if len (field_path ) == 1 :
179- return cursor
211+ return cursor . with_context ( collisions = collisions )
180212
181213 # Sanity check: If cursor is a repeated field, then raise an exception.
182214 # Repeated fields are only permitted in the terminal position.
@@ -191,12 +223,37 @@ def get_field(self, *field_path: Sequence[str]) -> Field:
191223
192224 # Recursion case: Pass the remainder of the path to the sub-field's
193225 # message.
194- return cursor .message .get_field (* field_path [1 :])
226+ return cursor .message .get_field (* field_path [1 :], collisions = collisions )
195227
196- @property
197- def ident (self ) -> metadata .Address :
198- """Return the identifier data to be used in templates."""
199- return self .meta .address
228+ def with_context (self , * ,
229+ collisions : Set [str ],
230+ skip_fields : bool = False ,
231+ ) -> 'MessageType' :
232+ """Return a derivative of this message with the provided context.
233+
234+ This method is used to address naming collisions. The returned
235+ ``MessageType`` object aliases module names to avoid naming collisions
236+ in the file being written.
237+
238+ The ``skip_fields`` argument will omit applying the context to the
239+ underlying fields. This provides for an "exit" in the case of circular
240+ references.
241+ """
242+ return dataclasses .replace (self ,
243+ fields = collections .OrderedDict ([
244+ (k , v .with_context (collisions = collisions ))
245+ for k , v in self .fields .items ()
246+ ]) if not skip_fields else self .fields ,
247+ nested_enums = collections .OrderedDict ([
248+ (k , v .with_context (collisions = collisions ))
249+ for k , v in self .nested_enums .items ()
250+ ]),
251+ nested_messages = collections .OrderedDict ([(k , v .with_context (
252+ collisions = collisions ,
253+ skip_fields = skip_fields ,
254+ )) for k , v in self .nested_messages .items ()]),
255+ meta = self .meta .with_context (collisions = collisions ),
256+ )
200257
201258
202259@dataclasses .dataclass (frozen = True )
@@ -228,6 +285,17 @@ def ident(self) -> metadata.Address:
228285 """Return the identifier data to be used in templates."""
229286 return self .meta .address
230287
288+ def with_context (self , * , collisions : Set [str ]) -> 'EnumType' :
289+ """Return a derivative of this enum with the provided context.
290+
291+ This method is used to address naming collisions. The returned
292+ ``EnumType`` object aliases module names to avoid naming collisions in
293+ the file being written.
294+ """
295+ return dataclasses .replace (self ,
296+ meta = self .meta .with_context (collisions = collisions ),
297+ )
298+
231299
232300@dataclasses .dataclass (frozen = True )
233301class PythonType :
@@ -275,6 +343,7 @@ def meta(self) -> metadata.Metadata:
275343 name = 'Operation' ,
276344 module = 'operation' ,
277345 package = ('google' , 'api_core' ),
346+ collisions = self .lro_response .meta .address .collisions ,
278347 ),
279348 documentation = descriptor_pb2 .SourceCodeInfo .Location (
280349 leading_comments = 'An object representing a long-running '
@@ -298,6 +367,18 @@ def name(self) -> str:
298367 # on google.api_core just to get these strings.
299368 return 'Operation'
300369
370+ def with_context (self , * , collisions : Set [str ]) -> 'OperationType' :
371+ """Return a derivative of this operation with the provided context.
372+
373+ This method is used to address naming collisions. The returned
374+ ``OperationType`` object aliases module names to avoid naming
375+ collisions in the file being written.
376+ """
377+ return dataclasses .replace (self ,
378+ lro_response = self .lro_response .with_context (collisions = collisions ),
379+ lro_metadata = self .lro_metadata .with_context (collisions = collisions ),
380+ )
381+
301382
302383@dataclasses .dataclass (frozen = True )
303384class Method :
@@ -381,6 +462,19 @@ def signatures(self) -> Tuple[signature_pb2.MethodSignature]:
381462 # Done; return a tuple of signatures.
382463 return MethodSignatures (all = tuple (answer ))
383464
465+ def with_context (self , * , collisions : Set [str ]) -> 'Method' :
466+ """Return a derivative of this method with the provided context.
467+
468+ This method is used to address naming collisions. The returned
469+ ``Method`` object aliases module names to avoid naming collisions
470+ in the file being written.
471+ """
472+ return dataclasses .replace (self ,
473+ input = self .input .with_context (collisions = collisions ),
474+ output = self .output .with_context (collisions = collisions ),
475+ meta = self .meta .with_context (collisions = collisions ),
476+ )
477+
384478
385479@dataclasses .dataclass (frozen = True )
386480class MethodSignature :
@@ -519,11 +613,26 @@ def python_modules(self) -> Sequence[imp.Import]:
519613 answer = set ()
520614 for method in self .methods .values ():
521615 for t in method .ref_types :
522- answer .add (t .ident .context ( self ). python_import )
616+ answer .add (t .ident .python_import )
523617 return tuple (sorted (list (answer )))
524618
525619 @property
526620 def has_lro (self ) -> bool :
527621 """Return whether the service has a long-running method."""
528622 return any ([getattr (m .output , 'lro_response' , None )
529623 for m in self .methods .values ()])
624+
625+ def with_context (self , * , collisions : Set [str ]) -> 'Service' :
626+ """Return a derivative of this service with the provided context.
627+
628+ This method is used to address naming collisions. The returned
629+ ``Service`` object aliases module names to avoid naming collisions
630+ in the file being written.
631+ """
632+ return dataclasses .replace (self ,
633+ methods = collections .OrderedDict ([
634+ (k , v .with_context (collisions = collisions ))
635+ for k , v in self .methods .items ()
636+ ]),
637+ meta = self .meta .with_context (collisions = collisions ),
638+ )
0 commit comments