Content OverviewUse tf.distribute.Strategy with custom training loopsWhat’s supported now?Examples and tutorialsOther topicsSetting up the TF_CONFIG environment variableWhat’s next?\Use tf.distribute.Strategy with custom training loopsAs demonstrated above, using tf.distribute.Strategy with Keras Model.fit requires changing only a couple lines of your code. With a little more effort, you can also use tf.distribute.Strategy with custom training loops.If you need more flexibility and control over your training loops than is possible with Estimator or Keras, you can write custom training loops. For instance, when using a GAN, you may want to take a different number of generator or discriminator steps each round. Similarly, the high level frameworks are not very suitable for Reinforcement Learning training.The tf.distribute.Strategy classes provide a core set of methods to support custom training loops. Using these may require minor restructuring of the code initially, but once that is done, you should be able to switch between GPUs, TPUs, and multiple machines simply by changing the strategy instance.Below is a brief snippet illustrating this use case for a simple training example using the same Keras model as before.First, create the model and optimizer inside the strategy's scope. This ensures that any variables created with the model and optimizer are mirrored variables.\with mirrored_strategy.scope(): model = tf.keras.Sequential([ tf.keras.layers.Dense(1, input_shape=(1,), kernel_regularizer=tf.keras.regularizers.L2(1e-4))]) optimizer = tf.keras.optimizers.SGD()Next, create the input dataset and call tf.distribute.Strategy.experimental_distribute_dataset to distribute the dataset based on the strategy.\dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(1000).batch( global_batch_size)dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)Then, define one step of the training. Use tf.GradientTape to compute gradients and optimizer to apply those gradients to update your model's variables. To distribute this training step, put it in a function train_step and pass it to tf.distribute.Strategy.run along with the dataset inputs you got from the dist_dataset created before:\# Sets `reduction=NONE` to leave it to tf.nn.compute_average_loss() below.loss_object = tf.keras.losses.BinaryCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE)def train_step(inputs): features, labels = inputs with tf.GradientTape() as tape: predictions = model(features, training=True) per_example_loss = loss_object(labels, predictions) loss = tf.nn.compute_average_loss(per_example_loss) model_losses = model.losses if model_losses: loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses)) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) return loss@tf.functiondef distributed_train_step(dist_inputs): per_replica_losses = mirrored_strategy.run(train_step, args=(dist_inputs,)) return mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)A few other things to note in the code above:You used tf.nn.compute_average_loss to reduce the per-example prediction losses to a scalar. tf.nn.compute_average_loss sums the per example loss and divides the sum by the global batch size. This is important because later after the gradients are calculated on each replica, they are aggregated across the replicas by summing them.By default, the global batch size is taken to be tf.get_strategy().num_replicas_in_sync * tf.shape(per_example_loss)[0]. It can also be specified explicitly as a keyword argument global_batch_size=. Without short batches, the default is equivalent to tf.nn.compute_average_loss(..., global_batch_size=global_batch_size) with the global_batch_size defined above. (For more on short batches and how to avoid or handle them, see the Custom Training tutorial.)You used tf.nn.scale_regularization_loss to scale regularization losses registered with the Model object, if any, by 1/num_replicas_in_sync as well. For those regularization losses that are input-dependent, it falls on the modeling code, not the custom training loop, to perform the averaging over the per-replica(!) batch size; that way the modeling code can remain agnostic of replication while the training loop remains agnostic of how regularization losses are computed.When you call apply_gradients within a distribution strategy scope, its behavior is modified. Specifically, before applying gradients on each parallel instance during synchronous training, it performs a sum-over-all-replicas of the gradients.You also used the tf.distribute.Strategy.reduce API to aggregate the results returned by tf.distribute.Strategy.run for reporting. tf.distribute.Strategy.run returns results from each local replica in the strategy, and there are multiple ways to consume this result. You can reduce them to get an aggregated value. You can also do tf.distribute.Strategy.experimental_local_results to get the list of values contained in the result, one per local replica.Finally, once you have defined the training step, you can iterate over dist_dataset and run the training in a loop:\for dist_inputs in dist_dataset: print(distributed_train_step(dist_inputs))\tf.Tensor(0.9024367, shape=(), dtype=float32)tf.Tensor(0.8953863, shape=(), dtype=float32)tf.Tensor(0.8884038, shape=(), dtype=float32)tf.Tensor(0.88148874, shape=(), dtype=float32)tf.Tensor(0.87464076, shape=(), dtype=float32)tf.Tensor(0.86785895, shape=(), dtype=float32)tf.Tensor(0.86114323, shape=(), dtype=float32)tf.Tensor(0.8544927, shape=(), dtype=float32)tf.Tensor(0.84790725, shape=(), dtype=float32)tf.Tensor(0.841386, shape=(), dtype=float32)tf.Tensor(0.83492863, shape=(), dtype=float32)tf.Tensor(0.8285344, shape=(), dtype=float32)tf.Tensor(0.82220304, shape=(), dtype=float32)tf.Tensor(0.8159339, shape=(), dtype=float32)tf.Tensor(0.8097264, shape=(), dtype=float32)tf.Tensor(0.8035801, shape=(), dtype=float32)tf.Tensor(0.79749453, shape=(), dtype=float32)tf.Tensor(0.79146886, shape=(), dtype=float32)tf.Tensor(0.785503, shape=(), dtype=float32)tf.Tensor(0.779596, shape=(), dtype=float32)tf.Tensor(0.77374756, shape=(), dtype=float32)tf.Tensor(0.7679571, shape=(), dtype=float32)tf.Tensor(0.7622242, shape=(), dtype=float32)tf.Tensor(0.7565481, shape=(), dtype=float32)tf.Tensor(0.75092846, shape=(), dtype=float32)tf.Tensor(0.7453647, shape=(), dtype=float32)tf.Tensor(0.73985624, shape=(), dtype=float32)tf.Tensor(0.7344028, shape=(), dtype=float32)tf.Tensor(0.7290035, shape=(), dtype=float32)tf.Tensor(0.723658, shape=(), dtype=float32)tf.Tensor(0.7183659, shape=(), dtype=float32)tf.Tensor(0.71312654, shape=(), dtype=float32)tf.Tensor(0.7079393, shape=(), dtype=float32)tf.Tensor(0.70280397, shape=(), dtype=float32)tf.Tensor(0.6977197, shape=(), dtype=float32)tf.Tensor(0.69268626, shape=(), dtype=float32)tf.Tensor(0.687703, shape=(), dtype=float32)tf.Tensor(0.68276954, shape=(), dtype=float32)tf.Tensor(0.67788523, shape=(), dtype=float32)tf.Tensor(0.6730496, shape=(), dtype=float32)tf.Tensor(0.66826224, shape=(), dtype=float32)tf.Tensor(0.66352266, shape=(), dtype=float32)tf.Tensor(0.6588302, shape=(), dtype=float32)tf.Tensor(0.6541846, shape=(), dtype=float32)tf.Tensor(0.6495853, shape=(), dtype=float32)tf.Tensor(0.64503175, shape=(), dtype=float32)tf.Tensor(0.6405235, shape=(), dtype=float32)tf.Tensor(0.6360602, shape=(), dtype=float32)tf.Tensor(0.6316412, shape=(), dtype=float32)tf.Tensor(0.62726617, shape=(), dtype=float32)tf.Tensor(0.6229345, shape=(), dtype=float32)tf.Tensor(0.61864597, shape=(), dtype=float32)tf.Tensor(0.6143999, shape=(), dtype=float32)tf.Tensor(0.6101959, shape=(), dtype=float32)tf.Tensor(0.60603356, shape=(), dtype=float32)tf.Tensor(0.60191244, shape=(), dtype=float32)tf.Tensor(0.597832, shape=(), dtype=float32)tf.Tensor(0.5937919, shape=(), dtype=float32)tf.Tensor(0.5897917, shape=(), dtype=float32)tf.Tensor(0.585831, shape=(), dtype=float32)tf.Tensor(0.58190924, shape=(), dtype=float32)tf.Tensor(0.5780261, shape=(), dtype=float32)tf.Tensor(0.57418114, shape=(), dtype=float32)tf.Tensor(0.57037395, shape=(), dtype=float32)tf.Tensor(0.5666041, shape=(), dtype=float32)tf.Tensor(0.56287116, shape=(), dtype=float32)tf.Tensor(0.55917484, shape=(), dtype=float32)tf.Tensor(0.5555145, shape=(), dtype=float32)tf.Tensor(0.55189, shape=(), dtype=float32)tf.Tensor(0.54830086, shape=(), dtype=float32)tf.Tensor(0.54474664, shape=(), dtype=float32)tf.Tensor(0.54122704, shape=(), dtype=float32)tf.Tensor(0.5377416, shape=(), dtype=float32)tf.Tensor(0.5342899, shape=(), dtype=float32)tf.Tensor(0.5308717, shape=(), dtype=float32)tf.Tensor(0.5274865, shape=(), dtype=float32)tf.Tensor(0.52413404, shape=(), dtype=float32)tf.Tensor(0.52081394, shape=(), dtype=float32)tf.Tensor(0.51752573, shape=(), dtype=float32)tf.Tensor(0.5142692, shape=(), dtype=float32)tf.Tensor(0.51104385, shape=(), dtype=float32)tf.Tensor(0.50784945, shape=(), dtype=float32)tf.Tensor(0.50468564, shape=(), dtype=float32)tf.Tensor(0.50155205, shape=(), dtype=float32)tf.Tensor(0.49844825, shape=(), dtype=float32)tf.Tensor(0.4953741, shape=(), dtype=float32)tf.Tensor(0.49232918, shape=(), dtype=float32)tf.Tensor(0.4893132, shape=(), dtype=float32)tf.Tensor(0.48632562, shape=(), dtype=float32)tf.Tensor(0.4833664, shape=(), dtype=float32)tf.Tensor(0.4804351, shape=(), dtype=float32)tf.Tensor(0.47753143, shape=(), dtype=float32)tf.Tensor(0.47465506, shape=(), dtype=float32)tf.Tensor(0.47180572, shape=(), dtype=float32)tf.Tensor(0.46898302, shape=(), dtype=float32)tf.Tensor(0.4661867, shape=(), dtype=float32)tf.Tensor(0.46341658, shape=(), dtype=float32)tf.Tensor(0.4606722, shape=(), dtype=float32)tf.Tensor(0.4579534, shape=(), dtype=float32)tf.Tensor(0.4552598, shape=(), dtype=float32)tf.Tensor(0.45259115, shape=(), dtype=float32)tf.Tensor(0.44994718, shape=(), dtype=float32)tf.Tensor(0.44732755, shape=(), dtype=float32)tf.Tensor(0.44473216, shape=(), dtype=float32)tf.Tensor(0.44216052, shape=(), dtype=float32)tf.Tensor(0.4396125, shape=(), dtype=float32)tf.Tensor(0.43708783, shape=(), dtype=float32)tf.Tensor(0.4345862, shape=(), dtype=float32)tf.Tensor(0.4321074, shape=(), dtype=float32)tf.Tensor(0.42965108, shape=(), dtype=float32)tf.Tensor(0.4272171, shape=(), dtype=float32)tf.Tensor(0.42480516, shape=(), dtype=float32)tf.Tensor(0.42241505, shape=(), dtype=float32)tf.Tensor(0.42004645, shape=(), dtype=float32)tf.Tensor(0.41769922, shape=(), dtype=float32)tf.Tensor(0.41537297, shape=(), dtype=float32)tf.Tensor(0.41306767, shape=(), dtype=float32)tf.Tensor(0.41078293, shape=(), dtype=float32)tf.Tensor(0.4085186, shape=(), dtype=float32)tf.Tensor(0.4062744, shape=(), dtype=float32)tf.Tensor(0.4040502, shape=(), dtype=float32)tf.Tensor(0.40184572, shape=(), dtype=float32)tf.Tensor(0.39966068, shape=(), dtype=float32)tf.Tensor(0.3974949, shape=(), dtype=float32)tf.Tensor(0.39534825, shape=(), dtype=float32)tf.Tensor(0.39322042, shape=(), dtype=float32)tf.Tensor(0.39111122, shape=(), dtype=float32)tf.Tensor(0.3890205, shape=(), dtype=float32)tf.Tensor(0.38694802, shape=(), dtype=float32)tf.Tensor(0.38489357, shape=(), dtype=float32)tf.Tensor(0.38285697, shape=(), dtype=float32)tf.Tensor(0.38083804, shape=(), dtype=float32)tf.Tensor(0.3788365, shape=(), dtype=float32)tf.Tensor(0.37685227, shape=(), dtype=float32)tf.Tensor(0.3748851, shape=(), dtype=float32)tf.Tensor(0.37293482, shape=(), dtype=float32)tf.Tensor(0.37100127, shape=(), dtype=float32)tf.Tensor(0.36908418, shape=(), dtype=float32)tf.Tensor(0.36718345, shape=(), dtype=float32)tf.Tensor(0.3652989, shape=(), dtype=float32)tf.Tensor(0.36343032, shape=(), dtype=float32)tf.Tensor(0.36157757, shape=(), dtype=float32)tf.Tensor(0.35974047, shape=(), dtype=float32)tf.Tensor(0.3579188, shape=(), dtype=float32)tf.Tensor(0.35611248, shape=(), dtype=float32)tf.Tensor(0.3543213, shape=(), dtype=float32)tf.Tensor(0.35254508, shape=(), dtype=float32)tf.Tensor(0.3507837, shape=(), dtype=float32)tf.Tensor(0.34903696, shape=(), dtype=float32)tf.Tensor(0.34730473, shape=(), dtype=float32)tf.Tensor(0.3455869, shape=(), dtype=float32)tf.Tensor(0.3438832, shape=(), dtype=float32)tf.Tensor(0.34219357, shape=(), dtype=float32)tf.Tensor(0.3405178, shape=(), dtype=float32)tf.Tensor(0.3388558, shape=(), dtype=float32)tf.Tensor(0.3372074, shape=(), dtype=float32)tf.Tensor(0.33557245, shape=(), dtype=float32)tf.Tensor(0.33395082, shape=(), dtype=float32)tf.Tensor(0.33234236, shape=(), dtype=float32)tf.Tensor(0.33074695, shape=(), dtype=float32)tf.Tensor(0.32916442, shape=(), dtype=float32)tf.Tensor(0.3275946, shape=(), dtype=float32)tf.Tensor(0.3260375, shape=(), dtype=float32)tf.Tensor(0.3244928, shape=(), dtype=float32)tf.Tensor(0.3229605, shape=(), dtype=float32)tf.Tensor(0.32144046, shape=(), dtype=float32)tf.Tensor(0.31993246, shape=(), dtype=float32)tf.Tensor(0.3184365, shape=(), dtype=float32)tf.Tensor(0.31695238, shape=(), dtype=float32)tf.Tensor(0.31548, shape=(), dtype=float32)tf.Tensor(0.31401917, shape=(), dtype=float32)tf.Tensor(0.3125699, shape=(), dtype=float32)tf.Tensor(0.31113195, shape=(), dtype=float32)tf.Tensor(0.30970532, shape=(), dtype=float32)tf.Tensor(0.3082898, shape=(), dtype=float32)tf.Tensor(0.30688527, shape=(), dtype=float32)tf.Tensor(0.3054917, shape=(), dtype=float32)tf.Tensor(0.30410892, shape=(), dtype=float32)tf.Tensor(0.3027368, shape=(), dtype=float32)tf.Tensor(0.30137527, shape=(), dtype=float32)tf.Tensor(0.3000242, shape=(), dtype=float32)tf.Tensor(0.29868355, shape=(), dtype=float32)tf.Tensor(0.29735315, shape=(), dtype=float32)tf.Tensor(0.29603288, shape=(), dtype=float32)tf.Tensor(0.29472268, shape=(), dtype=float32)tf.Tensor(0.2934224, shape=(), dtype=float32)tf.Tensor(0.29213202, shape=(), dtype=float32)tf.Tensor(0.29085135, shape=(), dtype=float32)tf.Tensor(0.28958035, shape=(), dtype=float32)tf.Tensor(0.2883189, shape=(), dtype=float32)tf.Tensor(0.28706694, shape=(), dtype=float32)tf.Tensor(0.28582436, shape=(), dtype=float32)tf.Tensor(0.28459102, shape=(), dtype=float32)tf.Tensor(0.28336692, shape=(), dtype=float32)tf.Tensor(0.2821518, shape=(), dtype=float32)tf.Tensor(0.28094578, shape=(), dtype=float32)tf.Tensor(0.27974862, shape=(), dtype=float32)tf.Tensor(0.2785603, shape=(), dtype=float32)tf.Tensor(0.27738073, shape=(), dtype=float32)tf.Tensor(0.2762098, shape=(), dtype=float32)In the example above, you iterated over the dist_dataset to provide input to your training. You are also provided with the tf.distribute.Strategy.make_experimental_numpy_dataset to support NumPy inputs. You can use this API to create a dataset before calling tf.distribute.Strategy.experimental_distribute_dataset.Another way of iterating over your data is to explicitly use iterators. You may want to do this when you want to run for a given number of steps as opposed to iterating over the entire dataset. The above iteration would now be modified to first create an iterator and then explicitly call next on it to get the input data.\iterator = iter(dist_dataset)for _ in range(10): print(distributed_train_step(next(iterator)))\tf.Tensor(0.27504745, shape=(), dtype=float32)tf.Tensor(0.2738936, shape=(), dtype=float32)tf.Tensor(0.2727481, shape=(), dtype=float32)tf.Tensor(0.27161098, shape=(), dtype=float32)tf.Tensor(0.27048206, shape=(), dtype=float32)tf.Tensor(0.26936132, shape=(), dtype=float32)tf.Tensor(0.26824862, shape=(), dtype=float32)tf.Tensor(0.26714393, shape=(), dtype=float32)tf.Tensor(0.26604718, shape=(), dtype=float32)tf.Tensor(0.26495826, shape=(), dtype=float32)This covers the simplest case of using tf.distribute.Strategy API to distribute custom training loops.What's supported now?| Training API | MirroredStrategy | TPUStrategy | MultiWorkerMirroredStrategy | ParameterServerStrategy | CentralStorageStrategy ||----|----|----|----|----|----|| Custom training loop | Supported | Supported | Supported | Experimental support | Experimental support |Examples and tutorialsHere are some examples for using distribution strategies with custom training loops:Tutorial: Training with a custom training loop and MirroredStrategy.Tutorial: Training with a custom training loop and MultiWorkerMirroredStrategy.Guide: Contains an example of a custom training loop with TPUStrategy.Tutorial: Parameter server training with a custom training loop and ParameterServerStrategy.TensorFlow Model Garden repository containing collections of state-of-the-art models implemented using various strategies.Other topicsThis section covers some topics that are relevant to multiple use cases.\Setting up the TF_CONFIG environment variableFor multi-worker training, as mentioned before, you need to set up the 'TF_CONFIG' environment variable for each binary running in your cluster. The 'TF_CONFIG' environment variable is a JSON string which specifies what tasks constitute a cluster, their addresses and each task's role in the cluster. The tensorflow/ecosystem repo provides a Kubernetes template, which sets up 'TF_CONFIG' for your training tasks.There are two components of 'TF_CONFIG': a cluster and a task.A cluster provides information about the training cluster, which is a dict consisting of different types of jobs such as workers. In multi-worker training, there is usually one worker that takes on a little more responsibility like saving checkpoint and writing summary file for TensorBoard in addition to what a regular worker does. Such worker is referred to as the "chief" worker, and it is customary that the worker with index 0 is appointed as the chief worker (in fact this is how tf.distribute.Strategy is implemented).A task on the other hand provides information about the current task. The first component cluster is the same for all workers, and the second component task is different on each worker and specifies the type and index of that worker.One example of 'TF_CONFIG' is:\os.environ["TF_CONFIG"] = json.dumps({ "cluster": { "worker": ["host1:port", "host2:port", "host3:port"], "ps": ["host4:port", "host5:port"] }, "task": {"type": "worker", "index": 1}})This 'TF_CONFIG' specifies that there are three workers and two "ps" tasks in the "cluster" along with their hosts and ports. The "task" part specifies the role of the current task in the "cluster"—worker 1 (the second worker). Valid roles in a cluster are "chief", "worker", "ps", and "evaluator". There should be no "ps" job except when using tf.distribute.experimental.ParameterServerStrategy.What's next?tf.distribute.Strategy is actively under development. Try it out and provide your feedback using GitHub issues.\\:::infoOriginally published on the TensorFlow website, this article appears here under a new headline and is licensed under CC BY 4.0. Code samples shared under the Apache 2.0 License.:::\