Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions invokeai/app/api/routers/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ async def list_model_installs() -> List[ModelInstallJob]:
* "waiting" -- Job is waiting in the queue to run
* "downloading" -- Model file(s) are downloading
* "running" -- Model has downloaded and the model probing and registration process is running
* "paused" -- Job is paused and can be resumed
* "completed" -- Installation completed successfully
* "error" -- An error occurred. Details will be in the "error_type" and "error" fields.
* "cancelled" -- Job was cancelled before completion.
Expand Down Expand Up @@ -818,6 +819,89 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
installer.cancel_job(job)


@model_manager_router.post(
"/install/{id}/pause",
operation_id="pause_model_install_job",
responses={
201: {"description": "The job was paused successfully"},
415: {"description": "No such job"},
},
status_code=201,
)
async def pause_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
"""Pause the model install job corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
job = installer.get_job_by_id(id)
except ValueError as e:
raise HTTPException(status_code=415, detail=str(e))
installer.pause_job(job)
return job


@model_manager_router.post(
"/install/{id}/resume",
operation_id="resume_model_install_job",
responses={
201: {"description": "The job was resumed successfully"},
415: {"description": "No such job"},
},
status_code=201,
)
async def resume_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
"""Resume a paused model install job corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
job = installer.get_job_by_id(id)
except ValueError as e:
raise HTTPException(status_code=415, detail=str(e))
installer.resume_job(job)
return job


@model_manager_router.post(
"/install/{id}/restart_failed",
operation_id="restart_failed_model_install_job",
responses={
201: {"description": "Failed files restarted successfully"},
415: {"description": "No such job"},
},
status_code=201,
)
async def restart_failed_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
"""Restart failed or non-resumable file downloads for the given job."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
job = installer.get_job_by_id(id)
except ValueError as e:
raise HTTPException(status_code=415, detail=str(e))
installer.restart_failed(job)
return job


@model_manager_router.post(
"/install/{id}/restart_file",
operation_id="restart_model_install_file",
responses={
201: {"description": "File restarted successfully"},
415: {"description": "No such job"},
},
status_code=201,
)
async def restart_model_install_file(
id: int = Path(description="Model install job ID"),
file_source: AnyHttpUrl = Body(description="File download URL to restart"),
) -> ModelInstallJob:
"""Restart a specific file download for the given job."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
job = installer.get_job_by_id(id)
except ValueError as e:
raise HTTPException(status_code=415, detail=str(e))
installer.restart_file(job, str(file_source))
return job


@model_manager_router.delete(
"/install",
operation_id="prune_model_install_jobs",
Expand Down
28 changes: 28 additions & 0 deletions invokeai/app/services/download/download_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class DownloadJobStatus(str, Enum):

WAITING = "waiting" # not enqueued, will not run
RUNNING = "running" # actively downloading
PAUSED = "paused" # paused, can be resumed
COMPLETED = "completed" # finished running
CANCELLED = "cancelled" # user cancelled
ERROR = "error" # terminated with an error message
Expand Down Expand Up @@ -61,6 +62,7 @@ class DownloadJobBase(BaseModel):

# internal flag
_cancelled: bool = PrivateAttr(default=False)
_paused: bool = PrivateAttr(default=False)

# optional event handlers passed in on creation
_on_start: Optional[DownloadEventHandler] = PrivateAttr(default=None)
Expand All @@ -72,6 +74,12 @@ class DownloadJobBase(BaseModel):
def cancel(self) -> None:
"""Call to cancel the job."""
self._cancelled = True
self._paused = False

def pause(self) -> None:
"""Pause the job, preserving partial downloads."""
self._paused = True
self._cancelled = True

# cancelled and the callbacks are private attributes in order to prevent
# them from being serialized and/or used in the Json Schema
Expand All @@ -80,6 +88,11 @@ def cancelled(self) -> bool:
"""Call to cancel the job."""
return self._cancelled

@property
def paused(self) -> bool:
"""Return true if job is paused."""
return self._paused

@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
Expand Down Expand Up @@ -161,6 +174,17 @@ class DownloadJob(DownloadJobBase):
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
canonical_url: Optional[str] = Field(default=None, description="Canonical URL to request on resume")
etag: Optional[str] = Field(default=None, description="ETag from the remote server, if available")
last_modified: Optional[str] = Field(default=None, description="Last-Modified from the remote server, if available")
final_url: Optional[str] = Field(default=None, description="Final resolved URL after redirects, if available")
expected_total_bytes: Optional[int] = Field(default=None, description="Expected total size of the download")
resume_required: bool = Field(default=False, description="True if server refused resume; restart required")
resume_message: Optional[str] = Field(default=None, description="Message explaining why resume is required")
resume_from_scratch: bool = Field(
default=False,
description="True if resume metadata existed but the partial file was missing and the download restarted from the beginning",
)

def __hash__(self) -> int:
"""Return hash of the string representation of this object, for indexing."""
Expand Down Expand Up @@ -321,6 +345,10 @@ def cancel_job(self, job: DownloadJobBase) -> None:
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass

def pause_job(self, job: DownloadJobBase) -> None: # noqa D401
"""Pause the job, preserving partial downloads."""
raise NotImplementedError

@abstractmethod
def join(self) -> None:
"""Wait until all jobs are off the queue."""
Expand Down
Loading