Answer to: How can we train a LLM from scractch in R with the R package torch?
Score: 0
Here is an example of how a LLM can be trained from scracth in R with torch :
library(torch)
# =========================================================
# 0) General config
# =========================================================
set.seed(123)
cfg <- list(
block_size = 12L,
batch_size = 12L,
n_embed = 12L,
n_head = 6L,
n_layer = 6L,
dropout = 0.1,
lr_max = 3e-4,
lr_min = 3e-5,
warmup_iters = 300L,
max_iters = 3000L,
eval_interval = 200L,
eval_batches = 30L,
grad_clip = 1.0,
weight_decay = 0.01,
train_frac = 0.9
)
# cfg <- list(
# block_size = 64L,
# batch_size = 32L,
# n_embed = 192L,
# n_head = 6L,
# n_layer = 6L,
# dropout = 0.1,
# lr_max = 3e-4,
# lr_min = 3e-5,
# warmup_iters = 300L,
# max_iters = 3000L,
# eval_interval = 200L,
# eval_batches = 30L,
# grad_clip = 1.0,
# weight_decay = 0.01,
# train_frac = 0.9
# )
stopifnot(cfg$n_embed %% cfg$n_head == 0)
device <- if (cuda_is_available()) torch_device("cuda") else torch_device("cpu")
cat("Device:", device$type, "\n")
# =========================================================
# 1) Text / vocab
# =========================================================
SPECIAL_TOKENS <- c("<pad>", "<unk>", "<bos>", "<eos>")
tokenize_text <- function(text) {
text <- tolower(text)
text <- gsub("([[:punct:]])", " \\1 ", text)
text <- gsub("[[:space:]]+", " ", text)
text <- trimws(text)
if (nchar(text) == 0) return(character(0))
strsplit(text, " ", fixed = TRUE)[[1]]
}
build_vocab <- function(texts, min_freq = 1L) {
toks <- unlist(lapply(texts, tokenize_text), use.names = FALSE)
freq <- sort(table(toks), decreasing = TRUE)
vocab_words <- names(freq)[freq >= min_freq]
vocab <- unique(c(SPECIAL_TOKENS, vocab_words))
stoi <- setNames(seq_along(vocab), vocab)
itos <- setNames(vocab, as.character(seq_along(vocab)))
list(
vocab = vocab,
stoi = stoi,
itos = itos
)
}
encode_tokens <- function(tokens, stoi, unk_id) {
ids <- unname(stoi[tokens])
ids[is.na(ids)] <- unk_id
as.integer(ids)
}
encode_text <- function(text, stoi, bos_id, eos_id, unk_id) {
toks <- tokenize_text(text)
c(bos_id, encode_tokens(toks, stoi, unk_id), eos_id)
}
decode_ids <- function(ids, itos) {
toks <- unname(itos[as.character(ids)])
toks <- toks[!is.na(toks)]
toks <- toks[!(toks %in% c("<pad>", "<bos>", "<eos>"))]
txt <- paste(toks, collapse = " ")
txt <- gsub("\\s+([[:punct:]])", "\\1", txt)
txt
}
read_corpus_dir <- function(path) {
files <- list.files(path, pattern = "\\.txt$", full.names = TRUE)
if (length(files) == 0) stop("Aucun fichier .txt trouvé dans le dossier.")
texts <- lapply(files, function(f) {
paste(readLines(f, warn = FALSE, encoding = "UTF-8"), collapse = " ")
})
unlist(texts, use.names = FALSE)
}
# =========================================================
# 2) Data preparation
# =========================================================
prepare_dataset <- function(raw_texts, train_frac = 0.9, min_freq = 1L) {
vocab_obj <- build_vocab(raw_texts, min_freq = min_freq)
stoi <- vocab_obj$stoi
itos <- vocab_obj$itos
pad_id <- unname(stoi["<pad>"])
unk_id <- unname(stoi["<unk>"])
bos_id <- unname(stoi["<bos>"])
eos_id <- unname(stoi["<eos>"])
encoded <- lapply(raw_texts, function(txt) {
encode_text(txt, stoi, bos_id, eos_id, unk_id)
})
idx <- sample.int(length(encoded))
n_train <- floor(train_frac * length(encoded))
train_sequences <- encoded[idx[1:n_train]]
val_sequences <- encoded[idx[(n_train + 1):length(encoded)]]
list(
train_sequences = train_sequences,
val_sequences = val_sequences,
stoi = stoi,
itos = itos,
vocab = vocab_obj$vocab,
vocab_size = length(vocab_obj$vocab),
pad_id = as.integer(pad_id),
unk_id = as.integer(unk_id),
bos_id = as.integer(bos_id),
eos_id = as.integer(eos_id)
)
}
sample_subsequence <- function(seq_ids, block_size, pad_id, bos_id, eos_id) {
if (length(seq_ids) < 2L) {
seq_ids <- c(bos_id, eos_id)
}
needed <- block_size + 1L
if (length(seq_ids) >= needed) {
start <- sample.int(length(seq_ids) - needed + 1L, 1)
chunk <- seq_ids[start:(start + needed - 1L)]
} else {
chunk <- c(seq_ids, rep(pad_id, needed - length(seq_ids)))
}
x <- chunk[1:block_size]
y <- chunk[2:(block_size + 1L)]
list(x = x, y = y)
}
get_batch <- function(sequences, batch_size, block_size, device,
pad_id, bos_id, eos_id) {
chosen <- sample(seq_along(sequences), batch_size, replace = TRUE)
xs <- vector("list", batch_size)
ys <- vector("list", batch_size)
for (i in seq_len(batch_size)) {
out <- sample_subsequence(
seq_ids = sequences[[chosen[i]]],
block_size = block_size,
pad_id = pad_id,
bos_id = bos_id,
eos_id = eos_id
)
xs[[i]] <- out$x
ys[[i]] <- out$y
}
x_mat <- do.call(rbind, xs)
y_mat <- do.call(rbind, ys)
list(
x = torch_tensor(x_mat, dtype = torch_long(), device = device),
y = torch_tensor(y_mat, dtype = torch_long(), device = device)
)
}
# =========================================================
# 3) Model module
# =========================================================
causal_multihead_attention <- nn_module(
"causal_multihead_attention",
initialize = function(n_embed, n_head, block_size, dropout = 0.1) {
self$n_embed <- n_embed
self$n_head <- n_head
self$head_dim <- as.integer(n_embed / n_head)
self$q_proj <- nn_linear(n_embed, n_embed, bias = TRUE)
self$k_proj <- nn_linear(n_embed, n_embed, bias = TRUE)
self$v_proj <- nn_linear(n_embed, n_embed, bias = TRUE)
self$out_proj <- nn_linear(n_embed, n_embed, bias = TRUE)
self$attn_dropout <- nn_dropout(dropout)
self$resid_dropout <- nn_dropout(dropout)
mask <- torch_tril(torch_ones(block_size, block_size))
self$register_buffer("mask", mask)
},
forward = function(x) {
B <- x$size(1)
T <- x$size(2)
C <- x$size(3)
q <- self$q_proj(x)$view(c(B, T, self$n_head, self$head_dim))$permute(c(1, 3, 2, 4))
k <- self$k_proj(x)$view(c(B, T, self$n_head, self$head_dim))$permute(c(1, 3, 2, 4))
v <- self$v_proj(x)$view(c(B, T, self$n_head, self$head_dim))$permute(c(1, 3, 2, 4))
att <- torch_matmul(q, k$transpose(-2, -1)) / sqrt(self$head_dim)
causal_mask <- self$mask[1:T, 1:T]$unsqueeze(1)$unsqueeze(1)
att <- att$masked_fill(causal_mask == 0, -1e9)
att <- nnf_softmax(att, dim = -1)
att <- self$attn_dropout(att)
y <- torch_matmul(att, v)
y <- y$permute(c(1, 3, 2, 4))$contiguous()$view(c(B, T, C))
y <- self$out_proj(y)
y <- self$resid_dropout(y)
y
}
)
feed_forward <- nn_module(
"feed_forward",
initialize = function(n_embed, dropout = 0.1) {
self$net <- nn_sequential(
nn_linear(n_embed, 4L * n_embed),
nn_gelu(),
nn_linear(4L * n_embed, n_embed),
nn_dropout(dropout)
)
},
forward = function(x) {
self$net(x)
}
)
transformer_block <- nn_module(
"transformer_block",
initialize = function(n_embed, n_head, block_size, dropout = 0.1) {
self$ln1 <- nn_layer_norm(n_embed)
self$attn <- causal_multihead_attention(n_embed, n_head, block_size, dropout)
self$ln2 <- nn_layer_norm(n_embed)
self$ffn <- feed_forward(n_embed, dropout)
},
forward = function(x) {
x <- x + self$attn(self$ln1(x))
x <- x + self$ffn(self$ln2(x))
x
}
)
gpt_model <- nn_module(
"gpt_model",
initialize = function(vocab_size, block_size, n_embed, n_head, n_layer, dropout = 0.1) {
self$vocab_size <- vocab_size
self$block_size <- block_size
self$n_embed <- n_embed
self$tok_emb <- nn_embedding(vocab_size, n_embed)
self$pos_emb <- nn_embedding(block_size, n_embed)
self$drop <- nn_dropout(dropout)
blocks <- vector("list", n_layer)
for (i in seq_len(n_layer)) {
blocks[[i]] <- transformer_block(n_embed, n_head, block_size, dropout)
}
self$blocks <- do.call(nn_sequential, blocks)
self$ln_f <- nn_layer_norm(n_embed)
# weight tying
self$lm_bias <- nn_parameter(torch_zeros(vocab_size))
},
forward = function(idx, targets = NULL, pad_id) {
B <- idx$size(1)
T <- idx$size(2)
if (T > self$block_size) {
stop("Sequence greater than block_size.")
}
tok <- self$tok_emb(idx)
pos_idx <- torch_tensor(1:T, dtype = torch_long(), device = idx$device)
pos <- self$pos_emb(pos_idx)$unsqueeze(1)$transpose(1, 2)
x <- tok + pos
x <- self$drop(x)
x <- self$blocks(x)
x <- self$ln_f(x)
logits <- torch_matmul(x, self$tok_emb$weight$t()) + self$lm_bias
loss <- NULL
if (!is.null(targets)) {
logits_flat <- logits$view(c(B * T, self$vocab_size))
targets_flat <- targets$view(c(B * T))
loss <- nnf_cross_entropy(
logits_flat,
targets_flat,
ignore_index = as.integer(pad_id)
)
}
list(logits = logits, loss = loss)
}
)
# =========================================================
# 4) Scheduler
# =========================================================
get_lr <- function(iter, warmup_iters, max_iters, lr_max, lr_min) {
if (iter < warmup_iters) {
return(lr_max * iter / warmup_iters)
}
if (iter > max_iters) {
return(lr_min)
}
progress <- (iter - warmup_iters) / (max_iters - warmup_iters)
coeff <- 0.5 * (1 + cos(pi * progress))
lr_min + coeff * (lr_max - lr_min)
}
# =========================================================
# 5) Checkpoints
# =========================================================
save_checkpoint <- function(path, model, optimizer, iter, best_val_loss, cfg, meta) {
obj <- list(
model_state = model$state_dict(),
optimizer_state = optimizer$state_dict(),
iter = iter,
best_val_loss = best_val_loss,
cfg = cfg,
meta = meta
)
torch_save(obj, path)
}
load_checkpoint <- function(path, model, optimizer = NULL, device = device) {
ckpt <- torch_load(path)
model$load_state_dict(ckpt$model_state)
model <- model$to(device = device)
if (!is.null(optimizer) && !is.null(ckpt$optimizer_state)) {
optimizer$load_state_dict(ckpt$optimizer_state)
}
list(
model = model,
optimizer = optimizer,
iter = ckpt$iter,
best_val_loss = ckpt$best_val_loss,
cfg = ckpt$cfg,
meta = ckpt$meta
)
}
# =========================================================
# 6) Evaluation
# =========================================================
estimate_loss <- function(model, train_sequences, val_sequences, cfg, device,
pad_id, bos_id, eos_id) {
model$eval()
eval_split <- function(seqs) {
losses <- numeric(cfg$eval_batches)
with_no_grad({
for (k in seq_len(cfg$eval_batches)) {
batch <- get_batch(
sequences = seqs,
batch_size = cfg$batch_size,
block_size = cfg$block_size,
device = device,
pad_id = pad_id,
bos_id = bos_id,
eos_id = eos_id
)
out <- model(batch$x, batch$y, pad_id = pad_id)
losses[k] <- as.numeric(out$loss$item())
}
})
mean(losses)
}
train_loss <- eval_split(train_sequences)
val_loss <- eval_split(val_sequences)
model$train()
list(
train = train_loss,
val = val_loss,
train_ppl = exp(train_loss),
val_ppl = exp(val_loss)
)
}
# =========================================================
# 7) Generation
# =========================================================
...
# =========================================================
# 8) Traning
# =========================================================
train_gpt <- function(model,
optimizer,
train_sequences,
val_sequences,
cfg,
device,
pad_id,
bos_id,
eos_id,
checkpoint_path = "checkpoint_best.pt") {
best_val_loss <- Inf
history <- list()
for (iter in seq_len(cfg$max_iters)) {
lr_now <- get_lr(
iter = iter,
warmup_iters = cfg$warmup_iters,
max_iters = cfg$max_iters,
lr_max = cfg$lr_max,
lr_min = cfg$lr_min
)
optimizer$param_groups[[1]]$lr <- lr_now
batch <- get_batch(
sequences = train_sequences,
batch_size = cfg$batch_size,
block_size = cfg$block_size,
device = device,
pad_id = pad_id,
bos_id = bos_id,
eos_id = eos_id
)
optimizer$zero_grad()
out <- model(batch$x, batch$y, pad_id = pad_id)
loss <- out$loss
loss$backward()
nn_utils_clip_grad_norm_(model$parameters, max_norm = cfg$grad_clip)
optimizer$step()
if (iter %% cfg$eval_interval == 0) {
stats <- estimate_loss(
model = model,
train_sequences = train_sequences,
val_sequences = val_sequences,
cfg = cfg,
device = device,
pad_id = pad_id,
bos_id = bos_id,
eos_id = eos_id
)
history[[length(history) + 1L]] <- list(
iter = iter,
lr = lr_now,
train_loss = stats$train,
val_loss = stats$val,
train_ppl = stats$train_ppl,
val_ppl = stats$val_ppl
)
cat(sprintf(
"iter=%d | lr=%.6f | train_loss=%.4f | val_loss=%.4f | train_ppl=%.2f | val_ppl=%.2f\n",
iter, lr_now, stats$train, stats$val, stats$train_ppl, stats$val_ppl
))
if (stats$val < best_val_loss) {
best_val_loss <- stats$val
save_checkpoint(
path = checkpoint_path,
model = model,
optimizer = optimizer,
iter = iter,
best_val_loss = best_val_loss,
cfg = cfg,
meta = list()
)
cat(" -> best checkpoint saved\n")
}
}
}
list(
model = model,
optimizer = optimizer,
best_val_loss = best_val_loss,
history = history
)
}
# =========================================================
# 9) Corpus example
# =========================================================
raw_texts <- c(
"Les modèles de langage autoregressifs prédisent le prochain token.",
"Un transformer causal utilise un masque triangulaire pour empêcher l attention vers le futur.",
"Les embeddings transforment les tokens en vecteurs continus.",
"Chaque bloc transformer contient une attention multi-têtes et un réseau feed forward.",
"La normalisation de couche aide à stabiliser l apprentissage.",
"La génération de texte peut être contrôlée avec la température, le top k et le top p.",
"Le pré entraînement sur un corpus massif améliore fortement les performances des modèles.",
"Avec torch en R, il est possible de construire un petit GPT pédagogique."
)
raw_texts <- rep(raw_texts, 500)
# raw_texts <- read_corpus_dir("my_dir_text")
# =========================================================
# 10) Prepare data and model
# =========================================================
data_obj <- prepare_dataset(
raw_texts = raw_texts,
train_frac = cfg$train_frac,
min_freq = 1L
)
model <- gpt_model(
vocab_size = data_obj$vocab_size,
block_size = cfg$block_size,
n_embed = cfg$n_embed,
n_head = cfg$n_head,
n_layer = cfg$n_layer,
dropout = cfg$dropout
)
model <- model$to(device = device)
optimizer <- optim_adamw(
params = model$parameters,
lr = cfg$lr_max,
betas = c(0.9, 0.95),
eps = 1e-8,
weight_decay = cfg$weight_decay
)
# =========================================================
# 11) Training
# =========================================================
fit <- train_gpt(
model = model,
optimizer = optimizer,
train_sequences = data_obj$train_sequences,
val_sequences = data_obj$val_sequences,
cfg = cfg,
device = device,
pad_id = data_obj$pad_id,
bos_id = data_obj$bos_id,
eos_id = data_obj$eos_id,
checkpoint_path = "D:/simple_gpt_best.pt"
)
# =========================================================
# 12) Load the best model
# =========================================================
best_model <- gpt_model(
vocab_size = data_obj$vocab_size,
block_size = cfg$block_size,
n_embed = cfg$n_embed,
n_head = cfg$n_head,
n_layer = cfg$n_layer,
dropout = cfg$dropout
)
best_model <- best_model$to(device = device)
reloaded <- load_checkpoint(
path = "D:/simple_gpt_best.pt",
model = best_model,
optimizer = NULL,
device = device
)
best_model <- reloaded$model
best_model$eval()
# =========================================================
# 13) Text generation
# =========================================================
...
View Question ↗
Question
Parent Entity
Score: 4 • Views: 64
Site: stackoverflow
Other Comments / Reviews
SaaS Metrics