diff options
Diffstat (limited to 'mcproto/decode.ha')
-rw-r--r-- | mcproto/decode.ha | 193 |
1 files changed, 193 insertions, 0 deletions
diff --git a/mcproto/decode.ha b/mcproto/decode.ha new file mode 100644 index 0000000..bd020c4 --- /dev/null +++ b/mcproto/decode.ha @@ -0,0 +1,193 @@ +use encoding::utf8; +use endian; +use fmt; +use io; +use strings; +use trace; +use uuid; + +export type Decoder = struct { + input: []u8, + pos: size, + tracer: *trace::tracer, +}; + +export type Context = struct { + dec: *Decoder, + pos: size, + fmt: str, + fields: []fmt::field, + up: nullable *Context, +}; + +export fn log( + ctx: *Context, + lvl: trace::level, + fmt: str, + fields: fmt::field... +) void = { + log_(ctx, null, lvl, fmt, fields...); +}; + +fn log_( + ctx: *Context, + trace_ctx: nullable *trace::context, + lvl: trace::level, + fmt: str, + fields: fmt::field... +) void = { + let s = ""; + defer free(s); + if (len(ctx.fmt) != 0) { + s = fmt::asprintf(ctx.fmt, ctx.fields...); + // TODO: is this legal? works at the moment due to qbe + // semantics, but who knows... + trace_ctx = &trace::context { + fmt = "{} (offset {})", + fields = [s, ctx.pos], + next = trace_ctx, + }; + }; + match (ctx.up) { + case let ctx_: *Context => + log_(ctx_, trace_ctx, lvl, fmt, fields...); + case null => + trace::log(ctx.dec.tracer, trace_ctx, lvl, fmt, fields...); + }; +}; + +export fn error(ctx: *Context, fmt: str, fields: fmt::field...) trace::failed = { + log(ctx, trace::level::ERROR, fmt, fields...); + return trace::failed; +}; + +export fn root(dec: *Decoder) Context = { + return Context { + dec = dec, + pos = dec.pos, + ... + }; +}; + +export fn context(ctx: *Context, fmt: str, fields: fmt::field...) Context = { + return Context { + dec = ctx.dec, + pos = ctx.dec.pos, + fmt = fmt, + fields = fields, + up = ctx, + }; +}; + +export fn decode_nbytes(ctx: *Context, length: size) + ([]u8 | trace::failed) = { + const dec = ctx.dec; + if (len(dec.input) - dec.pos < length) { + return error(ctx, "Expected {} bytes, found only {}", + length, len(dec.input) - dec.pos); + }; + const res = dec.input[dec.pos..dec.pos + length]; + dec.pos += length; + return res; +}; + +export fn decode_byte(ctx: *Context) (u8 | trace::failed) = { + const b = decode_nbytes(ctx, 1)?; + return b[0]; +}; +export fn decode_short(ctx: *Context) (u16 | trace::failed) = { + const b = decode_nbytes(ctx, 2)?; + return endian::begetu16(b); +}; +export fn decode_int(ctx: *Context) (u32 | trace::failed) = { + const b = decode_nbytes(ctx, 4)?; + return endian::begetu32(b); +}; +export fn decode_long(ctx: *Context) (u64 | trace::failed) = { + const b = decode_nbytes(ctx, 8)?; + return endian::begetu64(b); +}; + +export fn decode_bool(ctx: *Context) (bool | trace::failed) = { + const b = decode_byte(ctx)?; + if (b >= 2) { + return error(ctx, "Invalid boolean"); + }; + return b != 0; +}; + +export fn decode_float(ctx: *Context) (f32 | trace::failed) = { + const v = decode_int(ctx)?; + return *(&v: *f32); +}; +export fn decode_double(ctx: *Context) (f64 | trace::failed) = { + const v = decode_long(ctx)?; + return *(&v: *f64); +}; + +export fn try_decode_varint(ctx: *Context) (i32 | !(trace::failed | TooShort)) = { + let res = 0u32; + const dec = ctx.dec; + + for (let i = 0u32; dec.pos + i < len(dec.input); i += 1) { + const b = dec.input[dec.pos + i]; + + if (i == 4 && b & 0xf0 != 0) { + return error(ctx, "VarInt too long"); + }; + + res |= (b & 0x7f): u32 << (7 * i); + + if (b & 0x80 == 0) { + dec.pos += i + 1; + return res: i32; + }; + }; + + return TooShort; +}; + +export fn decode_varint(ctx: *Context) (i32 | trace::failed) = { + match (try_decode_varint(ctx)) { + case let res: i32 => + return res; + case => + return error(ctx, "VarInt too short"); + }; +}; + +export fn decode_string(ctx: *Context, maxlen: size) (str | trace::failed) = { + const length = decode_varint(&context(ctx, "string length"))?; + const length = length: size; + + if (length >= maxlen * 4) { + return error(ctx, + "String length {} exceeds limit of {} bytes", + length, maxlen * 4); + }; + + const ctx_ = context(ctx, "string data ({} bytes)", length); + const bytes = decode_nbytes(&ctx_, length)?; + match (strings::fromutf8(bytes)) { + case let string: str => + // don't bother checking length in code points. doesn't seem + // very useful. + return string; + case utf8::invalid => + return error(&ctx_, "Invalid UTF-8"); + }; +}; + +export fn decode_uuid(ctx: *Context) (uuid::uuid | trace::failed) = { + let uuid: [16]u8 = [0...]; + uuid[..] = decode_nbytes(ctx, 16)?; + return uuid; +}; + +export fn expect_end(ctx: *Context) (void | trace::failed) = { + if (ctx.dec.pos != len(ctx.dec.input)) { + return error(ctx, + "Expected end of input, but found {} extra bytes starting at {}", + len(ctx.dec.input) - ctx.dec.pos, ctx.dec.pos); + }; +}; |