Skip to content

feat: Integrate CorrDiffSolar model with MultiDiffusion support#545

Open
sunjingan wants to merge 1 commit intoNVIDIA:mainfrom
sunjingan:feat/MDsolar-integration
Open

feat: Integrate CorrDiffSolar model with MultiDiffusion support#545
sunjingan wants to merge 1 commit intoNVIDIA:mainfrom
sunjingan:feat/MDsolar-integration

Conversation

@sunjingan
Copy link

Earth2Studio Pull Request

Hi @Charlelie,

As discussed over email, this PR integrates the retrained CorrDiffSolar model into Earth2Studio.

Key Changes

  • Added CorrDiffSolarMD Model (earth2studio/models/dx/corrdiffMD.py).
  • Added MultiDiffusion Sampler (earth2studio/utils/multidiffusion.py).
  • Merged diagnostic_solar() logic into earth2studio/run.py.

Usage Example

  • An example inference script has been added at examples/19_multidiff_solar.py.
  • Note: We need your help to upload the package to the NGC, and then one can use load_default_package() to download it. (The package is in the docker image as I described in e-mail).

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 11, 2025

Greptile Overview

Greptile Summary

This PR integrates the CorrDiffSolar model with MultiDiffusion support for high-resolution solar radiation downscaling.

Key additions:

  • New CorrDiffMD base class and CorrDiffSolarMD subclass in corrdiffMD.py
  • New MultiDiffusion sampler for windowed diffusion processing
  • New diagnostic_solar workflow function in run.py
  • Example script demonstrating usage

Critical issues found:

  • Multiple import errors and syntax bugs that will cause runtime failures
  • Missing self.device attribute access will crash during inference
  • Undefined solarcorrdiffic.srx attribute access in run.py
  • Incorrect class name import (corrdiffMD vs CorrDiffMD) in __init__.py
  • Deprecated datetime.utcfromtimestamp usage
  • Hardcoded img_shape override defeats dynamic calculation

Architecture concerns:

  • The hardcoded 320x320 window size may limit flexibility for different resolutions
  • MultiDiffusion implementation appears sound for handling large images via overlapping windows

Confidence Score: 1/5

  • This PR has critical runtime errors that will cause immediate failures
  • Multiple critical bugs present including missing attributes (self.device, solarcorrdiffic.srx), incorrect imports, syntax errors in type annotations, and deprecated API usage. These issues will cause AttributeError and ImportError exceptions at runtime
  • Primary attention needed on earth2studio/models/dx/corrdiffMD.py (missing self.device), earth2studio/run.py (undefined srx attribute), and earth2studio/models/dx/__init__.py (incorrect import name)

Important Files Changed

File Analysis

Filename Score Overview
earth2studio/models/dx/corrdiffMD.py 2/5 New CorrDiffMD and CorrDiffSolarMD model classes with MultiDiffusion support. Critical bugs: duplicate imports, hardcoded img_shape, missing self.device attribute, deprecated datetime function, incorrect type hints
earth2studio/utils/multidiffusion.py 4/5 New MultiDiffusion sampler implementing windowed diffusion with overlapping regions. Minor style issues with Optional type hints, but logic appears sound
earth2studio/run.py 2/5 New diagnostic_solar workflow function. Critical bugs: incorrect type annotation for plt_lr parameter, undefined attribute access (solarcorrdiffic.srx)
earth2studio/models/dx/init.py 2/5 Added imports for corrdiffMD and CorrDiffSolarMD. Issue: importing non-existent corrdiffMD class (lowercase) will cause ImportError
examples/19_multidiff_solar.py 4/5 Example script for CorrDiffSolarMD inference. Minor style issue with commented environment variable, but overall implementation is clean

Sequence Diagram

