Determining Model Size and Training Horizon through Scaling Laws
A framework for optimizing model size and training token count from scaling-law coefficients — accounting for inference cost, data repetition, and system efficiency.
1. Introduction
How big should the model be, and how long should we train it? This post walks through how to answer that from scaling-law coefficients, with the usual twists: inference budget, limited data, system efficiency.
Three variants come up in practice: pick the compute-optimal size and horizon, go bigger than optimal to squeeze out more accuracy, or go smaller and overtrain it to save on inference.
The framework below handles the whole picture — training budget, inference budget, limited data, intentional repetition. Give it a compute budget and a rough sense of how much inference you’ll serve, and it returns the optimal size and horizon, plus the % compute overhead of deviating from them.
1.1. Problem formulation and notation
The approach builds on top of the scaling laws of the model family from Chinchilla. We define the variables as:
- : Number of model parameters (model size).
- : Number of training tokens (total tokens processed during training, i.e. training steps times batch size).
- : Available training compute budget, measured in floating point operations (FLOPs) or an equivalent unit. This constrains how large and can be (since more parameters or more tokens both consume more compute).
- : Projected number of inference tokens the model will serve in its lifetime (how much traffic the model will serve in its lifetime; for example, if we expect ~1 billion queries with an average of 1000 tokens each, tokens).
- : Number of unique tokens in the available training corpus (the size of the dataset). If , it means the dataset will be repeated to supply that many training tokens.
- Loss function : A proxy for model quality after training, given by the scaling law. Lower loss corresponds to a better model. We will use a parametric form informed by the Chinchilla paper:
where , , , , are the scaling laws coefficients that are determined by data and model architecture. This equation encapsulates the empirically observed power-law improvement as model and data scale, with representing the irreducible loss (approached as ).
These coefficients can be derived from a regression analysis following the [Juntei] How to run a scaling ladder ablation procedure.
2. Compute Optimality Approach
We’ll build this up in layers: start with vanilla compute-optimal training, then add data constraints, inference cost, and system efficiency one at a time.
2.1. Training compute only
To build up the optimization problem, we recall the classic compute-optimal training problem (Chinchilla setting) can be expressed as: for a given training compute budget , choose and (with where is usually 6) such that is minimized. In other words,
From this formulation, one could derive that the compute-optimal number of parameters and compute-optimal token counts follow a power law:
2.2. Data availability consideration
If , not all those tokens are unique. The scaling law assumes i.i.d. data and doesn’t explicitly differentiate between unique versus repeated data. To incorporate the effect of diminishing returns from repeated data, we need to modify the data term. One simple approach is to introduce an effective token count coefficient that discounts the actual token count to the equivalent token count for repetition . One particular form that fits the empirical data well is to model the half-life of the data utility which is the number of epochs at which time the additional utility from the dataset is reduced by half compared to the first epoch. The adjusted effective tokens in this case takes the following form:
This changes the compute optimality formulation to
2.3. Inference compute consideration
We want to account for the compute required for inference on tokens. If the model has parameters, a single token inference roughly costs FLOPs (for a forward pass). So the total inference cost is approximately , where is another proportionality constant (usually , or ).
Next, we consider two scenarios for the compute optimality problem that takes into account inference.
Scenario 2.3.1. Fixed total compute budget between training and inference
Here we have a fixed compute to be split between training and inference for a fixed amount tokens that are known beforehand. This only requires a simple change to the optimization constraint to take into account the inference compute.
Scenario 2.3.2. Under fixed training compute, overtrain an undersized model to reduce inference cost
In this scenario, we trade off the additional compute overhead from the overtraining smaller-than-optimal model for achieving the same accuracy compared to the compute optimal size and horizon. This trades extra training compute for cheaper inference down the line.
In this setting, we want to scale the parameters by and training tokens by to achieve the same loss as the compute optimal parameters and token count , in other words:
Solving for we have,
With that, we can derive the new total compute and the compute overhead compared to compute-optimal settings.
3. System Efficiency Consideration
Compute operations is a convenient abstraction, but what you actually pay for is machine time. For a fixed amount of compute, higher system efficiency means less machine time and lower training cost. For inference, higher efficiency means more throughput per node, so a smaller fleet for the same traffic. Folding efficiency in turns the compute-optimal story into something you can actually use for production decisions.
3.1 Training efficiency
For accurately taking into account training efficiency for training cost, we need to account for model flops utilization (MFU) and goodput. MFU is defined as the ratio of the observed throughput (tokens-per-second) relative to the theoretical maximum throughput of a system operating at peak FLOPs. Goodput is loosely defined as the time spent computing useful new steps over the elapsed time of the training job.
Machine time falls out in three divisions. Divide training compute by MFU to get the actual flops the system has to push through. Divide that by peak flops for the ideal machine time . Finally, divide by goodput to get real elapsed time:
3.2 Inference efficiency
Inference efficiency works the same way, but the details differ. Serving setups are much smaller than training clusters, so goodput is modeled differently. The workload is forward-only, which changes how performance engineering and MFU look. And Scenario 2.3.2 — overtraining an undersized model — is the more common case anyway, and it has no pre-determined token count. There, the natural target is throughput per host.