ROIpad ← Back to Search
stackoverflow › answer

Answer to: How can we train a LLM from scractch in R with the R package torch?

Score: 0
Answered: Mar 16, 2026
User Rep: 2,323
Here is an example of how a LLM can be trained from scracth in R with torch : library(torch) # ========================================================= # 0) General config # ========================================================= set.seed(123) cfg <- list( block_size = 12L, batch_size = 12L, n_embed = 12L, n_head = 6L, n_layer = 6L, dropout = 0.1, lr_max = 3e-4, lr_min = 3e-5, warmup_iters = 300L, max_iters = 3000L, eval_interval = 200L, eval_batches = 30L, grad_clip = 1.0, weight_decay = 0.01, train_frac = 0.9 ) # cfg <- list( # block_size = 64L, # batch_size = 32L, # n_embed = 192L, # n_head = 6L, # n_layer = 6L, # dropout = 0.1, # lr_max = 3e-4, # lr_min = 3e-5, # warmup_iters = 300L, # max_iters = 3000L, # eval_interval = 200L, # eval_batches = 30L, # grad_clip = 1.0, # weight_decay = 0.01, # train_frac = 0.9 # ) stopifnot(cfg$n_embed %% cfg$n_head == 0) device <- if (cuda_is_available()) torch_device("cuda") else torch_device("cpu") cat("Device:", device$type, "\n") # ========================================================= # 1) Text / vocab # ========================================================= SPECIAL_TOKENS <- c("<pad>", "<unk>", "<bos>", "<eos>") tokenize_text <- function(text) { text <- tolower(text) text <- gsub("([[:punct:]])", " \\1 ", text) text <- gsub("[[:space:]]+", " ", text) text <- trimws(text) if (nchar(text) == 0) return(character(0)) strsplit(text, " ", fixed = TRUE)[[1]] } build_vocab <- function(texts, min_freq = 1L) { toks <- unlist(lapply(texts, tokenize_text), use.names = FALSE) freq <- sort(table(toks), decreasing = TRUE) vocab_words <- names(freq)[freq >= min_freq] vocab <- unique(c(SPECIAL_TOKENS, vocab_words)) stoi <- setNames(seq_along(vocab), vocab) itos <- setNames(vocab, as.character(seq_along(vocab))) list( vocab = vocab, stoi = stoi, itos = itos ) } encode_tokens <- function(tokens, stoi, unk_id) { ids <- unname(stoi[tokens]) ids[is.na(ids)] <- unk_id as.integer(ids) } encode_text <- function(text, stoi, bos_id, eos_id, unk_id) { toks <- tokenize_text(text) c(bos_id, encode_tokens(toks, stoi, unk_id), eos_id) } decode_ids <- function(ids, itos) { toks <- unname(itos[as.character(ids)]) toks <- toks[!is.na(toks)] toks <- toks[!(toks %in% c("<pad>", "<bos>", "<eos>"))] txt <- paste(toks, collapse = " ") txt <- gsub("\\s+([[:punct:]])", "\\1", txt) txt } read_corpus_dir <- function(path) { files <- list.files(path, pattern = "\\.txt$", full.names = TRUE) if (length(files) == 0) stop("Aucun fichier .txt trouvé dans le dossier.") texts <- lapply(files, function(f) { paste(readLines(f, warn = FALSE, encoding = "UTF-8"), collapse = " ") }) unlist(texts, use.names = FALSE) } # ========================================================= # 2) Data preparation # ========================================================= prepare_dataset <- function(raw_texts, train_frac = 0.9, min_freq = 1L) { vocab_obj <- build_vocab(raw_texts, min_freq = min_freq) stoi <- vocab_obj$stoi itos <- vocab_obj$itos pad_id <- unname(stoi["<pad>"]) unk_id <- unname(stoi["<unk>"]) bos_id <- unname(stoi["<bos>"]) eos_id <- unname(stoi["<eos>"]) encoded <- lapply(raw_texts, function(txt) { encode_text(txt, stoi, bos_id, eos_id, unk_id) }) idx <- sample.int(length(encoded)) n_train <- floor(train_frac * length(encoded)) train_sequences <- encoded[idx[1:n_train]] val_sequences <- encoded[idx[(n_train + 1):length(encoded)]] list( train_sequences = train_sequences, val_sequences = val_sequences, stoi = stoi, itos = itos, vocab = vocab_obj$vocab, vocab_size = length(vocab_obj$vocab), pad_id = as.integer(pad_id), unk_id = as.integer(unk_id), bos_id = as.integer(bos_id), eos_id = as.integer(eos_id) ) } sample_subsequence <- function(seq_ids, block_size, pad_id, bos_id, eos_id) { if (length(seq_ids) < 2L) { seq_ids <- c(bos_id, eos_id) } needed <- block_size + 1L if (length(seq_ids) >= needed) { start <- sample.int(length(seq_ids) - needed + 1L, 1) chunk <- seq_ids[start:(start + needed - 1L)] } else { chunk <- c(seq_ids, rep(pad_id, needed - length(seq_ids))) } x <- chunk[1:block_size] y <- chunk[2:(block_size + 1L)] list(x = x, y = y) } get_batch <- function(sequences, batch_size, block_size, device, pad_id, bos_id, eos_id) { chosen <- sample(seq_along(sequences), batch_size, replace = TRUE) xs <- vector("list", batch_size) ys <- vector("list", batch_size) for (i in seq_len(batch_size)) { out <- sample_subsequence( seq_ids = sequences[[chosen[i]]], block_size = block_size, pad_id = pad_id, bos_id = bos_id, eos_id = eos_id ) xs[[i]] <- out$x ys[[i]] <- out$y } x_mat <- do.call(rbind, xs) y_mat <- do.call(rbind, ys) list( x = torch_tensor(x_mat, dtype = torch_long(), device = device), y = torch_tensor(y_mat, dtype = torch_long(), device = device) ) } # ========================================================= # 3) Model module # ========================================================= causal_multihead_attention <- nn_module( "causal_multihead_attention", initialize = function(n_embed, n_head, block_size, dropout = 0.1) { self$n_embed <- n_embed self$n_head <- n_head self$head_dim <- as.integer(n_embed / n_head) self$q_proj <- nn_linear(n_embed, n_embed, bias = TRUE) self$k_proj <- nn_linear(n_embed, n_embed, bias = TRUE) self$v_proj <- nn_linear(n_embed, n_embed, bias = TRUE) self$out_proj <- nn_linear(n_embed, n_embed, bias = TRUE) self$attn_dropout <- nn_dropout(dropout) self$resid_dropout <- nn_dropout(dropout) mask <- torch_tril(torch_ones(block_size, block_size)) self$register_buffer("mask", mask) }, forward = function(x) { B <- x$size(1) T <- x$size(2) C <- x$size(3) q <- self$q_proj(x)$view(c(B, T, self$n_head, self$head_dim))$permute(c(1, 3, 2, 4)) k <- self$k_proj(x)$view(c(B, T, self$n_head, self$head_dim))$permute(c(1, 3, 2, 4)) v <- self$v_proj(x)$view(c(B, T, self$n_head, self$head_dim))$permute(c(1, 3, 2, 4)) att <- torch_matmul(q, k$transpose(-2, -1)) / sqrt(self$head_dim) causal_mask <- self$mask[1:T, 1:T]$unsqueeze(1)$unsqueeze(1) att <- att$masked_fill(causal_mask == 0, -1e9) att <- nnf_softmax(att, dim = -1) att <- self$attn_dropout(att) y <- torch_matmul(att, v) y <- y$permute(c(1, 3, 2, 4))$contiguous()$view(c(B, T, C)) y <- self$out_proj(y) y <- self$resid_dropout(y) y } ) feed_forward <- nn_module( "feed_forward", initialize = function(n_embed, dropout = 0.1) { self$net <- nn_sequential( nn_linear(n_embed, 4L * n_embed), nn_gelu(), nn_linear(4L * n_embed, n_embed), nn_dropout(dropout) ) }, forward = function(x) { self$net(x) } ) transformer_block <- nn_module( "transformer_block", initialize = function(n_embed, n_head, block_size, dropout = 0.1) { self$ln1 <- nn_layer_norm(n_embed) self$attn <- causal_multihead_attention(n_embed, n_head, block_size, dropout) self$ln2 <- nn_layer_norm(n_embed) self$ffn <- feed_forward(n_embed, dropout) }, forward = function(x) { x <- x + self$attn(self$ln1(x)) x <- x + self$ffn(self$ln2(x)) x } ) gpt_model <- nn_module( "gpt_model", initialize = function(vocab_size, block_size, n_embed, n_head, n_layer, dropout = 0.1) { self$vocab_size <- vocab_size self$block_size <- block_size self$n_embed <- n_embed self$tok_emb <- nn_embedding(vocab_size, n_embed) self$pos_emb <- nn_embedding(block_size, n_embed) self$drop <- nn_dropout(dropout) blocks <- vector("list", n_layer) for (i in seq_len(n_layer)) { blocks[[i]] <- transformer_block(n_embed, n_head, block_size, dropout) } self$blocks <- do.call(nn_sequential, blocks) self$ln_f <- nn_layer_norm(n_embed) # weight tying self$lm_bias <- nn_parameter(torch_zeros(vocab_size)) }, forward = function(idx, targets = NULL, pad_id) { B <- idx$size(1) T <- idx$size(2) if (T > self$block_size) { stop("Sequence greater than block_size.") } tok <- self$tok_emb(idx) pos_idx <- torch_tensor(1:T, dtype = torch_long(), device = idx$device) pos <- self$pos_emb(pos_idx)$unsqueeze(1)$transpose(1, 2) x <- tok + pos x <- self$drop(x) x <- self$blocks(x) x <- self$ln_f(x) logits <- torch_matmul(x, self$tok_emb$weight$t()) + self$lm_bias loss <- NULL if (!is.null(targets)) { logits_flat <- logits$view(c(B * T, self$vocab_size)) targets_flat <- targets$view(c(B * T)) loss <- nnf_cross_entropy( logits_flat, targets_flat, ignore_index = as.integer(pad_id) ) } list(logits = logits, loss = loss) } ) # ========================================================= # 4) Scheduler # ========================================================= get_lr <- function(iter, warmup_iters, max_iters, lr_max, lr_min) { if (iter < warmup_iters) { return(lr_max * iter / warmup_iters) } if (iter > max_iters) { return(lr_min) } progress <- (iter - warmup_iters) / (max_iters - warmup_iters) coeff <- 0.5 * (1 + cos(pi * progress)) lr_min + coeff * (lr_max - lr_min) } # ========================================================= # 5) Checkpoints # ========================================================= save_checkpoint <- function(path, model, optimizer, iter, best_val_loss, cfg, meta) { obj <- list( model_state = model$state_dict(), optimizer_state = optimizer$state_dict(), iter = iter, best_val_loss = best_val_loss, cfg = cfg, meta = meta ) torch_save(obj, path) } load_checkpoint <- function(path, model, optimizer = NULL, device = device) { ckpt <- torch_load(path) model$load_state_dict(ckpt$model_state) model <- model$to(device = device) if (!is.null(optimizer) && !is.null(ckpt$optimizer_state)) { optimizer$load_state_dict(ckpt$optimizer_state) } list( model = model, optimizer = optimizer, iter = ckpt$iter, best_val_loss = ckpt$best_val_loss, cfg = ckpt$cfg, meta = ckpt$meta ) } # ========================================================= # 6) Evaluation # ========================================================= estimate_loss <- function(model, train_sequences, val_sequences, cfg, device, pad_id, bos_id, eos_id) { model$eval() eval_split <- function(seqs) { losses <- numeric(cfg$eval_batches) with_no_grad({ for (k in seq_len(cfg$eval_batches)) { batch <- get_batch( sequences = seqs, batch_size = cfg$batch_size, block_size = cfg$block_size, device = device, pad_id = pad_id, bos_id = bos_id, eos_id = eos_id ) out <- model(batch$x, batch$y, pad_id = pad_id) losses[k] <- as.numeric(out$loss$item()) } }) mean(losses) } train_loss <- eval_split(train_sequences) val_loss <- eval_split(val_sequences) model$train() list( train = train_loss, val = val_loss, train_ppl = exp(train_loss), val_ppl = exp(val_loss) ) } # ========================================================= # 7) Generation # ========================================================= ... # ========================================================= # 8) Traning # ========================================================= train_gpt <- function(model, optimizer, train_sequences, val_sequences, cfg, device, pad_id, bos_id, eos_id, checkpoint_path = "checkpoint_best.pt") { best_val_loss <- Inf history <- list() for (iter in seq_len(cfg$max_iters)) { lr_now <- get_lr( iter = iter, warmup_iters = cfg$warmup_iters, max_iters = cfg$max_iters, lr_max = cfg$lr_max, lr_min = cfg$lr_min ) optimizer$param_groups[[1]]$lr <- lr_now batch <- get_batch( sequences = train_sequences, batch_size = cfg$batch_size, block_size = cfg$block_size, device = device, pad_id = pad_id, bos_id = bos_id, eos_id = eos_id ) optimizer$zero_grad() out <- model(batch$x, batch$y, pad_id = pad_id) loss <- out$loss loss$backward() nn_utils_clip_grad_norm_(model$parameters, max_norm = cfg$grad_clip) optimizer$step() if (iter %% cfg$eval_interval == 0) { stats <- estimate_loss( model = model, train_sequences = train_sequences, val_sequences = val_sequences, cfg = cfg, device = device, pad_id = pad_id, bos_id = bos_id, eos_id = eos_id ) history[[length(history) + 1L]] <- list( iter = iter, lr = lr_now, train_loss = stats$train, val_loss = stats$val, train_ppl = stats$train_ppl, val_ppl = stats$val_ppl ) cat(sprintf( "iter=%d | lr=%.6f | train_loss=%.4f | val_loss=%.4f | train_ppl=%.2f | val_ppl=%.2f\n", iter, lr_now, stats$train, stats$val, stats$train_ppl, stats$val_ppl )) if (stats$val < best_val_loss) { best_val_loss <- stats$val save_checkpoint( path = checkpoint_path, model = model, optimizer = optimizer, iter = iter, best_val_loss = best_val_loss, cfg = cfg, meta = list() ) cat(" -> best checkpoint saved\n") } } } list( model = model, optimizer = optimizer, best_val_loss = best_val_loss, history = history ) } # ========================================================= # 9) Corpus example # ========================================================= raw_texts <- c( "Les modèles de langage autoregressifs prédisent le prochain token.", "Un transformer causal utilise un masque triangulaire pour empêcher l attention vers le futur.", "Les embeddings transforment les tokens en vecteurs continus.", "Chaque bloc transformer contient une attention multi-têtes et un réseau feed forward.", "La normalisation de couche aide à stabiliser l apprentissage.", "La génération de texte peut être contrôlée avec la température, le top k et le top p.", "Le pré entraînement sur un corpus massif améliore fortement les performances des modèles.", "Avec torch en R, il est possible de construire un petit GPT pédagogique." ) raw_texts <- rep(raw_texts, 500) # raw_texts <- read_corpus_dir("my_dir_text") # ========================================================= # 10) Prepare data and model # ========================================================= data_obj <- prepare_dataset( raw_texts = raw_texts, train_frac = cfg$train_frac, min_freq = 1L ) model <- gpt_model( vocab_size = data_obj$vocab_size, block_size = cfg$block_size, n_embed = cfg$n_embed, n_head = cfg$n_head, n_layer = cfg$n_layer, dropout = cfg$dropout ) model <- model$to(device = device) optimizer <- optim_adamw( params = model$parameters, lr = cfg$lr_max, betas = c(0.9, 0.95), eps = 1e-8, weight_decay = cfg$weight_decay ) # ========================================================= # 11) Training # ========================================================= fit <- train_gpt( model = model, optimizer = optimizer, train_sequences = data_obj$train_sequences, val_sequences = data_obj$val_sequences, cfg = cfg, device = device, pad_id = data_obj$pad_id, bos_id = data_obj$bos_id, eos_id = data_obj$eos_id, checkpoint_path = "D:/simple_gpt_best.pt" ) # ========================================================= # 12) Load the best model # ========================================================= best_model <- gpt_model( vocab_size = data_obj$vocab_size, block_size = cfg$block_size, n_embed = cfg$n_embed, n_head = cfg$n_head, n_layer = cfg$n_layer, dropout = cfg$dropout ) best_model <- best_model$to(device = device) reloaded <- load_checkpoint( path = "D:/simple_gpt_best.pt", model = best_model, optimizer = NULL, device = device ) best_model <- reloaded$model best_model$eval() # ========================================================= # 13) Text generation # ========================================================= ...
r artificial-intelligence large-language-model
View Question ↗
Question
Parent Entity
Score: 4 • Views: 64
Site: stackoverflow