Skip to content

Add inpainting training and sampling support for SD1.5 and SDXL#2309

Open
allanoepping wants to merge 6 commits intokohya-ss:devfrom
allanoepping:inpainting
Open

Add inpainting training and sampling support for SD1.5 and SDXL#2309
allanoepping wants to merge 6 commits intokohya-ss:devfrom
allanoepping:inpainting

Conversation

@allanoepping
Copy link
Copy Markdown

Summary

This PR implements inpainting model training support for both SD1.5 and SDXL, based on the approach originally proposed in #173 by @Fannovel16. That PR was never merged; this is a ground-up reimplementation that brings it up to date with the current codebase, extends it to SDXL, and adds inpainting-aware sampling during training.

The core technique: the UNet receives a 9-channel input during training — 4 channels of noisy latents, 1 channel of a downsampled binary mask, and 4 channels of the VAE-encoded masked image. At inference, the same concatenation is applied inside the denoising loop so that sample images generated during training checkpoints reflect the inpainting task.

Changes

UNet — 9-channel input support

  • library/original_unet.py, library/sdxl_original_unet.py: added explicit in_channels parameter so an inpainting checkpoint's 9-channel conv_in weight is accepted without shape mismatch errors.

Model loading — auto-detect in_channels from checkpoint

  • library/model_util.py: reads conv_in.weight.shape[1] after conversion and overrides unet_config["in_channels"] when it differs from 4, enabling transparent loading of existing inpainting checkpoints (e.g. sd-v1-5-inpainting.ckpt).
  • library/sdxl_model_util.py, library/sdxl_train_util.py: same for SDXL (input_blocks.0.0.weight).