sequenceDiagram
    participant User
    participant diagnostic_solar
    participant PrognosticModel
    participant CorrDiffSolarMD
    participant MultiDiffusion
    participant RegressionModel
    participant ResidualModel

    User->>diagnostic_solar: Run inference with time, nsteps
    diagnostic_solar->>DataSource: Fetch initial conditions
    DataSource-->>diagnostic_solar: Return x, coords
    
    loop For each timestep (nsteps)
        diagnostic_solar->>PrognosticModel: Forward pass
        PrognosticModel-->>diagnostic_solar: pro_out, pro_out_coord
        
        diagnostic_solar->>CorrDiffSolarMD: __call__(pro_out, pro_out_coord)
        
        CorrDiffSolarMD->>CorrDiffSolarMD: Interpolate to output grid
        CorrDiffSolarMD->>CorrDiffSolarMD: Compute solar zenith angle
        CorrDiffSolarMD->>CorrDiffSolarMD: Preprocess input (normalize)
        
        CorrDiffSolarMD->>CorrDiffSolarMD: get_windows(stride)
        
        loop For each window
            CorrDiffSolarMD->>RegressionModel: regression_step(window)
            RegressionModel-->>CorrDiffSolarMD: regression output
        end
        
        CorrDiffSolarMD->>CorrDiffSolarMD: Average overlapping regions
        
        alt inference_mode == "both"
            CorrDiffSolarMD->>MultiDiffusion: __call__(net, img_lr, regression_output, windows)
            
            loop For each diffusion step
                loop For each window
                    MultiDiffusion->>ResidualModel: Forward pass (denoising)
                    ResidualModel-->>MultiDiffusion: denoised output
                    MultiDiffusion->>ResidualModel: Forward pass (second order)
                    ResidualModel-->>MultiDiffusion: refined output
                end
                MultiDiffusion->>MultiDiffusion: Average overlapping regions
            end
            
            MultiDiffusion-->>CorrDiffSolarMD: residual output
            CorrDiffSolarMD->>CorrDiffSolarMD: Add regression + residual
        end
        
        CorrDiffSolarMD->>CorrDiffSolarMD: Postprocess output (denormalize)
        CorrDiffSolarMD-->>diagnostic_solar: Solar radiation output
        
        diagnostic_solar->>IOBackend: Write output
    end
    
    diagnostic_solar-->>User: Return IOBackend
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 14 comments

Edit Code Review Agent Settings | Greptile


import zipfile
from collections import OrderedDict
from collections.abc import Callable,Sequence
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: Missing space after comma in import

Suggested change
from collections.abc import Callable,Sequence
from collections.abc import Callable, Sequence

Comment on lines +51 to +52
from earth2studio.utils.type import CoordSystem
from earth2studio.utils.coords import CoordSystem, map_coords
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: Duplicate import of CoordSystem - already imported from earth2studio.utils.type on line 51

Suggested change
from earth2studio.utils.type import CoordSystem
from earth2studio.utils.coords import CoordSystem, map_coords
from earth2studio.utils.coords import map_coords

OptionalDependencyFailure,
check_optional_dependencies,
)
from earth2studio.utils.type import CoordSystem
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: Duplicate import of CoordSystem from earth2studio.utils.type - already imported on line 51

Suggested change
from earth2studio.utils.type import CoordSystem

Comment on lines +185 to +186
self.img_shape = (lat_output_grid.shape[0], lon_output_grid.shape[0])
self.img_shape = (320,320)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Hardcoded img_shape overrides computed value - line 185 calculates from grid shape but line 186 hardcodes to (320,320), making the computation on line 185 pointless

Suggested change
self.img_shape = (lat_output_grid.shape[0], lon_output_grid.shape[0])
self.img_shape = (320,320)
self.img_shape = (320, 320)

Comment on lines +628 to +629
image_reg_full = torch.zeros((1,len(self.output_variables),len(self.lat_output_numpy),len(self.lon_output_numpy))).to(self.device)
counts = torch.zeros_like(image_reg_full).to(self.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: self.device not defined - accessing self.device without setting it (will cause AttributeError). The device should be obtained from buffers or passed as parameter

variable=prognositc_ic["variable"],
lead_time=prognositc_ic["lead_time"],
device=device,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Undefined variable srx - accessing solarcorrdiffic.srx which is never set in CorrDiffSolarMD class (will cause AttributeError)

regression_output: Tensor,
class_labels: Optional[Tensor] = None,
randn_like: Callable[[Tensor], Tensor] = torch.randn_like,
windows: Optional[Tensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Type annotation uses deprecated syntax - Optional[Tensor] requires from typing import Optional, but Tensor | None is the modern Python 3.10+ syntax already used elsewhere in the codebase

Suggested change
windows: Optional[Tensor] = None,
windows: Tensor | None = None,

class_labels: Optional[Tensor] = None,
randn_like: Callable[[Tensor], Tensor] = torch.randn_like,
windows: Optional[Tensor] = None,
lead_time_label: Optional[Tensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Type annotation uses deprecated syntax - use Tensor | None instead of Optional[Tensor]

Suggested change
lead_time_label: Optional[Tensor] = None,
lead_time_label: Tensor | None = None,

net: torch.nn.Module,
img_lr: Tensor,
regression_output: Tensor,
class_labels: Optional[Tensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Type annotation uses deprecated syntax - use Tensor | None instead of Optional[Tensor]

Suggested change
class_labels: Optional[Tensor] = None,
class_labels: Tensor | None = None,


import os
os.environ["EARTH2STUDIO_PACKAGE_TIMEOUT"] = "10000"
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Commented out environment variable could cause confusion - either remove it or document why it's commented

Suggested change
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1" # Uncomment to specify GPU device

@NickGeneva
Copy link
Collaborator

Thanks for opening this @sunjingan !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants