Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e1166c8
wip: add remapping subroutine
TomMelt Jun 16, 2025
fddd7c9
chore: remove unused code
TomMelt Jun 27, 2025
d2b9ad6
wip: regridding debug
TomMelt Jul 30, 2025
1c2f05c
wip: use 192x288 grid instead
TomMelt Aug 20, 2025
aeb59af
wip: try to fix corner issue in regrid
TomMelt Aug 20, 2025
1bf238c
chore: remove unnecessary vars
TomMelt Sep 2, 2025
4885935
wip: gathering ps onto masterproc
TomMelt Sep 2, 2025
4cea511
wip: 2d field gather works, need to do 3d now
TomMelt Sep 3, 2025
243fad3
untested code for lonlat2phys regrid
fvitt Sep 3, 2025
17d7e5e
wip: gather 3d fields working now
TomMelt Sep 4, 2025
7ec6c6a
feat: make grid arrays available to module
TomMelt Sep 10, 2025
953d8a7
Revert "wip: try to fix corner issue in regrid"
TomMelt Sep 10, 2025
5cd9bf1
feat: change polemethod back to ESMF_POLEMETHOD_ALLAVG
TomMelt Sep 10, 2025
e169f0f
feat: regridding works both ways!
TomMelt Sep 11, 2025
3acd066
chore: restructure gw_nlgw ready for gw_nlgw_unet
TomMelt Sep 12, 2025
9c6be08
chore: tidy up remap.F90
TomMelt Sep 12, 2025
78a697e
feat: (wip) add UNet model to CAM
TomMelt Oct 29, 2025
e192d2f
feat: compute tendencies and tidy up
TomMelt Nov 3, 2025
85109cc
wip: output utgw and vtgw
TomMelt Nov 3, 2025
ecc9d5b
Fix output to initialise tends earlier
ma595 Nov 3, 2025
2d7d0bc
feat: update tendencies using Wills approach
TomMelt Nov 12, 2025
3803151
chore: change name to nlgw_unet when init ptend
TomMelt Nov 13, 2025
426c1c1
chore: remove hardcoded path for unet and add paths to module
TomMelt Nov 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,26 @@ gw_convect_dp_ml_norms='/path/to/norms'

To run CAM using the non local gravity wave ML model to replace all parameterisations use the following configuration
```fortran
use_gw_nlgw=.true.
gw_nlgw_model_path='/path/to/nlgw-scripted-model.pt'
use_gw_nlgw_ann=.true.
use_gw_nlgw_unet=.true.
gw_nlgw_model_path_ann='/path/to/ann-scripted-model.pt'
gw_nlgw_model_path_unet='/path/to/unet-scripted-model.pt'
```

* `use_gw_nlgw` (`logical`)
* `use_gw_nlgw_ann` (`logical`)

Whether or not to use the ML scheme for non local gravity waves. Default: `.false.`
Whether or not to use the ANN ML scheme for non local gravity waves. Default: `.false.`

* `gw_nlgw_model_path`
* `gw_nlgw_model_path_ann`

Absolute filepath to the non local gravity wave neural net used when `use_gw_nlgw` is set to `.true.` (`.pt`
extension).

* `use_gw_nlgw_unet` (`logical`)

Whether or not to use the UNET ML scheme for non local gravity waves. Default: `.false.`

* `gw_nlgw_model_path_unet`

