Skip to content

Commit aba34c1

Browse files
committed
Polars version update
1 parent a28c52f commit aba34c1

File tree

16 files changed

+431
-92
lines changed

16 files changed

+431
-92
lines changed

atomlib/atomcell.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
import copy
1313
import typing as t
1414

15+
from typing_extensions import Self
1516
import numpy
1617
from numpy.typing import NDArray, ArrayLike
1718
import polars
1819
import polars.dataframe.group_by
1920

21+
from typing_extensions import ParamSpec, Concatenate
2022
from .bbox import BBox3D
21-
from .types import VecLike, to_vec3, ParamSpec, Concatenate, Self
23+
from .types import VecLike, to_vec3
2224
from .transform import LinearTransform3D, AffineTransform3D, Transform3D, IntoTransform3D
2325
from .cell import CoordinateFrame, HasCell, Cell
2426
from .atoms import HasAtoms, Atoms, IntoAtoms, AtomSelection, AtomValues

atomlib/atoms.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,18 @@
2020
from io import StringIO
2121
import typing as t
2222

23+
from typing_extensions import Self, ParamSpec, Concatenate, TypeAlias
2324
import numpy
2425
from numpy.typing import ArrayLike, NDArray
2526
import polars
2627
import polars.dataframe.group_by
2728
import polars.datatypes
2829
import polars.interchange.dataframe
2930
import polars.testing
30-
import polars.type_aliases
31+
import polars._typing
32+
from polars.schema import Schema
3133

32-
from .types import to_vec3, VecLike, ParamSpec, Concatenate, TypeAlias, Self
34+
from .types import to_vec3, VecLike
3335
from .bbox import BBox3D
3436
from .elem import get_elem, get_sym, get_mass
3537
from .transform import Transform3D, IntoTransform3D, AffineTransform3D
@@ -59,6 +61,7 @@ def _is_abstract(cls: t.Type) -> bool:
5961
return bool(getattr(cls, "__abstractmethods__", False))
6062

6163

64+
"""
6265
def _polars_to_numpy_dtype(dtype: t.Type[polars.DataType]) -> numpy.dtype:
6366
from polars.datatypes import dtype_to_ctype
6467
if dtype == polars.Boolean:
@@ -67,14 +70,15 @@ def _polars_to_numpy_dtype(dtype: t.Type[polars.DataType]) -> numpy.dtype:
6770
return numpy.dtype(dtype_to_ctype(dtype))
6871
except NotImplementedError:
6972
return numpy.dtype(object)
73+
"""
7074

7175

7276
def _get_symbol_mapping(df: t.Union[polars.DataFrame, HasAtoms], mapping: t.Mapping[str, t.Any], ty: t.Type[polars.DataType]) -> polars.Expr:
7377
syms = df['symbol'].unique()
7478
if (missing := set(syms) - set(mapping.keys())):
7579
raise ValueError(f"Could not remap symbols {', '.join(map(repr, missing))}")
7680

77-
return polars.col('symbol').replace(mapping, default=None, return_dtype=ty)
81+
return polars.col('symbol').replace_strict(mapping, default=None, return_dtype=ty)
7882

7983

