Case Study: Building a Domain-Specific Foundation Model for Healthcare
Table of contents
- #Problem Statement
- #Step 0: Why Train From Scratch?
- #Step 1: Requirements
- #Step 2: Architecture - Designing the Transformer
- #Step 3: Tokenizer Design
- #Step 4: Optimization - Making Training Converge
- #Step 5: Distributed Training - Scaling Across 256 GPUs
- #Step 6: Compute Budget - What This Costs
- #Failure Modes
- #Operational Concerns
- #Going Deeper
- #References
This post applies the 9-step case study structure from the GenAI System Design Framework.
Problem Statement
A healthcare AI company builds clinical decision support tools, medical coding assistants, and patient communication systems. All of these products depend on a language model that understands medical text: clinical notes, discharge summaries, radiology reports, pharmaceutical literature, and insurance documentation.
Today the company fine-tunes general-purpose models (LLaMA, Mistral) for each downstream task. This works reasonably well for tasks like summarization, but it falls apart in specific ways. The tokenizer splits “acetaminophen” into 4-5 subword tokens, which wastes context window and fragments the model’s internal representation of a single concept. The model hallucinates drug interactions that don’t exist. It confuses ICD-10 codes that differ by a single character. When asked to interpret lab values alongside clinical notes, it produces answers that are syntactically fluent but clinically dangerous.
The core problem is not that general-purpose models are bad. They are remarkably capable. The problem is that medical language has structure, vocabulary, and safety constraints that a model trained primarily on web text has no reason to encode deeply. Fine-tuning can adapt surface behavior, but it cannot retroactively fix how the model tokenizes medical terminology, or fill in knowledge the pre-training data never contained at sufficient depth.
What we’re building: a 7-billion parameter decoder-only language model, pre-trained from scratch on 200 billion tokens of medical and clinical text, with a custom tokenizer designed for healthcare vocabulary. This model serves as the foundation for 50+ downstream applications across the company’s product suite.
Primary users: ML engineers who fine-tune the foundation model for specific clinical applications.
Secondary users: product teams building on the fine-tuned models, clinical validation teams evaluating safety, and the infrastructure team operating the training cluster.
What This System Is Not
This is not a chatbot. There is no conversational interface on the foundation model itself. It produces raw next-token predictions. Downstream teams fine-tune it and wrap it with application-specific interfaces.
This is also not a medical device. The foundation model does not make clinical decisions. It is a component in systems that may support clinical decisions, but those systems have their own validation, human-in-the-loop requirements, and regulatory considerations. We are building the engine, not the car.
It is not a retrieval system. The model does not look up facts at inference time (that’s what RAG is for in the downstream applications). The knowledge is baked into the weights during pre-training. If the training data doesn’t contain information about a rare condition, the model won’t know about it, and no amount of prompting will fix that.
Step 0: Why Train From Scratch?
This is the most important question in the entire case study, because training a foundation model from scratch is expensive, slow, and operationally painful. Fine-tuning an existing open model is 50-100x cheaper. You need a strong reason to take on that cost.
The tokenizer tax. General-purpose tokenizers (like LLaMA’s SentencePiece or GPT’s tiktoken) are trained on web-scale corpora. Medical terminology is underrepresented in that training data, so medical terms get split into many subword tokens. “Acetaminophen” becomes something like ["ace", "tam", "in", "ophen"]. “Electrocardiogram” becomes 5 tokens. “Methylprednisolone” becomes 6.
This matters for two reasons. First, it wastes context window. A clinical note that would be 800 tokens with a medical tokenizer might be 1,400 tokens with a general tokenizer. You are paying 75% more compute per document for the same information. Second, and more importantly, when a word is split across multiple tokens, the model must learn to compose the meaning across those token boundaries. It can do this (transformers are good at it), but the representation is less efficient than having a single token for a single concept. For a model that needs to reason about 50,000+ medical terms, this tax compounds.
You cannot swap in a custom tokenizer after pre-training. The tokenizer and the model’s embedding layer are tightly coupled. The embedding matrix has one row per vocabulary token. If you change the vocabulary, you need to retrain the embeddings, which means retraining the model.
The knowledge ceiling. Fine-tuning adapts a model’s behavior on top of its existing knowledge. If the base model was pre-trained on a corpus that contained 0.5% medical text, the model’s internal representations of medical concepts are shallow. Fine-tuning on medical question-answer pairs can teach the model to format answers like a clinician, but it cannot inject deep understanding of pharmacokinetics, differential diagnosis reasoning, or the relationship between lab values and disease progression. Those patterns need to be learned during pre-training, when the model is forming its core representations.
Think of it this way: RAG lets you give the model facts at inference time, and fine-tuning lets you adjust its behavior. But neither changes what the model “understands” at a representational level. If you need a model that reasons about medical concepts the way a general model reasons about common English, you need medical concepts in the pre-training data at sufficient volume and diversity.
The cost math. Training a 7B model on 200B tokens costs roughly $500K-$1M in cloud compute (we’ll break this down precisely in Step 6). That sounds like a lot, and it is. But the company has 50+ downstream applications, each of which currently fine-tunes a general model separately. If a purpose-built foundation model improves each downstream task by even a few percentage points, the aggregate value across 50 applications easily justifies the upfront cost. The foundation model is infrastructure, amortized across every product that uses it.
When NOT to train from scratch. If you have fewer than 5 downstream applications, fine-tuning a strong open model is almost certainly the right choice. If your domain vocabulary overlaps heavily with general English (legal text, for example, is unusual but uses standard English words), the tokenizer tax is minimal. If your primary bottleneck is factual recall rather than reasoning (the model needs to know current drug interactions), RAG is a better investment than pre-training. Training from scratch makes sense when the tokenizer tax is high, the domain has deep specialized structure, and the model serves as shared infrastructure across many products.
Step 1: Requirements
Functional Requirements
- Pre-train a decoder-only transformer language model on medical and clinical text
- Produce a model that can be fine-tuned for downstream tasks: clinical note summarization, medical coding (ICD-10, CPT), patient communication generation, clinical question answering, radiology report interpretation
- Support context lengths of at least 4,096 tokens (sufficient for most clinical notes; longer contexts are a post-training concern)
- Generate coherent medical text with correct terminology, drug names, dosage formats, and lab value references
- Serve as the single foundation model for 50+ downstream applications
Non-Functional Requirements
- Training time: Complete pre-training within 14 days on the allocated cluster. Longer training runs increase the risk of hardware failures, cluster scheduling conflicts, and opportunity cost.
- Reproducibility: Given the same data, code, and random seed, the training must produce the same model. This requires deterministic data loading order and controlled randomness in initialization.
- Checkpointing: Save full model state every 1,000 steps. A hardware failure should not lose more than ~30 minutes of training.
- Cost: Stay within $800K total compute budget including experimentation, failed runs, and the final production run.
- Safety: The model must not memorize and reproduce protected health information (PHI) from the training data. De-identification of training data is a pre-processing requirement, not a model requirement, but we validate this during evaluation.
Scale Assumptions
| Parameter | Value | Rationale |
|---|---|---|
| Model size | 7B parameters | Sweet spot for domain models: large enough for complex reasoning, small enough to fine-tune on a single 80GB GPU |
| Training tokens | 200B | Chinchilla-optimal for 7B is ~140B tokens, but we over-train following the LLaMA approach for better downstream performance |
| Vocabulary size | 64,000 tokens | Larger than typical (32K) to accommodate medical terminology without excessive splitting |
| Context length | 4,096 tokens | Covers 95%+ of clinical notes. Longer context is addressed post-training via RoPE scaling |
| Training cluster | 256x NVIDIA A100 80GB | Minimum viable cluster for completing training in 14 days |
| Batch size | 4M tokens per step | Standard for 7B-class models. Achieved via gradient accumulation across GPUs |
Quality Metrics
Quality for a foundation model is measured differently than for a fine-tuned task model. We cannot evaluate “accuracy” directly because the model is not trained for any specific task. Instead, we use proxy metrics:
| Metric | Target | Measurement |
|---|---|---|
| Validation perplexity | <8.0 on held-out medical text | Lower is better. General models score 12-15 on medical text |
| Medical term fertility | <1.5 tokens per medical term (average) | Measures tokenizer efficiency. General tokenizers score 2.5-3.5 |
| USMLE QA (zero-shot) | >45% accuracy | Proxy for medical knowledge. General 7B models score 35-40% |
| MedNLI (zero-shot) | >70% accuracy | Natural language inference on clinical text |
| Downstream fine-tuning lift | >3% over fine-tuned LLaMA 7B | Measured across 5 representative downstream tasks |
| PHI leakage rate | 0% on probing tests | Tested with known-template attacks against training data |
Step 2: Architecture - Designing the Transformer
The transformer architecture for a 7B model is not exotic. It follows a well-established template, with a few specific decisions that matter at this scale. This section walks through each component.
Decoder-Only vs. Encoder-Only vs. Encoder-Decoder
Three families of transformer architectures exist. Encoder-only models (BERT and its variants) produce bidirectional representations and are strong for classification and extraction tasks, but they cannot generate text. Encoder-decoder models (T5, BART) can generate text but require separate encoder and decoder stacks, which roughly doubles the parameter count for the same effective capacity. Decoder-only models (GPT, LLaMA, Mistral) generate text autoregressively and can be adapted to almost any downstream task through fine-tuning or prompting.
We use decoder-only. The reasoning is practical: all 50+ downstream applications require text generation (summaries, reports, answers, codes). A decoder-only model handles all of these. Encoder-only would require a separate model for generation tasks. Encoder-decoder would work but the two-stack architecture adds complexity for pre-training and serving without a clear quality advantage at the 7B scale.
Attention Mechanism
Standard multi-head attention (MHA) computes separate query, key, and value projections for each attention head. For a model with 32 heads and a hidden dimension of 4,096, each head operates on a 128-dimensional subspace. This works well but creates a serving problem: during autoregressive generation, the KV cache stores key and value tensors for all previous tokens across all heads. For 32 heads at 4,096 context length with FP16 precision, the KV cache per layer is:
2 (K and V) x 32 (heads) x 4096 (sequence length) x 128 (head dimension) x 2 (bytes for FP16) = 64 MB per layer
With 32 layers, that is 2 GB per sequence just for the KV cache. At batch size 32, you need 64 GB of memory solely for KV caches, which consumes almost an entire A100’s memory before the model weights are even loaded.
Grouped Query Attention (GQA) addresses this by sharing key-value heads across multiple query heads. Instead of 32 KV heads, you use 8 KV groups, each shared by 4 query heads. This reduces the KV cache by 4x (to 16 GB at batch 32) with minimal quality loss. LLaMA 2 70B introduced this and subsequent models have adopted it widely. For deeper coverage of attention and memory tradeoffs, see the posts on attention mechanisms and paged attention.
We use GQA with 8 KV groups. The quality difference from full MHA is negligible at 7B scale, and the 4x reduction in KV cache memory directly translates to higher serving throughput and lower inference cost for every downstream application.
Position Encoding: RoPE
Transformers have no inherent sense of token position. Position information must be injected explicitly. The original transformer used fixed sinusoidal embeddings. GPT-2 used learned absolute position embeddings. Both approaches have a hard ceiling: the model cannot generalize to sequence lengths it didn’t see during training.
Rotary Position Embeddings (RoPE) encode position by rotating the query and key vectors in pairs of dimensions. The rotation angle depends on the position, so the dot product between a query at position i and a key at position j depends only on the relative distance i - j, not on the absolute positions. This gives the model relative position awareness.
The practical benefit: RoPE models can be extended to longer contexts after pre-training by adjusting the rotation frequencies (NTK-aware scaling, YaRN, or simply training on a few billion tokens at the extended length). We pre-train at 4,096 context length and leave the option open for post-training extension to 8K or 16K for applications that need longer clinical documents.
Activation Function: SwiGLU
The feed-forward network (FFN) in each transformer block transforms the hidden representation through two linear layers with a non-linearity in between. The original transformer used ReLU. GPT-2 used GELU. Current best practice, following PaLM and LLaMA, uses SwiGLU.
SwiGLU is a gated activation: it takes two linear projections of the input, applies a Swish activation to one, and multiplies them element-wise. This gating mechanism gives the model more expressive power per parameter in the FFN block. The cost is that SwiGLU requires three weight matrices in the FFN instead of two (the gate projection, the up projection, and the down projection), so the FFN intermediate dimension is adjusted downward to keep the total parameter count constant. With a standard FFN you might use an intermediate size of 4 * hidden_dim. With SwiGLU, you use (8/3) * hidden_dim rounded to the nearest multiple of 256 for hardware efficiency.
Concrete Architecture
| Component | Value | Notes |
|---|---|---|
| Architecture | Decoder-only transformer | Autoregressive, causal attention mask |
| Parameters | 6.7B | Commonly called “7B” |
| Hidden dimension | 4,096 | Standard for 7B class |
| Number of layers | 32 | |
| Attention heads (query) | 32 | 128 dimensions per head |
| KV heads (GQA) | 8 | 4 query heads per KV group |
| FFN intermediate dimension | 11,008 | (8/3) * 4096, rounded to multiple of 256 |
| Vocabulary size | 64,000 | Custom medical tokenizer |
| Context length | 4,096 tokens | RoPE allows post-training extension |
| Position encoding | RoPE | Relative position, extensible |
| Activation | SwiGLU | Gated activation in FFN |
| Normalization | RMSNorm (pre-norm) | Applied before attention and FFN, not after |
| Tied embeddings | No | Input and output embeddings are separate |
The parameter count breaks down roughly as follows. The embedding layer is 64,000 * 4,096 = 262M parameters. Each transformer block contains the attention projection weights (4 * 4096^2 for Q plus K/V with GQA adjustments, roughly 50M per layer) and the SwiGLU FFN (3 * 4096 * 11,008, roughly 135M per layer). Across 32 layers, that’s approximately 32 * 185M = 5.9B parameters. Add the output projection (another 262M) and layer norms, and you arrive at approximately 6.7B total.

