mirror of
https://github.com/roc-lang/roc.git
synced 2025-10-17 15:17:12 +00:00
Split out unbundle, use Zig's zstd for wasm
This commit is contained in:
parent
4e6aa10280
commit
a8dadea7dd
8 changed files with 1290 additions and 33 deletions
|
@ -101,7 +101,7 @@ pub fn build(b: *std.Build) void {
|
|||
});
|
||||
playground_exe.entry = .disabled;
|
||||
playground_exe.rdynamic = true;
|
||||
roc_modules.addAll(playground_exe);
|
||||
roc_modules.addAllExceptBundle(playground_exe);
|
||||
|
||||
add_tracy(b, roc_modules.build_options, playground_exe, b.resolveTargetQuery(.{
|
||||
.cpu_arch = .wasm32,
|
||||
|
@ -126,7 +126,7 @@ pub fn build(b: *std.Build) void {
|
|||
});
|
||||
playground_integration_test_exe.root_module.addImport("bytebox", bytebox.module("bytebox"));
|
||||
playground_integration_test_exe.root_module.addImport("build_options", build_options.createModule());
|
||||
roc_modules.addAll(playground_integration_test_exe);
|
||||
roc_modules.addAllExceptBundle(playground_integration_test_exe);
|
||||
|
||||
const install = b.addInstallArtifact(playground_integration_test_exe, .{});
|
||||
// Ensure playground WASM is built before running the integration test
|
||||
|
|
|
@ -32,6 +32,7 @@ pub const ModuleType = enum {
|
|||
repl,
|
||||
fmt,
|
||||
bundle,
|
||||
unbundle,
|
||||
|
||||
/// Returns the dependencies for this module type
|
||||
pub fn getDependencies(self: ModuleType) []const ModuleType {
|
||||
|
@ -54,6 +55,7 @@ pub const ModuleType = enum {
|
|||
.repl => &.{ .base, .compile, .parse, .types, .can, .check, .builtins, .layout, .eval },
|
||||
.fmt => &.{ .base, .parse, .collections, .can, .fs, .tracy },
|
||||
.bundle => &.{ .base, .collections },
|
||||
.unbundle => &.{ .base, .collections },
|
||||
};
|
||||
}
|
||||
};
|
||||
|
@ -78,6 +80,7 @@ pub const RocModules = struct {
|
|||
repl: *Module,
|
||||
fmt: *Module,
|
||||
bundle: *Module,
|
||||
unbundle: *Module,
|
||||
|
||||
pub fn create(b: *Build, build_options_step: *Step.Options, zstd: ?*Dependency) RocModules {
|
||||
const self = RocModules{
|
||||
|
@ -105,12 +108,14 @@ pub const RocModules = struct {
|
|||
.repl = b.addModule("repl", .{ .root_source_file = b.path("src/repl/mod.zig") }),
|
||||
.fmt = b.addModule("fmt", .{ .root_source_file = b.path("src/fmt/mod.zig") }),
|
||||
.bundle = b.addModule("bundle", .{ .root_source_file = b.path("src/bundle/mod.zig") }),
|
||||
.unbundle = b.addModule("unbundle", .{ .root_source_file = b.path("src/unbundle/mod.zig") }),
|
||||
};
|
||||
|
||||
// Link zstd to bundle module if available
|
||||
// Link zstd to bundle module if available (for compression)
|
||||
if (zstd) |z| {
|
||||
self.bundle.linkLibrary(z.artifact("zstd"));
|
||||
}
|
||||
// Note: unbundle module uses Zig's std zstandard, so doesn't need C library
|
||||
|
||||
// Setup module dependencies using our generic helper
|
||||
self.setupModuleDependencies();
|
||||
|
@ -138,6 +143,7 @@ pub const RocModules = struct {
|
|||
.repl,
|
||||
.fmt,
|
||||
.bundle,
|
||||
.unbundle,
|
||||
};
|
||||
|
||||
// Setup dependencies for each module
|
||||
|
@ -171,12 +177,36 @@ pub const RocModules = struct {
|
|||
step.root_module.addImport("repl", self.repl);
|
||||
step.root_module.addImport("fmt", self.fmt);
|
||||
step.root_module.addImport("bundle", self.bundle);
|
||||
step.root_module.addImport("unbundle", self.unbundle);
|
||||
}
|
||||
|
||||
pub fn addAllToTest(self: RocModules, step: *Step.Compile) void {
|
||||
self.addAll(step);
|
||||
}
|
||||
|
||||
/// Add all modules except bundle (useful for wasm32 targets where zstd isn't available)
|
||||
pub fn addAllExceptBundle(self: RocModules, step: *Step.Compile) void {
|
||||
step.root_module.addImport("base", self.base);
|
||||
step.root_module.addImport("collections", self.collections);
|
||||
step.root_module.addImport("types", self.types);
|
||||
step.root_module.addImport("compile", self.compile);
|
||||
step.root_module.addImport("reporting", self.reporting);
|
||||
step.root_module.addImport("parse", self.parse);
|
||||
step.root_module.addImport("can", self.can);
|
||||
step.root_module.addImport("check", self.check);
|
||||
step.root_module.addImport("tracy", self.tracy);
|
||||
step.root_module.addImport("builtins", self.builtins);
|
||||
step.root_module.addImport("fs", self.fs);
|
||||
step.root_module.addImport("build_options", self.build_options);
|
||||
step.root_module.addImport("layout", self.layout);
|
||||
step.root_module.addImport("eval", self.eval);
|
||||
step.root_module.addImport("ipc", self.ipc);
|
||||
step.root_module.addImport("repl", self.repl);
|
||||
step.root_module.addImport("fmt", self.fmt);
|
||||
step.root_module.addImport("unbundle", self.unbundle);
|
||||
// Intentionally omitting bundle module (requires C zstd library)
|
||||
}
|
||||
|
||||
/// Get a module by its type
|
||||
pub fn getModule(self: RocModules, module_type: ModuleType) *Module {
|
||||
return switch (module_type) {
|
||||
|
@ -198,6 +228,7 @@ pub const RocModules = struct {
|
|||
.repl => self.repl,
|
||||
.fmt => self.fmt,
|
||||
.bundle => self.bundle,
|
||||
.unbundle => self.unbundle,
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -210,7 +241,7 @@ pub const RocModules = struct {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn createModuleTests(self: RocModules, b: *Build, target: ResolvedTarget, optimize: OptimizeMode, zstd: ?*Dependency) [16]ModuleTest {
|
||||
pub fn createModuleTests(self: RocModules, b: *Build, target: ResolvedTarget, optimize: OptimizeMode, zstd: ?*Dependency) [17]ModuleTest {
|
||||
const test_configs = [_]ModuleType{
|
||||
.collections,
|
||||
.base,
|
||||
|
@ -228,6 +259,7 @@ pub const RocModules = struct {
|
|||
.repl,
|
||||
.fmt,
|
||||
.bundle,
|
||||
.unbundle,
|
||||
};
|
||||
|
||||
var tests: [test_configs.len]ModuleTest = undefined;
|
||||
|
@ -241,6 +273,7 @@ pub const RocModules = struct {
|
|||
.optimize = optimize,
|
||||
// IPC module needs libc for mmap, munmap, close on POSIX systems
|
||||
// Bundle module needs libc for zstd
|
||||
// Unbundle module doesn't need libc (uses Zig's std zstandard)
|
||||
.link_libc = (module_type == .ipc or module_type == .bundle),
|
||||
});
|
||||
|
||||
|
|
|
@ -1,36 +1,28 @@
|
|||
//! Bundle and unbundle functionality for Roc packages
|
||||
//! Bundle functionality for Roc packages
|
||||
//!
|
||||
//! This module provides functionality to:
|
||||
//! - Bundle Roc packages and their dependencies into compressed tar archives
|
||||
//! - Unbundle these archives to restore the original files
|
||||
//! - Download and extract bundled archives from HTTPS URLs
|
||||
//! - Validate paths for security and cross-platform compatibility
|
||||
//!
|
||||
//! Note: This module requires the C zstd library for compression.
|
||||
//! For unbundling functionality that works on all platforms (including WebAssembly),
|
||||
//! see the separate `unbundle` module.
|
||||
|
||||
pub const bundle = @import("bundle.zig");
|
||||
pub const download = @import("download.zig");
|
||||
pub const streaming_reader = @import("streaming_reader.zig");
|
||||
pub const streaming_writer = @import("streaming_writer.zig");
|
||||
pub const base58 = @import("base58.zig");
|
||||
|
||||
// Re-export commonly used functions and types
|
||||
pub const bundleFiles = bundle.bundle;
|
||||
pub const unbundle = bundle.unbundle;
|
||||
pub const unbundleStream = bundle.unbundleStream;
|
||||
pub const validateBase58Hash = bundle.validateBase58Hash;
|
||||
pub const pathHasBundleErr = bundle.pathHasBundleErr;
|
||||
pub const pathHasUnbundleErr = bundle.pathHasUnbundleErr;
|
||||
pub const validateBase58Hash = bundle.validateBase58Hash;
|
||||
|
||||
// Re-export error types
|
||||
pub const BundleError = bundle.BundleError;
|
||||
pub const UnbundleError = bundle.UnbundleError;
|
||||
pub const PathValidationError = bundle.PathValidationError;
|
||||
pub const PathValidationReason = bundle.PathValidationReason;
|
||||
pub const ErrorContext = bundle.ErrorContext;
|
||||
|
||||
// Re-export extract writer types
|
||||
pub const ExtractWriter = bundle.ExtractWriter;
|
||||
pub const DirExtractWriter = bundle.DirExtractWriter;
|
||||
|
||||
// Re-export constants
|
||||
pub const STREAM_BUFFER_SIZE = bundle.STREAM_BUFFER_SIZE;
|
||||
pub const DEFAULT_COMPRESSION_LEVEL = bundle.DEFAULT_COMPRESSION_LEVEL;
|
||||
|
@ -39,22 +31,15 @@ pub const DEFAULT_COMPRESSION_LEVEL = bundle.DEFAULT_COMPRESSION_LEVEL;
|
|||
pub const allocForZstd = bundle.allocForZstd;
|
||||
pub const freeForZstd = bundle.freeForZstd;
|
||||
|
||||
// Re-export download functions
|
||||
pub const downloadBundle = download.download;
|
||||
pub const validateUrl = download.validateUrl;
|
||||
pub const DownloadError = download.DownloadError;
|
||||
|
||||
// Test coverage includes:
|
||||
// - Path validation for security and cross-platform compatibility
|
||||
// - Bundle/unbundle roundtrip with various file types and sizes
|
||||
// - Hash verification and corruption detection
|
||||
// - Streaming compression/decompression
|
||||
// - Download URL validation and security checks
|
||||
// - Large file handling (with std.tar limitations)
|
||||
// - Bundle creation with various file types and sizes
|
||||
// - Hash generation
|
||||
// - Streaming compression
|
||||
// - Large file handling
|
||||
test {
|
||||
_ = @import("test_bundle.zig");
|
||||
_ = @import("test_streaming.zig");
|
||||
_ = bundle;
|
||||
_ = download;
|
||||
_ = base58;
|
||||
}
|
||||
}
|
|
@ -15,6 +15,7 @@ const compile = @import("compile");
|
|||
const can = @import("can");
|
||||
const check = @import("check");
|
||||
const bundle = @import("bundle");
|
||||
const unbundle = @import("unbundle");
|
||||
const ipc = @import("ipc");
|
||||
const fmt = @import("fmt");
|
||||
|
||||
|
@ -1046,7 +1047,7 @@ pub fn extractReadRocFilePathShimLibrary(gpa: Allocator, output_path: []const u8
|
|||
}
|
||||
|
||||
/// Format a path validation reason into a user-friendly error message
|
||||
fn formatPathValidationReason(reason: bundle.PathValidationReason) []const u8 {
|
||||
fn formatPathValidationReason(reason: unbundle.PathValidationReason) []const u8 {
|
||||
return switch (reason) {
|
||||
.empty_path => "Path cannot be empty",
|
||||
.path_too_long => "Path exceeds maximum length of 255 characters",
|
||||
|
@ -1288,8 +1289,8 @@ fn rocUnbundle(gpa: Allocator, args: cli_args.UnbundleArgs) !void {
|
|||
|
||||
// Unbundle the archive
|
||||
var allocator_copy2 = arena_allocator;
|
||||
var error_ctx: bundle.ErrorContext = undefined;
|
||||
bundle.unbundle(
|
||||
var error_ctx: unbundle.ErrorContext = undefined;
|
||||
unbundle.unbundleFiles(
|
||||
archive_file.reader(),
|
||||
output_dir,
|
||||
&allocator_copy2,
|
||||
|
|
593
src/unbundle/base58.zig
Normal file
593
src/unbundle/base58.zig
Normal file
|
@ -0,0 +1,593 @@
|
|||
//! Base58 encoding and decoding for BLAKE3 hashes
|
||||
//!
|
||||
//! This module provides base58 encoding/decoding specifically optimized for 256-bit BLAKE3 hashes.
|
||||
//! The base58 alphabet excludes visually similar characters (0, O, I, l) to prevent confusion.
|
||||
|
||||
const std = @import("std");
|
||||
|
||||
// Base58 alphabet (no '0', 'O', 'I', or 'l' to deter visual similarity attacks.)
|
||||
const base58_alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz";
|
||||
|
||||
/// We use 256-bit BLAKE3 hashes
|
||||
const hash_bytes = 32;
|
||||
|
||||
/// Number of Base58 characters needed to represent a 256-bit hash
|
||||
///
|
||||
/// Math:
|
||||
/// - Each base58 character can represent 58 values
|
||||
/// - So we need ceil(log58(MAX_HASH + 1)) characters
|
||||
/// - Which is ceil(log(MAX_HASH + 1) / log(58))
|
||||
/// - Since MAX_HASH is 2^256 - 1, we need ceil(256 * log(2) / log(58))
|
||||
/// - 256 * log(2) / log(58) ≈ 43.7
|
||||
///
|
||||
/// So we need 44 characters.
|
||||
pub const base58_hash_bytes = 44;
|
||||
|
||||
/// Encode the given slice of 32 bytes as a base58 string and write it to the destination.
|
||||
/// Returns a slice of the destination containing the encoded string (1-45 characters).
|
||||
pub fn encode(src: *const [hash_bytes]u8, dest: *[base58_hash_bytes]u8) []u8 {
|
||||
// Count leading zero bytes
|
||||
var leading_zeros: usize = 0;
|
||||
while (leading_zeros < src.len and src[leading_zeros] == 0) {
|
||||
leading_zeros += 1;
|
||||
}
|
||||
|
||||
if (leading_zeros == src.len) {
|
||||
// All zeros - return just the leading '1's
|
||||
@memset(dest[0..leading_zeros], '1');
|
||||
return dest[0..leading_zeros];
|
||||
}
|
||||
|
||||
// Make a mutable scratch copy of the source
|
||||
var scratch: [hash_bytes]u8 = undefined;
|
||||
@memcpy(&scratch, src);
|
||||
|
||||
var write_idx: isize = base58_hash_bytes - 1;
|
||||
const start: usize = leading_zeros;
|
||||
|
||||
// Repeatedly divide scratch[start..] by 58, collecting remainder
|
||||
// We need to keep dividing until the entire number becomes zero
|
||||
var has_nonzero = true;
|
||||
while (has_nonzero) {
|
||||
var remainder: u16 = 0;
|
||||
has_nonzero = false;
|
||||
for (scratch[start..]) |*byte| {
|
||||
const value = (@as(u16, remainder) << 8) | byte.*;
|
||||
byte.* = @intCast(value / 58);
|
||||
remainder = value % 58;
|
||||
if (byte.* != 0) has_nonzero = true;
|
||||
}
|
||||
dest[@intCast(write_idx)] = base58_alphabet[@intCast(remainder)];
|
||||
write_idx -= 1;
|
||||
}
|
||||
|
||||
// Now combine leading '1's with the encoded value
|
||||
const encoded_start = @as(usize, @intCast(write_idx + 1));
|
||||
const encoded_len = base58_hash_bytes - encoded_start;
|
||||
|
||||
// Write leading '1's at the beginning
|
||||
@memset(dest[0..leading_zeros], '1');
|
||||
|
||||
// Move the encoded data to follow the '1's
|
||||
std.mem.copyForwards(u8, dest[leading_zeros .. leading_zeros + encoded_len], dest[encoded_start..base58_hash_bytes]);
|
||||
|
||||
return dest[0 .. leading_zeros + encoded_len];
|
||||
}
|
||||
|
||||
test "encode - all zero bytes" {
|
||||
const input = [_]u8{0} ** 32;
|
||||
var output: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&input, &output);
|
||||
|
||||
// Should produce all '1's
|
||||
for (encoded) |char| {
|
||||
try std.testing.expectEqual('1', char);
|
||||
}
|
||||
try std.testing.expectEqual(@as(usize, 32), encoded.len);
|
||||
}
|
||||
|
||||
test "encode - some leading zero bytes" {
|
||||
var input = [_]u8{0} ** 32;
|
||||
input[3] = 1;
|
||||
input[4] = 2;
|
||||
input[5] = 3;
|
||||
var output: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&input, &output);
|
||||
|
||||
// Should have 3 leading '1's
|
||||
try std.testing.expectEqual('1', encoded[0]);
|
||||
try std.testing.expectEqual('1', encoded[1]);
|
||||
try std.testing.expectEqual('1', encoded[2]);
|
||||
// Rest should be valid base58
|
||||
for (encoded) |char| {
|
||||
const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null;
|
||||
try std.testing.expect(is_valid);
|
||||
}
|
||||
}
|
||||
|
||||
test "encode - single non-zero byte at end" {
|
||||
var input = [_]u8{0} ** 32;
|
||||
input[31] = 255;
|
||||
var output: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&input, &output);
|
||||
|
||||
// Should produce valid base58 ending with encoded value of 255
|
||||
for (encoded) |char| {
|
||||
const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null;
|
||||
try std.testing.expect(is_valid);
|
||||
}
|
||||
}
|
||||
|
||||
test "encode - known 32-byte test vector" {
|
||||
// 32-byte test vector
|
||||
const input = [_]u8{
|
||||
0x00, 0x01, 0x09, 0x66, 0x77, 0x00, 0x06, 0x95,
|
||||
0x3D, 0x55, 0x67, 0x43, 0x9E, 0x5E, 0x39, 0xF8,
|
||||
0x6A, 0x0D, 0x27, 0x3B, 0xEE, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
};
|
||||
var output: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&input, &output);
|
||||
|
||||
// Verify all chars are valid base58
|
||||
for (encoded) |char| {
|
||||
const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null;
|
||||
try std.testing.expect(is_valid);
|
||||
}
|
||||
}
|
||||
|
||||
test "encode - text padded to 32 bytes" {
|
||||
// "Hello World" padded to 32 bytes
|
||||
var input = [_]u8{0} ** 32;
|
||||
const text = "Hello World";
|
||||
@memcpy(input[0..text.len], text);
|
||||
var output: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&input, &output);
|
||||
|
||||
// Verify all chars are valid base58
|
||||
for (encoded) |char| {
|
||||
const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null;
|
||||
try std.testing.expect(is_valid);
|
||||
}
|
||||
}
|
||||
|
||||
test "encode - 32-byte hash (typical use case)" {
|
||||
// SHA256 hash of "test"
|
||||
const input = [_]u8{
|
||||
0x9f, 0x86, 0xd0, 0x81, 0x88, 0x4c, 0x7d, 0x65,
|
||||
0x9a, 0x2f, 0xea, 0xa0, 0xc5, 0x5a, 0xd0, 0x15,
|
||||
0xa3, 0xbf, 0x4f, 0x1b, 0x2b, 0x0b, 0x82, 0x2c,
|
||||
0xd1, 0x5d, 0x6c, 0x15, 0xb0, 0xf0, 0x0a, 0x08,
|
||||
};
|
||||
var output: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&input, &output);
|
||||
|
||||
// Should produce a valid base58 string
|
||||
// Check that all characters are in the base58 alphabet
|
||||
for (encoded) |char| {
|
||||
const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null;
|
||||
try std.testing.expect(is_valid);
|
||||
}
|
||||
}
|
||||
|
||||
test "encode - all 0xFF bytes" {
|
||||
const input = [_]u8{0xFF} ** 32;
|
||||
var output: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&input, &output);
|
||||
|
||||
// Should produce a valid base58 string
|
||||
for (encoded) |char| {
|
||||
const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null;
|
||||
try std.testing.expect(is_valid);
|
||||
}
|
||||
}
|
||||
|
||||
test "encode - mixed bytes with leading zeros (32 bytes)" {
|
||||
var input = [_]u8{0} ** 32;
|
||||
input[3] = 1;
|
||||
input[4] = 2;
|
||||
input[5] = 3;
|
||||
input[6] = 4;
|
||||
input[7] = 5;
|
||||
var output: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&input, &output);
|
||||
|
||||
// Should have 3 leading '1's
|
||||
try std.testing.expectEqual('1', encoded[0]);
|
||||
try std.testing.expectEqual('1', encoded[1]);
|
||||
try std.testing.expectEqual('1', encoded[2]);
|
||||
|
||||
// All should be valid base58
|
||||
for (encoded) |char| {
|
||||
const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null;
|
||||
try std.testing.expect(is_valid);
|
||||
}
|
||||
}
|
||||
|
||||
test "encode - incremental values" {
|
||||
// Test that incrementing input produces different outputs
|
||||
var prev_output: [base58_hash_bytes]u8 = undefined;
|
||||
var curr_output: [base58_hash_bytes]u8 = undefined;
|
||||
|
||||
var input1 = [_]u8{0} ** 32;
|
||||
input1[31] = 1;
|
||||
const encoded1 = encode(&input1, &prev_output);
|
||||
|
||||
var input2 = [_]u8{0} ** 32;
|
||||
input2[31] = 2;
|
||||
const encoded2 = encode(&input2, &curr_output);
|
||||
|
||||
// Outputs should be different
|
||||
try std.testing.expect(!std.mem.eql(u8, encoded1, encoded2));
|
||||
}
|
||||
|
||||
test "encode - power of two boundaries" {
|
||||
// Test encoding at power-of-two boundaries in 32-byte inputs
|
||||
const test_values = [_]u8{ 1, 2, 4, 128, 255 };
|
||||
|
||||
for (test_values) |val| {
|
||||
var input = [_]u8{0} ** 32;
|
||||
input[31] = val;
|
||||
var output: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&input, &output);
|
||||
|
||||
// Just verify it produces valid base58
|
||||
for (encoded) |char| {
|
||||
const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null;
|
||||
try std.testing.expect(is_valid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test "encode - alternating bits pattern" {
|
||||
// Test with alternating 0xAA and 0x55 pattern
|
||||
var input = [_]u8{ 0xAA, 0x55 } ** 16;
|
||||
var output: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&input, &output);
|
||||
|
||||
// Should produce valid base58 chars
|
||||
for (encoded) |char| {
|
||||
const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null;
|
||||
try std.testing.expect(is_valid);
|
||||
}
|
||||
}
|
||||
|
||||
test "encode - sequential bytes" {
|
||||
// Test with sequential byte values
|
||||
var input: [hash_bytes]u8 = undefined;
|
||||
for (0..hash_bytes) |i| {
|
||||
input[i] = @intCast(i);
|
||||
}
|
||||
var output: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&input, &output);
|
||||
|
||||
// Should produce valid base58 chars
|
||||
for (encoded) |char| {
|
||||
const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null;
|
||||
try std.testing.expect(is_valid);
|
||||
}
|
||||
}
|
||||
|
||||
test "encode - high entropy random-like data" {
|
||||
// Test with pseudo-random looking data (using prime multiplication)
|
||||
var input: [hash_bytes]u8 = undefined;
|
||||
var val: u32 = 17;
|
||||
for (0..hash_bytes) |i| {
|
||||
val = (val *% 31) +% 37;
|
||||
input[i] = @truncate(val);
|
||||
}
|
||||
var output: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&input, &output);
|
||||
|
||||
// Should produce valid base58 chars
|
||||
for (encoded) |char| {
|
||||
const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null;
|
||||
try std.testing.expect(is_valid);
|
||||
}
|
||||
}
|
||||
|
||||
test "encode - single bit set in different positions" {
|
||||
// Test with single bit set at different byte positions
|
||||
const positions = [_]usize{ 0, 7, 15, 16, 24, 31 };
|
||||
|
||||
for (positions) |pos| {
|
||||
var input = [_]u8{0} ** 32;
|
||||
input[pos] = 1;
|
||||
var output: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&input, &output);
|
||||
|
||||
// All outputs should be different
|
||||
for (encoded) |char| {
|
||||
const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null;
|
||||
try std.testing.expect(is_valid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test "encode - max value in different positions" {
|
||||
// Test with 0xFF in different positions
|
||||
const positions = [_]usize{ 0, 8, 16, 24, 31 };
|
||||
var outputs: [positions.len][base58_hash_bytes]u8 = undefined;
|
||||
var encoded_slices: [positions.len][]u8 = undefined;
|
||||
|
||||
for (positions, 0..) |pos, i| {
|
||||
var input = [_]u8{0} ** 32;
|
||||
input[pos] = 0xFF;
|
||||
encoded_slices[i] = encode(&input, &outputs[i]);
|
||||
|
||||
// Verify all chars are valid
|
||||
for (encoded_slices[i]) |char| {
|
||||
const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null;
|
||||
try std.testing.expect(is_valid);
|
||||
}
|
||||
}
|
||||
|
||||
// All outputs should be unique
|
||||
for (0..outputs.len) |i| {
|
||||
for (i + 1..outputs.len) |j| {
|
||||
try std.testing.expect(!std.mem.eql(u8, encoded_slices[i], encoded_slices[j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test "encode - known sha256 hash values" {
|
||||
// Test with actual SHA256 hash outputs
|
||||
const test_cases = [_][hash_bytes]u8{
|
||||
// SHA256("")
|
||||
[_]u8{
|
||||
0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14,
|
||||
0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24,
|
||||
0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c,
|
||||
0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55,
|
||||
},
|
||||
// SHA256("abc")
|
||||
[_]u8{
|
||||
0xba, 0x78, 0x16, 0xbf, 0x8f, 0x01, 0xcf, 0xea,
|
||||
0x41, 0x41, 0x40, 0xde, 0x5d, 0xae, 0x22, 0x23,
|
||||
0xb0, 0x03, 0x61, 0xa3, 0x96, 0x17, 0x7a, 0x9c,
|
||||
0xb4, 0x10, 0xff, 0x61, 0xf2, 0x00, 0x15, 0xad,
|
||||
},
|
||||
};
|
||||
|
||||
for (test_cases) |input| {
|
||||
var output: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&input, &output);
|
||||
|
||||
// Should produce valid base58 chars
|
||||
for (encoded) |char| {
|
||||
const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null;
|
||||
try std.testing.expect(is_valid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test "encode - boundary values" {
|
||||
// Test edge cases around byte boundaries
|
||||
const test_cases = [_]struct {
|
||||
desc: []const u8,
|
||||
input: [hash_bytes]u8,
|
||||
}{
|
||||
.{ .desc = "minimum after all zeros", .input = blk: {
|
||||
var arr = [_]u8{0} ** 32;
|
||||
arr[31] = 1;
|
||||
break :blk arr;
|
||||
} },
|
||||
.{
|
||||
.desc = "one less than power of 58",
|
||||
.input = blk: {
|
||||
var arr = [_]u8{0} ** 32;
|
||||
arr[31] = 57; // 58 - 1
|
||||
break :blk arr;
|
||||
},
|
||||
},
|
||||
.{ .desc = "exactly power of 58", .input = blk: {
|
||||
var arr = [_]u8{0} ** 32;
|
||||
arr[31] = 58;
|
||||
break :blk arr;
|
||||
} },
|
||||
.{ .desc = "high bytes in middle", .input = blk: {
|
||||
var arr = [_]u8{0} ** 32;
|
||||
arr[15] = 0xFF;
|
||||
arr[16] = 0xFF;
|
||||
break :blk arr;
|
||||
} },
|
||||
};
|
||||
|
||||
for (test_cases) |tc| {
|
||||
var output: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&tc.input, &output);
|
||||
|
||||
// All should produce valid base58
|
||||
for (encoded) |char| {
|
||||
const is_valid = std.mem.indexOfScalar(u8, base58_alphabet, char) != null;
|
||||
try std.testing.expect(is_valid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test "encode - deterministic output" {
|
||||
// Ensure encoding is deterministic - same input always produces same output
|
||||
const input = [_]u8{
|
||||
0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0,
|
||||
0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88,
|
||||
0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00,
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
|
||||
};
|
||||
|
||||
var output1: [base58_hash_bytes]u8 = undefined;
|
||||
var output2: [base58_hash_bytes]u8 = undefined;
|
||||
var output3: [base58_hash_bytes]u8 = undefined;
|
||||
|
||||
const encoded1 = encode(&input, &output1);
|
||||
const encoded2 = encode(&input, &output2);
|
||||
const encoded3 = encode(&input, &output3);
|
||||
|
||||
// All outputs should be identical
|
||||
try std.testing.expectEqualSlices(u8, encoded1, encoded2);
|
||||
try std.testing.expectEqualSlices(u8, encoded2, encoded3);
|
||||
}
|
||||
|
||||
/// Decode a base58 string back to 32 bytes.
|
||||
/// Returns InvalidBase58 error if the string contains invalid characters.
|
||||
pub fn decode(src: []const u8, dest: *[hash_bytes]u8) !void {
|
||||
// Clear destination - needed because the multiplication algorithm
|
||||
// accumulates values across the entire buffer
|
||||
@memset(dest, 0);
|
||||
|
||||
// Count leading '1's (representing leading zeros in the output)
|
||||
var leading_ones: usize = 0;
|
||||
for (src) |char| {
|
||||
if (char == '1') {
|
||||
leading_ones += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If all '1's, we're done (all zeros)
|
||||
if (leading_ones == src.len) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Process each character from the input
|
||||
for (src) |char| {
|
||||
// Find the value of this character
|
||||
const char_value = blk: {
|
||||
for (base58_alphabet, 0..) |alpha_char, i| {
|
||||
if (char == alpha_char) {
|
||||
break :blk @as(u8, @intCast(i));
|
||||
}
|
||||
}
|
||||
return error.InvalidBase58;
|
||||
};
|
||||
|
||||
// Multiply dest by 58 and add char_value
|
||||
var carry: u16 = char_value;
|
||||
var j: usize = hash_bytes;
|
||||
while (j > 0) {
|
||||
j -= 1;
|
||||
const value = @as(u16, dest[j]) * 58 + carry;
|
||||
dest[j] = @truncate(value);
|
||||
carry = value >> 8;
|
||||
}
|
||||
|
||||
// If we still have carry, the number is too large
|
||||
if (carry != 0) {
|
||||
return error.InvalidBase58;
|
||||
}
|
||||
}
|
||||
|
||||
// Count actual leading zeros we produced
|
||||
var actual_zeros: usize = 0;
|
||||
for (dest.*) |byte| {
|
||||
if (byte == 0) {
|
||||
actual_zeros += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Standard base58: ensure we have exactly the right number of leading zeros
|
||||
// Each leading '1' in input should produce one leading zero in output
|
||||
if (actual_zeros < leading_ones) {
|
||||
const shift = leading_ones - actual_zeros;
|
||||
// Shift data right to make room for more zeros
|
||||
var i: usize = hash_bytes;
|
||||
while (i > shift) {
|
||||
i -= 1;
|
||||
dest[i] = dest[i - shift];
|
||||
}
|
||||
@memset(dest[0..shift], 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Tests for decode
|
||||
test "decode - all ones" {
|
||||
const input = [_]u8{'1'} ** base58_hash_bytes;
|
||||
var output: [hash_bytes]u8 = undefined;
|
||||
try decode(&input, &output);
|
||||
|
||||
// Should produce all zeros
|
||||
for (output) |byte| {
|
||||
try std.testing.expectEqual(@as(u8, 0), byte);
|
||||
}
|
||||
}
|
||||
|
||||
test "decode - roundtrip simple" {
|
||||
// Test roundtrip: encode then decode
|
||||
var original = [_]u8{0} ** 32;
|
||||
original[31] = 42;
|
||||
|
||||
var encoded_buf: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&original, &encoded_buf);
|
||||
|
||||
var decoded: [hash_bytes]u8 = undefined;
|
||||
try decode(encoded, &decoded);
|
||||
|
||||
try std.testing.expectEqualSlices(u8, &original, &decoded);
|
||||
}
|
||||
|
||||
test "decode - roundtrip all values" {
|
||||
// Test roundtrip with all different byte values in last position
|
||||
for (0..256) |val| {
|
||||
var original = [_]u8{0} ** 32;
|
||||
original[31] = @intCast(val);
|
||||
|
||||
var encoded_buf: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&original, &encoded_buf);
|
||||
|
||||
var decoded: [hash_bytes]u8 = undefined;
|
||||
try decode(encoded, &decoded);
|
||||
|
||||
try std.testing.expectEqualSlices(u8, &original, &decoded);
|
||||
}
|
||||
}
|
||||
|
||||
test "decode - roundtrip with leading zeros" {
|
||||
var original = [_]u8{0} ** 32;
|
||||
original[5] = 1;
|
||||
original[6] = 2;
|
||||
original[7] = 3;
|
||||
|
||||
var encoded_buf: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&original, &encoded_buf);
|
||||
|
||||
var decoded: [hash_bytes]u8 = undefined;
|
||||
try decode(encoded, &decoded);
|
||||
|
||||
try std.testing.expectEqualSlices(u8, &original, &decoded);
|
||||
}
|
||||
|
||||
test "decode - invalid character" {
|
||||
var input = [_]u8{'1'} ** base58_hash_bytes;
|
||||
input[20] = '0'; // '0' is not in base58 alphabet
|
||||
|
||||
var output: [hash_bytes]u8 = undefined;
|
||||
const result = decode(&input, &output);
|
||||
|
||||
try std.testing.expectError(error.InvalidBase58, result);
|
||||
}
|
||||
|
||||
test "decode - roundtrip random patterns" {
|
||||
const patterns = [_][hash_bytes]u8{
|
||||
[_]u8{0xFF} ** 32,
|
||||
[_]u8{ 0xAA, 0x55 } ** 16,
|
||||
blk: {
|
||||
var arr: [hash_bytes]u8 = undefined;
|
||||
for (0..hash_bytes) |i| {
|
||||
arr[i] = @intCast(i * 7);
|
||||
}
|
||||
break :blk arr;
|
||||
},
|
||||
};
|
||||
|
||||
for (patterns) |original| {
|
||||
var encoded_buf: [base58_hash_bytes]u8 = undefined;
|
||||
const encoded = encode(&original, &encoded_buf);
|
||||
|
||||
var decoded: [hash_bytes]u8 = undefined;
|
||||
try decode(encoded, &decoded);
|
||||
|
||||
try std.testing.expectEqualSlices(u8, &original, &decoded);
|
||||
}
|
||||
}
|
227
src/unbundle/download.zig
Normal file
227
src/unbundle/download.zig
Normal file
|
@ -0,0 +1,227 @@
|
|||
//! Download and extract bundled tar.zst files over https
|
||||
//! (or http if the URL host is `localhost`, `127.0.0.1`, or `::1`)
|
||||
|
||||
const std = @import("std");
|
||||
const builtin = @import("builtin");
|
||||
const unbundle = @import("unbundle.zig");
|
||||
|
||||
// Network constants
|
||||
const HTTPS_DEFAULT_PORT: u16 = 443;
|
||||
const HTTP_DEFAULT_PORT: u16 = 80;
|
||||
const SERVER_HEADER_BUFFER_SIZE: usize = 16 * 1024;
|
||||
|
||||
// IPv4 loopback address 127.0.0.1 in network byte order
|
||||
const IPV4_LOOPBACK_BE: u32 = 0x7F000001; // Big-endian
|
||||
const IPV4_LOOPBACK_LE: u32 = 0x0100007F; // Little-endian
|
||||
|
||||
/// Errors that can occur during the download operation.
|
||||
pub const DownloadError = error{
|
||||
InvalidUrl,
|
||||
LocalhostWasNotLoopback,
|
||||
InvalidHash,
|
||||
HttpError,
|
||||
NoHashInUrl,
|
||||
} || unbundle.UnbundleError || std.mem.Allocator.Error;
|
||||
|
||||
/// Parse URL and validate it meets our security requirements.
|
||||
/// Returns the hash from the URL if valid.
|
||||
pub fn validateUrl(url: []const u8) DownloadError![]const u8 {
|
||||
// Check for https:// prefix
|
||||
if (std.mem.startsWith(u8, url, "https://")) {
|
||||
// This is fine, extract hash from last segment
|
||||
} else if (std.mem.startsWith(u8, url, "http://127.0.0.1:") or std.mem.startsWith(u8, url, "http://127.0.0.1/")) {
|
||||
// This is allowed for local testing (IPv4 loopback)
|
||||
} else if (std.mem.startsWith(u8, url, "http://[::1]:") or std.mem.startsWith(u8, url, "http://[::1]/")) {
|
||||
// This is allowed for local testing (IPv6 loopback)
|
||||
} else if (std.mem.startsWith(u8, url, "http://localhost:") or std.mem.startsWith(u8, url, "http://localhost/")) {
|
||||
// This is allowed but will require verification that localhost resolves to loopback
|
||||
} else {
|
||||
return error.InvalidUrl;
|
||||
}
|
||||
|
||||
// Extract the last path segment (should be the hash)
|
||||
const last_slash = std.mem.lastIndexOf(u8, url, "/") orelse return error.NoHashInUrl;
|
||||
const hash_part = url[last_slash + 1 ..];
|
||||
|
||||
// Remove .tar.zst extension if present
|
||||
const hash = if (std.mem.endsWith(u8, hash_part, ".tar.zst"))
|
||||
hash_part[0 .. hash_part.len - 8]
|
||||
else
|
||||
hash_part;
|
||||
|
||||
if (hash.len == 0) {
|
||||
return error.NoHashInUrl;
|
||||
}
|
||||
|
||||
return hash;
|
||||
}
|
||||
|
||||
/// Download and extract a bundled tar.zst file from a URL.
|
||||
///
|
||||
/// The URL must:
|
||||
/// - Start with "https://" or "http://127.0.0.1"
|
||||
/// - Have the base58-encoded blake3 hash as the last path segment
|
||||
/// - Point to a tar.zst file created with `roc bundle`
|
||||
pub fn downloadAndExtract(
|
||||
allocator: *std.mem.Allocator,
|
||||
url: []const u8,
|
||||
extract_dir: std.fs.Dir,
|
||||
) DownloadError!void {
|
||||
// Validate URL and extract hash
|
||||
const base58_hash = try validateUrl(url);
|
||||
|
||||
// Validate the hash before starting any I/O
|
||||
const expected_hash = (try unbundle.validateBase58Hash(base58_hash)) orelse {
|
||||
return error.InvalidHash;
|
||||
};
|
||||
|
||||
// Create HTTP client
|
||||
var client = std.http.Client{ .allocator = allocator.* };
|
||||
defer client.deinit();
|
||||
|
||||
// Parse the URL
|
||||
const uri = std.Uri.parse(url) catch return error.InvalidUrl;
|
||||
|
||||
// Check if we need to resolve localhost
|
||||
var extra_headers: []const std.http.Header = &.{};
|
||||
if (uri.host) |host| {
|
||||
if (std.mem.eql(u8, host.percent_encoded, "localhost")) {
|
||||
// Security: We must resolve "localhost" and verify it points to a loopback address.
|
||||
// This prevents attacks where:
|
||||
// 1. An attacker modifies /etc/hosts to make localhost resolve to their server
|
||||
// 2. A compromised DNS makes localhost resolve to an external IP
|
||||
// 3. Container/VM networking misconfiguration exposes localhost to external IPs
|
||||
//
|
||||
// We're being intentionally strict here:
|
||||
// - For IPv4: We only accept exactly 127.0.0.1 (not the full 127.0.0.0/8 range)
|
||||
// - For IPv6: We only accept exactly ::1 (not other loopback addresses)
|
||||
//
|
||||
// While the specs technically allow any 127.x.y.z address for IPv4 loopback
|
||||
// and multiple forms for IPv6, in practice localhost almost always resolves
|
||||
// to these exact addresses, and being stricter improves security.
|
||||
|
||||
const address_list = try std.net.getAddressList(allocator.*, "localhost", uri.port orelse HTTP_DEFAULT_PORT);
|
||||
defer address_list.deinit();
|
||||
|
||||
if (address_list.addrs.len == 0) {
|
||||
return error.LocalhostWasNotLoopback;
|
||||
}
|
||||
|
||||
// Check that at least one address is a loopback
|
||||
var found_loopback = false;
|
||||
for (address_list.addrs) |addr| {
|
||||
switch (addr.any.family) {
|
||||
std.posix.AF.INET => {
|
||||
const ipv4_addr = addr.in.sa.addr;
|
||||
if (ipv4_addr == IPV4_LOOPBACK_BE or ipv4_addr == IPV4_LOOPBACK_LE) {
|
||||
found_loopback = true;
|
||||
break;
|
||||
}
|
||||
},
|
||||
std.posix.AF.INET6 => {
|
||||
const ipv6_addr = addr.in6.sa.addr;
|
||||
// Check if it's exactly ::1 (all zeros except last byte is 1)
|
||||
var is_loopback = true;
|
||||
for (ipv6_addr[0..15]) |byte| {
|
||||
if (byte != 0) {
|
||||
is_loopback = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_loopback and ipv6_addr[15] == 1) {
|
||||
found_loopback = true;
|
||||
break;
|
||||
}
|
||||
},
|
||||
else => {}, // Ignore other address families
|
||||
}
|
||||
}
|
||||
|
||||
if (!found_loopback) {
|
||||
return error.LocalhostWasNotLoopback;
|
||||
}
|
||||
|
||||
// Since we're using "localhost", we need to set the Host header manually
|
||||
// to match what the server expects
|
||||
extra_headers = &.{
|
||||
.{ .name = "Host", .value = "localhost" },
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Start the HTTP request
|
||||
var header_buffer: [SERVER_HEADER_BUFFER_SIZE]u8 = undefined;
|
||||
var request = try client.open(.GET, uri, .{
|
||||
.server_header_buffer = &header_buffer,
|
||||
.extra_headers = extra_headers,
|
||||
});
|
||||
defer request.deinit();
|
||||
|
||||
// Send the request and wait for response
|
||||
try request.send();
|
||||
try request.wait();
|
||||
|
||||
// Check for successful response
|
||||
if (request.response.status != .ok) {
|
||||
return error.HttpError;
|
||||
}
|
||||
|
||||
const reader = request.reader();
|
||||
|
||||
// Setup directory extract writer
|
||||
var dir_writer = unbundle.DirExtractWriter.init(extract_dir);
|
||||
|
||||
// Stream and extract the content
|
||||
try unbundle.unbundleStream(reader, dir_writer.extractWriter(), allocator, &expected_hash, null);
|
||||
}
|
||||
|
||||
/// Download and extract a bundled tar.zst file to memory buffers.
|
||||
///
|
||||
/// Returns a BufferExtractWriter containing all extracted files and directories.
|
||||
/// The caller owns the returned writer and must call deinit() on it.
|
||||
pub fn downloadAndExtractToBuffer(
|
||||
allocator: *std.mem.Allocator,
|
||||
url: []const u8,
|
||||
) DownloadError!unbundle.BufferExtractWriter {
|
||||
// Validate URL and extract hash
|
||||
const base58_hash = try validateUrl(url);
|
||||
|
||||
// Validate the hash before starting any I/O
|
||||
const expected_hash = (try unbundle.validateBase58Hash(base58_hash)) orelse {
|
||||
return error.InvalidHash;
|
||||
};
|
||||
|
||||
// Create HTTP client
|
||||
var client = std.http.Client{ .allocator = allocator.* };
|
||||
defer client.deinit();
|
||||
|
||||
// Parse the URL
|
||||
const uri = std.Uri.parse(url) catch return error.InvalidUrl;
|
||||
|
||||
// Start the HTTP request (simplified version without localhost resolution for brevity)
|
||||
var header_buffer: [SERVER_HEADER_BUFFER_SIZE]u8 = undefined;
|
||||
var request = try client.open(.GET, uri, .{
|
||||
.server_header_buffer = &header_buffer,
|
||||
});
|
||||
defer request.deinit();
|
||||
|
||||
// Send the request and wait for response
|
||||
try request.send();
|
||||
try request.wait();
|
||||
|
||||
// Check for successful response
|
||||
if (request.response.status != .ok) {
|
||||
return error.HttpError;
|
||||
}
|
||||
|
||||
const reader = request.reader();
|
||||
|
||||
// Setup buffer extract writer
|
||||
var buffer_writer = unbundle.BufferExtractWriter.init(allocator);
|
||||
errdefer buffer_writer.deinit();
|
||||
|
||||
// Stream and extract the content
|
||||
try unbundle.unbundleStream(reader, buffer_writer.extractWriter(), allocator, &expected_hash, null);
|
||||
|
||||
return buffer_writer;
|
||||
}
|
35
src/unbundle/mod.zig
Normal file
35
src/unbundle/mod.zig
Normal file
|
@ -0,0 +1,35 @@
|
|||
//! Unbundle functionality for Roc packages using Zig's standard library
|
||||
//!
|
||||
//! This module provides functionality to:
|
||||
//! - Unbundle compressed tar archives (.tar.zst files)
|
||||
//! - Validate and decode base58-encoded hashes
|
||||
//! - Extract files with security and cross-platform path validation
|
||||
//! - Download and extract bundled archives from HTTPS URLs
|
||||
//!
|
||||
//! This module uses Zig's std.compress.zstandard for decompression,
|
||||
//! making it compatible with WebAssembly targets.
|
||||
|
||||
pub const unbundle = @import("unbundle.zig");
|
||||
pub const download = @import("download.zig");
|
||||
pub const base58 = @import("base58.zig");
|
||||
|
||||
// Re-export commonly used functions and types
|
||||
pub const unbundleFiles = unbundle.unbundle;
|
||||
pub const unbundleStream = unbundle.unbundleStream;
|
||||
pub const validateBase58Hash = unbundle.validateBase58Hash;
|
||||
pub const pathHasUnbundleErr = unbundle.pathHasUnbundleErr;
|
||||
|
||||
// Re-export error types
|
||||
pub const UnbundleError = unbundle.UnbundleError;
|
||||
pub const PathValidationError = unbundle.PathValidationError;
|
||||
pub const PathValidationReason = unbundle.PathValidationReason;
|
||||
pub const ErrorContext = unbundle.ErrorContext;
|
||||
|
||||
// Re-export extract writer types
|
||||
pub const ExtractWriter = unbundle.ExtractWriter;
|
||||
pub const DirExtractWriter = unbundle.DirExtractWriter;
|
||||
pub const BufferExtractWriter = unbundle.BufferExtractWriter;
|
||||
|
||||
// Re-export download functionality
|
||||
pub const downloadAndExtract = download.downloadAndExtract;
|
||||
pub const downloadAndExtractToBuffer = download.downloadAndExtractToBuffer;
|
383
src/unbundle/unbundle.zig
Normal file
383
src/unbundle/unbundle.zig
Normal file
|
@ -0,0 +1,383 @@
|
|||
//! Unbundle compressed tar archives using Zig's standard library
|
||||
//!
|
||||
//! This module provides unbundling functionality that works on all platforms
|
||||
//! including WebAssembly, by using Zig's std.compress.zstandard instead of
|
||||
//! the C zstd library.
|
||||
|
||||
const builtin = @import("builtin");
|
||||
const std = @import("std");
|
||||
const base58 = @import("base58.zig");
|
||||
|
||||
// Constants
|
||||
const TAR_EXTENSION = ".tar.zst";
|
||||
const STREAM_BUFFER_SIZE: usize = 64 * 1024; // 64KB buffer for streaming operations
|
||||
|
||||
/// Errors that can occur during the unbundle operation.
|
||||
pub const UnbundleError = error{
|
||||
DecompressionFailed,
|
||||
InvalidTarHeader,
|
||||
UnexpectedEndOfStream,
|
||||
FileCreateFailed,
|
||||
DirectoryCreateFailed,
|
||||
FileWriteFailed,
|
||||
HashMismatch,
|
||||
InvalidFilename,
|
||||
FileTooLarge,
|
||||
InvalidPath,
|
||||
NoDataExtracted,
|
||||
} || std.mem.Allocator.Error;
|
||||
|
||||
/// Context for error reporting during unbundle operations
|
||||
pub const ErrorContext = struct {
|
||||
path: []const u8,
|
||||
reason: PathValidationReason,
|
||||
};
|
||||
|
||||
/// Specific reason why a path validation failed
|
||||
pub const PathValidationReason = union(enum) {
|
||||
empty_path,
|
||||
path_too_long,
|
||||
windows_reserved_char: u8,
|
||||
absolute_path,
|
||||
path_traversal,
|
||||
current_directory_reference,
|
||||
windows_reserved_name,
|
||||
contained_backslash_on_unix,
|
||||
component_ends_with_space,
|
||||
component_ends_with_period,
|
||||
};
|
||||
|
||||
/// Error type for path validation failures
|
||||
pub const PathValidationError = struct {
|
||||
path: []const u8,
|
||||
reason: PathValidationReason,
|
||||
};
|
||||
|
||||
/// Writer interface for extracting files during unbundle
|
||||
pub const ExtractWriter = struct {
|
||||
ptr: *anyopaque,
|
||||
vtable: *const VTable,
|
||||
|
||||
pub const VTable = struct {
|
||||
createFile: *const fn (ptr: *anyopaque, path: []const u8) CreateFileError!std.io.AnyWriter,
|
||||
finishFile: *const fn (ptr: *anyopaque, writer: std.io.AnyWriter) FinishFileError!void,
|
||||
makeDir: *const fn (ptr: *anyopaque, path: []const u8) MakeDirError!void,
|
||||
};
|
||||
|
||||
pub const CreateFileError = error{ FileCreateFailed, InvalidPath, OutOfMemory };
|
||||
pub const FinishFileError = error{FileWriteFailed};
|
||||
pub const MakeDirError = error{ DirectoryCreateFailed, InvalidPath, OutOfMemory };
|
||||
|
||||
pub fn createFile(self: ExtractWriter, path: []const u8) CreateFileError!std.io.AnyWriter {
|
||||
return self.vtable.createFile(self.ptr, path);
|
||||
}
|
||||
|
||||
pub fn finishFile(self: ExtractWriter, writer: std.io.AnyWriter) FinishFileError!void {
|
||||
return self.vtable.finishFile(self.ptr, writer);
|
||||
}
|
||||
|
||||
pub fn makeDir(self: ExtractWriter, path: []const u8) MakeDirError!void {
|
||||
return self.vtable.makeDir(self.ptr, path);
|
||||
}
|
||||
};
|
||||
|
||||
/// Directory-based extract writer
|
||||
pub const DirExtractWriter = struct {
|
||||
dir: std.fs.Dir,
|
||||
|
||||
pub fn init(dir: std.fs.Dir) DirExtractWriter {
|
||||
return .{ .dir = dir };
|
||||
}
|
||||
|
||||
pub fn extractWriter(self: *DirExtractWriter) ExtractWriter {
|
||||
return ExtractWriter{
|
||||
.ptr = self,
|
||||
.vtable = &vtable,
|
||||
};
|
||||
}
|
||||
|
||||
const vtable = ExtractWriter.VTable{
|
||||
.createFile = createFile,
|
||||
.finishFile = finishFile,
|
||||
.makeDir = makeDir,
|
||||
};
|
||||
|
||||
fn createFile(ptr: *anyopaque, path: []const u8) ExtractWriter.CreateFileError!std.io.AnyWriter {
|
||||
const self: *DirExtractWriter = @ptrCast(@alignCast(ptr));
|
||||
|
||||
// Ensure parent directories exist
|
||||
if (std.fs.path.dirname(path)) |parent| {
|
||||
self.dir.makePath(parent) catch return error.FileCreateFailed;
|
||||
}
|
||||
|
||||
const file = self.dir.createFile(path, .{}) catch return error.FileCreateFailed;
|
||||
return file.writer().any();
|
||||
}
|
||||
|
||||
fn finishFile(_: *anyopaque, writer: std.io.AnyWriter) ExtractWriter.FinishFileError!void {
|
||||
// For file writers, we need to close the file
|
||||
// In Zig 0.14, we need to properly cast the context
|
||||
const file_writer = writer.context;
|
||||
const file = @as(*std.fs.File, @ptrCast(@alignCast(file_writer)));
|
||||
file.close();
|
||||
}
|
||||
|
||||
fn makeDir(ptr: *anyopaque, path: []const u8) ExtractWriter.MakeDirError!void {
|
||||
const self: *DirExtractWriter = @ptrCast(@alignCast(ptr));
|
||||
self.dir.makePath(path) catch return error.DirectoryCreateFailed;
|
||||
}
|
||||
};
|
||||
|
||||
/// Buffer-based extract writer for in-memory extraction
|
||||
pub const BufferExtractWriter = struct {
|
||||
allocator: *std.mem.Allocator,
|
||||
files: std.StringHashMap(std.ArrayList(u8)),
|
||||
directories: std.ArrayList([]u8),
|
||||
current_file: ?*std.ArrayList(u8) = null,
|
||||
|
||||
pub fn init(allocator: *std.mem.Allocator) BufferExtractWriter {
|
||||
return .{
|
||||
.allocator = allocator,
|
||||
.files = std.StringHashMap(std.ArrayList(u8)).init(allocator.*),
|
||||
.directories = std.ArrayList([]u8).init(allocator.*),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *BufferExtractWriter) void {
|
||||
var iter = self.files.iterator();
|
||||
while (iter.next()) |entry| {
|
||||
self.allocator.free(entry.key_ptr.*);
|
||||
entry.value_ptr.deinit();
|
||||
}
|
||||
self.files.deinit();
|
||||
|
||||
for (self.directories.items) |dir| {
|
||||
self.allocator.free(dir);
|
||||
}
|
||||
self.directories.deinit();
|
||||
}
|
||||
|
||||
pub fn extractWriter(self: *BufferExtractWriter) ExtractWriter {
|
||||
return ExtractWriter{
|
||||
.ptr = self,
|
||||
.vtable = &vtable,
|
||||
};
|
||||
}
|
||||
|
||||
const vtable = ExtractWriter.VTable{
|
||||
.createFile = createFile,
|
||||
.finishFile = finishFile,
|
||||
.makeDir = makeDir,
|
||||
};
|
||||
|
||||
fn createFile(ptr: *anyopaque, path: []const u8) ExtractWriter.CreateFileError!std.io.AnyWriter {
|
||||
const self: *BufferExtractWriter = @ptrCast(@alignCast(ptr));
|
||||
|
||||
const key = self.allocator.dupe(u8, path) catch return error.OutOfMemory;
|
||||
errdefer self.allocator.free(key);
|
||||
|
||||
const result = self.files.getOrPut(key) catch return error.OutOfMemory;
|
||||
if (result.found_existing) {
|
||||
self.allocator.free(key);
|
||||
result.value_ptr.clearRetainingCapacity();
|
||||
} else {
|
||||
result.value_ptr.* = std.ArrayList(u8).init(self.allocator.*);
|
||||
}
|
||||
|
||||
self.current_file = result.value_ptr;
|
||||
return result.value_ptr.writer().any();
|
||||
}
|
||||
|
||||
fn finishFile(ptr: *anyopaque, _: std.io.AnyWriter) ExtractWriter.FinishFileError!void {
|
||||
const self: *BufferExtractWriter = @ptrCast(@alignCast(ptr));
|
||||
self.current_file = null;
|
||||
}
|
||||
|
||||
fn makeDir(ptr: *anyopaque, path: []const u8) ExtractWriter.MakeDirError!void {
|
||||
const self: *BufferExtractWriter = @ptrCast(@alignCast(ptr));
|
||||
const dir_path = self.allocator.dupe(u8, path) catch return error.OutOfMemory;
|
||||
self.directories.append(dir_path) catch {
|
||||
self.allocator.free(dir_path);
|
||||
return error.OutOfMemory;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
/// Validate a base58-encoded hash string and return the decoded hash.
|
||||
/// Returns null if the hash is invalid.
|
||||
pub fn validateBase58Hash(base58_hash: []const u8) !?[32]u8 {
|
||||
if (base58_hash.len > base58.base58_hash_bytes) {
|
||||
return null;
|
||||
}
|
||||
|
||||
var hash: [32]u8 = undefined;
|
||||
base58.decode(base58_hash, &hash) catch return null;
|
||||
return hash;
|
||||
}
|
||||
|
||||
/// Check if a path has any unbundling errors (security and compatibility issues)
|
||||
pub fn pathHasUnbundleErr(path: []const u8) ?PathValidationError {
|
||||
if (path.len == 0) {
|
||||
return PathValidationError{
|
||||
.path = path,
|
||||
.reason = .empty_path,
|
||||
};
|
||||
}
|
||||
|
||||
// Check for absolute paths
|
||||
if (path[0] == '/' or (builtin.target.os.tag == .windows and path.len >= 2 and path[1] == ':')) {
|
||||
return PathValidationError{
|
||||
.path = path,
|
||||
.reason = .absolute_path,
|
||||
};
|
||||
}
|
||||
|
||||
// Check for path traversal attempts
|
||||
if (std.mem.indexOf(u8, path, "..") != null) {
|
||||
return PathValidationError{
|
||||
.path = path,
|
||||
.reason = .path_traversal,
|
||||
};
|
||||
}
|
||||
|
||||
// Check for current directory references
|
||||
if (std.mem.eql(u8, path, ".") or std.mem.indexOf(u8, path, "./") != null or std.mem.indexOf(u8, path, "/.") != null) {
|
||||
return PathValidationError{
|
||||
.path = path,
|
||||
.reason = .current_directory_reference,
|
||||
};
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/// Unbundle files from a compressed tar archive stream.
|
||||
///
|
||||
/// This is the core streaming unbundle logic that can be used by both file-based
|
||||
/// unbundling and network-based downloading.
|
||||
/// If an InvalidPath error is returned, error_context will contain details about the invalid path.
|
||||
pub fn unbundleStream(
|
||||
input_reader: anytype,
|
||||
extract_writer: ExtractWriter,
|
||||
allocator: *std.mem.Allocator,
|
||||
expected_hash: *const [32]u8,
|
||||
error_context: ?*ErrorContext,
|
||||
) UnbundleError!void {
|
||||
// Create a hashing reader to verify the hash while reading
|
||||
var hasher = std.crypto.hash.Blake3.init(.{});
|
||||
const HashingReader = struct {
|
||||
child_reader: @TypeOf(input_reader),
|
||||
hasher: *std.crypto.hash.Blake3,
|
||||
|
||||
pub const Error = @TypeOf(input_reader).Error;
|
||||
pub const Reader = std.io.Reader(@This(), Error, read);
|
||||
|
||||
pub fn read(self: @This(), buffer: []u8) Error!usize {
|
||||
const n = try self.child_reader.read(buffer);
|
||||
if (n > 0) {
|
||||
self.hasher.update(buffer[0..n]);
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
pub fn reader(self: @This()) Reader {
|
||||
return .{ .context = self };
|
||||
}
|
||||
};
|
||||
|
||||
var hashing_reader = HashingReader{
|
||||
.child_reader = input_reader,
|
||||
.hasher = &hasher,
|
||||
};
|
||||
|
||||
// Create zstandard decompressor
|
||||
var zstd_stream = std.compress.zstandard.decompressStream(allocator.*, hashing_reader.reader());
|
||||
defer zstd_stream.deinit();
|
||||
const decompressed_reader = zstd_stream.reader();
|
||||
|
||||
// Create tar reader
|
||||
var tar_iterator = std.tar.iterator(decompressed_reader, .{
|
||||
.max_file_size = std.math.maxInt(usize), // No limit on file size
|
||||
});
|
||||
|
||||
var data_extracted = false;
|
||||
|
||||
// Process all tar entries
|
||||
while (try tar_iterator.next()) |entry| {
|
||||
const file_path = entry.path;
|
||||
|
||||
// Validate path for security
|
||||
if (pathHasUnbundleErr(file_path)) |validation_error| {
|
||||
if (error_context) |ctx| {
|
||||
ctx.path = validation_error.path;
|
||||
ctx.reason = validation_error.reason;
|
||||
}
|
||||
return error.InvalidPath;
|
||||
}
|
||||
|
||||
switch (entry.kind) {
|
||||
.directory => {
|
||||
try extract_writer.makeDir(file_path);
|
||||
data_extracted = true;
|
||||
},
|
||||
.file => {
|
||||
const file_writer = try extract_writer.createFile(file_path);
|
||||
defer extract_writer.finishFile(file_writer) catch {};
|
||||
|
||||
// Stream the file content
|
||||
const file_size = std.math.cast(usize, entry.size) orelse return error.FileTooLarge;
|
||||
var bytes_remaining = file_size;
|
||||
var buffer: [STREAM_BUFFER_SIZE]u8 = undefined;
|
||||
|
||||
while (bytes_remaining > 0) {
|
||||
const to_read = @min(buffer.len, bytes_remaining);
|
||||
const bytes_read = try entry.reader().readAll(buffer[0..to_read]);
|
||||
if (bytes_read == 0) return error.UnexpectedEndOfStream;
|
||||
try file_writer.writeAll(buffer[0..bytes_read]);
|
||||
bytes_remaining -= bytes_read;
|
||||
}
|
||||
|
||||
data_extracted = true;
|
||||
},
|
||||
else => {
|
||||
// Skip other entry types (symlinks, etc.)
|
||||
try entry.skip();
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if (!data_extracted) {
|
||||
return error.NoDataExtracted;
|
||||
}
|
||||
|
||||
// Verify the hash
|
||||
var actual_hash: [32]u8 = undefined;
|
||||
hasher.final(&actual_hash);
|
||||
if (!std.mem.eql(u8, &actual_hash, expected_hash)) {
|
||||
return error.HashMismatch;
|
||||
}
|
||||
}
|
||||
|
||||
/// Unbundle files from a compressed tar archive to a directory.
|
||||
///
|
||||
/// The filename parameter should be the base58-encoded blake3 hash + .tar.zst extension.
|
||||
/// If an InvalidPath error is returned, error_context will contain details about the invalid path.
|
||||
pub fn unbundle(
|
||||
input_reader: anytype,
|
||||
extract_dir: std.fs.Dir,
|
||||
allocator: *std.mem.Allocator,
|
||||
filename: []const u8,
|
||||
error_context: ?*ErrorContext,
|
||||
) UnbundleError!void {
|
||||
// Extract expected hash from filename
|
||||
if (!std.mem.endsWith(u8, filename, TAR_EXTENSION)) {
|
||||
return error.InvalidFilename;
|
||||
}
|
||||
const base58_hash = filename[0 .. filename.len - TAR_EXTENSION.len]; // Remove .tar.zst
|
||||
const expected_hash = (try validateBase58Hash(base58_hash)) orelse {
|
||||
return error.InvalidFilename;
|
||||
};
|
||||
|
||||
var dir_writer = DirExtractWriter.init(extract_dir);
|
||||
return unbundleStream(input_reader, dir_writer.extractWriter(), allocator, &expected_hash, error_context);
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue