Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions diffsynth_engine/pipelines/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,8 @@ def __init__(
dtype=config.model_dtype,
)
self.config = config
self.tokenizer_max_length = 1024
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.prompt_template_encode_start_idx = 34
self.default_sample_size = 128
# sampler
self.noise_scheduler = RecifitedFlowScheduler(shift=3.0, use_dynamic_shifting=True)
self.sampler = FlowMatchEulerSampler()
Expand Down Expand Up @@ -262,7 +260,7 @@ def encode_prompt(
template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx
texts = [template.format(txt) for txt in prompt]
outputs = self.tokenizer(texts, max_length=min(max_sequence_length, self.tokenizer_max_length) + drop_idx)
outputs = self.tokenizer(texts, max_length=max_sequence_length + drop_idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The max_length argument is now directly used without considering self.tokenizer_max_length. It's crucial to ensure that max_sequence_length is appropriately constrained to prevent potential out-of-memory errors or unexpected behavior with extremely long sequences. Consider adding a check to ensure max_sequence_length does not exceed a reasonable limit, or document the expected behavior when it does.

outputs = self.tokenizer(texts, max_length=min(max_sequence_length + drop_idx, MAX_TOKEN_LENGTH))

input_ids, attention_mask = outputs["input_ids"].to(self.device), outputs["attention_mask"].to(self.device)
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
hidden_states = outputs["hidden_states"]
Expand Down