Skip to content

Commit 076e138

Browse files
committed
fix(save): make isolated and atomic
1 parent 58735f9 commit 076e138

7 files changed

Lines changed: 319 additions & 174 deletions

File tree

src/datachain/catalog/catalog.py

Lines changed: 113 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
PULL_DATASET_CHUNK_TIMEOUT = 3600
8686
PULL_DATASET_SLEEP_INTERVAL = 0.1 # sleep time while waiting for chunk to be available
8787
PULL_DATASET_CHECK_STATUS_INTERVAL = 20 # interval to check export status in Studio
88+
_MAX_VERSION_CLAIM_RETRIES = 5
8889

8990

9091
def _round_robin_batch(urls: list[str], num_workers: int) -> list[list[str]]:
@@ -800,7 +801,6 @@ def create_dataset(
800801
columns: Sequence[Column],
801802
feature_schema: dict | None = None,
802803
query_script: str = "",
803-
create_rows: bool | None = True,
804804
validate_version: bool | None = True,
805805
listing: bool | None = False,
806806
uuid: str | None = None,
@@ -821,19 +821,12 @@ def create_dataset(
821821
raise RuntimeError(
822822
"Cannot create dataset that starts with source prefix, e.g s3://"
823823
)
824-
default_version = DEFAULT_DATASET_VERSION
825824
try:
826825
dataset = self.get_dataset(
827826
name,
828827
namespace_name=project.namespace.name if project else None,
829828
project_name=project.name if project else None,
830829
)
831-
if not version:
832-
default_version = dataset.next_version_patch
833-
if update_version == "major":
834-
default_version = dataset.next_version_major
835-
if update_version == "minor":
836-
default_version = dataset.next_version_minor
837830

838831
if (description or attrs) and (
839832
dataset.description != description or dataset.attrs != attrs
@@ -862,27 +855,114 @@ def create_dataset(
862855
attrs=attrs,
863856
)
864857

865-
version = version or default_version
866-
867-
if dataset.has_version(version):
868-
raise DatasetInvalidVersionError(
869-
f"Version {version} already exists in dataset {name}"
870-
)
871-
872-
if validate_version and not dataset.is_valid_next_version(version):
873-
raise DatasetInvalidVersionError(
874-
f"Version {version} must be higher than the current latest one"
875-
)
876-
877-
return self.create_dataset_version(
878-
dataset,
879-
version,
858+
# Claim the version (with retry for auto-versioned saves).
859+
return self._try_claim_version(
860+
dataset=dataset,
861+
name=name,
862+
version=version,
863+
project=project,
880864
feature_schema=feature_schema,
881865
query_script=query_script,
882-
create_rows_table=create_rows,
883866
columns=columns,
884867
uuid=uuid,
885868
job_id=job_id,
869+
validate_version=validate_version,
870+
update_version=update_version,
871+
)
872+
873+
@staticmethod
874+
def _next_auto_version(dataset: "DatasetRecord", update_version: str | None) -> str:
875+
"""Compute the next version for a dataset based on the update strategy.
876+
877+
Handles brand-new datasets whose versions list may contain a single
878+
phantom entry with ``version=None`` (artifact of the LEFT JOIN used
879+
by ``get_dataset``).
880+
"""
881+
if not any(v.version for v in dataset.versions):
882+
return DEFAULT_DATASET_VERSION
883+
if update_version == "major":
884+
return dataset.next_version_major
885+
if update_version == "minor":
886+
return dataset.next_version_minor
887+
return dataset.next_version_patch
888+
889+
def _try_claim_version(
890+
self,
891+
dataset: "DatasetRecord",
892+
name: str,
893+
version: str | None,
894+
project: Project | None,
895+
feature_schema: dict | None,
896+
query_script: str,
897+
columns: Sequence[Column],
898+
uuid: str | None,
899+
job_id: str | None,
900+
validate_version: bool | None,
901+
update_version: str | None,
902+
) -> "DatasetRecord":
903+
"""
904+
Try to claim a dataset version, retrying on conflict.
905+
906+
When *version* is explicit (not None), a single attempt is made and
907+
a conflict raises immediately. When *version* is None the target is
908+
auto-computed from the dataset and retried up to
909+
``_MAX_VERSION_CLAIM_RETRIES`` times on conflict.
910+
"""
911+
max_retries = 0 if version else _MAX_VERSION_CLAIM_RETRIES
912+
target_version = version or self._next_auto_version(dataset, update_version)
913+
914+
for attempt in range(1 + max_retries):
915+
if dataset.has_version(target_version):
916+
raise DatasetInvalidVersionError(
917+
f"Version {target_version} already exists in dataset {name}"
918+
)
919+
920+
if validate_version and not dataset.is_valid_next_version(target_version):
921+
raise DatasetInvalidVersionError(
922+
f"Version {target_version} must be higher than"
923+
f" the current latest one"
924+
)
925+
926+
dataset, version_created = self.create_dataset_version(
927+
dataset,
928+
target_version,
929+
feature_schema=feature_schema,
930+
query_script=query_script,
931+
columns=columns,
932+
uuid=uuid,
933+
job_id=job_id,
934+
)
935+
936+
if version_created:
937+
return dataset
938+
939+
# Another writer claimed this version between our check and insert.
940+
if attempt >= max_retries:
941+
break
942+
943+
logger.debug(
944+
"Version %s of dataset %s was claimed by another writer "
945+
"(attempt %d/%d), retrying with next version",
946+
target_version,
947+
name,
948+
attempt + 1,
949+
1 + max_retries,
950+
)
951+
dataset = self.get_dataset(
952+
name,
953+
namespace_name=project.namespace.name if project else None,
954+
project_name=project.name if project else None,
955+
)
956+
target_version = self._next_auto_version(dataset, update_version)
957+
958+
if version:
959+
raise DatasetInvalidVersionError(
960+
f"Version {target_version} of dataset {name} was claimed by"
961+
" another writer"
962+
)
963+
raise DatasetInvalidVersionError(
964+
f"Failed to claim a version for dataset {name} after"
965+
f" {1 + max_retries} attempts due to concurrent writers"
886966
)
887967

888968
def create_dataset_version(
@@ -897,20 +977,23 @@ def create_dataset_version(
897977
error_message="",
898978
error_stack="",
899979
script_output="",
900-
create_rows_table=True,
901980
job_id: str | None = None,
902981
uuid: str | None = None,
903-
) -> DatasetRecord:
982+
) -> tuple[DatasetRecord, bool]:
904983
"""
905-
Creates dataset version if it doesn't exist.
906-
If create_rows is False, dataset rows table will not be created
984+
Creates dataset version metadata (no rows table).
985+
986+
Returns:
987+
A tuple of (dataset_record, version_created) where version_created
988+
is True if this call actually created the version, False if the
989+
version already existed (when ignore_if_exists applies).
907990
"""
908991
assert [c.name for c in columns if c.name != "sys__id"], f"got {columns=}"
909992
schema = {
910993
c.name: c.type.to_dict() for c in columns if isinstance(c.type, SQLType)
911994
}
912995

913-
dataset = self.metastore.create_dataset_version(
996+
dataset, version_created = self.metastore.create_dataset_version(
914997
dataset,
915998
version,
916999
status=DatasetStatus.CREATED,
@@ -926,12 +1009,7 @@ def create_dataset_version(
9261009
uuid=uuid,
9271010
)
9281011

929-
if create_rows_table:
930-
table_name = self.warehouse.dataset_table_name(dataset, version)
931-
self.warehouse.create_dataset_rows_table(table_name, columns=columns)
932-
self.update_dataset_version_with_warehouse_info(dataset, version)
933-
934-
return dataset
1012+
return dataset, version_created
9351013

9361014
def update_dataset_version_with_warehouse_info(
9371015
self, dataset: DatasetRecord, version: str, rows_dropped=False, **kwargs
@@ -1841,8 +1919,6 @@ def pull_dataset( # noqa: C901, PLR0915, PLR0912
18411919
project,
18421920
local_ds_version,
18431921
query_script=remote_ds_version.query_script,
1844-
# Don't create table, we'll rename the temp table.
1845-
create_rows=False,
18461922
columns=columns,
18471923
feature_schema=remote_ds_version.feature_schema,
18481924
validate_version=False,

src/datachain/data_storage/metastore.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,14 @@ def create_dataset_version( # noqa: PLR0913
310310
preview: list[dict] | None = None,
311311
job_id: str | None = None,
312312
uuid: str | None = None,
313-
) -> DatasetRecord:
314-
"""Creates new dataset version."""
313+
) -> tuple[DatasetRecord, bool]:
314+
"""Creates new dataset version.
315+
316+
Returns:
317+
A tuple of (dataset_record, version_created) where version_created
318+
is True if this call actually created the version, False if the
319+
version already existed (only possible when ignore_if_exists=True).
320+
"""
315321

316322
@abstractmethod
317323
def remove_dataset(self, dataset: DatasetRecord) -> None:
@@ -1234,16 +1240,24 @@ def create_dataset_version( # noqa: PLR0913
12341240
job_id: str | None = None,
12351241
uuid: str | None = None,
12361242
conn=None,
1237-
) -> DatasetRecord:
1238-
"""Creates new dataset version."""
1243+
) -> tuple[DatasetRecord, bool]:
1244+
"""Creates new dataset version.
1245+
1246+
Returns:
1247+
A tuple of (dataset_record, version_created) where version_created
1248+
is True if this call actually created the version, False if the
1249+
version already existed (only possible when ignore_if_exists=True).
1250+
"""
12391251
if status in [DatasetStatus.COMPLETE, DatasetStatus.FAILED]:
12401252
finished_at = finished_at or datetime.now(timezone.utc)
12411253
else:
12421254
finished_at = None
12431255

1256+
my_uuid = uuid or str(uuid4())
1257+
12441258
query = self._datasets_versions_insert().values(
12451259
dataset_id=dataset.id,
1246-
uuid=uuid or str(uuid4()),
1260+
uuid=my_uuid,
12471261
version=version,
12481262
status=status,
12491263
feature_schema=json.dumps(feature_schema or {}),
@@ -1268,14 +1282,22 @@ def create_dataset_version( # noqa: PLR0913
12681282
)
12691283
self.db.execute(query, conn=conn)
12701284

1271-
return self.get_dataset(
1285+
dataset = self.get_dataset(
12721286
dataset.name,
12731287
namespace_name=dataset.project.namespace.name,
12741288
project_name=dataset.project.name,
12751289
include_incomplete=True,
12761290
conn=conn,
12771291
)
12781292

1293+
# Detect whether this call actually created the version by comparing
1294+
# the UUID we attempted to insert with the one stored in the DB.
1295+
# If another writer won the ON CONFLICT race, the stored UUID will
1296+
# differ from ours.
1297+
version_created = dataset.get_version(version).uuid == my_uuid
1298+
1299+
return dataset, version_created
1300+
12791301
def remove_dataset(self, dataset: DatasetRecord) -> None:
12801302
"""Removes dataset."""
12811303
d = self._datasets

src/datachain/lib/dc/records.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,11 @@ def generate_records():
136136
to_insert = []
137137

138138
warehouse = catalog.warehouse
139+
140+
# Create the rows table (create_dataset only creates metadata).
141+
table_name = warehouse.dataset_table_name(dsr, dsr.latest_version)
142+
warehouse.create_dataset_rows_table(table_name, columns=columns)
143+
139144
dr = warehouse.dataset_rows(dsr)
140145
table = dr.get_table()
141146

src/datachain/query/dataset.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2820,24 +2820,54 @@ def save(
28202820
"Ensure at least one column (other than 'id') is selected."
28212821
)
28222822

2823-
dataset = self.catalog.create_dataset(
2824-
name,
2825-
project,
2826-
version=version,
2827-
feature_schema=feature_schema,
2828-
columns=columns,
2829-
description=description,
2830-
attrs=attrs,
2831-
update_version=update_version,
2832-
job_id=job_id,
2833-
**kwargs,
2823+
# Phase 1: Create a temp staging table and populate it.
2824+
# If the process dies here, only an orphaned tmp_ table remains,
2825+
# cleaned up by 'datachain gc'.
2826+
temp_table_name = self.catalog.warehouse.temp_table_name()
2827+
self.catalog.warehouse.create_dataset_rows_table(
2828+
temp_table_name, columns=columns
28342829
)
2835-
version = version or dataset.latest_version
2830+
temp_table = self.catalog.warehouse.get_table(temp_table_name)
2831+
try:
2832+
self.catalog.warehouse.insert_into(temp_table, query.select())
2833+
except Exception:
2834+
with contextlib.suppress(Exception):
2835+
self.catalog.warehouse.cleanup_tables([temp_table_name])
2836+
raise
2837+
2838+
# Phase 2: Claim the version (metadata only).
2839+
try:
2840+
dataset = self.catalog.create_dataset(
2841+
name,
2842+
project,
2843+
version=version,
2844+
feature_schema=feature_schema,
2845+
columns=columns,
2846+
description=description,
2847+
attrs=attrs,
2848+
update_version=update_version,
2849+
job_id=job_id,
2850+
**kwargs,
2851+
)
2852+
except Exception:
2853+
with contextlib.suppress(Exception):
2854+
self.catalog.warehouse.cleanup_tables([temp_table_name])
2855+
raise
28362856

2837-
dr = self.catalog.warehouse.dataset_rows(dataset)
2857+
version = version or dataset.latest_version
28382858

2839-
self.catalog.warehouse.insert_into(dr.get_table(), query.select())
2859+
# Phase 3: Rename staging table to the final dataset table name.
2860+
final_table_name = self.catalog.warehouse.dataset_table_name(
2861+
dataset, version
2862+
)
2863+
try:
2864+
self.catalog.warehouse.rename_table(temp_table, final_table_name)
2865+
except Exception:
2866+
with contextlib.suppress(Exception):
2867+
self.catalog.warehouse.cleanup_tables([temp_table_name])
2868+
raise
28402869

2870+
# Phase 4: Finalize metadata and mark COMPLETE.
28412871
self.catalog.update_dataset_version_with_warehouse_info(dataset, version)
28422872

28432873
# Link this dataset version to the job that created it

0 commit comments

Comments
 (0)