8084
def _values_to_expr(df: t.Union[polars.DataFrame, HasAtoms], values: AtomValues, ty: t.Type[polars.DataType]) -> polars.Expr:
@@ -97,13 +101,8 @@ def _values_to_numpy(df: t.Union[polars.DataFrame, HasAtoms], values: AtomValues
97101
# syms = df.select(polars.col('symbol').filter(values.is_null())).unique().to_series().to_list()
98102
# raise ValueError(f"Could not remap symbols {', '.join(map(repr, syms))}")
99103
if isinstance(values, polars.Series):
100-
if ty == polars.Boolean:
101-
# force conversion to numpy (unpacked) bool
102-
return values.cast(polars.UInt8).to_numpy().astype(numpy.bool_)
103-
return numpy.broadcast_to(values.cast(ty).to_numpy(), len(df))
104-
105-
dtype = _polars_to_numpy_dtype(ty)
106-
return numpy.broadcast_to(numpy.asarray(values, dtype), len(df))
104+
values = values.cast(ty)
105+
return numpy.broadcast_to(values, len(df))
107106

108107

109108
def _selection_to_expr(df: t.Union[polars.DataFrame, HasAtoms], selection: t.Optional[AtomSelection] = None) -> polars.Expr:
@@ -127,7 +126,7 @@ def _select_schema(df: t.Union[polars.DataFrame, HasAtoms], schema: SchemaDict)
127126
polars.col(col).cast(ty, strict=True)
128127
for (col, ty) in schema.items()
129128
])
130-
except (polars.ComputeError, polars.ColumnNotFoundError):
129+
except (polars.exceptions.ComputeError, polars.exceptions.ColumnNotFoundError):
131130
raise TypeError(f"Failed to cast '{df.__class__.__name__}' with schema '{df.schema}' to schema '{schema}'.")
132131

133132

