Keras is the approachable, productive, high-level API for fast experimentation on the TensorFlow platform. Keras focuses on modern deep learning and lets you fully use TensorFlow's scalability and cross-platform features.
If you use TensorFlow, use Keras APIs by default for each step of the machine learning workflow, from data processing to hyperparameter tuning to deployment. Otherwise, a few use cases require the low-level TensorFlow Core APIs.
There are multiple ways to run a Keras model:
- Run on a TPU Pod or large clusters of GPUs.
- Export to run in a browser or mobile device.
- Serve via a web API.
Keras makes ML workflows easier:
- Simple, consistent interfaces.
- Few steps for common tasks.
- Clear, helpful error messages.
- Gradual learning curve: easy to start, advanced options later.
- Concise, readable code.
Keras has two core data structures:
The tf.keras.layers.Layer
class is the main abstraction in Keras. A
Layer
encapsulates a state (weights) and some computation defined in the
tf.keras.layers.Layer.call
method.
Weights created by layers are trainable or non-trainable. Layers are recursively composable: for a layer instance assigned as an attribute of another layer, the outer layer tracks the weights created by the inner layer.
Layers can also handle data preprocessing tasks like normalization and text vectorization. Models are made portable by directly including preprocessing layers during or after training.
A model is an object that groups layers and trains on data. There are a few main model types:
Sequential
model - The simplest model which is a linear stack of layers.- Keras functional API - Lets you build arbitrary graphs of layers for more complex architectures.
You can also use subclassing to write models from scratch.
The tf.keras.Model
class has built-in training and evaluation methods:
tf.keras.Model.fit
- Trains the model for a fixed number of epochs.tf.keras.Model.predict
- Generates output predictions for the input samples.tf.keras.Model.evaluate
- Returns the loss and metrics values for the model and is configured via thetf.keras.Model.compile
method.
These methods provide built-in training features:
- Callbacks - Leverage built-in callbacks for early stopping, model checkpointing, and TensorBoard monitoring. You can also implement custom callbacks.
- Distributed training - Easily scale training to multiple GPUs, TPUs, or devices.
- Step fusing - The
steps_per_execution
argument intf.keras.Model.compile
can process multiple batches in a singletf.function
call, improving device utilization on TPUs.
For details on using fit
, see the
training and evaluation guide.
To learn how to customize the built-in training and evaluation loops, see
customizing what happens in fit()
.
Keras provides many other APIs and tools for deep learning:
For a full list of available APIs, see the Keras API reference. See the Keras ecosystem to learn more about other Keras projects.
To get started using Keras with TensorFlow, see the following topics:
- The Sequential model
- The Functional API
- Training & evaluation with the built-in methods
- Making new layers and models via subclassing
- Serialization and saving
- Working with preprocessing layers
- Customizing what happens in fit()
- Writing a training loop from scratch
- Working with RNNs
- Understanding masking & padding
- Writing your own callbacks
- Transfer learning & fine-tuning
- Multi-GPU and distributed training
To learn more about Keras, see the following topics at keras.io: