Pr0f3ssi0n4ln00b commited on
Commit
db11ff0
·
verified ·
1 Parent(s): c549fb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +260 -306
app.py CHANGED
@@ -1,19 +1,14 @@
1
- # app.py
2
  import os
3
  import re
4
  import gc
5
- import uuid
6
- import time
7
- import math
8
  import traceback
9
- import random
10
- from typing import Iterable, Optional
11
-
12
  import gradio as gr
13
  import numpy as np
14
  import spaces
15
  import torch
 
16
  from PIL import Image, ImageDraw
 
17
 
18
  from transformers import (
19
  AutoProcessor,
@@ -29,133 +24,6 @@ from safetensors.torch import load_file as safetensors_load_file
29
  from gradio.themes import Soft
30
  from gradio.themes.utils import colors, fonts, sizes
31
 
32
-
33
- # ============================================================
34
- # Process-unique temp dir (helps avoid /tmp collisions)
35
- # ============================================================
36
-
37
- def _ensure_unique_gradio_tmp():
38
- """
39
- ZeroGPU/Spaces can serve multiple users across recycled containers.
40
- Gradio may use /tmp/gradio by default. We force a unique directory per process.
41
- """
42
- if os.environ.get("GRADIO_TEMP_DIR"):
43
- print(f"GRADIO_TEMP_DIR = {os.environ['GRADIO_TEMP_DIR']}")
44
- return
45
- pid = os.getpid()
46
- tmp = f"/tmp/gradio_{pid}_{uuid.uuid4().hex[:8]}"
47
- os.environ["GRADIO_TEMP_DIR"] = tmp
48
- try:
49
- os.makedirs(tmp, exist_ok=True)
50
- except Exception:
51
- pass
52
- print(f"GRADIO_TEMP_DIR = {tmp}")
53
-
54
- _ensure_unique_gradio_tmp()
55
-
56
-
57
- # ============================================================
58
- # Patch: Qwen2.5-VL RoPE (avoid cublas batched GEMM; preserve shapes)
59
- # ============================================================
60
-
61
- def patch_qwen25vl_rope_no_gemm():
62
- """
63
- Patch Qwen2.5-VL rotary embedding to avoid the matmul that can trigger:
64
- CUBLAS_STATUS_INVALID_VALUE (cublasSgemmStridedBatched)
65
- on some ZeroGPU/H200 MIG configurations.
66
-
67
- CRITICAL: Preserve exact output shapes used by apply_multimodal_rotary_pos_emb,
68
- otherwise you get split_with_sizes mismatches.
69
- """
70
- if os.environ.get("DISABLE_ROPE_PATCH", "").strip() == "1":
71
- print("[patch][rope] DISABLE_ROPE_PATCH=1 -> skipping patch.")
72
- return
73
-
74
- try:
75
- from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl as qvl
76
- except Exception as e:
77
- print(f"[patch][rope] could not import qwen2_5_vl modeling: {e}")
78
- return
79
-
80
- Rotary = None
81
- for name in ["Qwen2_5_VLRotaryEmbedding", "Qwen2_5RotaryEmbedding", "RotaryEmbedding"]:
82
- Rotary = getattr(qvl, name, None)
83
- if Rotary is not None:
84
- break
85
-
86
- if Rotary is None:
87
- print("[patch][rope] rotary embedding class not found; no patch applied.")
88
- return
89
-
90
- orig_forward = Rotary.forward
91
-
92
- def forward_no_gemm(self, x, position_ids):
93
- # Fallback to original if structure differs
94
- if not hasattr(self, "inv_freq") or position_ids is None:
95
- return orig_forward(self, x, position_ids)
96
-
97
- # Determine rotary dim from module config (NOT x.shape[-1])
98
- if hasattr(self, "dim") and isinstance(self.dim, int):
99
- rope_dim = int(self.dim)
100
- else:
101
- rope_dim = int(self.inv_freq.numel() * 2)
102
-
103
- # Normalize position_ids to (bs, seq)
104
- if position_ids.ndim > 2:
105
- pos = position_ids.reshape(position_ids.shape[0], -1)
106
- else:
107
- pos = position_ids
108
-
109
- # Compute on the same device as inv_freq/x
110
- dev = self.inv_freq.device
111
- pos = pos.to(device=dev)
112
-
113
- # Broadcast multiply instead of matmul:
114
- # inv: (1,1,dim/2,1), pos: (bs,1,1,seq) -> freqs: (bs,1,dim/2,seq)
115
- inv = self.inv_freq[None, None, :, None].float()
116
- posf = pos[:, None, None, :].float()
117
- freqs = (inv * posf).transpose(2, 3) # (bs,1,seq,dim/2)
118
-
119
- # Double to full rotary dim
120
- emb = torch.cat((freqs, freqs), dim=-1) # (bs,1,seq,dim)
121
-
122
- # Enforce exact expected rotary dim
123
- if emb.shape[-1] != rope_dim:
124
- emb = emb[..., :rope_dim]
125
-
126
- cos = emb.cos()
127
- sin = emb.sin()
128
-
129
- # Respect attention scaling if present
130
- attn_scale = getattr(self, "attention_scaling", None)
131
- if attn_scale is not None:
132
- cos = cos * attn_scale
133
- sin = sin * attn_scale
134
-
135
- # Match dtype expectations (upstream typically returns same dtype as x)
136
- cos = cos.to(dtype=x.dtype)
137
- sin = sin.to(dtype=x.dtype)
138
-
139
- # Optional debug (enable by env)
140
- if os.environ.get("DEBUG_ROPE", "").strip() == "1":
141
- ms = getattr(self, "mrope_section", None)
142
- if ms is not None:
143
- try:
144
- ms_list = list(ms)
145
- print(f"[DEBUG][rope] rope_dim={rope_dim} cos_last={cos.shape[-1]} mrope_sum={sum(ms_list)} mrope={ms_list}")
146
- except Exception:
147
- print(f"[DEBUG][rope] rope_dim={rope_dim} cos_last={cos.shape[-1]} (mrope_section unreadable)")
148
- else:
149
- print(f"[DEBUG][rope] rope_dim={rope_dim} cos_last={cos.shape[-1]} (no mrope_section attr)")
150
-
151
- return cos, sin
152
-
153
- Rotary.forward = forward_no_gemm
154
- print("[patch] Patched Qwen2.5-VL RoPE matmul -> broadcast multiply (shape-preserving).")
155
-
156
- patch_qwen25vl_rope_no_gemm()
157
-
158
-
159
  # ============================================================
160
  # Theme
161
  # ============================================================
@@ -175,6 +43,7 @@ colors.orange_red = colors.Color(
175
  c950="#802200",
176
  )
177
 
 
178
  class OrangeRedTheme(Soft):
179
  def __init__(
180
  self,
@@ -230,11 +99,11 @@ class OrangeRedTheme(Soft):
230
  block_label_background_fill="*primary_200",
231
  )
232
 
233
- orange_red_theme = OrangeRedTheme()
234
 
 
235
 
236
  # ============================================================
237
- # Device / Env debug
238
  # ============================================================
239
 
240
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -249,14 +118,6 @@ if torch.cuda.is_available():
249
  print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
250
  print("Using device:", device)
251
 
252
- # Optional: make matmul a bit more stable (doesn't change correctness)
253
- try:
254
- torch.backends.cuda.matmul.allow_tf32 = True
255
- torch.backends.cudnn.allow_tf32 = True
256
- except Exception:
257
- pass
258
-
259
-
260
  # ============================================================
261
  # AIO version (Space variable)
262
  # ============================================================
@@ -267,6 +128,7 @@ DEFAULT_AIO_VERSION = "v19"
267
  _VER_RE = re.compile(r"^v\d+$")
268
  _DIGITS_RE = re.compile(r"^\d+$")
269
 
 
270
  def _normalize_version(raw: str) -> Optional[str]:
271
  if raw is None:
272
  return None
@@ -275,10 +137,12 @@ def _normalize_version(raw: str) -> Optional[str]:
275
  return None
276
  if _VER_RE.fullmatch(s):
277
  return s
 
278
  if _DIGITS_RE.fullmatch(s):
279
  return f"v{s}"
280
  return None
281
 
 
282
  _AIO_ENV_RAW = os.environ.get("AIO_VERSION", "")
283
  _AIO_ENV_NORM = _normalize_version(_AIO_ENV_RAW)
284
 
@@ -289,7 +153,6 @@ print(f"AIO_VERSION (env raw) = {_AIO_ENV_RAW!r}")
289
  print(f"AIO_VERSION (normalized) = {_AIO_ENV_NORM!r}")
290
  print(f"Using AIO_VERSION = {AIO_VERSION} ({AIO_VERSION_SOURCE})")
291
 
292
-
293
  # ============================================================
294
  # Pipeline
295
  # ============================================================
@@ -301,6 +164,7 @@ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
301
 
302
  dtype = torch.bfloat16
303
 
 
304
  def _load_pipe_with_version(version: str) -> QwenImageEditPlusPipeline:
305
  sub = f"{version}/transformer"
306
  print(f"📦 Loading AIO transformer: {AIO_REPO_ID} / {sub}")
@@ -316,9 +180,11 @@ def _load_pipe_with_version(version: str) -> QwenImageEditPlusPipeline:
316
  ).to(device)
