Clean up generalizer + add comments

This commit is contained in:
Jared Ramirez 2025-11-05 09:51:56 -05:00
parent 9ef6e28718
commit 54fdf4d154
No known key found for this signature in database
GPG key ID: 41158983F521D68C

View file

@ -1,8 +1,41 @@
//! Type instantiation for Hindley-Milner type inference.
//! Type generalization for Hindley-Milner type inference.
//!
//! This module provides functionality to instantiate polymorphic types with fresh
//! type variables while preserving type aliases and structure. This is a critical
//! component for proper handling of annotated functions in the type system.
//! This module implements the generalization phase of Hindley-Milner type inference,
//! which determines which type variables can be made polymorphic (generalized).
//!
//! ## Generalization Overview
//!
//! In Hindley-Milner type systems, we use "ranks" to track the scope level where
//! type variables are introduced. When we finish inferring a let-binding, we attempt
//! to generalize its type - converting concrete type variables into polymorphic ones
//! that can be instantiated differently at each use site.
//!
//! **Key idea:** A variable can only be generalized if it doesn't "escape" its scope
//! by being referenced by variables from outer (lower-ranked) scopes.
//!
//! ## Ranks
//!
//! - **Rank 0 (generalized):** Polymorphic type variables (post-generalization)
//! - **Rank 1 (top_level):** Built-in types, constants
//! - **Rank 2:** Variables introduced at the outermost let-binding
//! - **Rank 3+:** Variables introduced in nested let-bindings
//!
//! ## Example
//!
//! ```roc
//! id = |x| x # Can generalize: forall a. a -> a
//!
//! apply = |f, x|
//! result = f(x) # Cannot fully generalize
//! result
//! ```
//!
//! When generalizing the inner `result`, we discover it references `f` and `x` from
//! the outer scope, so it "escapes" and cannot be generalized.
//!
//! ## Main entry point
//!
//! - `Generalizer.generalize()` - Generalize all variables at a given rank
const std = @import("std");
const base = @import("base");
@ -25,18 +58,45 @@ const Tuple = @import("types.zig").Tuple;
const Rank = @import("types.zig").Rank;
const Ident = base.Ident;
/// Type to manage instantiation.
/// Manages the generalization process for type variables.
///
/// Entry point is `instantiateVar`
/// The Generalizer is responsible for determining which type variables at a given
/// rank can be safely generalized (made polymorphic) and which have "escaped" their
/// scope by being referenced from outer scopes.
///
/// This type does not own any of it's fields it's a convenience wrapper to
/// making threading it's field through all the recursive functions easier
/// ## Algorithm Overview
///
/// 1. **Build temporary rank table:** Copy vars at the rank we're generalizing into
/// a temporary pool for processing
///
/// 2. **Adjust ranks:** Walk through all variables and adjust their ranks to maintain
/// the invariant that ranks never increase as you go deeper into types. This phase
/// detects which variables have escaped.
///
/// 3. **Categorize variables:** After rank adjustment:
/// - If var.rank < rank_to_generalize: Variable escaped (move to lower rank pool)
/// - If var.rank == rank_to_generalize: Can generalize (set rank to .generalized)
///
/// 4. **Update pools:** Move escaped variables to their correct rank pools and set
/// generalized variables to rank .generalized
///
/// ## Entry point
///
/// - `generalize()` - Main function that performs generalization for a given rank
///
/// ## Internal state
///
/// This type holds temporary state during generalization and should be reset between
/// uses. Fields are not owned - store is borrowed, and temporary structures are reused.
pub const Generalizer = struct {
// not owned
/// Borrowed reference to the type store
store: *TypesStore,
/// Tracks which variables we've already adjusted (for handling recursive types)
rank_adjusted_vars: std.AutoHashMap(Var, void),
/// Temporary pool for processing variables during rank adjustment
tmp_var_pool: VarPool,
escaped_vars: std.ArrayList(EscapedVar),
/// Map of which variables we are generalizing this pass
vars_to_generalized: std.AutoHashMap(Var, void),
const EscapedVar = struct { var_: Var, rank: Rank };
@ -47,49 +107,76 @@ pub const Generalizer = struct {
pub fn init(gpa: std.mem.Allocator, store: *TypesStore) std.mem.Allocator.Error!Self {
return .{
.store = store,
.rank_adjusted_vars = std.AutoHashMap(Var, void).init(gpa),
.tmp_var_pool = try VarPool.init(gpa),
.escaped_vars = try std.ArrayList(EscapedVar).initCapacity(gpa, 32),
.rank_adjusted_vars = std.AutoHashMap(Var, void).init(gpa),
.vars_to_generalized = std.AutoHashMap(Var, void).init(gpa),
};
}
/// Reset the state of the generalizer
pub fn reset(self: *Self) void {
self.rank_adjusted_vars.clearRetainingCapacity();
self.tmp_var_pool.clearRetainingCapacity();
self.escaped_vars.clearRetainingCapacity();
self.rank_adjusted_vars.clearRetainingCapacity();
self.vars_to_generalized.clearRetainingCapacity();
}
pub fn deinit(self: *Self, gpa: std.mem.Allocator) void {
self.rank_adjusted_vars.deinit();
pub fn deinit(self: *Self, _: std.mem.Allocator) void {
self.tmp_var_pool.deinit();
self.escaped_vars.deinit(gpa);
self.rank_adjusted_vars.deinit();
self.vars_to_generalized.deinit();
}
/// Performs generalization for all variables at the given rank.
///
/// This is the main entry point for the generalization algorithm. It processes all
/// type variables introduced at `rank_to_generalize` and determines which can be
/// generalized (made polymorphic) and which have escaped to outer scopes.
///
/// ## Algorithm steps:
///
/// 1. **Copy to temporary pool:** Move all vars at this rank into a temporary pool
/// for processing, preserving their current ranks
///
/// 2. **Adjust ranks:** Process vars from lowest to highest rank, adjusting each
/// var's rank based on the ranks of variables it references. This enforces the
/// invariant that ranks never increase as you traverse deeper into types.
///
/// 3. **Separate escaped from generalizable:** After rank adjustment:
/// - Vars with rank < rank_to_generalize have "escaped" (reference outer vars)
/// - Vars with rank == rank_to_generalize can be safely generalized
///
/// 4. **Update var pool:**
/// - Move escaped vars to their (now lower) rank pools
/// - Set generalizable vars to rank (Rank.generalized)
/// - Clear the original rank pool
///
/// ## Parameters
/// - `var_pool`: The main variable pool tracking all vars by rank
/// - `rank_to_generalize`: The rank level to generalize (must be var_pool.current_rank)
pub fn generalize(self: *Self, _: std.mem.Allocator, var_pool: *VarPool, rank_to_generalize: Rank) std.mem.Allocator.Error!void {
std.debug.assert(var_pool.current_rank == rank_to_generalize);
const rank_to_generalize_int = @intFromEnum(rank_to_generalize);
// Reset internal state
// Reset internal state from any previous generalization
self.reset();
// Ensure the tmp pool has enough ranks
// Prepare temporary pool to hold variables during processing
try self.tmp_var_pool.ensureRanksThrough(rank_to_generalize);
self.tmp_var_pool.current_rank = rank_to_generalize;
const vars_to_generalize = var_pool.getVarsForRank(rank_to_generalize);
try self.vars_to_generalized.ensureUnusedCapacity(@intCast(vars_to_generalize.len));
// Build tmp rank table based on the vars at this level to generalize
// Copy all variables at this rank into the temporary pool, resolving redirects
for (vars_to_generalize) |var_| {
// if (!self.store.isRedirect(var_)) {
const resolved = self.store.resolveVar(var_);
try self.tmp_var_pool.addVarToRank(resolved.var_, resolved.desc.rank);
try self.vars_to_generalized.put(resolved.var_, {});
}
// Adjust ranks such that the rank can never increase as you unwrap
// through the structurd
//
// Process from lowest to highest so lower ranks are finalized first
// Adjust ranks to maintain invariant: ranks never increase going deeper.
// Process from lowest to highest rank so that lower ranks are finalized first,
// ensuring we have accurate rank information when processing higher ranks.
for (self.tmp_var_pool.slice(), 0..) |vars_at_rank, group_rank_int| {
const group_rank: Rank = @enumFromInt(group_rank_int);
for (vars_at_rank.items) |var_| {
@ -97,77 +184,109 @@ pub const Generalizer = struct {
}
}
// For ranks 0 through (rank_to_generalize - 1), move to the correct pool
// Move variables from lower ranks (generalized through rank_to_generalize-1) back to main pool.
// These are vars that were initially at rank_to_generalize but had their ranks
// lowered during adjustment because they reference outer-scope variables.
for (self.tmp_var_pool.sliceExceptCurrentRank()) |vars_at_rank| {
// Skip redundant vars
for (vars_at_rank.items) |var_| {
if (!self.store.isRedirect(var_)) {
const resolved = self.store.resolveVar(var_);
// After adjustRank, the variable might have a different rank
try var_pool.addVarToRank(resolved.var_, resolved.desc.rank);
}
}
}
// Iterate over the rank to generalize
// Process variables still at rank_to_generalize after adjustment.
// These either escaped (rank lowered) or can be generalized (rank unchanged).
for (self.tmp_var_pool.ranks.items[rank_to_generalize_int].items) |rank_var| {
// Skip redundant vars
if (!self.store.isRedirect(rank_var)) {
const resolved = self.store.resolveVar(rank_var);
if (@intFromEnum(resolved.desc.rank) < rank_to_generalize_int) {
// Var escaped, move to the correct pool
// Rank was lowered during adjustment - variable escaped
try var_pool.addVarToRank(resolved.var_, resolved.desc.rank);
} else {
// Didn't escape, generalize it
// Rank unchanged - safe to generalize
self.store.setDescRank(resolved.desc_idx, Rank.generalized);
}
}
}
// Clear the rank we just processed
// Clear the rank we just processed from the main pool
var_pool.ranks.items[rank_to_generalize_int].clearRetainingCapacity();
}
// adjust rank //
/// Adjust the rank of a type such that the rank never increase as you move deeper
/// This way, the outermost rank is representative of of the entire structure
/// Adjusts type variable ranks to prepare for generalization.
///
/// This implements the rank adjustment phase of Hindley-Milner generalization.
/// The key insight is that a type can only be generalized if all the type variables
/// it references are also being generalized at the same time (are at the same rank).
///
/// **Core Invariant:** Ranks never increase as you traverse deeper into a type structure.
/// This means the outermost rank represents the maximum rank of the entire type,
/// making it easy to determine which variables can be generalized.
///
/// ## Two classes of variables:
///
/// 1. **Variables being generalized** (in `vars_to_generalize`):
/// - Start at `group_rank` (the rank we're trying to generalize)
/// - Ranks can be INCREASED to the max rank found in their contents
/// - Final rank = max(group_rank, ranks of all nested variables)
/// - If final rank > group_rank, the variable "escaped" and cannot be generalized
///
/// 2. **Other variables** (not in `vars_to_generalize`):
/// - Already introduced at some earlier (lower) rank
/// - Ranks can only be LOWERED to maintain the invariant
/// - Final rank = min(current_rank, group_rank)
/// - This ensures outer types don't incorrectly claim to be "more general" than their contents
///
/// ## Example:
/// ```
/// let outer = |x| # rank 1, introduces var 'x' at rank 1
/// let inner = |y| # rank 2, introduces var 'y' at rank 2
/// (x, y) # references 'x' from rank 1
/// inner
/// ```
/// When generalizing rank 2, we process `inner`'s type. We find it references `x`
/// which is at rank 1 (lower/outer scope). Since `x` is NOT in vars_to_generalize
/// for rank 2, `inner`'s effective rank becomes max(2, 1) = 2, but `x` stays at rank 1.
/// This creates an "escape" - we cannot generalize `inner` because it captures
/// a not-yet-generalized variable from an outer scope.
///
/// ## Recursion handling:
/// - `rank_adjusted_vars` tracks variables we've already processed to handle cycles
/// - For recursive types like `type List a = [Nil, Cons a (List a)]`, we mark the
/// variable as "seen" immediately before recursing, preventing infinite loops
fn adjustRank(self: *Self, var_: Var, group_rank: Rank, vars_to_generalize: []Var) std.mem.Allocator.Error!Rank {
// Resolve the var
const resolved = self.store.resolveVar(var_);
const is_var_to_adjust = blk: {
for (vars_to_generalize) |var_to_generalize| {
if (var_to_generalize == resolved.var_) {
break :blk true;
}
}
break :blk false;
// Check if this variable is one we're trying to generalize at this rank
const is_var_to_generalize = self.vars_to_generalized.contains(resolved.var_);
// Early return for already-processed vars to handle recursive types
if (is_var_to_generalize and self.rank_adjusted_vars.contains(resolved.var_)) {
return resolved.desc.rank;
}
// Calculate the new rank based on whether we're generalizing this var
const new_rank = if (is_var_to_generalize) blk: {
// Mark as seen before recursing to handle cycles
_ = try self.rank_adjusted_vars.put(resolved.var_, {});
// For vars being generalized: rank INCREASES to max of nested vars
// This allows us to detect when a variable "escapes" by referencing
// variables from outer scopes (lower ranks)
break :blk try self.adjustRankContent(resolved.desc.content, group_rank, vars_to_generalize);
} else blk: {
// For other vars: rank can only DECREASE (maintain invariant)
// This ensures that if an outer type references an inner variable,
// the outer type's rank is lowered to match
break :blk resolved.desc.rank.min(group_rank);
};
if (is_var_to_adjust) {
if (self.rank_adjusted_vars.contains(resolved.var_)) {
return resolved.desc.rank;
} else {
// Add the resolved var to the list of seen vars immediately, in case
// this is a recursive type
_ = try self.rank_adjusted_vars.put(resolved.var_, {});
// Get the max rank of this vars
const max_rank = try self.adjustRankContent(resolved.desc.content, group_rank, vars_to_generalize);
// Set the rank
self.store.setDescRank(resolved.desc_idx, max_rank);
return max_rank;
}
} else {
const next_rank = resolved.desc.rank.min(group_rank);
self.store.setDescRank(resolved.desc_idx, next_rank);
return next_rank;
}
self.store.setDescRank(resolved.desc_idx, new_rank);
return new_rank;
}
fn adjustRankContent(self: *Self, content: Content, group_rank: Rank, vars_to_generalize: []Var) std.mem.Allocator.Error!Rank {