initial round-trip using cache, just NodeStore for now

This commit is contained in:
Luke Boswell 2025-06-24 14:37:11 +10:00
parent 2c53d4deb9
commit fd7bb1bf79
No known key found for this signature in database
GPG key ID: 54A7324B1B975757
3 changed files with 160 additions and 0 deletions

View file

@ -260,3 +260,80 @@ test "readCacheInto after writeToCache" {
const data_bytes = read_buffer[@sizeOf(CacheHeader)..expected_total_bytes];
try std.testing.expectEqualStrings(test_data, data_bytes);
}
// TODO expand this test gradually to more of our Can IR until
// we can round-trip a whole type-checked module from cache
test "NodeStore cache round-trip" {
const NodeStore = @import("check/canonicalize/NodeStore.zig");
const Node = @import("check/canonicalize/Node.zig");
var tmp_dir = std.testing.tmpDir(.{});
defer tmp_dir.cleanup();
var abs_path_buf: [std.fs.max_path_bytes]u8 = undefined;
const abs_cache_dir = try tmp_dir.dir.realpath(".", &abs_path_buf);
const fs = Filesystem.default();
const allocator = std.testing.allocator;
const test_hash = "0123456789abcdef";
var store = NodeStore.initCapacity(allocator, 10);
defer store.deinit();
const expr_node = Node{
.data_1 = 42,
.data_2 = 100,
.data_3 = 200,
.region = .{ .start = .{ .offset = 0 }, .end = .{ .offset = 10 } },
.tag = .expr_string,
};
const expr_idx = store.nodes.append(store.gpa, expr_node);
try store.extra_data.append(store.gpa, 1234);
try store.extra_data.append(store.gpa, 5678);
const store_size = store.serializedSize();
const store_buffer = try allocator.alignedAlloc(u8, @alignOf(Node), store_size);
defer allocator.free(store_buffer);
const serialized = try store.serializeInto(store_buffer);
try std.testing.expectEqual(store_size, serialized.len);
const header_size = @sizeOf(CacheHeader);
const aligned_header_size = std.mem.alignForward(usize, header_size, @alignOf(Node));
const total_size = aligned_header_size + store_size;
var write_buffer = try allocator.alignedAlloc(u8, @alignOf(Node), total_size);
defer allocator.free(write_buffer);
const header = @as(*CacheHeader, @ptrCast(write_buffer.ptr));
header.* = .{
.total_cached_bytes = @intCast(store_size),
};
@memcpy(write_buffer[aligned_header_size..total_size], serialized);
try writeToCache(abs_cache_dir, test_hash, header, fs, allocator);
var read_buffer: [4096]u8 align(@alignOf(Node)) = undefined;
const bytes_read = try readCacheInto(&read_buffer, abs_cache_dir, test_hash, fs, allocator);
const parsed_header = try CacheHeader.initFromBytes(read_buffer[0..bytes_read]);
try std.testing.expectEqual(header.total_cached_bytes, parsed_header.total_cached_bytes);
const data_start = std.mem.alignForward(usize, @sizeOf(CacheHeader), @alignOf(Node));
const data_end = data_start + parsed_header.total_cached_bytes;
var restored_store = try NodeStore.deserializeFrom(@as([]align(@alignOf(Node)) const u8, @alignCast(read_buffer[data_start..data_end])), allocator);
defer restored_store.deinit();
try std.testing.expectEqual(store.nodes.len(), restored_store.nodes.len());
try std.testing.expectEqual(store.extra_data.items.len, restored_store.extra_data.items.len);
const restored_node = restored_store.nodes.get(expr_idx);
try std.testing.expectEqual(expr_node.data_1, restored_node.data_1);
try std.testing.expectEqual(expr_node.data_2, restored_node.data_2);
try std.testing.expectEqual(expr_node.data_3, restored_node.data_3);
try std.testing.expectEqual(expr_node.tag, restored_node.tag);
try std.testing.expectEqual(@as(u32, 1234), restored_store.extra_data.items[0]);
try std.testing.expectEqual(@as(u32, 5678), restored_store.extra_data.items[1]);
}

View file

