-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
43 lines (33 loc) · 1.9 KB
/
utils.py
File metadata and controls
43 lines (33 loc) · 1.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# -------------------------------------------
#
# File Name: utils.py
# Author: WANG Yiyang
# Created: Feb.18, 2025
# Description: Utility functions for DiffCamera.
#
# -------------------------------------------
import os
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--bokeh_level', type=int, default=10, help="Bokeh(blur) level for the image. Higher values mean more blur. Ranging from 0 to 30")
parser.add_argument('--focus_point_x',type=float, default=0.5, help="Normalized focus point x-coordinate in the image. Ranging from 0 to 1, where 0 is the top edge and 1 is the bottom edge.")
parser.add_argument('--focus_point_y',type=float, default=0.5, help="Normalized focus point y-coordinate in the image. Ranging from 0 to 1, where 0 is the left edge and 1 is the right edge.")
parser.add_argument('--depth_model_path', type=str, default=None)
parser.add_argument('--mixed_precision', type=str, default='bf16')
parser.add_argument('--pretrained_model_name_or_path', type=str, required=True)
parser.add_argument('--revision', type=str, default=None)
parser.add_argument('--variant', type=str, default=None)
parser.add_argument('--camera_emb_token_num', type=int, default=1)
parser.add_argument('--image_size', type=int, default=512)
parser.add_argument('--include_depth', action='store_true')
parser.add_argument('--include_focus_depth', action='store_true')
parser.add_argument('--zero_depth', action='store_true')
parser.add_argument('--cfg_scale', type=float, default=1.5)
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--resume_from_checkpoint', type=str, default=None)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if hasattr(args, "local_rank") and env_local_rank != -1:
args.local_rank = env_local_rank
return args