Spaces:
Sleeping
Sleeping
| import functools | |
| import importlib | |
| import os | |
| from functools import partial | |
| from inspect import isfunction | |
| import fsspec | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| from safetensors.torch import load_file as load_safetensors | |
| def disabled_train(self, mode=True): | |
| """使用此函数重写 model.train,以确保训练/评估模式不再改变。""" | |
| return self | |
| def get_string_from_tuple(s): | |
| try: | |
| # 检查字符串是否以小括号开始和结束 | |
| if s[0] == "(" and s[-1] == ")": | |
| # 将字符串转换为元组 | |
| t = eval(s) | |
| # 检查 t 的类型是否为元组 | |
| if type(t) == tuple: | |
| return t[0] | |
| else: | |
| pass | |
| except: | |
| pass | |
| return s | |
| def is_power_of_two(n): | |
| """ | |
| chat.openai.com/chat | |
| 如果 n 是 2 的幂,则返回 True,否则返回 False。 | |
| 函数 is_power_of_two 将整数 n 作为输入,如果 n 是 2 的幂,则返回 True,否则返回 False。 | |
| 如果 n 小于或等于 0,就不可能是 2 的幂,因此函数返回 False。 | |
| 如果 n 大于 0,函数会在 n 和 n-1 之间使用比特 AND 运算检查 n 是否是 2 的幂。如果 n 是 2 的幂,那么它的二进制表示中只有一位被置 1。当我们从 2 的幂中减去 1 时,该位右边的所有位都会变为 1,而该位本身则变为 0。 因此,当我们在 n 和 n-1 之间进行位和运算时,如果 n 是 2 的幂,则得到 0,否则得到一个非零值。 | |
| 因此,如果位与运算的结果为 0,则 n 是 2 的幂,函数返回 True。否则,函数返回 False。 | |
| """ | |
| if n <= 0: | |
| return False | |
| return (n & (n - 1)) == 0 | |
| def autocast(f, enabled=True): | |
| def do_autocast(*args, **kwargs): | |
| with torch.cuda.amp.autocast( | |
| enabled=enabled, | |
| dtype=torch.get_autocast_gpu_dtype(), | |
| cache_enabled=torch.is_autocast_cache_enabled(), | |
| ): | |
| return f(*args, **kwargs) | |
| return do_autocast | |
| def load_partial_from_config(config): | |
| return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) | |
| def log_txt_as_img(wh, xc, size=10): | |
| # wh 一个四元组 (width, height) | |
| # xc 要绘制的标题列表 | |
| b = len(xc) | |
| txts = list() | |
| for bi in range(b): | |
| txt = Image.new("RGB", wh, color="white") | |
| draw = ImageDraw.Draw(txt) | |
| font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) | |
| nc = int(40 * (wh[0] / 256)) | |
| if isinstance(xc[bi], list): | |
| text_seq = xc[bi][0] | |
| else: | |
| text_seq = xc[bi] | |
| lines = "\n".join( | |
| text_seq[start : start + nc] for start in range(0, len(text_seq), nc) | |
| ) | |
| try: | |
| draw.text((0, 0), lines, fill="black", font=font) | |
| except UnicodeEncodeError: | |
| print("无法对字符串进行编码以记录日志。跳过。") | |
| txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 | |
| txts.append(txt) | |
| txts = np.stack(txts) | |
| txts = torch.tensor(txts) | |
| return txts | |
| def partialclass(cls, *args, **kwargs): | |
| class NewCls(cls): | |
| __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) | |
| return NewCls | |
| def make_path_absolute(path): | |
| fs, p = fsspec.core.url_to_fs(path) | |
| if fs.protocol == "file": | |
| return os.path.abspath(p) | |
| return path | |
| def ismap(x): | |
| if not isinstance(x, torch.Tensor): | |
| return False | |
| return (len(x.shape) == 4) and (x.shape[1] > 3) | |
| def isimage(x): | |
| if not isinstance(x, torch.Tensor): | |
| return False | |
| return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) | |
| def isheatmap(x): | |
| if not isinstance(x, torch.Tensor): | |
| return False | |
| return x.ndim == 2 | |
| def isneighbors(x): | |
| if not isinstance(x, torch.Tensor): | |
| return False | |
| return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) | |
| def exists(x): | |
| return x is not None | |
| def expand_dims_like(x, y): | |
| while x.dim() != y.dim(): | |
| x = x.unsqueeze(-1) | |
| return x | |
| def default(val, d): | |
| if exists(val): | |
| return val | |
| return d() if isfunction(d) else d | |
| def mean_flat(tensor): | |
| """ | |
| https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 | |
| 取所有非批次维度的平均值。 | |
| """ | |
| return tensor.mean(dim=list(range(1, len(tensor.shape)))) | |
| def count_params(model, verbose=False): | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| if verbose: | |
| print(f"{model.__class__.__name__} 拥有 {total_params * 1.e-6:.2f} 个 M 参数。") | |
| return total_params | |
| def instantiate_from_config(config): | |
| if not "target" in config: | |
| if config == "__is_first_stage__": | |
| return None | |
| elif config == "__is_unconditional__": | |
| return None | |
| raise KeyError("预期键 `target` 将被实例化。") | |
| return get_obj_from_str(config["target"])(**config.get("params", dict())) | |
| def get_obj_from_str(string, reload=False, invalidate_cache=True): | |
| module, cls = string.rsplit(".", 1) | |
| if invalidate_cache: | |
| importlib.invalidate_caches() | |
| if reload: | |
| module_imp = importlib.import_module(module) | |
| importlib.reload(module_imp) | |
| return getattr(importlib.import_module(module, package=None), cls) | |
| def append_zero(x): | |
| return torch.cat([x, x.new_zeros([1])]) | |
| def append_dims(x, target_dims): | |
| """将维数添加到张量的末尾,直到张量的维数达到 target_dims。""" | |
| dims_to_append = target_dims - x.ndim | |
| if dims_to_append < 0: | |
| raise ValueError( | |
| f"输入有 {x.ndim} 个尺寸,但 target_dims 是 {target_dims},最小" | |
| ) | |
| return x[(...,) + (None,) * dims_to_append] | |
| def load_model_from_config(config, ckpt, verbose=True, freeze=True): | |
| print(f"从 {ckpt} 加载模型") | |
| if ckpt.endswith("ckpt"): | |
| pl_sd = torch.load(ckpt, map_location="cpu") | |
| if "global_step" in pl_sd: | |
| print(f"全局步骤:{pl_sd['global_step']}") | |
| sd = pl_sd["state_dict"] | |
| elif ckpt.endswith("safetensors"): | |
| sd = load_safetensors(ckpt) | |
| else: | |
| raise NotImplementedError | |
| model = instantiate_from_config(config.model) | |
| m, u = model.load_state_dict(sd, strict=False) | |
| if len(m) > 0 and verbose: | |
| print("缺失 keys:") | |
| print(m) | |
| if len(u) > 0 and verbose: | |
| print("意料之外的 keys:") | |
| print(u) | |
| if freeze: | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| model.eval() | |
| return model | |
| def get_configs_path() -> str: | |
| """ | |
| 获取 `configs` 目录。 | |
| 对于工作拷贝来说,这是版本库根目录下的拷贝, | |
| 但对于已安装的副本,它在 `sgm` 软件包中(见 pyproject.toml)。 | |
| """ | |
| this_dir = os.path.dirname(__file__) | |
| candidates = ( | |
| os.path.join(this_dir, "configs"), | |
| os.path.join(this_dir, "..", "configs"), | |
| ) | |
| for candidate in candidates: | |
| candidate = os.path.abspath(candidate) | |
| if os.path.isdir(candidate): | |
| return candidate | |
| raise FileNotFoundError(f"无法在 {candidates} 找到 SGM 配置") | |