@ -1802,3 +1802,85 @@ pub fn addTypeVarSlot(store: *NodeStore, parent_node_idx: Node.Idx, region: base
});
return @enumFromInt(@intFromEnum(nid));
}
/// Calculate the size needed to serialize this NodeStore
pub fn serializedSize(self: *const NodeStore) usize {
// We only serialize nodes and extra_data (the scratch arrays are transient)
return self.nodes.serializedSize() +
@sizeOf(u32) + // extra_data length
(self.extra_data.items.len * @sizeOf(u32));
}
/// Serialize this NodeStore into the provided buffer
/// Buffer must be at least serializedSize() bytes and properly aligned
pub fn serializeInto(self: *const NodeStore, buffer: []align(@alignOf(Node)) u8) ![]u8 {
const size = self.serializedSize();
if (buffer.len < size) return error.BufferTooSmall;
var offset: usize = 0;
// Serialize nodes - cast to proper alignment for Node type
const nodes_buffer = @as([]align(@alignOf(Node)) u8, @alignCast(buffer[offset..]));
const nodes_slice = try self.nodes.serializeInto(nodes_buffer);
offset += nodes_slice.len;
// Serialize extra_data length
const extra_len_ptr = @as(*u32, @ptrCast(@alignCast(buffer.ptr + offset)));
extra_len_ptr.* = @intCast(self.extra_data.items.len);
offset += @sizeOf(u32);
// Serialize extra_data items
if (self.extra_data.items.len > 0) {
const extra_ptr = @as([*]u32, @ptrCast(@alignCast(buffer.ptr + offset)));
@memcpy(extra_ptr, self.extra_data.items);
offset += self.extra_data.items.len * @sizeOf(u32);
}
return buffer[0..offset];
}
/// Deserialize a NodeStore from the provided buffer
pub fn deserializeFrom(buffer: []align(@alignOf(Node)) const u8, allocator: std.mem.Allocator) !NodeStore {
var offset: usize = 0;
// Deserialize nodes - cast to proper alignment for Node type
const nodes_buffer = @as([]align(@alignOf(Node)) const u8, @alignCast(buffer[offset..]));
const nodes = try Node.List.deserializeFrom(nodes_buffer, allocator);
offset += nodes.serializedSize();
// Deserialize extra_data length
if (buffer.len < offset + @sizeOf(u32)) return error.BufferTooSmall;
const extra_len = @as(*const u32, @ptrCast(@alignCast(buffer.ptr + offset))).*;
offset += @sizeOf(u32);
// Deserialize extra_data items
var extra_data = try std.ArrayListUnmanaged(u32).initCapacity(allocator, extra_len);
if (extra_len > 0) {
const remaining = buffer.len - offset;
const expected = extra_len * @sizeOf(u32);
if (remaining < expected) return error.BufferTooSmall;
const extra_ptr = @as([*]const u32, @ptrCast(@alignCast(buffer.ptr + offset)));
extra_data.appendSliceAssumeCapacity(extra_ptr[0..extra_len]);
}
// Create NodeStore with empty scratch arrays
return NodeStore{
.gpa = allocator,
.nodes = nodes,
.extra_data = extra_data,
// All scratch arrays start empty
.scratch_statements = base.Scratch(CIR.Statement.Idx){ .items = .{} },
.scratch_exprs = base.Scratch(CIR.Expr.Idx){ .items = .{} },
.scratch_record_fields = base.Scratch(CIR.RecordField.Idx){ .items = .{} },
.scratch_when_branches = base.Scratch(CIR.WhenBranch.Idx){ .items = .{} },
.scratch_where_clauses = base.Scratch(CIR.WhereClause.Idx){ .items = .{} },
.scratch_patterns = base.Scratch(CIR.Pattern.Idx){ .items = .{} },
.scratch_pattern_record_fields = base.Scratch(CIR.PatternRecordField.Idx){ .items = .{} },
.scratch_type_annos = base.Scratch(CIR.TypeAnno.Idx){ .items = .{} },
.scratch_anno_record_fields = base.Scratch(CIR.AnnoRecordField.Idx){ .items = .{} },
.scratch_exposed_items = base.Scratch(CIR.ExposedItem.Idx){ .items = .{} },
.scratch_defs = base.Scratch(CIR.Def.Idx){ .items = .{} },
.scratch_diagnostics = base.Scratch(CIR.Diagnostic.Idx){ .items = .{} },
};
}

View file

@ -4,6 +4,7 @@ const testing = std.testing;
test {
testing.refAllDeclsRecursive(@import("main.zig"));
testing.refAllDeclsRecursive(@import("builtins/main.zig"));
testing.refAllDeclsRecursive(@import("cache.zig"));
// TODO: Remove after hooking up
testing.refAllDeclsRecursive(@import("reporting.zig"));