Skip to main content

r/interpreter/builtins/
random.rs

1//! Random number generation builtins: set.seed, runif, rnorm, rbinom, sample, etc.
2//! Uses the per-interpreter RNG state via `BuiltinContext`.
3
4use derive_more::{Display, Error};
5use rand::RngExt;
6use rand::SeedableRng;
7use rand_distr::Distribution;
8
9use crate::interpreter::coerce::{f64_to_i64, f64_to_u64, i64_to_f64};
10use crate::interpreter::value::*;
11use crate::interpreter::BuiltinContext;
12use minir_macros::interpreter_builtin;
13
14// region: RandomError
15
16#[derive(Debug, Display, Error)]
17pub enum RandomError {
18    #[display("invalid '{}' argument", param)]
19    InvalidParam { param: &'static str },
20
21    #[display("invalid argument: '{}' must be non-negative", param)]
22    NonNegative { param: &'static str },
23
24    #[display("invalid arguments: 'min' must not be greater than 'max'")]
25    MinGreaterThanMax,
26
27    #[display("invalid distribution parameters: {}", reason)]
28    InvalidDistribution { reason: String },
29
30    #[display("invalid first argument: must be a positive integer or a vector")]
31    InvalidSampleInput,
32
33    #[display(
34        "cannot take a sample larger than the population ({} > {}) when 'replace = FALSE'",
35        size,
36        pop_len
37    )]
38    SampleTooLarge { size: usize, pop_len: usize },
39
40    #[display("argument '{}' is missing, with no default", param)]
41    MissingParam { param: &'static str },
42
43    #[display("NA in probability vector")]
44    NaInProb,
45
46    #[display("negative probability")]
47    NegativeProb,
48
49    #[display(
50        "'prob' must have the same length as the population ({} != {})",
51        prob_len,
52        pop_len
53    )]
54    ProbLengthMismatch { prob_len: usize, pop_len: usize },
55
56    #[display("too few positive probabilities")]
57    TooFewPositiveProbs,
58}
59
60impl RandomError {
61    fn invalid_dist(e: impl std::fmt::Display) -> Self {
62        RandomError::InvalidDistribution {
63            reason: e.to_string(),
64        }
65    }
66}
67
68impl From<RandomError> for RError {
69    fn from(e: RandomError) -> Self {
70        RError::from_source(RErrorKind::Argument, e)
71    }
72}
73
74// endregion
75
76/// Extract a positive integer `n` from `args[0]`.
77///
78/// Dispatch-level `reorder_builtin_args` ensures `args[0]` is always the first
79/// formal parameter regardless of how the user passed it (named or positional).
80fn extract_n(args: &[RValue]) -> Result<usize, RandomError> {
81    let n = args
82        .first()
83        .and_then(|v| v.as_vector())
84        .and_then(|v| v.as_integer_scalar())
85        .ok_or(RandomError::InvalidParam { param: "n" })?;
86    if n < 0 {
87        return Err(RandomError::NonNegative { param: "n" });
88    }
89    usize::try_from(n).map_err(|_| RandomError::InvalidParam { param: "n" })
90}
91
92/// Extract an optional f64 parameter from named args or positional index.
93///
94/// Checks named args first (handles gap cases where the arg was named but
95/// earlier formals were omitted), then falls back to positional index.
96fn extract_param(
97    args: &[RValue],
98    named: &[(String, RValue)],
99    name: &str,
100    positional_index: usize,
101    default: f64,
102) -> f64 {
103    for (k, v) in named {
104        if k == name {
105            if let Some(rv) = v.as_vector() {
106                if let Some(d) = rv.as_double_scalar() {
107                    return d;
108                }
109            }
110        }
111    }
112    args.get(positional_index)
113        .and_then(|v| v.as_vector())
114        .and_then(|v| v.as_double_scalar())
115        .unwrap_or(default)
116}
117
118/// Extract a required f64 parameter from named args or positional index.
119fn require_param(
120    args: &[RValue],
121    named: &[(String, RValue)],
122    name: &'static str,
123    positional_index: usize,
124) -> Result<f64, RandomError> {
125    for (k, v) in named {
126        if k == name {
127            if let Some(rv) = v.as_vector() {
128                if let Some(d) = rv.as_double_scalar() {
129                    return Ok(d);
130                }
131            }
132        }
133    }
134    args.get(positional_index)
135        .and_then(|v| v.as_vector())
136        .and_then(|v| v.as_double_scalar())
137        .ok_or(RandomError::MissingParam { param: name })
138}
139
140// region: set.seed
141
142/// Set the random number generator seed for reproducibility.
143///
144/// Seeds the per-interpreter RNG deterministically so that subsequent random
145/// draws produce the same sequence. The RNG algorithm seeded depends on the
146/// current `RNGkind()` setting — either Xoshiro (default) or ChaCha20.
147///
148/// Also stores the seed value in `.Random.seed` in the global environment
149/// (as an integer vector whose first element is the seed), matching R's
150/// convention of exposing RNG state there.
151///
152/// @param seed integer seed value (or NULL to re-seed from system entropy)
153/// @return NULL, invisibly
154#[interpreter_builtin(name = "set.seed", min_args = 1)]
155fn interp_set_seed(
156    args: &[RValue],
157    _named: &[(String, RValue)],
158    context: &BuiltinContext,
159) -> Result<RValue, RError> {
160    use crate::interpreter::{InterpreterRng, RngKind};
161
162    // set.seed(NULL) re-seeds from system entropy (like a fresh interpreter)
163    if matches!(args[0], RValue::Null) {
164        context.with_interpreter(|interp| {
165            let mut thread_rng = rand::rng();
166            let new_rng = match interp.rng_kind.get() {
167                RngKind::Xoshiro => {
168                    InterpreterRng::Fast(rand::rngs::SmallRng::from_rng(&mut thread_rng))
169                }
170                RngKind::ChaCha20 => InterpreterRng::Deterministic(Box::new(
171                    rand_chacha::ChaCha20Rng::from_rng(&mut thread_rng),
172                )),
173            };
174            *interp.rng().borrow_mut() = new_rng;
175            // Remove .Random.seed when re-seeding from entropy
176            interp.global_env.remove(".Random.seed");
177        });
178        return Ok(RValue::Null);
179    }
180
181    let seed_f64 = args[0]
182        .as_vector()
183        .and_then(|v| v.as_double_scalar())
184        .ok_or(RandomError::InvalidParam { param: "seed" })?;
185    let seed = f64_to_u64(seed_f64)?;
186    context.with_interpreter(|interp| {
187        let kind = interp.rng_kind.get();
188        let new_rng = match kind {
189            RngKind::Xoshiro => InterpreterRng::Fast(rand::rngs::SmallRng::seed_from_u64(seed)),
190            RngKind::ChaCha20 => InterpreterRng::Deterministic(Box::new(
191                rand_chacha::ChaCha20Rng::seed_from_u64(seed),
192            )),
193        };
194        *interp.rng().borrow_mut() = new_rng;
195        // Store the seed in .Random.seed in the global env.
196        // R's .Random.seed is an integer vector; we store the u64 seed as two
197        // i64 values: a "kind" marker (0 = Xoshiro, 1 = ChaCha20) and the seed.
198        // This is a simplified version of R's full .Random.seed protocol.
199        let kind_code = match kind {
200            RngKind::Xoshiro => 0i64,
201            RngKind::ChaCha20 => 1i64,
202        };
203        let seed_i64 = i64::try_from(seed).unwrap_or(i64::MAX);
204        interp.global_env.set(
205            ".Random.seed".to_string(),
206            RValue::vec(Vector::Integer(
207                vec![Some(kind_code), Some(seed_i64)].into(),
208            )),
209        );
210    });
211    Ok(RValue::Null)
212}
213
214// endregion
215
216// region: RNGkind
217
218/// Query or set the RNG algorithm.
219///
220/// With no arguments, returns the name of the current RNG kind as a character
221/// vector. With a `kind` argument, switches to the specified algorithm.
222///
223/// Supported kinds:
224/// - `"Xoshiro"` (default) — fast, non-cryptographic (`SmallRng` / Xoshiro256++)
225/// - `"ChaCha20"` — deterministic across platforms and Rust versions
226///
227/// After switching the RNG kind, call `set.seed()` to seed the new algorithm.
228/// The switch itself does NOT re-seed — the new RNG starts from system entropy.
229///
230/// @param kind character string naming the RNG algorithm (optional)
231/// @return character vector with the previous RNG kind (invisibly when setting)
232#[interpreter_builtin(name = "RNGkind")]
233fn interp_rng_kind(
234    args: &[RValue],
235    named: &[(String, RValue)],
236    context: &BuiltinContext,
237) -> Result<RValue, RError> {
238    use crate::interpreter::{InterpreterRng, RngKind};
239
240    let old_kind = context.with_interpreter(|interp| interp.rng_kind.get());
241    let old_kind_str = old_kind.to_string();
242
243    // Extract kind argument (positional or named)
244    let kind_arg = named
245        .iter()
246        .find(|(k, _)| k == "kind")
247        .map(|(_, v)| v)
248        .or(args.first());
249
250    if let Some(kind_val) = kind_arg {
251        // NULL means query-only (same as no argument)
252        if matches!(kind_val, RValue::Null) {
253            return Ok(RValue::vec(Vector::Character(
254                vec![Some(old_kind_str)].into(),
255            )));
256        }
257
258        let kind_str = kind_val
259            .as_vector()
260            .and_then(|v| v.as_character_scalar())
261            .ok_or_else(|| {
262                RError::new(
263                    RErrorKind::Argument,
264                    "RNGkind() requires a character string argument".to_string(),
265                )
266            })?;
267
268        let new_kind = match kind_str.as_str() {
269            "Xoshiro" | "xoshiro" => RngKind::Xoshiro,
270            "ChaCha20" | "chacha20" | "ChaCha" | "chacha" => RngKind::ChaCha20,
271            other => {
272                return Err(RError::new(
273                    RErrorKind::Argument,
274                    format!(
275                        "RNGkind(\"{other}\") is not a recognized RNG kind.\n  \
276                         Valid choices: \"Xoshiro\" (default, fast) or \"ChaCha20\" (deterministic, cross-platform)."
277                    ),
278                ));
279            }
280        };
281
282        context.with_interpreter(|interp| {
283            interp.rng_kind.set(new_kind);
284            // Replace the RNG with a fresh instance of the new kind, seeded from entropy.
285            let mut thread_rng = rand::rng();
286            let new_rng = match new_kind {
287                RngKind::Xoshiro => {
288                    InterpreterRng::Fast(rand::rngs::SmallRng::from_rng(&mut thread_rng))
289                }
290                RngKind::ChaCha20 => InterpreterRng::Deterministic(Box::new(
291                    rand_chacha::ChaCha20Rng::from_rng(&mut thread_rng),
292                )),
293            };
294            *interp.rng().borrow_mut() = new_rng;
295        });
296    }
297
298    Ok(RValue::vec(Vector::Character(
299        vec![Some(old_kind_str)].into(),
300    )))
301}
302
303// endregion
304
305// region: Continuous distributions
306
307/// Random uniform deviates.
308///
309/// Generates n random values from a uniform distribution.
310///
311/// @param n number of observations
312/// @param min lower limit of the distribution (default 0)
313/// @param max upper limit of the distribution (default 1)
314/// @return numeric vector of length n
315#[interpreter_builtin(min_args = 1)]
316fn interp_runif(
317    args: &[RValue],
318    named: &[(String, RValue)],
319    context: &BuiltinContext,
320) -> Result<RValue, RError> {
321    let n = extract_n(args)?;
322    let min = extract_param(args, named, "min", 1, 0.0);
323    let max = extract_param(args, named, "max", 2, 1.0);
324    if min > max {
325        return Err(RandomError::MinGreaterThanMax.into());
326    }
327    let dist = rand_distr::Uniform::new(min, max).map_err(RandomError::invalid_dist)?;
328    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
329        let mut rng = interp.rng().borrow_mut();
330        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
331    });
332    Ok(RValue::vec(Vector::Double(values.into())))
333}
334
335/// Random normal deviates.
336///
337/// Generates n random values from a normal distribution.
338#[derive(minir_macros::FromArgs)]
339#[builtin(name = "rnorm")]
340struct RnormArgs {
341    /// number of observations
342    n: i64,
343    /// mean of the distribution
344    #[default(0.0)]
345    mean: f64,
346    /// standard deviation
347    #[default(1.0)]
348    sd: f64,
349}
350
351impl crate::interpreter::value::Builtin for RnormArgs {
352    fn call(self, ctx: &BuiltinContext) -> Result<RValue, RError> {
353        let n = usize::try_from(self.n).map_err(|_| RandomError::NonNegative { param: "n" })?;
354        let dist =
355            rand_distr::Normal::new(self.mean, self.sd).map_err(RandomError::invalid_dist)?;
356        let values: Vec<Option<f64>> = ctx.with_interpreter(|interp| {
357            let mut rng = interp.rng().borrow_mut();
358            (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
359        });
360        Ok(RValue::vec(Vector::Double(values.into())))
361    }
362}
363
364/// Random exponential deviates.
365///
366/// Generates n random values from an exponential distribution.
367///
368/// @param n number of observations
369/// @param rate rate parameter (default 1)
370/// @return numeric vector of length n
371#[interpreter_builtin(min_args = 1)]
372fn interp_rexp(
373    args: &[RValue],
374    named: &[(String, RValue)],
375    context: &BuiltinContext,
376) -> Result<RValue, RError> {
377    let n = extract_n(args)?;
378    let rate = extract_param(args, named, "rate", 1, 1.0);
379    let dist = rand_distr::Exp::new(rate).map_err(RandomError::invalid_dist)?;
380    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
381        let mut rng = interp.rng().borrow_mut();
382        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
383    });
384    Ok(RValue::vec(Vector::Double(values.into())))
385}
386
387/// Random gamma deviates.
388///
389/// Generates n random values from a gamma distribution.
390///
391/// @param n number of observations
392/// @param shape shape parameter (default 1)
393/// @param rate rate parameter (default 1); scale = 1/rate
394/// @return numeric vector of length n
395#[interpreter_builtin(min_args = 1)]
396fn interp_rgamma(
397    args: &[RValue],
398    named: &[(String, RValue)],
399    context: &BuiltinContext,
400) -> Result<RValue, RError> {
401    let n = extract_n(args)?;
402    let shape = extract_param(args, named, "shape", 1, 1.0);
403    let rate = extract_param(args, named, "rate", 2, 1.0);
404    // R uses rate, rand_distr::Gamma uses scale = 1/rate
405    let dist = rand_distr::Gamma::new(shape, 1.0 / rate).map_err(RandomError::invalid_dist)?;
406    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
407        let mut rng = interp.rng().borrow_mut();
408        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
409    });
410    Ok(RValue::vec(Vector::Double(values.into())))
411}
412
413/// Random beta deviates.
414///
415/// Generates n random values from a beta distribution.
416///
417/// @param n number of observations
418/// @param shape1 first shape parameter (default 1)
419/// @param shape2 second shape parameter (default 1)
420/// @return numeric vector of length n
421#[interpreter_builtin(min_args = 1)]
422fn interp_rbeta(
423    args: &[RValue],
424    named: &[(String, RValue)],
425    context: &BuiltinContext,
426) -> Result<RValue, RError> {
427    let n = extract_n(args)?;
428    let shape1 = extract_param(args, named, "shape1", 1, 1.0);
429    let shape2 = extract_param(args, named, "shape2", 2, 1.0);
430    let dist = rand_distr::Beta::new(shape1, shape2).map_err(RandomError::invalid_dist)?;
431    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
432        let mut rng = interp.rng().borrow_mut();
433        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
434    });
435    Ok(RValue::vec(Vector::Double(values.into())))
436}
437
438/// Random Cauchy deviates.
439///
440/// Generates n random values from a Cauchy distribution.
441///
442/// @param n number of observations
443/// @param location location parameter (default 0)
444/// @param scale scale parameter (default 1)
445/// @return numeric vector of length n
446#[interpreter_builtin(min_args = 1)]
447fn interp_rcauchy(
448    args: &[RValue],
449    named: &[(String, RValue)],
450    context: &BuiltinContext,
451) -> Result<RValue, RError> {
452    let n = extract_n(args)?;
453    let location = extract_param(args, named, "location", 1, 0.0);
454    let scale = extract_param(args, named, "scale", 2, 1.0);
455    let dist = rand_distr::Cauchy::new(location, scale).map_err(RandomError::invalid_dist)?;
456    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
457        let mut rng = interp.rng().borrow_mut();
458        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
459    });
460    Ok(RValue::vec(Vector::Double(values.into())))
461}
462
463/// Random Weibull deviates.
464///
465/// Generates n random values from a Weibull distribution.
466///
467/// @param n number of observations
468/// @param shape shape parameter (default 1)
469/// @param scale scale parameter (default 1)
470/// @return numeric vector of length n
471#[interpreter_builtin(min_args = 1)]
472fn interp_rweibull(
473    args: &[RValue],
474    named: &[(String, RValue)],
475    context: &BuiltinContext,
476) -> Result<RValue, RError> {
477    let n = extract_n(args)?;
478    let shape = extract_param(args, named, "shape", 1, 1.0);
479    let scale = extract_param(args, named, "scale", 2, 1.0);
480    let dist = rand_distr::Weibull::new(scale, shape).map_err(RandomError::invalid_dist)?;
481    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
482        let mut rng = interp.rng().borrow_mut();
483        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
484    });
485    Ok(RValue::vec(Vector::Double(values.into())))
486}
487
488/// Random log-normal deviates.
489///
490/// Generates n random values from a log-normal distribution.
491///
492/// @param n number of observations
493/// @param meanlog mean of the distribution on the log scale (default 0)
494/// @param sdlog standard deviation on the log scale (default 1)
495/// @return numeric vector of length n
496#[interpreter_builtin(min_args = 1)]
497fn interp_rlnorm(
498    args: &[RValue],
499    named: &[(String, RValue)],
500    context: &BuiltinContext,
501) -> Result<RValue, RError> {
502    let n = extract_n(args)?;
503    let meanlog = extract_param(args, named, "meanlog", 1, 0.0);
504    let sdlog = extract_param(args, named, "sdlog", 2, 1.0);
505    let dist = rand_distr::LogNormal::new(meanlog, sdlog).map_err(RandomError::invalid_dist)?;
506    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
507        let mut rng = interp.rng().borrow_mut();
508        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
509    });
510    Ok(RValue::vec(Vector::Double(values.into())))
511}
512
513// endregion
514
515// region: Discrete distributions
516
517/// Random binomial deviates.
518///
519/// Generates n random values from a binomial distribution.
520///
521/// @param n number of observations
522/// @param size number of trials (default 1)
523/// @param prob probability of success on each trial (default 0.5)
524/// @return integer vector of length n
525#[interpreter_builtin(min_args = 2)]
526fn interp_rbinom(
527    args: &[RValue],
528    named: &[(String, RValue)],
529    context: &BuiltinContext,
530) -> Result<RValue, RError> {
531    let n = extract_n(args)?;
532    let size = f64_to_u64(extract_param(args, named, "size", 1, 1.0))?;
533    let prob = extract_param(args, named, "prob", 2, 0.5);
534    let dist = rand_distr::Binomial::new(size, prob).map_err(RandomError::invalid_dist)?;
535    let values: Vec<Option<i64>> = context.with_interpreter(|interp| {
536        let mut rng = interp.rng().borrow_mut();
537        (0..n)
538            .map(|_| i64::try_from(dist.sample(&mut *rng)).map(Some))
539            .collect::<Result<_, _>>()
540    })?;
541    Ok(RValue::vec(Vector::Integer(values.into())))
542}
543
544/// Random Poisson deviates.
545///
546/// Generates n random values from a Poisson distribution.
547///
548/// @param n number of observations
549/// @param lambda mean rate parameter (default 1)
550/// @return integer vector of length n
551#[interpreter_builtin(min_args = 1)]
552fn interp_rpois(
553    args: &[RValue],
554    named: &[(String, RValue)],
555    context: &BuiltinContext,
556) -> Result<RValue, RError> {
557    let n = extract_n(args)?;
558    let lambda = extract_param(args, named, "lambda", 1, 1.0);
559    let dist = rand_distr::Poisson::new(lambda).map_err(RandomError::invalid_dist)?;
560    let values: Vec<Option<i64>> = context.with_interpreter(|interp| {
561        let mut rng = interp.rng().borrow_mut();
562        (0..n)
563            .map(|_| f64_to_i64(dist.sample(&mut *rng)).map(Some))
564            .collect::<Result<_, _>>()
565    })?;
566    Ok(RValue::vec(Vector::Integer(values.into())))
567}
568
569/// Random geometric deviates.
570///
571/// Generates n random values from a geometric distribution.
572///
573/// @param n number of observations
574/// @param prob probability of success (default 0.5)
575/// @return integer vector of length n
576#[interpreter_builtin(min_args = 1)]
577fn interp_rgeom(
578    args: &[RValue],
579    named: &[(String, RValue)],
580    context: &BuiltinContext,
581) -> Result<RValue, RError> {
582    let n = extract_n(args)?;
583    let prob = extract_param(args, named, "prob", 1, 0.5);
584    let dist = rand_distr::Geometric::new(prob).map_err(RandomError::invalid_dist)?;
585    let values: Vec<Option<i64>> = context.with_interpreter(|interp| {
586        let mut rng = interp.rng().borrow_mut();
587        (0..n)
588            .map(|_| i64::try_from(dist.sample(&mut *rng)).map(Some))
589            .collect::<Result<_, _>>()
590    })?;
591    Ok(RValue::vec(Vector::Integer(values.into())))
592}
593
594/// Random chi-squared deviates.
595///
596/// Generates n random values from a chi-squared distribution.
597///
598/// @param n number of observations
599/// @param df degrees of freedom (default 1)
600/// @return numeric vector of length n
601#[interpreter_builtin(min_args = 2)]
602fn interp_rchisq(
603    args: &[RValue],
604    named: &[(String, RValue)],
605    context: &BuiltinContext,
606) -> Result<RValue, RError> {
607    let n = extract_n(args)?;
608    let df = extract_param(args, named, "df", 1, 1.0);
609    let dist = rand_distr::ChiSquared::new(df).map_err(RandomError::invalid_dist)?;
610    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
611        let mut rng = interp.rng().borrow_mut();
612        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
613    });
614    Ok(RValue::vec(Vector::Double(values.into())))
615}
616
617/// Random Student's t deviates.
618///
619/// Generates n random values from a Student's t distribution.
620///
621/// @param n number of observations
622/// @param df degrees of freedom (default 1)
623/// @return numeric vector of length n
624#[interpreter_builtin(min_args = 2)]
625fn interp_rt(
626    args: &[RValue],
627    named: &[(String, RValue)],
628    context: &BuiltinContext,
629) -> Result<RValue, RError> {
630    let n = extract_n(args)?;
631    let df = extract_param(args, named, "df", 1, 1.0);
632    let dist = rand_distr::StudentT::new(df).map_err(RandomError::invalid_dist)?;
633    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
634        let mut rng = interp.rng().borrow_mut();
635        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
636    });
637    Ok(RValue::vec(Vector::Double(values.into())))
638}
639
640/// Random F deviates.
641///
642/// Generates n random values from an F distribution.
643///
644/// @param n number of observations
645/// @param df1 numerator degrees of freedom (default 1)
646/// @param df2 denominator degrees of freedom (default 1)
647/// @return numeric vector of length n
648#[interpreter_builtin(min_args = 2)]
649fn interp_rf(
650    args: &[RValue],
651    named: &[(String, RValue)],
652    context: &BuiltinContext,
653) -> Result<RValue, RError> {
654    let n = extract_n(args)?;
655    let df1 = extract_param(args, named, "df1", 1, 1.0);
656    let df2 = extract_param(args, named, "df2", 2, 1.0);
657    let dist = rand_distr::FisherF::new(df1, df2).map_err(RandomError::invalid_dist)?;
658    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
659        let mut rng = interp.rng().borrow_mut();
660        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
661    });
662    Ok(RValue::vec(Vector::Double(values.into())))
663}
664
665/// Random hypergeometric deviates.
666///
667/// Generates nn random values from a hypergeometric distribution.
668///
669/// @param nn number of observations
670/// @param m number of white balls in the urn
671/// @param n number of black balls in the urn
672/// @param k number of balls drawn from the urn
673/// @return integer vector of length nn
674#[interpreter_builtin(min_args = 4)]
675fn interp_rhyper(
676    args: &[RValue],
677    named: &[(String, RValue)],
678    context: &BuiltinContext,
679) -> Result<RValue, RError> {
680    let nn = extract_n(args)?;
681    let m = f64_to_u64(extract_param(args, named, "m", 1, 1.0))?; // white balls
682    let n = f64_to_u64(extract_param(args, named, "n", 2, 1.0))?; // black balls
683    let k = f64_to_u64(extract_param(args, named, "k", 3, 1.0))?; // draws
684    let dist = rand_distr::Hypergeometric::new(m + n, m, k).map_err(RandomError::invalid_dist)?;
685    let values: Vec<Option<i64>> = context.with_interpreter(|interp| {
686        let mut rng = interp.rng().borrow_mut();
687        (0..nn)
688            .map(|_| i64::try_from(dist.sample(&mut *rng)).map(Some))
689            .collect::<Result<_, _>>()
690    })?;
691    Ok(RValue::vec(Vector::Integer(values.into())))
692}
693
694// endregion
695
696// region: sample
697
698/// Random sampling with or without replacement.
699///
700/// @param x vector to sample from, or a positive integer n (sample from 1:n)
701/// @param size number of items to draw (default: length of x)
702/// @param replace if TRUE, sample with replacement (default FALSE)
703/// @param prob optional vector of probability weights
704/// @return vector of sampled elements
705#[interpreter_builtin(min_args = 1)]
706fn interp_sample(
707    args: &[RValue],
708    named: &[(String, RValue)],
709    context: &BuiltinContext,
710) -> Result<RValue, RError> {
711    // sample(x, size, replace = FALSE, prob = NULL)
712    // If x is a single positive integer n, sample from 1:n
713    let x_vec = match &args[0] {
714        RValue::Vector(rv) => rv.clone(),
715        _ => return Err(RandomError::InvalidSampleInput.into()),
716    };
717
718    // Check if x is a single integer n — if so, sample from 1:n
719    let population: Vec<i64> = if x_vec.len() == 1 {
720        if let Some(n) = x_vec.inner.as_integer_scalar() {
721            if n >= 1 {
722                (1..=n).collect()
723            } else {
724                return Err(RandomError::InvalidSampleInput.into());
725            }
726        } else if let Some(d) = x_vec.inner.as_double_scalar() {
727            let n = f64_to_i64(d)?;
728            if n >= 1 && (d - i64_to_f64(n)).abs() < 1e-10 {
729                (1..=n).collect()
730            } else {
731                return Err(RandomError::InvalidSampleInput.into());
732            }
733        } else {
734            // Single-element character/logical vector — sample from the element itself
735            vec![1]
736        }
737    } else {
738        // x is a vector of length > 1 — return indices
739        (1..=i64::try_from(x_vec.len())?).collect()
740    };
741
742    let pop_len = population.len();
743
744    // size defaults to length of population
745    let size = usize::try_from(
746        named
747            .iter()
748            .find(|(k, _)| k == "size")
749            .map(|(_, v)| v)
750            .or(args.get(1))
751            .and_then(|v| v.as_vector())
752            .and_then(|v| v.as_integer_scalar())
753            .map_or_else(|| i64::try_from(pop_len), Ok)?,
754    )?;
755
756    // replace defaults to FALSE
757    let replace = named
758        .iter()
759        .find(|(k, _)| k == "replace")
760        .map(|(_, v)| v)
761        .or(args.get(2))
762        .and_then(|v| v.as_vector())
763        .and_then(|v| v.as_logical_scalar())
764        .unwrap_or(false);
765
766    // Extract prob weights (positional 3 or named "prob")
767    let prob_arg = named
768        .iter()
769        .find(|(k, _)| k == "prob")
770        .map(|(_, v)| v)
771        .or(args.get(3));
772
773    let prob_weights = match prob_arg {
774        Some(RValue::Null) | None => None,
775        Some(v) => {
776            let rv = v
777                .as_vector()
778                .ok_or(RandomError::InvalidParam { param: "prob" })?;
779            let doubles = rv.to_doubles();
780            if doubles.len() != pop_len {
781                return Err(RandomError::ProbLengthMismatch {
782                    prob_len: doubles.len(),
783                    pop_len,
784                }
785                .into());
786            }
787            // Validate: no NA, no negative
788            let mut weights = Vec::with_capacity(pop_len);
789            for w in &doubles {
790                match w {
791                    None => return Err(RandomError::NaInProb.into()),
792                    Some(p) if *p < 0.0 => return Err(RandomError::NegativeProb.into()),
793                    Some(p) => weights.push(*p),
794                }
795            }
796            Some(weights)
797        }
798    };
799
800    if !replace && size > pop_len {
801        return Err(RandomError::SampleTooLarge { size, pop_len }.into());
802    }
803
804    // For weighted without replacement, additionally check that enough items have nonzero weight
805    if !replace {
806        if let Some(ref weights) = prob_weights {
807            let nonzero_count = weights.iter().filter(|&&w| w > 0.0).count();
808            if nonzero_count < size {
809                return Err(RandomError::TooFewPositiveProbs.into());
810            }
811        }
812    }
813
814    let result: Vec<Option<i64>> = context.with_interpreter(|interp| {
815        let mut rng = interp.rng().borrow_mut();
816        match prob_weights {
817            None => {
818                // Unweighted sampling
819                if replace {
820                    (0..size)
821                        .map(|_| Some(population[rng.random_range(0..pop_len)]))
822                        .collect()
823                } else {
824                    // Fisher-Yates partial shuffle
825                    let mut pool = population;
826                    for i in 0..size {
827                        let j = rng.random_range(i..pool.len());
828                        pool.swap(i, j);
829                    }
830                    pool.into_iter().take(size).map(Some).collect()
831                }
832            }
833            Some(weights) => {
834                if replace {
835                    // Weighted sampling with replacement using cumulative probabilities
836                    weighted_sample_with_replacement(&population, &weights, size, &mut *rng)
837                } else {
838                    // Weighted sampling without replacement: sequential draws
839                    weighted_sample_without_replacement(&population, &weights, size, &mut *rng)
840                }
841            }
842        }
843    });
844
845    // If the input was a vector with >1 elements, index into it (1-based)
846    if x_vec.len() > 1 {
847        context
848            .with_interpreter(|interp| interp.index_by_integer(&x_vec.inner, &result))
849            .map_err(RError::from)
850    } else {
851        Ok(RValue::vec(Vector::Integer(result.into())))
852    }
853}
854
855/// Weighted sampling with replacement using cumulative probability + binary search.
856fn weighted_sample_with_replacement(
857    population: &[i64],
858    weights: &[f64],
859    size: usize,
860    rng: &mut impl rand::Rng,
861) -> Vec<Option<i64>> {
862    // Normalize weights to cumulative probabilities
863    let total: f64 = weights.iter().sum();
864    if total <= 0.0 {
865        // All weights are zero — return empty or repeat first non-zero? R errors here.
866        return vec![None; size];
867    }
868
869    let mut cumulative = Vec::with_capacity(weights.len());
870    let mut acc = 0.0;
871    for &w in weights {
872        acc += w / total;
873        cumulative.push(acc);
874    }
875    // Fix rounding: ensure last entry is exactly 1.0
876    if let Some(last) = cumulative.last_mut() {
877        *last = 1.0;
878    }
879
880    let dist = rand_distr::Uniform::new(0.0, 1.0).expect("valid uniform range");
881    (0..size)
882        .map(|_| {
883            let u: f64 = dist.sample(rng);
884            let idx = cumulative.partition_point(|&c| c < u);
885            let idx = idx.min(population.len() - 1);
886            Some(population[idx])
887        })
888        .collect()
889}
890
891/// Weighted sampling without replacement: sequential weighted draws, removing selected items.
892fn weighted_sample_without_replacement(
893    population: &[i64],
894    weights: &[f64],
895    size: usize,
896    rng: &mut impl rand::Rng,
897) -> Vec<Option<i64>> {
898    let mut remaining: Vec<(i64, f64)> = population
899        .iter()
900        .copied()
901        .zip(weights.iter().copied())
902        .collect();
903    let mut result = Vec::with_capacity(size);
904    let dist = rand_distr::Uniform::new(0.0, 1.0).expect("valid uniform range");
905
906    for _ in 0..size {
907        // Compute total weight of remaining items
908        let total: f64 = remaining.iter().map(|(_, w)| w).sum();
909        if total <= 0.0 {
910            break;
911        }
912
913        // Pick a random point in [0, total)
914        let u: f64 = dist.sample(rng) * total;
915        let mut acc = 0.0;
916        let mut chosen_idx = remaining.len() - 1;
917        for (i, (_, w)) in remaining.iter().enumerate() {
918            acc += w;
919            if acc > u {
920                chosen_idx = i;
921                break;
922            }
923        }
924
925        let (val, _) = remaining.remove(chosen_idx);
926        result.push(Some(val));
927    }
928
929    result
930}
931
932// endregion
933
934// region: miniR extension distributions
935//
936// Distributions available via rand_distr that are NOT part of standard R.
937// These are miniR extensions, registered in the "collections" namespace.
938
939/// Random Frechet (Type II extreme value) deviates.
940///
941/// **miniR extension** -- not available in base R.
942///
943/// The Frechet distribution models the maximum of many random variables.
944/// It is parameterised by shape `alpha`, scale `s`, and location `m`.
945///
946/// @param n number of observations
947/// @param alpha shape parameter (positive)
948/// @param s scale parameter (positive, default 1)
949/// @param m location parameter (default 0)
950/// @return numeric vector of length n
951#[interpreter_builtin(min_args = 2, namespace = "collections")]
952fn interp_rfrechet(
953    args: &[RValue],
954    named: &[(String, RValue)],
955    context: &BuiltinContext,
956) -> Result<RValue, RError> {
957    let n = extract_n(args)?;
958    let alpha = require_param(args, named, "alpha", 1)?;
959    let s = extract_param(args, named, "s", 2, 1.0);
960    let m = extract_param(args, named, "m", 3, 0.0);
961    // Frechet::new(location, scale, shape)
962    let dist = rand_distr::Frechet::new(m, s, alpha).map_err(RandomError::invalid_dist)?;
963    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
964        let mut rng = interp.rng().borrow_mut();
965        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
966    });
967    Ok(RValue::vec(Vector::Double(values.into())))
968}
969
970/// Random Gumbel (Type I extreme value) deviates.
971///
972/// **miniR extension** -- not available in base R.
973///
974/// The Gumbel distribution models the maximum (or minimum) of many samples.
975/// It is parameterised by location `mu` and scale `beta`.
976///
977/// @param n number of observations
978/// @param mu location parameter (default 0)
979/// @param beta scale parameter (positive, default 1)
980/// @return numeric vector of length n
981#[interpreter_builtin(min_args = 1, namespace = "collections")]
982fn interp_rgumbel(
983    args: &[RValue],
984    named: &[(String, RValue)],
985    context: &BuiltinContext,
986) -> Result<RValue, RError> {
987    let n = extract_n(args)?;
988    let mu = extract_param(args, named, "mu", 1, 0.0);
989    let beta = extract_param(args, named, "beta", 2, 1.0);
990    let dist = rand_distr::Gumbel::new(mu, beta).map_err(RandomError::invalid_dist)?;
991    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
992        let mut rng = interp.rng().borrow_mut();
993        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
994    });
995    Ok(RValue::vec(Vector::Double(values.into())))
996}
997
998/// Random inverse Gaussian (Wald) deviates.
999///
1000/// **miniR extension** -- not available in base R.
1001///
1002/// The inverse Gaussian distribution is a continuous distribution defined for
1003/// x > 0, parameterised by mean `mu` and shape `lambda`.
1004///
1005/// @param n number of observations
1006/// @param mu mean parameter (positive)
1007/// @param lambda shape parameter (positive)
1008/// @return numeric vector of length n
1009#[interpreter_builtin(min_args = 3, namespace = "collections")]
1010fn interp_rinvgauss(
1011    args: &[RValue],
1012    named: &[(String, RValue)],
1013    context: &BuiltinContext,
1014) -> Result<RValue, RError> {
1015    let n = extract_n(args)?;
1016    let mu = require_param(args, named, "mu", 1)?;
1017    let lambda = require_param(args, named, "lambda", 2)?;
1018    let dist = rand_distr::InverseGaussian::new(mu, lambda).map_err(RandomError::invalid_dist)?;
1019    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
1020        let mut rng = interp.rng().borrow_mut();
1021        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
1022    });
1023    Ok(RValue::vec(Vector::Double(values.into())))
1024}
1025
1026/// Random Pareto deviates.
1027///
1028/// **miniR extension** -- not available in base R.
1029///
1030/// The Pareto distribution is a power-law distribution parameterised by
1031/// scale (minimum value) and shape (tail index).
1032///
1033/// @param n number of observations
1034/// @param scale scale parameter (positive, minimum value of the distribution)
1035/// @param shape shape parameter (positive, controls tail heaviness)
1036/// @return numeric vector of length n
1037#[interpreter_builtin(min_args = 3, namespace = "collections")]
1038fn interp_rpareto(
1039    args: &[RValue],
1040    named: &[(String, RValue)],
1041    context: &BuiltinContext,
1042) -> Result<RValue, RError> {
1043    let n = extract_n(args)?;
1044    let scale = require_param(args, named, "scale", 1)?;
1045    let shape = require_param(args, named, "shape", 2)?;
1046    let dist = rand_distr::Pareto::new(scale, shape).map_err(RandomError::invalid_dist)?;
1047    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
1048        let mut rng = interp.rng().borrow_mut();
1049        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
1050    });
1051    Ok(RValue::vec(Vector::Double(values.into())))
1052}
1053
1054/// Random PERT deviates.
1055///
1056/// **miniR extension** -- not available in base R.
1057///
1058/// The PERT distribution is similar to the triangular distribution but with
1059/// a smooth (beta-shaped) PDF. It is parameterised by min, max, and mode.
1060///
1061/// @param n number of observations
1062/// @param min minimum value
1063/// @param max maximum value
1064/// @param mode most likely value (must be in [min, max])
1065/// @return numeric vector of length n
1066#[interpreter_builtin(min_args = 4, namespace = "collections")]
1067fn interp_rpert(
1068    args: &[RValue],
1069    named: &[(String, RValue)],
1070    context: &BuiltinContext,
1071) -> Result<RValue, RError> {
1072    let n = extract_n(args)?;
1073    let min = require_param(args, named, "min", 1)?;
1074    let max = require_param(args, named, "max", 2)?;
1075    let mode = require_param(args, named, "mode", 3)?;
1076    let dist = rand_distr::Pert::new(min, max)
1077        .with_mode(mode)
1078        .map_err(RandomError::invalid_dist)?;
1079    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
1080        let mut rng = interp.rng().borrow_mut();
1081        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
1082    });
1083    Ok(RValue::vec(Vector::Double(values.into())))
1084}
1085
1086/// Random skew-normal deviates.
1087///
1088/// **miniR extension** -- not available in base R.
1089///
1090/// The skew-normal distribution generalises the normal distribution to allow
1091/// non-zero skewness. When shape = 0 it reduces to the normal distribution.
1092///
1093/// @param n number of observations
1094/// @param location location parameter (default 0)
1095/// @param scale scale parameter (positive, default 1)
1096/// @param shape skewness parameter (default 0; 0 = normal)
1097/// @return numeric vector of length n
1098#[interpreter_builtin(min_args = 1, namespace = "collections")]
1099fn interp_rskewnorm(
1100    args: &[RValue],
1101    named: &[(String, RValue)],
1102    context: &BuiltinContext,
1103) -> Result<RValue, RError> {
1104    let n = extract_n(args)?;
1105    let location = extract_param(args, named, "location", 1, 0.0);
1106    let scale = extract_param(args, named, "scale", 2, 1.0);
1107    let shape = extract_param(args, named, "shape", 3, 0.0);
1108    let dist =
1109        rand_distr::SkewNormal::new(location, scale, shape).map_err(RandomError::invalid_dist)?;
1110    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
1111        let mut rng = interp.rng().borrow_mut();
1112        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
1113    });
1114    Ok(RValue::vec(Vector::Double(values.into())))
1115}
1116
1117/// Random triangular deviates.
1118///
1119/// **miniR extension** -- not available in base R.
1120///
1121/// The triangular distribution has a piecewise linear PDF defined by min, max,
1122/// and mode. For a smooth alternative, see `rpert()`.
1123///
1124/// @param n number of observations
1125/// @param min minimum value
1126/// @param max maximum value
1127/// @param mode most likely value (must be in [min, max])
1128/// @return numeric vector of length n
1129#[interpreter_builtin(min_args = 4, namespace = "collections")]
1130fn interp_rtriangular(
1131    args: &[RValue],
1132    named: &[(String, RValue)],
1133    context: &BuiltinContext,
1134) -> Result<RValue, RError> {
1135    let n = extract_n(args)?;
1136    let min = require_param(args, named, "min", 1)?;
1137    let max = require_param(args, named, "max", 2)?;
1138    let mode = require_param(args, named, "mode", 3)?;
1139    let dist = rand_distr::Triangular::new(min, max, mode).map_err(RandomError::invalid_dist)?;
1140    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
1141        let mut rng = interp.rng().borrow_mut();
1142        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
1143    });
1144    Ok(RValue::vec(Vector::Double(values.into())))
1145}
1146
1147/// Random Zeta deviates.
1148///
1149/// **miniR extension** -- not available in base R.
1150///
1151/// The Zeta distribution is a discrete power-law distribution on positive
1152/// integers. It is the limit of the Zipf distribution as n -> infinity.
1153/// The parameter `s` must be strictly greater than 1.
1154///
1155/// @param n number of observations
1156/// @param s shape parameter (must be > 1)
1157/// @return numeric vector of length n (values are positive integers stored as doubles)
1158#[interpreter_builtin(min_args = 2, namespace = "collections")]
1159fn interp_rzeta(
1160    args: &[RValue],
1161    named: &[(String, RValue)],
1162    context: &BuiltinContext,
1163) -> Result<RValue, RError> {
1164    let n = extract_n(args)?;
1165    let s = require_param(args, named, "s", 1)?;
1166    let dist = rand_distr::Zeta::new(s).map_err(RandomError::invalid_dist)?;
1167    let values: Vec<Option<f64>> = context.with_interpreter(|interp| {
1168        let mut rng = interp.rng().borrow_mut();
1169        (0..n).map(|_| Some(dist.sample(&mut *rng))).collect()
1170    });
1171    Ok(RValue::vec(Vector::Double(values.into())))
1172}
1173
1174// endregion