66from . import GraphAcc , LayerAcc , MetaVecAcc , MultiAcc
77
88if TYPE_CHECKING :
9+ from collections .abc import Callable
910 from typing import Literal
1011
11- from . import AdAcc , AdPath , Idx2D
12+ from . import AdAcc , AdPath , GraphVecAcc , Idx2D , LayerVecAcc
1213
1314
1415@overload
@@ -23,21 +24,19 @@ def parse[P: AdPath](a: AdAcc[P], spec: str, *, strict: bool = True) -> P | None
2324 except ValueError :
2425 return None
2526
27+ if spec .startswith ("X[" ):
28+ return _parse_path_2d (lambda _ : a , spec )
2629 if "." not in spec :
2730 msg = f"Cannot parse accessor { spec !r} "
2831 raise ValueError (msg )
29- acc , rest = spec .split ("." , 1 )
30- match getattr (a , acc , None ):
31- # TODO: X
32- case LayerAcc () as layers :
33- return _parse_path_layer (layers , rest )
32+ acc_name , rest = spec .split ("." , 1 )
33+ match getattr (a , acc_name , None ):
34+ case LayerAcc () | GraphAcc () as acc :
35+ return _parse_path_2d (acc .__getitem__ , rest )
3436 case MetaVecAcc () as meta :
3537 return meta [rest ]
3638 case MultiAcc () as multi :
3739 return _parse_path_multi (multi , rest )
38- case GraphAcc ():
39- msg = "TODO"
40- raise NotImplementedError (msg )
4140 case None : # pragma: no cover
4241 msg = (
4342 f"Unknown accessor { spec !r} . "
@@ -48,14 +47,17 @@ def parse[P: AdPath](a: AdAcc[P], spec: str, *, strict: bool = True) -> P | None
4847 raise AssertionError (msg ) # pragma: no cover
4948
5049
51- def _parse_path_layer [P : AdPath ](layers : LayerAcc [P ], spec : str ) -> P :
50+ def _parse_path_2d [P : AdPath ](
51+ get_vec_acc : Callable [[str ], LayerVecAcc [P ] | GraphVecAcc [P ]], spec : str
52+ ) -> P :
5253 if not (
5354 m := re .fullmatch (r"([^\[]+)\[([^,]+),\s?([^\]]+)\]" , spec )
5455 ): # pragma: no cover
55- msg = f"Cannot parse layer accessor { spec !r} : should be `name[i,:]`/`name[:,j]`"
56+ msg = f"Cannot parse accessor { spec !r} : should be `name[i,:]`/`name[:,j]`"
5657 raise ValueError (msg )
57- layer , i , j = m .groups ("" ) # "" just for typing
58- return layers [layer ][_parse_idx_2d (i , j , str )]
58+ name , i , j = m .groups ("" ) # "" just for typing
59+ vec_acc = get_vec_acc (name )
60+ return vec_acc [_parse_idx_2d (i , j , str )]
5961
6062
6163def _parse_path_multi [P : AdPath ](multi : MultiAcc [P ], spec : str ) -> P :
@@ -68,6 +70,8 @@ def _parse_path_multi[P: AdPath](multi: MultiAcc[P], spec: str) -> P:
6870
6971def _parse_idx_2d [Idx : int | str ](i : str , j : str , cls : type [Idx ]) -> Idx2D [Idx ]:
7072 match i , j :
73+ case ":" , ":" :
74+ return slice (None ), slice (None )
7175 case _, ":" :
7276 return cls (i ), slice (None )
7377 case ":" , _:
0 commit comments