@@ -138,9 +137,10 @@ def _with_columns_stacked(df: polars.DataFrame, cols: t.Sequence[str], out_col:
138137
i = df.get_column_index(cols[0])
139138
dtype = df[cols[0]].dtype
140139

141-
arr = numpy.array(tuple(df[c].to_numpy() for c in cols)).T
140+
# https://github.com/pola-rs/polars/issues/18369
141+
arr = [] if len(df) == 0 else numpy.array(tuple(df[c].to_numpy() for c in cols)).T
142142

143-
return df.drop(cols).insert_column(i, polars.Series(out_col, arr, polars.Array(dtype, arr.shape[-1])))
143+
return df.drop(cols).insert_column(i, polars.Series(out_col, arr, polars.Array(dtype, len(cols))))
144144

145145

146146
HasAtomsT = t.TypeVar('HasAtomsT', bound='HasAtoms')
@@ -250,8 +250,8 @@ def dtypes(self) -> t.List[polars.DataType]:
250250
...
251251

252252
@property
253-
@_fwd_frame(lambda df: df.schema)
254-
def schema(self) -> SchemaDict:
253+
@_fwd_frame(lambda df: df.schema) # type: ignore
254+
def schema(self) -> Schema:
255255
"""
256256
Return the schema of `self`.
257257
@@ -288,7 +288,7 @@ def with_columns(self,
288288
def insert_column(self, index: int, column: polars.Series) -> polars.DataFrame:
289289
return self._get_frame().insert_column(index, column)
290290

291-
@_fwd_frame(polars.DataFrame.get_column)
291+
@_fwd_frame(lambda df, name: df.get_column(name))
292292
def get_column(self, name: str) -> polars.Series:
293293
"""
294294
Get the specified column from `self`, raising [`polars.ColumnNotFoundError`][polars.exceptions.ColumnNotFoundError] if it's not present.
@@ -328,9 +328,9 @@ def clone(self) -> polars.DataFrame:
328328
"""Return a copy of `self`."""
329329
return self._get_frame().clone()
330330

331-
def drop(self, *columns: t.Union[str, t.Iterable[str]]) -> polars.DataFrame:
331+
def drop(self, *columns: t.Union[str, t.Iterable[str]], strict: bool = True) -> polars.DataFrame:
332332
"""Return `self` with the specified columns removed."""
333-
return self._get_frame().drop(*columns)
333+
return self._get_frame().drop(*columns, strict=strict)
334334

335335
# row-wise operations
336336

@@ -341,11 +341,10 @@ def filter(
341341
) -> Self:
342342
"""Filter `self`, removing rows which evaluate to `False`."""
343343
# TODO clean up
344-
preds_not_none: t.Tuple[t.Union[IntoExprColumn, t.Iterable[IntoExprColumn], bool, t.List[bool], numpy.ndarray], ...]
345-
preds_not_none = tuple(filter(lambda p: p is not None, predicates)) # type: ignore
344+
preds_not_none = tuple(filter(lambda p: p is not None, predicates))
346345
if not len(preds_not_none) and not len(constraints):
347346
return self
348-
return self.with_atoms(Atoms(self._get_frame().filter(*preds_not_none, **constraints), _unchecked=True))
347+
return self.with_atoms(Atoms(self._get_frame().filter(*preds_not_none, **constraints), _unchecked=True)) # type: ignore
349348

350349
@_fwd_frame_map
351350
def sort(
@@ -498,7 +497,7 @@ def select_props(
498497
A [`HasAtoms`][atomlib.atoms.HasAtoms] filtered to contain the
499498
specified properties (as well as required columns).
500499
"""
501-
props = self._get_frame().lazy().select(*exprs, **named_exprs).drop(_REQUIRED_COLUMNS).collect(_eager=True)
500+
props = self._get_frame().lazy().select(*exprs, **named_exprs).drop(_REQUIRED_COLUMNS, strict=False).collect(_eager=True)
502501
return self.with_atoms(
503502
Atoms(self._get_frame().select(_REQUIRED_COLUMNS).hstack(props), _unchecked=False)
504503
)
@@ -525,7 +524,7 @@ def try_get_column(self, name: str) -> t.Optional[polars.Series]:
525524
"""Try to get a column from `self`, returning `None` if it doesn't exist."""
526525
try:
527526
return self.get_column(name)
528-
except polars.ColumnNotFoundError:
527+
except polars.exceptions.ColumnNotFoundError:
529528
return None
530529

531530
def assert_equal(self, other: t.Any):
@@ -555,7 +554,7 @@ def __radd__(self, other: IntoAtoms) -> HasAtoms:
555554
def __getitem__(self, column: str) -> polars.Series:
556555
try:
557556
return self.get_column(column)
558-
except polars.ColumnNotFoundError:
557+
except polars.exceptions.ColumnNotFoundError:
559558
if column in ('x', 'y', 'z'):
560559
return self.select(_coord_expr(column)).to_series()
561560
raise
@@ -938,7 +937,8 @@ def with_coords(self, pts: ArrayLike, selection: t.Optional[AtomSelection] = Non
938937
new_pts[selection] = pts
939938
pts = new_pts
940939

941-
pts = numpy.broadcast_to(pts, (len(self), 3))
940+
# https://github.com/pola-rs/polars/issues/18369
941+
pts = numpy.broadcast_to(pts, (len(self), 3)) if len(self) else []
942942
return self.with_columns(polars.Series('coords', pts, polars.Array(polars.Float64, 3)))
943943

944944
def with_velocity(self, pts: t.Optional[ArrayLike] = None,
@@ -963,7 +963,8 @@ def with_velocity(self, pts: t.Optional[ArrayLike] = None,
963963
assert pts.shape[-1] == 3
964964
all_pts[selection] = pts
965965

966-
all_pts = numpy.broadcast_to(all_pts, (len(self), 3))
966+
# https://github.com/pola-rs/polars/issues/18369
967+
all_pts = numpy.broadcast_to(all_pts, (len(self), 3)) if len(self) else []
967968
return self.with_columns(polars.Series('velocity', all_pts, polars.Array(polars.Float64, 3)))
968969

969970

@@ -1088,11 +1089,11 @@ def _repr_pretty_(self, p, cycle: bool) -> None:
10881089

10891090

10901091
SchemaDict: TypeAlias = OrderedDict[str, polars.DataType]
1091-
IntoExprColumn: TypeAlias = polars.type_aliases.IntoExprColumn
1092-
IntoExpr: TypeAlias = polars.type_aliases.IntoExpr
1093-
UniqueKeepStrategy: TypeAlias = polars.type_aliases.UniqueKeepStrategy
1094-
FillNullStrategy: TypeAlias = polars.type_aliases.FillNullStrategy
1095-
RollingInterpolationMethod: TypeAlias = polars.type_aliases.RollingInterpolationMethod
1092+
IntoExprColumn: TypeAlias = polars._typing.IntoExprColumn
1093+
IntoExpr: TypeAlias = polars._typing.IntoExpr
1094+
UniqueKeepStrategy: TypeAlias = polars._typing.UniqueKeepStrategy
1095+
FillNullStrategy: TypeAlias = polars._typing.FillNullStrategy
1096+
RollingInterpolationMethod: TypeAlias = polars._typing.RollingInterpolationMethod
10961097
ConcatMethod: TypeAlias = t.Literal['horizontal', 'vertical', 'diagonal', 'inner', 'align']
10971098

10981099
IntoAtoms = t.Union[t.Dict[str, t.Sequence[t.Any]], t.Sequence[t.Any], numpy.ndarray, polars.DataFrame, 'Atoms']
@@ -1115,4 +1116,4 @@ def _repr_pretty_(self, p, cycle: bool) -> None:
11151116

11161117
__all__ = [
11171118
'Atoms', 'HasAtoms', 'IntoAtoms', 'AtomSelection', 'AtomValues',
1118-
]
1119+
]

atomlib/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
import typing as t
77
import logging
88

9+
from typing_extensions import ParamSpec, Concatenate
910
import click
1011

1112
from . import CoordinateFrame, HasAtoms, Atoms, AtomCell, AtomSelection
1213
from . import io
13-
from .types import ParamSpec, Concatenate
1414
from .transform import LinearTransform3D, AffineTransform3D
1515

1616

atomlib/defect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def ellip_pi(n: NDArray[numpy.float64], m: NDArray[numpy.float64]) -> NDArray[nu
2626
2727
[wolfram_ellip_pi]: https://mathworld.wolfram.com/EllipticIntegraloftheThirdKind.html
2828
"""
29-
from scipy.special import elliprf, elliprj
29+
from scipy.special import elliprf, elliprj # type: ignore
3030

3131
y = 1 - m
3232
assert numpy.all(y > 0)

atomlib/elem.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@
44

55
from importlib_resources import files
66
import polars
7-
from polars.exceptions import PolarsPanicError
87
import numpy
98

9+
try:
10+
from polars.exceptions import PanicException
11+
except ImportError:
12+
from polars.exceptions import PolarsPanicError as PanicException # type: ignore
13+
1014
from .types import ElemLike
1115

1216
ELEMENTS = {
@@ -79,7 +83,7 @@ def get_elem(sym: t.Union[int, str, polars.Series]):
7983

8084
if isinstance(sym, polars.Series):
8185
elem = sym.str.extract(_SYM_RE, 0).str.to_lowercase() \
82-
.replace(ELEMENTS, return_dtype=polars.UInt8, default=255) \
86+
.replace_strict(ELEMENTS, default=255, return_dtype=polars.UInt8) \
8387
.alias('elem')
8488

8589
if (invalid := sym.filter(sym.is_not_null() & (elem > 118)).to_list()):
@@ -123,7 +127,7 @@ def get_sym(elem: t.Union[int, polars.Series]):
123127
try:
124128
return elem.map_elements(_get_sym, return_dtype=polars.Utf8, skip_nulls=True) \
125129
.alias('symbol')
126-
except PolarsPanicError:
130+
except PanicException:
127131
# attempt to recreate the error in Python
128132
_ = [_get_sym(t.cast(int, e)) for e in elem.to_list() if e is not None]
129133
raise

atomlib/io/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def read_cif(f: t.Union[FileOrPath, CIF, CIFDataBlock], block: t.Union[int, str,
4343
raise ValueError("No data present in CIF file.")
4444
if block is None:
4545
if len(cif) > 1:
46-
logging.warn("Multiple blocks present in CIF file. Defaulting to reading first block.")
46+
logging.warning("Multiple blocks present in CIF file. Defaulting to reading first block.")
4747
cif = cif.data_blocks[0]
4848
else:
4949
cif = cif.get_block(block)

atomlib/io/lmp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_atoms(self, type_map: t.Optional[t.Dict[int, t.Union[str, int]]] = None)
5252
def _apply_type_labels(df: polars.DataFrame, section_name: str, labels: t.Optional[polars.DataFrame] = None) -> polars.DataFrame:
5353
if labels is not None:
5454
#df = df.with_columns(polars.col('type').replace(d, default=polars.col('type').cast(polars.Int32, strict=False), return_dtype=polars.Int32))
55-
df = df.with_columns(polars.col('type').replace(labels['symbol'], labels['type'], default=polars.col('type').cast(polars.Int32, strict=False), return_dtype=polars.Int32))
55+
df = df.with_columns(polars.col('type').replace_strict(labels['symbol'], labels['type'], default=polars.col('type').cast(polars.Int32, strict=False), return_dtype=polars.Int32))
5656
if df['type'].is_null().any():
5757
raise ValueError(f"While parsing section {section_name}: Unknown atom label or invalid atom type")
5858
try:

atomlib/io/mslice.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from warnings import warn
1414
import typing as t
1515

16+
from typing_extensions import TypeAlias
1617
from importlib_resources import files
1718
import numpy
1819
from numpy.typing import ArrayLike
@@ -25,9 +26,9 @@
2526
from ..transform import AffineTransform3D, LinearTransform3D
2627

2728

28-
ElementTree = et._ElementTree
29-
Element = et._Element
30-
MSliceFile = t.Union[ElementTree, FileOrPath]
29+
ElementTree: TypeAlias = et._ElementTree
30+
Element: TypeAlias = et._Element
31+
MSliceFile: TypeAlias = t.Union[ElementTree, FileOrPath]
3132

3233

3334
DEFAULT_TEMPLATE_PATH = files('atomlib.data') / 'template.mslice'

atomlib/io/util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import typing as t
66

77
import polars
8-
from polars.type_aliases import SchemaDict, PolarsDataType
8+
from polars._typing import SchemaDict, PolarsDataType
99

1010

1111
class LineBuffer:
@@ -144,15 +144,15 @@ def _parse_rows_whitespace_separated(
144144
assert inner_ty is not None
145145

146146
elem_base_name = _ARRAY_ELEM_NAMES.get(col, col)
147-
suffixes = ('x', 'y', 'z') if ty.width == 3 else range(ty.width)
147+
suffixes = ('x', 'y', 'z') if ty.size == 3 else range(ty.size)
148148
elem_cols = [f"{elem_base_name}_{s}".lstrip('_') for s in suffixes]
149149

150150
expanded_schema.update({elem_col: inner_ty for elem_col in elem_cols})
151151

152152
exprs.append(polars.concat_list(
153153
polars.col('s').struct.field(elem_col)
154154
for elem_col in elem_cols
155-
).list.to_array(ty.width).alias(col))
155+
).list.to_array(ty.size).alias(col))
156156

157157
regex = "".join((
158158
"^",

atomlib/io/xsf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __post_init__(self):
103103
raise ValueError("Error: No coordinates are specified (atoms, primitive, or conventional).")
104104

105105
if self.prim_coords is not None and self.conv_coords is not None:
106-
logging.warn("Warning: Both 'primcoord' and 'convcoord' are specified. 'convcoord' will be ignored.")
106+
logging.warning("Warning: Both 'primcoord' and 'convcoord' are specified. 'convcoord' will be ignored.")
107107
elif self.conv_coords is not None and self.conventional_cell is None:
108108
raise ValueError("If 'convcoord' is specified, 'convvec' must be specified as well.")
109109

@@ -196,7 +196,7 @@ def parse_atoms(self, expected_length: t.Optional[int] = None) -> polars.DataFra
196196

197197
if expected_length is not None:
198198
if not expected_length == len(zs):
199-
logging.warn(f"Warning: List length {len(zs)} doesn't match declared length {expected_length}")
199+
logging.warning(f"Warning: List length {len(zs)} doesn't match declared length {expected_length}")
200200
elif len(zs) == 0:
201201
raise ValueError(f"Expected atom list after keyword 'ATOMS'. Got '{line or 'EOF'}' instead.")
202202

0 commit comments

Comments
 (0)