317
  return p
318
 
 
 
319
  try:
320
  pipe = _load_pipe_with_version(AIO_VERSION)
321
- except Exception:
322
  print("❌ Failed to load requested AIO_VERSION. Falling back to v19.")
323
  print("---- exception ----")
324
  print(traceback.format_exc())
@@ -336,39 +202,47 @@ except Exception as e:
336
 
337
  MAX_SEED = np.iinfo(np.int32).max
338
 
339
-
340
  # ============================================================
341
- # Derived conditioning (Transformers): Pose + Depth (v1-style)
342
  # ============================================================
 
 
 
 
343
 
344
  POSE_MODEL_ID = "usyd-community/vitpose-base-simple"
345
  POSE_DETECTOR_ID = "PekingU/rtdetr_r50vd_coco_o365"
346
  DEPTH_MODEL_ID = "depth-anything/Depth-Anything-V2-Small-hf"
347
 
 
348
  _POSE_CACHE = {}
349
  _DEPTH_CACHE = {}
350
 
 
351
  COCO17_EDGES = [
352
- (0, 1), (0, 2), (1, 3), (2, 4),
353
- (5, 6),
354
- (5, 7), (7, 9),
355
- (6, 8), (8, 10),
356
- (5, 11), (6, 12), (11, 12),
357
- (11, 13), (13, 15),
358
- (12, 14), (14, 16),
359
  ]
360
 
361
  def _derived_device(use_gpu: bool) -> torch.device:
362
  return torch.device("cuda" if (use_gpu and torch.cuda.is_available()) else "cpu")
363
 
 
364
  def _load_pose_models(dev: torch.device):
365
  key = str(dev)
366
  if key in _POSE_CACHE:
367
  return _POSE_CACHE[key]
368
 
 
369
  det_proc = AutoProcessor.from_pretrained(POSE_DETECTOR_ID)
370
  det_model = RTDetrForObjectDetection.from_pretrained(POSE_DETECTOR_ID).to(dev)
371
 
 
372
  pose_proc = AutoProcessor.from_pretrained(POSE_MODEL_ID)
373
  pose_model = VitPoseForPoseEstimation.from_pretrained(POSE_MODEL_ID).to(dev)
374
 
@@ -378,6 +252,7 @@ def _load_pose_models(dev: torch.device):
378
  _POSE_CACHE[key] = (det_proc, det_model, pose_proc, pose_model)
379
  return _POSE_CACHE[key]
380
 
 
381
  def _load_depth_models(dev: torch.device):
382
  key = str(dev)
383
  if key in _DEPTH_CACHE:
@@ -390,6 +265,7 @@ def _load_depth_models(dev: torch.device):
390
  _DEPTH_CACHE[key] = (proc, model)
391
  return _DEPTH_CACHE[key]