Scaling Laws: How We Chose 7B and 200B Tokens
Scaling laws give us a principled way to decide model size and training data quantity for a given compute budget.
The Chinchilla scaling law (Hoffmann et al., 2022) found that the compute-optimal ratio is roughly 20 tokens per parameter. For a 7B model, that suggests 140B tokens. Training beyond this point yields diminishing returns per FLOP.
However, the LLaMA approach (Touvron et al., 2023) demonstrated that over-training smaller models beyond the Chinchilla-optimal point produces models that perform better at inference time per parameter. LLaMA 7B was trained on 1T tokens (140x the parameters, vs Chinchilla’s 20x). The logic: inference cost scales with model size, but training cost is paid once. If you are going to serve the model millions of times, spending extra on training to get a better small model is worth it.
We train on 200B tokens (roughly 29x the parameter count). This is a compromise: we don’t have 1T tokens of high-quality medical data, and the returns from over-training diminish beyond a point. 200B tokens gives us meaningful over-training benefits while staying within our data budget.
Step 3: Tokenizer Design
The tokenizer converts raw text into integer token IDs that the model processes. This is the very first step in the pipeline, and mistakes here propagate through everything downstream. A bad tokenizer means wasted context window, fragmented representations, and a model that works harder to learn what should be simple concepts.
How BPE Works
Byte-Pair Encoding (BPE) is the standard tokenization algorithm for modern language models. The process:
- Start with a base vocabulary of individual bytes (or characters). This guarantees that any input text can be encoded, even if it contains rare characters.
- Count all adjacent pairs of tokens in the training corpus.
- Merge the most frequent pair into a single new token. “t” + “h” becomes “th”.
- Repeat step 2-3 for the desired number of merge operations (vocabulary size minus base vocabulary size).
After training, common words and subwords become single tokens (“the”, “ing”, “tion”), while rare words are composed from smaller pieces. The vocabulary size controls the granularity: larger vocabulary means more words are single tokens (less splitting), but the embedding table grows, and rare tokens have fewer training examples.

Why General Tokenizers Fail on Medical Text
A general-purpose BPE tokenizer trained on web text (Common Crawl, Wikipedia, books) sees medical terms infrequently. “Acetaminophen” appears far less often than “the” or “and”, so BPE never merges it into a single token. Instead, it gets split:
General tokenizer (LLaMA SentencePiece):
"The patient was prescribed acetaminophen 500mg for pain management"
→ ["The", " patient", " was", " prescribed", " acet", "amin", "ophen", " 500", "mg", " for", " pain", " management"]
→ 12 tokens
Medical tokenizer (ours):
"The patient was prescribed acetaminophen 500mg for pain management"
→ ["The", " patient", " was", " prescribed", " acetaminophen", " 500mg", " for", " pain", " management"]
→ 9 tokens
Three fewer tokens for one sentence. Now multiply this across a 2,000-word clinical note. The general tokenizer might produce 2,800 tokens where the medical tokenizer produces 1,900. That is a 47% overhead. At 4,096 context length, you fit the entire note with the medical tokenizer but must truncate with the general one.
A more aggressive example with a radiology report fragment:
General tokenizer:
"Findings: Bilateral pleural effusions with associated atelectasis.
Cardiomediastinal silhouette is within normal limits.
No pneumothorax."
→ 38 tokens (splits "pleural", "effusions", "atelectasis",
"cardiomediastinal", "silhouette", "pneumothorax")
Medical tokenizer:
→ 24 tokens (all medical terms are single tokens or 2-token compounds)
Building the Medical Tokenizer
We train a BPE tokenizer from scratch on the medical corpus (the same 200B tokens used for model pre-training). The training process:
-
Corpus preparation: Sample a representative subset (10B tokens worth of raw text) spanning all document types: clinical notes, PubMed abstracts, pharmaceutical literature, radiology reports, discharge summaries, ICD/CPT code descriptions. The sampling must be proportional to the desired distribution, not the raw availability. PubMed has billions of tokens, but clinical notes are more important for our use case and are harder to source.
-
Pre-tokenization: Before BPE training, split on whitespace and punctuation to establish word boundaries. This prevents BPE from merging across word boundaries (you don’t want “pain management” to become a single token because the model needs to understand each word independently).
-
BPE training: Train with a target vocabulary of 64,000 tokens using SentencePiece. The training runs for about 2 hours on a single machine.
-
Special tokens: Reserve tokens for padding (
<pad>), beginning of sequence (<s>), end of sequence (</s>), and unknown (<unk>). Also reserve tokens for section headers common in clinical notes ([HISTORY],[ASSESSMENT],[PLAN], etc.) so they are always single tokens. -
Validation: Measure fertility (average tokens per word) on held-out medical text. Target: <1.5 for medical terms, <1.2 for common English. General tokenizers score 2.5-3.5 on medical terms.
Vocabulary Size Trade-offs
| Vocab Size | Pros | Cons |
|---|---|---|
| 32,000 | Smaller embedding table (131M params), every token has many training examples | Medical terms heavily split, high fertility |
| 64,000 | Most medical terms are single tokens, good fertility | Embedding table doubles (262M params), rare tokens have fewer examples |
| 128,000 | Nearly all medical terms single tokens | Embedding table is 524M params (8% of total model), many tokens appear <100 times in training data |
We chose 64,000 as the balance point. The 262M parameter embedding table is 3.9% of total model parameters, which is reasonable. The fertility improvement over 32K is significant (going from 2.5 average to 1.4 average on medical terms). Going to 128K gives marginal fertility improvement but doubles the embedding cost and creates many undertrained tokens.
The Fertility Metric
Fertility measures tokenizer efficiency: how many tokens does the tokenizer produce per word in the input text? Lower is better.
Fertility = (number of tokens produced) / (number of whitespace-separated words in input)
For a medical tokenizer to be worth the cost of training from scratch, it needs to demonstrate clear fertility improvement on medical text:
| Text Type | General Tokenizer Fertility | Medical Tokenizer Fertility | Improvement |
|---|---|---|---|
| Clinical notes | 1.65 | 1.18 | 28% |
| Radiology reports | 1.82 | 1.22 | 33% |
| Pharmaceutical text | 1.91 | 1.25 | 35% |
| General English | 1.15 | 1.17 | -2% (slightly worse) |
| Combined medical corpus | 1.72 | 1.20 | 30% |
The medical tokenizer is slightly worse on general English because some of its vocabulary budget is allocated to medical terms instead of general English subwords. This is an acceptable trade-off: the model is built for medical text.
Step 4: Optimization - Making Training Converge
Pre-training a language model is an optimization problem. You are minimizing next-token prediction loss over 200 billion tokens. The optimization choices determine whether training converges to a good solution, diverges, or gets stuck.
Training Objective
The objective is straightforward: next-token prediction (causal language modeling). Given a sequence of tokens [t1, t2, ..., tn], the model predicts each t_i given all preceding tokens [t1, ..., t_{i-1}]. The loss is the cross-entropy between the model’s predicted probability distribution over the vocabulary and the actual next token.
This single objective, applied over 200B diverse medical tokens, teaches the model everything: medical terminology, clinical reasoning patterns, drug interactions, anatomy, diagnostic criteria, and even the formatting conventions of different document types. No task-specific labels needed. The data IS the supervision.
Optimizer Selection: Why AdamW
SGD (stochastic gradient descent) is simple and memory-efficient but converges slowly on transformer training due to the highly non-convex, variable-curvature loss landscape. Training would take 5-10x longer.
Adam maintains per-parameter moving averages of the gradient (first moment) and the squared gradient (second moment). This adaptive learning rate helps each parameter converge at its own pace. Adam works well, but has a known issue: the L2 regularization term interacts poorly with the adaptive learning rates, making weight decay less effective.
AdamW decouples weight decay from the gradient update. Instead of adding the L2 penalty to the gradient (which then gets scaled by Adam’s adaptive learning rate), AdamW applies weight decay directly to the parameters after the Adam update. This produces more consistent regularization and better generalization.
AdamW is the universal standard for transformer pre-training. We use it with:
beta1 = 0.9(first moment decay)beta2 = 0.95(second moment decay; 0.95 instead of the default 0.999, following LLaMA, which stabilizes training)weight_decay = 0.1epsilon = 1e-8
Memory cost of AdamW: for each parameter, AdamW stores the first moment and second moment, both the same size as the parameter itself. So a 7B parameter model requires 7B (params) + 7B (first moment) + 7B (second moment) = 21B values in FP32. At 4 bytes each, that is 84 GB just for the optimizer states, which exceeds a single A100’s 80 GB. This is one of the reasons we need distributed training (covered in Step 5).
Learning Rate Schedule
The learning rate schedule has two phases:
Warmup (first 2,000 steps): The learning rate increases linearly from 0 to the peak value. This prevents early instability when the model’s parameters are randomly initialized and gradients are noisy and large. Starting with a high learning rate on random weights can cause irreversible divergence.
Cosine decay (remaining steps): After warmup, the learning rate follows a cosine curve from the peak value down to 10% of peak. This gradual reduction lets the model make large updates early (exploring the loss landscape) and fine-grained updates later (settling into a good minimum).
Peak learning rate for 7B: 3e-4. This is higher than what you would use for fine-tuning (typically 1e-5 to 5e-5) because pre-training starts from random initialization and needs to make large parameter updates.
Batch Size and Gradient Accumulation
The global batch size is 4 million tokens per optimization step. This is not a single forward pass. Instead:
- Each of the 256 GPUs processes a micro-batch of 4 sequences, each 4,096 tokens long:
4 * 4096 = 16,384tokens per GPU per micro-batch. - We accumulate gradients over 16 micro-batches before performing an optimizer step:
16 * 16,384 = 262,144tokens per GPU per step. - Across 256 GPUs (with data parallelism):
262,144 * ~16 = ~4Mtokens per step (the exact number depends on the parallelism configuration, which we detail in Step 5).
Why 4M tokens? Larger batch sizes improve hardware utilization (more parallel work per step) and produce more stable gradient estimates. But they also reduce the total number of optimization steps for a given token budget. At 200B tokens with 4M tokens per step, we get 50,000 optimization steps. Empirically, 50K-100K steps is a good range for stable 7B training. Fewer steps (larger batches) risk under-optimization. More steps (smaller batches) mean slower wall-clock training.
Mixed Precision Training: BF16
Training in full FP32 precision wastes memory and compute. Modern GPUs have specialized hardware (tensor cores) that operate at 2x speed on reduced precision formats.
FP16 (16-bit floating point) halves memory and doubles throughput, but has a narrow dynamic range. Gradient values that are very small (common in deep transformers) can underflow to zero. This requires loss scaling: artificially multiplying the loss by a large factor to keep gradients in the representable range, then dividing back after the backward pass. Loss scaling works, but adds complexity and occasional training instabilities.
BF16 (bfloat16) has the same 16-bit storage as FP16 but allocates more bits to the exponent and fewer to the mantissa. This gives it the same dynamic range as FP32 (no underflow issues) at the cost of slightly less precision. You don’t need loss scaling with BF16.
We use BF16 for forward and backward passes, with FP32 for the optimizer states and master copy of the weights. This is the standard “mixed precision” setup: compute in BF16, accumulate in FP32. Memory savings: model weights in BF16 are 14 GB instead of 28 GB (for 7B params), and activations are similarly halved. The A100 achieves 312 TFLOPS in BF16 vs 156 TFLOPS in FP32, so we get a 2x throughput improvement with negligible quality loss.
Gradient Clipping
Gradient norms occasionally spike during training, especially when the model encounters unusual data batches (a sequence of rare medical terms, or a malformed document that slipped through data cleaning). A single large gradient update can destabilize the entire training run.
We clip the global gradient norm to 1.0. Before each optimizer step, compute the L2 norm of the entire gradient vector across all parameters. If the norm exceeds 1.0, scale all gradients down proportionally so the norm equals 1.0. This prevents catastrophic updates while preserving the gradient direction.
Gradient clipping is not optional at this scale. Without it, loss spikes from bad data batches can cascade into divergence that wastes days of compute.
Training Stability: What Goes Wrong
Even with the right optimizer, schedule, and clipping, 7B-parameter training over 200B tokens is a long optimization run (50,000 steps over 14 days). Several failure modes:
- Loss spikes: Sudden increases in loss, often caused by a data batch with unusual characteristics. Usually the model recovers within a few hundred steps if gradient clipping is working. If the spike doesn’t recover, you roll back to the last checkpoint and skip the problematic data.
- Slow divergence: Loss gradually increases over thousands of steps. Usually indicates the learning rate is too high for the current phase of training, or weight decay is too low. Hard to catch without continuous monitoring.
- NaN/Inf values: Numerical overflow, usually in attention logits when attention scores become very large. This is more common with FP16 than BF16, which is one reason we prefer BF16.
- Loss plateau: Loss stops decreasing but hasn’t reached the expected level. Can indicate insufficient learning rate, data quality issues, or a data loader bug that is repeating the same data.
We monitor training loss, gradient norm, learning rate, and per-layer activation statistics every 10 steps. Alert thresholds are set for loss spikes (>2x running average), gradient norm spikes (>10x running average), and NaN detection.
Step 5: Distributed Training - Scaling Across 256 GPUs
A 7B parameter model with AdamW optimizer states and activations cannot fit on a single GPU. Even if it could, training on 200B tokens on a single A100 would take roughly 18,000 GPU-hours, which is over 2 years of continuous operation. We need to distribute across 256 GPUs for two reasons: memory (the model doesn’t fit on one GPU) and time (we need to finish in 14 days).
Memory Math: Why You Can’t Fit on One GPU
Let’s account for every byte:
| Component | Size (FP32) | Size (BF16/mixed) | Notes |
|---|---|---|---|
| Model parameters | 26.8 GB | 13.4 GB (BF16) | 6.7B params x 4 bytes (FP32) or 2 bytes (BF16) |
| Optimizer states (AdamW) | 53.6 GB | 53.6 GB (always FP32) | 2 states x 6.7B x 4 bytes |
| Gradients | 26.8 GB | 13.4 GB (BF16) | Same size as parameters |
| Master weights (FP32 copy) | 26.8 GB | 26.8 GB | Required for mixed precision |
| Activations (per micro-batch) | ~20-40 GB | ~10-20 GB | Depends on sequence length and batch size |
| Total | ~154 GB | ~117-127 GB | Exceeds A100 80 GB |
Even with BF16, the total exceeds a single A100’s 80 GB memory. This is before accounting for CUDA kernel overhead, memory fragmentation, and the peak memory during the backward pass (which temporarily stores additional intermediate values).
Parallelism Strategies
There are four main approaches to distributing training across multiple GPUs. Most production training systems use a combination.
Data Parallelism (DP): Each GPU holds a complete copy of the model and processes a different data batch. After the forward and backward pass, gradients are averaged across all GPUs (via all-reduce), and each GPU updates its local copy identically. Simple and effective, but each GPU must hold the full model, optimizer states, and gradients, so it doesn’t solve the memory problem for large models.
Fully Sharded Data Parallelism (FSDP/ZeRO): This extends data parallelism by sharding the model, optimizer states, and gradients across GPUs instead of replicating them. ZeRO (from DeepSpeed) introduced three stages:
- Stage 1: Shard optimizer states. Each GPU stores optimizer states for only
1/Nof the parameters (where N is the number of GPUs). Reduces optimizer memory byNx. For 256 GPUs, optimizer states drop from 53.6 GB to 0.21 GB per GPU. - Stage 2: Shard optimizer states AND gradients. Gradient memory drops from 13.4 GB to 0.05 GB per GPU.
- Stage 3: Shard optimizer states, gradients, AND parameters. Each GPU stores only
1/Nof the model. Parameter memory drops from 13.4 GB to 0.05 GB per GPU. The trade-off: parameters must be gathered from other GPUs before each forward/backward computation, increasing communication.
PyTorch FSDP implements a similar approach natively in PyTorch, without requiring the DeepSpeed library. The concept is the same: shard everything, gather when needed, reduce after computation.
Tensor Parallelism (TP): Splits individual layer computations across GPUs. For example, a linear layer with weight matrix [4096, 11008] can be split column-wise across 4 GPUs, each holding [4096, 2752]. The matrix multiplication runs in parallel, and results are combined with an all-reduce. This reduces per-GPU memory for the layer and distributes computation, but requires high-bandwidth interconnects (NVLink) because communication happens within every layer forward/backward pass.
Pipeline Parallelism (PP): Assigns different layers to different GPUs. GPU 0 handles layers 0-7, GPU 1 handles layers 8-15, etc. Requires careful scheduling (micro-batching) to keep all GPUs busy, because a naive implementation has each GPU idle while waiting for the previous stage. The “1F1B” (one forward, one backward) schedule minimizes pipeline bubbles but doesn’t eliminate them.
What We Actually Use
For a 7B model on 256 A100s, the configuration is:
- FSDP (ZeRO Stage 3) across all 256 GPUs for parameter, gradient, and optimizer sharding
- Tensor Parallelism with degree 2 within each node (pairs of GPUs connected by NVLink handle TP, which has the highest communication volume)
- No Pipeline Parallelism (7B is not large enough to benefit; PP adds bubble overhead that isn’t justified)
With this configuration, each GPU’s memory footprint:
| Component | Per-GPU Size | Calculation |
|---|---|---|
| Parameters (sharded, BF16) | ~0.1 GB | 13.4 GB / 128 FSDP shards (256 GPUs / 2 TP degree) |
| Optimizer states (sharded, FP32) | ~0.4 GB | 53.6 GB / 128 |
| Gradients (sharded, BF16) | ~0.1 GB | 13.4 GB / 128 |
| Unsharded parameters (gathered for compute) | ~6.7 GB | Full layer params gathered temporarily during forward/backward |
| Activations | ~15 GB | Micro-batch of 4 sequences x 4096 tokens |
| Communication buffers | ~5 GB | Staging area for all-gather and reduce-scatter |
| Total per GPU | ~27 GB | Well within A100 80 GB, leaving headroom |
The headroom is important. It allows activation checkpointing to be used selectively (recomputing some activations during the backward pass instead of storing them) and provides buffer for memory fragmentation.

