@@ -73,7 +73,7 @@ def json_to_labeled_instances(self, inputs: JsonDict) -> List[Instance]:
7373
7474 def get_gradients (self , instances : List [Instance ]) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
7575 """
76- Gets the gradients of the loss with respect to the model inputs.
76+ Gets the gradients of the logits with respect to the model inputs.
7777
7878 # Parameters
7979
@@ -91,7 +91,7 @@ def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict
9191 Takes a `JsonDict` representing the inputs of the model and converts
9292 them to [`Instances`](../data/instance.md)), sends these through
9393 the model [`forward`](../models/model.md#forward) function after registering hooks on the embedding
94- layer of the model. Calls `backward` on the loss and then removes the
94+ layer of the model. Calls `backward` on the logits and then removes the
9595 hooks.
9696 """
9797 # set requires_grad to true for all parameters, but save original values to
@@ -113,13 +113,13 @@ def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict
113113 self ._model .forward (** dataset_tensor_dict ) # type: ignore
114114 )
115115
116- loss = outputs ["loss" ]
116+ predicted_logit = outputs ["logits" ]. squeeze ( 0 )[ int ( torch . argmax ( outputs [ 'probs' ])) ]
117117 # Zero gradients.
118118 # NOTE: this is actually more efficient than calling `self._model.zero_grad()`
119119 # because it avoids a read op when the gradients are first updated below.
120120 for p in self ._model .parameters ():
121121 p .grad = None
122- loss .backward ()
122+ predicted_logit .backward ()
123123
124124 for hook in hooks :
125125 hook .remove ()
0 commit comments