|
1 | 1 | # Copyright 2023 Secure Saurce LLC |
| 2 | +import ast |
| 3 | + |
2 | 4 | from tree_sitter import Node |
3 | 5 |
|
4 | 6 | from precli.core.parser import Parser |
@@ -50,63 +52,102 @@ def import_from_statement(self, nodes: list[Node]) -> dict: |
50 | 52 |
|
51 | 53 | return imports |
52 | 54 |
|
53 | | - def get_qual_value(self, context: dict, node: Node) -> str: |
54 | | - if node.type == "attribute": |
55 | | - attribute = node |
56 | | - name = attribute.text |
57 | | - if b"." in name: |
58 | | - name = name.rpartition(b".")[0] |
59 | | - |
60 | | - if name in context["imports"]: |
61 | | - return attribute.text.replace(name, context["imports"][name]) |
62 | | - elif node.type == "identifier": |
63 | | - name = node.text |
64 | | - if name in context["imports"]: |
65 | | - return context["imports"][name] |
66 | | - elif node.type == "keyword_argument": |
67 | | - kwarg = {} |
68 | | - keyword = node.children[0].text |
69 | | - kwarg[keyword] = self.get_qual_value(context, node.children[2]) |
70 | | - return kwarg |
71 | | - |
72 | | - def call(self, context: dict, nodes: list[Node]): |
| 55 | + def get_call_arg(self, context: dict, node: Node) -> str: |
| 56 | + match node.type: |
| 57 | + case "attribute": |
| 58 | + attribute = node |
| 59 | + name = attribute.text |
| 60 | + if b"." in name: |
| 61 | + name = name.rpartition(b".")[0] |
| 62 | + if name in context["imports"]: |
| 63 | + qual_name = context["imports"][name] |
| 64 | + return attribute.text.replace(name, qual_name) |
| 65 | + # TODO: else return attr text? |
| 66 | + case "identifier": |
| 67 | + name = node.text |
| 68 | + if name in context["imports"]: |
| 69 | + return context["imports"][name] |
| 70 | + else: |
| 71 | + return name |
| 72 | + case "dictionary": |
| 73 | + # TODO: need to avoid use of decode |
| 74 | + return ast.literal_eval(node.text.decode()) |
| 75 | + case "list": |
| 76 | + # TODO: need to avoid use of decode |
| 77 | + return ast.literal_eval(node.text.decode()) |
| 78 | + case "tuple": |
| 79 | + # TODO: need to avoid use of decode |
| 80 | + return ast.literal_eval(node.text.decode()) |
| 81 | + case "string": |
| 82 | + # TODO: bytes and f-type strings are messed up |
| 83 | + return node.text |
| 84 | + case "integer": |
| 85 | + # TODO: hex, octal, binary |
| 86 | + return int(node.text) |
| 87 | + case "float": |
| 88 | + return float(node.text) |
| 89 | + case "true": |
| 90 | + return True |
| 91 | + case "false": |
| 92 | + return False |
| 93 | + case "none": |
| 94 | + return None |
| 95 | + case _: |
| 96 | + # TODO: complex |
| 97 | + print(node.type) |
| 98 | + print(node.text) |
| 99 | + |
| 100 | + def get_call_kwarg(self, context: dict, node: Node) -> dict: |
| 101 | + kwarg = dict() |
| 102 | + keyword = node.children[0].text |
| 103 | + kwarg[keyword] = self.get_call_arg(context, node.children[2]) |
| 104 | + return kwarg |
| 105 | + |
| 106 | + def call(self, context: dict, nodes: list[Node]) -> tuple: |
73 | 107 | # Resolve the fully qualified function name |
74 | 108 | func_call_qual = "" |
75 | 109 | first_node = next(nodes) |
76 | | - func_call_qual = self.get_qual_value(context, first_node) |
| 110 | + func_call_qual = self.get_call_arg(context, first_node) |
77 | 111 |
|
78 | 112 | # Get the arguments of the function call |
79 | | - func_call_args = [] |
| 113 | + func_call_args = list() |
| 114 | + func_call_kwargs = dict() |
80 | 115 | second_node = next(nodes) |
81 | 116 | if second_node.type == "argument_list": |
82 | 117 | for child in second_node.children: |
83 | 118 | if child.type not in "(,)": |
84 | | - arg_value = self.get_qual_value(context, child) |
85 | | - func_call_args.append(arg_value) |
| 119 | + if child.type == "keyword_argument": |
| 120 | + kwarg = self.get_call_kwarg(context, child) |
| 121 | + func_call_kwargs = func_call_kwargs | kwarg |
| 122 | + else: |
| 123 | + arg = self.get_call_arg(context, child) |
| 124 | + func_call_args.append(arg) |
86 | 125 |
|
87 | | - return (func_call_qual, func_call_args) |
| 126 | + return (func_call_qual, func_call_args, func_call_kwargs) |
88 | 127 |
|
89 | 128 | def parse(self, data: bytes) -> list[Result]: |
90 | | - results = [] |
| 129 | + results = list() |
91 | 130 | context = dict() |
92 | | - context["imports"] = {} |
| 131 | + context["imports"] = dict() |
93 | 132 | tree = self.parser.parse(data) |
94 | 133 |
|
95 | 134 | for node in Parser.traverse_tree(tree): |
96 | 135 | context["node"] = node |
97 | 136 | match node.type: |
98 | 137 | case "import_statement": |
99 | | - imps = self.import_statement(iter(node.children)) |
| 138 | + children = iter(node.children) |
| 139 | + imps = self.import_statement(children) |
100 | 140 | context["imports"].update(imps) |
101 | | - |
102 | 141 | case "import_from_statement": |
103 | | - imps = self.import_from_statement(iter(node.children)) |
| 142 | + children = iter(node.children) |
| 143 | + imps = self.import_from_statement(children) |
104 | 144 | context["imports"].update(imps) |
105 | | - |
106 | 145 | case "call": |
107 | | - (func, args) = self.call(context, iter(node.children)) |
| 146 | + children = iter(node.children) |
| 147 | + (func, args, kwargs) = self.call(context, children) |
108 | 148 | context["func_call_qual"] = func |
109 | 149 | context["func_call_args"] = args |
| 150 | + context["func_call_kwargs"] = kwargs |
110 | 151 |
|
111 | 152 | for rule in self.rules.values(): |
112 | 153 | result = rule.analyze(context) |
|
0 commit comments