2020from io import StringIO
2121import typing as t
2222
23+ from typing_extensions import Self , ParamSpec , Concatenate , TypeAlias
2324import numpy
2425from numpy .typing import ArrayLike , NDArray
2526import polars
2627import polars .dataframe .group_by
2728import polars .datatypes
2829import polars .interchange .dataframe
2930import polars .testing
30- import polars .type_aliases
31+ import polars ._typing
32+ from polars .schema import Schema
3133
32- from .types import to_vec3 , VecLike , ParamSpec , Concatenate , TypeAlias , Self
34+ from .types import to_vec3 , VecLike
3335from .bbox import BBox3D
3436from .elem import get_elem , get_sym , get_mass
3537from .transform import Transform3D , IntoTransform3D , AffineTransform3D
@@ -59,6 +61,7 @@ def _is_abstract(cls: t.Type) -> bool:
5961 return bool (getattr (cls , "__abstractmethods__" , False ))
6062
6163
64+ """
6265def _polars_to_numpy_dtype(dtype: t.Type[polars.DataType]) -> numpy.dtype:
6366 from polars.datatypes import dtype_to_ctype
6467 if dtype == polars.Boolean:
@@ -67,14 +70,15 @@ def _polars_to_numpy_dtype(dtype: t.Type[polars.DataType]) -> numpy.dtype:
6770 return numpy.dtype(dtype_to_ctype(dtype))
6871 except NotImplementedError:
6972 return numpy.dtype(object)
73+ """
7074
7175
7276def _get_symbol_mapping (df : t .Union [polars .DataFrame , HasAtoms ], mapping : t .Mapping [str , t .Any ], ty : t .Type [polars .DataType ]) -> polars .Expr :
7377 syms = df ['symbol' ].unique ()
7478 if (missing := set (syms ) - set (mapping .keys ())):
7579 raise ValueError (f"Could not remap symbols { ', ' .join (map (repr , missing ))} " )
7680
77- return polars .col ('symbol' ).replace (mapping , default = None , return_dtype = ty )
81+ return polars .col ('symbol' ).replace_strict (mapping , default = None , return_dtype = ty )
7882
7983
8084def _values_to_expr (df : t .Union [polars .DataFrame , HasAtoms ], values : AtomValues , ty : t .Type [polars .DataType ]) -> polars .Expr :
@@ -97,13 +101,8 @@ def _values_to_numpy(df: t.Union[polars.DataFrame, HasAtoms], values: AtomValues
97101 # syms = df.select(polars.col('symbol').filter(values.is_null())).unique().to_series().to_list()
98102 # raise ValueError(f"Could not remap symbols {', '.join(map(repr, syms))}")
99103 if isinstance (values , polars .Series ):
100- if ty == polars .Boolean :
101- # force conversion to numpy (unpacked) bool
102- return values .cast (polars .UInt8 ).to_numpy ().astype (numpy .bool_ )
103- return numpy .broadcast_to (values .cast (ty ).to_numpy (), len (df ))
104-
105- dtype = _polars_to_numpy_dtype (ty )
106- return numpy .broadcast_to (numpy .asarray (values , dtype ), len (df ))
104+ values = values .cast (ty )
105+ return numpy .broadcast_to (values , len (df ))
107106
108107
109108def _selection_to_expr (df : t .Union [polars .DataFrame , HasAtoms ], selection : t .Optional [AtomSelection ] = None ) -> polars .Expr :
@@ -127,7 +126,7 @@ def _select_schema(df: t.Union[polars.DataFrame, HasAtoms], schema: SchemaDict)
127126 polars .col (col ).cast (ty , strict = True )
128127 for (col , ty ) in schema .items ()
129128 ])
130- except (polars .ComputeError , polars .ColumnNotFoundError ):
129+ except (polars .exceptions . ComputeError , polars . exceptions .ColumnNotFoundError ):
131130 raise TypeError (f"Failed to cast '{ df .__class__ .__name__ } ' with schema '{ df .schema } ' to schema '{ schema } '." )
132131
133132
@@ -138,9 +137,10 @@ def _with_columns_stacked(df: polars.DataFrame, cols: t.Sequence[str], out_col:
138137 i = df .get_column_index (cols [0 ])
139138 dtype = df [cols [0 ]].dtype
140139
141- arr = numpy .array (tuple (df [c ].to_numpy () for c in cols )).T
140+ # https://github.com/pola-rs/polars/issues/18369
141+ arr = [] if len (df ) == 0 else numpy .array (tuple (df [c ].to_numpy () for c in cols )).T
142142
143- return df .drop (cols ).insert_column (i , polars .Series (out_col , arr , polars .Array (dtype , arr . shape [ - 1 ] )))
143+ return df .drop (cols ).insert_column (i , polars .Series (out_col , arr , polars .Array (dtype , len ( cols ) )))
144144
145145
146146HasAtomsT = t .TypeVar ('HasAtomsT' , bound = 'HasAtoms' )
@@ -250,8 +250,8 @@ def dtypes(self) -> t.List[polars.DataType]:
250250 ...
251251
252252 @property
253- @_fwd_frame (lambda df : df .schema )
254- def schema (self ) -> SchemaDict :
253+ @_fwd_frame (lambda df : df .schema ) # type: ignore
254+ def schema (self ) -> Schema :
255255 """
256256 Return the schema of `self`.
257257
@@ -288,7 +288,7 @@ def with_columns(self,
288288 def insert_column (self , index : int , column : polars .Series ) -> polars .DataFrame :
289289 return self ._get_frame ().insert_column (index , column )
290290
291- @_fwd_frame (polars . DataFrame . get_column )
291+ @_fwd_frame (lambda df , name : df . get_column ( name ) )
292292 def get_column (self , name : str ) -> polars .Series :
293293 """
294294 Get the specified column from `self`, raising [`polars.ColumnNotFoundError`][polars.exceptions.ColumnNotFoundError] if it's not present.
@@ -328,9 +328,9 @@ def clone(self) -> polars.DataFrame:
328328 """Return a copy of `self`."""
329329 return self ._get_frame ().clone ()
330330
331- def drop (self , * columns : t .Union [str , t .Iterable [str ]]) -> polars .DataFrame :
331+ def drop (self , * columns : t .Union [str , t .Iterable [str ]], strict : bool = True ) -> polars .DataFrame :
332332 """Return `self` with the specified columns removed."""
333- return self ._get_frame ().drop (* columns )
333+ return self ._get_frame ().drop (* columns , strict = strict )
334334
335335 # row-wise operations
336336
@@ -341,11 +341,10 @@ def filter(
341341 ) -> Self :
342342 """Filter `self`, removing rows which evaluate to `False`."""
343343 # TODO clean up
344- preds_not_none : t .Tuple [t .Union [IntoExprColumn , t .Iterable [IntoExprColumn ], bool , t .List [bool ], numpy .ndarray ], ...]
345- preds_not_none = tuple (filter (lambda p : p is not None , predicates )) # type: ignore
344+ preds_not_none = tuple (filter (lambda p : p is not None , predicates ))
346345 if not len (preds_not_none ) and not len (constraints ):
347346 return self
348- return self .with_atoms (Atoms (self ._get_frame ().filter (* preds_not_none , ** constraints ), _unchecked = True ))
347+ return self .with_atoms (Atoms (self ._get_frame ().filter (* preds_not_none , ** constraints ), _unchecked = True )) # type: ignore
349348
350349 @_fwd_frame_map
351350 def sort (
@@ -498,7 +497,7 @@ def select_props(
498497 A [`HasAtoms`][atomlib.atoms.HasAtoms] filtered to contain the
499498 specified properties (as well as required columns).
500499 """
501- props = self ._get_frame ().lazy ().select (* exprs , ** named_exprs ).drop (_REQUIRED_COLUMNS ).collect (_eager = True )
500+ props = self ._get_frame ().lazy ().select (* exprs , ** named_exprs ).drop (_REQUIRED_COLUMNS , strict = False ).collect (_eager = True )
502501 return self .with_atoms (
503502 Atoms (self ._get_frame ().select (_REQUIRED_COLUMNS ).hstack (props ), _unchecked = False )
504503 )
@@ -525,7 +524,7 @@ def try_get_column(self, name: str) -> t.Optional[polars.Series]:
525524 """Try to get a column from `self`, returning `None` if it doesn't exist."""
526525 try :
527526 return self .get_column (name )
528- except polars .ColumnNotFoundError :
527+ except polars .exceptions . ColumnNotFoundError :
529528 return None
530529
531530 def assert_equal (self , other : t .Any ):
@@ -555,7 +554,7 @@ def __radd__(self, other: IntoAtoms) -> HasAtoms:
555554 def __getitem__ (self , column : str ) -> polars .Series :
556555 try :
557556 return self .get_column (column )
558- except polars .ColumnNotFoundError :
557+ except polars .exceptions . ColumnNotFoundError :
559558 if column in ('x' , 'y' , 'z' ):
560559 return self .select (_coord_expr (column )).to_series ()
561560 raise
@@ -938,7 +937,8 @@ def with_coords(self, pts: ArrayLike, selection: t.Optional[AtomSelection] = Non
938937 new_pts [selection ] = pts
939938 pts = new_pts
940939
941- pts = numpy .broadcast_to (pts , (len (self ), 3 ))
940+ # https://github.com/pola-rs/polars/issues/18369
941+ pts = numpy .broadcast_to (pts , (len (self ), 3 )) if len (self ) else []
942942 return self .with_columns (polars .Series ('coords' , pts , polars .Array (polars .Float64 , 3 )))
943943
944944 def with_velocity (self , pts : t .Optional [ArrayLike ] = None ,
@@ -963,7 +963,8 @@ def with_velocity(self, pts: t.Optional[ArrayLike] = None,
963963 assert pts .shape [- 1 ] == 3
964964 all_pts [selection ] = pts
965965
966- all_pts = numpy .broadcast_to (all_pts , (len (self ), 3 ))
966+ # https://github.com/pola-rs/polars/issues/18369
967+ all_pts = numpy .broadcast_to (all_pts , (len (self ), 3 )) if len (self ) else []
967968 return self .with_columns (polars .Series ('velocity' , all_pts , polars .Array (polars .Float64 , 3 )))
968969
969970
@@ -1088,11 +1089,11 @@ def _repr_pretty_(self, p, cycle: bool) -> None:
10881089
10891090
10901091SchemaDict : TypeAlias = OrderedDict [str , polars .DataType ]
1091- IntoExprColumn : TypeAlias = polars .type_aliases .IntoExprColumn
1092- IntoExpr : TypeAlias = polars .type_aliases .IntoExpr
1093- UniqueKeepStrategy : TypeAlias = polars .type_aliases .UniqueKeepStrategy
1094- FillNullStrategy : TypeAlias = polars .type_aliases .FillNullStrategy
1095- RollingInterpolationMethod : TypeAlias = polars .type_aliases .RollingInterpolationMethod
1092+ IntoExprColumn : TypeAlias = polars ._typing .IntoExprColumn
1093+ IntoExpr : TypeAlias = polars ._typing .IntoExpr
1094+ UniqueKeepStrategy : TypeAlias = polars ._typing .UniqueKeepStrategy
1095+ FillNullStrategy : TypeAlias = polars ._typing .FillNullStrategy
1096+ RollingInterpolationMethod : TypeAlias = polars ._typing .RollingInterpolationMethod
10961097ConcatMethod : TypeAlias = t .Literal ['horizontal' , 'vertical' , 'diagonal' , 'inner' , 'align' ]
10971098
10981099IntoAtoms = t .Union [t .Dict [str , t .Sequence [t .Any ]], t .Sequence [t .Any ], numpy .ndarray , polars .DataFrame , 'Atoms' ]
@@ -1115,4 +1116,4 @@ def _repr_pretty_(self, p, cycle: bool) -> None:
11151116
11161117__all__ = [
11171118 'Atoms' , 'HasAtoms' , 'IntoAtoms' , 'AtomSelection' , 'AtomValues' ,
1118- ]
1119+ ]
0 commit comments