import os import warnings import hydra import numpy as np import torch import tqdm import imageio from omegaconf import DictConfig from PIL import Image from pytorch3d.renderer import ( PerspectiveCameras, look_at_view_transform ) import matplotlib.pyplot as plt from implicit import implicit_dict from sampler import sampler_dict from renderer import renderer_dict from ray_utils import ( sample_images_at_xy, get_pixels_from_image, get_random_pixels_from_image, get_rays_from_pixels ) from data_utils import ( dataset_from_config, create_surround_cameras, vis_grid, vis_rays, ) from dataset import ( get_nerf_datasets, trivial_collate, ) from render_functions import render_points # Model class containing: # 1) Implicit volume defining the scene # 2) Sampling scheme which generates sample points along rays # 3) Renderer which can render an implicit volume given a sampling scheme class Model(torch.nn.Module): def __init__( self, cfg ): super().__init__() # Get implicit function from config self.implicit_fn = implicit_dict[cfg.implicit_function.type]( cfg.implicit_function ) # Point sampling (raymarching) scheme self.sampler = sampler_dict[cfg.sampler.type]( cfg.sampler ) # Initialize volume renderer self.renderer = renderer_dict[cfg.renderer.type]( cfg.renderer ) def forward( self, ray_bundle ): # Call renderer with # a) Implicit volume # b) Sampling routine return self.renderer( self.sampler, self.implicit_fn, ray_bundle ) # 4. NeRF Extras (CHOOSE ONE! More than one is extra credit) class CoarseFineModel(torch.nn.Module): """ Model with separate coarse and fine networks for hierarchical sampling. """ def __init__(self, cfg): super().__init__() # Coarse network self.coarse_implicit_fn = implicit_dict[cfg.implicit_function.type]( cfg.implicit_function ) self.coarse_sampler = sampler_dict[cfg.sampler_coarse.type]( cfg.sampler_coarse ) self.coarse_renderer = renderer_dict[cfg.renderer.type]( cfg.renderer ) # Fine network (same architecture as coarse, separate parameters) self.fine_implicit_fn = implicit_dict[cfg.implicit_function.type]( cfg.implicit_function ) self.fine_sampler = sampler_dict[cfg.sampler_fine.type]( cfg.sampler_fine ) self.fine_renderer = renderer_dict[cfg.renderer.type]( cfg.renderer ) def forward(self, ray_bundle, return_coarse=True): """ Two-pass rendering: coarse then fine. Args: ray_bundle: RayBundle with ray origins and directions return_coarse: If True, return both coarse and fine outputs Returns: dict with 'coarse' and 'fine' keys (if return_coarse=True) or just fine output """ # Coarse pass coarse_ray_bundle = self.coarse_sampler(ray_bundle) n_pts_coarse = coarse_ray_bundle.sample_shape[1] # Render with coarse network coarse_output = self.coarse_renderer._render_with_implicit( coarse_ray_bundle, self.coarse_implicit_fn, n_pts_coarse ) # Fine pass with importance sampling # Use coarse weights for hierarchical sampling coarse_weights = coarse_output['weights'] # (B, n_pts_coarse, 1) # Sample fine points using hierarchical sampler fine_ray_bundle = self.fine_sampler( coarse_ray_bundle, coarse_weights=coarse_weights.squeeze(-1) ) n_pts_fine = fine_ray_bundle.sample_shape[1] # Render with fine network fine_output = self.fine_renderer._render_with_implicit( fine_ray_bundle, self.fine_implicit_fn, n_pts_fine ) if return_coarse: return { 'coarse': coarse_output, 'fine': fine_output } else: return fine_output def render_images( model, cameras, image_size, save=False, file_prefix='' ): all_images = [] device = list(model.parameters())[0].device for cam_idx, camera in enumerate(cameras): print(f'Rendering image {cam_idx}') torch.cuda.empty_cache() camera = camera.to(device) xy_grid = get_pixels_from_image(image_size, camera) # TODO (Q1.3): implement in ray_utils.py ray_bundle = get_rays_from_pixels(xy_grid, image_size, camera) # TODO (Q1.3): implement in ray_utils.py # TODO (Q1.3): Visualize xy grid using vis_grid if cam_idx == 0 and file_prefix == '': vis_grid(xy_grid, image_size) xy_img = vis_grid(xy_grid, image_size) plt.imshow(xy_img) plt.axis("off") plt.show() # TODO (Q1.3): Visualize rays using vis_rays if cam_idx == 0 and file_prefix == '': ray_img = vis_rays(ray_bundle, image_size) plt.imshow(ray_img) plt.axis("off") plt.show() # TODO (Q1.4): Implement point sampling along rays in sampler.py ray_bundle = model.sampler(ray_bundle) # TODO (Q1.4): Visualize sample points as point cloud if cam_idx == 0 and file_prefix == '': # points = ray_bundle.sample_points.reshape(-1, 3).unsqueeze(0).to(device) # (1, R*n, 3) points = ray_bundle.sample_points.reshape(-1, 3).unsqueeze(0).detach().cpu() # print("[DEBUG] points shape:", points.shape, "generating sample_points...") render_points( filename='images/sample_points.png', points=points, image_size=image_size, color=[0.7, 0.7, 1], device=torch.device("cpu"), ) # TODO (Q1.5): Implement rendering in renderer.py out = model(ray_bundle) # Return rendered features (colors) image = np.array( out['feature'].view( image_size[1], image_size[0], 3 ).detach().cpu() ) all_images.append(image) # TODO (Q1.5): Visualize depth if cam_idx == 2 and file_prefix == '': depth = out['depth'].view(image_size[1], image_size[0], 1).detach().cpu().numpy() depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8) plt.imsave('images/depth.png', depth.squeeze(), cmap='viridis') # Save if save: plt.imsave( f'{file_prefix}_{cam_idx}.png', image ) return all_images def render( cfg, ): # Create model model = Model(cfg) # model = model.cuda(); model.eval() if torch.cuda.is_available(): model = model.cuda() else: model = model.to('cpu') print("Running on CPU - this will be slower!") model.eval() # Render spiral cameras = create_surround_cameras(3.0, n_poses=20) all_images = render_images( model, cameras, cfg.data.image_size ) imageio.mimsave('images/part_1.gif', [np.uint8(im * 255) for im in all_images], loop=0) print("Rendered images successfully. save at images/part_1.gif") # Display a few sample images in the notebook fig, axes = plt.subplots(1, min(4, len(all_images)), figsize=(15, 4)) if len(all_images) == 1: axes = [axes] for i, ax in enumerate(axes): ax.imshow(all_images[i]) ax.axis('off') ax.set_title(f'View {i}') plt.tight_layout() plt.show() def train( cfg ): # Create model model = Model(cfg) # model = model.cuda(); model.train() if torch.cuda.is_available(): model = model.cuda() else: model = model.to('cpu') print("Running on CPU - this will be slower!") model.train() # Create dataset train_dataset = dataset_from_config(cfg.data) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=lambda batch: batch, ) image_size = cfg.data.image_size # Create optimizer optimizer = torch.optim.Adam( model.parameters(), lr=cfg.training.lr ) # Render images before training cameras = [item['camera'] for item in train_dataset] render_images( model, cameras, image_size, save=True, file_prefix='images/part_2_before_training' ) # Train t_range = tqdm.tqdm(range(cfg.training.num_epochs)) for epoch in t_range: for iteration, batch in enumerate(train_dataloader): image, camera, camera_idx = batch[0].values() image = image.cuda() camera = camera.cuda() # Sample rays xy_grid = get_random_pixels_from_image(cfg.training.batch_size, image_size, camera) # TODO (Q2.1): implement in ray_utils.py ray_bundle = get_rays_from_pixels(xy_grid, image_size, camera) rgb_gt = sample_images_at_xy(image, xy_grid) # Run model forward out = model(ray_bundle) # TODO (Q2.2): Calculate loss loss = torch.nn.functional.mse_loss(out['feature'], rgb_gt) # Backprop optimizer.zero_grad() loss.backward() optimizer.step() if (epoch % 10) == 0: t_range.set_description(f'Epoch: {epoch:04d}, Loss: {loss:.06f}') t_range.refresh() # Print center and side lengths print("Box center:", tuple(np.array(model.implicit_fn.sdf.center.data.detach().cpu()).tolist()[0])) print("Box side lengths:", tuple(np.array(model.implicit_fn.sdf.side_lengths.data.detach().cpu()).tolist()[0])) # Render images after training render_images( model, cameras, image_size, save=True, file_prefix='images/part_2_after_training' ) all_images = render_images( model, create_surround_cameras(3.0, n_poses=20), image_size, file_prefix='part_2' ) imageio.mimsave('images/part_2.gif', [np.uint8(im * 255) for im in all_images], loop=0) print("Rendered images successfully. save at images/part_2.gif") # Display a few sample images in the notebook fig, axes = plt.subplots(1, min(4, len(all_images)), figsize=(15, 4)) if len(all_images) == 1: axes = [axes] for i, ax in enumerate(axes): ax.imshow(all_images[i]) ax.axis('off') ax.set_title(f'View {i}') plt.tight_layout() plt.show() def create_model(cfg): # Create model model = Model(cfg) if torch.cuda.is_available(): model = model.cuda() else: model = model.to('cpu') print("Running on CPU - this will be slower!") model.train() # Load checkpoints optimizer_state_dict = None start_epoch = 0 checkpoint_path = os.path.join( hydra.utils.get_original_cwd(), cfg.training.checkpoint_path ) if len(cfg.training.checkpoint_path) > 0: # Make the root of the experiment directory. checkpoint_dir = os.path.split(checkpoint_path)[0] os.makedirs(checkpoint_dir, exist_ok=True) # Resume training if requested. if cfg.training.resume and os.path.isfile(checkpoint_path): print(f"Resuming from checkpoint {checkpoint_path}.") loaded_data = torch.load(checkpoint_path) model.load_state_dict(loaded_data["model"]) start_epoch = loaded_data["epoch"] print(f" => resuming from epoch {start_epoch}.") optimizer_state_dict = loaded_data["optimizer"] # Initialize the optimizer. optimizer = torch.optim.Adam( model.parameters(), lr=cfg.training.lr, ) # Load the optimizer state dict in case we are resuming. if optimizer_state_dict is not None: optimizer.load_state_dict(optimizer_state_dict) optimizer.last_epoch = start_epoch # The learning rate scheduling is implemented with LambdaLR PyTorch scheduler. def lr_lambda(epoch): return cfg.training.lr_scheduler_gamma ** ( epoch / cfg.training.lr_scheduler_step_size ) lr_scheduler = torch.optim.lr_scheduler.LambdaLR( # optimizer, lr_lambda, last_epoch=start_epoch - 1, verbose=False optimizer, lr_lambda, last_epoch=start_epoch - 1 ) return model, optimizer, lr_scheduler, start_epoch, checkpoint_path def train_nerf( cfg ): # Create model print("[DEBUG] Start training NeRF...") model, optimizer, lr_scheduler, start_epoch, checkpoint_path = create_model(cfg) # Load the training/validation data. train_dataset, val_dataset, _ = get_nerf_datasets( dataset_name=cfg.data.dataset_name, image_size=[cfg.data.image_size[1], cfg.data.image_size[0]], ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=trivial_collate, ) # Run the main training loop. for epoch in range(start_epoch, cfg.training.num_epochs): t_range = tqdm.tqdm(enumerate(train_dataloader)) for iteration, batch in t_range: image, camera, camera_idx = batch[0].values() device = next(model.parameters()).device image = image.to(device).unsqueeze(0) camera = camera.to(device) # Sample rays xy_grid = get_random_pixels_from_image( cfg.training.batch_size, cfg.data.image_size, camera ) ray_bundle = get_rays_from_pixels( xy_grid, cfg.data.image_size, camera ) rgb_gt = sample_images_at_xy(image, xy_grid) # Run model forward out = model(ray_bundle) # TODO (Q3.1): Calculate loss loss = torch.nn.functional.mse_loss(out['feature'], rgb_gt) # Take the training step. optimizer.zero_grad() loss.backward() optimizer.step() t_range.set_description(f'Epoch: {epoch:04d}, Loss: {loss:.06f}') t_range.refresh() # Adjust the learning rate. lr_scheduler.step() # Checkpoint. if ( epoch % cfg.training.checkpoint_interval == 0 and len(cfg.training.checkpoint_path) > 0 and epoch > 0 ): print(f"Storing checkpoint {checkpoint_path}.") data_to_store = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch, } torch.save(data_to_store, checkpoint_path) # Render if ( epoch % cfg.training.render_interval == 0 and epoch > 0 ): with torch.no_grad(): test_images = render_images( model, create_surround_cameras(4.0, n_poses=20, up=(0.0, 0.0, 1.0), focal_length=2.0), cfg.data.image_size, file_prefix='nerf' ) imageio.mimsave('images/part_3.gif', [np.uint8(im * 255) for im in test_images], loop=0) print("Saved nerf rendering to images/part_3.gif") # Display a few sample images in the notebook fig, axes = plt.subplots(1, min(4, len(test_images)), figsize=(15, 4)) if len(test_images) == 1: axes = [axes] for i, ax in enumerate(axes): ax.imshow(test_images[i]) ax.axis('off') ax.set_title(f'View {i}') plt.tight_layout() plt.show() def train_nerf_coarse_fine(cfg): """ Training function for coarse-fine NeRF (Q4.2). Uses hierarchical importance sampling with separate coarse and fine networks. """ print("[DEBUG] Start training Coarse-Fine NeRF...") # Create coarse-fine model model = CoarseFineModel(cfg) if torch.cuda.is_available(): model = model.cuda() else: model = model.to('cpu') print("Running on CPU - this will be slower!") model.train() # Load training data train_dataset, val_dataset, _ = get_nerf_datasets( dataset_name=cfg.data.dataset_name, image_size=[cfg.data.image_size[1], cfg.data.image_size[0]], ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=trivial_collate, ) # Optimizer for both networks optimizer = torch.optim.Adam( model.parameters(), lr=cfg.training.lr, ) # Learning rate scheduler def lr_lambda(epoch): return cfg.training.lr_scheduler_gamma ** ( epoch / cfg.training.lr_scheduler_step_size ) lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda ) # Loss weights coarse_weight = cfg.training.coarse_weight if 'coarse_weight' in cfg.training else 0.5 fine_weight = cfg.training.fine_weight if 'fine_weight' in cfg.training else 1.0 # Training loop for epoch in range(cfg.training.num_epochs): t_range = tqdm.tqdm(enumerate(train_dataloader)) epoch_loss_coarse = 0.0 epoch_loss_fine = 0.0 epoch_loss_total = 0.0 for iteration, batch in t_range: image, camera, camera_idx = batch[0].values() device = next(model.parameters()).device image = image.to(device).unsqueeze(0) camera = camera.to(device) # Sample rays xy_grid = get_random_pixels_from_image( cfg.training.batch_size, cfg.data.image_size, camera ) ray_bundle = get_rays_from_pixels( xy_grid, cfg.data.image_size, camera ) rgb_gt = sample_images_at_xy(image, xy_grid) # Forward pass: get both coarse and fine outputs outputs = model(ray_bundle, return_coarse=True) # Compute losses loss_coarse = torch.nn.functional.mse_loss( outputs['coarse']['feature'], rgb_gt ) loss_fine = torch.nn.functional.mse_loss( outputs['fine']['feature'], rgb_gt ) # Total loss is weighted sum loss = coarse_weight * loss_coarse + fine_weight * loss_fine # Backprop optimizer.zero_grad() loss.backward() optimizer.step() # Track losses epoch_loss_coarse += loss_coarse.item() epoch_loss_fine += loss_fine.item() epoch_loss_total += loss.item() t_range.set_description( f'Epoch: {epoch:04d}, Loss: {loss:.06f} ' f'(Coarse: {loss_coarse:.06f}, Fine: {loss_fine:.06f})' ) t_range.refresh() # Update learning rate lr_scheduler.step() # Average losses for the epoch n_batches = len(train_dataloader) print(f"Epoch {epoch}: Avg Loss = {epoch_loss_total/n_batches:.6f}, " f"Coarse = {epoch_loss_coarse/n_batches:.6f}, " f"Fine = {epoch_loss_fine/n_batches:.6f}") # Render test images if (epoch % cfg.training.render_interval == 0 and epoch > 0): model.eval() with torch.no_grad(): test_cameras = create_surround_cameras( 4.0, n_poses=20, up=(0.0, 0.0, 1.0), focal_length=2.0 ) test_images = [] for camera in test_cameras: camera = camera.to(device) xy_grid = get_pixels_from_image(cfg.data.image_size, camera) ray_bundle = get_rays_from_pixels(xy_grid, cfg.data.image_size, camera) # Use only fine network for rendering output = model(ray_bundle, return_coarse=False) image = np.array( output['feature'].view( cfg.data.image_size[1], cfg.data.image_size[0], 3 ).detach().cpu() ) test_images.append(image) imageio.mimsave( f'images/part_4_2_epoch_{epoch}.gif', [np.uint8(im * 255) for im in test_images], loop=0 ) print(f"Saved rendering to images/part_4_2_epoch_{epoch}.gif") print("Rendered images successfully. save at images/part_1.gif") # Display a few sample images in the notebook fig, axes = plt.subplots(1, min(4, len(test_images)), figsize=(15, 4)) if len(test_images) == 1: axes = [axes] for i, ax in enumerate(axes): ax.imshow(test_images[i]) ax.axis('off') ax.set_title(f'View {i}') plt.tight_layout() plt.show() model.train() print("Training complete!") @hydra.main(config_path='./configs', config_name='sphere') def main(cfg: DictConfig): os.chdir(hydra.utils.get_original_cwd()) if cfg.type == 'render': render(cfg) elif cfg.type == 'train': train(cfg) elif cfg.type == 'train_nerf': train_nerf(cfg) elif cfg.type == 'train_nerf_coarse_fine': train_nerf_coarse_fine(cfg) if __name__ == "__main__": main()