diff --git a/build.zig b/build.zig index 1d2afc7199..5c503abc29 100644 --- a/build.zig +++ b/build.zig @@ -101,7 +101,7 @@ pub fn build(b: *std.Build) void { }); playground_exe.entry = .disabled; playground_exe.rdynamic = true; - roc_modules.addAll(playground_exe); + roc_modules.addAllExceptBundle(playground_exe); add_tracy(b, roc_modules.build_options, playground_exe, b.resolveTargetQuery(.{ .cpu_arch = .wasm32, @@ -126,7 +126,7 @@ pub fn build(b: *std.Build) void { }); playground_integration_test_exe.root_module.addImport("bytebox", bytebox.module("bytebox")); playground_integration_test_exe.root_module.addImport("build_options", build_options.createModule()); - roc_modules.addAll(playground_integration_test_exe); + roc_modules.addAllExceptBundle(playground_integration_test_exe); const install = b.addInstallArtifact(playground_integration_test_exe, .{}); // Ensure playground WASM is built before running the integration test diff --git a/src/build/modules.zig b/src/build/modules.zig index e56ba6d77f..a87a51a4c9 100644 --- a/src/build/modules.zig +++ b/src/build/modules.zig @@ -32,6 +32,7 @@ pub const ModuleType = enum { repl, fmt, bundle, + unbundle, /// Returns the dependencies for this module type pub fn getDependencies(self: ModuleType) []const ModuleType { @@ -54,6 +55,7 @@ pub const ModuleType = enum { .repl => &.{ .base, .compile, .parse, .types, .can, .check, .builtins, .layout, .eval }, .fmt => &.{ .base, .parse, .collections, .can, .fs, .tracy }, .bundle => &.{ .base, .collections }, + .unbundle => &.{ .base, .collections }, }; } }; @@ -78,6 +80,7 @@ pub const RocModules = struct { repl: *Module, fmt: *Module, bundle: *Module, + unbundle: *Module, pub fn create(b: *Build, build_options_step: *Step.Options, zstd: ?*Dependency) RocModules { const self = RocModules{ @@ -105,12 +108,14 @@ pub const RocModules = struct { .repl = b.addModule("repl", .{ .root_source_file = b.path("src/repl/mod.zig") }), .fmt = b.addModule("fmt", .{ .root_source_file = b.path("src/fmt/mod.zig") }), .bundle = b.addModule("bundle", .{ .root_source_file = b.path("src/bundle/mod.zig") }), + .unbundle = b.addModule("unbundle", .{ .root_source_file = b.path("src/unbundle/mod.zig") }), }; - // Link zstd to bundle module if available + // Link zstd to bundle module if available (for compression) if (zstd) |z| { self.bundle.linkLibrary(z.artifact("zstd")); } + // Note: unbundle module uses Zig's std zstandard, so doesn't need C library // Setup module dependencies using our generic helper self.setupModuleDependencies(); @@ -138,6 +143,7 @@ pub const RocModules = struct { .repl, .fmt, .bundle, + .unbundle, }; // Setup dependencies for each module @@ -171,12 +177,36 @@ pub const RocModules = struct { step.root_module.addImport("repl", self.repl); step.root_module.addImport("fmt", self.fmt); step.root_module.addImport("bundle", self.bundle); + step.root_module.addImport("unbundle", self.unbundle); } pub fn addAllToTest(self: RocModules, step: *Step.Compile) void { self.addAll(step); } + /// Add all modules except bundle (useful for wasm32 targets where zstd isn't available) + pub fn addAllExceptBundle(self: RocModules, step: *Step.Compile) void { + step.root_module.addImport("base", self.base); + step.root_module.addImport("collections", self.collections); + step.root_module.addImport("types", self.types); + step.root_module.addImport("compile", self.compile); + step.root_module.addImport("reporting", self.reporting); + step.root_module.addImport("parse", self.parse); + step.root_module.addImport("can", self.can); + step.root_module.addImport("check", self.check); + step.root_module.addImport("tracy", self.tracy); + step.root_module.addImport("builtins", self.builtins); + step.root_module.addImport("fs", self.fs); + step.root_module.addImport("build_options", self.build_options); + step.root_module.addImport("layout", self.layout); + step.root_module.addImport("eval", self.eval); + step.root_module.addImport("ipc", self.ipc); + step.root_module.addImport("repl", self.repl); + step.root_module.addImport("fmt", self.fmt); + step.root_module.addImport("unbundle", self.unbundle); + // Intentionally omitting bundle module (requires C zstd library) + } + /// Get a module by its type pub fn getModule(self: RocModules, module_type: ModuleType) *Module { return switch (module_type) { @@ -198,6 +228,7 @@ pub const RocModules = struct { .repl => self.repl, .fmt => self.fmt, .bundle => self.bundle, + .unbundle => self.unbundle, }; } @@ -210,7 +241,7 @@ pub const RocModules = struct { } } - pub fn createModuleTests(self: RocModules, b: *Build, target: ResolvedTarget, optimize: OptimizeMode, zstd: ?*Dependency) [16]ModuleTest { + pub fn createModuleTests(self: RocModules, b: *Build, target: ResolvedTarget, optimize: OptimizeMode, zstd: ?*Dependency) [17]ModuleTest { const test_configs = [_]ModuleType{ .collections, .base, @@ -228,6 +259,7 @@ pub const RocModules = struct { .repl, .fmt, .bundle, + .unbundle, }; var tests: [test_configs.len]ModuleTest = undefined; @@ -241,6 +273,7 @@ pub const RocModules = struct { .optimize = optimize, // IPC module needs libc for mmap, munmap, close on POSIX systems // Bundle module needs libc for zstd + // Unbundle module doesn't need libc (uses Zig's std zstandard) .link_libc = (module_type == .ipc or module_type == .bundle), }); diff --git a/src/bundle/mod.zig b/src/bundle/mod.zig index d64d337017..423e33f343 100644 --- a/src/bundle/mod.zig +++ b/src/bundle/mod.zig @@ -1,36 +1,28 @@ -//! Bundle and unbundle functionality for Roc packages +//! Bundle functionality for Roc packages //! //! This module provides functionality to: //! - Bundle Roc packages and their dependencies into compressed tar archives -//! - Unbundle these archives to restore the original files -//! - Download and extract bundled archives from HTTPS URLs //! - Validate paths for security and cross-platform compatibility +//! +//! Note: This module requires the C zstd library for compression. +//! For unbundling functionality that works on all platforms (including WebAssembly), +//! see the separate `unbundle` module. pub const bundle = @import("bundle.zig"); -pub const download = @import("download.zig"); -pub const streaming_reader = @import("streaming_reader.zig"); pub const streaming_writer = @import("streaming_writer.zig"); pub const base58 = @import("base58.zig"); // Re-export commonly used functions and types pub const bundleFiles = bundle.bundle; -pub const unbundle = bundle.unbundle; -pub const unbundleStream = bundle.unbundleStream; -pub const validateBase58Hash = bundle.validateBase58Hash; pub const pathHasBundleErr = bundle.pathHasBundleErr; -pub const pathHasUnbundleErr = bundle.pathHasUnbundleErr; +pub const validateBase58Hash = bundle.validateBase58Hash; // Re-export error types pub const BundleError = bundle.BundleError; -pub const UnbundleError = bundle.UnbundleError; pub const PathValidationError = bundle.PathValidationError; pub const PathValidationReason = bundle.PathValidationReason; pub const ErrorContext = bundle.ErrorContext; -// Re-export extract writer types -pub const ExtractWriter = bundle.ExtractWriter; -pub const DirExtractWriter = bundle.DirExtractWriter; - // Re-export constants pub const STREAM_BUFFER_SIZE = bundle.STREAM_BUFFER_SIZE; pub const DEFAULT_COMPRESSION_LEVEL = bundle.DEFAULT_COMPRESSION_LEVEL; @@ -39,22 +31,15 @@ pub const DEFAULT_COMPRESSION_LEVEL = bundle.DEFAULT_COMPRESSION_LEVEL; pub const allocForZstd = bundle.allocForZstd; pub const freeForZstd = bundle.freeForZstd; -// Re-export download functions -pub const downloadBundle = download.download; -pub const validateUrl = download.validateUrl; -pub const DownloadError = download.DownloadError; - // Test coverage includes: // - Path validation for security and cross-platform compatibility -// - Bundle/unbundle roundtrip with various file types and sizes -// - Hash verification and corruption detection -// - Streaming compression/decompression -// - Download URL validation and security checks -// - Large file handling (with std.tar limitations) +// - Bundle creation with various file types and sizes +// - Hash generation +// - Streaming compression +// - Large file handling test { _ = @import("test_bundle.zig"); _ = @import("test_streaming.zig"); _ = bundle; - _ = download; _ = base58; -} +} \ No newline at end of file diff --git a/src/cli/main.zig b/src/cli/main.zig index b5e47d5a51..d58141e62b 100644 --- a/src/cli/main.zig +++ b/src/cli/main.zig @@ -15,6 +15,7 @@ const compile = @import("compile"); const can = @import("can"); const check = @import("check"); const bundle = @import("bundle"); +const unbundle = @import("unbundle"); const ipc = @import("ipc"); const fmt = @import("fmt"); @@ -1046,7 +1047,7 @@ pub fn extractReadRocFilePathShimLibrary(gpa: Allocator, output_path: []const u8 } /// Format a path validation reason into a user-friendly error message -fn formatPathValidationReason(reason: bundle.PathValidationReason) []const u8 { +fn formatPathValidationReason(reason: unbundle.PathValidationReason) []const u8 { return switch (reason) { .empty_path => "Path cannot be empty", .path_too_long => "Path exceeds maximum length of 255 characters", @@ -1288,8 +1289,8 @@ fn rocUnbundle(gpa: Allocator, args: cli_args.UnbundleArgs) !void { // Unbundle the archive var allocator_copy2 = arena_allocator; - var error_ctx: bundle.ErrorContext = undefined; - bundle.unbundle( + var error_ctx: unbundle.ErrorContext = undefined; + unbundle.unbundleFiles( archive_file.reader(), output_dir, &allocator_copy2, diff --git a/src/unbundle/base58.zig b/src/unbundle/base58.zig new file mode 100644 index 0000000000..1ba0714173 --- /dev/null +++ b/src/unbundle/base58.zig @@ -0,0 +1,593 @@ +//! Base58 encoding and decoding for BLAKE3 hashes +//! +//! This module provides base58 encoding/decoding specifically optimized for 256-bit BLAKE3 hashes. +//! The base58 alphabet excludes visually similar characters (0, O, I, l) to prevent confusion. + +const std = @import("std"); + +// Base58 alphabet (no '0', 'O', 'I', or 'l' to deter visual similarity attacks.) +const base58_alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"; + +/// We use 256-bit BLAKE3 hashes +const hash_bytes = 32; + +/// Number of Base58 characters needed to represent a 256-bit hash +/// +/// Math: +/// - Each base58 character can represent 58 values +/// - So we need ceil(log58(MAX_HASH + 1)) characters +/// - Which is ceil(log(MAX_HASH + 1) / log(58)) +/// - Since MAX_HASH is 2^256 - 1, we need ceil(256 * log(2) / log(58)) +/// - 256 * log(2) / log(58) ≈ 43.7 +/// +/// So we need 44 characters. +pub const base58_hash_bytes = 44; + +/// Encode the given slice of 32 bytes as a base58 string and write it to the destination. +/// Returns a slice of the destination containing the encoded string (1-45 characters). +pub fn encode(src: *const [hash_bytes]u8, dest: *[base58_hash_bytes]u8) []u8 { + // Count leading zero bytes + var leading_zeros: usize = 0; + while (leading_zeros < src.len and src[leading_zeros] == 0) { + leading_zeros += 1; + } + + if (leading_zeros == src.len) { + // All zeros - return just the leading '1's + @memset(dest[0..leading_zeros], '1'); + return dest[0..leading_zeros]; + } + + // Make a mutable scratch copy of the source + var scratch: [hash_bytes]u8 = undefined; + @memcpy(&scratch, src); + + var write_idx: isize = base58_hash_bytes - 1; + const start: usize = leading_zeros; + + // Repeatedly divide scratch[start..] by 58, collecting remainder + // We need to keep dividing until the entire number becomes zero + var has_nonzero = true; + while (has_nonzero) { + var remainder: u16 = 0; + has_nonzero = false; + for (scratch[start..]) |*byte| { + const value = (@as(u16, remainder) << 8) | byte.*; + byte.* = @intCast(value / 58); + remainder = value % 58; + if (byte.* != 0) has_nonzero = true; + } + dest[@intCast(write_idx)] = base58_alphabet[@intCast(remainder)]; + write_idx -= 1; + } + + // Now combine leading '1's with the encoded value + const encoded_start = @as(usize, @intCast(write_idx + 1)); + const encoded_len = base58_hash_bytes - encoded_start; + + // Write leading '1's at the beginning + @memset(dest[0..leading_zeros], '1'); + + // Move the encoded data to follow the '1's + std.mem.copyForwards(u8, dest[leading_zeros .. leading_zeros + encoded_len], dest[encoded_start..base58_hash_bytes]); + + return dest[0 .. leading_zeros + encoded_len]; +} + +test "encode - all zero bytes" { + const input = [_]u8{0} ** 32; + var output: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&input, &output); + + // Should produce all '1's + for (encoded) |char| { + try std.testing.expectEqual('1', char); + } + try std.testing.expectEqual(@as(usize, 32), encoded.len); +} + +test "encode - some leading zero bytes" { + var input = [_]u8{0} ** 32; + input[3] = 1; + input[4] = 2; + input[5] = 3; + var output: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&input, &output); + + // Should have 3 leading '1's + try std.testing.expectEqual('1', encoded[0]); + try std.testing.expectEqual('1', encoded[1]); + try std.testing.expectEqual('1', encoded[2]); + // Rest should be valid base58 + for (encoded) |char| { + const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null; + try std.testing.expect(is_valid); + } +} + +test "encode - single non-zero byte at end" { + var input = [_]u8{0} ** 32; + input[31] = 255; + var output: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&input, &output); + + // Should produce valid base58 ending with encoded value of 255 + for (encoded) |char| { + const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null; + try std.testing.expect(is_valid); + } +} + +test "encode - known 32-byte test vector" { + // 32-byte test vector + const input = [_]u8{ + 0x00, 0x01, 0x09, 0x66, 0x77, 0x00, 0x06, 0x95, + 0x3D, 0x55, 0x67, 0x43, 0x9E, 0x5E, 0x39, 0xF8, + 0x6A, 0x0D, 0x27, 0x3B, 0xEE, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + var output: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&input, &output); + + // Verify all chars are valid base58 + for (encoded) |char| { + const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null; + try std.testing.expect(is_valid); + } +} + +test "encode - text padded to 32 bytes" { + // "Hello World" padded to 32 bytes + var input = [_]u8{0} ** 32; + const text = "Hello World"; + @memcpy(input[0..text.len], text); + var output: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&input, &output); + + // Verify all chars are valid base58 + for (encoded) |char| { + const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null; + try std.testing.expect(is_valid); + } +} + +test "encode - 32-byte hash (typical use case)" { + // SHA256 hash of "test" + const input = [_]u8{ + 0x9f, 0x86, 0xd0, 0x81, 0x88, 0x4c, 0x7d, 0x65, + 0x9a, 0x2f, 0xea, 0xa0, 0xc5, 0x5a, 0xd0, 0x15, + 0xa3, 0xbf, 0x4f, 0x1b, 0x2b, 0x0b, 0x82, 0x2c, + 0xd1, 0x5d, 0x6c, 0x15, 0xb0, 0xf0, 0x0a, 0x08, + }; + var output: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&input, &output); + + // Should produce a valid base58 string + // Check that all characters are in the base58 alphabet + for (encoded) |char| { + const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null; + try std.testing.expect(is_valid); + } +} + +test "encode - all 0xFF bytes" { + const input = [_]u8{0xFF} ** 32; + var output: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&input, &output); + + // Should produce a valid base58 string + for (encoded) |char| { + const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null; + try std.testing.expect(is_valid); + } +} + +test "encode - mixed bytes with leading zeros (32 bytes)" { + var input = [_]u8{0} ** 32; + input[3] = 1; + input[4] = 2; + input[5] = 3; + input[6] = 4; + input[7] = 5; + var output: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&input, &output); + + // Should have 3 leading '1's + try std.testing.expectEqual('1', encoded[0]); + try std.testing.expectEqual('1', encoded[1]); + try std.testing.expectEqual('1', encoded[2]); + + // All should be valid base58 + for (encoded) |char| { + const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null; + try std.testing.expect(is_valid); + } +} + +test "encode - incremental values" { + // Test that incrementing input produces different outputs + var prev_output: [base58_hash_bytes]u8 = undefined; + var curr_output: [base58_hash_bytes]u8 = undefined; + + var input1 = [_]u8{0} ** 32; + input1[31] = 1; + const encoded1 = encode(&input1, &prev_output); + + var input2 = [_]u8{0} ** 32; + input2[31] = 2; + const encoded2 = encode(&input2, &curr_output); + + // Outputs should be different + try std.testing.expect(!std.mem.eql(u8, encoded1, encoded2)); +} + +test "encode - power of two boundaries" { + // Test encoding at power-of-two boundaries in 32-byte inputs + const test_values = [_]u8{ 1, 2, 4, 128, 255 }; + + for (test_values) |val| { + var input = [_]u8{0} ** 32; + input[31] = val; + var output: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&input, &output); + + // Just verify it produces valid base58 + for (encoded) |char| { + const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null; + try std.testing.expect(is_valid); + } + } +} + +test "encode - alternating bits pattern" { + // Test with alternating 0xAA and 0x55 pattern + var input = [_]u8{ 0xAA, 0x55 } ** 16; + var output: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&input, &output); + + // Should produce valid base58 chars + for (encoded) |char| { + const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null; + try std.testing.expect(is_valid); + } +} + +test "encode - sequential bytes" { + // Test with sequential byte values + var input: [hash_bytes]u8 = undefined; + for (0..hash_bytes) |i| { + input[i] = @intCast(i); + } + var output: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&input, &output); + + // Should produce valid base58 chars + for (encoded) |char| { + const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null; + try std.testing.expect(is_valid); + } +} + +test "encode - high entropy random-like data" { + // Test with pseudo-random looking data (using prime multiplication) + var input: [hash_bytes]u8 = undefined; + var val: u32 = 17; + for (0..hash_bytes) |i| { + val = (val *% 31) +% 37; + input[i] = @truncate(val); + } + var output: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&input, &output); + + // Should produce valid base58 chars + for (encoded) |char| { + const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null; + try std.testing.expect(is_valid); + } +} + +test "encode - single bit set in different positions" { + // Test with single bit set at different byte positions + const positions = [_]usize{ 0, 7, 15, 16, 24, 31 }; + + for (positions) |pos| { + var input = [_]u8{0} ** 32; + input[pos] = 1; + var output: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&input, &output); + + // All outputs should be different + for (encoded) |char| { + const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null; + try std.testing.expect(is_valid); + } + } +} + +test "encode - max value in different positions" { + // Test with 0xFF in different positions + const positions = [_]usize{ 0, 8, 16, 24, 31 }; + var outputs: [positions.len][base58_hash_bytes]u8 = undefined; + var encoded_slices: [positions.len][]u8 = undefined; + + for (positions, 0..) |pos, i| { + var input = [_]u8{0} ** 32; + input[pos] = 0xFF; + encoded_slices[i] = encode(&input, &outputs[i]); + + // Verify all chars are valid + for (encoded_slices[i]) |char| { + const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null; + try std.testing.expect(is_valid); + } + } + + // All outputs should be unique + for (0..outputs.len) |i| { + for (i + 1..outputs.len) |j| { + try std.testing.expect(!std.mem.eql(u8, encoded_slices[i], encoded_slices[j])); + } + } +} + +test "encode - known sha256 hash values" { + // Test with actual SHA256 hash outputs + const test_cases = [_][hash_bytes]u8{ + // SHA256("") + [_]u8{ + 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, + 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, + 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, + 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, + }, + // SHA256("abc") + [_]u8{ + 0xba, 0x78, 0x16, 0xbf, 0x8f, 0x01, 0xcf, 0xea, + 0x41, 0x41, 0x40, 0xde, 0x5d, 0xae, 0x22, 0x23, + 0xb0, 0x03, 0x61, 0xa3, 0x96, 0x17, 0x7a, 0x9c, + 0xb4, 0x10, 0xff, 0x61, 0xf2, 0x00, 0x15, 0xad, + }, + }; + + for (test_cases) |input| { + var output: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&input, &output); + + // Should produce valid base58 chars + for (encoded) |char| { + const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null; + try std.testing.expect(is_valid); + } + } +} + +test "encode - boundary values" { + // Test edge cases around byte boundaries + const test_cases = [_]struct { + desc: []const u8, + input: [hash_bytes]u8, + }{ + .{ .desc = "minimum after all zeros", .input = blk: { + var arr = [_]u8{0} ** 32; + arr[31] = 1; + break :blk arr; + } }, + .{ + .desc = "one less than power of 58", + .input = blk: { + var arr = [_]u8{0} ** 32; + arr[31] = 57; // 58 - 1 + break :blk arr; + }, + }, + .{ .desc = "exactly power of 58", .input = blk: { + var arr = [_]u8{0} ** 32; + arr[31] = 58; + break :blk arr; + } }, + .{ .desc = "high bytes in middle", .input = blk: { + var arr = [_]u8{0} ** 32; + arr[15] = 0xFF; + arr[16] = 0xFF; + break :blk arr; + } }, + }; + + for (test_cases) |tc| { + var output: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&tc.input, &output); + + // All should produce valid base58 + for (encoded) |char| { + const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null; + try std.testing.expect(is_valid); + } + } +} + +test "encode - deterministic output" { + // Ensure encoding is deterministic - same input always produces same output + const input = [_]u8{ + 0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0, + 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, + 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + }; + + var output1: [base58_hash_bytes]u8 = undefined; + var output2: [base58_hash_bytes]u8 = undefined; + var output3: [base58_hash_bytes]u8 = undefined; + + const encoded1 = encode(&input, &output1); + const encoded2 = encode(&input, &output2); + const encoded3 = encode(&input, &output3); + + // All outputs should be identical + try std.testing.expectEqualSlices(u8, encoded1, encoded2); + try std.testing.expectEqualSlices(u8, encoded2, encoded3); +} + +/// Decode a base58 string back to 32 bytes. +/// Returns InvalidBase58 error if the string contains invalid characters. +pub fn decode(src: []const u8, dest: *[hash_bytes]u8) !void { + // Clear destination - needed because the multiplication algorithm + // accumulates values across the entire buffer + @memset(dest, 0); + + // Count leading '1's (representing leading zeros in the output) + var leading_ones: usize = 0; + for (src) |char| { + if (char == '1') { + leading_ones += 1; + } else { + break; + } + } + + // If all '1's, we're done (all zeros) + if (leading_ones == src.len) { + return; + } + + // Process each character from the input + for (src) |char| { + // Find the value of this character + const char_value = blk: { + for (base58_alphabet, 0..) |alpha_char, i| { + if (char == alpha_char) { + break :blk @as(u8, @intCast(i)); + } + } + return error.InvalidBase58; + }; + + // Multiply dest by 58 and add char_value + var carry: u16 = char_value; + var j: usize = hash_bytes; + while (j > 0) { + j -= 1; + const value = @as(u16, dest[j]) * 58 + carry; + dest[j] = @truncate(value); + carry = value >> 8; + } + + // If we still have carry, the number is too large + if (carry != 0) { + return error.InvalidBase58; + } + } + + // Count actual leading zeros we produced + var actual_zeros: usize = 0; + for (dest.*) |byte| { + if (byte == 0) { + actual_zeros += 1; + } else { + break; + } + } + + // Standard base58: ensure we have exactly the right number of leading zeros + // Each leading '1' in input should produce one leading zero in output + if (actual_zeros < leading_ones) { + const shift = leading_ones - actual_zeros; + // Shift data right to make room for more zeros + var i: usize = hash_bytes; + while (i > shift) { + i -= 1; + dest[i] = dest[i - shift]; + } + @memset(dest[0..shift], 0); + } +} + +// Tests for decode +test "decode - all ones" { + const input = [_]u8{'1'} ** base58_hash_bytes; + var output: [hash_bytes]u8 = undefined; + try decode(&input, &output); + + // Should produce all zeros + for (output) |byte| { + try std.testing.expectEqual(@as(u8, 0), byte); + } +} + +test "decode - roundtrip simple" { + // Test roundtrip: encode then decode + var original = [_]u8{0} ** 32; + original[31] = 42; + + var encoded_buf: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&original, &encoded_buf); + + var decoded: [hash_bytes]u8 = undefined; + try decode(encoded, &decoded); + + try std.testing.expectEqualSlices(u8, &original, &decoded); +} + +test "decode - roundtrip all values" { + // Test roundtrip with all different byte values in last position + for (0..256) |val| { + var original = [_]u8{0} ** 32; + original[31] = @intCast(val); + + var encoded_buf: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&original, &encoded_buf); + + var decoded: [hash_bytes]u8 = undefined; + try decode(encoded, &decoded); + + try std.testing.expectEqualSlices(u8, &original, &decoded); + } +} + +test "decode - roundtrip with leading zeros" { + var original = [_]u8{0} ** 32; + original[5] = 1; + original[6] = 2; + original[7] = 3; + + var encoded_buf: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&original, &encoded_buf); + + var decoded: [hash_bytes]u8 = undefined; + try decode(encoded, &decoded); + + try std.testing.expectEqualSlices(u8, &original, &decoded); +} + +test "decode - invalid character" { + var input = [_]u8{'1'} ** base58_hash_bytes; + input[20] = '0'; // '0' is not in base58 alphabet + + var output: [hash_bytes]u8 = undefined; + const result = decode(&input, &output); + + try std.testing.expectError(error.InvalidBase58, result); +} + +test "decode - roundtrip random patterns" { + const patterns = [_][hash_bytes]u8{ + [_]u8{0xFF} ** 32, + [_]u8{ 0xAA, 0x55 } ** 16, + blk: { + var arr: [hash_bytes]u8 = undefined; + for (0..hash_bytes) |i| { + arr[i] = @intCast(i * 7); + } + break :blk arr; + }, + }; + + for (patterns) |original| { + var encoded_buf: [base58_hash_bytes]u8 = undefined; + const encoded = encode(&original, &encoded_buf); + + var decoded: [hash_bytes]u8 = undefined; + try decode(encoded, &decoded); + + try std.testing.expectEqualSlices(u8, &original, &decoded); + } +} diff --git a/src/unbundle/download.zig b/src/unbundle/download.zig new file mode 100644 index 0000000000..c3a0d5c4a0 --- /dev/null +++ b/src/unbundle/download.zig @@ -0,0 +1,227 @@ +//! Download and extract bundled tar.zst files over https +//! (or http if the URL host is `localhost`, `127.0.0.1`, or `::1`) + +const std = @import("std"); +const builtin = @import("builtin"); +const unbundle = @import("unbundle.zig"); + +// Network constants +const HTTPS_DEFAULT_PORT: u16 = 443; +const HTTP_DEFAULT_PORT: u16 = 80; +const SERVER_HEADER_BUFFER_SIZE: usize = 16 * 1024; + +// IPv4 loopback address 127.0.0.1 in network byte order +const IPV4_LOOPBACK_BE: u32 = 0x7F000001; // Big-endian +const IPV4_LOOPBACK_LE: u32 = 0x0100007F; // Little-endian + +/// Errors that can occur during the download operation. +pub const DownloadError = error{ + InvalidUrl, + LocalhostWasNotLoopback, + InvalidHash, + HttpError, + NoHashInUrl, +} || unbundle.UnbundleError || std.mem.Allocator.Error; + +/// Parse URL and validate it meets our security requirements. +/// Returns the hash from the URL if valid. +pub fn validateUrl(url: []const u8) DownloadError![]const u8 { + // Check for https:// prefix + if (std.mem.startsWith(u8, url, "https://")) { + // This is fine, extract hash from last segment + } else if (std.mem.startsWith(u8, url, "http://127.0.0.1:") or std.mem.startsWith(u8, url, "http://127.0.0.1/")) { + // This is allowed for local testing (IPv4 loopback) + } else if (std.mem.startsWith(u8, url, "http://[::1]:") or std.mem.startsWith(u8, url, "http://[::1]/")) { + // This is allowed for local testing (IPv6 loopback) + } else if (std.mem.startsWith(u8, url, "http://localhost:") or std.mem.startsWith(u8, url, "http://localhost/")) { + // This is allowed but will require verification that localhost resolves to loopback + } else { + return error.InvalidUrl; + } + + // Extract the last path segment (should be the hash) + const last_slash = std.mem.lastIndexOf(u8, url, "/") orelse return error.NoHashInUrl; + const hash_part = url[last_slash + 1 ..]; + + // Remove .tar.zst extension if present + const hash = if (std.mem.endsWith(u8, hash_part, ".tar.zst")) + hash_part[0 .. hash_part.len - 8] + else + hash_part; + + if (hash.len == 0) { + return error.NoHashInUrl; + } + + return hash; +} + +/// Download and extract a bundled tar.zst file from a URL. +/// +/// The URL must: +/// - Start with "https://" or "http://127.0.0.1" +/// - Have the base58-encoded blake3 hash as the last path segment +/// - Point to a tar.zst file created with `roc bundle` +pub fn downloadAndExtract( + allocator: *std.mem.Allocator, + url: []const u8, + extract_dir: std.fs.Dir, +) DownloadError!void { + // Validate URL and extract hash + const base58_hash = try validateUrl(url); + + // Validate the hash before starting any I/O + const expected_hash = (try unbundle.validateBase58Hash(base58_hash)) orelse { + return error.InvalidHash; + }; + + // Create HTTP client + var client = std.http.Client{ .allocator = allocator.* }; + defer client.deinit(); + + // Parse the URL + const uri = std.Uri.parse(url) catch return error.InvalidUrl; + + // Check if we need to resolve localhost + var extra_headers: []const std.http.Header = &.{}; + if (uri.host) |host| { + if (std.mem.eql(u8, host.percent_encoded, "localhost")) { + // Security: We must resolve "localhost" and verify it points to a loopback address. + // This prevents attacks where: + // 1. An attacker modifies /etc/hosts to make localhost resolve to their server + // 2. A compromised DNS makes localhost resolve to an external IP + // 3. Container/VM networking misconfiguration exposes localhost to external IPs + // + // We're being intentionally strict here: + // - For IPv4: We only accept exactly 127.0.0.1 (not the full 127.0.0.0/8 range) + // - For IPv6: We only accept exactly ::1 (not other loopback addresses) + // + // While the specs technically allow any 127.x.y.z address for IPv4 loopback + // and multiple forms for IPv6, in practice localhost almost always resolves + // to these exact addresses, and being stricter improves security. + + const address_list = try std.net.getAddressList(allocator.*, "localhost", uri.port orelse HTTP_DEFAULT_PORT); + defer address_list.deinit(); + + if (address_list.addrs.len == 0) { + return error.LocalhostWasNotLoopback; + } + + // Check that at least one address is a loopback + var found_loopback = false; + for (address_list.addrs) |addr| { + switch (addr.any.family) { + std.posix.AF.INET => { + const ipv4_addr = addr.in.sa.addr; + if (ipv4_addr == IPV4_LOOPBACK_BE or ipv4_addr == IPV4_LOOPBACK_LE) { + found_loopback = true; + break; + } + }, + std.posix.AF.INET6 => { + const ipv6_addr = addr.in6.sa.addr; + // Check if it's exactly ::1 (all zeros except last byte is 1) + var is_loopback = true; + for (ipv6_addr[0..15]) |byte| { + if (byte != 0) { + is_loopback = false; + break; + } + } + if (is_loopback and ipv6_addr[15] == 1) { + found_loopback = true; + break; + } + }, + else => {}, // Ignore other address families + } + } + + if (!found_loopback) { + return error.LocalhostWasNotLoopback; + } + + // Since we're using "localhost", we need to set the Host header manually + // to match what the server expects + extra_headers = &.{ + .{ .name = "Host", .value = "localhost" }, + }; + } + } + + // Start the HTTP request + var header_buffer: [SERVER_HEADER_BUFFER_SIZE]u8 = undefined; + var request = try client.open(.GET, uri, .{ + .server_header_buffer = &header_buffer, + .extra_headers = extra_headers, + }); + defer request.deinit(); + + // Send the request and wait for response + try request.send(); + try request.wait(); + + // Check for successful response + if (request.response.status != .ok) { + return error.HttpError; + } + + const reader = request.reader(); + + // Setup directory extract writer + var dir_writer = unbundle.DirExtractWriter.init(extract_dir); + + // Stream and extract the content + try unbundle.unbundleStream(reader, dir_writer.extractWriter(), allocator, &expected_hash, null); +} + +/// Download and extract a bundled tar.zst file to memory buffers. +/// +/// Returns a BufferExtractWriter containing all extracted files and directories. +/// The caller owns the returned writer and must call deinit() on it. +pub fn downloadAndExtractToBuffer( + allocator: *std.mem.Allocator, + url: []const u8, +) DownloadError!unbundle.BufferExtractWriter { + // Validate URL and extract hash + const base58_hash = try validateUrl(url); + + // Validate the hash before starting any I/O + const expected_hash = (try unbundle.validateBase58Hash(base58_hash)) orelse { + return error.InvalidHash; + }; + + // Create HTTP client + var client = std.http.Client{ .allocator = allocator.* }; + defer client.deinit(); + + // Parse the URL + const uri = std.Uri.parse(url) catch return error.InvalidUrl; + + // Start the HTTP request (simplified version without localhost resolution for brevity) + var header_buffer: [SERVER_HEADER_BUFFER_SIZE]u8 = undefined; + var request = try client.open(.GET, uri, .{ + .server_header_buffer = &header_buffer, + }); + defer request.deinit(); + + // Send the request and wait for response + try request.send(); + try request.wait(); + + // Check for successful response + if (request.response.status != .ok) { + return error.HttpError; + } + + const reader = request.reader(); + + // Setup buffer extract writer + var buffer_writer = unbundle.BufferExtractWriter.init(allocator); + errdefer buffer_writer.deinit(); + + // Stream and extract the content + try unbundle.unbundleStream(reader, buffer_writer.extractWriter(), allocator, &expected_hash, null); + + return buffer_writer; +} \ No newline at end of file diff --git a/src/unbundle/mod.zig b/src/unbundle/mod.zig new file mode 100644 index 0000000000..47890318ab --- /dev/null +++ b/src/unbundle/mod.zig @@ -0,0 +1,35 @@ +//! Unbundle functionality for Roc packages using Zig's standard library +//! +//! This module provides functionality to: +//! - Unbundle compressed tar archives (.tar.zst files) +//! - Validate and decode base58-encoded hashes +//! - Extract files with security and cross-platform path validation +//! - Download and extract bundled archives from HTTPS URLs +//! +//! This module uses Zig's std.compress.zstandard for decompression, +//! making it compatible with WebAssembly targets. + +pub const unbundle = @import("unbundle.zig"); +pub const download = @import("download.zig"); +pub const base58 = @import("base58.zig"); + +// Re-export commonly used functions and types +pub const unbundleFiles = unbundle.unbundle; +pub const unbundleStream = unbundle.unbundleStream; +pub const validateBase58Hash = unbundle.validateBase58Hash; +pub const pathHasUnbundleErr = unbundle.pathHasUnbundleErr; + +// Re-export error types +pub const UnbundleError = unbundle.UnbundleError; +pub const PathValidationError = unbundle.PathValidationError; +pub const PathValidationReason = unbundle.PathValidationReason; +pub const ErrorContext = unbundle.ErrorContext; + +// Re-export extract writer types +pub const ExtractWriter = unbundle.ExtractWriter; +pub const DirExtractWriter = unbundle.DirExtractWriter; +pub const BufferExtractWriter = unbundle.BufferExtractWriter; + +// Re-export download functionality +pub const downloadAndExtract = download.downloadAndExtract; +pub const downloadAndExtractToBuffer = download.downloadAndExtractToBuffer; \ No newline at end of file diff --git a/src/unbundle/unbundle.zig b/src/unbundle/unbundle.zig new file mode 100644 index 0000000000..abd8d21743 --- /dev/null +++ b/src/unbundle/unbundle.zig @@ -0,0 +1,383 @@ +//! Unbundle compressed tar archives using Zig's standard library +//! +//! This module provides unbundling functionality that works on all platforms +//! including WebAssembly, by using Zig's std.compress.zstandard instead of +//! the C zstd library. + +const builtin = @import("builtin"); +const std = @import("std"); +const base58 = @import("base58.zig"); + +// Constants +const TAR_EXTENSION = ".tar.zst"; +const STREAM_BUFFER_SIZE: usize = 64 * 1024; // 64KB buffer for streaming operations + +/// Errors that can occur during the unbundle operation. +pub const UnbundleError = error{ + DecompressionFailed, + InvalidTarHeader, + UnexpectedEndOfStream, + FileCreateFailed, + DirectoryCreateFailed, + FileWriteFailed, + HashMismatch, + InvalidFilename, + FileTooLarge, + InvalidPath, + NoDataExtracted, +} || std.mem.Allocator.Error; + +/// Context for error reporting during unbundle operations +pub const ErrorContext = struct { + path: []const u8, + reason: PathValidationReason, +}; + +/// Specific reason why a path validation failed +pub const PathValidationReason = union(enum) { + empty_path, + path_too_long, + windows_reserved_char: u8, + absolute_path, + path_traversal, + current_directory_reference, + windows_reserved_name, + contained_backslash_on_unix, + component_ends_with_space, + component_ends_with_period, +}; + +/// Error type for path validation failures +pub const PathValidationError = struct { + path: []const u8, + reason: PathValidationReason, +}; + +/// Writer interface for extracting files during unbundle +pub const ExtractWriter = struct { + ptr: *anyopaque, + vtable: *const VTable, + + pub const VTable = struct { + createFile: *const fn (ptr: *anyopaque, path: []const u8) CreateFileError!std.io.AnyWriter, + finishFile: *const fn (ptr: *anyopaque, writer: std.io.AnyWriter) FinishFileError!void, + makeDir: *const fn (ptr: *anyopaque, path: []const u8) MakeDirError!void, + }; + + pub const CreateFileError = error{ FileCreateFailed, InvalidPath, OutOfMemory }; + pub const FinishFileError = error{FileWriteFailed}; + pub const MakeDirError = error{ DirectoryCreateFailed, InvalidPath, OutOfMemory }; + + pub fn createFile(self: ExtractWriter, path: []const u8) CreateFileError!std.io.AnyWriter { + return self.vtable.createFile(self.ptr, path); + } + + pub fn finishFile(self: ExtractWriter, writer: std.io.AnyWriter) FinishFileError!void { + return self.vtable.finishFile(self.ptr, writer); + } + + pub fn makeDir(self: ExtractWriter, path: []const u8) MakeDirError!void { + return self.vtable.makeDir(self.ptr, path); + } +}; + +/// Directory-based extract writer +pub const DirExtractWriter = struct { + dir: std.fs.Dir, + + pub fn init(dir: std.fs.Dir) DirExtractWriter { + return .{ .dir = dir }; + } + + pub fn extractWriter(self: *DirExtractWriter) ExtractWriter { + return ExtractWriter{ + .ptr = self, + .vtable = &vtable, + }; + } + + const vtable = ExtractWriter.VTable{ + .createFile = createFile, + .finishFile = finishFile, + .makeDir = makeDir, + }; + + fn createFile(ptr: *anyopaque, path: []const u8) ExtractWriter.CreateFileError!std.io.AnyWriter { + const self: *DirExtractWriter = @ptrCast(@alignCast(ptr)); + + // Ensure parent directories exist + if (std.fs.path.dirname(path)) |parent| { + self.dir.makePath(parent) catch return error.FileCreateFailed; + } + + const file = self.dir.createFile(path, .{}) catch return error.FileCreateFailed; + return file.writer().any(); + } + + fn finishFile(_: *anyopaque, writer: std.io.AnyWriter) ExtractWriter.FinishFileError!void { + // For file writers, we need to close the file + // In Zig 0.14, we need to properly cast the context + const file_writer = writer.context; + const file = @as(*std.fs.File, @ptrCast(@alignCast(file_writer))); + file.close(); + } + + fn makeDir(ptr: *anyopaque, path: []const u8) ExtractWriter.MakeDirError!void { + const self: *DirExtractWriter = @ptrCast(@alignCast(ptr)); + self.dir.makePath(path) catch return error.DirectoryCreateFailed; + } +}; + +/// Buffer-based extract writer for in-memory extraction +pub const BufferExtractWriter = struct { + allocator: *std.mem.Allocator, + files: std.StringHashMap(std.ArrayList(u8)), + directories: std.ArrayList([]u8), + current_file: ?*std.ArrayList(u8) = null, + + pub fn init(allocator: *std.mem.Allocator) BufferExtractWriter { + return .{ + .allocator = allocator, + .files = std.StringHashMap(std.ArrayList(u8)).init(allocator.*), + .directories = std.ArrayList([]u8).init(allocator.*), + }; + } + + pub fn deinit(self: *BufferExtractWriter) void { + var iter = self.files.iterator(); + while (iter.next()) |entry| { + self.allocator.free(entry.key_ptr.*); + entry.value_ptr.deinit(); + } + self.files.deinit(); + + for (self.directories.items) |dir| { + self.allocator.free(dir); + } + self.directories.deinit(); + } + + pub fn extractWriter(self: *BufferExtractWriter) ExtractWriter { + return ExtractWriter{ + .ptr = self, + .vtable = &vtable, + }; + } + + const vtable = ExtractWriter.VTable{ + .createFile = createFile, + .finishFile = finishFile, + .makeDir = makeDir, + }; + + fn createFile(ptr: *anyopaque, path: []const u8) ExtractWriter.CreateFileError!std.io.AnyWriter { + const self: *BufferExtractWriter = @ptrCast(@alignCast(ptr)); + + const key = self.allocator.dupe(u8, path) catch return error.OutOfMemory; + errdefer self.allocator.free(key); + + const result = self.files.getOrPut(key) catch return error.OutOfMemory; + if (result.found_existing) { + self.allocator.free(key); + result.value_ptr.clearRetainingCapacity(); + } else { + result.value_ptr.* = std.ArrayList(u8).init(self.allocator.*); + } + + self.current_file = result.value_ptr; + return result.value_ptr.writer().any(); + } + + fn finishFile(ptr: *anyopaque, _: std.io.AnyWriter) ExtractWriter.FinishFileError!void { + const self: *BufferExtractWriter = @ptrCast(@alignCast(ptr)); + self.current_file = null; + } + + fn makeDir(ptr: *anyopaque, path: []const u8) ExtractWriter.MakeDirError!void { + const self: *BufferExtractWriter = @ptrCast(@alignCast(ptr)); + const dir_path = self.allocator.dupe(u8, path) catch return error.OutOfMemory; + self.directories.append(dir_path) catch { + self.allocator.free(dir_path); + return error.OutOfMemory; + }; + } +}; + +/// Validate a base58-encoded hash string and return the decoded hash. +/// Returns null if the hash is invalid. +pub fn validateBase58Hash(base58_hash: []const u8) !?[32]u8 { + if (base58_hash.len > base58.base58_hash_bytes) { + return null; + } + + var hash: [32]u8 = undefined; + base58.decode(base58_hash, &hash) catch return null; + return hash; +} + +/// Check if a path has any unbundling errors (security and compatibility issues) +pub fn pathHasUnbundleErr(path: []const u8) ?PathValidationError { + if (path.len == 0) { + return PathValidationError{ + .path = path, + .reason = .empty_path, + }; + } + + // Check for absolute paths + if (path[0] == '/' or (builtin.target.os.tag == .windows and path.len >= 2 and path[1] == ':')) { + return PathValidationError{ + .path = path, + .reason = .absolute_path, + }; + } + + // Check for path traversal attempts + if (std.mem.indexOf(u8, path, "..") != null) { + return PathValidationError{ + .path = path, + .reason = .path_traversal, + }; + } + + // Check for current directory references + if (std.mem.eql(u8, path, ".") or std.mem.indexOf(u8, path, "./") != null or std.mem.indexOf(u8, path, "/.") != null) { + return PathValidationError{ + .path = path, + .reason = .current_directory_reference, + }; + } + + return null; +} + +/// Unbundle files from a compressed tar archive stream. +/// +/// This is the core streaming unbundle logic that can be used by both file-based +/// unbundling and network-based downloading. +/// If an InvalidPath error is returned, error_context will contain details about the invalid path. +pub fn unbundleStream( + input_reader: anytype, + extract_writer: ExtractWriter, + allocator: *std.mem.Allocator, + expected_hash: *const [32]u8, + error_context: ?*ErrorContext, +) UnbundleError!void { + // Create a hashing reader to verify the hash while reading + var hasher = std.crypto.hash.Blake3.init(.{}); + const HashingReader = struct { + child_reader: @TypeOf(input_reader), + hasher: *std.crypto.hash.Blake3, + + pub const Error = @TypeOf(input_reader).Error; + pub const Reader = std.io.Reader(@This(), Error, read); + + pub fn read(self: @This(), buffer: []u8) Error!usize { + const n = try self.child_reader.read(buffer); + if (n > 0) { + self.hasher.update(buffer[0..n]); + } + return n; + } + + pub fn reader(self: @This()) Reader { + return .{ .context = self }; + } + }; + + var hashing_reader = HashingReader{ + .child_reader = input_reader, + .hasher = &hasher, + }; + + // Create zstandard decompressor + var zstd_stream = std.compress.zstandard.decompressStream(allocator.*, hashing_reader.reader()); + defer zstd_stream.deinit(); + const decompressed_reader = zstd_stream.reader(); + + // Create tar reader + var tar_iterator = std.tar.iterator(decompressed_reader, .{ + .max_file_size = std.math.maxInt(usize), // No limit on file size + }); + + var data_extracted = false; + + // Process all tar entries + while (try tar_iterator.next()) |entry| { + const file_path = entry.path; + + // Validate path for security + if (pathHasUnbundleErr(file_path)) |validation_error| { + if (error_context) |ctx| { + ctx.path = validation_error.path; + ctx.reason = validation_error.reason; + } + return error.InvalidPath; + } + + switch (entry.kind) { + .directory => { + try extract_writer.makeDir(file_path); + data_extracted = true; + }, + .file => { + const file_writer = try extract_writer.createFile(file_path); + defer extract_writer.finishFile(file_writer) catch {}; + + // Stream the file content + const file_size = std.math.cast(usize, entry.size) orelse return error.FileTooLarge; + var bytes_remaining = file_size; + var buffer: [STREAM_BUFFER_SIZE]u8 = undefined; + + while (bytes_remaining > 0) { + const to_read = @min(buffer.len, bytes_remaining); + const bytes_read = try entry.reader().readAll(buffer[0..to_read]); + if (bytes_read == 0) return error.UnexpectedEndOfStream; + try file_writer.writeAll(buffer[0..bytes_read]); + bytes_remaining -= bytes_read; + } + + data_extracted = true; + }, + else => { + // Skip other entry types (symlinks, etc.) + try entry.skip(); + }, + } + } + + if (!data_extracted) { + return error.NoDataExtracted; + } + + // Verify the hash + var actual_hash: [32]u8 = undefined; + hasher.final(&actual_hash); + if (!std.mem.eql(u8, &actual_hash, expected_hash)) { + return error.HashMismatch; + } +} + +/// Unbundle files from a compressed tar archive to a directory. +/// +/// The filename parameter should be the base58-encoded blake3 hash + .tar.zst extension. +/// If an InvalidPath error is returned, error_context will contain details about the invalid path. +pub fn unbundle( + input_reader: anytype, + extract_dir: std.fs.Dir, + allocator: *std.mem.Allocator, + filename: []const u8, + error_context: ?*ErrorContext, +) UnbundleError!void { + // Extract expected hash from filename + if (!std.mem.endsWith(u8, filename, TAR_EXTENSION)) { + return error.InvalidFilename; + } + const base58_hash = filename[0 .. filename.len - TAR_EXTENSION.len]; // Remove .tar.zst + const expected_hash = (try validateBase58Hash(base58_hash)) orelse { + return error.InvalidFilename; + }; + + var dir_writer = DirExtractWriter.init(extract_dir); + return unbundleStream(input_reader, dir_writer.extractWriter(), allocator, &expected_hash, error_context); +} \ No newline at end of file