Training loop

  • fine_tune.py, train_db.py, train_network.py, train_textual_inversion.py, sdxl_train.py: when batch["masks"] is present, encodes the masked image via VAE and concatenates [noisy_latents, mask, masked_image_latents] as the UNet input. Uses latents.shape[2:] for mask interpolation (fixes tuple // int error when args.resolution is a tuple) and casts mask to weight_dtype before concatenation.

Dataset / config pipeline

  • library/train_util.py: train_inpainting: bool propagated through BaseDataset, DreamBoothDataset, FineTuningDataset, ControlNetDataset; --train_inpainting CLI flag; assertion that cache_latents and train_inpainting cannot be used together (masks are generated randomly per step from the source image); --img /path directive in prompt files for sampling; graceful skip-on-missing-image; sampling resolution rounded to multiples of 64.

Sampling pipelines

  • library/lpw_stable_diffusion.py, library/sdxl_lpw_stable_diffusion.py: added inpaint_image/inpaint_mask parameters; encodes masked image before the denoising loop; prepends 9-channel input to latent_model_input each step with proper dtype casting; prepare_latents uses vae.config.latent_channels (4) instead of unet.in_channels (9) so the initial noise tensor is always 4-channel.

Procedural mask generation

  • library/mask_generator.py (new): generates inpainting masks procedurally — fractional Brownian motion cloud masks (layered via cv2), convex polygon masks, and basic shape masks (rect/ellipse); combined and fully random modes. Used both during training and for sampling previews.

Test utilities

  • tests/generate_inpainting_test_data.py: generates synthetic training images for smoke tests.
  • tests/download_training_data.py: streams real images from common-canvas/commoncatalog-cc-by via HuggingFace datasets with metadata and post-download size filtering.
  • tests/visualize_masks.py: renders a gallery PNG for each mask type for visual inspection of mask quality.
  • tests/run_inpainting_test.sh, tests/run_sdxl_inpainting_test.sh: SD1.5 and SDXL smoke test scripts.
  • tests/sdxl_inpainting_test.toml: memory-efficient SDXL training config (Adafactor, bf16, gradient checkpointing, no cache_latents).

Usage

Train an inpainting model from an existing inpainting checkpoint

accelerate launch train_network.py
--pretrained_model_name_or_path sd-v1-5-inpainting.ckpt
--train_inpainting
...

Add to a prompt file to get inpainting sample images during training:

--img /path/to/reference.jpg

Notes
--train_inpainting is incompatible with --cache_latents / --cache_latents_to_disk because masks are generated randomly per step.
Existing non-inpainting training is unaffected; the inpainting path is only active when batch["masks"] is not None.

- 9-channel UNet input (noisy_latents + mask + masked_image_latents)
  wired through all training scripts (train_db, train_network,
  fine_tune, train_textual_inversion, sdxl_train)
- Auto-detect in_channels from checkpoint conv_in weight shape in
  model_util.py and sdxl_model_util.py; UNet constructors accept
  explicit in_channels parameter
- Inpainting inference added to lpw_stable_diffusion.py and
  sdxl_lpw_stable_diffusion.py: encodes masked image before denoising
  loop, prepends 9-ch input each step; latent init uses
  vae.config.latent_channels (4) not unet.in_channels (9)
- --train_inpainting CLI flag; cache_latents incompatibility assertion;
  --img prompt directive for sampling source image; missing image
  gracefully skips sample; resolution rounded to multiples of 64
- library/mask_generator.py: procedural cloud (fBm), polygon, shape,
  and combined random mask generation using numpy/cv2/PIL
- tests/: synthetic data generator, mask visualizer, HuggingFace
  training data downloader, SD1.5 and SDXL smoke test scripts and TOML
Fix conflicts from dev branch merge
@kohya-ss
Copy link
Copy Markdown
Owner

kohya-ss commented Apr 5, 2026

Thank you for this PR. I think it's very well implemented.

However, with several lightweight image editing models available now, I'm skeptical about how much demand there will be for SDXL inpainting models.

Furthermore, if other training tools such as Diffusers have this functionality, it may not necessarily be necessary to implement it in sd-scripts.

Could you tell me where SDXL inpainting models are being used or where there is demand for them?

@allanoepping
Copy link
Copy Markdown
Author

Thank you for your hard work on this project. Inpainting is still used and embedded into may toolsets. Inpainting is better when you need visual continuity, such as in medical imaging, but also in many other cases.

I'm not aware of an easy to use open-source tool, such as yours - for end-users - that can train for this. If you search for it Gemini even recommends Kohya_ss, which hasn't supported it (until now). Since SDXL is still widely used, and I've seen others ask about derivations of models being adapted for inpainting, I decided to implement this.

I don't think this will be hard to maintain, and most of the changes are isolated. If you have some suggestions on changes I can make to improve maintainability I'd be happy to make them.

@kohya-ss
Copy link
Copy Markdown
Owner

kohya-ss commented Apr 5, 2026

Thank you for the detailed explanation; I understand it well.

The code is indeed well-isolated and seems to have low maintenance costs. I will review it and consider merging it.

I will handle the consistency of the code with other parts of the repository after the merge. Could you please add documentation for this feature, even if it's only in English?

@allanoepping
Copy link
Copy Markdown
Author

allanoepping commented Apr 5, 2026 via email

@allanoepping
Copy link
Copy Markdown
Author

I've added documentation. I don't have a good way to verify the Japanese translation. I did a check but I don't have easy access to a native speaker to verify.

@kohya-ss
Copy link
Copy Markdown
Owner

kohya-ss commented Apr 6, 2026

Thank you, I think the documentation, including the Japanese version, is well done.

I will continue the review, but it seems that a minimal inference script for testing is not currently included. If possible, a script like flux_minimal_inference.py would be helpful. Interactive mode is not necessarily required; command-line arguments alone is sufficient.

Also, I would appreciate it if you could provide a URL where I can download the checkpoints (both SD and SDXL) to use for testing.

@allanoepping
Copy link
Copy Markdown
Author

I've added a inpainting_minimal_inference.py script and a wobbly ellipsoid mask generator for better sampling.

Here are the base inpainting models, although it should be able to train an inpainting model from a base model:

https://huggingface.co/wangqyqq/sd_xl_base_1.0_inpainting_0.1.safetensors/blob/main/sd_xl_base_1.0_inpainting_0.1.safetensors
https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt

A section was added to the inpainting_training.md for the inference script. I can remove that if not desired as not really an end-user script.

@kohya-ss
Copy link
Copy Markdown
Owner

kohya-ss commented Apr 7, 2026

Thank you again for update! I will review and merge this sooner.

@allanoepping
Copy link
Copy Markdown
Author

Thank you, I appreciate it!

@kohya-ss
Copy link
Copy Markdown
Owner

kohya-ss commented Apr 9, 2026

I'm testing from SD1.5 finetuning.

The documentation says "a standard model checkpoint if you want to train inpainting from scratch," but when I specify the weights of a standard model (not inpainting), I get the following error at line 1591 of original_unet.py:

return F.conv2d(
RuntimeError: Given groups=1, weight of size [320, 4, 3, 3], expected input[2, 9, 80, 56] to have 4 channels, but got 9 channels instead.

Could you please check this?

It might be necessary to create a model instance with 9 input channels and transform the shape of the weights of conv_in, before loading them using load_state_dict.

@allanoepping
Copy link
Copy Markdown
Author

Will do!

Thanks,
Allan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants