Skip to content

Commit 111985e

Browse files
dandhleetbpg
andauthored
fix: update annotation name extraction logic and add unit tests (#320)
* fix: update annotation name extraction logic * test: update unit test * test: update docstring for test --------- Co-authored-by: Tyler Bui-Palsulich <26876514+tbpg@users.noreply.github.com>
1 parent ff15676 commit 111985e

File tree

2 files changed

+59
-7
lines changed

2 files changed

+59
-7
lines changed

packages/gcp-sphinx-docfx-yaml/docfx_yaml/extension.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -747,17 +747,13 @@ def _extract_type_name(annotation: Any) -> str:
747747
Returns:
748748
The extracted type hint in human-readable string format.
749749
"""
750-
type_name = ""
751-
# Extract names for simple types.
752-
try:
750+
751+
annotation_dir = dir(annotation)
752+
if '__args__' not in annotation_dir:
753753
return annotation.__name__
754-
except AttributeError:
755-
pass
756754

757755
# Try to extract names for more complicated types.
758756
type_name = str(annotation)
759-
if not annotation.__args__:
760-
return type_name
761757

762758
# If ForwardRef references are found, recursively remove them.
763759
prefix_to_remove_start = "ForwardRef('"
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""Tests for various inspect related utils."""
2+
from docfx_yaml import extension
3+
4+
import inspect
5+
import unittest
6+
from parameterized import parameterized
7+
8+
from typing import Any, Optional, Union
9+
10+
11+
class TestGenerate(unittest.TestCase):
12+
13+
types_to_test = [
14+
[
15+
# Test for a simple type without __args__.
16+
str,
17+
'str',
18+
],
19+
[
20+
# Test for a more complex type without forward reference.
21+
list[str],
22+
'list[str]',
23+
],
24+
[
25+
# Test for imported type, without forward reference.
26+
dict[str, Any],
27+
'dict[str, typing.Any]',
28+
],
29+
[
30+
# Test for forward reference.
31+
Optional["ForwardClass"],
32+
'typing.Optional[ForwardClass]'
33+
],
34+
[
35+
# Test for multiple forward references.
36+
Union["ForwardClass", "ForwardClass2"],
37+
'typing.Union[ForwardClass, ForwardClass2]'
38+
],
39+
]
40+
@parameterized.expand(types_to_test)
41+
def test_extracts_annotations(self, type_to_test, expected_type_name):
42+
"""Extracts annotations from test method, compares to expected name."""
43+
def test_method(name: type_to_test):
44+
pass
45+
46+
annotations = inspect.getfullargspec(test_method).annotations
47+
annotation_to_use = annotations['name']
48+
49+
extracted_annotation_name = extension._extract_type_name(
50+
annotation_to_use)
51+
52+
self.assertEqual(extracted_annotation_name, expected_type_name)
53+
54+
55+
if __name__ == '__main__':
56+
unittest.main()

0 commit comments

Comments
 (0)