8585PULL_DATASET_CHUNK_TIMEOUT = 3600
8686PULL_DATASET_SLEEP_INTERVAL = 0.1 # sleep time while waiting for chunk to be available
8787PULL_DATASET_CHECK_STATUS_INTERVAL = 20 # interval to check export status in Studio
88+ _MAX_VERSION_CLAIM_RETRIES = 5
8889
8990
9091def _round_robin_batch (urls : list [str ], num_workers : int ) -> list [list [str ]]:
@@ -800,7 +801,6 @@ def create_dataset(
800801 columns : Sequence [Column ],
801802 feature_schema : dict | None = None ,
802803 query_script : str = "" ,
803- create_rows : bool | None = True ,
804804 validate_version : bool | None = True ,
805805 listing : bool | None = False ,
806806 uuid : str | None = None ,
@@ -821,19 +821,12 @@ def create_dataset(
821821 raise RuntimeError (
822822 "Cannot create dataset that starts with source prefix, e.g s3://"
823823 )
824- default_version = DEFAULT_DATASET_VERSION
825824 try :
826825 dataset = self .get_dataset (
827826 name ,
828827 namespace_name = project .namespace .name if project else None ,
829828 project_name = project .name if project else None ,
830829 )
831- if not version :
832- default_version = dataset .next_version_patch
833- if update_version == "major" :
834- default_version = dataset .next_version_major
835- if update_version == "minor" :
836- default_version = dataset .next_version_minor
837830
838831 if (description or attrs ) and (
839832 dataset .description != description or dataset .attrs != attrs
@@ -862,27 +855,114 @@ def create_dataset(
862855 attrs = attrs ,
863856 )
864857
865- version = version or default_version
866-
867- if dataset .has_version (version ):
868- raise DatasetInvalidVersionError (
869- f"Version { version } already exists in dataset { name } "
870- )
871-
872- if validate_version and not dataset .is_valid_next_version (version ):
873- raise DatasetInvalidVersionError (
874- f"Version { version } must be higher than the current latest one"
875- )
876-
877- return self .create_dataset_version (
878- dataset ,
879- version ,
858+ # Claim the version (with retry for auto-versioned saves).
859+ return self ._try_claim_version (
860+ dataset = dataset ,
861+ name = name ,
862+ version = version ,
863+ project = project ,
880864 feature_schema = feature_schema ,
881865 query_script = query_script ,
882- create_rows_table = create_rows ,
883866 columns = columns ,
884867 uuid = uuid ,
885868 job_id = job_id ,
869+ validate_version = validate_version ,
870+ update_version = update_version ,
871+ )
872+
873+ @staticmethod
874+ def _next_auto_version (dataset : "DatasetRecord" , update_version : str | None ) -> str :
875+ """Compute the next version for a dataset based on the update strategy.
876+
877+ Handles brand-new datasets whose versions list may contain a single
878+ phantom entry with ``version=None`` (artifact of the LEFT JOIN used
879+ by ``get_dataset``).
880+ """
881+ if not any (v .version for v in dataset .versions ):
882+ return DEFAULT_DATASET_VERSION
883+ if update_version == "major" :
884+ return dataset .next_version_major
885+ if update_version == "minor" :
886+ return dataset .next_version_minor
887+ return dataset .next_version_patch
888+
889+ def _try_claim_version (
890+ self ,
891+ dataset : "DatasetRecord" ,
892+ name : str ,
893+ version : str | None ,
894+ project : Project | None ,
895+ feature_schema : dict | None ,
896+ query_script : str ,
897+ columns : Sequence [Column ],
898+ uuid : str | None ,
899+ job_id : str | None ,
900+ validate_version : bool | None ,
901+ update_version : str | None ,
902+ ) -> "DatasetRecord" :
903+ """
904+ Try to claim a dataset version, retrying on conflict.
905+
906+ When *version* is explicit (not None), a single attempt is made and
907+ a conflict raises immediately. When *version* is None the target is
908+ auto-computed from the dataset and retried up to
909+ ``_MAX_VERSION_CLAIM_RETRIES`` times on conflict.
910+ """
911+ max_retries = 0 if version else _MAX_VERSION_CLAIM_RETRIES
912+ target_version = version or self ._next_auto_version (dataset , update_version )
913+
914+ for attempt in range (1 + max_retries ):
915+ if dataset .has_version (target_version ):
916+ raise DatasetInvalidVersionError (
917+ f"Version { target_version } already exists in dataset { name } "
918+ )
919+
920+ if validate_version and not dataset .is_valid_next_version (target_version ):
921+ raise DatasetInvalidVersionError (
922+ f"Version { target_version } must be higher than"
923+ f" the current latest one"
924+ )
925+
926+ dataset , version_created = self .create_dataset_version (
927+ dataset ,
928+ target_version ,
929+ feature_schema = feature_schema ,
930+ query_script = query_script ,
931+ columns = columns ,
932+ uuid = uuid ,
933+ job_id = job_id ,
934+ )
935+
936+ if version_created :
937+ return dataset
938+
939+ # Another writer claimed this version between our check and insert.
940+ if attempt >= max_retries :
941+ break
942+
943+ logger .debug (
944+ "Version %s of dataset %s was claimed by another writer "
945+ "(attempt %d/%d), retrying with next version" ,
946+ target_version ,
947+ name ,
948+ attempt + 1 ,
949+ 1 + max_retries ,
950+ )
951+ dataset = self .get_dataset (
952+ name ,
953+ namespace_name = project .namespace .name if project else None ,
954+ project_name = project .name if project else None ,
955+ )
956+ target_version = self ._next_auto_version (dataset , update_version )
957+
958+ if version :
959+ raise DatasetInvalidVersionError (
960+ f"Version { target_version } of dataset { name } was claimed by"
961+ " another writer"
962+ )
963+ raise DatasetInvalidVersionError (
964+ f"Failed to claim a version for dataset { name } after"
965+ f" { 1 + max_retries } attempts due to concurrent writers"
886966 )
887967
888968 def create_dataset_version (
@@ -897,20 +977,23 @@ def create_dataset_version(
897977 error_message = "" ,
898978 error_stack = "" ,
899979 script_output = "" ,
900- create_rows_table = True ,
901980 job_id : str | None = None ,
902981 uuid : str | None = None ,
903- ) -> DatasetRecord :
982+ ) -> tuple [ DatasetRecord , bool ] :
904983 """
905- Creates dataset version if it doesn't exist.
906- If create_rows is False, dataset rows table will not be created
984+ Creates dataset version metadata (no rows table).
985+
986+ Returns:
987+ A tuple of (dataset_record, version_created) where version_created
988+ is True if this call actually created the version, False if the
989+ version already existed (when ignore_if_exists applies).
907990 """
908991 assert [c .name for c in columns if c .name != "sys__id" ], f"got { columns = } "
909992 schema = {
910993 c .name : c .type .to_dict () for c in columns if isinstance (c .type , SQLType )
911994 }
912995
913- dataset = self .metastore .create_dataset_version (
996+ dataset , version_created = self .metastore .create_dataset_version (
914997 dataset ,
915998 version ,
916999 status = DatasetStatus .CREATED ,
@@ -926,12 +1009,7 @@ def create_dataset_version(
9261009 uuid = uuid ,
9271010 )
9281011
929- if create_rows_table :
930- table_name = self .warehouse .dataset_table_name (dataset , version )
931- self .warehouse .create_dataset_rows_table (table_name , columns = columns )
932- self .update_dataset_version_with_warehouse_info (dataset , version )
933-
934- return dataset
1012+ return dataset , version_created
9351013
9361014 def update_dataset_version_with_warehouse_info (
9371015 self , dataset : DatasetRecord , version : str , rows_dropped = False , ** kwargs
@@ -1841,8 +1919,6 @@ def pull_dataset( # noqa: C901, PLR0915, PLR0912
18411919 project ,
18421920 local_ds_version ,
18431921 query_script = remote_ds_version .query_script ,
1844- # Don't create table, we'll rename the temp table.
1845- create_rows = False ,
18461922 columns = columns ,
18471923 feature_schema = remote_ds_version .feature_schema ,
18481924 validate_version = False ,
0 commit comments