Skip to content
This repository was archived by the owner on May 3, 2022. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions rinokeras/core/v1x/train/RinokerasGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Sequence, Union, Any, Optional, Dict
import pickle as pkl

import h5py
import tensorflow as tf
from tensorflow.python.client import timeline
from tqdm import tqdm
Expand Down Expand Up @@ -159,19 +160,36 @@ def run_epoch(self,
data_len: Optional[int] = None,
epoch_num: Optional[int] = None,
summary_writer: Optional[tf.summary.FileWriter] = None,
save_outputs: Optional[str] = None) -> MetricsAccumulator:
all_outputs = []
save_outputs: Optional[str] = None,
save_format: Optional[str] = 'pkl') -> MetricsAccumulator:

if save_format == 'pkl':
all_outputs = []
elif save_format == 'h5':
h5_outf = h5py.File(save_outputs, 'w')
i = 0
else:
raise Exception('Unsupported save format: {}'.format(save_format))

with self.add_progress_bar(data_len, epoch_num).initialize():
assert self.epoch_metrics is not None
while True:
if save_outputs is not None:
loss, outputs = self.run('default', return_outputs=True)
all_outputs.append(outputs)
if save_format == 'pkl':
all_outputs.append(outputs)
elif save_format == 'h5':
grp = h5_outf.create_group(str(i))
outputs = outputs[0] # can we rely on this being a tuple of length 1?
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DavidMChan I'm not sure about this, but can I rely on the run output always being a tuple of length 1?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Effectively, the outputs are the forward pass of your model. This means that "outputs" can be whatever you want it to be (a numpy array, a dict of numpy arrays, a tuple of arrays, etc.) This is probably why @rmrao hijacked it for use in TAPE. It also makes it tricky to write a generic saving function for the outputs since you have no guarantees on the data format. You can know, however, that the outputs will be the result of a forward run of the model (so they are convertible to tensors).

Perhaps it makes sense instead of adding options, to add a callback function? Not entirely sure, but this is why we used pickle.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of a callback function. The save_outputs option is really more about quick-and-dirty debugging than it is a real feature at the moment.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise there's not a really good, general way of saving things. It'll vary hugely. Plus a callback would let us do things other than saving them.

for key in outputs.keys():
grp.create_dataset(key, data=outputs[key])
i += 1
else:
self.run('default')

if save_outputs is not None:
if save_format == 'h5':
h5_outf.close()
if save_outputs is not None and save_format == 'pkl':
with open(save_outputs, 'wb') as f:
pkl.dump(all_outputs, f)

Expand Down