Spaces:
Runtime error
Runtime error
Update opensora/models/ae/videobase/causal_vae/modeling_causalvae.py
Browse files
opensora/models/ae/videobase/causal_vae/modeling_causalvae.py
CHANGED
|
@@ -316,7 +316,8 @@ class CausalVAEModel(VideoBaseAE_PL):
|
|
| 316 |
self.tile_sample_min_size = 256
|
| 317 |
self.tile_sample_min_size_t = 65
|
| 318 |
self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1)))
|
| 319 |
-
|
|
|
|
| 320 |
self.tile_overlap_factor = 0.25
|
| 321 |
self.use_tiling = False
|
| 322 |
|
|
@@ -374,8 +375,9 @@ class CausalVAEModel(VideoBaseAE_PL):
|
|
| 374 |
if self.use_tiling and (
|
| 375 |
x.shape[-1] > self.tile_sample_min_size
|
| 376 |
or x.shape[-2] > self.tile_sample_min_size
|
|
|
|
| 377 |
):
|
| 378 |
-
return self.
|
| 379 |
h = self.encoder(x)
|
| 380 |
moments = self.quant_conv(h)
|
| 381 |
posterior = DiagonalGaussianDistribution(moments)
|
|
@@ -385,8 +387,9 @@ class CausalVAEModel(VideoBaseAE_PL):
|
|
| 385 |
if self.use_tiling and (
|
| 386 |
z.shape[-1] > self.tile_latent_min_size
|
| 387 |
or z.shape[-2] > self.tile_latent_min_size
|
|
|
|
| 388 |
):
|
| 389 |
-
return self.
|
| 390 |
z = self.post_quant_conv(z)
|
| 391 |
dec = self.decoder(z)
|
| 392 |
return dec
|
|
@@ -554,7 +557,54 @@ class CausalVAEModel(VideoBaseAE_PL):
|
|
| 554 |
) + b[:, :, :, :, x] * (x / blend_extent)
|
| 555 |
return b
|
| 556 |
|
| 557 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
| 559 |
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
| 560 |
row_limit = self.tile_latent_min_size - blend_extent
|
|
@@ -590,7 +640,8 @@ class CausalVAEModel(VideoBaseAE_PL):
|
|
| 590 |
|
| 591 |
moments = torch.cat(result_rows, dim=3)
|
| 592 |
posterior = DiagonalGaussianDistribution(moments)
|
| 593 |
-
|
|
|
|
| 594 |
return posterior
|
| 595 |
|
| 596 |
def tiled_decode2d(self, z):
|
|
|
|
| 316 |
self.tile_sample_min_size = 256
|
| 317 |
self.tile_sample_min_size_t = 65
|
| 318 |
self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1)))
|
| 319 |
+
t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0]
|
| 320 |
+
self.tile_latent_min_size_t = int((self.tile_sample_min_size_t-1) / (2 ** len(t_down_ratio))) + 1
|
| 321 |
self.tile_overlap_factor = 0.25
|
| 322 |
self.use_tiling = False
|
| 323 |
|
|
|
|
| 375 |
if self.use_tiling and (
|
| 376 |
x.shape[-1] > self.tile_sample_min_size
|
| 377 |
or x.shape[-2] > self.tile_sample_min_size
|
| 378 |
+
or x.shape[-3] > self.tile_sample_min_size_t
|
| 379 |
):
|
| 380 |
+
return self.tiled_encode(x)
|
| 381 |
h = self.encoder(x)
|
| 382 |
moments = self.quant_conv(h)
|
| 383 |
posterior = DiagonalGaussianDistribution(moments)
|
|
|
|
| 387 |
if self.use_tiling and (
|
| 388 |
z.shape[-1] > self.tile_latent_min_size
|
| 389 |
or z.shape[-2] > self.tile_latent_min_size
|
| 390 |
+
or z.shape[-3] > self.tile_latent_min_size_t
|
| 391 |
):
|
| 392 |
+
return self.tiled_decode(z)
|
| 393 |
z = self.post_quant_conv(z)
|
| 394 |
dec = self.decoder(z)
|
| 395 |
return dec
|
|
|
|
| 557 |
) + b[:, :, :, :, x] * (x / blend_extent)
|
| 558 |
return b
|
| 559 |
|
| 560 |
+
def tiled_encode(self, x):
|
| 561 |
+
t = x.shape[2]
|
| 562 |
+
t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t-1)]
|
| 563 |
+
if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
|
| 564 |
+
t_chunk_start_end = [[0, t]]
|
| 565 |
+
else:
|
| 566 |
+
t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i+1]+1] for i in range(len(t_chunk_idx)-1)]
|
| 567 |
+
if t_chunk_start_end[-1][-1] > t:
|
| 568 |
+
t_chunk_start_end[-1][-1] = t
|
| 569 |
+
elif t_chunk_start_end[-1][-1] < t:
|
| 570 |
+
last_start_end = [t_chunk_idx[-1], t]
|
| 571 |
+
t_chunk_start_end.append(last_start_end)
|
| 572 |
+
moments = []
|
| 573 |
+
for idx, (start, end) in enumerate(t_chunk_start_end):
|
| 574 |
+
chunk_x = x[:, :, start: end]
|
| 575 |
+
if idx != 0:
|
| 576 |
+
moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1:]
|
| 577 |
+
else:
|
| 578 |
+
moment = self.tiled_encode2d(chunk_x, return_moments=True)
|
| 579 |
+
moments.append(moment)
|
| 580 |
+
moments = torch.cat(moments, dim=2)
|
| 581 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 582 |
+
return posterior
|
| 583 |
+
|
| 584 |
+
def tiled_decode(self, x):
|
| 585 |
+
t = x.shape[2]
|
| 586 |
+
t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t-1)]
|
| 587 |
+
if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
|
| 588 |
+
t_chunk_start_end = [[0, t]]
|
| 589 |
+
else:
|
| 590 |
+
t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i+1]+1] for i in range(len(t_chunk_idx)-1)]
|
| 591 |
+
if t_chunk_start_end[-1][-1] > t:
|
| 592 |
+
t_chunk_start_end[-1][-1] = t
|
| 593 |
+
elif t_chunk_start_end[-1][-1] < t:
|
| 594 |
+
last_start_end = [t_chunk_idx[-1], t]
|
| 595 |
+
t_chunk_start_end.append(last_start_end)
|
| 596 |
+
dec_ = []
|
| 597 |
+
for idx, (start, end) in enumerate(t_chunk_start_end):
|
| 598 |
+
chunk_x = x[:, :, start: end]
|
| 599 |
+
if idx != 0:
|
| 600 |
+
dec = self.tiled_decode2d(chunk_x)[:, :, 1:]
|
| 601 |
+
else:
|
| 602 |
+
dec = self.tiled_decode2d(chunk_x)
|
| 603 |
+
dec_.append(dec)
|
| 604 |
+
dec_ = torch.cat(dec_, dim=2)
|
| 605 |
+
return dec_
|
| 606 |
+
|
| 607 |
+
def tiled_encode2d(self, x, return_moments=False):
|
| 608 |
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
| 609 |
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
| 610 |
row_limit = self.tile_latent_min_size - blend_extent
|
|
|
|
| 640 |
|
| 641 |
moments = torch.cat(result_rows, dim=3)
|
| 642 |
posterior = DiagonalGaussianDistribution(moments)
|
| 643 |
+
if return_moments:
|
| 644 |
+
return moments
|
| 645 |
return posterior
|
| 646 |
|
| 647 |
def tiled_decode2d(self, z):
|