Communication Overhead
Distributed training is bottlenecked by communication, not computation. Every optimizer step requires:
- All-gather of parameters: Before each layer’s forward pass, FSDP gathers the full parameters from all shards. For 7B params in BF16, this is 13.4 GB total, but it happens layer-by-layer (overlapped with computation). Each layer is ~0.4 GB.
- Reduce-scatter of gradients: After each layer’s backward pass, gradients are reduced (summed) and scattered (each GPU keeps its shard). Same volume as the all-gather.
- All-reduce for tensor parallelism: Within each TP group (2 GPUs), partial results are combined. This uses NVLink (600 GB/s bidirectional on A100 nodes), so it is fast.
Total communication per step: roughly 2 * 13.4 GB (all-gather + reduce-scatter for FSDP) plus TP all-reduces within each node. On a cluster with 400 Gbps InfiniBand between nodes, the FSDP communication takes approximately:
26.8 GB / (400 Gbps / 8 bits per byte) = 26.8 / 50 = 0.54 seconds
This is overlapped with computation, but not perfectly. In practice, communication overhead reduces effective GPU utilization to 40-55% of peak TFLOPS. This is typical for large-scale training and is accounted for in our compute budget.
Training Throughput
With 256 A100s at ~45% utilization (accounting for communication overhead, pipeline bubbles, and memory operations):
- Per-GPU throughput:
312 TFLOPS (BF16 peak) * 0.45 = 140 TFLOPS effective - Cluster throughput:
256 * 140 = 35,840 TFLOPS - Tokens per second (we’ll derive this from FLOPs per token in Step 6): approximately 85,000 tokens per second across the cluster
- Time to process 200B tokens:
200e9 / 85,000 = ~2.35 million seconds = ~27 days
That is nearly double our 14-day target. There are two levers:
- Increase utilization to 55% through better communication overlap (achievable with optimized FSDP configurations and NVLink-aware TP): throughput increases to ~104,000 tokens/sec, bringing training time to ~22 days. Still over budget.
- Activation checkpointing: Trade compute for memory. Recompute activations during the backward pass instead of storing them. This frees memory, allowing larger micro-batch sizes (8 instead of 4), which improves GPU utilization by reducing the ratio of communication to computation. With both optimizations, we can target 55-60% utilization and ~130,000 tokens/sec, which brings training to approximately 18 days.
To hit 14 days reliably, we accept that we may need to run 18 days at 130K tokens/sec, or optimize further. In practice, the throughput numbers improve as the engineering team tunes the training configuration in the first few days. Most large-scale training runs report 50-60% MFU (Model FLOPs Utilization) after optimization, which puts us in the 14-18 day range. We budget for 18 days and treat 14 days as an optimistic target.
Checkpointing
At 130,000 tokens per second, 1,000 optimization steps (each consuming 4M tokens) takes roughly 8.5 hours. We save a full checkpoint at every 1,000 steps. Each checkpoint includes:
- Sharded model parameters (13.4 GB total across all shards)
- Sharded optimizer states (53.6 GB total)
- Learning rate scheduler state
- Data loader state (which samples have been seen)
- Random number generator states for all GPUs
Total checkpoint size: approximately 70 GB. Writing to a distributed filesystem (or S3) at each checkpoint takes 2-5 minutes, during which training pauses. Over a 14-day training run with ~50,000 steps, that’s 50 checkpoints and roughly 2-4 hours of total checkpoint time, or 1% overhead.
We keep the last 5 checkpoints and delete older ones (except every 10th checkpoint, which is kept permanently for analysis). A hardware failure that kills the training run loses at most 8.5 hours of compute.
Step 6: Compute Budget - What This Costs
This section does the math. No hand-waving.
FLOPs Per Token
The standard approximation for transformer training FLOPs per token is 6N, where N is the number of parameters. The factor of 6 comes from:
- Forward pass: ~2N FLOPs per token (each parameter is involved in one multiply-add, and there are roughly N multiply-adds in the forward pass)
- Backward pass: ~4N FLOPs per token (roughly 2x the forward pass: computing gradients with respect to both the activations and the parameters)
For our 6.7B parameter model: 6 * 6.7e9 = 40.2e9 FLOPs per token (approximately 40 TFLOPs per token).
Total Training FLOPs
Total FLOPs = FLOPs per token * number of tokens
= 40.2e9 * 200e9
= 8.04e21 FLOPs
That’s 8 zettaFLOPs. For reference, training LLaMA 7B on 1T tokens required approximately 40 zettaFLOPs.
GPU-Hours Calculation
An NVIDIA A100 80GB achieves 312 TFLOPS peak in BF16. With realistic utilization (accounting for communication, memory operations, and idle time):
- Conservative MFU: 45% of peak = 140 TFLOPS effective
- Optimistic MFU: 55% of peak = 172 TFLOPS effective
Conservative estimate:
GPU-hours = Total FLOPs / (effective TFLOPS per GPU * 3600 seconds/hour)
= 8.04e21 / (140e12 * 3600)
= 8.04e21 / 5.04e17
= 15,952 GPU-hours
Optimistic estimate:
= 8.04e21 / (172e12 * 3600)
= 8.04e21 / 6.19e17
= 12,988 GPU-hours
Wall-Clock Time
With 256 GPUs:
- Conservative:
15,952 / 256 = 62.3 hours = ~2.6 daysof pure compute - But this doesn’t account for checkpointing overhead (~1%), data loading stalls, occasional restarts: multiply by 1.15 for operational overhead
- Conservative with overhead:
62.3 * 1.15 = 71.6 hours = 3 days
Wait, that’s much less than the 14-18 days I estimated in Step 5. Let me reconcile.
The discrepancy is between “pure FLOP throughput” and “actual token throughput.” The 45% MFU already accounts for communication overhead, but the tokens-per-second calculation in Step 5 was more conservative because it accounted for the full communication and synchronization pattern. Let me recalculate using the FLOP-based approach, which is more standard:
At 45% MFU with 256 A100s:
Cluster effective TFLOPS = 256 * 312 * 0.45 = 35,942 TFLOPS
Time = 8.04e21 / 35,942e12 = 223,749 seconds = 62 hours = 2.6 days
And tokens per second at that throughput:
200e9 tokens / 223,749 seconds = 894,000 tokens/sec
The Step 5 estimate of 85,000-130,000 tokens/sec was too conservative. With proper FSDP overlap and the full cluster, 800K-900K tokens/sec at 45% MFU is the right ballpark. A 7B model is relatively small for 256 GPUs, which means the computation-to-communication ratio is actually not great (a larger model would utilize the cluster more efficiently). More realistically, MFU for a 7B model on 256 GPUs might be 35-40% due to the communication overhead being a larger fraction of total time for a smaller model.
Revised at 35% MFU:
Cluster effective TFLOPS = 256 * 312 * 0.35 = 27,955 TFLOPS
Time = 8.04e21 / 27,955e12 = 287,592 seconds = 80 hours = 3.3 days
With operational overhead (checkpointing, restarts, data loading stalls): ~4-5 days of actual wall-clock time. The 14-day budget provides ample margin for failed runs, hyperparameter tuning experiments, and unexpected issues.
Cloud Cost
On-demand A100 80GB pricing (as of early 2026):
| Provider | Per-GPU-hour | 16,000 GPU-hours | Notes |
|---|---|---|---|
| AWS (p4d.24xlarge) | ~$10-12 | $160K-192K | 8x A100 per instance |
| GCP (a2-ultragpu-8g) | ~$10-11 | $160K-176K | 8x A100 per instance |
| CoreWeave / Lambda | ~$2.50-3.50 | $40K-56K | GPU cloud providers |
| Reserved/spot instances | ~$4-6 | $64K-96K | 1-year commitment or spot |
At $3 per GPU-hour on a GPU cloud provider, the production training run costs:
16,000 GPU-hours * $3 = $48,000
But this is just the final run. The total budget includes:
| Expense | GPU-hours | Cost ($3/hr) |
|---|---|---|
| Final production run | 16,000 | $48,000 |
| Hyperparameter tuning (10 short runs at 10% scale) | 16,000 | $48,000 |
| Tokenizer experiments | 2,000 | $6,000 |
| Failed/restarted runs | 8,000 | $24,000 |
| Evaluation runs | 4,000 | $12,000 |
| Cluster idle time (scheduling gaps) | 5,000 | $15,000 |
| Total | 51,000 | $153,000 |
This is well under the $800K budget. The $500K-$1M estimate from Step 0 assumed on-demand pricing at a major cloud provider. At GPU cloud pricing, the cost is significantly lower. However, you should budget 3-5x the cost of the final training run for experimentation, which is what we see here.

