@@ -691,8 +691,12 @@ module function get_params(self) result(params)
691691 params = this_layer % get_params()
692692 type is (embedding_layer)
693693 params = this_layer % get_params()
694+
694695 type is (layernorm_layer)
695- params = this_layer % get_params()
696+ call this_layer % get_params_ptr(w_ptr, b_ptr)
697+ allocate (params(size (w_ptr) + size (b_ptr)))
698+ params(1 :size (w_ptr)) = w_ptr
699+ params(size (w_ptr)+ 1 :) = b_ptr
696700 class default
697701 error stop ' Unknown layer type.'
698702 end select
@@ -703,6 +707,8 @@ end function get_params
703707 module subroutine set_params (self , params )
704708 class(layer), intent (in out ) :: self
705709 real , intent (in ) :: params(:)
710+ real , pointer :: w_ptr(:)
711+ real , pointer :: b_ptr(:)
706712
707713 ! Check that the number of parameters is correct.
708714 ! This check will still pass if the size(params) == 0 and the layer is a
@@ -736,37 +742,55 @@ module subroutine set_params(self, params)
736742 // ' on a zero-parameter layer; nothing to do.'
737743
738744 type is (dense_layer)
739- call this_layer % set_params(params)
745+ call this_layer % get_params_ptr(w_ptr, b_ptr)
746+
747+ w_ptr = params(1 :size (w_ptr))
748+ b_ptr = params(size (w_ptr)+ 1 :)
740749
741750 type is (dropout_layer)
742751 ! No parameters to set.
743752 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
744753 // ' on a zero-parameter layer; nothing to do.'
745754
746755 type is (conv1d_layer)
747- call this_layer % set_params(params)
756+ call this_layer % get_params_ptr(w_ptr, b_ptr)
757+
758+ w_ptr = params(1 :size (w_ptr))
759+ b_ptr = params(size (w_ptr)+ 1 :)
748760
749761 type is (conv2d_layer)
750- call this_layer % set_params(params)
762+ call this_layer % get_params_ptr(w_ptr, b_ptr)
763+
764+ w_ptr = params(1 :size (w_ptr))
765+ b_ptr = params(size (w_ptr)+ 1 :)
751766
752767 type is (locally_connected2d_layer)
753- call this_layer % set_params(params)
768+ call this_layer % get_params_ptr(w_ptr, b_ptr)
769+
770+ w_ptr = params(1 :size (w_ptr))
771+ b_ptr = params(size (w_ptr)+ 1 :)
754772
755773 type is (maxpool1d_layer)
756774 ! No parameters to set.
757775 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
758776 // ' on a zero-parameter layer; nothing to do.'
759777
760778 type is (linear2d_layer)
761- call this_layer % set_params(params)
779+ call this_layer % get_params_ptr(w_ptr, b_ptr)
780+
781+ w_ptr = params(1 :size (w_ptr))
782+ b_ptr = params(size (w_ptr)+ 1 :)
762783
763784 type is (self_attention_layer)
764785 call this_layer % set_params(params)
765786 type is (embedding_layer)
766787 call this_layer % set_params(params)
767788
768789 type is (layernorm_layer)
769- call this_layer % set_params(params)
790+ call this_layer % get_params_ptr(w_ptr, b_ptr)
791+
792+ w_ptr = params(1 :size (w_ptr))
793+ b_ptr = params(size (w_ptr)+ 1 :)
770794
771795 type is (maxpool2d_layer)
772796 ! No parameters to set.
0 commit comments