Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,19 +1,14 @@
|
|
| 1 |
-
# app.py
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
import gc
|
| 5 |
-
import uuid
|
| 6 |
-
import time
|
| 7 |
-
import math
|
| 8 |
import traceback
|
| 9 |
-
import random
|
| 10 |
-
from typing import Iterable, Optional
|
| 11 |
-
|
| 12 |
import gradio as gr
|
| 13 |
import numpy as np
|
| 14 |
import spaces
|
| 15 |
import torch
|
|
|
|
| 16 |
from PIL import Image, ImageDraw
|
|
|
|
| 17 |
|
| 18 |
from transformers import (
|
| 19 |
AutoProcessor,
|
|
@@ -29,133 +24,6 @@ from safetensors.torch import load_file as safetensors_load_file
|
|
| 29 |
from gradio.themes import Soft
|
| 30 |
from gradio.themes.utils import colors, fonts, sizes
|
| 31 |
|
| 32 |
-
|
| 33 |
-
# ============================================================
|
| 34 |
-
# Process-unique temp dir (helps avoid /tmp collisions)
|
| 35 |
-
# ============================================================
|
| 36 |
-
|
| 37 |
-
def _ensure_unique_gradio_tmp():
|
| 38 |
-
"""
|
| 39 |
-
ZeroGPU/Spaces can serve multiple users across recycled containers.
|
| 40 |
-
Gradio may use /tmp/gradio by default. We force a unique directory per process.
|
| 41 |
-
"""
|
| 42 |
-
if os.environ.get("GRADIO_TEMP_DIR"):
|
| 43 |
-
print(f"GRADIO_TEMP_DIR = {os.environ['GRADIO_TEMP_DIR']}")
|
| 44 |
-
return
|
| 45 |
-
pid = os.getpid()
|
| 46 |
-
tmp = f"/tmp/gradio_{pid}_{uuid.uuid4().hex[:8]}"
|
| 47 |
-
os.environ["GRADIO_TEMP_DIR"] = tmp
|
| 48 |
-
try:
|
| 49 |
-
os.makedirs(tmp, exist_ok=True)
|
| 50 |
-
except Exception:
|
| 51 |
-
pass
|
| 52 |
-
print(f"GRADIO_TEMP_DIR = {tmp}")
|
| 53 |
-
|
| 54 |
-
_ensure_unique_gradio_tmp()
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
# ============================================================
|
| 58 |
-
# Patch: Qwen2.5-VL RoPE (avoid cublas batched GEMM; preserve shapes)
|
| 59 |
-
# ============================================================
|
| 60 |
-
|
| 61 |
-
def patch_qwen25vl_rope_no_gemm():
|
| 62 |
-
"""
|
| 63 |
-
Patch Qwen2.5-VL rotary embedding to avoid the matmul that can trigger:
|
| 64 |
-
CUBLAS_STATUS_INVALID_VALUE (cublasSgemmStridedBatched)
|
| 65 |
-
on some ZeroGPU/H200 MIG configurations.
|
| 66 |
-
|
| 67 |
-
CRITICAL: Preserve exact output shapes used by apply_multimodal_rotary_pos_emb,
|
| 68 |
-
otherwise you get split_with_sizes mismatches.
|
| 69 |
-
"""
|
| 70 |
-
if os.environ.get("DISABLE_ROPE_PATCH", "").strip() == "1":
|
| 71 |
-
print("[patch][rope] DISABLE_ROPE_PATCH=1 -> skipping patch.")
|
| 72 |
-
return
|
| 73 |
-
|
| 74 |
-
try:
|
| 75 |
-
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl as qvl
|
| 76 |
-
except Exception as e:
|
| 77 |
-
print(f"[patch][rope] could not import qwen2_5_vl modeling: {e}")
|
| 78 |
-
return
|
| 79 |
-
|
| 80 |
-
Rotary = None
|
| 81 |
-
for name in ["Qwen2_5_VLRotaryEmbedding", "Qwen2_5RotaryEmbedding", "RotaryEmbedding"]:
|
| 82 |
-
Rotary = getattr(qvl, name, None)
|
| 83 |
-
if Rotary is not None:
|
| 84 |
-
break
|
| 85 |
-
|
| 86 |
-
if Rotary is None:
|
| 87 |
-
print("[patch][rope] rotary embedding class not found; no patch applied.")
|
| 88 |
-
return
|
| 89 |
-
|
| 90 |
-
orig_forward = Rotary.forward
|
| 91 |
-
|
| 92 |
-
def forward_no_gemm(self, x, position_ids):
|
| 93 |
-
# Fallback to original if structure differs
|
| 94 |
-
if not hasattr(self, "inv_freq") or position_ids is None:
|
| 95 |
-
return orig_forward(self, x, position_ids)
|
| 96 |
-
|
| 97 |
-
# Determine rotary dim from module config (NOT x.shape[-1])
|
| 98 |
-
if hasattr(self, "dim") and isinstance(self.dim, int):
|
| 99 |
-
rope_dim = int(self.dim)
|
| 100 |
-
else:
|
| 101 |
-
rope_dim = int(self.inv_freq.numel() * 2)
|
| 102 |
-
|
| 103 |
-
# Normalize position_ids to (bs, seq)
|
| 104 |
-
if position_ids.ndim > 2:
|
| 105 |
-
pos = position_ids.reshape(position_ids.shape[0], -1)
|
| 106 |
-
else:
|
| 107 |
-
pos = position_ids
|
| 108 |
-
|
| 109 |
-
# Compute on the same device as inv_freq/x
|
| 110 |
-
dev = self.inv_freq.device
|
| 111 |
-
pos = pos.to(device=dev)
|
| 112 |
-
|
| 113 |
-
# Broadcast multiply instead of matmul:
|
| 114 |
-
# inv: (1,1,dim/2,1), pos: (bs,1,1,seq) -> freqs: (bs,1,dim/2,seq)
|
| 115 |
-
inv = self.inv_freq[None, None, :, None].float()
|
| 116 |
-
posf = pos[:, None, None, :].float()
|
| 117 |
-
freqs = (inv * posf).transpose(2, 3) # (bs,1,seq,dim/2)
|
| 118 |
-
|
| 119 |
-
# Double to full rotary dim
|
| 120 |
-
emb = torch.cat((freqs, freqs), dim=-1) # (bs,1,seq,dim)
|
| 121 |
-
|
| 122 |
-
# Enforce exact expected rotary dim
|
| 123 |
-
if emb.shape[-1] != rope_dim:
|
| 124 |
-
emb = emb[..., :rope_dim]
|
| 125 |
-
|
| 126 |
-
cos = emb.cos()
|
| 127 |
-
sin = emb.sin()
|
| 128 |
-
|
| 129 |
-
# Respect attention scaling if present
|
| 130 |
-
attn_scale = getattr(self, "attention_scaling", None)
|
| 131 |
-
if attn_scale is not None:
|
| 132 |
-
cos = cos * attn_scale
|
| 133 |
-
sin = sin * attn_scale
|
| 134 |
-
|
| 135 |
-
# Match dtype expectations (upstream typically returns same dtype as x)
|
| 136 |
-
cos = cos.to(dtype=x.dtype)
|
| 137 |
-
sin = sin.to(dtype=x.dtype)
|
| 138 |
-
|
| 139 |
-
# Optional debug (enable by env)
|
| 140 |
-
if os.environ.get("DEBUG_ROPE", "").strip() == "1":
|
| 141 |
-
ms = getattr(self, "mrope_section", None)
|
| 142 |
-
if ms is not None:
|
| 143 |
-
try:
|
| 144 |
-
ms_list = list(ms)
|
| 145 |
-
print(f"[DEBUG][rope] rope_dim={rope_dim} cos_last={cos.shape[-1]} mrope_sum={sum(ms_list)} mrope={ms_list}")
|
| 146 |
-
except Exception:
|
| 147 |
-
print(f"[DEBUG][rope] rope_dim={rope_dim} cos_last={cos.shape[-1]} (mrope_section unreadable)")
|
| 148 |
-
else:
|
| 149 |
-
print(f"[DEBUG][rope] rope_dim={rope_dim} cos_last={cos.shape[-1]} (no mrope_section attr)")
|
| 150 |
-
|
| 151 |
-
return cos, sin
|
| 152 |
-
|
| 153 |
-
Rotary.forward = forward_no_gemm
|
| 154 |
-
print("[patch] Patched Qwen2.5-VL RoPE matmul -> broadcast multiply (shape-preserving).")
|
| 155 |
-
|
| 156 |
-
patch_qwen25vl_rope_no_gemm()
|
| 157 |
-
|
| 158 |
-
|
| 159 |
# ============================================================
|
| 160 |
# Theme
|
| 161 |
# ============================================================
|
|
@@ -175,6 +43,7 @@ colors.orange_red = colors.Color(
|
|
| 175 |
c950="#802200",
|
| 176 |
)
|
| 177 |
|
|
|
|
| 178 |
class OrangeRedTheme(Soft):
|
| 179 |
def __init__(
|
| 180 |
self,
|
|
@@ -230,11 +99,11 @@ class OrangeRedTheme(Soft):
|
|
| 230 |
block_label_background_fill="*primary_200",
|
| 231 |
)
|
| 232 |
|
| 233 |
-
orange_red_theme = OrangeRedTheme()
|
| 234 |
|
|
|
|
| 235 |
|
| 236 |
# ============================================================
|
| 237 |
-
# Device
|
| 238 |
# ============================================================
|
| 239 |
|
| 240 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -249,14 +118,6 @@ if torch.cuda.is_available():
|
|
| 249 |
print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
|
| 250 |
print("Using device:", device)
|
| 251 |
|
| 252 |
-
# Optional: make matmul a bit more stable (doesn't change correctness)
|
| 253 |
-
try:
|
| 254 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 255 |
-
torch.backends.cudnn.allow_tf32 = True
|
| 256 |
-
except Exception:
|
| 257 |
-
pass
|
| 258 |
-
|
| 259 |
-
|
| 260 |
# ============================================================
|
| 261 |
# AIO version (Space variable)
|
| 262 |
# ============================================================
|
|
@@ -267,6 +128,7 @@ DEFAULT_AIO_VERSION = "v19"
|
|
| 267 |
_VER_RE = re.compile(r"^v\d+$")
|
| 268 |
_DIGITS_RE = re.compile(r"^\d+$")
|
| 269 |
|
|
|
|
| 270 |
def _normalize_version(raw: str) -> Optional[str]:
|
| 271 |
if raw is None:
|
| 272 |
return None
|
|
@@ -275,10 +137,12 @@ def _normalize_version(raw: str) -> Optional[str]:
|
|
| 275 |
return None
|
| 276 |
if _VER_RE.fullmatch(s):
|
| 277 |
return s
|
|
|
|
| 278 |
if _DIGITS_RE.fullmatch(s):
|
| 279 |
return f"v{s}"
|
| 280 |
return None
|
| 281 |
|
|
|
|
| 282 |
_AIO_ENV_RAW = os.environ.get("AIO_VERSION", "")
|
| 283 |
_AIO_ENV_NORM = _normalize_version(_AIO_ENV_RAW)
|
| 284 |
|
|
@@ -289,7 +153,6 @@ print(f"AIO_VERSION (env raw) = {_AIO_ENV_RAW!r}")
|
|
| 289 |
print(f"AIO_VERSION (normalized) = {_AIO_ENV_NORM!r}")
|
| 290 |
print(f"Using AIO_VERSION = {AIO_VERSION} ({AIO_VERSION_SOURCE})")
|
| 291 |
|
| 292 |
-
|
| 293 |
# ============================================================
|
| 294 |
# Pipeline
|
| 295 |
# ============================================================
|
|
@@ -301,6 +164,7 @@ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
|
|
| 301 |
|
| 302 |
dtype = torch.bfloat16
|
| 303 |
|
|
|
|
| 304 |
def _load_pipe_with_version(version: str) -> QwenImageEditPlusPipeline:
|
| 305 |
sub = f"{version}/transformer"
|
| 306 |
print(f"📦 Loading AIO transformer: {AIO_REPO_ID} / {sub}")
|
|
@@ -316,9 +180,11 @@ def _load_pipe_with_version(version: str) -> QwenImageEditPlusPipeline:
|
|
| 316 |
).to(device)
|
| 317 |
return p
|
| 318 |
|
|
|
|
|
|
|
| 319 |
try:
|
| 320 |
pipe = _load_pipe_with_version(AIO_VERSION)
|
| 321 |
-
except Exception:
|
| 322 |
print("❌ Failed to load requested AIO_VERSION. Falling back to v19.")
|
| 323 |
print("---- exception ----")
|
| 324 |
print(traceback.format_exc())
|
|
@@ -336,39 +202,47 @@ except Exception as e:
|
|
| 336 |
|
| 337 |
MAX_SEED = np.iinfo(np.int32).max
|
| 338 |
|
| 339 |
-
|
| 340 |
# ============================================================
|
| 341 |
-
# Derived conditioning (Transformers): Pose + Depth
|
| 342 |
# ============================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
POSE_MODEL_ID = "usyd-community/vitpose-base-simple"
|
| 345 |
POSE_DETECTOR_ID = "PekingU/rtdetr_r50vd_coco_o365"
|
| 346 |
DEPTH_MODEL_ID = "depth-anything/Depth-Anything-V2-Small-hf"
|
| 347 |
|
|
|
|
| 348 |
_POSE_CACHE = {}
|
| 349 |
_DEPTH_CACHE = {}
|
| 350 |
|
|
|
|
| 351 |
COCO17_EDGES = [
|
| 352 |
-
(0, 1), (0, 2), (1, 3), (2, 4),
|
| 353 |
-
(5, 6),
|
| 354 |
-
(5, 7), (7, 9),
|
| 355 |
-
(6, 8), (8, 10),
|
| 356 |
-
(5, 11), (6, 12), (11, 12),
|
| 357 |
-
(11, 13), (13, 15),
|
| 358 |
-
(12, 14), (14, 16),
|
| 359 |
]
|
| 360 |
|
| 361 |
def _derived_device(use_gpu: bool) -> torch.device:
|
| 362 |
return torch.device("cuda" if (use_gpu and torch.cuda.is_available()) else "cpu")
|
| 363 |
|
|
|
|
| 364 |
def _load_pose_models(dev: torch.device):
|
| 365 |
key = str(dev)
|
| 366 |
if key in _POSE_CACHE:
|
| 367 |
return _POSE_CACHE[key]
|
| 368 |
|
|
|
|
| 369 |
det_proc = AutoProcessor.from_pretrained(POSE_DETECTOR_ID)
|
| 370 |
det_model = RTDetrForObjectDetection.from_pretrained(POSE_DETECTOR_ID).to(dev)
|
| 371 |
|
|
|
|
| 372 |
pose_proc = AutoProcessor.from_pretrained(POSE_MODEL_ID)
|
| 373 |
pose_model = VitPoseForPoseEstimation.from_pretrained(POSE_MODEL_ID).to(dev)
|
| 374 |
|
|
@@ -378,6 +252,7 @@ def _load_pose_models(dev: torch.device):
|
|
| 378 |
_POSE_CACHE[key] = (det_proc, det_model, pose_proc, pose_model)
|
| 379 |
return _POSE_CACHE[key]
|
| 380 |
|
|
|
|
| 381 |
def _load_depth_models(dev: torch.device):
|
| 382 |
key = str(dev)
|
| 383 |
if key in _DEPTH_CACHE:
|
|
@@ -390,6 +265,7 @@ def _load_depth_models(dev: torch.device):
|
|
| 390 |
_DEPTH_CACHE[key] = (proc, model)
|
| 391 |
return _DEPTH_CACHE[key]
|
| 392 |
|
|
|
|
| 393 |
def _draw_skeleton_on_blank(
|
| 394 |
size: tuple[int, int],
|
| 395 |
persons_keypoints: list[np.ndarray],
|
|
@@ -403,6 +279,7 @@ def _draw_skeleton_on_blank(
|
|
| 403 |
draw = ImageDraw.Draw(canvas)
|
| 404 |
|
| 405 |
for kps, sc in zip(persons_keypoints, persons_scores):
|
|
|
|
| 406 |
for a, b in COCO17_EDGES:
|
| 407 |
if a >= len(sc) or b >= len(sc):
|
| 408 |
continue
|
|
@@ -412,13 +289,20 @@ def _draw_skeleton_on_blank(
|
|
| 412 |
xb, yb = float(kps[b, 0]), float(kps[b, 1])
|
| 413 |
draw.line([(xa, ya), (xb, yb)], fill=(255, 255, 255), width=line_w)
|
| 414 |
|
|
|
|
| 415 |
for i in range(min(len(sc), len(kps))):
|
| 416 |
if sc[i] < kp_thresh:
|
| 417 |
continue
|
| 418 |
x, y = float(kps[i, 0]), float(kps[i, 1])
|
| 419 |
-
draw.ellipse(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
return canvas
|
| 421 |
|
|
|
|
| 422 |
def make_pose_map(
|
| 423 |
img: Image.Image,
|
| 424 |
*,
|
|
@@ -427,6 +311,12 @@ def make_pose_map(
|
|
| 427 |
det_thresh: float = 0.30,
|
| 428 |
max_people: int = 4,
|
| 429 |
) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
img = img.convert("RGB")
|
| 431 |
dev = _derived_device(use_gpu)
|
| 432 |
det_proc, det_model, pose_proc, pose_model = _load_pose_models(dev)
|
|
@@ -434,8 +324,10 @@ def make_pose_map(
|
|
| 434 |
w, h = img.size
|
| 435 |
|
| 436 |
if mode == "fast":
|
|
|
|
| 437 |
boxes = np.array([[0.0, 0.0, float(w), float(h)]], dtype=np.float32)
|
| 438 |
else:
|
|
|
|
| 439 |
inputs = det_proc(images=img, return_tensors="pt").to(dev)
|
| 440 |
with torch.no_grad():
|
| 441 |
outputs = det_model(**inputs)
|
|
@@ -446,11 +338,14 @@ def make_pose_map(
|
|
| 446 |
threshold=det_thresh,
|
| 447 |
)[0]
|
| 448 |
|
|
|
|
| 449 |
person_boxes = results["boxes"][results["labels"] == 0].detach().cpu().numpy()
|
| 450 |
|
| 451 |
if person_boxes.size == 0:
|
|
|
|
| 452 |
boxes = np.array([[0.0, 0.0, float(w), float(h)]], dtype=np.float32)
|
| 453 |
else:
|
|
|
|
| 454 |
person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
|
| 455 |
person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
|
| 456 |
boxes = person_boxes.astype(np.float32)
|
|
@@ -464,17 +359,23 @@ def make_pose_map(
|
|
| 464 |
|
| 465 |
pose_results = pose_proc.post_process_pose_estimation(pose_outputs, boxes=[boxes])[0]
|
| 466 |
|
| 467 |
-
persons_kps
|
|
|
|
| 468 |
for pr in pose_results:
|
| 469 |
-
|
| 470 |
-
|
|
|
|
|
|
|
| 471 |
|
| 472 |
if not persons_kps:
|
|
|
|
| 473 |
return Image.new("RGB", img.size, (0, 0, 0))
|
| 474 |
|
| 475 |
return _draw_skeleton_on_blank(img.size, persons_kps, persons_sc)
|
| 476 |
|
|
|
|
| 477 |
def make_depth_map(img: Image.Image, *, use_gpu: bool) -> Image.Image:
|
|
|
|
| 478 |
img = img.convert("RGB")
|
| 479 |
dev = _derived_device(use_gpu)
|
| 480 |
proc, model = _load_depth_models(dev)
|
|
@@ -485,7 +386,10 @@ def make_depth_map(img: Image.Image, *, use_gpu: bool) -> Image.Image:
|
|
| 485 |
with torch.no_grad():
|
| 486 |
out = model(**inputs)
|
| 487 |
|
|
|
|
| 488 |
pred = out.predicted_depth
|
|
|
|
|
|
|
| 489 |
pred = torch.nn.functional.interpolate(
|
| 490 |
pred.unsqueeze(1),
|
| 491 |
size=(img.height, img.width),
|
|
@@ -499,55 +403,9 @@ def make_depth_map(img: Image.Image, *, use_gpu: bool) -> Image.Image:
|
|
| 499 |
arr = arr / denom
|
| 500 |
|
| 501 |
depth8 = (arr * 255.0).clip(0, 255).astype(np.uint8)
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
# ============================================================
|
| 506 |
-
# Helpers: gallery normalization + debug-friendly PIL conversion
|
| 507 |
-
# ============================================================
|
| 508 |
-
|
| 509 |
-
def _to_pil_rgb(x) -> Optional[Image.Image]:
|
| 510 |
-
"""
|
| 511 |
-
Accepts PIL / numpy / (image, caption) tuples / gradio dicts and returns PIL RGB.
|
| 512 |
-
Also safely ignores broken temp paths.
|
| 513 |
-
"""
|
| 514 |
-
if x is None:
|
| 515 |
-
return None
|
| 516 |
|
| 517 |
-
# Gallery often returns (image, caption)
|
| 518 |
-
if isinstance(x, tuple) and len(x) >= 1:
|
| 519 |
-
x = x[0]
|
| 520 |
-
if x is None:
|
| 521 |
-
return None
|
| 522 |
-
|
| 523 |
-
# Some gradio versions can return dict with a temp file path
|
| 524 |
-
if isinstance(x, dict):
|
| 525 |
-
# common keys: 'name' or 'path'
|
| 526 |
-
p = x.get("name") or x.get("path")
|
| 527 |
-
if isinstance(p, str):
|
| 528 |
-
if not os.path.exists(p):
|
| 529 |
-
print(f"[WARN] extra image path missing, skipping: {p}")
|
| 530 |
-
return None
|
| 531 |
-
try:
|
| 532 |
-
return Image.open(p).convert("RGB")
|
| 533 |
-
except Exception as e:
|
| 534 |
-
print(f"[WARN] failed to open extra image path {p}: {e}")
|
| 535 |
-
return None
|
| 536 |
-
|
| 537 |
-
if isinstance(x, Image.Image):
|
| 538 |
-
return x.convert("RGB")
|
| 539 |
-
|
| 540 |
-
if isinstance(x, np.ndarray):
|
| 541 |
-
try:
|
| 542 |
-
return Image.fromarray(x).convert("RGB")
|
| 543 |
-
except Exception:
|
| 544 |
-
return None
|
| 545 |
-
|
| 546 |
-
# last resort
|
| 547 |
-
try:
|
| 548 |
-
return Image.fromarray(np.array(x)).convert("RGB")
|
| 549 |
-
except Exception:
|
| 550 |
-
return None
|
| 551 |
|
| 552 |
def _append_to_gallery(existing, new_img: Image.Image):
|
| 553 |
items = []
|
|
@@ -559,29 +417,8 @@ def _append_to_gallery(existing, new_img: Image.Image):
|
|
| 559 |
items.append(new_img)
|
| 560 |
return items
|
| 561 |
|
| 562 |
-
def build_labeled_images(
|
| 563 |
-
img1: Image.Image,
|
| 564 |
-
img2: Optional[Image.Image],
|
| 565 |
-
extra_imgs: Optional[list[Image.Image]],
|
| 566 |
-
) -> dict[str, Image.Image]:
|
| 567 |
-
labeled: dict[str, Image.Image] = {}
|
| 568 |
-
idx = 1
|
| 569 |
-
labeled[f"image_{idx}"] = img1
|
| 570 |
-
idx += 1
|
| 571 |
-
if img2 is not None:
|
| 572 |
-
labeled[f"image_{idx}"] = img2
|
| 573 |
-
idx += 1
|
| 574 |
-
if extra_imgs:
|
| 575 |
-
for im in extra_imgs:
|
| 576 |
-
if im is None:
|
| 577 |
-
continue
|
| 578 |
-
labeled[f"image_{idx}"] = im
|
| 579 |
-
idx += 1
|
| 580 |
-
return labeled
|
| 581 |
-
|
| 582 |
-
|
| 583 |
# ============================================================
|
| 584 |
-
# LoRA adapters + presets
|
| 585 |
# ============================================================
|
| 586 |
|
| 587 |
NONE_LORA = "None"
|
|
@@ -649,7 +486,7 @@ ADAPTER_SPECS = {
|
|
| 649 |
"weights": "bfs_head_v5_2511_original.safetensors",
|
| 650 |
"adapter_name": "BFS-Best-Faceswap",
|
| 651 |
"strength": 1.0,
|
| 652 |
-
"needs_alpha_fix": True,
|
| 653 |
},
|
| 654 |
"BFS-Best-FaceSwap-merge": {
|
| 655 |
"type": "single",
|
|
@@ -659,7 +496,7 @@ ADAPTER_SPECS = {
|
|
| 659 |
"weights": "bfs_head_v5_2511_merged_version_rank_32_fp32.safetensors",
|
| 660 |
"adapter_name": "BFS-Best-Faceswap-merge",
|
| 661 |
"strength": 1.1,
|
| 662 |
-
"needs_alpha_fix": True,
|
| 663 |
},
|
| 664 |
"F2P": {
|
| 665 |
"type": "single",
|
|
@@ -745,13 +582,16 @@ LORA_PRESET_PROMPTS = {
|
|
| 745 |
"BFS-Best-FaceSwap-merge": "head_swap: start with Picture 1 as the base image, keeping its lighting, environment, and background. remove the head from Picture 1 completely and replace it with the head from Picture 2, strictly preserving the hair, eye color, and nose structure of Picture 2. copy the eye direction, head rotation, and micro-expressions from Picture 1. high quality, sharp details, 4k",
|
| 746 |
}
|
| 747 |
|
|
|
|
| 748 |
LOADED_ADAPTERS = set()
|
| 749 |
|
| 750 |
-
|
| 751 |
# ============================================================
|
| 752 |
-
# Helpers: resolution
|
| 753 |
# ============================================================
|
| 754 |
|
|
|
|
|
|
|
|
|
|
| 755 |
def _round_to_multiple(x: int, m: int) -> int:
|
| 756 |
return max(m, (int(x) // m) * m)
|
| 757 |
|
|
@@ -760,9 +600,16 @@ def compute_canvas_dimensions_from_area(
|
|
| 760 |
target_area: int,
|
| 761 |
multiple_of: int,
|
| 762 |
) -> tuple[int, int]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
w, h = image.size
|
| 764 |
aspect = w / h if h else 1.0
|
|
|
|
|
|
|
| 765 |
from qwenimage.pipeline_qwenimage_edit_plus import calculate_dimensions
|
|
|
|
| 766 |
width, height = calculate_dimensions(int(target_area), float(aspect))
|
| 767 |
width = _round_to_multiple(int(width), int(multiple_of))
|
| 768 |
height = _round_to_multiple(int(height), int(multiple_of))
|
|
@@ -773,18 +620,29 @@ def get_target_area_for_lora(
|
|
| 773 |
lora_adapter: str,
|
| 774 |
user_target_megapixels: float,
|
| 775 |
) -> int:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
spec = ADAPTER_SPECS.get(lora_adapter, {})
|
|
|
|
| 777 |
if "target_area" in spec:
|
| 778 |
try:
|
| 779 |
return int(spec["target_area"])
|
| 780 |
except Exception:
|
| 781 |
pass
|
|
|
|
| 782 |
if "target_megapixels" in spec:
|
| 783 |
try:
|
| 784 |
mp = float(spec["target_megapixels"])
|
| 785 |
return int(mp * 1024 * 1024)
|
| 786 |
except Exception:
|
| 787 |
pass
|
|
|
|
|
|
|
| 788 |
if "target_long_edge" in spec:
|
| 789 |
try:
|
| 790 |
long_edge = int(spec["target_long_edge"])
|
|
@@ -798,21 +656,99 @@ def get_target_area_for_lora(
|
|
| 798 |
return int(new_w * new_h)
|
| 799 |
except Exception:
|
| 800 |
pass
|
| 801 |
-
return int(float(user_target_megapixels) * 1024 * 1024)
|
| 802 |
|
|
|
|
|
|
|
| 803 |
|
| 804 |
# ============================================================
|
| 805 |
-
# Helpers:
|
| 806 |
# ============================================================
|
| 807 |
|
|
|
|
| 808 |
def lora_requires_two_images(lora_adapter: str) -> bool:
|
| 809 |
return bool(ADAPTER_SPECS.get(lora_adapter, {}).get("requires_two_images", False))
|
| 810 |
|
|
|
|
| 811 |
def image2_label_for_lora(lora_adapter: str) -> str:
|
| 812 |
return str(ADAPTER_SPECS.get(lora_adapter, {}).get("image2_label", "Upload Reference (Image 2)"))
|
| 813 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 814 |
def _inject_missing_alpha_keys(state_dict: dict) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 815 |
bases = {}
|
|
|
|
| 816 |
for k, v in state_dict.items():
|
| 817 |
if not isinstance(v, torch.Tensor):
|
| 818 |
continue
|
|
@@ -833,9 +769,27 @@ def _inject_missing_alpha_keys(state_dict: dict) -> dict:
|
|
| 833 |
stripped_alpha = f"{stripped_base}.alpha"
|
| 834 |
if stripped_alpha not in state_dict:
|
| 835 |
state_dict[stripped_alpha] = alpha_tensor
|
|
|
|
| 836 |
return state_dict
|
| 837 |
|
|
|
|
| 838 |
def _filter_to_diffusers_lora_keys(state_dict: dict) -> tuple[dict, dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 839 |
keep_suffixes = (
|
| 840 |
".lora_up.weight",
|
| 841 |
".lora_down.weight",
|
|
@@ -843,6 +797,7 @@ def _filter_to_diffusers_lora_keys(state_dict: dict) -> tuple[dict, dict]:
|
|
| 843 |
".alpha",
|
| 844 |
".lora_alpha",
|
| 845 |
)
|
|
|
|
| 846 |
dropped_patch = 0
|
| 847 |
dropped_other = 0
|
| 848 |
kept = 0
|
|
@@ -851,21 +806,28 @@ def _filter_to_diffusers_lora_keys(state_dict: dict) -> tuple[dict, dict]:
|
|
| 851 |
out: dict[str, torch.Tensor] = {}
|
| 852 |
for k, v in state_dict.items():
|
| 853 |
if not isinstance(v, torch.Tensor):
|
|
|
|
| 854 |
dropped_other += 1
|
| 855 |
continue
|
|
|
|
|
|
|
| 856 |
if k.endswith(".diff") or k.endswith(".diff_b"):
|
| 857 |
dropped_patch += 1
|
| 858 |
continue
|
|
|
|
| 859 |
if not k.endswith(keep_suffixes):
|
| 860 |
dropped_other += 1
|
| 861 |
continue
|
|
|
|
| 862 |
if k.endswith(".lora_alpha"):
|
|
|
|
| 863 |
base = k[: -len(".lora_alpha")]
|
| 864 |
k2 = f"{base}.alpha"
|
| 865 |
out[k2] = v.float() if v.dtype != torch.float32 else v
|
| 866 |
normalized_alpha += 1
|
| 867 |
kept += 1
|
| 868 |
continue
|
|
|
|
| 869 |
out[k] = v
|
| 870 |
kept += 1
|
| 871 |
|
|
@@ -877,7 +839,15 @@ def _filter_to_diffusers_lora_keys(state_dict: dict) -> tuple[dict, dict]:
|
|
| 877 |
}
|
| 878 |
return out, stats
|
| 879 |
|
|
|
|
| 880 |
def _duplicate_stripped_prefix_keys(state_dict: dict, prefix: str = "diffusion_model.") -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 881 |
out = dict(state_dict)
|
| 882 |
for k, v in list(state_dict.items()):
|
| 883 |
if not k.startswith(prefix):
|
|
@@ -887,11 +857,18 @@ def _duplicate_stripped_prefix_keys(state_dict: dict, prefix: str = "diffusion_m
|
|
| 887 |
out[stripped] = v
|
| 888 |
return out
|
| 889 |
|
|
|
|
| 890 |
def _load_lora_weights_with_fallback(repo: str, weight_name: str, adapter_name: str, needs_alpha_fix: bool = False):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 891 |
try:
|
| 892 |
pipe.load_lora_weights(repo, weight_name=weight_name, adapter_name=adapter_name)
|
| 893 |
return
|
| 894 |
except (KeyError, ValueError) as e:
|
|
|
|
|
|
|
| 895 |
if not needs_alpha_fix:
|
| 896 |
raise
|
| 897 |
|
|
@@ -903,8 +880,13 @@ def _load_lora_weights_with_fallback(repo: str, weight_name: str, adapter_name:
|
|
| 903 |
local_path = hf_hub_download(repo_id=repo, filename=weight_name)
|
| 904 |
sd = safetensors_load_file(local_path)
|
| 905 |
|
|
|
|
| 906 |
sd = _inject_missing_alpha_keys(sd)
|
|
|
|
|
|
|
| 907 |
sd, stats = _filter_to_diffusers_lora_keys(sd)
|
|
|
|
|
|
|
| 908 |
sd = _duplicate_stripped_prefix_keys(sd)
|
| 909 |
|
| 910 |
print(
|
|
@@ -912,9 +894,16 @@ def _load_lora_weights_with_fallback(repo: str, weight_name: str, adapter_name:
|
|
| 912 |
f"kept={stats['kept']} dropped_patch={stats['dropped_patch']} "
|
| 913 |
f"dropped_other={stats['dropped_other']} normalized_alpha={stats['normalized_alpha']}"
|
| 914 |
)
|
|
|
|
| 915 |
pipe.load_lora_weights(sd, adapter_name=adapter_name)
|
| 916 |
return
|
| 917 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 918 |
def _ensure_loaded_and_get_active_adapters(selected_lora: str):
|
| 919 |
spec = ADAPTER_SPECS.get(selected_lora)
|
| 920 |
if not spec:
|
|
@@ -952,6 +941,7 @@ def _ensure_loaded_and_get_active_adapters(selected_lora: str):
|
|
| 952 |
|
| 953 |
adapter_names.append(adapter_name)
|
| 954 |
adapter_weights.append(strength)
|
|
|
|
| 955 |
else:
|
| 956 |
repo = spec["repo"]
|
| 957 |
weights = spec["weights"]
|
|
@@ -984,7 +974,9 @@ def _ensure_loaded_and_get_active_adapters(selected_lora: str):
|
|
| 984 |
# UI handlers
|
| 985 |
# ============================================================
|
| 986 |
|
|
|
|
| 987 |
def on_lora_change_ui(selected_lora, current_prompt, current_extras_condition_only):
|
|
|
|
| 988 |
if selected_lora != NONE_LORA:
|
| 989 |
preset = LORA_PRESET_PROMPTS.get(selected_lora, "")
|
| 990 |
if preset and (current_prompt is None or str(current_prompt).strip() == ""):
|
|
@@ -994,19 +986,20 @@ def on_lora_change_ui(selected_lora, current_prompt, current_extras_condition_on
|
|
| 994 |
else:
|
| 995 |
prompt_update = gr.update(value=current_prompt)
|
| 996 |
|
|
|
|
| 997 |
if lora_requires_two_images(selected_lora):
|
| 998 |
img2_update = gr.update(visible=True, label=image2_label_for_lora(selected_lora))
|
| 999 |
else:
|
| 1000 |
img2_update = gr.update(visible=False, value=None, label="Upload Reference (Image 2)")
|
| 1001 |
|
|
|
|
|
|
|
| 1002 |
if selected_lora in ("BFS-Best-FaceSwap", "BFS-Best-FaceSwap-merge", "AnyPose"):
|
| 1003 |
extras_update = gr.update(value=True)
|
| 1004 |
else:
|
| 1005 |
extras_update = gr.update(value=current_extras_condition_only)
|
| 1006 |
|
| 1007 |
return prompt_update, img2_update, extras_update
|
| 1008 |
-
|
| 1009 |
-
|
| 1010 |
# ============================================================
|
| 1011 |
# UI helpers: output routing + derived conditioning
|
| 1012 |
# ============================================================
|
|
@@ -1016,16 +1009,19 @@ def set_output_as_image1(last):
|
|
| 1016 |
raise gr.Error("No output available yet.")
|
| 1017 |
return gr.update(value=last)
|
| 1018 |
|
|
|
|
| 1019 |
def set_output_as_image2(last):
|
| 1020 |
if last is None:
|
| 1021 |
raise gr.Error("No output available yet.")
|
| 1022 |
return gr.update(value=last)
|
| 1023 |
|
|
|
|
| 1024 |
def set_output_as_extra(last, existing_extra):
|
| 1025 |
if last is None:
|
| 1026 |
raise gr.Error("No output available yet.")
|
| 1027 |
return _append_to_gallery(existing_extra, last)
|
| 1028 |
|
|
|
|
| 1029 |
@spaces.GPU
|
| 1030 |
def add_derived_ref(img1, existing_extra, derived_type, derived_use_gpu, derived_max_people):
|
| 1031 |
if img1 is None:
|
|
@@ -1054,36 +1050,17 @@ def add_derived_ref(img1, existing_extra, derived_type, derived_use_gpu, derived
|
|
| 1054 |
return gr.update(value=new_gallery), gr.update(visible=True, value=derived)
|
| 1055 |
|
| 1056 |
|
| 1057 |
-
# ============================================================
|
| 1058 |
-
# Debug helpers (CUDA mem + token count)
|
| 1059 |
-
# ============================================================
|
| 1060 |
-
|
| 1061 |
-
def _cuda_mem(prefix: str):
|
| 1062 |
-
if not torch.cuda.is_available():
|
| 1063 |
-
return
|
| 1064 |
-
try:
|
| 1065 |
-
free, total = torch.cuda.mem_get_info()
|
| 1066 |
-
print(f"[DEBUG][cuda][{prefix}] mem free={free/1e9:.2f}GB total={total/1e9:.2f}GB")
|
| 1067 |
-
except Exception:
|
| 1068 |
-
pass
|
| 1069 |
-
|
| 1070 |
-
def _approx_token_count(text: str) -> int:
|
| 1071 |
-
# Lightweight: we avoid forcing tokenizer calls here; this is only for debug.
|
| 1072 |
-
# Rule-of-thumb: ~4 chars per token in English-ish text.
|
| 1073 |
-
if not text:
|
| 1074 |
-
return 0
|
| 1075 |
-
return max(1, int(math.ceil(len(text) / 4.0)))
|
| 1076 |
-
|
| 1077 |
|
| 1078 |
# ============================================================
|
| 1079 |
# Inference
|
| 1080 |
# ============================================================
|
| 1081 |
|
|
|
|
| 1082 |
@spaces.GPU
|
| 1083 |
def infer(
|
| 1084 |
input_image_1,
|
| 1085 |
input_image_2,
|
| 1086 |
-
input_images_extra,
|
| 1087 |
prompt,
|
| 1088 |
lora_adapter,
|
| 1089 |
seed,
|
|
@@ -1099,8 +1076,6 @@ def infer(
|
|
| 1099 |
if torch.cuda.is_available():
|
| 1100 |
torch.cuda.empty_cache()
|
| 1101 |
|
| 1102 |
-
print("[DEBUG][infer] input types:", type(input_image_1), type(input_image_2), type(input_images_extra))
|
| 1103 |
-
|
| 1104 |
if input_image_1 is None:
|
| 1105 |
raise gr.Error("Please upload Image 1.")
|
| 1106 |
|
|
@@ -1119,23 +1094,15 @@ def infer(
|
|
| 1119 |
seed = random.randint(0, MAX_SEED)
|
| 1120 |
|
| 1121 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 1122 |
-
|
| 1123 |
negative_prompt = (
|
| 1124 |
"worst quality, low quality, bad anatomy, bad hands, text, error, missing fingers, "
|
| 1125 |
"extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"
|
| 1126 |
)
|
| 1127 |
|
| 1128 |
-
img1 = input_image_1.convert("RGB")
|
| 1129 |
-
if
|
| 1130 |
-
raise gr.Error("Image 1 could not be read (unexpected input type/path).")
|
| 1131 |
-
|
| 1132 |
-
img2 = None
|
| 1133 |
-
if input_image_2 is not None:
|
| 1134 |
-
img2 = input_image_2.convert("RGB") if isinstance(input_image_2, Image.Image) else _to_pil_rgb(input_image_2)
|
| 1135 |
-
if img2 is None:
|
| 1136 |
-
raise gr.Error("Image 2 could not be read (unexpected input type/path).")
|
| 1137 |
|
| 1138 |
-
# Normalize extra images (Gallery)
|
| 1139 |
extra_imgs: list[Image.Image] = []
|
| 1140 |
if input_images_extra:
|
| 1141 |
for item in input_images_extra:
|
|
@@ -1143,16 +1110,20 @@ def infer(
|
|
| 1143 |
if pil is not None:
|
| 1144 |
extra_imgs.append(pil)
|
| 1145 |
|
| 1146 |
-
# Enforce 2-image LoRA behavior
|
| 1147 |
if lora_requires_two_images(lora_adapter) and img2 is None:
|
| 1148 |
raise gr.Error("This LoRA needs two images. Please upload Image 2 as well.")
|
| 1149 |
|
|
|
|
| 1150 |
labeled = build_labeled_images(img1, img2, extra_imgs)
|
| 1151 |
|
|
|
|
| 1152 |
pipe_images = list(labeled.values())
|
| 1153 |
if len(pipe_images) == 1:
|
| 1154 |
pipe_images = pipe_images[0]
|
| 1155 |
|
|
|
|
|
|
|
| 1156 |
target_area = get_target_area_for_lora(img1, lora_adapter, float(target_megapixels))
|
| 1157 |
width, height = compute_canvas_dimensions_from_area(
|
| 1158 |
img1,
|
|
@@ -1160,48 +1131,33 @@ def infer(
|
|
| 1160 |
multiple_of=int(pipe.vae_scale_factor * 2),
|
| 1161 |
)
|
| 1162 |
|
|
|
|
|
|
|
| 1163 |
vae_image_indices = None
|
| 1164 |
if extras_condition_only:
|
| 1165 |
if isinstance(pipe_images, list) and len(pipe_images) > 2:
|
| 1166 |
vae_image_indices = [0, 1] if len(pipe_images) >= 2 else [0]
|
| 1167 |
|
| 1168 |
-
# Debug summary
|
| 1169 |
-
n_images = len(pipe_images) if isinstance(pipe_images, list) else 1
|
| 1170 |
-
tok_est = _approx_token_count(prompt or "")
|
| 1171 |
-
print(
|
| 1172 |
-
"[DEBUG][infer] submitting request | "
|
| 1173 |
-
f"lora_adapter={lora_adapter!r} seed={seed} prompt_len={len(prompt or '')} "
|
| 1174 |
-
f"steps={steps} true_cfg_scale={guidance_scale} target_mp={target_megapixels} "
|
| 1175 |
-
f"canvas=({width}x{height}) n_images={n_images} vae_image_indices={vae_image_indices} "
|
| 1176 |
-
f"pad_to_canvas={bool(pad_to_canvas)}"
|
| 1177 |
-
)
|
| 1178 |
-
print(f"[DEBUG][infer] image_1 size: {img1.size} image_2 size: {img2.size if img2 else None}")
|
| 1179 |
-
print(f"[DEBUG][infer] prompt token_estimate: {tok_est}")
|
| 1180 |
-
_cuda_mem("before")
|
| 1181 |
-
|
| 1182 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1183 |
result = pipe(
|
| 1184 |
image=pipe_images,
|
| 1185 |
prompt=prompt,
|
| 1186 |
negative_prompt=negative_prompt,
|
| 1187 |
height=height,
|
| 1188 |
width=width,
|
| 1189 |
-
num_inference_steps=
|
| 1190 |
generator=generator,
|
| 1191 |
-
true_cfg_scale=
|
| 1192 |
vae_image_indices=vae_image_indices,
|
| 1193 |
pad_to_canvas=bool(pad_to_canvas),
|
| 1194 |
).images[0]
|
| 1195 |
return result, seed, result
|
| 1196 |
-
|
| 1197 |
-
except Exception as e:
|
| 1198 |
-
print("---- [ERROR][infer] exception ----")
|
| 1199 |
-
print(traceback.format_exc())
|
| 1200 |
-
print("---------------------------------")
|
| 1201 |
-
raise
|
| 1202 |
-
|
| 1203 |
finally:
|
| 1204 |
-
_cuda_mem("after")
|
| 1205 |
gc.collect()
|
| 1206 |
if torch.cuda.is_available():
|
| 1207 |
torch.cuda.empty_cache()
|
|
@@ -1214,13 +1170,8 @@ def infer_example(input_image, prompt, lora_adapter):
|
|
| 1214 |
input_pil = input_image.convert("RGB")
|
| 1215 |
guidance_scale = 1.0
|
| 1216 |
steps = 4
|
| 1217 |
-
|
| 1218 |
-
|
| 1219 |
-
prompt, lora_adapter,
|
| 1220 |
-
0, True,
|
| 1221 |
-
guidance_scale, steps, 1.0,
|
| 1222 |
-
True, True
|
| 1223 |
-
)
|
| 1224 |
return result, seed, last
|
| 1225 |
|
| 1226 |
|
|
@@ -1247,7 +1198,7 @@ with gr.Blocks() as demo:
|
|
| 1247 |
gr.Markdown(
|
| 1248 |
"Perform diverse image edits using specialized "
|
| 1249 |
"[LoRA](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image-Edit-2511) adapters for the "
|
| 1250 |
-
"[Qwen-Image-Edit](https://huggingface.co/Qwen/Qwen-Image-Edit-2511) model. Uses a Diffusers compatible extraction of the transformers from Phr00t's Rapid AIO merge."
|
| 1251 |
)
|
| 1252 |
gr.Markdown(aio_status_line)
|
| 1253 |
|
|
@@ -1341,6 +1292,7 @@ with gr.Blocks() as demo:
|
|
| 1341 |
value=True,
|
| 1342 |
)
|
| 1343 |
|
|
|
|
| 1344 |
lora_adapter.change(
|
| 1345 |
fn=on_lora_change_ui,
|
| 1346 |
inputs=[lora_adapter, prompt, extras_condition_only],
|
|
@@ -1400,16 +1352,18 @@ with gr.Blocks() as demo:
|
|
| 1400 |
outputs=[output_image, seed, last_output],
|
| 1401 |
)
|
| 1402 |
|
|
|
|
| 1403 |
btn_out_to_img1.click(fn=set_output_as_image1, inputs=[last_output], outputs=[input_image_1])
|
| 1404 |
btn_out_to_img2.click(fn=set_output_as_image2, inputs=[last_output], outputs=[input_image_2])
|
| 1405 |
btn_out_to_extra.click(fn=set_output_as_extra, inputs=[last_output, input_images_extra], outputs=[input_images_extra])
|
| 1406 |
-
|
|
|
|
| 1407 |
add_derived_btn.click(
|
| 1408 |
fn=add_derived_ref,
|
| 1409 |
inputs=[input_image_1, input_images_extra, derived_type, derived_use_gpu, derived_max_people],
|
| 1410 |
outputs=[input_images_extra, derived_preview],
|
| 1411 |
)
|
| 1412 |
-
|
| 1413 |
if __name__ == "__main__":
|
| 1414 |
demo.queue(max_size=30).launch(
|
| 1415 |
css=css,
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import gc
|
|
|
|
|
|
|
|
|
|
| 4 |
import traceback
|
|
|
|
|
|
|
|
|
|
| 5 |
import gradio as gr
|
| 6 |
import numpy as np
|
| 7 |
import spaces
|
| 8 |
import torch
|
| 9 |
+
import random
|
| 10 |
from PIL import Image, ImageDraw
|
| 11 |
+
from typing import Iterable, Optional
|
| 12 |
|
| 13 |
from transformers import (
|
| 14 |
AutoProcessor,
|
|
|
|
| 24 |
from gradio.themes import Soft
|
| 25 |
from gradio.themes.utils import colors, fonts, sizes
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# ============================================================
|
| 28 |
# Theme
|
| 29 |
# ============================================================
|
|
|
|
| 43 |
c950="#802200",
|
| 44 |
)
|
| 45 |
|
| 46 |
+
|
| 47 |
class OrangeRedTheme(Soft):
|
| 48 |
def __init__(
|
| 49 |
self,
|
|
|
|
| 99 |
block_label_background_fill="*primary_200",
|
| 100 |
)
|
| 101 |
|
|
|
|
| 102 |
|
| 103 |
+
orange_red_theme = OrangeRedTheme()
|
| 104 |
|
| 105 |
# ============================================================
|
| 106 |
+
# Device
|
| 107 |
# ============================================================
|
| 108 |
|
| 109 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 118 |
print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
|
| 119 |
print("Using device:", device)
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
# ============================================================
|
| 122 |
# AIO version (Space variable)
|
| 123 |
# ============================================================
|
|
|
|
| 128 |
_VER_RE = re.compile(r"^v\d+$")
|
| 129 |
_DIGITS_RE = re.compile(r"^\d+$")
|
| 130 |
|
| 131 |
+
|
| 132 |
def _normalize_version(raw: str) -> Optional[str]:
|
| 133 |
if raw is None:
|
| 134 |
return None
|
|
|
|
| 137 |
return None
|
| 138 |
if _VER_RE.fullmatch(s):
|
| 139 |
return s
|
| 140 |
+
# forgiving: allow "21" -> "v21"
|
| 141 |
if _DIGITS_RE.fullmatch(s):
|
| 142 |
return f"v{s}"
|
| 143 |
return None
|
| 144 |
|
| 145 |
+
|
| 146 |
_AIO_ENV_RAW = os.environ.get("AIO_VERSION", "")
|
| 147 |
_AIO_ENV_NORM = _normalize_version(_AIO_ENV_RAW)
|
| 148 |
|
|
|
|
| 153 |
print(f"AIO_VERSION (normalized) = {_AIO_ENV_NORM!r}")
|
| 154 |
print(f"Using AIO_VERSION = {AIO_VERSION} ({AIO_VERSION_SOURCE})")
|
| 155 |
|
|
|
|
| 156 |
# ============================================================
|
| 157 |
# Pipeline
|
| 158 |
# ============================================================
|
|
|
|
| 164 |
|
| 165 |
dtype = torch.bfloat16
|
| 166 |
|
| 167 |
+
|
| 168 |
def _load_pipe_with_version(version: str) -> QwenImageEditPlusPipeline:
|
| 169 |
sub = f"{version}/transformer"
|
| 170 |
print(f"📦 Loading AIO transformer: {AIO_REPO_ID} / {sub}")
|
|
|
|
| 180 |
).to(device)
|
| 181 |
return p
|
| 182 |
|
| 183 |
+
|
| 184 |
+
# Forgiving load: try env/default version, fallback to v19 if it fails
|
| 185 |
try:
|
| 186 |
pipe = _load_pipe_with_version(AIO_VERSION)
|
| 187 |
+
except Exception as e:
|
| 188 |
print("❌ Failed to load requested AIO_VERSION. Falling back to v19.")
|
| 189 |
print("---- exception ----")
|
| 190 |
print(traceback.format_exc())
|
|
|
|
| 202 |
|
| 203 |
MAX_SEED = np.iinfo(np.int32).max
|
| 204 |
|
|
|
|
| 205 |
# ============================================================
|
| 206 |
+
# Derived conditioning (Transformers): Pose + Depth
|
| 207 |
# ============================================================
|
| 208 |
+
# Pose estimation uses ViTPose (top-down). Official docs show RT-DETR -> ViTPose flow:
|
| 209 |
+
# https://huggingface.co/docs/transformers/model_doc/vitpose
|
| 210 |
+
# Depth uses Depth Anything V2 Small (Transformers-compatible):
|
| 211 |
+
# https://huggingface.co/depth-anything/Depth-Anything-V2-Small-hf
|
| 212 |
|
| 213 |
POSE_MODEL_ID = "usyd-community/vitpose-base-simple"
|
| 214 |
POSE_DETECTOR_ID = "PekingU/rtdetr_r50vd_coco_o365"
|
| 215 |
DEPTH_MODEL_ID = "depth-anything/Depth-Anything-V2-Small-hf"
|
| 216 |
|
| 217 |
+
# Lazy caches keyed by device string ("cpu" / "cuda")
|
| 218 |
_POSE_CACHE = {}
|
| 219 |
_DEPTH_CACHE = {}
|
| 220 |
|
| 221 |
+
# COCO-17 skeleton connections (approx "OpenPose-like" stick figure)
|
| 222 |
COCO17_EDGES = [
|
| 223 |
+
(0, 1), (0, 2), (1, 3), (2, 4), # head
|
| 224 |
+
(5, 6), # shoulders
|
| 225 |
+
(5, 7), (7, 9), # left arm
|
| 226 |
+
(6, 8), (8, 10), # right arm
|
| 227 |
+
(5, 11), (6, 12), (11, 12), # torso
|
| 228 |
+
(11, 13), (13, 15), # left leg
|
| 229 |
+
(12, 14), (14, 16), # right leg
|
| 230 |
]
|
| 231 |
|
| 232 |
def _derived_device(use_gpu: bool) -> torch.device:
|
| 233 |
return torch.device("cuda" if (use_gpu and torch.cuda.is_available()) else "cpu")
|
| 234 |
|
| 235 |
+
|
| 236 |
def _load_pose_models(dev: torch.device):
|
| 237 |
key = str(dev)
|
| 238 |
if key in _POSE_CACHE:
|
| 239 |
return _POSE_CACHE[key]
|
| 240 |
|
| 241 |
+
# Detector (optional but used for multi-person boxes)
|
| 242 |
det_proc = AutoProcessor.from_pretrained(POSE_DETECTOR_ID)
|
| 243 |
det_model = RTDetrForObjectDetection.from_pretrained(POSE_DETECTOR_ID).to(dev)
|
| 244 |
|
| 245 |
+
# Pose model
|
| 246 |
pose_proc = AutoProcessor.from_pretrained(POSE_MODEL_ID)
|
| 247 |
pose_model = VitPoseForPoseEstimation.from_pretrained(POSE_MODEL_ID).to(dev)
|
| 248 |
|
|
|
|
| 252 |
_POSE_CACHE[key] = (det_proc, det_model, pose_proc, pose_model)
|
| 253 |
return _POSE_CACHE[key]
|
| 254 |
|
| 255 |
+
|
| 256 |
def _load_depth_models(dev: torch.device):
|
| 257 |
key = str(dev)
|
| 258 |
if key in _DEPTH_CACHE:
|
|
|
|
| 265 |
_DEPTH_CACHE[key] = (proc, model)
|
| 266 |
return _DEPTH_CACHE[key]
|
| 267 |
|
| 268 |
+
|
| 269 |
def _draw_skeleton_on_blank(
|
| 270 |
size: tuple[int, int],
|
| 271 |
persons_keypoints: list[np.ndarray],
|
|
|
|
| 279 |
draw = ImageDraw.Draw(canvas)
|
| 280 |
|
| 281 |
for kps, sc in zip(persons_keypoints, persons_scores):
|
| 282 |
+
# Draw edges
|
| 283 |
for a, b in COCO17_EDGES:
|
| 284 |
if a >= len(sc) or b >= len(sc):
|
| 285 |
continue
|
|
|
|
| 289 |
xb, yb = float(kps[b, 0]), float(kps[b, 1])
|
| 290 |
draw.line([(xa, ya), (xb, yb)], fill=(255, 255, 255), width=line_w)
|
| 291 |
|
| 292 |
+
# Draw keypoints
|
| 293 |
for i in range(min(len(sc), len(kps))):
|
| 294 |
if sc[i] < kp_thresh:
|
| 295 |
continue
|
| 296 |
x, y = float(kps[i, 0]), float(kps[i, 1])
|
| 297 |
+
draw.ellipse(
|
| 298 |
+
[(x - point_r, y - point_r), (x + point_r, y + point_r)],
|
| 299 |
+
fill=(255, 255, 255),
|
| 300 |
+
outline=None,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
return canvas
|
| 304 |
|
| 305 |
+
|
| 306 |
def make_pose_map(
|
| 307 |
img: Image.Image,
|
| 308 |
*,
|
|
|
|
| 311 |
det_thresh: float = 0.30,
|
| 312 |
max_people: int = 4,
|
| 313 |
) -> Image.Image:
|
| 314 |
+
"""Return an OpenPose-like skeleton map (RGB) using Transformers models.
|
| 315 |
+
|
| 316 |
+
mode:
|
| 317 |
+
- "fast": full-frame box (no detector). Good when Image 1 is already a single subject.
|
| 318 |
+
- "detect": RT-DETR person boxes -> ViTPose. Better for multi-person scenes.
|
| 319 |
+
"""
|
| 320 |
img = img.convert("RGB")
|
| 321 |
dev = _derived_device(use_gpu)
|
| 322 |
det_proc, det_model, pose_proc, pose_model = _load_pose_models(dev)
|
|
|
|
| 324 |
w, h = img.size
|
| 325 |
|
| 326 |
if mode == "fast":
|
| 327 |
+
# Single box covering whole image, COCO format [x, y, w, h]
|
| 328 |
boxes = np.array([[0.0, 0.0, float(w), float(h)]], dtype=np.float32)
|
| 329 |
else:
|
| 330 |
+
# Detect people
|
| 331 |
inputs = det_proc(images=img, return_tensors="pt").to(dev)
|
| 332 |
with torch.no_grad():
|
| 333 |
outputs = det_model(**inputs)
|
|
|
|
| 338 |
threshold=det_thresh,
|
| 339 |
)[0]
|
| 340 |
|
| 341 |
+
# COCO label 0 is "person" for COCO-trained detectors
|
| 342 |
person_boxes = results["boxes"][results["labels"] == 0].detach().cpu().numpy()
|
| 343 |
|
| 344 |
if person_boxes.size == 0:
|
| 345 |
+
# Fallback to full-frame
|
| 346 |
boxes = np.array([[0.0, 0.0, float(w), float(h)]], dtype=np.float32)
|
| 347 |
else:
|
| 348 |
+
# Convert VOC x1,y1,x2,y2 to COCO x,y,w,h
|
| 349 |
person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
|
| 350 |
person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
|
| 351 |
boxes = person_boxes.astype(np.float32)
|
|
|
|
| 359 |
|
| 360 |
pose_results = pose_proc.post_process_pose_estimation(pose_outputs, boxes=[boxes])[0]
|
| 361 |
|
| 362 |
+
persons_kps = []
|
| 363 |
+
persons_sc = []
|
| 364 |
for pr in pose_results:
|
| 365 |
+
kps = pr["keypoints"].detach().cpu().numpy()
|
| 366 |
+
sc = pr["scores"].detach().cpu().numpy()
|
| 367 |
+
persons_kps.append(kps)
|
| 368 |
+
persons_sc.append(sc)
|
| 369 |
|
| 370 |
if not persons_kps:
|
| 371 |
+
# No pose found; return black canvas
|
| 372 |
return Image.new("RGB", img.size, (0, 0, 0))
|
| 373 |
|
| 374 |
return _draw_skeleton_on_blank(img.size, persons_kps, persons_sc)
|
| 375 |
|
| 376 |
+
|
| 377 |
def make_depth_map(img: Image.Image, *, use_gpu: bool) -> Image.Image:
|
| 378 |
+
"""Return a grayscale (RGB) depth map using Depth Anything V2 Small."""
|
| 379 |
img = img.convert("RGB")
|
| 380 |
dev = _derived_device(use_gpu)
|
| 381 |
proc, model = _load_depth_models(dev)
|
|
|
|
| 386 |
with torch.no_grad():
|
| 387 |
out = model(**inputs)
|
| 388 |
|
| 389 |
+
# predicted_depth: (B, H, W)
|
| 390 |
pred = out.predicted_depth
|
| 391 |
+
|
| 392 |
+
# Upsample to original image size
|
| 393 |
pred = torch.nn.functional.interpolate(
|
| 394 |
pred.unsqueeze(1),
|
| 395 |
size=(img.height, img.width),
|
|
|
|
| 403 |
arr = arr / denom
|
| 404 |
|
| 405 |
depth8 = (arr * 255.0).clip(0, 255).astype(np.uint8)
|
| 406 |
+
depth_img = Image.fromarray(depth8, mode="L").convert("RGB")
|
| 407 |
+
return depth_img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
def _append_to_gallery(existing, new_img: Image.Image):
|
| 411 |
items = []
|
|
|
|
| 417 |
items.append(new_img)
|
| 418 |
return items
|
| 419 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
# ============================================================
|
| 421 |
+
# LoRA adapters + presets
|
| 422 |
# ============================================================
|
| 423 |
|
| 424 |
NONE_LORA = "None"
|
|
|
|
| 486 |
"weights": "bfs_head_v5_2511_original.safetensors",
|
| 487 |
"adapter_name": "BFS-Best-Faceswap",
|
| 488 |
"strength": 1.0,
|
| 489 |
+
"needs_alpha_fix": True, # <-- fixes KeyError 'img_in.alpha'
|
| 490 |
},
|
| 491 |
"BFS-Best-FaceSwap-merge": {
|
| 492 |
"type": "single",
|
|
|
|
| 496 |
"weights": "bfs_head_v5_2511_merged_version_rank_32_fp32.safetensors",
|
| 497 |
"adapter_name": "BFS-Best-Faceswap-merge",
|
| 498 |
"strength": 1.1,
|
| 499 |
+
"needs_alpha_fix": True, # <-- fixes KeyError 'img_in.alpha'
|
| 500 |
},
|
| 501 |
"F2P": {
|
| 502 |
"type": "single",
|
|
|
|
| 582 |
"BFS-Best-FaceSwap-merge": "head_swap: start with Picture 1 as the base image, keeping its lighting, environment, and background. remove the head from Picture 1 completely and replace it with the head from Picture 2, strictly preserving the hair, eye color, and nose structure of Picture 2. copy the eye direction, head rotation, and micro-expressions from Picture 1. high quality, sharp details, 4k",
|
| 583 |
}
|
| 584 |
|
| 585 |
+
# Track what is currently loaded in memory (adapter_name values)
|
| 586 |
LOADED_ADAPTERS = set()
|
| 587 |
|
|
|
|
| 588 |
# ============================================================
|
| 589 |
+
# Helpers: resolution
|
| 590 |
# ============================================================
|
| 591 |
|
| 592 |
+
# We prefer *area-based* sizing (≈ megapixels) over long-edge sizing.
|
| 593 |
+
# This aligns better with Qwen-Image-Edit's internal assumptions and reduces FOV drift.
|
| 594 |
+
|
| 595 |
def _round_to_multiple(x: int, m: int) -> int:
|
| 596 |
return max(m, (int(x) // m) * m)
|
| 597 |
|
|
|
|
| 600 |
target_area: int,
|
| 601 |
multiple_of: int,
|
| 602 |
) -> tuple[int, int]:
|
| 603 |
+
"""Compute (width, height) that matches image aspect ratio and approximates target_area.
|
| 604 |
+
|
| 605 |
+
The result is floored to be divisible by multiple_of (typically vae_scale_factor*2).
|
| 606 |
+
"""
|
| 607 |
w, h = image.size
|
| 608 |
aspect = w / h if h else 1.0
|
| 609 |
+
|
| 610 |
+
# Use the pipeline's own area->(w,h) helper for consistency.
|
| 611 |
from qwenimage.pipeline_qwenimage_edit_plus import calculate_dimensions
|
| 612 |
+
|
| 613 |
width, height = calculate_dimensions(int(target_area), float(aspect))
|
| 614 |
width = _round_to_multiple(int(width), int(multiple_of))
|
| 615 |
height = _round_to_multiple(int(height), int(multiple_of))
|
|
|
|
| 620 |
lora_adapter: str,
|
| 621 |
user_target_megapixels: float,
|
| 622 |
) -> int:
|
| 623 |
+
"""Return target pixel area for the canvas.
|
| 624 |
+
|
| 625 |
+
Priority:
|
| 626 |
+
1) Adapter spec: target_area (pixels) or target_megapixels
|
| 627 |
+
2) Adapter spec: target_long_edge (legacy) -> converted to area using image aspect
|
| 628 |
+
3) User slider target megapixels
|
| 629 |
+
"""
|
| 630 |
spec = ADAPTER_SPECS.get(lora_adapter, {})
|
| 631 |
+
|
| 632 |
if "target_area" in spec:
|
| 633 |
try:
|
| 634 |
return int(spec["target_area"])
|
| 635 |
except Exception:
|
| 636 |
pass
|
| 637 |
+
|
| 638 |
if "target_megapixels" in spec:
|
| 639 |
try:
|
| 640 |
mp = float(spec["target_megapixels"])
|
| 641 |
return int(mp * 1024 * 1024)
|
| 642 |
except Exception:
|
| 643 |
pass
|
| 644 |
+
|
| 645 |
+
# Legacy support (e.g. Upscale2K)
|
| 646 |
if "target_long_edge" in spec:
|
| 647 |
try:
|
| 648 |
long_edge = int(spec["target_long_edge"])
|
|
|
|
| 656 |
return int(new_w * new_h)
|
| 657 |
except Exception:
|
| 658 |
pass
|
|
|
|
| 659 |
|
| 660 |
+
# User default
|
| 661 |
+
return int(float(user_target_megapixels) * 1024 * 1024)
|
| 662 |
|
| 663 |
# ============================================================
|
| 664 |
+
# Helpers: multi-input routing + gallery normalization
|
| 665 |
# ============================================================
|
| 666 |
|
| 667 |
+
|
| 668 |
def lora_requires_two_images(lora_adapter: str) -> bool:
|
| 669 |
return bool(ADAPTER_SPECS.get(lora_adapter, {}).get("requires_two_images", False))
|
| 670 |
|
| 671 |
+
|
| 672 |
def image2_label_for_lora(lora_adapter: str) -> str:
|
| 673 |
return str(ADAPTER_SPECS.get(lora_adapter, {}).get("image2_label", "Upload Reference (Image 2)"))
|
| 674 |
|
| 675 |
+
|
| 676 |
+
def _to_pil_rgb(x) -> Optional[Image.Image]:
|
| 677 |
+
"""
|
| 678 |
+
Accepts PIL / numpy / (image, caption) tuples from gr.Gallery and returns PIL RGB.
|
| 679 |
+
Gradio Gallery commonly yields tuples like (image, caption).
|
| 680 |
+
"""
|
| 681 |
+
if x is None:
|
| 682 |
+
return None
|
| 683 |
+
|
| 684 |
+
# Gallery often returns (image, caption)
|
| 685 |
+
if isinstance(x, tuple) and len(x) >= 1:
|
| 686 |
+
x = x[0]
|
| 687 |
+
if x is None:
|
| 688 |
+
return None
|
| 689 |
+
|
| 690 |
+
if isinstance(x, Image.Image):
|
| 691 |
+
return x.convert("RGB")
|
| 692 |
+
|
| 693 |
+
if isinstance(x, np.ndarray):
|
| 694 |
+
return Image.fromarray(x).convert("RGB")
|
| 695 |
+
|
| 696 |
+
# Best-effort fallback
|
| 697 |
+
try:
|
| 698 |
+
return Image.fromarray(np.array(x)).convert("RGB")
|
| 699 |
+
except Exception:
|
| 700 |
+
return None
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
def build_labeled_images(
|
| 704 |
+
img1: Image.Image,
|
| 705 |
+
img2: Optional[Image.Image],
|
| 706 |
+
extra_imgs: Optional[list[Image.Image]],
|
| 707 |
+
) -> dict[str, Image.Image]:
|
| 708 |
+
"""
|
| 709 |
+
Creates labels image_1, image_2, image_3... based on what is actually uploaded:
|
| 710 |
+
- img1 is always image_1
|
| 711 |
+
- img2 becomes image_2 only if present
|
| 712 |
+
- extras start immediately after the last present base box
|
| 713 |
+
The pipeline receives images in this exact order.
|
| 714 |
+
"""
|
| 715 |
+
labeled: dict[str, Image.Image] = {}
|
| 716 |
+
idx = 1
|
| 717 |
+
|
| 718 |
+
labeled[f"image_{idx}"] = img1
|
| 719 |
+
idx += 1
|
| 720 |
+
|
| 721 |
+
if img2 is not None:
|
| 722 |
+
labeled[f"image_{idx}"] = img2
|
| 723 |
+
idx += 1
|
| 724 |
+
|
| 725 |
+
if extra_imgs:
|
| 726 |
+
for im in extra_imgs:
|
| 727 |
+
if im is None:
|
| 728 |
+
continue
|
| 729 |
+
labeled[f"image_{idx}"] = im
|
| 730 |
+
idx += 1
|
| 731 |
+
|
| 732 |
+
return labeled
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
# ============================================================
|
| 736 |
+
# Helpers: BFS alpha key fix
|
| 737 |
+
# ============================================================
|
| 738 |
+
|
| 739 |
+
|
| 740 |
def _inject_missing_alpha_keys(state_dict: dict) -> dict:
|
| 741 |
+
"""
|
| 742 |
+
Diffusers' Qwen LoRA converter expects '<module>.alpha' keys.
|
| 743 |
+
BFS safetensors omits them. We inject alpha = rank (neutral scaling).
|
| 744 |
+
|
| 745 |
+
IMPORTANT: diffusers may strip 'diffusion_model.' before lookup, so we
|
| 746 |
+
inject BOTH:
|
| 747 |
+
- diffusion_model.xxx.alpha
|
| 748 |
+
- xxx.alpha
|
| 749 |
+
"""
|
| 750 |
bases = {}
|
| 751 |
+
|
| 752 |
for k, v in state_dict.items():
|
| 753 |
if not isinstance(v, torch.Tensor):
|
| 754 |
continue
|
|
|
|
| 769 |
stripped_alpha = f"{stripped_base}.alpha"
|
| 770 |
if stripped_alpha not in state_dict:
|
| 771 |
state_dict[stripped_alpha] = alpha_tensor
|
| 772 |
+
|
| 773 |
return state_dict
|
| 774 |
|
| 775 |
+
|
| 776 |
def _filter_to_diffusers_lora_keys(state_dict: dict) -> tuple[dict, dict]:
|
| 777 |
+
"""Return (filtered_state_dict, stats).
|
| 778 |
+
|
| 779 |
+
Some ComfyUI/Qwen safetensors (especially "merged" variants) include non-LoRA
|
| 780 |
+
delta/patch keys like `*.diff` and `*.diff_b` alongside real LoRA tensors.
|
| 781 |
+
Diffusers' internal Qwen LoRA converter is strict: any leftover keys cause an
|
| 782 |
+
error (`state_dict should be empty...`).
|
| 783 |
+
|
| 784 |
+
This helper keeps only the keys Diffusers can consume as a LoRA:
|
| 785 |
+
- `*.lora_up.weight`
|
| 786 |
+
- `*.lora_down.weight`
|
| 787 |
+
- (rare) `*.lora_mid.weight`
|
| 788 |
+
- alpha keys: `*.alpha` (or `*.lora_alpha` which we normalize to `*.alpha`)
|
| 789 |
+
|
| 790 |
+
It also drops known patch keys (`*.diff`, `*.diff_b`) and everything else.
|
| 791 |
+
"""
|
| 792 |
+
|
| 793 |
keep_suffixes = (
|
| 794 |
".lora_up.weight",
|
| 795 |
".lora_down.weight",
|
|
|
|
| 797 |
".alpha",
|
| 798 |
".lora_alpha",
|
| 799 |
)
|
| 800 |
+
|
| 801 |
dropped_patch = 0
|
| 802 |
dropped_other = 0
|
| 803 |
kept = 0
|
|
|
|
| 806 |
out: dict[str, torch.Tensor] = {}
|
| 807 |
for k, v in state_dict.items():
|
| 808 |
if not isinstance(v, torch.Tensor):
|
| 809 |
+
# Ignore non-tensor entries if any.
|
| 810 |
dropped_other += 1
|
| 811 |
continue
|
| 812 |
+
|
| 813 |
+
# Drop ComfyUI "delta" keys that Diffusers' LoRA loader will never consume.
|
| 814 |
if k.endswith(".diff") or k.endswith(".diff_b"):
|
| 815 |
dropped_patch += 1
|
| 816 |
continue
|
| 817 |
+
|
| 818 |
if not k.endswith(keep_suffixes):
|
| 819 |
dropped_other += 1
|
| 820 |
continue
|
| 821 |
+
|
| 822 |
if k.endswith(".lora_alpha"):
|
| 823 |
+
# Normalize common alt name to what Diffusers expects.
|
| 824 |
base = k[: -len(".lora_alpha")]
|
| 825 |
k2 = f"{base}.alpha"
|
| 826 |
out[k2] = v.float() if v.dtype != torch.float32 else v
|
| 827 |
normalized_alpha += 1
|
| 828 |
kept += 1
|
| 829 |
continue
|
| 830 |
+
|
| 831 |
out[k] = v
|
| 832 |
kept += 1
|
| 833 |
|
|
|
|
| 839 |
}
|
| 840 |
return out, stats
|
| 841 |
|
| 842 |
+
|
| 843 |
def _duplicate_stripped_prefix_keys(state_dict: dict, prefix: str = "diffusion_model.") -> dict:
|
| 844 |
+
"""Ensure both prefixed and unprefixed variants exist for LoRA-related keys.
|
| 845 |
+
|
| 846 |
+
Diffusers' Qwen LoRA conversion may strip `diffusion_model.` when looking up
|
| 847 |
+
modules. Some exports only include prefixed keys. To be maximally compatible,
|
| 848 |
+
we duplicate LoRA keys (and alpha) in stripped form when missing.
|
| 849 |
+
"""
|
| 850 |
+
|
| 851 |
out = dict(state_dict)
|
| 852 |
for k, v in list(state_dict.items()):
|
| 853 |
if not k.startswith(prefix):
|
|
|
|
| 857 |
out[stripped] = v
|
| 858 |
return out
|
| 859 |
|
| 860 |
+
|
| 861 |
def _load_lora_weights_with_fallback(repo: str, weight_name: str, adapter_name: str, needs_alpha_fix: bool = False):
|
| 862 |
+
"""
|
| 863 |
+
Normal path: pipe.load_lora_weights(repo, weight_name=..., adapter_name=...)
|
| 864 |
+
BFS fallback: download safetensors, inject missing alpha keys, then load from dict.
|
| 865 |
+
"""
|
| 866 |
try:
|
| 867 |
pipe.load_lora_weights(repo, weight_name=weight_name, adapter_name=adapter_name)
|
| 868 |
return
|
| 869 |
except (KeyError, ValueError) as e:
|
| 870 |
+
# KeyError: missing required alpha keys (common in BFS)
|
| 871 |
+
# ValueError: Diffusers Qwen converter found leftover keys (e.g. .diff/.diff_b)
|
| 872 |
if not needs_alpha_fix:
|
| 873 |
raise
|
| 874 |
|
|
|
|
| 880 |
local_path = hf_hub_download(repo_id=repo, filename=weight_name)
|
| 881 |
sd = safetensors_load_file(local_path)
|
| 882 |
|
| 883 |
+
# 1) Inject required `<module>.alpha` keys (neutral scaling alpha=rank).
|
| 884 |
sd = _inject_missing_alpha_keys(sd)
|
| 885 |
+
|
| 886 |
+
# 2) Keep only LoRA + alpha keys; drop ComfyUI patch/delta keys.
|
| 887 |
sd, stats = _filter_to_diffusers_lora_keys(sd)
|
| 888 |
+
|
| 889 |
+
# 3) Duplicate stripped keys (remove `diffusion_model.`) for compatibility.
|
| 890 |
sd = _duplicate_stripped_prefix_keys(sd)
|
| 891 |
|
| 892 |
print(
|
|
|
|
| 894 |
f"kept={stats['kept']} dropped_patch={stats['dropped_patch']} "
|
| 895 |
f"dropped_other={stats['dropped_other']} normalized_alpha={stats['normalized_alpha']}"
|
| 896 |
)
|
| 897 |
+
|
| 898 |
pipe.load_lora_weights(sd, adapter_name=adapter_name)
|
| 899 |
return
|
| 900 |
|
| 901 |
+
|
| 902 |
+
# ============================================================
|
| 903 |
+
# LoRA loader: single/package + strengths
|
| 904 |
+
# ============================================================
|
| 905 |
+
|
| 906 |
+
|
| 907 |
def _ensure_loaded_and_get_active_adapters(selected_lora: str):
|
| 908 |
spec = ADAPTER_SPECS.get(selected_lora)
|
| 909 |
if not spec:
|
|
|
|
| 941 |
|
| 942 |
adapter_names.append(adapter_name)
|
| 943 |
adapter_weights.append(strength)
|
| 944 |
+
|
| 945 |
else:
|
| 946 |
repo = spec["repo"]
|
| 947 |
weights = spec["weights"]
|
|
|
|
| 974 |
# UI handlers
|
| 975 |
# ============================================================
|
| 976 |
|
| 977 |
+
|
| 978 |
def on_lora_change_ui(selected_lora, current_prompt, current_extras_condition_only):
|
| 979 |
+
# Preset prompt (fill only if empty)
|
| 980 |
if selected_lora != NONE_LORA:
|
| 981 |
preset = LORA_PRESET_PROMPTS.get(selected_lora, "")
|
| 982 |
if preset and (current_prompt is None or str(current_prompt).strip() == ""):
|
|
|
|
| 986 |
else:
|
| 987 |
prompt_update = gr.update(value=current_prompt)
|
| 988 |
|
| 989 |
+
# Image2 visibility/label
|
| 990 |
if lora_requires_two_images(selected_lora):
|
| 991 |
img2_update = gr.update(visible=True, label=image2_label_for_lora(selected_lora))
|
| 992 |
else:
|
| 993 |
img2_update = gr.update(visible=False, value=None, label="Upload Reference (Image 2)")
|
| 994 |
|
| 995 |
+
# Extra references routing default:
|
| 996 |
+
# For BFS/AnyPose-like adapters, it's usually safer to keep extra refs as conditioning-only.
|
| 997 |
if selected_lora in ("BFS-Best-FaceSwap", "BFS-Best-FaceSwap-merge", "AnyPose"):
|
| 998 |
extras_update = gr.update(value=True)
|
| 999 |
else:
|
| 1000 |
extras_update = gr.update(value=current_extras_condition_only)
|
| 1001 |
|
| 1002 |
return prompt_update, img2_update, extras_update
|
|
|
|
|
|
|
| 1003 |
# ============================================================
|
| 1004 |
# UI helpers: output routing + derived conditioning
|
| 1005 |
# ============================================================
|
|
|
|
| 1009 |
raise gr.Error("No output available yet.")
|
| 1010 |
return gr.update(value=last)
|
| 1011 |
|
| 1012 |
+
|
| 1013 |
def set_output_as_image2(last):
|
| 1014 |
if last is None:
|
| 1015 |
raise gr.Error("No output available yet.")
|
| 1016 |
return gr.update(value=last)
|
| 1017 |
|
| 1018 |
+
|
| 1019 |
def set_output_as_extra(last, existing_extra):
|
| 1020 |
if last is None:
|
| 1021 |
raise gr.Error("No output available yet.")
|
| 1022 |
return _append_to_gallery(existing_extra, last)
|
| 1023 |
|
| 1024 |
+
|
| 1025 |
@spaces.GPU
|
| 1026 |
def add_derived_ref(img1, existing_extra, derived_type, derived_use_gpu, derived_max_people):
|
| 1027 |
if img1 is None:
|
|
|
|
| 1050 |
return gr.update(value=new_gallery), gr.update(visible=True, value=derived)
|
| 1051 |
|
| 1052 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1053 |
|
| 1054 |
# ============================================================
|
| 1055 |
# Inference
|
| 1056 |
# ============================================================
|
| 1057 |
|
| 1058 |
+
|
| 1059 |
@spaces.GPU
|
| 1060 |
def infer(
|
| 1061 |
input_image_1,
|
| 1062 |
input_image_2,
|
| 1063 |
+
input_images_extra, # gallery multi-image box
|
| 1064 |
prompt,
|
| 1065 |
lora_adapter,
|
| 1066 |
seed,
|
|
|
|
| 1076 |
if torch.cuda.is_available():
|
| 1077 |
torch.cuda.empty_cache()
|
| 1078 |
|
|
|
|
|
|
|
| 1079 |
if input_image_1 is None:
|
| 1080 |
raise gr.Error("Please upload Image 1.")
|
| 1081 |
|
|
|
|
| 1094 |
seed = random.randint(0, MAX_SEED)
|
| 1095 |
|
| 1096 |
generator = torch.Generator(device=device).manual_seed(seed)
|
|
|
|
| 1097 |
negative_prompt = (
|
| 1098 |
"worst quality, low quality, bad anatomy, bad hands, text, error, missing fingers, "
|
| 1099 |
"extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"
|
| 1100 |
)
|
| 1101 |
|
| 1102 |
+
img1 = input_image_1.convert("RGB")
|
| 1103 |
+
img2 = input_image_2.convert("RGB") if input_image_2 is not None else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1104 |
|
| 1105 |
+
# Normalize extra images (Gallery) to PIL RGB (handles tuples from Gallery)
|
| 1106 |
extra_imgs: list[Image.Image] = []
|
| 1107 |
if input_images_extra:
|
| 1108 |
for item in input_images_extra:
|
|
|
|
| 1110 |
if pil is not None:
|
| 1111 |
extra_imgs.append(pil)
|
| 1112 |
|
| 1113 |
+
# Enforce existing 2-image LoRA behavior (image_1 + image_2 required)
|
| 1114 |
if lora_requires_two_images(lora_adapter) and img2 is None:
|
| 1115 |
raise gr.Error("This LoRA needs two images. Please upload Image 2 as well.")
|
| 1116 |
|
| 1117 |
+
# Label images as image_1, image_2, image_3...
|
| 1118 |
labeled = build_labeled_images(img1, img2, extra_imgs)
|
| 1119 |
|
| 1120 |
+
# Pass to pipeline in labeled order. Keep single-image call when only one is present.
|
| 1121 |
pipe_images = list(labeled.values())
|
| 1122 |
if len(pipe_images) == 1:
|
| 1123 |
pipe_images = pipe_images[0]
|
| 1124 |
|
| 1125 |
+
# Resolution derived from Image 1 (base/body/target)
|
| 1126 |
+
# Use target *area* (≈ megapixels) rather than long-edge sizing to reduce FOV drift.
|
| 1127 |
target_area = get_target_area_for_lora(img1, lora_adapter, float(target_megapixels))
|
| 1128 |
width, height = compute_canvas_dimensions_from_area(
|
| 1129 |
img1,
|
|
|
|
| 1131 |
multiple_of=int(pipe.vae_scale_factor * 2),
|
| 1132 |
)
|
| 1133 |
|
| 1134 |
+
# Decide which images participate in the VAE latent stream.
|
| 1135 |
+
# If enabled, extra references beyond (Img_1, Img_2) become conditioning-only.
|
| 1136 |
vae_image_indices = None
|
| 1137 |
if extras_condition_only:
|
| 1138 |
if isinstance(pipe_images, list) and len(pipe_images) > 2:
|
| 1139 |
vae_image_indices = [0, 1] if len(pipe_images) >= 2 else [0]
|
| 1140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1141 |
try:
|
| 1142 |
+
print(
|
| 1143 |
+
"[DEBUG][infer] submitting request | "
|
| 1144 |
+
f"lora_adapter={lora_adapter!r} seed={seed} prompt={prompt!r}"
|
| 1145 |
+
)
|
| 1146 |
+
|
| 1147 |
result = pipe(
|
| 1148 |
image=pipe_images,
|
| 1149 |
prompt=prompt,
|
| 1150 |
negative_prompt=negative_prompt,
|
| 1151 |
height=height,
|
| 1152 |
width=width,
|
| 1153 |
+
num_inference_steps=steps,
|
| 1154 |
generator=generator,
|
| 1155 |
+
true_cfg_scale=guidance_scale,
|
| 1156 |
vae_image_indices=vae_image_indices,
|
| 1157 |
pad_to_canvas=bool(pad_to_canvas),
|
| 1158 |
).images[0]
|
| 1159 |
return result, seed, result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1160 |
finally:
|
|
|
|
| 1161 |
gc.collect()
|
| 1162 |
if torch.cuda.is_available():
|
| 1163 |
torch.cuda.empty_cache()
|
|
|
|
| 1170 |
input_pil = input_image.convert("RGB")
|
| 1171 |
guidance_scale = 1.0
|
| 1172 |
steps = 4
|
| 1173 |
+
# Examples don't supply Image 2 or extra images; and example list doesn't include AnyPose/BFS.
|
| 1174 |
+
result, seed, last = infer(input_pil, None, None, prompt, lora_adapter, 0, True, guidance_scale, steps, 1.0, True, True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1175 |
return result, seed, last
|
| 1176 |
|
| 1177 |
|
|
|
|
| 1198 |
gr.Markdown(
|
| 1199 |
"Perform diverse image edits using specialized "
|
| 1200 |
"[LoRA](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image-Edit-2511) adapters for the "
|
| 1201 |
+
"[Qwen-Image-Edit](https://huggingface.co/Qwen/Qwen-Image-Edit-2511) model. Uses a Diffusers compatible extraction of the transformers from Phr00t's Rapid AIO merge. If a different AIO version is desired, copy the space and set the space variable to change version.'"
|
| 1202 |
)
|
| 1203 |
gr.Markdown(aio_status_line)
|
| 1204 |
|
|
|
|
| 1292 |
value=True,
|
| 1293 |
)
|
| 1294 |
|
| 1295 |
+
# On LoRA selection: preset prompt + toggle Image 2
|
| 1296 |
lora_adapter.change(
|
| 1297 |
fn=on_lora_change_ui,
|
| 1298 |
inputs=[lora_adapter, prompt, extras_condition_only],
|
|
|
|
| 1352 |
outputs=[output_image, seed, last_output],
|
| 1353 |
)
|
| 1354 |
|
| 1355 |
+
# Output routing buttons
|
| 1356 |
btn_out_to_img1.click(fn=set_output_as_image1, inputs=[last_output], outputs=[input_image_1])
|
| 1357 |
btn_out_to_img2.click(fn=set_output_as_image2, inputs=[last_output], outputs=[input_image_2])
|
| 1358 |
btn_out_to_extra.click(fn=set_output_as_extra, inputs=[last_output, input_images_extra], outputs=[input_images_extra])
|
| 1359 |
+
|
| 1360 |
+
# Derived conditioning: append pose/depth map as extra ref (UI shows preview)
|
| 1361 |
add_derived_btn.click(
|
| 1362 |
fn=add_derived_ref,
|
| 1363 |
inputs=[input_image_1, input_images_extra, derived_type, derived_use_gpu, derived_max_people],
|
| 1364 |
outputs=[input_images_extra, derived_preview],
|
| 1365 |
)
|
| 1366 |
+
|
| 1367 |
if __name__ == "__main__":
|
| 1368 |
demo.queue(max_size=30).launch(
|
| 1369 |
css=css,
|