392
 
 
393
  def _draw_skeleton_on_blank(
394
  size: tuple[int, int],
395
  persons_keypoints: list[np.ndarray],
@@ -403,6 +279,7 @@ def _draw_skeleton_on_blank(
403
  draw = ImageDraw.Draw(canvas)
404
 
405
  for kps, sc in zip(persons_keypoints, persons_scores):
 
406
  for a, b in COCO17_EDGES:
407
  if a >= len(sc) or b >= len(sc):
408
  continue
@@ -412,13 +289,20 @@ def _draw_skeleton_on_blank(
412
  xb, yb = float(kps[b, 0]), float(kps[b, 1])
413
  draw.line([(xa, ya), (xb, yb)], fill=(255, 255, 255), width=line_w)
414
 
 
415
  for i in range(min(len(sc), len(kps))):
416
  if sc[i] < kp_thresh:
417
  continue
418
  x, y = float(kps[i, 0]), float(kps[i, 1])
419
- draw.ellipse([(x - point_r, y - point_r), (x + point_r, y + point_r)], fill=(255, 255, 255))
 
 
 
 
 
420
  return canvas
421
 
 
422
  def make_pose_map(
423
  img: Image.Image,
424
  *,
@@ -427,6 +311,12 @@ def make_pose_map(
427
  det_thresh: float = 0.30,
428
  max_people: int = 4,
429
  ) -> Image.Image:
 
 
 
 
 
 
430
  img = img.convert("RGB")
431
  dev = _derived_device(use_gpu)
432
  det_proc, det_model, pose_proc, pose_model = _load_pose_models(dev)
@@ -434,8 +324,10 @@ def make_pose_map(
434
  w, h = img.size
435
 
436
  if mode == "fast":
 
437
  boxes = np.array([[0.0, 0.0, float(w), float(h)]], dtype=np.float32)
438
  else:
 
439
  inputs = det_proc(images=img, return_tensors="pt").to(dev)
440
  with torch.no_grad():
441
  outputs = det_model(**inputs)
@@ -446,11 +338,14 @@ def make_pose_map(
446
  threshold=det_thresh,
447
  )[0]
448
 
 
449
  person_boxes = results["boxes"][results["labels"] == 0].detach().cpu().numpy()
450
 
451
  if person_boxes.size == 0:
 
452
  boxes = np.array([[0.0, 0.0, float(w), float(h)]], dtype=np.float32)
453
  else:
 
454
  person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
455
  person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
456
  boxes = person_boxes.astype(np.float32)
@@ -464,17 +359,23 @@ def make_pose_map(
464
 
465
  pose_results = pose_proc.post_process_pose_estimation(pose_outputs, boxes=[boxes])[0]
466
 
467
- persons_kps, persons_sc = [], []
 
468
  for pr in pose_results:
469
- persons_kps.append(pr["keypoints"].detach().cpu().numpy())
470
- persons_sc.append(pr["scores"].detach().cpu().numpy())
 
 
471
 
472
  if not persons_kps:
 
473
  return Image.new("RGB", img.size, (0, 0, 0))
474
 
475
  return _draw_skeleton_on_blank(img.size, persons_kps, persons_sc)
476
 
 
477
  def make_depth_map(img: Image.Image, *, use_gpu: bool) -> Image.Image:
 
478
  img = img.convert("RGB")
479
  dev = _derived_device(use_gpu)
480
  proc, model = _load_depth_models(dev)
@@ -485,7 +386,10 @@ def make_depth_map(img: Image.Image, *, use_gpu: bool) -> Image.Image:
485
  with torch.no_grad():
486
  out = model(**inputs)
487
 
 
488
  pred = out.predicted_depth
 
 
489
  pred = torch.nn.functional.interpolate(
490
  pred.unsqueeze(1),
491
  size=(img.height, img.width),
@@ -499,55 +403,9 @@ def make_depth_map(img: Image.Image, *, use_gpu: bool) -> Image.Image:
499
  arr = arr / denom
500
 
501
  depth8 = (arr * 255.0).clip(0, 255).astype(np.uint8)
502
- return Image.fromarray(depth8, mode="L").convert("RGB")
503
-
504
-
505
- # ============================================================
506
- # Helpers: gallery normalization + debug-friendly PIL conversion
507
- # ============================================================
508
-
509
- def _to_pil_rgb(x) -> Optional[Image.Image]:
510
- """
511
- Accepts PIL / numpy / (image, caption) tuples / gradio dicts and returns PIL RGB.
512
- Also safely ignores broken temp paths.
513
- """
514
- if x is None:
515
- return None
516
 
517
- # Gallery often returns (image, caption)
518
- if isinstance(x, tuple) and len(x) >= 1:
519
- x = x[0]
520
- if x is None:
521
- return None
522
-
523
- # Some gradio versions can return dict with a temp file path
524
- if isinstance(x, dict):
525
- # common keys: 'name' or 'path'
526
- p = x.get("name") or x.get("path")
527
- if isinstance(p, str):
528
- if not os.path.exists(p):
529
- print(f"[WARN] extra image path missing, skipping: {p}")
530
- return None
531
- try:
532
- return Image.open(p).convert("RGB")
533
- except Exception as e:
534
- print(f"[WARN] failed to open extra image path {p}: {e}")
535
- return None
536
-
537
- if isinstance(x, Image.Image):
538
- return x.convert("RGB")
539
-
540
- if isinstance(x, np.ndarray):
541
- try:
542
- return Image.fromarray(x).convert("RGB")
543
- except Exception:
544
- return None
545
-
546
- # last resort
547
- try:
548
- return Image.fromarray(np.array(x)).convert("RGB")
549
- except Exception:
550
- return None
551
 
552
  def _append_to_gallery(existing, new_img: Image.Image):
553
  items = []
@@ -559,29 +417,8 @@ def _append_to_gallery(existing, new_img: Image.Image):
559
  items.append(new_img)
560
  return items
561
 
562
- def build_labeled_images(
563
- img1: Image.Image,
564
- img2: Optional[Image.Image],
565
- extra_imgs: Optional[list[Image.Image]],
566
- ) -> dict[str, Image.Image]:
567
- labeled: dict[str, Image.Image] = {}
568
- idx = 1
569
- labeled[f"image_{idx}"] = img1
570
- idx += 1
571
- if img2 is not None:
572
- labeled[f"image_{idx}"] = img2
573
- idx += 1
574
- if extra_imgs:
575
- for im in extra_imgs:
576
- if im is None:
577
- continue
578
- labeled[f"image_{idx}"] = im
579
- idx += 1
580
- return labeled
581
-
582
-
583
  # ============================================================
584
- # LoRA adapters + presets (your v1 config)
585
  # ============================================================
586
 
587
  NONE_LORA = "None"
@@ -649,7 +486,7 @@ ADAPTER_SPECS = {
649
  "weights": "bfs_head_v5_2511_original.safetensors",
650
  "adapter_name": "BFS-Best-Faceswap",
651
  "strength": 1.0,
652
- "needs_alpha_fix": True,
653
  },
654
  "BFS-Best-FaceSwap-merge": {
655
  "type": "single",
@@ -659,7 +496,7 @@ ADAPTER_SPECS = {
659
  "weights": "bfs_head_v5_2511_merged_version_rank_32_fp32.safetensors",
660
  "adapter_name": "BFS-Best-Faceswap-merge",
661
  "strength": 1.1,
662
- "needs_alpha_fix": True,
663
  },
664
  "F2P": {
665
  "type": "single",
@@ -745,13 +582,16 @@ LORA_PRESET_PROMPTS = {
745
  "BFS-Best-FaceSwap-merge": "head_swap: start with Picture 1 as the base image, keeping its lighting, environment, and background. remove the head from Picture 1 completely and replace it with the head from Picture 2, strictly preserving the hair, eye color, and nose structure of Picture 2. copy the eye direction, head rotation, and micro-expressions from Picture 1. high quality, sharp details, 4k",
746
  }
747
 
 
748
  LOADED_ADAPTERS = set()
749
 
750
-
751
  # ============================================================
752
- # Helpers: resolution (area-based sizing)
753
  # ============================================================
754
 
 
 
 
755
  def _round_to_multiple(x: int, m: int) -> int:
756
  return max(m, (int(x) // m) * m)
757
 
@@ -760,9 +600,16 @@ def compute_canvas_dimensions_from_area(
760
  target_area: int,
761
  multiple_of: int,
762
  ) -> tuple[int, int]:
 
 
 
 
763
  w, h = image.size
764
  aspect = w / h if h else 1.0
 
 
765
  from qwenimage.pipeline_qwenimage_edit_plus import calculate_dimensions
 
766
  width, height = calculate_dimensions(int(target_area), float(aspect))
767
  width = _round_to_multiple(int(width), int(multiple_of))
768
  height = _round_to_multiple(int(height), int(multiple_of))
@@ -773,18 +620,29 @@ def get_target_area_for_lora(
773
  lora_adapter: str,
774
  user_target_megapixels: float,
775
  ) -> int:
 
 
 
 
 
 
 
776
  spec = ADAPTER_SPECS.get(lora_adapter, {})
 
777
  if "target_area" in spec:
778
  try:
779
  return int(spec["target_area"])
780
  except Exception:
781
  pass
 
782
  if "target_megapixels" in spec:
783
  try:
784
  mp = float(spec["target_megapixels"])
785
  return int(mp * 1024 * 1024)
786
  except Exception:
787
  pass
 
 
788
  if "target_long_edge" in spec:
789
  try:
790
  long_edge = int(spec["target_long_edge"])
@@ -798,21 +656,99 @@ def get_target_area_for_lora(
798
  return int(new_w * new_h)
799
  except Exception:
800
  pass
801
- return int(float(user_target_megapixels) * 1024 * 1024)
802
 
 
 
803
 
804
  # ============================================================
805
- # Helpers: LoRA routing + BFS alpha fixes (your v1 logic)
806
  # ============================================================
807
 
 
808
  def lora_requires_two_images(lora_adapter: str) -> bool:
809
  return bool(ADAPTER_SPECS.get(lora_adapter, {}).get("requires_two_images", False))
810
 
 
811
  def image2_label_for_lora(lora_adapter: str) -> str:
812
  return str(ADAPTER_SPECS.get(lora_adapter, {}).get("image2_label", "Upload Reference (Image 2)"))
813
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
814
  def _inject_missing_alpha_keys(state_dict: dict) -> dict:
 
 
 
 
 
 
 
 
 
815
  bases = {}
 
816
  for k, v in state_dict.items():
817
  if not isinstance(v, torch.Tensor):
818
  continue
@@ -833,9 +769,27 @@ def _inject_missing_alpha_keys(state_dict: dict) -> dict:
833
  stripped_alpha = f"{stripped_base}.alpha"
834
  if stripped_alpha not in state_dict:
835
  state_dict[stripped_alpha] = alpha_tensor
 
836
  return state_dict
837
 
 
838
  def _filter_to_diffusers_lora_keys(state_dict: dict) -> tuple[dict, dict]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
839
  keep_suffixes = (
840
  ".lora_up.weight",
841
  ".lora_down.weight",
@@ -843,6 +797,7 @@ def _filter_to_diffusers_lora_keys(state_dict: dict) -> tuple[dict, dict]:
843
  ".alpha",
844
  ".lora_alpha",
845
  )
 
846
  dropped_patch = 0
847
  dropped_other = 0
848
  kept = 0
@@ -851,21 +806,28 @@ def _filter_to_diffusers_lora_keys(state_dict: dict) -> tuple[dict, dict]:
851
  out: dict[str, torch.Tensor] = {}
852
  for k, v in state_dict.items():
853
  if not isinstance(v, torch.Tensor):
 
854
  dropped_other += 1
855
  continue
 
 
856
  if k.endswith(".diff") or k.endswith(".diff_b"):
857
  dropped_patch += 1
858
  continue
 
859
  if not k.endswith(keep_suffixes):
860
  dropped_other += 1
861
  continue
 
862
  if k.endswith(".lora_alpha"):
 
863
  base = k[: -len(".lora_alpha")]
864
  k2 = f"{base}.alpha"
865
  out[k2] = v.float() if v.dtype != torch.float32 else v
866
  normalized_alpha += 1
867
  kept += 1
868
  continue
 
869
  out[k] = v
870
  kept += 1
871
 
@@ -877,7 +839,15 @@ def _filter_to_diffusers_lora_keys(state_dict: dict) -> tuple[dict, dict]:
877
  }
878
  return out, stats
879
 
 
880
  def _duplicate_stripped_prefix_keys(state_dict: dict, prefix: str = "diffusion_model.") -> dict:
 
 
 
 
 
 
 
881
  out = dict(state_dict)
882
  for k, v in list(state_dict.items()):
883
  if not k.startswith(prefix):
@@ -887,11 +857,18 @@ def _duplicate_stripped_prefix_keys(state_dict: dict, prefix: str = "diffusion_m
887
  out[stripped] = v
888
  return out
889
 
 
890
  def _load_lora_weights_with_fallback(repo: str, weight_name: str, adapter_name: str, needs_alpha_fix: bool = False):
 
 
 
 
891
  try:
892
  pipe.load_lora_weights(repo, weight_name=weight_name, adapter_name=adapter_name)
893
  return
894
  except (KeyError, ValueError) as e:
 
 
895
  if not needs_alpha_fix:
896
  raise
897
 
@@ -903,8 +880,13 @@ def _load_lora_weights_with_fallback(repo: str, weight_name: str, adapter_name:
903
  local_path = hf_hub_download(repo_id=repo, filename=weight_name)
904
  sd = safetensors_load_file(local_path)
905
 
 
906
  sd = _inject_missing_alpha_keys(sd)
 
 
907
  sd, stats = _filter_to_diffusers_lora_keys(sd)
 
 
908
  sd = _duplicate_stripped_prefix_keys(sd)
909
 
910
  print(
@@ -912,9 +894,16 @@ def _load_lora_weights_with_fallback(repo: str, weight_name: str, adapter_name:
912
  f"kept={stats['kept']} dropped_patch={stats['dropped_patch']} "
913
  f"dropped_other={stats['dropped_other']} normalized_alpha={stats['normalized_alpha']}"
914
  )
 
915
  pipe.load_lora_weights(sd, adapter_name=adapter_name)
916
  return
917
 
 
 
 
 
 
 
918
  def _ensure_loaded_and_get_active_adapters(selected_lora: str):
919
  spec = ADAPTER_SPECS.get(selected_lora)
920
  if not spec:
@@ -952,6 +941,7 @@ def _ensure_loaded_and_get_active_adapters(selected_lora: str):
952
 
953
  adapter_names.append(adapter_name)
954
  adapter_weights.append(strength)
 
955
  else:
956
  repo = spec["repo"]
957
  weights = spec["weights"]
@@ -984,7 +974,9 @@ def _ensure_loaded_and_get_active_adapters(selected_lora: str):
984
  # UI handlers
985
  # ============================================================
986
 
 
987
  def on_lora_change_ui(selected_lora, current_prompt, current_extras_condition_only):
 
988
  if selected_lora != NONE_LORA:
989
  preset = LORA_PRESET_PROMPTS.get(selected_lora, "")
990
  if preset and (current_prompt is None or str(current_prompt).strip() == ""):
@@ -994,19 +986,20 @@ def on_lora_change_ui(selected_lora, current_prompt, current_extras_condition_on
994
  else:
995
  prompt_update = gr.update(value=current_prompt)
996
 
 
997
  if lora_requires_two_images(selected_lora):
998
  img2_update = gr.update(visible=True, label=image2_label_for_lora(selected_lora))
999
  else:
1000
  img2_update = gr.update(visible=False, value=None, label="Upload Reference (Image 2)")
1001
 
 
 
1002
  if selected_lora in ("BFS-Best-FaceSwap", "BFS-Best-FaceSwap-merge", "AnyPose"):
1003
  extras_update = gr.update(value=True)
1004
  else:
1005
  extras_update = gr.update(value=current_extras_condition_only)
1006
 
1007
  return prompt_update, img2_update, extras_update
1008
-
1009
-
1010
  # ============================================================
1011
  # UI helpers: output routing + derived conditioning
1012
  # ============================================================
@@ -1016,16 +1009,19 @@ def set_output_as_image1(last):
1016
  raise gr.Error("No output available yet.")
1017
  return gr.update(value=last)
1018
 
 
1019
  def set_output_as_image2(last):
1020
  if last is None:
1021
  raise gr.Error("No output available yet.")
1022
  return gr.update(value=last)
1023
 
 
1024
  def set_output_as_extra(last, existing_extra):
1025
  if last is None:
1026
  raise gr.Error("No output available yet.")
1027
  return _append_to_gallery(existing_extra, last)
1028
 
 
1029
  @spaces.GPU
1030
  def add_derived_ref(img1, existing_extra, derived_type, derived_use_gpu, derived_max_people):
1031
  if img1 is None:
@@ -1054,36 +1050,17 @@ def add_derived_ref(img1, existing_extra, derived_type, derived_use_gpu, derived
1054
  return gr.update(value=new_gallery), gr.update(visible=True, value=derived)
1055
 
1056
 
1057
- # ============================================================
1058
- # Debug helpers (CUDA mem + token count)
1059
- # ============================================================
1060
-
1061
- def _cuda_mem(prefix: str):
1062
- if not torch.cuda.is_available():
1063
- return
1064
- try:
1065
- free, total = torch.cuda.mem_get_info()
1066
- print(f"[DEBUG][cuda][{prefix}] mem free={free/1e9:.2f}GB total={total/1e9:.2f}GB")
1067
- except Exception:
1068
- pass
1069
-
1070
- def _approx_token_count(text: str) -> int:
1071
- # Lightweight: we avoid forcing tokenizer calls here; this is only for debug.
1072
- # Rule-of-thumb: ~4 chars per token in English-ish text.
1073
- if not text:
1074
- return 0
1075
- return max(1, int(math.ceil(len(text) / 4.0)))
1076
-
1077
 
1078
  # ============================================================
1079
  # Inference
1080
  # ============================================================
1081
 
 
1082
  @spaces.GPU
1083
  def infer(
1084
  input_image_1,
1085
  input_image_2,
1086
- input_images_extra,
1087
  prompt,
1088
  lora_adapter,
1089
  seed,
@@ -1099,8 +1076,6 @@ def infer(
1099
  if torch.cuda.is_available():
1100
  torch.cuda.empty_cache()
1101
 
1102
- print("[DEBUG][infer] input types:", type(input_image_1), type(input_image_2), type(input_images_extra))
1103
-
1104
  if input_image_1 is None:
1105
  raise gr.Error("Please upload Image 1.")
1106
 
@@ -1119,23 +1094,15 @@ def infer(
1119
  seed = random.randint(0, MAX_SEED)
1120
 
1121
  generator = torch.Generator(device=device).manual_seed(seed)
1122
-
1123
  negative_prompt = (
1124
  "worst quality, low quality, bad anatomy, bad hands, text, error, missing fingers, "
1125
  "extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"
1126
  )
1127
 
1128
- img1 = input_image_1.convert("RGB") if isinstance(input_image_1, Image.Image) else _to_pil_rgb(input_image_1)
1129
- if img1 is None:
1130
- raise gr.Error("Image 1 could not be read (unexpected input type/path).")
1131
-
1132
- img2 = None
1133
- if input_image_2 is not None:
1134
- img2 = input_image_2.convert("RGB") if isinstance(input_image_2, Image.Image) else _to_pil_rgb(input_image_2)
1135
- if img2 is None:
1136
- raise gr.Error("Image 2 could not be read (unexpected input type/path).")
1137
 
1138
- # Normalize extra images (Gallery)
1139
  extra_imgs: list[Image.Image] = []
1140
  if input_images_extra:
1141
  for item in input_images_extra:
@@ -1143,16 +1110,20 @@ def infer(
1143
  if pil is not None:
1144
  extra_imgs.append(pil)
1145
 
1146
- # Enforce 2-image LoRA behavior
1147
  if lora_requires_two_images(lora_adapter) and img2 is None:
1148
  raise gr.Error("This LoRA needs two images. Please upload Image 2 as well.")
1149
 
 
1150
  labeled = build_labeled_images(img1, img2, extra_imgs)
1151
 
 
1152
  pipe_images = list(labeled.values())
1153
  if len(pipe_images) == 1:
1154
  pipe_images = pipe_images[0]
1155
 
 
 
1156
  target_area = get_target_area_for_lora(img1, lora_adapter, float(target_megapixels))
1157
  width, height = compute_canvas_dimensions_from_area(
1158
  img1,
@@ -1160,48 +1131,33 @@ def infer(
1160
  multiple_of=int(pipe.vae_scale_factor * 2),
1161
  )
1162
 
 
 
1163
  vae_image_indices = None
1164
  if extras_condition_only:
1165
  if isinstance(pipe_images, list) and len(pipe_images) > 2:
1166
  vae_image_indices = [0, 1] if len(pipe_images) >= 2 else [0]
1167
 
1168
- # Debug summary
1169
- n_images = len(pipe_images) if isinstance(pipe_images, list) else 1
1170
- tok_est = _approx_token_count(prompt or "")
1171
- print(
1172
- "[DEBUG][infer] submitting request | "
1173
- f"lora_adapter={lora_adapter!r} seed={seed} prompt_len={len(prompt or '')} "
1174
- f"steps={steps} true_cfg_scale={guidance_scale} target_mp={target_megapixels} "
1175
- f"canvas=({width}x{height}) n_images={n_images} vae_image_indices={vae_image_indices} "
1176
- f"pad_to_canvas={bool(pad_to_canvas)}"
1177
- )
1178
- print(f"[DEBUG][infer] image_1 size: {img1.size} image_2 size: {img2.size if img2 else None}")
1179
- print(f"[DEBUG][infer] prompt token_estimate: {tok_est}")
1180
- _cuda_mem("before")
1181
-
1182
  try:
 
 
 
 
 
1183
  result = pipe(
1184
  image=pipe_images,
1185
  prompt=prompt,
1186
  negative_prompt=negative_prompt,
1187
  height=height,
1188
  width=width,
1189
- num_inference_steps=int(steps),
1190
  generator=generator,
1191
- true_cfg_scale=float(guidance_scale),
1192
  vae_image_indices=vae_image_indices,
1193
  pad_to_canvas=bool(pad_to_canvas),
1194
  ).images[0]
1195
  return result, seed, result
1196
-
1197
- except Exception as e:
1198
- print("---- [ERROR][infer] exception ----")
1199
- print(traceback.format_exc())
1200
- print("---------------------------------")
1201
- raise
1202
-
1203
  finally:
1204
- _cuda_mem("after")
1205
  gc.collect()
1206
  if torch.cuda.is_available():
1207
  torch.cuda.empty_cache()
@@ -1214,13 +1170,8 @@ def infer_example(input_image, prompt, lora_adapter):
1214
  input_pil = input_image.convert("RGB")
1215
  guidance_scale = 1.0
1216
  steps = 4
1217
- result, seed, last = infer(
1218
- input_pil, None, None,
1219
- prompt, lora_adapter,
1220
- 0, True,
1221
- guidance_scale, steps, 1.0,
1222
- True, True
1223
- )
1224
  return result, seed, last
1225
 
1226
 
@@ -1247,7 +1198,7 @@ with gr.Blocks() as demo:
1247
  gr.Markdown(
1248
  "Perform diverse image edits using specialized "
1249
  "[LoRA](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image-Edit-2511) adapters for the "
1250
- "[Qwen-Image-Edit](https://huggingface.co/Qwen/Qwen-Image-Edit-2511) model. Uses a Diffusers compatible extraction of the transformers from Phr00t's Rapid AIO merge."
1251
  )
1252
  gr.Markdown(aio_status_line)
1253
 
@@ -1341,6 +1292,7 @@ with gr.Blocks() as demo:
1341
  value=True,
1342
  )
1343
 
 
1344
  lora_adapter.change(
1345
  fn=on_lora_change_ui,
1346
  inputs=[lora_adapter, prompt, extras_condition_only],
@@ -1400,16 +1352,18 @@ with gr.Blocks() as demo:
1400
  outputs=[output_image, seed, last_output],
1401
  )
1402
 
 
1403
  btn_out_to_img1.click(fn=set_output_as_image1, inputs=[last_output], outputs=[input_image_1])
1404
  btn_out_to_img2.click(fn=set_output_as_image2, inputs=[last_output], outputs=[input_image_2])
1405
  btn_out_to_extra.click(fn=set_output_as_extra, inputs=[last_output, input_images_extra], outputs=[input_images_extra])
1406
-
 
1407
  add_derived_btn.click(
1408
  fn=add_derived_ref,
1409
  inputs=[input_image_1, input_images_extra, derived_type, derived_use_gpu, derived_max_people],
1410
  outputs=[input_images_extra, derived_preview],
1411
  )
1412
-
1413
  if __name__ == "__main__":
1414
  demo.queue(max_size=30).launch(
1415
  css=css,
 
 
1
  import os
2
  import re
3
  import gc
 
 
 
4
  import traceback
 
 
 
5
  import gradio as gr
6
  import numpy as np
7
  import spaces
8
  import torch
9
+ import random
10
  from PIL import Image, ImageDraw
11
+ from typing import Iterable, Optional
12
 
13
  from transformers import (
14
  AutoProcessor,
 
24
  from gradio.themes import Soft
25
  from gradio.themes.utils import colors, fonts, sizes
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # ============================================================
28
  # Theme
29
  # ============================================================
 
43
  c950="#802200",
44
  )
45
 
46
+
47
  class OrangeRedTheme(Soft):
48
  def __init__(
49
  self,
 
99
  block_label_background_fill="*primary_200",
100
  )
101
 
 
102
 
103
+ orange_red_theme = OrangeRedTheme()
104
 
105
  # ============================================================
106
+ # Device
107
  # ============================================================
108
 
109
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
118
  print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
119
  print("Using device:", device)
120
 
 
 
 
 
 
 
 
 
121
  # ============================================================
122
  # AIO version (Space variable)
123
  # ============================================================
 
128
  _VER_RE = re.compile(r"^v\d+$")
129
  _DIGITS_RE = re.compile(r"^\d+$")
130
 
131
+
132
  def _normalize_version(raw: str) -> Optional[str]:
133
  if raw is None:
134
  return None
 
137
  return None
138
  if _VER_RE.fullmatch(s):
139
  return s
140
+ # forgiving: allow "21" -> "v21"
141
  if _DIGITS_RE.fullmatch(s):
142
  return f"v{s}"
143
  return None
144
 
145
+
146
  _AIO_ENV_RAW = os.environ.get("AIO_VERSION", "")
147
  _AIO_ENV_NORM = _normalize_version(_AIO_ENV_RAW)
148
 
 
153
  print(f"AIO_VERSION (normalized) = {_AIO_ENV_NORM!r}")
154
  print(f"Using AIO_VERSION = {AIO_VERSION} ({AIO_VERSION_SOURCE})")
155
 
 
156
  # ============================================================
157
  # Pipeline
158
  # ============================================================
 
164
 
165
  dtype = torch.bfloat16
166
 
167
+
168
  def _load_pipe_with_version(version: str) -> QwenImageEditPlusPipeline:
169
  sub = f"{version}/transformer"
170
  print(f"📦 Loading AIO transformer: {AIO_REPO_ID} / {sub}")
 
180
  ).to(device)
181
  return p
182
 
183
+
184
+ # Forgiving load: try env/default version, fallback to v19 if it fails
185
  try:
186
  pipe = _load_pipe_with_version(AIO_VERSION)
187
+ except Exception as e:
188
  print("❌ Failed to load requested AIO_VERSION. Falling back to v19.")
189
  print("---- exception ----")
190
  print(traceback.format_exc())
 
202
 
203
  MAX_SEED = np.iinfo(np.int32).max
204
 
 
205
  # ============================================================
206
+ # Derived conditioning (Transformers): Pose + Depth
207
  # ============================================================
208
+ # Pose estimation uses ViTPose (top-down). Official docs show RT-DETR -> ViTPose flow:
209
+ # https://huggingface.co/docs/transformers/model_doc/vitpose
210
+ # Depth uses Depth Anything V2 Small (Transformers-compatible):
211
+ # https://huggingface.co/depth-anything/Depth-Anything-V2-Small-hf
212
 
213
  POSE_MODEL_ID = "usyd-community/vitpose-base-simple"
214
  POSE_DETECTOR_ID = "PekingU/rtdetr_r50vd_coco_o365"
215
  DEPTH_MODEL_ID = "depth-anything/Depth-Anything-V2-Small-hf"
216
 
217
+ # Lazy caches keyed by device string ("cpu" / "cuda")
218
  _POSE_CACHE = {}
219
  _DEPTH_CACHE = {}
220
 
221
+ # COCO-17 skeleton connections (approx "OpenPose-like" stick figure)
222
  COCO17_EDGES = [
223
+ (0, 1), (0, 2), (1, 3), (2, 4), # head
224
+ (5, 6), # shoulders
225
+ (5, 7), (7, 9), # left arm
226
+ (6, 8), (8, 10), # right arm
227
+ (5, 11), (6, 12), (11, 12), # torso
228
+ (11, 13), (13, 15), # left leg
229
+ (12, 14), (14, 16), # right leg
230
  ]
231
 
232
  def _derived_device(use_gpu: bool) -> torch.device:
233
  return torch.device("cuda" if (use_gpu and torch.cuda.is_available()) else "cpu")
234
 
235
+
236
  def _load_pose_models(dev: torch.device):
237
  key = str(dev)
238
  if key in _POSE_CACHE:
239
  return _POSE_CACHE[key]
240
 
241
+ # Detector (optional but used for multi-person boxes)
242
  det_proc = AutoProcessor.from_pretrained(POSE_DETECTOR_ID)
243
  det_model = RTDetrForObjectDetection.from_pretrained(POSE_DETECTOR_ID).to(dev)
244
 
245
+ # Pose model
246
  pose_proc = AutoProcessor.from_pretrained(POSE_MODEL_ID)
247
  pose_model = VitPoseForPoseEstimation.from_pretrained(POSE_MODEL_ID).to(dev)
248
 
 
252
  _POSE_CACHE[key] = (det_proc, det_model, pose_proc, pose_model)
253
  return _POSE_CACHE[key]
254
 
255
+
256
  def _load_depth_models(dev: torch.device):
257
  key = str(dev)
258
  if key in _DEPTH_CACHE:
 
265
  _DEPTH_CACHE[key] = (proc, model)
266
  return _DEPTH_CACHE[key]
267
 
268
+
269
  def _draw_skeleton_on_blank(
270
  size: tuple[int, int],
271
  persons_keypoints: list[np.ndarray],
 
279
  draw = ImageDraw.Draw(canvas)
280
 
281
  for kps, sc in zip(persons_keypoints, persons_scores):
282
+ # Draw edges
283
  for a, b in COCO17_EDGES:
284
  if a >= len(sc) or b >= len(sc):
285
  continue
 
289
  xb, yb = float(kps[b, 0]), float(kps[b, 1])
290
  draw.line([(xa, ya), (xb, yb)], fill=(255, 255, 255), width=line_w)
291
 
292
+ # Draw keypoints
293
  for i in range(min(len(sc), len(kps))):
294
  if sc[i] < kp_thresh:
295
  continue
296
  x, y = float(kps[i, 0]), float(kps[i, 1])
297
+ draw.ellipse(
298
+ [(x - point_r, y - point_r), (x + point_r, y + point_r)],
299
+ fill=(255, 255, 255),
300
+ outline=None,
301
+ )
302
+
303
  return canvas
304
 
305
+
306
  def make_pose_map(
307
  img: Image.Image,
308
  *,
 
311
  det_thresh: float = 0.30,
312
  max_people: int = 4,
313
  ) -> Image.Image:
314
+ """Return an OpenPose-like skeleton map (RGB) using Transformers models.
315
+
316
+ mode:
317
+ - "fast": full-frame box (no detector). Good when Image 1 is already a single subject.
318
+ - "detect": RT-DETR person boxes -> ViTPose. Better for multi-person scenes.
319
+ """
320
  img = img.convert("RGB")
321
  dev = _derived_device(use_gpu)
322
  det_proc, det_model, pose_proc, pose_model = _load_pose_models(dev)
 
324
  w, h = img.size
325
 
326
  if mode == "fast":
327
+ # Single box covering whole image, COCO format [x, y, w, h]
328
  boxes = np.array([[0.0, 0.0, float(w), float(h)]], dtype=np.float32)
329
  else:
330
+ # Detect people
331
  inputs = det_proc(images=img, return_tensors="pt").to(dev)
332
  with torch.no_grad():
333
  outputs = det_model(**inputs)
 
338
  threshold=det_thresh,
339
  )[0]
340
 
341
+ # COCO label 0 is "person" for COCO-trained detectors
342
  person_boxes = results["boxes"][results["labels"] == 0].detach().cpu().numpy()
343
 
344
  if person_boxes.size == 0:
345
+ # Fallback to full-frame
346
  boxes = np.array([[0.0, 0.0, float(w), float(h)]], dtype=np.float32)
347
  else:
348
+ # Convert VOC x1,y1,x2,y2 to COCO x,y,w,h
349
  person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
350
  person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
351
  boxes = person_boxes.astype(np.float32)
 
359
 
360
  pose_results = pose_proc.post_process_pose_estimation(pose_outputs, boxes=[boxes])[0]
361
 
362
+ persons_kps = []
363
+ persons_sc = []
364
  for pr in pose_results:
365
+ kps = pr["keypoints"].detach().cpu().numpy()
366
+ sc = pr["scores"].detach().cpu().numpy()
367
+ persons_kps.append(kps)
368
+ persons_sc.append(sc)
369
 
370
  if not persons_kps:
371
+ # No pose found; return black canvas
372
  return Image.new("RGB", img.size, (0, 0, 0))
373
 
374
  return _draw_skeleton_on_blank(img.size, persons_kps, persons_sc)
375
 
376
+
377
  def make_depth_map(img: Image.Image, *, use_gpu: bool) -> Image.Image:
378
+ """Return a grayscale (RGB) depth map using Depth Anything V2 Small."""
379
  img = img.convert("RGB")
380
  dev = _derived_device(use_gpu)
381
  proc, model = _load_depth_models(dev)
 
386
  with torch.no_grad():
387
  out = model(**inputs)
388
 
389
+ # predicted_depth: (B, H, W)
390
  pred = out.predicted_depth
391
+
392
+ # Upsample to original image size
393
  pred = torch.nn.functional.interpolate(
394
  pred.unsqueeze(1),
395
  size=(img.height, img.width),
 
403
  arr = arr / denom
404
 
405
  depth8 = (arr * 255.0).clip(0, 255).astype(np.uint8)
406
+ depth_img = Image.fromarray(depth8, mode="L").convert("RGB")
407
+ return depth_img
 
 
 
 
 
 
 
 
 
 
 
 
408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
  def _append_to_gallery(existing, new_img: Image.Image):
411
  items = []
 
417
  items.append(new_img)
418
  return items
419
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  # ============================================================
421
+ # LoRA adapters + presets
422
  # ============================================================
423
 
424
  NONE_LORA = "None"
 
486
  "weights": "bfs_head_v5_2511_original.safetensors",
487
  "adapter_name": "BFS-Best-Faceswap",
488
  "strength": 1.0,
489
+ "needs_alpha_fix": True, # <-- fixes KeyError 'img_in.alpha'
490
  },
491
  "BFS-Best-FaceSwap-merge": {
492
  "type": "single",
 
496
  "weights": "bfs_head_v5_2511_merged_version_rank_32_fp32.safetensors",
497
  "adapter_name": "BFS-Best-Faceswap-merge",
498
  "strength": 1.1,
499
+ "needs_alpha_fix": True, # <-- fixes KeyError 'img_in.alpha'
500
  },
501
  "F2P": {
502
  "type": "single",
 
582
  "BFS-Best-FaceSwap-merge": "head_swap: start with Picture 1 as the base image, keeping its lighting, environment, and background. remove the head from Picture 1 completely and replace it with the head from Picture 2, strictly preserving the hair, eye color, and nose structure of Picture 2. copy the eye direction, head rotation, and micro-expressions from Picture 1. high quality, sharp details, 4k",
583
  }
584
 
585
+ # Track what is currently loaded in memory (adapter_name values)
586
  LOADED_ADAPTERS = set()
587
 
 
588
  # ============================================================
589
+ # Helpers: resolution
590
  # ============================================================
591
 
592
+ # We prefer *area-based* sizing (≈ megapixels) over long-edge sizing.
593
+ # This aligns better with Qwen-Image-Edit's internal assumptions and reduces FOV drift.
594
+
595
  def _round_to_multiple(x: int, m: int) -> int:
596
  return max(m, (int(x) // m) * m)
597
 
 
600
  target_area: int,
601
  multiple_of: int,
602
  ) -> tuple[int, int]:
603
+ """Compute (width, height) that matches image aspect ratio and approximates target_area.
604
+
605
+ The result is floored to be divisible by multiple_of (typically vae_scale_factor*2).
606
+ """
607
  w, h = image.size
608
  aspect = w / h if h else 1.0
609
+
610
+ # Use the pipeline's own area->(w,h) helper for consistency.
611
  from qwenimage.pipeline_qwenimage_edit_plus import calculate_dimensions
612
+
613
  width, height = calculate_dimensions(int(target_area), float(aspect))
614
  width = _round_to_multiple(int(width), int(multiple_of))
615
  height = _round_to_multiple(int(height), int(multiple_of))
 
620
  lora_adapter: str,
621
  user_target_megapixels: float,
622
  ) -> int:
623
+ """Return target pixel area for the canvas.
624
+
625
+ Priority:
626
+ 1) Adapter spec: target_area (pixels) or target_megapixels
627
+ 2) Adapter spec: target_long_edge (legacy) -> converted to area using image aspect
628
+ 3) User slider target megapixels
629
+ """
630
  spec = ADAPTER_SPECS.get(lora_adapter, {})
631
+
632
  if "target_area" in spec:
633
  try:
634
  return int(spec["target_area"])
635
  except Exception:
636
  pass
637
+
638
  if "target_megapixels" in spec:
639
  try:
640
  mp = float(spec["target_megapixels"])
641
  return int(mp * 1024 * 1024)
642
  except Exception:
643
  pass
644
+
645
+ # Legacy support (e.g. Upscale2K)
646
  if "target_long_edge" in spec:
647
  try:
648
  long_edge = int(spec["target_long_edge"])
 
656
  return int(new_w * new_h)
657
  except Exception:
658
  pass
 
659
 
660
+ # User default
661
+ return int(float(user_target_megapixels) * 1024 * 1024)
662
 
663
  # ============================================================
664
+ # Helpers: multi-input routing + gallery normalization
665
  # ============================================================
666
 
667
+
668
  def lora_requires_two_images(lora_adapter: str) -> bool:
669
  return bool(ADAPTER_SPECS.get(lora_adapter, {}).get("requires_two_images", False))
670
 
671
+
672
  def image2_label_for_lora(lora_adapter: str) -> str:
673
  return str(ADAPTER_SPECS.get(lora_adapter, {}).get("image2_label", "Upload Reference (Image 2)"))
674
 
675
+
676
+ def _to_pil_rgb(x) -> Optional[Image.Image]:
677
+ """
678
+ Accepts PIL / numpy / (image, caption) tuples from gr.Gallery and returns PIL RGB.
679
+ Gradio Gallery commonly yields tuples like (image, caption).
680
+ """
681
+ if x is None:
682
+ return None
683
+
684
+ # Gallery often returns (image, caption)
685
+ if isinstance(x, tuple) and len(x) >= 1:
686
+ x = x[0]
687
+ if x is None:
688
+ return None
689
+
690
+ if isinstance(x, Image.Image):
691
+ return x.convert("RGB")
692
+
693
+ if isinstance(x, np.ndarray):
694
+ return Image.fromarray(x).convert("RGB")
695
+
696
+ # Best-effort fallback
697
+ try:
698
+ return Image.fromarray(np.array(x)).convert("RGB")
699
+ except Exception:
700
+ return None
701
+
702
+
703
+ def build_labeled_images(
704
+ img1: Image.Image,
705
+ img2: Optional[Image.Image],
706
+ extra_imgs: Optional[list[Image.Image]],
707
+ ) -> dict[str, Image.Image]:
708
+ """
709
+ Creates labels image_1, image_2, image_3... based on what is actually uploaded:
710
+ - img1 is always image_1
711
+ - img2 becomes image_2 only if present
712
+ - extras start immediately after the last present base box
713
+ The pipeline receives images in this exact order.
714
+ """
715
+ labeled: dict[str, Image.Image] = {}
716
+ idx = 1
717
+
718
+ labeled[f"image_{idx}"] = img1
719
+ idx += 1
720
+
721
+ if img2 is not None:
722
+ labeled[f"image_{idx}"] = img2
723
+ idx += 1
724
+
725
+ if extra_imgs:
726
+ for im in extra_imgs:
727
+ if im is None:
728
+ continue
729
+ labeled[f"image_{idx}"] = im
730
+ idx += 1
731
+
732
+ return labeled
733
+
734
+
735
+ # ============================================================
736
+ # Helpers: BFS alpha key fix
737
+ # ============================================================
738
+
739
+
740
  def _inject_missing_alpha_keys(state_dict: dict) -> dict:
741
+ """
742
+ Diffusers' Qwen LoRA converter expects '<module>.alpha' keys.
743
+ BFS safetensors omits them. We inject alpha = rank (neutral scaling).
744
+
745
+ IMPORTANT: diffusers may strip 'diffusion_model.' before lookup, so we
746
+ inject BOTH:
747
+ - diffusion_model.xxx.alpha
748
+ - xxx.alpha
749
+ """
750
  bases = {}
751
+
752
  for k, v in state_dict.items():
753
  if not isinstance(v, torch.Tensor):
754
  continue
 
769
  stripped_alpha = f"{stripped_base}.alpha"
770
  if stripped_alpha not in state_dict:
771
  state_dict[stripped_alpha] = alpha_tensor
772
+
773
  return state_dict
774
 
775
+
776
  def _filter_to_diffusers_lora_keys(state_dict: dict) -> tuple[dict, dict]:
777
+ """Return (filtered_state_dict, stats).
778
+
779
+ Some ComfyUI/Qwen safetensors (especially "merged" variants) include non-LoRA
780
+ delta/patch keys like `*.diff` and `*.diff_b` alongside real LoRA tensors.
781
+ Diffusers' internal Qwen LoRA converter is strict: any leftover keys cause an
782
+ error (`state_dict should be empty...`).
783
+
784
+ This helper keeps only the keys Diffusers can consume as a LoRA:
785
+ - `*.lora_up.weight`
786
+ - `*.lora_down.weight`
787
+ - (rare) `*.lora_mid.weight`
788
+ - alpha keys: `*.alpha` (or `*.lora_alpha` which we normalize to `*.alpha`)
789
+
790
+ It also drops known patch keys (`*.diff`, `*.diff_b`) and everything else.
791
+ """
792
+
793
  keep_suffixes = (
794
  ".lora_up.weight",
795
  ".lora_down.weight",
 
797
  ".alpha",
798
  ".lora_alpha",
799
  )
800
+
801
  dropped_patch = 0
802
  dropped_other = 0
803
  kept = 0
 
806
  out: dict[str, torch.Tensor] = {}
807
  for k, v in state_dict.items():
808
  if not isinstance(v, torch.Tensor):
809
+ # Ignore non-tensor entries if any.
810
  dropped_other += 1
811
  continue
812
+
813
+ # Drop ComfyUI "delta" keys that Diffusers' LoRA loader will never consume.
814
  if k.endswith(".diff") or k.endswith(".diff_b"):
815
  dropped_patch += 1
816
  continue
817
+
818
  if not k.endswith(keep_suffixes):
819
  dropped_other += 1
820
  continue
821
+
822
  if k.endswith(".lora_alpha"):
823
+ # Normalize common alt name to what Diffusers expects.
824
  base = k[: -len(".lora_alpha")]
825
  k2 = f"{base}.alpha"
826
  out[k2] = v.float() if v.dtype != torch.float32 else v
827
  normalized_alpha += 1
828
  kept += 1
829
  continue
830
+
831
  out[k] = v
832
  kept += 1
833
 
 
839
  }
840
  return out, stats
841
 
842
+
843
  def _duplicate_stripped_prefix_keys(state_dict: dict, prefix: str = "diffusion_model.") -> dict:
844
+ """Ensure both prefixed and unprefixed variants exist for LoRA-related keys.
845
+
846
+ Diffusers' Qwen LoRA conversion may strip `diffusion_model.` when looking up
847
+ modules. Some exports only include prefixed keys. To be maximally compatible,
848
+ we duplicate LoRA keys (and alpha) in stripped form when missing.
849
+ """
850
+
851
  out = dict(state_dict)
852
  for k, v in list(state_dict.items()):
853
  if not k.startswith(prefix):
 
857
  out[stripped] = v
858
  return out
859
 
860
+
861
  def _load_lora_weights_with_fallback(repo: str, weight_name: str, adapter_name: str, needs_alpha_fix: bool = False):
862
+ """
863
+ Normal path: pipe.load_lora_weights(repo, weight_name=..., adapter_name=...)
864
+ BFS fallback: download safetensors, inject missing alpha keys, then load from dict.
865
+ """
866
  try:
867
  pipe.load_lora_weights(repo, weight_name=weight_name, adapter_name=adapter_name)
868
  return
869
  except (KeyError, ValueError) as e:
870
+ # KeyError: missing required alpha keys (common in BFS)
871
+ # ValueError: Diffusers Qwen converter found leftover keys (e.g. .diff/.diff_b)
872
  if not needs_alpha_fix:
873
  raise
874
 
 
880
  local_path = hf_hub_download(repo_id=repo, filename=weight_name)
881
  sd = safetensors_load_file(local_path)
882
 
883
+ # 1) Inject required `<module>.alpha` keys (neutral scaling alpha=rank).
884
  sd = _inject_missing_alpha_keys(sd)
885
+
886
+ # 2) Keep only LoRA + alpha keys; drop ComfyUI patch/delta keys.
887
  sd, stats = _filter_to_diffusers_lora_keys(sd)
888
+
889
+ # 3) Duplicate stripped keys (remove `diffusion_model.`) for compatibility.
890
  sd = _duplicate_stripped_prefix_keys(sd)
891
 
892
  print(
 
894
  f"kept={stats['kept']} dropped_patch={stats['dropped_patch']} "
895
  f"dropped_other={stats['dropped_other']} normalized_alpha={stats['normalized_alpha']}"
896
  )
897
+
898
  pipe.load_lora_weights(sd, adapter_name=adapter_name)
899
  return
900
 
901
+
902
+ # ============================================================
903
+ # LoRA loader: single/package + strengths
904
+ # ============================================================
905
+
906
+
907
  def _ensure_loaded_and_get_active_adapters(selected_lora: str):
908
  spec = ADAPTER_SPECS.get(selected_lora)
909
  if not spec:
 
941
 
942
  adapter_names.append(adapter_name)
943
  adapter_weights.append(strength)
944
+
945
  else:
946
  repo = spec["repo"]
947
  weights = spec["weights"]
 
974
  # UI handlers
975
  # ============================================================
976
 
977
+
978
  def on_lora_change_ui(selected_lora, current_prompt, current_extras_condition_only):
979
+ # Preset prompt (fill only if empty)
980
  if selected_lora != NONE_LORA:
981
  preset = LORA_PRESET_PROMPTS.get(selected_lora, "")
982
  if preset and (current_prompt is None or str(current_prompt).strip() == ""):
 
986
  else:
987
  prompt_update = gr.update(value=current_prompt)
988
 
989
+ # Image2 visibility/label
990
  if lora_requires_two_images(selected_lora):
991
  img2_update = gr.update(visible=True, label=image2_label_for_lora(selected_lora))
992
  else:
993
  img2_update = gr.update(visible=False, value=None, label="Upload Reference (Image 2)")
994
 
995
+ # Extra references routing default:
996
+ # For BFS/AnyPose-like adapters, it's usually safer to keep extra refs as conditioning-only.
997
  if selected_lora in ("BFS-Best-FaceSwap", "BFS-Best-FaceSwap-merge", "AnyPose"):
998
  extras_update = gr.update(value=True)
999
  else:
1000
  extras_update = gr.update(value=current_extras_condition_only)
1001
 
1002
  return prompt_update, img2_update, extras_update
 
 
1003
  # ============================================================
1004
  # UI helpers: output routing + derived conditioning
1005
  # ============================================================
 
1009
  raise gr.Error("No output available yet.")
1010
  return gr.update(value=last)
1011
 
1012
+
1013
  def set_output_as_image2(last):
1014
  if last is None:
1015
  raise gr.Error("No output available yet.")
1016
  return gr.update(value=last)
1017
 
1018
+
1019
  def set_output_as_extra(last, existing_extra):
1020
  if last is None:
1021
  raise gr.Error("No output available yet.")
1022
  return _append_to_gallery(existing_extra, last)
1023
 
1024
+
1025
  @spaces.GPU
1026
  def add_derived_ref(img1, existing_extra, derived_type, derived_use_gpu, derived_max_people):
1027
  if img1 is None:
 
1050
  return gr.update(value=new_gallery), gr.update(visible=True, value=derived)
1051
 
1052
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1053
 
1054
  # ============================================================
1055
  # Inference
1056
  # ============================================================
1057
 
1058
+
1059
  @spaces.GPU
1060
  def infer(
1061
  input_image_1,
1062
  input_image_2,
1063
+ input_images_extra, # gallery multi-image box
1064
  prompt,
1065
  lora_adapter,
1066
  seed,
 
1076
  if torch.cuda.is_available():
1077
  torch.cuda.empty_cache()
1078
 
 
 
1079
  if input_image_1 is None:
1080
  raise gr.Error("Please upload Image 1.")
1081
 
 
1094
  seed = random.randint(0, MAX_SEED)
1095
 
1096
  generator = torch.Generator(device=device).manual_seed(seed)
 
1097
  negative_prompt = (
1098
  "worst quality, low quality, bad anatomy, bad hands, text, error, missing fingers, "
1099
  "extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"
1100
  )
1101
 
1102
+ img1 = input_image_1.convert("RGB")
1103
+ img2 = input_image_2.convert("RGB") if input_image_2 is not None else None
 
 
 
 
 
 
 
1104
 
1105
+ # Normalize extra images (Gallery) to PIL RGB (handles tuples from Gallery)
1106
  extra_imgs: list[Image.Image] = []
1107
  if input_images_extra:
1108
  for item in input_images_extra:
 
1110
  if pil is not None:
1111
  extra_imgs.append(pil)
1112
 
1113
+ # Enforce existing 2-image LoRA behavior (image_1 + image_2 required)
1114
  if lora_requires_two_images(lora_adapter) and img2 is None:
1115
  raise gr.Error("This LoRA needs two images. Please upload Image 2 as well.")
1116
 
1117
+ # Label images as image_1, image_2, image_3...
1118
  labeled = build_labeled_images(img1, img2, extra_imgs)
1119
 
1120
+ # Pass to pipeline in labeled order. Keep single-image call when only one is present.
1121
  pipe_images = list(labeled.values())
1122
  if len(pipe_images) == 1:
1123
  pipe_images = pipe_images[0]
1124
 
1125
+ # Resolution derived from Image 1 (base/body/target)
1126
+ # Use target *area* (≈ megapixels) rather than long-edge sizing to reduce FOV drift.
1127
  target_area = get_target_area_for_lora(img1, lora_adapter, float(target_megapixels))
1128
  width, height = compute_canvas_dimensions_from_area(
1129
  img1,
 
1131
  multiple_of=int(pipe.vae_scale_factor * 2),
1132
  )
1133
 
1134
+ # Decide which images participate in the VAE latent stream.
1135
+ # If enabled, extra references beyond (Img_1, Img_2) become conditioning-only.
1136
  vae_image_indices = None
1137
  if extras_condition_only:
1138
  if isinstance(pipe_images, list) and len(pipe_images) > 2:
1139
  vae_image_indices = [0, 1] if len(pipe_images) >= 2 else [0]
1140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1141
  try:
1142
+ print(
1143
+ "[DEBUG][infer] submitting request | "
1144
+ f"lora_adapter={lora_adapter!r} seed={seed} prompt={prompt!r}"
1145
+ )
1146
+
1147
  result = pipe(
1148
  image=pipe_images,
1149
  prompt=prompt,
1150
  negative_prompt=negative_prompt,
1151
  height=height,
1152
  width=width,
1153
+ num_inference_steps=steps,
1154
  generator=generator,
1155
+ true_cfg_scale=guidance_scale,
1156
  vae_image_indices=vae_image_indices,
1157
  pad_to_canvas=bool(pad_to_canvas),
1158
  ).images[0]
1159
  return result, seed, result
 
 
 
 
 
 
 
1160
  finally:
 
1161
  gc.collect()
1162
  if torch.cuda.is_available():
1163
  torch.cuda.empty_cache()
 
1170
  input_pil = input_image.convert("RGB")
1171
  guidance_scale = 1.0
1172
  steps = 4
1173
+ # Examples don't supply Image 2 or extra images; and example list doesn't include AnyPose/BFS.
1174
+ result, seed, last = infer(input_pil, None, None, prompt, lora_adapter, 0, True, guidance_scale, steps, 1.0, True, True)
 
 
 
 
 
1175
  return result, seed, last
1176
 
1177
 
 
1198
  gr.Markdown(
1199
  "Perform diverse image edits using specialized "
1200
  "[LoRA](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image-Edit-2511) adapters for the "
1201
+ "[Qwen-Image-Edit](https://huggingface.co/Qwen/Qwen-Image-Edit-2511) model. Uses a Diffusers compatible extraction of the transformers from Phr00t's Rapid AIO merge. If a different AIO version is desired, copy the space and set the space variable to change version.'"
1202
  )
1203
  gr.Markdown(aio_status_line)
1204
 
 
1292
  value=True,
1293
  )
1294
 
1295
+ # On LoRA selection: preset prompt + toggle Image 2
1296
  lora_adapter.change(
1297
  fn=on_lora_change_ui,
1298
  inputs=[lora_adapter, prompt, extras_condition_only],
 
1352
  outputs=[output_image, seed, last_output],
1353
  )
1354
 
1355
+ # Output routing buttons
1356
  btn_out_to_img1.click(fn=set_output_as_image1, inputs=[last_output], outputs=[input_image_1])
1357
  btn_out_to_img2.click(fn=set_output_as_image2, inputs=[last_output], outputs=[input_image_2])
1358
  btn_out_to_extra.click(fn=set_output_as_extra, inputs=[last_output, input_images_extra], outputs=[input_images_extra])
1359
+
1360
+ # Derived conditioning: append pose/depth map as extra ref (UI shows preview)
1361
  add_derived_btn.click(
1362
  fn=add_derived_ref,
1363
  inputs=[input_image_1, input_images_extra, derived_type, derived_use_gpu, derived_max_people],
1364
  outputs=[input_images_extra, derived_preview],
1365
  )
1366
+
1367
  if __name__ == "__main__":
1368
  demo.queue(max_size=30).launch(
1369
  css=css,