Absolute filepath to the non local gravity wave neural net used when `use_gw_nlgw` is set to `.true.` (`.pt`
extension).
Expand Down
3 changes: 2 additions & 1 deletion bld/build-namelist
Original file line number Diff line number Diff line change
Expand Up @@ -3606,7 +3606,8 @@ if (!$simple_phys) {
add_default($nl, 'use_gw_rdg_gamma' , 'val'=>'.false.');
add_default($nl, 'use_gw_front_igw' , 'val'=>'.false.');
add_default($nl, 'use_gw_convect_sh', 'val'=>'.false.');
add_default($nl, 'use_gw_nlgw' , 'val'=>'.false.');
add_default($nl, 'use_gw_nlgw_ann' , 'val'=>'.false.');
add_default($nl, 'use_gw_nlgw_unet' , 'val'=>'.false.');
add_default($nl, 'gw_lndscl_sgh');
add_default($nl, 'gw_oro_south_fac');
add_default($nl, 'gw_limit_tau_without_eff');
Expand Down
21 changes: 17 additions & 4 deletions bld/namelist_files/namelist_definition.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1332,10 +1332,17 @@ Whether or not to enable gravity waves produced by shallow convection.
Default: .false.
</entry>

<entry id="use_gw_nlgw" type="logical" category="gw_drag"
<entry id="use_gw_nlgw_ann" type="logical" category="gw_drag"
group="phys_ctl_nl" valid_values="" >
Whether or not to enable gravity waves produced by non-local gravity
wave ML model.
wave ANN ML model.
Default: set by build-namelist.
</entry>

<entry id="use_gw_nlgw_unet" type="logical" category="gw_drag"
group="phys_ctl_nl" valid_values="" >
Whether or not to enable gravity waves produced by non-local gravity
wave UNet ML model.
Default: set by build-namelist.
</entry>

Expand Down Expand Up @@ -1428,10 +1435,16 @@ Absolute filepath to the deep convection gravity wave neural net used when
Default: .false.
</entry>

<entry id="gw_nlgw_model_path" type="char*132" input_pathname="abs" category="gw_drag"
<entry id="gw_nlgw_model_path_ann" type="char*132" input_pathname="abs" category="gw_drag"
group="gw_drag_nl" valid_values="" >
Absolute filepath to the non local gravity wave traced model (.pt)
used when `use_gw_nlgw_ann` is set to `.true.`.
</entry>

<entry id="gw_nlgw_model_path_unet" type="char*132" input_pathname="abs" category="gw_drag"
group="gw_drag_nl" valid_values="" >
Absolute filepath to the non local gravity wave traced model (.pt)
used when `use_gw_nlgw` is set to `.true.`.
used when `use_gw_nlgw_unet` is set to `.true.`.
</entry>

<entry id="effgw_cm" type="real" category="gw_drag"
Expand Down
29 changes: 17 additions & 12 deletions src/physics/cam/gw_drag.F90
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ module gw_drag
! These are the actual switches for different gravity wave sources.
use phys_control, only: use_gw_oro, use_gw_front, use_gw_front_igw, &
use_gw_convect_dp, use_gw_convect_sh, &
use_simple_phys, use_gw_nlgw
use_simple_phys, &
use_gw_nlgw_ann, use_gw_nlgw_unet

use gw_common, only: GWBand
use gw_convect, only: BeresSourceDesc
use gw_front, only: CMSourceDesc
use gw_ml, only: gw_drag_convect_dp_ml_init, gw_drag_convect_dp_ml_final, &
gw_drag_convect_dp_ml
use gw_nlgw, only: gw_nlgw_dp_ml, gw_nlgw_dp_init, gw_nlgw_dp_finalize
use gw_nlgw_ann, only: gw_nlgw_ann_infer, gw_nlgw_ann_init, gw_nlgw_ann_finalize

! Typical module header
implicit none
Expand Down Expand Up @@ -203,7 +204,6 @@ module gw_drag
logical :: gw_convect_dp_ml_compare = .false.
character(len=132) :: gw_convect_dp_ml_net_path
character(len=132) :: gw_convect_dp_ml_norms
character(len=132) :: gw_nlgw_model_path

!==========================================================================
contains
Expand All @@ -216,6 +216,7 @@ subroutine gw_drag_readnl(nlfile)
use spmd_utils, only: mpicom, mstrid=>masterprocid, mpi_real8, &
mpi_character, mpi_logical, mpi_integer
use gw_rdg, only: gw_rdg_readnl
use gw_nlgw_utils, only: gw_nlgw_model_path_ann, gw_nlgw_model_path_unet

! File containing namelist input.
character(len=*), intent(in) :: nlfile
Expand Down Expand Up @@ -248,7 +249,7 @@ subroutine gw_drag_readnl(nlfile)
gw_top_taper, front_gaussian_width, &
gw_convect_dp_ml, gw_convect_dp_ml_compare, &
gw_convect_dp_ml_net_path, gw_convect_dp_ml_norms, &
gw_nlgw_model_path
gw_nlgw_model_path_ann, gw_nlgw_model_path_unet
!----------------------------------------------------------------------

if (use_simple_phys) return
Expand Down Expand Up @@ -364,8 +365,11 @@ subroutine gw_drag_readnl(nlfile)
call mpi_bcast(gw_convect_dp_ml_norms, len(gw_convect_dp_ml_norms), mpi_character, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_convect_dp_ml_norms")

call mpi_bcast(gw_nlgw_model_path, len(gw_nlgw_model_path), mpi_character, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_nlgw_model_path")
call mpi_bcast(gw_nlgw_model_path_ann, len(gw_nlgw_model_path_ann), mpi_character, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_nlgw_model_path_ann")

call mpi_bcast(gw_nlgw_model_path_unet, len(gw_nlgw_model_path_unet), mpi_character, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_nlgw_model_path_unet")

! Check if fcrit2 was set.
call shr_assert(fcrit2 /= unset_r8, &
Expand Down Expand Up @@ -425,6 +429,7 @@ subroutine gw_init()

use gw_common, only: gw_common_init
use gw_front, only: gaussian_cm_desc
use gw_nlgw_utils, only: gw_nlgw_model_path_ann

!---------------------------Local storage-------------------------------

Expand Down Expand Up @@ -577,8 +582,8 @@ subroutine gw_init()
call shr_assert(trim(errstring) == "", "gw_common_init: "//errstring// &
errMsg(__FILE__, __LINE__))

if ( use_gw_nlgw ) then
call gw_nlgw_dp_init(gw_nlgw_model_path)
if ( use_gw_nlgw_ann ) then
call gw_nlgw_ann_init(gw_nlgw_model_path_ann)
end if

if ( use_gw_oro ) then
Expand Down Expand Up @@ -1298,8 +1303,8 @@ subroutine gw_final()
if ((gw_convect_dp_ml) .or. (gw_convect_dp_ml_compare)) then
call gw_drag_convect_dp_ml_final()
endif
if ( use_gw_nlgw ) then
call gw_nlgw_dp_finalize()
if ( use_gw_nlgw_ann ) then
call gw_nlgw_ann_finalize()
end if
end subroutine gw_final

Expand Down Expand Up @@ -1548,8 +1553,8 @@ subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat)
egwdffi_tot = 0._r8
flx_heat = 0._r8

if ( use_gw_nlgw ) then
call gw_nlgw_dp_ml(state1,ptend,lchnk)
if ( use_gw_nlgw_ann ) then
call gw_nlgw_ann_infer(state1,ptend,lchnk)
end if

if (use_gw_convect_dp) then
Expand Down
55 changes: 15 additions & 40 deletions src/physics/cam/gw_nlgw.F90 → src/physics/cam/gw_nlgw_ann.F90
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module gw_nlgw
module gw_nlgw_ann

!
! This module predicts gravity wave forcings via PyTorch NNs trained to include non-local gravity wave effects
Expand All @@ -11,19 +11,18 @@ module gw_nlgw
use cam_abortutils, only: endrun
use cam_logfile, only: iulog
use physconst, only: cappa, pi
use gw_nlgw_utils, only: p0
use interpolate_data, only: lininterp
use cam_history, only: outfld, addfld

use ftorch

implicit none

public :: gw_nlgw_dp_ml, gw_nlgw_dp_init, gw_nlgw_dp_finalize
public :: gw_nlgw_ann_infer, gw_nlgw_ann_init, gw_nlgw_ann_finalize

private

integer, parameter :: p0 = 100000 ! 1000 hPa (Pa)

type(torch_model) :: nlgw_model ! pytorch model

integer :: ncol ! number of vertical columns
Expand Down Expand Up @@ -105,7 +104,9 @@ module gw_nlgw

!==========================================================================

subroutine gw_nlgw_dp_ml(state_in, ptend, lchnk)
subroutine gw_nlgw_ann_infer(state_in, ptend, lchnk)

use gw_nlgw_utils, only: flux_to_forcing

! inputs
type(physics_state), intent(in) :: state_in
Expand Down Expand Up @@ -172,8 +173,8 @@ subroutine gw_nlgw_dp_ml(state_in, ptend, lchnk)
call extract_output()
call denormalise_data()

call flux_to_forcing(uflux, utgw)
call flux_to_forcing(vflux, vtgw)
call flux_to_forcing(uflux, utgw, pmid, ncol)
call flux_to_forcing(vflux, vtgw, pmid, ncol)

! Write UTGW and VTGW to file
call outfld('UTGW_NL', utgw, ncol, lchnk)
Expand Down Expand Up @@ -209,10 +210,10 @@ subroutine gw_nlgw_dp_ml(state_in, ptend, lchnk)
deallocate(net_inputs)
deallocate(net_outputs)

end subroutine gw_nlgw_dp_ml
end subroutine gw_nlgw_ann_infer


subroutine gw_nlgw_dp_init(model_path)
subroutine gw_nlgw_ann_init(model_path)

character(len=*), intent(in) :: model_path ! Filepath to PyTorch Torchscript net
integer :: device_id
Expand All @@ -233,17 +234,17 @@ subroutine gw_nlgw_dp_init(model_path)
call addfld('UFLUX_NL', (/ 'lev' /), 'A', 'm/s', 'Nonlinear GW zonal wind flux')
call addfld('VFLUX_NL', (/ 'lev' /), 'A', 'm/s', 'Nonlinear GW meridional wind flux')

end subroutine gw_nlgw_dp_init
end subroutine gw_nlgw_ann_init


subroutine gw_nlgw_dp_finalize()
subroutine gw_nlgw_ann_finalize()

deallocate(net_inputs)
deallocate(net_outputs)
! free model memory
call torch_delete(nlgw_model)

end subroutine gw_nlgw_dp_finalize
end subroutine gw_nlgw_ann_finalize


subroutine read_norms()
Expand Down Expand Up @@ -273,6 +274,7 @@ subroutine read_norms()
end subroutine read_norms

subroutine normalise_data()
use gw_nlgw_utils, only: cbrt

! lat lon are in radians (convert to degrees first)
lat = lat * 180. / pi
Expand Down Expand Up @@ -377,31 +379,4 @@ subroutine denormalise_data()

end subroutine denormalise_data

elemental function cbrt(a) result(root)
real(r8), intent(in) :: a
real(r8), parameter :: one_third = 1._r8/3._r8
real(r8) :: root
root = sign(abs(a)**one_third, a)
end function cbrt

subroutine flux_to_forcing(flux, forcing)

real(r8), intent(in), dimension(:,:) :: flux
real(r8), intent(out), dimension(:,:) :: forcing ! forcing = -d(u'\omega')/d(p), units = m/s^2

integer :: level, col

! convert fluxes to tendencies
! pressure profile must be in Pascals

do col = 1, ncol
forcing(col,1) = -1*(flux(col,2) - flux(col,1))/(pmid(col,2) - pmid(col,1))
do level = 2, pver-1
forcing(col,level) = -1*(flux(col,level+1) - flux(col,level-1)) / (pmid(col,level)*(log(pmid(col,level+1)) - log(pmid(col,level-1))))
end do
forcing(col,pver) = -1*(flux(col,pver) - flux(col,pver-1)) / (pmid(col,pver) - pmid(col,pver-1))
end do

end subroutine flux_to_forcing

end module gw_nlgw
end module gw_nlgw_ann
Loading