Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1965,6 +1965,7 @@ def spark() -> SparkSession:
.config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions")
.config("spark.sql.catalog.integration", "org.apache.iceberg.spark.SparkCatalog")
.config("spark.sql.catalog.integration.catalog-impl", "org.apache.iceberg.rest.RESTCatalog")
.config("spark.sql.catalog.integration.cache-enabled", "false")
.config("spark.sql.catalog.integration.uri", "http://localhost:8181")
.config("spark.sql.catalog.integration.io-impl", "org.apache.iceberg.aws.s3.S3FileIO")
.config("spark.sql.catalog.integration.warehouse", "s3://warehouse/wh/")
Expand Down
20 changes: 20 additions & 0 deletions tests/integration/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,26 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w
assert [row.deleted_data_files_count for row in rows] == [0, 0, 1, 0, 0]


@pytest.mark.integration
def test_multiple_spark_sql_queries(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think a better name would be test_python_writes_with_spark_snapshot_reads or something more specific than what it currently is . It's mre verbose but I think it captures the goal of the test better

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense, added!

identifier = "default.multiple_spark_sql_queries"
tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, [])

def get_current_snapshot_id(identifier: str) -> int:
return (
spark.sql(f"SELECT snapshot_id FROM {identifier}.snapshots order by committed_at desc limit 1")
.collect()[0]
.snapshot_id
)

tbl.overwrite(arrow_table_with_null)
assert tbl.current_snapshot().snapshot_id == get_current_snapshot_id(identifier) # type: ignore
tbl.overwrite(arrow_table_with_null)
assert tbl.current_snapshot().snapshot_id == get_current_snapshot_id(identifier) # type: ignore
tbl.append(arrow_table_with_null)
assert tbl.current_snapshot().snapshot_id == get_current_snapshot_id(identifier) # type: ignore


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
@pytest.mark.parametrize(
Expand Down