Thanks to visit codestin.com
Credit goes to Github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 183 additions & 2 deletions zml2/io/vfs/file.zig
Original file line number Diff line number Diff line change
@@ -1,12 +1,72 @@
const builtin = @import("builtin");
const std = @import("std");

const log = std.log.scoped(.@"zml/io/vfs/file");

pub const DirectIoError = error{
UnsupportedPlatform,
MisalignedBuffer,
BufferTooSmall,
UnexpectedFcntlResult,
} || std.posix.FcntlError;

fn canUseDirectIO() bool {
if (builtin.target.os.tag == .linux) {
return @hasField(std.posix.O, "DIRECT");
}
return false;
}

fn useDirectIO(file: std.Io.File) DirectIoError!bool {
if (canUseDirectIO()) {
const flags = try std.posix.fcntl(file.handle, std.posix.F.GETFL, 0);
const direct_flag: c_int = @bitCast(std.posix.O{ .DIRECT = true });
return (flags & direct_flag) != 0;
} else {
return DirectIoError.UnsupportedPlatform;
}
}

fn switchToBufferedIO(file: std.fs.File) DirectIoError!void {
if (canUseDirectIO()) {
const flags = try std.posix.fcntl(file.handle, std.posix.F.GETFL, 0);
const direct_flag: c_int = @bitCast(std.posix.O{ .DIRECT = true });
if ((flags & direct_flag) == 0) return;

const result = try std.posix.fcntl(file.handle, std.posix.F.SETFL, flags & ~@as(c_uint, @bitCast(@as(u32, @intCast(direct_flag)))));
if (result != 0) return DirectIoError.UnexpectedFcntlResult;
} else {
return DirectIoError.UnsupportedPlatform;
}
}

fn switchToDirectIO(file: std.Io.File) DirectIoError!void {
if (canUseDirectIO()) {
const flags = try std.posix.fcntl(file.handle, std.posix.F.GETFL, 0);
const direct_flag: c_int = @bitCast(std.posix.O{ .DIRECT = true });

const result = try std.posix.fcntl(file.handle, std.posix.F.SETFL, flags | direct_flag);
if (result != 0) return DirectIoError.UnexpectedFcntlResult;
} else {
return DirectIoError.UnsupportedPlatform;
}
}

