2929 Iterable ,
3030 Optional ,
3131 ContextManager ,
32+ Tuple ,
3233)
3334import uuid
3435
@@ -195,6 +196,7 @@ def __init__(
195196 self ._logdir = logdir
196197 self ._allowed_plugins = frozenset (allowed_plugins )
197198 self ._run_name_prefix = run_name_prefix
199+ self ._is_brand_new_experiment = False
198200
199201 self ._upload_limits = upload_limits
200202 if not self ._upload_limits :
@@ -265,6 +267,9 @@ def active_filter(secs):
265267 self ._logdir_loader = logdir_loader .LogdirLoader (
266268 self ._logdir , directory_loader_factory
267269 )
270+ self ._logdir_loader_pre_create = logdir_loader .LogdirLoader (
271+ self ._logdir , directory_loader_factory
272+ )
268273 self ._tracker = upload_tracker .UploadTracker (verbosity = self ._verbosity )
269274
270275 self ._create_additional_senders ()
@@ -290,6 +295,7 @@ def _create_or_get_experiment(self) -> tensorboard_experiment.TensorboardExperim
290295 tensorboard_experiment = tb_experiment ,
291296 tensorboard_experiment_id = self ._experiment_name ,
292297 )
298+ self ._is_brand_new_experiment = True
293299 except exceptions .AlreadyExists :
294300 logger .info ("Creating experiment failed. Retrieving experiment." )
295301 experiment_name = os .path .join (
@@ -303,7 +309,11 @@ def create_experiment(self):
303309
304310 experiment = self ._create_or_get_experiment ()
305311 self ._experiment = experiment
306- request_sender = _BatchedRequestSender (
312+ self ._one_platform_resource_manager = uploader_utils .OnePlatformResourceManager (
313+ self ._experiment .name , self ._api
314+ )
315+
316+ self ._request_sender = _BatchedRequestSender (
307317 self ._experiment .name ,
308318 self ._api ,
309319 allowed_plugins = self ._allowed_plugins ,
@@ -313,6 +323,7 @@ def create_experiment(self):
313323 blob_rpc_rate_limiter = self ._blob_rpc_rate_limiter ,
314324 blob_storage_bucket = self ._blob_storage_bucket ,
315325 blob_storage_folder = self ._blob_storage_folder ,
326+ one_platform_resource_manager = self ._one_platform_resource_manager ,
316327 tracker = self ._tracker ,
317328 )
318329
@@ -323,7 +334,8 @@ def create_experiment(self):
323334 )
324335
325336 self ._dispatcher = _Dispatcher (
326- request_sender = request_sender , additional_senders = self ._additional_senders ,
337+ request_sender = self ._request_sender ,
338+ additional_senders = self ._additional_senders ,
327339 )
328340
329341 def _create_additional_senders (self ) -> Dict [str , uploader_utils .RequestSender ]:
@@ -366,6 +378,17 @@ def start_uploading(self):
366378 """
367379 if self ._dispatcher is None :
368380 raise RuntimeError ("Must call create_experiment() before start_uploading()" )
381+
382+ if self ._one_shot :
383+ if self ._is_brand_new_experiment :
384+ self ._pre_create_runs_and_time_series ()
385+ else :
386+ logger .warning (
387+ "Please consider uploading to a new experiment instead of "
388+ "an existing one, as the former allows for better upload "
389+ "performance."
390+ )
391+
369392 while True :
370393 self ._logdir_poll_rate_limiter .tick ()
371394 self ._upload_once ()
@@ -377,6 +400,58 @@ def start_uploading(self):
377400 "without any uploadable data" % self ._logdir
378401 )
379402
403+ def _pre_create_runs_and_time_series (self ):
404+ """
405+ Iterates though the log dir to collect TensorboardRuns and
406+ TensorboardTimeSeries that need to be created, and creates them in batch
407+ to speed up uploading later on.
408+ """
409+ self ._logdir_loader_pre_create .synchronize_runs ()
410+ run_to_events = self ._logdir_loader_pre_create .get_run_events ()
411+ if self ._run_name_prefix :
412+ run_to_events = {
413+ self ._run_name_prefix + k : v for k , v in run_to_events .items ()
414+ }
415+
416+ run_names = []
417+ run_tag_name_to_time_series_proto = {}
418+ for (run_name , events ) in run_to_events .items ():
419+ run_names .append (run_name )
420+ for event in events :
421+ _filter_graph_defs (event )
422+ for value in event .summary .value :
423+ metadata , is_valid = self ._request_sender .get_metadata_and_validate (
424+ run_name , value
425+ )
426+ if not is_valid :
427+ continue
428+ if metadata .data_class == summary_pb2 .DATA_CLASS_SCALAR :
429+ value_type = (
430+ tensorboard_time_series .TensorboardTimeSeries .ValueType .SCALAR
431+ )
432+ elif metadata .data_class == summary_pb2 .DATA_CLASS_TENSOR :
433+ value_type = (
434+ tensorboard_time_series .TensorboardTimeSeries .ValueType .TENSOR
435+ )
436+ elif metadata .data_class == summary_pb2 .DATA_CLASS_BLOB_SEQUENCE :
437+ value_type = (
438+ tensorboard_time_series .TensorboardTimeSeries .ValueType .BLOB_SEQUENCE
439+ )
440+
441+ run_tag_name_to_time_series_proto [
442+ (run_name , value .tag )
443+ ] = tensorboard_time_series .TensorboardTimeSeries (
444+ display_name = value .tag ,
445+ value_type = value_type ,
446+ plugin_name = metadata .plugin_data .plugin_name ,
447+ plugin_data = metadata .plugin_data .content ,
448+ )
449+
450+ self ._one_platform_resource_manager .batch_create_runs (run_names )
451+ self ._one_platform_resource_manager .batch_create_time_series (
452+ run_tag_name_to_time_series_proto
453+ )
454+
380455 def _upload_once (self ):
381456 """Runs one upload cycle, sending zero or more RPCs."""
382457 logger .info ("Starting an upload cycle" )
@@ -439,6 +514,7 @@ def __init__(
439514 blob_rpc_rate_limiter : util .RateLimiter ,
440515 blob_storage_bucket : storage .Bucket ,
441516 blob_storage_folder : str ,
517+ one_platform_resource_manager : uploader_utils .OnePlatformResourceManager ,
442518 tracker : upload_tracker .UploadTracker ,
443519 ):
444520 """Constructs _BatchedRequestSender for the given experiment resource.
@@ -456,16 +532,16 @@ def __init__(
456532 Note the chunk stream is internally rate-limited by backpressure from
457533 the server, so it is not a concern that we do not explicitly rate-limit
458534 within the stream here.
535+ one_platform_resource_manager: An instance of the One Platform
536+ resource management class.
459537 tracker: Upload tracker to track information about uploads.
460538 """
461539 self ._experiment_resource_name = experiment_resource_name
462540 self ._api = api
463541 self ._tag_metadata = {}
464542 self ._allowed_plugins = frozenset (allowed_plugins )
465543 self ._tracker = tracker
466- self ._one_platform_resource_manager = uploader_utils .OnePlatformResourceManager (
467- self ._experiment_resource_name , self ._api
468- )
544+ self ._one_platform_resource_manager = one_platform_resource_manager
469545 self ._scalar_request_sender = _ScalarBatchedRequestSender (
470546 experiment_resource_id = experiment_resource_name ,
471547 api = api ,
@@ -516,6 +592,37 @@ def send_request(
516592 RuntimeError: If no progress can be made because even a single
517593 point is too large (say, due to a gigabyte-long tag name).
518594 """
595+ metadata , is_valid = self .get_metadata_and_validate (run_name , value )
596+ if not is_valid :
597+ return
598+ plugin_name = metadata .plugin_data .plugin_name
599+ self ._tracker .add_plugin_name (plugin_name )
600+
601+ if metadata .data_class == summary_pb2 .DATA_CLASS_SCALAR :
602+ self ._scalar_request_sender .add_event (run_name , event , value , metadata )
603+ elif metadata .data_class == summary_pb2 .DATA_CLASS_TENSOR :
604+ self ._tensor_request_sender .add_event (run_name , event , value , metadata )
605+ elif metadata .data_class == summary_pb2 .DATA_CLASS_BLOB_SEQUENCE :
606+ self ._blob_request_sender .add_event (run_name , event , value , metadata )
607+
608+ def flush (self ):
609+ """Flushes any events that have been stored."""
610+ self ._scalar_request_sender .flush ()
611+ self ._tensor_request_sender .flush ()
612+ self ._blob_request_sender .flush ()
613+
614+ def get_metadata_and_validate (
615+ self , run_name : str , value : tf .compat .v1 .Summary .Value
616+ ) -> Tuple [tf .compat .v1 .SummaryMetadata , bool ]:
617+ """
618+
619+ :param run_name: Name of the run retrieved by
620+ `LogdirLoader.get_run_events`
621+ :param value: A single `tf.compat.v1.Summary.Value` from the event,
622+ where there can be multiple values per event.
623+ :return: (metadata, is_valid): a metadata derived from the value, and
624+ whether the value itself is valid.
625+ """
519626
520627 time_series_key = (run_name , value .tag )
521628
@@ -539,29 +646,16 @@ def send_request(
539646 metadata .plugin_data .plugin_name ,
540647 value .metadata .plugin_data .plugin_name ,
541648 )
542- return
649+ return metadata , False
543650 if plugin_name not in self ._allowed_plugins :
544651 if first_in_time_series :
545652 logger .info (
546653 "Skipping time series %r with unsupported plugin name %r" ,
547654 time_series_key ,
548655 plugin_name ,
549656 )
550- return
551- self ._tracker .add_plugin_name (plugin_name )
552-
553- if metadata .data_class == summary_pb2 .DATA_CLASS_SCALAR :
554- self ._scalar_request_sender .add_event (run_name , event , value , metadata )
555- elif metadata .data_class == summary_pb2 .DATA_CLASS_TENSOR :
556- self ._tensor_request_sender .add_event (run_name , event , value , metadata )
557- elif metadata .data_class == summary_pb2 .DATA_CLASS_BLOB_SEQUENCE :
558- self ._blob_request_sender .add_event (run_name , event , value , metadata )
559-
560- def flush (self ):
561- """Flushes any events that have been stored."""
562- self ._scalar_request_sender .flush ()
563- self ._tensor_request_sender .flush ()
564- self ._blob_request_sender .flush ()
657+ return metadata , False
658+ return metadata , True
565659
566660
567661class _Dispatcher (object ):
0 commit comments