feat: Integrate CorrDiffSolar model with MultiDiffusion support#545
feat: Integrate CorrDiffSolar model with MultiDiffusion support#545sunjingan wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
Greptile OverviewGreptile SummaryThis PR integrates the CorrDiffSolar model with MultiDiffusion support for high-resolution solar radiation downscaling. Key additions:
Critical issues found:
Architecture concerns:
Confidence Score: 1/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant diagnostic_solar
participant PrognosticModel
participant CorrDiffSolarMD
participant MultiDiffusion
participant RegressionModel
participant ResidualModel
User->>diagnostic_solar: Run inference with time, nsteps
diagnostic_solar->>DataSource: Fetch initial conditions
DataSource-->>diagnostic_solar: Return x, coords
loop For each timestep (nsteps)
diagnostic_solar->>PrognosticModel: Forward pass
PrognosticModel-->>diagnostic_solar: pro_out, pro_out_coord
diagnostic_solar->>CorrDiffSolarMD: __call__(pro_out, pro_out_coord)
CorrDiffSolarMD->>CorrDiffSolarMD: Interpolate to output grid
CorrDiffSolarMD->>CorrDiffSolarMD: Compute solar zenith angle
CorrDiffSolarMD->>CorrDiffSolarMD: Preprocess input (normalize)
CorrDiffSolarMD->>CorrDiffSolarMD: get_windows(stride)
loop For each window
CorrDiffSolarMD->>RegressionModel: regression_step(window)
RegressionModel-->>CorrDiffSolarMD: regression output
end
CorrDiffSolarMD->>CorrDiffSolarMD: Average overlapping regions
alt inference_mode == "both"
CorrDiffSolarMD->>MultiDiffusion: __call__(net, img_lr, regression_output, windows)
loop For each diffusion step
loop For each window
MultiDiffusion->>ResidualModel: Forward pass (denoising)
ResidualModel-->>MultiDiffusion: denoised output
MultiDiffusion->>ResidualModel: Forward pass (second order)
ResidualModel-->>MultiDiffusion: refined output
end
MultiDiffusion->>MultiDiffusion: Average overlapping regions
end
MultiDiffusion-->>CorrDiffSolarMD: residual output
CorrDiffSolarMD->>CorrDiffSolarMD: Add regression + residual
end
CorrDiffSolarMD->>CorrDiffSolarMD: Postprocess output (denormalize)
CorrDiffSolarMD-->>diagnostic_solar: Solar radiation output
diagnostic_solar->>IOBackend: Write output
end
diagnostic_solar-->>User: Return IOBackend
|
|
|
||
| import zipfile | ||
| from collections import OrderedDict | ||
| from collections.abc import Callable,Sequence |
There was a problem hiding this comment.
syntax: Missing space after comma in import
| from collections.abc import Callable,Sequence | |
| from collections.abc import Callable, Sequence |
| from earth2studio.utils.type import CoordSystem | ||
| from earth2studio.utils.coords import CoordSystem, map_coords |
There was a problem hiding this comment.
syntax: Duplicate import of CoordSystem - already imported from earth2studio.utils.type on line 51
| from earth2studio.utils.type import CoordSystem | |
| from earth2studio.utils.coords import CoordSystem, map_coords | |
| from earth2studio.utils.coords import map_coords |
| OptionalDependencyFailure, | ||
| check_optional_dependencies, | ||
| ) | ||
| from earth2studio.utils.type import CoordSystem |
There was a problem hiding this comment.
syntax: Duplicate import of CoordSystem from earth2studio.utils.type - already imported on line 51
| from earth2studio.utils.type import CoordSystem |
| self.img_shape = (lat_output_grid.shape[0], lon_output_grid.shape[0]) | ||
| self.img_shape = (320,320) |
There was a problem hiding this comment.
logic: Hardcoded img_shape overrides computed value - line 185 calculates from grid shape but line 186 hardcodes to (320,320), making the computation on line 185 pointless
| self.img_shape = (lat_output_grid.shape[0], lon_output_grid.shape[0]) | |
| self.img_shape = (320,320) | |
| self.img_shape = (320, 320) |
| image_reg_full = torch.zeros((1,len(self.output_variables),len(self.lat_output_numpy),len(self.lon_output_numpy))).to(self.device) | ||
| counts = torch.zeros_like(image_reg_full).to(self.device) |
There was a problem hiding this comment.
logic: self.device not defined - accessing self.device without setting it (will cause AttributeError). The device should be obtained from buffers or passed as parameter
| variable=prognositc_ic["variable"], | ||
| lead_time=prognositc_ic["lead_time"], | ||
| device=device, | ||
| ) |
There was a problem hiding this comment.
logic: Undefined variable srx - accessing solarcorrdiffic.srx which is never set in CorrDiffSolarMD class (will cause AttributeError)
| regression_output: Tensor, | ||
| class_labels: Optional[Tensor] = None, | ||
| randn_like: Callable[[Tensor], Tensor] = torch.randn_like, | ||
| windows: Optional[Tensor] = None, |
There was a problem hiding this comment.
style: Type annotation uses deprecated syntax - Optional[Tensor] requires from typing import Optional, but Tensor | None is the modern Python 3.10+ syntax already used elsewhere in the codebase
| windows: Optional[Tensor] = None, | |
| windows: Tensor | None = None, |
| class_labels: Optional[Tensor] = None, | ||
| randn_like: Callable[[Tensor], Tensor] = torch.randn_like, | ||
| windows: Optional[Tensor] = None, | ||
| lead_time_label: Optional[Tensor] = None, |
There was a problem hiding this comment.
style: Type annotation uses deprecated syntax - use Tensor | None instead of Optional[Tensor]
| lead_time_label: Optional[Tensor] = None, | |
| lead_time_label: Tensor | None = None, |
| net: torch.nn.Module, | ||
| img_lr: Tensor, | ||
| regression_output: Tensor, | ||
| class_labels: Optional[Tensor] = None, |
There was a problem hiding this comment.
style: Type annotation uses deprecated syntax - use Tensor | None instead of Optional[Tensor]
| class_labels: Optional[Tensor] = None, | |
| class_labels: Tensor | None = None, |
|
|
||
| import os | ||
| os.environ["EARTH2STUDIO_PACKAGE_TIMEOUT"] = "10000" | ||
| #os.environ["CUDA_VISIBLE_DEVICES"] = "1" |
There was a problem hiding this comment.
style: Commented out environment variable could cause confusion - either remove it or document why it's commented
| #os.environ["CUDA_VISIBLE_DEVICES"] = "1" | |
| # os.environ["CUDA_VISIBLE_DEVICES"] = "1" # Uncomment to specify GPU device |
|
Thanks for opening this @sunjingan ! |
Earth2Studio Pull Request
Hi @Charlelie,
As discussed over email, this PR integrates the retrained CorrDiffSolar model into Earth2Studio.
Key Changes
CorrDiffSolarMDModel (earth2studio/models/dx/corrdiffMD.py).MultiDiffusionSampler (earth2studio/utils/multidiffusion.py).diagnostic_solar()logic intoearth2studio/run.py.Usage Example
examples/19_multidiff_solar.py.load_default_package()to download it. (The package is in the docker image as I described in e-mail).