Spaces:
Build error
Build error
| from transformers import set_seed | |
| from tqdm.auto import trange | |
| from PIL import Image | |
| import numpy as np | |
| import random | |
| import utils | |
| import torch | |
| CONFIG_SPEC = [ | |
| ("General", [ | |
| ("text", "A cloud at dawn", str), | |
| ("iterations", 5000, (0, 7500)), | |
| ("seed", 12, int), | |
| ("show_every", 10, int), | |
| ]), | |
| ("Rendering", [ | |
| ("w", 224, [224, 252]), | |
| ("h", 224, [224, 252]), | |
| ("showoff", 5000, (0, 10000)), | |
| ("turns", 4, int), | |
| ("focal_length", 0.1, float), | |
| ("plane_width", 0.1, float), | |
| ("shade_strength", 0.25, float), | |
| ("gamma", 0.5, float), | |
| ("max_depth", 7, float), | |
| ("offset", 5, float), | |
| ("offset_random", 0.75, float), | |
| ("xyz_random", 0.25, float), | |
| ("altitude_range", 0.3, float), | |
| ("augments", 4, int), | |
| ]), | |
| ("Optimization", [ | |
| ("epochs", 6, int), | |
| ("lr", 0.6, float), | |
| #@markdown CLIP loss type, might improve the results | |
| ("loss_type", "spherical", ["spherical", "cosine"]), | |
| #@markdown CLIP loss weight | |
| ("clip_weight", 1.0, float), #@param {type: "number"} | |
| ]), | |
| ("Elements", [ | |
| ("num_objects", 256, int), | |
| #@markdown Number of dimensions. 0 is for point clouds (default), 1 will make | |
| #@markdown strokes, 2 will make planes, 3 produces little cubes | |
| ("ndim", 0, [0, 1, 2, 3]), #@param {type: "integer"} | |
| #@markdown Opacity scale: | |
| ("min_opacity", 1e-4, float), #@param {type: "number"} | |
| ("max_opacity", 1.0, float), #@param {type: "number"} | |
| ("log_opacity", False, bool), #@param {type: "boolean"} | |
| ("min_radius", 0.030, float), | |
| ("max_radius", 0.170, float), | |
| ("log_radius", False, bool), | |
| # TODO dynamically decide bezier_res | |
| #@markdown Bezier resolution: how many points a line/plane/cube will have. Not applicable to points | |
| ("bezier_res", 8, int), #@param {type: "integer"} | |
| #@markdown Maximum scale of parameters: position, velocity, acceleration | |
| ("pos_scale", 0.4, float), #@param {type: "number"} | |
| ("vel_scale", 0.15, float), #@param {type: "number"} | |
| ("acc_scale", 0.15, float), #@param {type: "number"} | |
| #@markdown Scale of each individual 3D object. Master control for velocity and acceleration scale. | |
| ("scale", 1, float), #@param {type: "number"} | |
| ]), | |
| ] | |
| # TODO: one day separate the config into multiple parts and split this megaobject into multiple objects | |
| # 2022/08/09: halfway done | |
| class PulsarCLIP(object): | |
| def __init__(self, args): | |
| args = DotDict(**args) | |
| set_seed(args.seed) | |
| self.args = args | |
| self.device = args.get("device", "cuda" if torch.cuda.is_available() else "cpu") | |
| # Defer the import so that we can import `pulsar_clip` and then install `pytorch3d` | |
| import pytorch3d.renderer.points.pulsar as ps | |
| self.ndim = int(self.args.ndim) | |
| self.renderer = ps.Renderer(self.args.w, self.args.h, | |
| self.args.num_objects * (self.args.bezier_res ** self.ndim)).to(self.device) | |
| self.bezier_pos = torch.nn.Parameter(torch.randn((args.num_objects, 4)).to(self.device)) | |
| self.bezier_vel = torch.nn.Parameter(torch.randn((args.num_objects, 3 * self.ndim)).to(self.device)) | |
| self.bezier_acc = torch.nn.Parameter(torch.randn((args.num_objects, 3 * self.ndim)).to(self.device)) | |
| self.bezier_col = torch.nn.Parameter(torch.randn((args.num_objects, 4 * (1 + self.ndim))).to(self.device)) | |
| self.optimizer = torch.optim.Adam([dict(params=[self.bezier_col], lr=5e-1 * args.lr), | |
| dict(params=[self.bezier_pos], lr=1e-1 * args.lr), | |
| dict(params=[self.bezier_vel, self.bezier_acc], lr=5e-2 * args.lr), | |
| ]) | |
| self.model_clip, self.preprocess_clip = utils.load_clip() | |
| self.model_clip.visual.requires_grad_(False) | |
| self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, | |
| int(self.args.iterations | |
| / self.args.augments | |
| / self.args.epochs), | |
| eta_min=args.lr / 100) | |
| import clip | |
| self.txt_emb = self.model_clip.encode_text(clip.tokenize([self.args.text]).to(self.device))[0].detach() | |
| self.txt_emb = torch.nn.functional.normalize(self.txt_emb, dim=-1) | |
| def get_points(self): | |
| if self.ndim > 0: | |
| bezier_ts = torch.stack(torch.meshgrid( | |
| (torch.linspace(0, 1, self.args.bezier_res, device=self.device),) * self.ndim), dim=0 | |
| ).unsqueeze(1).repeat((1, self.args.num_objects) + (1,) * self.ndim).unsqueeze(-1) | |
| def interpolate_3D(pos, vel=0.0, acc=0.0, pos_scale=None, vel_scale=None, acc_scale=None, scale=None): | |
| pos_scale = self.args.pos_scale if pos_scale is None else pos_scale | |
| vel_scale = self.args.vel_scale if vel_scale is None else vel_scale | |
| acc_scale = self.args.acc_scale if acc_scale is None else acc_scale | |
| scale = self.args.scale if scale is None else scale | |
| if self.ndim == 0: | |
| return pos * pos_scale | |
| result = 0.0 | |
| s = pos.shape[-1] | |
| assert s * self.ndim == vel.shape[-1] == acc.shape[-1] | |
| # O(dim) sequential lol | |
| for d, bezier_t in zip(range(self.ndim), bezier_ts): # TODO replace with fused dimension operation | |
| result = (result | |
| + torch.tanh(vel[..., d * s:(d + 1) * s]).view( | |
| (-1,) + (1,) * self.ndim + (s,)) * vel_scale * bezier_t | |
| + torch.tanh(acc[..., d * s:(d + 1) * s]).view( | |
| (-1,) + (1,) * self.ndim + (s,)) * acc_scale * bezier_t.pow(2)) | |
| result = (result * scale | |
| + torch.tanh(pos[..., :s]).view((-1,) + (1,) * self.ndim + (s,)) * pos_scale).view(-1, s) | |
| return result | |
| vert_pos = interpolate_3D(self.bezier_pos[..., :3], self.bezier_vel, self.bezier_acc) | |
| vert_col = interpolate_3D(self.bezier_col[..., :4], | |
| self.bezier_col[..., 4:4 + 4 * self.ndim], | |
| self.bezier_col[..., -4 * self.ndim:]) | |
| to_bezier = lambda x: x.view((-1,) + (1,) * self.ndim + (x.shape[-1],)).repeat( | |
| (1,) + (self.args.bezier_res,) * self.ndim + (1,)).reshape(-1, x.shape[-1]) | |
| rescale = lambda x, a, b, is_log=False: (torch.exp(x | |
| * np.log(b / a) | |
| + np.log(a))) if is_log else x * (b - a) + a | |
| return ( | |
| vert_pos, | |
| torch.sigmoid(vert_col[..., :3]), | |
| rescale( | |
| torch.sigmoid(to_bezier(self.bezier_pos[..., -1:])[..., 0]), | |
| self.args.min_radius, self.args.max_radius, is_log=self.args.log_radius | |
| ), | |
| rescale(torch.sigmoid(vert_col[..., -1]), | |
| self.args.min_opacity, self.args.max_opacity, is_log=self.args.log_opacity)) | |
| def camera(self, angle, altitude=0.0, offset=None, use_random=True, offset_random=None, | |
| xyz_random=None, focal_length=None, plane_width=None): | |
| if offset is None: | |
| offset = self.args.offset | |
| if xyz_random is None: | |
| xyz_random = self.args.xyz_random | |
| if focal_length is None: | |
| focal_length = self.args.focal_length | |
| if plane_width is None: | |
| plane_width = self.args.plane_width | |
| if offset_random is None: | |
| offset_random = self.args.offset_random | |
| device = self.device | |
| offset = offset + np.random.normal() * offset_random * int(use_random) | |
| position = torch.tensor([0, 0, -offset], dtype=torch.float) | |
| position = utils.rotate_axis(position, altitude, 0) | |
| position = utils.rotate_axis(position, angle, 1) | |
| position = position + torch.randn(3) * xyz_random * int(use_random) | |
| return torch.tensor([position[0], position[1], position[2], | |
| altitude, angle, 0, | |
| focal_length, plane_width], dtype=torch.float, device=device) | |
| def render(self, cam_params=None): | |
| if cam_params is None: | |
| cam_params = self.camera(0, 0) | |
| vert_pos, vert_col, radius, opacity = self.get_points() | |
| rgb = self.renderer(vert_pos, vert_col, radius, cam_params, | |
| self.args.gamma, self.args.max_depth, opacity=opacity) | |
| opacity = self.renderer(vert_pos, vert_col * 0, radius, cam_params, | |
| self.args.gamma, self.args.max_depth, opacity=opacity) | |
| return rgb, opacity | |
| def random_view_render(self): | |
| angle = random.uniform(0, np.pi * 2) | |
| altitude = random.uniform(-self.args.altitude_range / 2, self.args.altitude_range / 2) | |
| cam_params = self.camera(angle, altitude) | |
| result, alpha = self.render(cam_params) | |
| back = torch.zeros_like(result) | |
| s = back.shape | |
| for j in range(s[-1]): | |
| n = random.choice([7, 14, 28]) | |
| back[..., j] = utils.rand_perlin_2d_octaves(s[:-1], (n, n)).clip(-0.5, 0.5) + 0.5 | |
| result = result * (1 - alpha) + back * alpha | |
| return result | |
| def generate(self): | |
| self.optimizer.zero_grad() | |
| try: | |
| for i in trange(self.args.iterations + self.args.showoff): | |
| if i < self.args.iterations: | |
| result = self.random_view_render() | |
| img_emb = self.model_clip.encode_image( | |
| self.preprocess_clip(result.permute(2, 0, 1)).unsqueeze(0).clamp(0., 1.)) | |
| img_emb = torch.nn.functional.normalize(img_emb, dim=-1) | |
| if self.args.loss_type == "spherical": | |
| clip_loss = (img_emb - self.txt_emb).norm(dim=-1).div(2).arcsin().pow(2).mul(2).mean() | |
| elif self.args.loss_type == "cosine": | |
| clip_loss = (1 - img_emb @ self.txt_emb.T).mean() | |
| else: | |
| raise NotImplementedError(f"CLIP loss type not supported: {self.args.loss_type}") | |
| loss = clip_loss * self.args.clip_weight + (0 and ...) # TODO add more loss types | |
| loss.backward() | |
| if i % self.args.augments == self.args.augments - 1: | |
| self.optimizer.step() | |
| self.optimizer.zero_grad() | |
| try: | |
| self.scheduler.step() | |
| except AttributeError: | |
| pass | |
| if i % self.args.show_every == 0: | |
| cam_params = self.camera(i / self.args.iterations * np.pi * 2 * self.args.turns, use_random=False) | |
| img_show, _ = self.render(cam_params) | |
| img = Image.fromarray((img_show.cpu().detach().numpy() * 255).astype(np.uint8)) | |
| yield img | |
| except KeyboardInterrupt: | |
| pass | |
| def save_obj(self, fn): | |
| utils.save_obj(self.get_points(), fn) | |
| class DotDict(dict): | |
| def __getattr__(self, item): | |
| return self.__getitem__(item) | |