Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions tzrec/datasets/parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Any, Dict, Iterator, List, Optional

import pyarrow as pa
import pyarrow.dataset as ds
from pyarrow import parquet

from tzrec.constant import Mode
Expand All @@ -27,6 +26,16 @@
from tzrec.protos import data_pb2


def _reader_iter(
input_files: List[str],
batch_size: int,
) -> Iterator[pa.RecordBatch]:
for input_file in input_files:
parquet_file = parquet.ParquetFile(input_file)
for batch in parquet_file.iter_batches(batch_size):
yield batch


class ParquetDataset(BaseDataset):
"""Dataset for reading data with parquet format.

Expand Down Expand Up @@ -93,16 +102,16 @@ def __init__(
self._input_files.extend(glob.glob(input_path))
if len(self._input_files) == 0:
raise RuntimeError(f"No parquet files exist in {self._input_path}.")
dataset = ds.dataset(self._input_files[0], format="parquet")
parquet_file = parquet.ParquetFile(self._input_files[0])
if self._selected_cols:
self._ordered_cols = []
for field in dataset.schema:
for field in parquet_file.schema_arrow:
# pyre-ignore [58]
if field.name in selected_cols:
self.schema.append(field)
self._ordered_cols.append(field.name)
else:
self.schema = dataset.schema
self.schema = parquet_file.schema_arrow

def to_batches(
self, worker_id: int = 0, num_workers: int = 1
Expand All @@ -112,10 +121,7 @@ def to_batches(
if self._shuffle:
random.shuffle(input_files)
if len(input_files) > 0:
dataset = ds.dataset(input_files, format="parquet")
reader = dataset.to_batches(
batch_size=self._batch_size, columns=self._ordered_cols
)
reader = _reader_iter(input_files, self._batch_size)
yield from self._arrow_reader_iter(reader)

def num_files(self) -> int:
Expand Down