Memory Budget Per GPU (Detailed)
For a single A100 80GB during training with FSDP:
| Component | Memory | Notes |
|---|---|---|
| Sharded parameters (BF16) | 0.1 GB | 13.4 GB / 128 shards |
| Gathered parameters (active layer, BF16) | 0.8 GB | Largest layer, fully gathered |
| Sharded optimizer states (FP32) | 0.8 GB | 107 GB / 128 shards (params + 2 moments in FP32) |
| Sharded gradients (BF16) | 0.1 GB | 13.4 GB / 128 shards |
| Activations (with checkpointing) | 8-12 GB | Depends on micro-batch size and checkpointing granularity |
| Communication buffers | 3-5 GB | All-gather and reduce-scatter staging |
| CUDA context and fragmentation | 3-5 GB | Overhead from CUDA runtime |
| Total | 16-24 GB | ~20-30% of A100 80GB |
This leaves 56-64 GB free, which is more than necessary. For a 7B model, 256 A100s is over-provisioned from a memory perspective. The cluster size is driven by the wall-clock time requirement (finish in 14 days), not memory. You could train this model on 32-64 GPUs if you were willing to train for 30-60 days.
Scaling Laws for Cost Prediction
Want to estimate cost for a different model size? The scaling relationship is approximately:
Cost proportional to N * D
Where N is parameter count and D is token count. If you double the model to 14B and keep 200B tokens: cost roughly doubles. If you also scale tokens to 400B (to maintain the Chinchilla ratio): cost quadruples.
| Model Size | Chinchilla-Optimal Tokens | Estimated GPU-hours | Cost (at $3/hr) |
|---|---|---|---|
| 1B | 20B | 750 | $2,250 |
| 7B | 140B | 12,000 | $36,000 |
| 7B (200B, our setup) | 200B | 16,000 | $48,000 |
| 13B | 260B | 42,000 | $126,000 |
| 30B | 600B | 225,000 | $675,000 |
| 70B | 1.4T | 1,200,000 | $3,600,000 |
These are rough estimates for the final training run only. Multiply by 3-5x for the full project cost including experimentation.
Failure Modes
Training a foundation model has failure modes that don’t exist in fine-tuning or inference. Several are subtle enough to waste weeks of compute before detection.
Loss Spikes
The most common failure during training. The loss suddenly jumps by 2-10x, then gradually recovers (or doesn’t). Causes include:
- Bad data batch: A sequence of corrupted text, extremely rare tokens, or a document in an unexpected language. The gradient from this batch is large and unusual, pushing parameters in a bad direction.
- Learning rate too high: More common in the mid-training phase when the model has already learned some structure and a large update disrupts it.
- Numerical instability in attention: Attention logits can grow very large (especially with long sequences), causing softmax to produce near-zero or near-one values. With FP16, this causes NaN propagation. BF16 is more robust but not immune.
Mitigation: Gradient clipping (already in place), checkpoint every 1,000 steps so you can roll back, and a data filtering pipeline that catches obvious corruption before training. If a loss spike doesn’t recover within 500 steps, roll back to the last checkpoint and skip the problematic data range.
Data Contamination
Medical training data comes from multiple sources: PubMed, clinical note repositories, textbooks, drug databases. If the evaluation benchmarks (USMLE questions, MedNLI) appear in the training data, evaluation metrics will be inflated. The model hasn’t learned medical reasoning; it has memorized the test.
Mitigation: Before training, deduplicate the training corpus against all evaluation sets using n-gram overlap detection (13-gram overlap threshold). Remove any training document that has >70% n-gram overlap with any evaluation sample. This is standard practice but easy to skip under deadline pressure.
Tokenizer Misalignment
The tokenizer is trained on a sample of the training corpus, not the full corpus. If the sample is not representative (for example, it over-represents PubMed abstracts and under-represents clinical notes), the tokenizer will have poor fertility on the under-represented document types.
Worse, if the tokenizer training data contains formatting artifacts (HTML tags, PDF extraction noise) that are cleaned from the model training data, the tokenizer will allocate vocabulary budget to tokens that never appear during training. Those embedding rows receive no gradient updates and remain at their random initialization, wasting both vocabulary capacity and embedding parameters.
Mitigation: Ensure the tokenizer training sample has the same distribution as the model training data. Run fertility analysis on each document type separately before committing to the tokenizer.
Hardware Failures
256 GPUs running for 5+ days. The probability of at least one GPU failure is high. Common issues:
- GPU memory errors (ECC failures): The GPU produces incorrect computation results. If detected, the training run crashes. If undetected (silent data corruption), the model trains on corrupted gradients. Silent data corruption is rare but devastating.
- Network link failures: A single failed InfiniBand link can partition the cluster or degrade all-reduce performance by orders of magnitude.
- Node failures: An entire 8-GPU node goes down. Training crashes, must restart from checkpoint.
Mitigation: Elastic training frameworks (like TorchElastic) can detect node failures and restart training with the remaining nodes, automatically adjusting the parallelism configuration. Save checkpoints frequently. Monitor GPU health metrics (temperature, ECC error count, memory utilization) and proactively replace GPUs showing early signs of failure.
Scaling Law Mismatch
You ran scaling law experiments at small scale (1B model, 5B tokens) to predict performance at 7B/200B. But scaling laws are empirical fits, not physical laws. They can break down when:
- The data distribution changes between scales (you have enough high-quality data for a 1B model but are padding with lower-quality data for the 7B run)
- The architecture choices that work at 1B don’t transfer to 7B (less common with standard transformer architectures, but possible with exotic modifications)
- The optimizer hyperparameters need re-tuning at different scales
Mitigation: Run a 1B model to completion as a “proof of concept” before committing the full cluster to the 7B run. Validate that the 1B model’s learning curve matches scaling law predictions. If it deviates significantly, investigate before scaling up.
Operational Concerns
Training Infrastructure
The training stack:
- Framework: PyTorch with native FSDP. We avoid DeepSpeed not because it’s bad, but because PyTorch FSDP is now mature enough and having the parallelism logic in the same framework as the model code reduces integration complexity.
- Cluster manager: Slurm for job scheduling on a bare-metal or cloud GPU cluster. Kubernetes can work but adds overhead for tightly-coupled distributed training jobs.
- Storage: A parallel filesystem (Lustre or GPFS) for checkpoint storage. S3 for long-term checkpoint archival. The data pipeline reads from S3 into a prefetch buffer.
- Monitoring: Weights & Biases for training metrics (loss, gradient norm, learning rate, throughput). Prometheus + Grafana for cluster health (GPU utilization, memory, temperature, network throughput). Custom alerts for loss spikes and NaN detection.
Data Pipeline
The data pipeline deserves its own architecture, but the key components:
-
Source collection: PubMed abstracts and full-text articles (~30B tokens), de-identified clinical notes from partner health systems (~40B tokens), medical textbooks and reference materials (~20B tokens), drug databases and pharmaceutical literature (~15B tokens), clinical trial reports (~10B tokens), medical coding documentation (ICD-10, CPT, SNOMED) (~5B tokens), high-quality general English text (~80B tokens, for maintaining general language ability).
-
De-identification: All clinical notes pass through a de-identification pipeline that removes names, dates, locations, medical record numbers, and other PHI markers. This is done using a combination of rule-based systems (regex for known patterns like MRN formats) and a fine-tuned NER model for names and locations. The de-identification pipeline runs before tokenization.
-
Quality filtering: Remove documents shorter than 50 tokens (after tokenization). Remove documents with perplexity >1000 under a small reference model (indicates garbage text). Remove documents with >50% non-alphanumeric characters (likely OCR errors or formatting artifacts). Deduplicate using MinHash with Jaccard similarity threshold of 0.8.
-
Tokenization: Apply the trained medical tokenizer. Pack multiple documents into sequences of 4,096 tokens separated by
</s>tokens. Packing avoids wasting compute on padding. -
Shuffling: Shuffle the packed sequences. This is important: if the model sees all PubMed abstracts first and then all clinical notes, it will partially forget PubMed by the time it finishes clinical notes. Shuffling ensures each batch contains a mix of document types.