pub const File = struct {
pub const Config = struct {
direct_io: bool = false,
direct_io_alignment: std.mem.Alignment = .fromByteUnits(4 * 1024),
};

allocator: std.mem.Allocator,
direct_io_map: std.AutoHashMapUnmanaged(std.Io.File.Handle, std.Io.File.Handle),
mutex: std.Io.Mutex,
config: Config,

inner: std.Io,
vtable: std.Io.VTable,

pub fn init(inner: std.Io) File {
pub fn init(allocator: std.mem.Allocator, inner: std.Io, config: Config) File {
var vtable = inner.vtable.*;
vtable.dirMake = dirMake;
vtable.dirMakePath = dirMakePath;
Expand All @@ -15,14 +75,25 @@ pub const File = struct {
vtable.dirAccess = dirAccess;
vtable.dirCreateFile = dirCreateFile;
vtable.dirOpenFile = dirOpenFile;
vtable.fileClose = fileClose;
vtable.dirOpenDir = dirOpenDir;
vtable.fileReadPositional = fileReadPositional;
vtable.fileReadStreaming = fileReadStreaming;

return .{
.allocator = allocator,
.direct_io_map = .{},
.mutex = .init,
.config = config,
.inner = inner,
.vtable = vtable,
};
}

pub fn deinit(self: *File) void {
self.direct_io_map.deinit(self.allocator);
}

pub fn io(self: *File) std.Io {
return .{
.userdata = self,
Expand All @@ -37,6 +108,51 @@ pub const File = struct {
return path;
}

fn canOpenWithDirectIO(self: *File, file: std.Io.File, flags: std.Io.File.OpenFlags) std.Io.File.OpenError!bool {
const file_stat = self.inner.vtable.fileStat(self.inner.userdata, file) catch |err| switch (err) {
else => return std.Io.File.OpenError.Unexpected,
};

return self.config.direct_io and flags.mode == .read_only and file_stat.size >= self.config.direct_io_alignment.toByteUnits();
}

fn innerFile(
self: *File,
file: std.Io.File,
buffers: [][]u8,
position: u64,
) std.Io.File {
const alignment_bytes: usize = self.config.direct_io_alignment.toByteUnits();
var buffers_aligned = true;

for (buffers) |buf| {
const ptr_addr = @intFromPtr(buf.ptr);
if (!std.mem.isAligned(ptr_addr, alignment_bytes) or
!std.mem.isAligned(buf.len, alignment_bytes) or
(buf.len < alignment_bytes))
{
buffers_aligned = false;
break;
}
}

const pos_aligned = std.mem.isAligned(@as(usize, position), alignment_bytes);

self.mutex.lockUncancelable(self.inner);
const direct_io_handle = self.direct_io_map.get(file.handle);
self.mutex.unlock(self.inner);

return if (direct_io_handle) |handle| blk: {
if (buffers_aligned and pos_aligned) {
break :blk std.Io.File{ .handle = handle };
} else {
break :blk file;
}
} else blk: {
break :blk file;
};
}

fn dirMake(
userdata: ?*anyopaque,
dir: std.Io.Dir,
Expand Down Expand Up @@ -104,7 +220,48 @@ pub const File = struct {
flags: std.Io.File.OpenFlags,
) std.Io.File.OpenError!std.Io.File {
const self: *File = @ptrCast(@alignCast(userdata orelse unreachable));
return self.inner.vtable.dirOpenFile(self.inner.userdata, dir, stripScheme(sub_path), flags);
const file = try self.inner.vtable.dirOpenFile(self.inner.userdata, dir, stripScheme(sub_path), flags);

const use_direct_io = try self.canOpenWithDirectIO(file, flags);
errdefer file.close(self.inner);

if (use_direct_io) {
const direct_io_file = try self.inner.vtable.dirOpenFile(self.inner.userdata, dir, stripScheme(sub_path), flags);
errdefer direct_io_file.close(self.inner);

switchToDirectIO(direct_io_file) catch |err| {
log.err("Failed to switch to Direct I/O mode: {any}", .{err});
return std.Io.File.OpenError.Unexpected;
};

self.mutex.lockUncancelable(self.inner);
defer self.mutex.unlock(self.inner);

self.direct_io_map.put(self.allocator, file.handle, direct_io_file.handle) catch |err| {
log.err("Failed to insert Direct I/O file into map: {any}", .{err});
return std.Io.File.OpenError.Unexpected;
};
}

return file;
}

fn fileClose(userdata: ?*anyopaque, file: std.Io.File) void {
const self: *File = @ptrCast(@alignCast(userdata orelse unreachable));

self.mutex.lockUncancelable(self.inner);
const direct_io_handle = self.direct_io_map.fetchRemove(file.handle);
self.mutex.unlock(self.inner);

if (direct_io_handle) |handle| {
const direct_io_file: std.Io.File = .{
.handle = handle.value,
};

self.inner.vtable.fileClose(self.inner.userdata, direct_io_file);
}

return self.inner.vtable.fileClose(self.inner.userdata, file);
}

fn dirOpenDir(
Expand All @@ -116,4 +273,28 @@ pub const File = struct {
const self: *File = @ptrCast(@alignCast(userdata orelse unreachable));
return self.inner.vtable.dirOpenDir(self.inner.userdata, dir, stripScheme(sub_path), options);
}

fn fileReadPositional(
userdata: ?*anyopaque,
file: std.Io.File,
buffer: [][]u8,
position: u64,
) std.Io.File.ReadPositionalError!usize {
const self: *File = @ptrCast(@alignCast(userdata orelse unreachable));
const inner_file = self.innerFile(file, buffer, position);

return self.inner.vtable.fileReadPositional(self.inner.userdata, inner_file, buffer, position);
}

fn fileReadStreaming(
userdata: ?*anyopaque,
file: std.Io.File,
buffer: [][]u8,
) std.Io.File.Reader.Error!usize {
const self: *File = @ptrCast(@alignCast(userdata orelse unreachable));
const position = std.posix.lseek_CUR_get(file.handle) catch return std.Io.File.Reader.Error.Unexpected;
const inner_file = self.innerFile(file, buffer, position);

return self.inner.vtable.fileReadStreaming(self.inner.userdata, inner_file, buffer);
}
};
115 changes: 69 additions & 46 deletions zml2/main_io.zig
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub fn main() !void {

const allocator = debug_allocator.allocator();

// const allocator = std.heap.c_allocator;
// const allocator = std.heap.smp_allocator;

var threaded: std.Io.Threaded = .init(allocator);
defer threaded.deinit();
Expand All @@ -44,7 +44,15 @@ pub fn main() !void {
try http_client.initDefaultProxies(allocator);
defer http_client.deinit();

var vfs_file: zml.io.VFS.File = .init(threaded.io());
var vfs_file: zml.io.VFS.File = .init(
allocator,
threaded.io(),
.{
.direct_io = std.process.hasEnvVarConstant("DIRECT"),
.direct_io_alignment = .fromByteUnits(4 * 1024),
},
);
defer vfs_file.deinit();

var vfs_http: zml.io.VFS.HTTP = try .init(allocator, threaded.io(), &http_client, .{});
defer vfs_http.deinit();
Expand Down Expand Up @@ -88,21 +96,70 @@ pub fn main() !void {
break :blk default_uri;
};

// {
// // const initial_path = "/Users/hugo/Developer/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct";
// // const initial_path = "file:///Users/hugo/Developer/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct";
// const initial_path = "https://storage.googleapis.com/zig-vfs/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct";

// const llama_dir = try std.Io.Dir.openDir(.cwd(), io, initial_path, .{});
// defer llama_dir.close(io);

// const filepath = "../model.safetensors.index.json";

// const index_file = try llama_dir.openFile(io, filepath, .{});
// defer index_file.close(io);

// const stat = try index_file.stat(io);
// log.info("Opened local index file with size: {d} bytes", .{stat.size});
// }

{
// const initial_path = "/Users/hugo/Developer/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct";
// const initial_path = "file:///Users/hugo/Developer/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct";
const initial_path = "https://storage.googleapis.com/zig-vfs/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct";
const repo = try std.Io.Dir.openDir(.cwd(), io, "file:///home/hugo/Llama-3.1-8B-Instruct/", .{});
defer repo.close(io);

const file = try repo.openFile(io, "model-00001-of-00004.safetensors", .{});
defer file.close(io);

const llama_dir = try std.Io.Dir.openDir(.cwd(), io, initial_path, .{});
defer llama_dir.close(io);
const buf_size = 256 * 1024 * 1024;
const buf = try allocator.alignedAlloc(u8, .fromByteUnits(4 * 1024), buf_size);
defer allocator.free(buf);

const filepath = "../model.safetensors.index.json";
var sha256: std.crypto.hash.sha2.Sha256 = .init(.{});
const compute_sha = false;

const index_file = try llama_dir.openFile(io, filepath, .{});
defer index_file.close(io);
var total_read: usize = 1 * 1024;
var bufs = [_][]u8{buf};

var timer: std.time.Timer = try .start();
defer {
const elapsed = timer.read();
log.info("File read in {d} ms", .{elapsed / std.time.ns_per_ms});
}

const stat = try index_file.stat(io);
log.info("Opened local index file with size: {d} bytes", .{stat.size});
while (true) {
const n = try io.vtable.fileReadStreaming(io.userdata, file, &bufs);
if (n == 0) break;
if (compute_sha) sha256.update(buf[0..n]);
total_read += n;
}

const elapsed = timer.read();
const read_mb = @as(f64, @floatFromInt(total_read)) / (1024.0 * 1024.0);
const read_time_s = @as(f64, @floatFromInt(elapsed)) / @as(f64, std.time.ns_per_s);
const throughput_mb_s = if (read_time_s > 0) read_mb / read_time_s else 0;
const throughput_gbps = if (read_time_s > 0) (read_mb * 8.0) / 1024.0 / read_time_s else 0;
log.info("Read {d:.2} MB in {d:.2} s = {d:.2} MB/s | {d:.2} Gbps", .{
read_mb,
read_time_s,
throughput_mb_s,
throughput_gbps,
});

if (compute_sha) {
var hash: [32]u8 = undefined;
sha256.final(&hash);
log.info("SHA256: {x}", .{hash});
}
}

{
Expand All @@ -123,40 +180,6 @@ pub fn main() !void {

log.info("Parsed {d} tensors", .{registry.tensors.count()});

// {
// const tensor_name = "model.layers.31.mlp.down_proj.weight";
// // const tensor_name = "model.layers.31.input_layernorm.weight";
// // const tensor_name = "lm_head.weight";
// const tensor = registry.tensors.get(tensor_name) orelse return error.TensorNotFound;
// log.info("Reading {f}...", .{tensor});

// const tensor_reader_buf = try allocator.alloc(u8, tensor.byteSize());
// defer allocator.free(tensor_reader_buf);

// const writer_buffer = try allocator.alloc(u8, tensor.byteSize());
// defer allocator.free(writer_buffer);

// var tensor_reader = try registry.reader(io, &vfs, tensor_name, tensor_reader_buf);
// defer tensor_reader.deinit();

// var writer: std.Io.Writer = .fixed(writer_buffer);

// var timer_read: std.time.Timer = try .start();

// const read = try tensor_reader.interface.streamRemaining(&writer);

// const elapsed = timer_read.read();
// const read_mb = @as(f64, @floatFromInt(read)) / (1024.0 * 1024.0);
// const read_time_s = @as(f64, @floatFromInt(elapsed)) / @as(f64, std.time.ns_per_s);
// const throughput_gbps = (read_mb * 8.0) / 1024.0 / read_time_s;

// log.info("Read completed in {d} ms | {d:.2} MB/s | {d:.2} Gbps", .{
// elapsed / std.time.ns_per_ms,
// read_mb / read_time_s,
// throughput_gbps,
// });
// }

{
var err: ?anyerror = null;

Expand Down
Loading