Evaluation During Training
You cannot wait until the end of training to discover the model is underperforming. Evaluation happens at three frequencies:
Every 10 steps: Training loss and gradient norm. These are cheap to compute (they come from the training loop itself). Look for loss spikes, gradient explosions, and convergence rate.
Every 1,000 steps (each checkpoint): Validation perplexity on a held-out set of 10M tokens spanning all document types. This measures generalization (is the model overfitting to the training data?). Also run a small few-shot evaluation on 100 USMLE questions to track medical knowledge acquisition.
Every 10,000 steps: Full evaluation suite. USMLE (zero-shot and 5-shot), MedNLI, MedQA, PubMedQA, and a custom clinical note understanding benchmark. This takes 2-3 GPU-hours and provides a comprehensive view of model capability at that training stage.
Plot all evaluation metrics against training tokens consumed. The learning curve should show steady improvement with diminishing returns. If a metric plateaus early or regresses, investigate.
When to Stop Training
The plan says 200B tokens, but you might stop earlier or continue longer based on:
- Validation perplexity plateau: If perplexity hasn’t improved in 20B tokens, additional training is unlikely to help. You may be data-constrained (repeating data) or at the model’s capacity limit.
- Downstream task performance: If the 5 representative downstream tasks stop improving on intermediate checkpoints, the foundation model has converged for practical purposes.
- Budget exhaustion: If hardware failures have consumed the compute margin and you are at risk of exceeding budget, stop at the best checkpoint you have.
- Overfitting: If training loss continues to decrease but validation perplexity increases, the model is memorizing. Stop and use the checkpoint with the best validation perplexity.
In practice, for a well-curated 200B token dataset with a 7B model, training will not overfit. The model sees each token only once (single epoch). Overfitting is primarily a concern when you repeat data.
Model Release
After training completes, the model goes through a release process:
- Select the best checkpoint based on validation perplexity and downstream task performance. This is usually the final checkpoint, but occasionally an earlier checkpoint performs better on specific tasks.
- Convert to inference format: Consolidate FSDP shards into a single model file. Convert to the format expected by the inference framework (HuggingFace format, or a custom format for vLLM/TGI serving).
- Run the full evaluation suite: All medical benchmarks, general language benchmarks (to verify the model hasn’t lost basic language ability), safety evaluations (PHI leakage probing, toxicity, refusal of dangerous medical advice in zero-shot).
- Red-teaming: Internal clinical team attempts to elicit dangerous outputs (incorrect drug dosages, hallucinated contraindications, fabricated clinical guidelines). Document findings and assess whether post-training alignment (RLHF, DPO) is needed before deployment.
- Publish internal model card: Document the training data composition, known limitations, evaluation results, and recommended use cases. This is not a public release; it is an internal document for downstream teams who will fine-tune the model.
The foundation model is the starting point. Downstream teams fine-tune it for specific applications, add their own safety layers, and validate in their clinical context. The foundation model team’s job is to hand off the best possible starting point with clear documentation of what the model can and cannot do.
Going Deeper
Why not use a Mixture of Experts (MoE)? MoE architectures (like Mixtral) use conditional computation: only a subset of parameters are active for each token. This allows a model with 46B total parameters to run at the cost of a 12B dense model. For a domain model, MoE has a specific downside: the router (which selects which experts to activate) must learn which expert is relevant for which token. In a domain with very different sub-specialties (radiology vs. pharmacology vs. surgery), the router might learn to specialize experts, which is great for perplexity but can cause problems when fine-tuning for tasks that span specialties. We chose dense architecture for simplicity and uniform representation quality across medical sub-domains. MoE is a viable alternative if serving cost is the primary constraint.
Curriculum learning: Instead of shuffling all data uniformly, present data in a structured order. Start with well-formatted medical textbooks (clean, educational text), then introduce PubMed abstracts (more varied quality), then clinical notes (noisy, abbreviated, domain-specific formatting). The hypothesis is that learning structured medical language first provides a better foundation for learning from noisier sources. Empirical results are mixed. Some teams report 2-3% improvements in downstream performance. Others find that random shuffling with domain mixing works just as well. We shuffle uniformly for simplicity but allocate one of our hyperparameter tuning runs to test a curriculum approach.
Continual pre-training vs. training from scratch: An intermediate option between fine-tuning and training from scratch. Take an existing open model (LLaMA 7B), keep its architecture and most of its weights, but replace the tokenizer and embedding layer with a medical tokenizer. Initialize the new embeddings randomly, and continue pre-training on the medical corpus. The non-embedding layers retain their general language knowledge, and the new embeddings adapt to the medical tokenizer. This typically requires only 20-50B tokens of continued training (10-25% of training from scratch). The risk is that the old representations and the new tokenizer are misaligned early in training, causing a temporary performance drop that may not fully recover. We evaluated this approach in a pilot study and found it reached 90% of the from-scratch model’s quality at 25% of the cost. For teams with tighter budgets, this is the recommended approach.
Data mixing ratios: The 200B token budget is split across data sources. The mixing ratio affects what the model learns. If 90% is PubMed abstracts, the model will be excellent at scientific medical writing but poor at understanding clinical notes (which use abbreviations, shorthand, and a completely different writing style). Our ratio: 20% PubMed, 20% clinical notes, 10% textbooks, 7.5% pharmaceutical, 5% clinical trials, 2.5% coding documentation, 35% general English. The 35% general English allocation is intentional: it maintains the model’s ability to generate fluent, grammatically correct text. Without it, the model tends to produce text that reads like a clinical note even when a conversational tone is appropriate.
The relationship between pre-training and prefill/decode phases at inference: During pre-training, the model processes all tokens in parallel (teacher forcing). During inference, the model generates tokens one at a time (autoregressive decoding). This means the model is trained in a mode that resembles the prefill phase of inference (processing many tokens at once) but deployed in the decode phase (generating one token at a time). The decode phase’s sequential nature and KV cache management are inference-time concerns that don’t affect training, but understanding this distinction helps downstream teams make serving decisions.
Post-training alignment: The foundation model produces raw next-token predictions. It will complete any prompt, including harmful ones. Before deployment in any user-facing application, downstream teams must apply alignment training (supervised fine-tuning on instruction-following data, followed by RLHF or DPO). This is a separate project from foundation model training but is a critical safety requirement. The foundation model team provides the pre-trained weights; the application teams are responsible for alignment and safety layers appropriate to their specific use case.
References
[1] Hoffmann et al. — Training Compute-Optimal Large Language Models (Chinchilla)
[2] Touvron et al. — LLaMA: Open and Efficient Foundation Language Models
[3] Su et al. — RoFormer: Enhanced Transformer with Rotary Position Embedding
[4] Shazeer — GLU Variants Improve Transformer (SwiGLU)
[5] Rajbhandari et al. — ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
[6] Zhao et al. — PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel
[7] Sennrich et al. — Neural Machine Translation of Rare Words with Subword Units (BPE)
[8] Kudo and Richardson — SentencePiece: A Simple and Language Independent Subword Tokenizer
[9] Anil et al. — GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
[10] Zhang et al. — OPT: Open Pre-trained Transformer Language Models (Training Instability Analysis)
[11] Loshchilov and Hutter — Decoupled Weight Decay Regularization (AdamW)
[12] Micikevicius et al. — Mixed Precision Training
[13] Kaplan et al. — Scaling Laws for Neural Language Models
[14] Zhang and Sennrich — Revisiting Few-Sample BERT Fine-Tuning (on Tokenizer Impact)
[15] Singhal et al. — Large Language Models Encode Clinical Knowledge (Med-PaLM)
[16] Narayanan et al. — Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM
Note: This blog represents my technical views and production experience. I use AI-based tools to help with drafting and formatting to keep these posts coming daily.
← Back to all posts