From 7c5ae90954fc1fafe15a4b302a368d65044c00fd Mon Sep 17 00:00:00 2001 From: Nicolas Derumigny <nderumigny@gmail.com> Date: Wed, 31 Mar 2021 16:09:26 +0200 Subject: [PATCH 01/12] benchmark/common: using argument-provided location for MEM_AREAs Using a lot the stack to save these values... --- src/pipedream/benchmark/common.py | 221 +++++++++++++++++------------- 1 file changed, 128 insertions(+), 93 deletions(-) diff --git a/src/pipedream/benchmark/common.py b/src/pipedream/benchmark/common.py index e901ece..76abeef 100644 --- a/src/pipedream/benchmark/common.py +++ b/src/pipedream/benchmark/common.py @@ -209,6 +209,8 @@ class Benchmark_Lib: (1, "num_papi_events", 0), (1, "papi_result_array", None), (1, "reg_values", None), + (1, "load_area", None), + (1, "store_area", None), ), ) fn.__name__ = fn_name @@ -239,6 +241,8 @@ class Benchmark_Lib: ctypes.c_ssize_t, ctypes.POINTER(ctypes.c_longlong), ctypes.POINTER(ctypes.c_byte), + ctypes.POINTER(ctypes.c_byte), + ctypes.POINTER(ctypes.c_byte), ) def __del__(self): @@ -1158,6 +1162,8 @@ class _Benchmark_Runner: num_events = out.allocate_argument(1) results = out.allocate_argument(2) reg_values = out.allocate_argument(3) + mem_load_area = out.allocate_argument(4) + mem_store_area = out.allocate_argument(5) out.newline() out.comment("*" * 70) @@ -1174,11 +1180,21 @@ class _Benchmark_Runner: out.comment("ARG num_events ", num_events) out.comment("ARG results ", results) out.comment("ARG reg_values ", reg_values) + out.comment("ARG mem_load_area ", mem_load_area) + out.comment("ARG mem_store_area ", mem_store_area) out.comment("free callee-saves for kernel") out.push_callee_saves() out.newline() + need_memory: bool = any(i.has_memory_operand() for i in benchmark.instructions) + if need_memory: + out.comment("saving memory areas") + out.push_to_stack(mem_load_area) + out.push_to_stack(mem_store_area) + else: + out.free_reg(mem_load_area) + out.free_reg(mem_store_area) SCRATCH_REG_1 = out.scratch_register(0) SCRATCH_REG_2 = out.scratch_register(1) @@ -1186,9 +1202,8 @@ class _Benchmark_Runner: SCRATCH_REG_4 = out.scratch_register(3) SCRATCH_REG_5 = out.scratch_register(4) # Set of registers that are currently used, but will not during the execution - # of the kernel (used during generation), and reverse + # of the kernel (=> used only during initialisation), and reverse unused_registers_kernel = set() - used_registers_kernel = set() from pipedream.asm.x86 import RDX @@ -1213,57 +1228,74 @@ class _Benchmark_Runner: out.comment("size of one row of results table in bytes") STRIDE = out.mul_reg_with_const(num_events, 8) - need_memory: bool = any(i.has_memory_operand() for i in benchmark.instructions) - if need_memory: - MEMORY_ARENA_LD = ir.Label(self._MEMORY_ARENA_LD + "@GOTPCREL(%rip)") - MEMORY_ARENA_ST = ir.Label(self._MEMORY_ARENA_ST + "@GOTPCREL(%rip)") out.newline() - out.comment("clear memory arena") + out.comment("clear mem_store_area") + # FIXME: other address sizes ## void *memset(void *s, int c, size_t n); s = out.get_argument_register(0) c = out.get_argument_register(1) n = out.get_argument_register(2) - out.put_const_in_register(MEMORY_ARENA_LD, s) + MEMORY_REG_STORE = out.ir_builder.select_memory_base_register( + benchmark.instructions, + set(out.iter_free_registers()), + 64, + ) + out.comment("mem_store -> ", MEMORY_REG_STORE) + # popping mem_store_area from the stack + out.pop_from_stack(s) + # 1) saves `mem_store_area` 2) frees `s` and 3) takes `MEMORY_REG_STORE` + out.move_to(s, MEMORY_REG_STORE) + # Saving `MEMORY_REG_STORE` and `reg_values` as they may be overwritten by + # the call + out.push_to_stack(MEMORY_REG_STORE) + out.push_to_stack(reg_values) out.put_const_in_register(0, c) out.put_const_in_register(self.memory_size(benchmark), n) out.call("memset@PLT", s, c, n) out.free_reg(s) out.free_reg(c) out.free_reg(n) - out.put_const_in_register(MEMORY_ARENA_ST, s) + # Restoring `reg_values` to leave the stack unchanged + out.pop_from_stack(reg_values) + out.pop_from_stack(MEMORY_REG_STORE) + + out.newline() + out.comment("clear mem_load_area") + MEMORY_REG_LOAD = out.ir_builder.select_memory_base_register( + benchmark.instructions, + set(out.iter_free_registers()) - set([MEMORY_REG_STORE]), + 64, + ) + out.comment("mem_store -> ", MEMORY_REG_LOAD) + # Popping `mem_load_area` from the stack + out.pop_from_stack(s) + # 1) saves `mem_load_area` 2) frees `s` and 3) takes `MEMORY_REG_LOAD` + out.move_to(s, MEMORY_REG_LOAD) + # saving `MEM_REG_LOAD`, `MEM_REG_STORE` and `reg_values` as they may be + # overwritten by the call + out.push_to_stack(MEMORY_REG_LOAD) + out.push_to_stack(MEMORY_REG_STORE) + out.push_to_stack(reg_values) out.put_const_in_register(0, c) out.put_const_in_register(self.memory_size(benchmark), n) out.call("memset@PLT", s, c, n) out.free_reg(s) out.free_reg(c) out.free_reg(n) + # Restoring `reg_values` to leave the stack unchanged + out.pop_from_stack(reg_values) + out.newline() + else: + MEMORY_REG_LOAD = None + MEMORY_REG_STORE = None with out.counting_loop("measurement", LOOP_COUNTER, num_iterations) as loop: out.comment("push loop counter") - out.push_to_stack(LOOP_COUNTER) - - ## Reserve registers for future use if need_memory: - # FIXME: other address sizes - ## We may use `reg_values` as it will be free before codegen - out.free_reg(reg_values) - MEMORY_REG_LOAD = out.ir_builder.select_memory_base_register( - benchmark.instructions, - set(out.iter_free_registers()), - 64, - ) - MEMORY_REG_STORE = out.ir_builder.select_memory_base_register( - benchmark.instructions, - set(out.iter_free_registers()) - set([MEMORY_REG_LOAD]), - 64, - ) - used_registers_kernel.add(MEMORY_REG_LOAD) - used_registers_kernel.add(MEMORY_REG_STORE) - out.take_reg(reg_values) - else: - MEMORY_REG_LOAD = None - MEMORY_REG_STORE = None + out.pop_from_stack(MEMORY_REG_STORE) + out.pop_from_stack(MEMORY_REG_LOAD) + out.push_to_stack(LOOP_COUNTER) ## Pre-allocate kernel and its related variables, as some register information may ## be used by prologue and/or pre-prologue @@ -1282,7 +1314,7 @@ class _Benchmark_Runner: MEMORY_REG_LOAD, MEMORY_REG_STORE, unused_registers_kernel, - used_registers_kernel, + set(), gen_iaca_markers=gen_iaca_markers, ) @@ -1298,6 +1330,9 @@ class _Benchmark_Runner: if gen_papi_calls: out.comment("push loop stride") out.push_to_stack(STRIDE) + if need_memory: + out.push_to_stack(MEMORY_REG_LOAD) + out.push_to_stack(MEMORY_REG_STORE) out.sequentialize_cpu() @@ -1309,26 +1344,27 @@ class _Benchmark_Runner: out.branch_if_not_zero(ret, loop.exit) # TODO: test this - out.comment("pop papi_event_set") + out.comment("push papi_event_set and papi_results") + if need_memory: + out.pop_from_stack(MEMORY_REG_STORE) + out.pop_from_stack(MEMORY_REG_LOAD) out.push_to_stack(papi_event_set) - out.comment("push papi_results") out.push_to_stack(results) + if need_memory: + out.push_to_stack(MEMORY_REG_LOAD) + out.push_to_stack(MEMORY_REG_STORE) out.sequentialize_cpu() ## allow prologue generator to see real instructions with allocated registers, etc. + if need_memory: + out.comment("pop load/store regions") + out.pop_from_stack(MEMORY_REG_STORE) + out.pop_from_stack(MEMORY_REG_LOAD) out.emit_benchmark_prologue( fully_allocated_kernel, ) - ## Actually use reserved registers - if need_memory: - out.comment("init pointers location") - out.take_reg(MEMORY_REG_LOAD) - out.take_reg(MEMORY_REG_STORE) - out.put_const_in_register(MEMORY_ARENA_LD, MEMORY_REG_LOAD) - out.put_const_in_register(MEMORY_ARENA_ST, MEMORY_REG_STORE) - ## free registers stolen by backend out.free_stolen_benchmark_registers(stolen_regs) @@ -1338,10 +1374,6 @@ class _Benchmark_Runner: ## actually emit kernel out.splice_in_code(kernel_code) - if MEMORY_REG_LOAD is not None and MEMORY_REG_STORE is not None: - out.free_reg(MEMORY_REG_LOAD) - out.free_reg(MEMORY_REG_STORE) - out.comment("END BENCHMARK") out.comment("*" * 40) @@ -1354,6 +1386,9 @@ class _Benchmark_Runner: out.pop_from_stack(results) out.comment("pop papi_event_set") out.pop_from_stack(papi_event_set) + if need_memory: + out.push_to_stack(MEMORY_REG_LOAD) + out.push_to_stack(MEMORY_REG_STORE) out.comment("stop & read PAPI counters") ret = out.call("PAPI_stop@PLT", papi_event_set, results) @@ -1365,15 +1400,34 @@ class _Benchmark_Runner: out.sequentialize_cpu() out.comment("pop stride") + if need_memory: + out.pop_from_stack(MEMORY_REG_STORE) + out.pop_from_stack(MEMORY_REG_LOAD) out.pop_from_stack(STRIDE) + if need_memory: + out.push_to_stack(MEMORY_REG_LOAD) + out.push_to_stack(MEMORY_REG_STORE) SCRATCH_REG_5 = out.scratch_register(4) + if need_memory: + out.pop_from_stack(MEMORY_REG_STORE) + out.pop_from_stack(MEMORY_REG_LOAD) out.emit_loop_epilogue(fully_allocated_kernel, SCRATCH_REG_5) + if need_memory: + out.push_to_stack(MEMORY_REG_LOAD) + out.push_to_stack(MEMORY_REG_STORE) out.comment("pop loop counter") + if need_memory: + out.pop_from_stack(MEMORY_REG_STORE) + out.pop_from_stack(MEMORY_REG_LOAD) + out.pop_from_stack(LOOP_COUNTER) - out.newline() + if need_memory: + out.push_to_stack(MEMORY_REG_LOAD) + out.push_to_stack(MEMORY_REG_STORE) + out.newline() out.add_registers(STRIDE, results) out.newline() @@ -1381,6 +1435,12 @@ class _Benchmark_Runner: with out.with_register(out.return_register()) as ret: out.put_const_in_register(0, ret) + if need_memory: + out.comment("Memory area are not needed anymore, discarding") + out.pop_from_stack(MEMORY_REG_STORE) + out.pop_from_stack(MEMORY_REG_LOAD) + out.free_reg(MEMORY_REG_LOAD, MEMORY_REG_STORE) + out.free_reg( papi_event_set, results, @@ -1492,8 +1552,6 @@ class _Benchmark_Runner: return out.take_code(), kernel_instructions - _MEMORY_ARENA_LD = "_memory_arena_ld" - _MEMORY_ARENA_ST = "_memory_arena_st" _PAGE_SIZE = 4096 def _gen_benchmark_lib( @@ -1572,50 +1630,7 @@ class _Benchmark_Runner: benchmark_functions[benchmark] = fn_name - ## calculate of size of memory arena - - memory_arena_size = 0 - - for b in benchmark_specs: - memory_size = ( - allocator.Maximize_Deps_Register_Allocator.memory_arena_size( - b.instructions - ) - * b.unroll_factor - ) - - memory_arena_size = max(memory_arena_size, memory_size) - - PAGE_SIZE = self._PAGE_SIZE - - ## round up to a multiple of page size - memory_arena_size = ( - memory_arena_size + PAGE_SIZE - memory_arena_size % PAGE_SIZE - ) - - ## add a padding page - memory_arena_size += PAGE_SIZE - - ## why not - memory_arena_size *= 2 - - asm_writer.global_byte_array( - self._MEMORY_ARENA_LD + "pad_before_", memory_arena_size, 4096 - ) - asm_writer.global_byte_array(self._MEMORY_ARENA_LD, memory_arena_size, 4096) - asm_writer.global_byte_array( - self._MEMORY_ARENA_LD + "pad_after_", memory_arena_size, 4096 - ) - - asm_writer.global_byte_array( - self._MEMORY_ARENA_ST + "pad_before_", memory_arena_size, 4096 - ) - asm_writer.global_byte_array(self._MEMORY_ARENA_ST, memory_arena_size, 4096) - asm_writer.global_byte_array( - self._MEMORY_ARENA_ST + "pad_after_", memory_arena_size, 4096 - ) - - asm_writer.end_file(asm_file) + asm_writer.end_file(asm_file) self.info("assemble benchmark library") @@ -1729,6 +1744,22 @@ class _Benchmark_Runner: dtype=ctypes.c_byte, order="C", ) + load_area = numpy.ndarray( + shape=[ + self.memory_size(benchmark), + 1, + ], + dtype=ctypes.c_byte, + order="C", + ) + store_area = numpy.ndarray( + shape=[ + self.memory_size(benchmark), + 1, + ], + dtype=ctypes.c_byte, + order="C", + ) random.seed(42) for i in range(benchmark.arch.nb_vector_reg * ALIGNEMENT - 1): @@ -1736,6 +1767,8 @@ class _Benchmark_Runner: init_values[i, j] = random.getrandbits(8) data = init_values.ctypes.data_as(ctypes.POINTER(ctypes.c_byte)) + load_area_data = load_area.ctypes.data_as(ctypes.POINTER(ctypes.c_byte)) + store_area_data = store_area.ctypes.data_as(ctypes.POINTER(ctypes.c_byte)) addr = ctypes.addressof(data) offset = 0 if addr % ALIGNEMENT == 0 else ALIGNEMENT - (addr % ALIGNEMENT) aligned_init_values = ctypes.POINTER(ctypes.c_byte).from_address( @@ -1748,6 +1781,8 @@ class _Benchmark_Runner: num_events, result_array.ctypes.data_as(ctypes.POINTER(ctypes.c_longlong)), aligned_init_values, + load_area_data, + store_area_data, ) time_after = time.perf_counter() -- GitLab From bc43d91d2d4262c88065b1e4313446e7f9ae95e7 Mon Sep 17 00:00:00 2001 From: Nicolas Derumigny <nderumigny@gmail.com> Date: Thu, 1 Apr 2021 14:30:50 +0200 Subject: [PATCH 02/12] benchmark/common: updating according to reviews --- src/pipedream/benchmark/common.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/pipedream/benchmark/common.py b/src/pipedream/benchmark/common.py index 76abeef..dab5ecb 100644 --- a/src/pipedream/benchmark/common.py +++ b/src/pipedream/benchmark/common.py @@ -1202,7 +1202,7 @@ class _Benchmark_Runner: SCRATCH_REG_4 = out.scratch_register(3) SCRATCH_REG_5 = out.scratch_register(4) # Set of registers that are currently used, but will not during the execution - # of the kernel (=> used only during initialisation), and reverse + # of the kernel (=> used only during initialisation) unused_registers_kernel = set() from pipedream.asm.x86 import RDX @@ -1346,11 +1346,14 @@ class _Benchmark_Runner: out.comment("push papi_event_set and papi_results") if need_memory: + out.comment("poping first mem_area regs") out.pop_from_stack(MEMORY_REG_STORE) out.pop_from_stack(MEMORY_REG_LOAD) + out.comment('"real" push') out.push_to_stack(papi_event_set) out.push_to_stack(results) if need_memory: + out.comment("pushing back mem_area regs") out.push_to_stack(MEMORY_REG_LOAD) out.push_to_stack(MEMORY_REG_STORE) -- GitLab From 7fc1bb54b1458e2e8a713238a7c6e72a6b370792 Mon Sep 17 00:00:00 2001 From: Nicolas Derumigny <nderumigny@gmail.com> Date: Thu, 1 Apr 2021 18:53:21 +0200 Subject: [PATCH 03/12] armv8a/registers: first implementation + corrections on x86 registers --- src/pipedream/asm/armv8a/__init__.py | 21 + src/pipedream/asm/armv8a/registers.py | 438 +++++++++ src/pipedream/asm/x86/registers.py | 1179 +++++++++++++------------ 3 files changed, 1053 insertions(+), 585 deletions(-) create mode 100644 src/pipedream/asm/armv8a/__init__.py create mode 100644 src/pipedream/asm/armv8a/registers.py diff --git a/src/pipedream/asm/armv8a/__init__.py b/src/pipedream/asm/armv8a/__init__.py new file mode 100644 index 0000000..a7622d5 --- /dev/null +++ b/src/pipedream/asm/armv8a/__init__.py @@ -0,0 +1,21 @@ +from pipedream.utils import * +from pipedream.asm.ir import * +from pipedream.asm.allocator import * +from pipedream.asm.asmwriter import * + +from pipedream.benchmark.types import Loop_Overhead + + +from . import registers +from . import operands + +from .asmwriter import * +from .operands import * +from .registers import * +from .instructions import * + +__all__ = [ + *registers.__all__, + # *operands.__all__, + # *instructions.__all__, +] diff --git a/src/pipedream/asm/armv8a/registers.py b/src/pipedream/asm/armv8a/registers.py new file mode 100644 index 0000000..ff9f2b9 --- /dev/null +++ b/src/pipedream/asm/armv8a/registers.py @@ -0,0 +1,438 @@ +from enum import Enum, auto +from typing import Optional, Set, Union, Tuple, List, FrozenSet, Dict, cast +from pipedream.utils import abc +from pipedream.asm.ir import * + +__all__ = [ + "ARMv8_Register_Set", + "ARMv8_Register", + "ANY_REGISTER", + "ARGUMENT_REGISTER", + "CALLER_SAVED", + "CALLEE_SAVED", +] + + +# Wrapper to allow multiple values in `Enum` (i.e. `Enum.A == Enum.B` but +# `Enum.A is not Enum.B`) +class Unique: + def __init__(self, value: int): + self.value = value + + +class _ARMv8_REGISTER_TYPE(Enum): + # Int + W = Unique(32) + X = Unique(64) + # FP + B = Unique(8) + H = Unique(16) + S = Unique(32) + D = Unique(64) + Q = Unique(128) + # Vect elements + VB = Unique(8) + VH = Unique(16) + VS = Unique(32) + VD = Unique(64) + # ARMv8.2 only (dot product) + V4B = Unique(32) + # Vects + V_8B = Unique(64) + V_4H = Unique(64) + V_2S = Unique(64) + V_16B = Unique(128) + V_8H = Unique(128) + V_4S = Unique(128) + V_2D = Unique(128) + # Other + WSP = Unique(32) + WZR = Unique(32) + SP = Unique(64) + ZXR = Unique(64) + LR = Unique(64) + + @property + def value(self): + return super().value.value + + +for type_ in _ARMv8_REGISTER_TYPE: + __all__.append(type_.name) + +_ARMv8_BASE_SIZES = [ + _ARMv8_REGISTER_TYPE.B, + _ARMv8_REGISTER_TYPE.H, + _ARMv8_REGISTER_TYPE.S, + _ARMv8_REGISTER_TYPE.D, + _ARMv8_REGISTER_TYPE.Q, +] + +_ARMv8_WIDTH_TO_V_ELMT_TYPE = { + _ARMv8_REGISTER_TYPE.B: _ARMv8_REGISTER_TYPE.VB, + _ARMv8_REGISTER_TYPE.H: _ARMv8_REGISTER_TYPE.VH, + _ARMv8_REGISTER_TYPE.S: _ARMv8_REGISTER_TYPE.VS, + _ARMv8_REGISTER_TYPE.D: _ARMv8_REGISTER_TYPE.VD, +} + +_ARMv8_V_ELMT_TYPE_TO_WIDTH = { + _ARMv8_REGISTER_TYPE.VB: _ARMv8_REGISTER_TYPE.B, + _ARMv8_REGISTER_TYPE.VH: _ARMv8_REGISTER_TYPE.H, + _ARMv8_REGISTER_TYPE.VS: _ARMv8_REGISTER_TYPE.S, + _ARMv8_REGISTER_TYPE.VD: _ARMv8_REGISTER_TYPE.D, +} + +_ARMv8_WIDTH_TO_VECT = { + (_ARMv8_REGISTER_TYPE.D, _ARMv8_REGISTER_TYPE.B): _ARMv8_REGISTER_TYPE.V_8B, + (_ARMv8_REGISTER_TYPE.D, _ARMv8_REGISTER_TYPE.H): _ARMv8_REGISTER_TYPE.V_4H, + (_ARMv8_REGISTER_TYPE.D, _ARMv8_REGISTER_TYPE.S): _ARMv8_REGISTER_TYPE.V_2S, + (_ARMv8_REGISTER_TYPE.Q, _ARMv8_REGISTER_TYPE.B): _ARMv8_REGISTER_TYPE.V_16B, + (_ARMv8_REGISTER_TYPE.Q, _ARMv8_REGISTER_TYPE.H): _ARMv8_REGISTER_TYPE.V_8H, + (_ARMv8_REGISTER_TYPE.Q, _ARMv8_REGISTER_TYPE.S): _ARMv8_REGISTER_TYPE.V_4S, + (_ARMv8_REGISTER_TYPE.Q, _ARMv8_REGISTER_TYPE.D): _ARMv8_REGISTER_TYPE.V_2D, +} + +_ARMv8_VECT_TO_SUBTYPE = { + _ARMv8_REGISTER_TYPE.V4B: _ARMv8_REGISTER_TYPE.B, + _ARMv8_REGISTER_TYPE.V_8B: _ARMv8_REGISTER_TYPE.B, + _ARMv8_REGISTER_TYPE.V_4H: _ARMv8_REGISTER_TYPE.H, + _ARMv8_REGISTER_TYPE.V_2S: _ARMv8_REGISTER_TYPE.S, + _ARMv8_REGISTER_TYPE.V_16B: _ARMv8_REGISTER_TYPE.B, + _ARMv8_REGISTER_TYPE.V_8H: _ARMv8_REGISTER_TYPE.H, + _ARMv8_REGISTER_TYPE.V_4S: _ARMv8_REGISTER_TYPE.S, + _ARMv8_REGISTER_TYPE.V_2D: _ARMv8_REGISTER_TYPE.D, +} + + +class ARMv8_Register(Register): + def __init__( + self, + idx: Optional[int], + sub_idx: Optional[int], + name: Optional[str], + type_: _ARMv8_REGISTER_TYPE, + subs: Set["ARMv8_Register"], + ): + assert idx is not None or name is not None + self._type = type_ + self._widest: "ARMv8_Register" = self + self._as_width: Dict[int, "ARMv8_Register"] = {self.width: self} + self._subs: Union[Set["ARMv8_Register"], Tuple] = subs.copy() + self._aliases: Union[Set["ARMv8_Register"], Tuple] = set() + self._supers: Union[Set["ARMv8_Register"], Tuple] = set() + if name: + self._name = name + else: + if self._type in _ARMv8_V_ELMT_TYPE_TO_WIDTH.keys(): + assert sub_idx is not None + subtype = _ARMv8_V_ELMT_TYPE_TO_WIDTH[self._type] + self._name = f"V{idx}.{subtype.name}[{sub_idx}]" + elif self._type in _ARMv8_VECT_TO_SUBTYPE.keys(): + assert sub_idx is None + subtype = _ARMv8_VECT_TO_SUBTYPE[self._type] + self._name = f"V{idx}.{self.width//subtype.value}{subtype.name}" + else: + assert sub_idx is None + self._name = f"{self._type.name}{idx}" + + @property # type:ignore + @abc.override + def name(self) -> str: + return self._name + + @property # type:ignore + @abc.override + def att_asm_name(self) -> str: + return self._name + + @property # type:ignore + @abc.override + def width(self): + """Bit width of register.""" + return self._type.value + + @abc.override + def as_width(self, bits: int): + return self._as_width[bits] + + @property # type:ignore + @abc.override + def widest(self) -> "ARMv8_Register": + return self._widest + + @property # type:ignore + @abc.override + def sub_registers(self): + return iter(self._subs) + + @property # type:ignore + @abc.override + def super_registers(self): + return iter(self._supers) + + @property # type:ignore + @abc.override + def aliases(self): + return iter(self._aliases) + + def __str__(self): + return type(self).__name__ + "(" + self.name + ")" + + def __repr__(self): + return type(self).__name__ + "(" + self.name + ")" + + def __deepcopy__(self, memo): + return self + + def freeze(self): + self._aliases = tuple(self._aliases | self._subs | self._supers) + self._subs = tuple(self._subs) + self._supers = tuple(self._supers) + + +_ALL_REGISTERS_: List[ARMv8_Register] = [] +_ALL_REGISTER_CLASSES_: List[Register_Class] = [] + +## hack to allow putting register objects in module scope +self = vars() + + +def _def_ARMv8_reg( + idx_or_name: Union[int, str, Tuple[int, int]], + type_: _ARMv8_REGISTER_TYPE, + sub_regs: List[ARMv8_Register] = [], + widest=False, +) -> ARMv8_Register: + sub_idx: Optional[int] = None + if isinstance(idx_or_name, int): + idx: Optional[int] = idx_or_name + name: Optional[str] = None + elif isinstance(idx_or_name, str): + idx = None + name = idx_or_name + elif isinstance(idx_or_name, tuple): + idx, sub_idx = idx_or_name + name = None + else: + assert False, "Wrong `idx_or_name` type" + reg = ARMv8_Register(idx, sub_idx, name, type_, set(sub_regs)) + assert reg.name not in self, "duplicate register def" + + self[reg.name] = reg + _ALL_REGISTERS_.append(reg) + __all__.append(reg.name) + if widest: + for sreg in sub_regs: + sreg._widest = reg + sreg._as_width[reg.width] = reg + return reg + + +def _def_ARMv8_register_class(name, *registers, **aliases): + ## put register aliases in module scope + for k, v in aliases.items(): + assert k not in globals() or globals()[k] is v + + globals()[k] = v + __all__.append(k) + + clss = Register_Class(name, registers, aliases) + _ALL_REGISTER_CLASSES_.append(clss) + globals()[name] = clss + __all__.append(name) + + +# All taken from https://developer.arm.com/-/media/Arm%20Developer%20Community/PDF/Learn%20the%20Architecture/Armv8-A%20Instruction%20Set%20Architecture.pdf?revision=ebf53406-04fd-4c67-a485-1b329febfb3e + +# 31 registers X0..X30 and their 32-bit part, W0..W30 +for idx in range(31): + w_part = _def_ARMv8_reg(idx, _ARMv8_REGISTER_TYPE.W) + _ = _def_ARMv8_reg(idx, _ARMv8_REGISTER_TYPE.X, [w_part], widest=True) + +# Link Register, which is a pure alias of X30 +LR = _def_ARMv8_reg("LR", _ARMv8_REGISTER_TYPE.LR) +assert isinstance(LR._aliases, set) +LR._aliases.add(self["X30"]) +assert isinstance(self["X30"]._aliases, set) +self["X30"]._aliases.add(LR) + + +def recurse_vect_fp_reg_add( + n: int, size_idx: int, vect_idx: int +) -> Tuple[List[ARMv8_Register], List[ARMv8_Register], List[ARMv8_Register]]: + # Returns (fp_regs, vect_regs, combined_vect) + if size_idx < 0: + return [], [], [] + + ret_fp_1, ret_vect_1, ret_combined = recurse_vect_fp_reg_add( + n, + size_idx - 1, + 2 * vect_idx, + ) + as_width_fp = ret_fp_1[-1]._as_width if len(ret_fp_1) > 0 else {} + as_width_vect: Dict[int, ARMv8_Register] = ( + ret_vect_1[-1]._as_width if len(ret_vect_1) > 0 else {} + ) + # no `ret_combined` on odd `vect_idx` + ret_fp_2, ret_vect_2, _ = recurse_vect_fp_reg_add( + n, + size_idx - 1, + 2 * vect_idx + 1, + ) + + ret_fp: List[ARMv8_Register] = ret_fp_1 + ret_fp_2 + ret_vect: List[ARMv8_Register] = ret_vect_1 + ret_vect_2 + # Used for aliases + same_size_fp: List[ARMv8_Register] = [] + same_size_vect: List[ARMv8_Register] = [] + same_size_combined: List[ARMv8_Register] = [] + + fp_type = _ARMv8_BASE_SIZES[size_idx] + # Adding Vn.Y[x] with Y = `vect_type` and x = `vect_idx` + # On the first iteration, there is no Rx.Q[0] (Vx.Q[0]) + vect_type_l: List[_ARMv8_REGISTER_TYPE] = [] + if size_idx != 4: + vect_type_l.append(_ARMv8_WIDTH_TO_V_ELMT_TYPE[fp_type]) + if size_idx == 2 and vect_idx == 0: + # ^^^ Sweet ARMv8.2 dot product: Vn.4B + vect_type_l.append(_ARMv8_REGISTER_TYPE.V4B) + for vect_type in vect_type_l: + new_vect = _def_ARMv8_reg( + (n, vect_idx) if vect_type != _ARMv8_REGISTER_TYPE.V4B else n, + vect_type, + ret_vect, + ) + # As the dict is a pointer, it is shared by the others subregisters, no need to + # propagate + as_width_vect[vect_type.value] = new_vect + new_vect._as_width = as_width_vect + for reg in as_width_vect.values(): + reg._widest = new_vect + same_size_vect.append(new_vect) + + # Adding Vn.xY: combined vector composed of vect of size Y = `fp_type2` (x + # integer automatically computed) + if vect_idx == 0: + if size_idx in [3, 4]: + for idx, subvect_type_idx in enumerate(range(size_idx)): + vect_type = _ARMv8_WIDTH_TO_VECT[ + fp_type, _ARMv8_BASE_SIZES[subvect_type_idx] + ] + new_combined = _def_ARMv8_reg(n, vect_type, ret_combined) + # Adding `as_width`, only for 128-bit combined vector regs (it includes + # only 64-bit combined vector regs) + # Still the same hack of shared dicts + if size_idx == 4 and idx < size_idx: + ret_combined[idx]._widest = new_combined + as_width = ret_combined[idx]._as_width + as_width[new_combined.width] = new_combined + new_combined._as_width = as_width + ret_combined.append(new_combined) + same_size_combined.append(new_combined) + + # Adding Yn: FP register of size Y = `fp_type` + new_fp = _def_ARMv8_reg(n, fp_type, ret_fp) + as_width_fp[fp_type.value] = new_fp + new_fp._as_width = as_width_fp + for reg in as_width_fp.values(): + reg._widest = new_fp + same_size_fp.append(new_fp) + # Set aliases: current reg aliases with former reg... + for reg in same_size_fp + same_size_vect + same_size_combined: + # Adding all aliases to `reg` except itself + assert isinstance(reg._aliases, set) + reg._aliases |= set( + same_size_fp + + same_size_vect + + same_size_combined + + ret_fp + + ret_vect + + ret_combined + ) - set([reg]) + # ... and reverse + for reg in ret_fp + ret_vect + ret_combined: + assert isinstance(reg._aliases, set) + reg._aliases |= set(same_size_fp + same_size_vect + same_size_combined) + + return ( + ret_fp + same_size_fp, + ret_vect + same_size_vect, + ret_combined + same_size_combined, + ) + + +# stack pointer +WSP = _def_ARMv8_reg("WSP", _ARMv8_REGISTER_TYPE.WSP) +SP = _def_ARMv8_reg("SP", _ARMv8_REGISTER_TYPE.SP, [WSP], widest=True) + +# zero registers +WZR = _def_ARMv8_reg("WZR", _ARMv8_REGISTER_TYPE.WZR) +ZXR = _def_ARMv8_reg( + "ZXR", _ARMv8_REGISTER_TYPE.ZXR, [WZR], widest=True # technically... +) + +# 32 FP / vector regs and their derivatives +for idx in range(32): + # 4 is `len(_ARMv8_FP_SIZES) - 1` + fp_regs, vect_regs, combined_vect_regs = recurse_vect_fp_reg_add(idx, 4, 0) + # Strange 4B vector element, ARMv8.2 dotproduct only. + +# Fill in super registers +_ALL_REGISTERS_.sort(key=lambda reg: reg.width) +for reg in _ALL_REGISTERS_: + for sub in reg._subs: + assert isinstance(sub._supers, set) + sub._supers.add(reg) + +# Set `as_width` and (which will set correctly `_alias`) +for reg in _ALL_REGISTERS_: + reg.freeze() + + +_def_ARMv8_register_class("ANY_REGISTER", *_ALL_REGISTERS_) + +# See "Procedure Call Standard for the Arm 64-bit Architecture" +_def_ARMv8_register_class( + "ARGUMENT_REGISTER", + *[self[f"X{i}"] for i in range(0, 7)], + **{f"ARG_{i}": self[f"X{i}"] for i in range(0, 7)}, +) + + +_def_ARMv8_register_class( + "CALLER_SAVED", + *[self[f"X{i}"] for i in range(0, 19)], + *[self[f"X{i}"] for i in range(28, 31)], +) + +_def_ARMv8_register_class("CALLEE_SAVED", *[self[f"X{i}"] for i in range(19, 28)]) + +for type_ in _ARMv8_REGISTER_TYPE: + _def_ARMv8_register_class( + type_.name, *filter(lambda reg: reg._type == type_, _ALL_REGISTERS_) + ) + +# End of hack +del self + + +class ARMv8_Register_Set(Register_Set): + @abc.override + def stack_pointer_register(self) -> Register: + return SP + + @abc.override + def register_classes(self) -> List[Register_Class]: + return _ALL_REGISTER_CLASSES_ + + @abc.override + def all_registers(self) -> Register_Class: + return ANY_REGISTER # type: ignore + + @abc.override + def argument_registers(self) -> Register_Class: + return ARGUMENT_REGISTER # type: ignore + + @abc.override + def callee_save_registers(self) -> Register_Class: + return CALLEE_SAVED # type:ignore diff --git a/src/pipedream/asm/x86/registers.py b/src/pipedream/asm/x86/registers.py index a54f596..e3ad6d6 100644 --- a/src/pipedream/asm/x86/registers.py +++ b/src/pipedream/asm/x86/registers.py @@ -1,106 +1,105 @@ +from typing import List -from pipedream.utils import abc -from pipedream.asm.ir import * +from pipedream.utils import abc +from pipedream.asm.ir import * __all__ = [ - 'X86_Register_Set', - - 'X86_Register', - - 'ANY_REGISTER', - 'ARGUMENT_REGISTER', - 'CALLER_SAVED', - 'CALLEE_SAVED', - 'GPR', - 'GPR64', - 'GPR32', - 'GPR16', - 'GPR8', - 'GPR8NOREX', - 'VR', - 'VR64', - 'VR128', - 'VRX128', - 'VR256', - 'VRX256', - 'VRX512', - 'CL_REGISTER', - 'FLAGS_REGISTER', - 'PC_REGISTER', - 'STACK_POINTER_REGISTER', - 'BASE_REGISTER_64', - 'BASE_REGISTER_32', - 'BASE_REGISTER_16', - 'INDEX_REGISTER_64', - 'INDEX_REGISTER_32', - 'INDEX_REGISTER_16', + "X86_Register_Set", + "X86_Register", + "ANY_REGISTER", + "ARGUMENT_REGISTER", + "CALLER_SAVED", + "CALLEE_SAVED", + "GPR", + "GPR64", + "GPR32", + "GPR16", + "GPR8", + "GPR8NOREX", + "VR", + "VR64", + "VR128", + "VRX128", + "VR256", + "VRX256", + "VRX512", + "CL_REGISTER", + "FLAGS_REGISTER", + "PC_REGISTER", + "STACK_POINTER_REGISTER", + "BASE_REGISTER_64", + "BASE_REGISTER_32", + "BASE_REGISTER_16", + "INDEX_REGISTER_64", + "INDEX_REGISTER_32", + "INDEX_REGISTER_16", ] class X86_Register(Register): - def __init__(self, name: str, att_asm_name: str, idx: int, width: int, subs: set): - self._idx = idx - self._name = name - self._att_asm_name = att_asm_name - self._width = width + def __init__(self, name: str, att_asm_name: str, idx: int, width: int, subs: set): + self._idx = idx + self._name = name + self._att_asm_name = att_asm_name + self._width = width - self._subs = set(subs) - self._supers = set() + self._subs = set(subs) + self._supers = set() - @property - @abc.override - def name(self) -> str: - return self._name + @property + @abc.override + def name(self) -> str: + return self._name - @property - @abc.override - def att_asm_name(self) -> str: - return self._att_asm_name + @property + @abc.override + def att_asm_name(self) -> str: + return self._att_asm_name - @property - @abc.override - def width(self): - """Bit width of register.""" - return self._width + @property + @abc.override + def width(self): + """Bit width of register.""" + return self._width - @abc.override - def as_width(self, bits: int) -> 'X86_Register': - return self._as_width[bits] + @abc.override + def as_width(self, bits: int) -> "X86_Register": + return self._as_width[bits] - @property - @abc.override - def widest(self) -> 'X86_Register': - return self._widest + @property + @abc.override + def widest(self) -> "X86_Register": + return self._widest - @property - @abc.override - def sub_registers(self): - return iter(self._subs) + @property + @abc.override + def sub_registers(self): + return iter(self._subs) - @property - @abc.override - def super_registers(self): - return iter(self._supers) + @property + @abc.override + def super_registers(self): + return iter(self._supers) - @property - @abc.override - def aliases(self): - return iter(self._aliases) + @property + @abc.override + def aliases(self): + return iter(self._aliases) - def __lt__(self, other): - if type(self) is not type(other): - return NotImplemented + def __lt__(self, other): + if type(self) is not type(other): + return NotImplemented - return self._idx < other._idx + return self._idx < other._idx - def __str__(self): - return type(self).__name__ + '(' + self.name + ')' + def __str__(self): + return type(self).__name__ + "(" + self.name + ")" - def __repr__(self): - return type(self).__name__ + '(' + self.name + ')' + def __repr__(self): + return type(self).__name__ + "(" + self.name + ")" - def __deepcopy__(self, memo): - return self + def __deepcopy__(self, memo): + return self _ALL_REGISTERS_ = [] @@ -110,17 +109,17 @@ self = vars() def _def_x86_reg(name: str, width: int, *immediate_sub_regs, att_asm_name: str = None): - assert name not in self, 'duplicate register def' + assert name not in self, "duplicate register def" - if not att_asm_name: - att_asm_name = '%' + name.lower() + if not att_asm_name: + att_asm_name = "%" + name.lower() - idx = len(_ALL_REGISTERS_) - reg = X86_Register(name, att_asm_name, idx, width, frozenset(immediate_sub_regs)) + idx = len(_ALL_REGISTERS_) + reg = X86_Register(name, att_asm_name, idx, width, frozenset(immediate_sub_regs)) - self[name] = reg - _ALL_REGISTERS_.append(reg) - __all__.append(reg.name) + self[name] = reg + _ALL_REGISTERS_.append(reg) + __all__.append(reg.name) ## in order of encoding @@ -129,181 +128,181 @@ def _def_x86_reg(name: str, width: int, *immediate_sub_regs, att_asm_name: str = # TODO: AVX write mask registers k0-k7 # Return value, caller-saved -_def_x86_reg('RAX', 64, 'EAX') -_def_x86_reg('EAX', 32, 'AX') -_def_x86_reg('AX', 16, 'AL', 'AH') -_def_x86_reg('AL', 8) -_def_x86_reg('AH', 8) +_def_x86_reg("RAX", 64, "EAX") +_def_x86_reg("EAX", 32, "AX") +_def_x86_reg("AX", 16, "AL", "AH") +_def_x86_reg("AL", 8) +_def_x86_reg("AH", 8) # 4th argument, caller-saved -_def_x86_reg('RCX', 64, 'ECX') -_def_x86_reg('ECX', 32, 'CX') -_def_x86_reg('CX', 16, 'CL', 'AH') -_def_x86_reg('CL', 8) -_def_x86_reg('CH', 8) +_def_x86_reg("RCX", 64, "ECX") +_def_x86_reg("ECX", 32, "CX") +_def_x86_reg("CX", 16, "CL", "CH") +_def_x86_reg("CL", 8) +_def_x86_reg("CH", 8) # 3rd argument, caller-saved -_def_x86_reg('RDX', 64, 'EDX') -_def_x86_reg('EDX', 32, 'DX') -_def_x86_reg('DX', 16, 'DL', 'DH') -_def_x86_reg('DL', 8) -_def_x86_reg('DH', 8) +_def_x86_reg("RDX", 64, "EDX") +_def_x86_reg("EDX", 32, "DX") +_def_x86_reg("DX", 16, "DL", "DH") +_def_x86_reg("DL", 8) +_def_x86_reg("DH", 8) # Local variable, callee-saved -_def_x86_reg('RBX', 64, 'EBX') -_def_x86_reg('EBX', 32, 'BX') -_def_x86_reg('BX', 16, 'BL', 'BH') -_def_x86_reg('BL', 8) -_def_x86_reg('BH', 8) +_def_x86_reg("RBX", 64, "EBX") +_def_x86_reg("EBX", 32, "BX") +_def_x86_reg("BX", 16, "BL", "BH") +_def_x86_reg("BL", 8) +_def_x86_reg("BH", 8) # Stack pointer, callee-saved -_def_x86_reg('RSP', 64, 'ESP') -_def_x86_reg('ESP', 32, 'SP') -_def_x86_reg('SP', 16, 'SPL') -_def_x86_reg('SPL', 8) +_def_x86_reg("RSP", 64, "ESP") +_def_x86_reg("ESP", 32, "SP") +_def_x86_reg("SP", 16, "SPL") +_def_x86_reg("SPL", 8) # Local variable, callee-saved -_def_x86_reg('RBP', 64, 'EBP') -_def_x86_reg('EBP', 32, 'BP') -_def_x86_reg('BP', 16, 'BPL') -_def_x86_reg('BPL', 8) +_def_x86_reg("RBP", 64, "EBP") +_def_x86_reg("EBP", 32, "BP") +_def_x86_reg("BP", 16, "BPL") +_def_x86_reg("BPL", 8) # 2nd argument, caller-saved -_def_x86_reg('RSI', 64, 'ESI') -_def_x86_reg('ESI', 32, 'SI') -_def_x86_reg('SI', 16, 'SIL') -_def_x86_reg('SIL', 8) +_def_x86_reg("RSI", 64, "ESI") +_def_x86_reg("ESI", 32, "SI") +_def_x86_reg("SI", 16, "SIL") +_def_x86_reg("SIL", 8) # 1st argument, caller-saved -_def_x86_reg('RDI', 64, 'EDI') -_def_x86_reg('EDI', 32, 'DI') -_def_x86_reg('DI', 16, 'DIL') -_def_x86_reg('DIL', 8) +_def_x86_reg("RDI", 64, "EDI") +_def_x86_reg("EDI", 32, "DI") +_def_x86_reg("DI", 16, "DIL") +_def_x86_reg("DIL", 8) # 5th argument, caller-saved -_def_x86_reg('R8', 64, 'R8D') -_def_x86_reg('R8D', 32, 'R8W') -_def_x86_reg('R8W', 16, 'R8B') -_def_x86_reg('R8B', 8) +_def_x86_reg("R8", 64, "R8D") +_def_x86_reg("R8D", 32, "R8W") +_def_x86_reg("R8W", 16, "R8B") +_def_x86_reg("R8B", 8) # 6th argument, caller-saved -_def_x86_reg('R9', 64, 'R9D') -_def_x86_reg('R9D', 32, 'R9W') -_def_x86_reg('R9W', 16, 'R9B') -_def_x86_reg('R9B', 8) +_def_x86_reg("R9", 64, "R9D") +_def_x86_reg("R9D", 32, "R9W") +_def_x86_reg("R9W", 16, "R9B") +_def_x86_reg("R9B", 8) # Scratch/temporary, caller-saved -_def_x86_reg('R10', 64, 'R10D') -_def_x86_reg('R10D', 32, 'R10W') -_def_x86_reg('R10W', 16, 'R10B') -_def_x86_reg('R10B', 8) +_def_x86_reg("R10", 64, "R10D") +_def_x86_reg("R10D", 32, "R10W") +_def_x86_reg("R10W", 16, "R10B") +_def_x86_reg("R10B", 8) # Scratch/temporary, caller-saved -_def_x86_reg('R11', 64, 'R11D') -_def_x86_reg('R11D', 32, 'R11W') -_def_x86_reg('R11W', 16, 'R11B') -_def_x86_reg('R11B', 8) +_def_x86_reg("R11", 64, "R11D") +_def_x86_reg("R11D", 32, "R11W") +_def_x86_reg("R11W", 16, "R11B") +_def_x86_reg("R11B", 8) # Local variable, callee-saved -_def_x86_reg('R12', 64, 'R12D') -_def_x86_reg('R12D', 32, 'R12W') -_def_x86_reg('R12W', 16, 'R12B') -_def_x86_reg('R12B', 8) +_def_x86_reg("R12", 64, "R12D") +_def_x86_reg("R12D", 32, "R12W") +_def_x86_reg("R12W", 16, "R12B") +_def_x86_reg("R12B", 8) # Local variable, callee-saved -_def_x86_reg('R13', 64, 'R13D') -_def_x86_reg('R13D', 32, 'R13W') -_def_x86_reg('R13W', 16, 'R13B') -_def_x86_reg('R13B', 8) +_def_x86_reg("R13", 64, "R13D") +_def_x86_reg("R13D", 32, "R13W") +_def_x86_reg("R13W", 16, "R13B") +_def_x86_reg("R13B", 8) # Local variable, callee-saved -_def_x86_reg('R14', 64, 'R14D') -_def_x86_reg('R14D', 32, 'R14W') -_def_x86_reg('R14W', 16, 'R14B') -_def_x86_reg('R14B', 8) +_def_x86_reg("R14", 64, "R14D") +_def_x86_reg("R14D", 32, "R14W") +_def_x86_reg("R14W", 16, "R14B") +_def_x86_reg("R14B", 8) # Local variable, callee-saved -_def_x86_reg('R15', 64, 'R15D') -_def_x86_reg('R15D', 32, 'R15W') -_def_x86_reg('R15W', 16, 'R15B') -_def_x86_reg('R15B', 8) +_def_x86_reg("R15", 64, "R15D") +_def_x86_reg("R15D", 32, "R15W") +_def_x86_reg("R15W", 16, "R15B") +_def_x86_reg("R15B", 8) # Instruction pointer -_def_x86_reg('RIP', 64, 'EIP') -_def_x86_reg('EIP', 32, 'IP') -_def_x86_reg('IP', 16) +_def_x86_reg("RIP", 64, "EIP") +_def_x86_reg("EIP", 32, "IP") +_def_x86_reg("IP", 16) # Status/condition code bits -_def_x86_reg('RFLAGS', 64, 'EFLAGS') -_def_x86_reg('EFLAGS', 32, 'FLAGS') -_def_x86_reg('FLAGS', 16) +_def_x86_reg("RFLAGS", 64, "EFLAGS") +_def_x86_reg("EFLAGS", 32, "FLAGS") +_def_x86_reg("FLAGS", 16) # weird system registers # TODO: fix bit widths -_def_x86_reg('CS', 16) -_def_x86_reg('DS', 16) -_def_x86_reg('ES', 16) -_def_x86_reg('FS', 16) -_def_x86_reg('GS', 16) -_def_x86_reg('SS', 16) -_def_x86_reg('FSBASE', 16) -_def_x86_reg('GSBASE', 16) -_def_x86_reg('TR', 16) -_def_x86_reg('LDTR', 16) -_def_x86_reg('GDTR', 16) -_def_x86_reg('SSP', 16) -_def_x86_reg('TSC', 32) -_def_x86_reg('TSCAUX', 16) -_def_x86_reg('X87STATUS', 16) -_def_x86_reg('X87TAG', 16) -_def_x86_reg('X87CONTROL', 16) -_def_x86_reg('MXCSR', 32) -_def_x86_reg('MSRS', 64) +_def_x86_reg("CS", 16) +_def_x86_reg("DS", 16) +_def_x86_reg("ES", 16) +_def_x86_reg("FS", 16) +_def_x86_reg("GS", 16) +_def_x86_reg("SS", 16) +_def_x86_reg("FSBASE", 16) +_def_x86_reg("GSBASE", 16) +_def_x86_reg("TR", 16) +_def_x86_reg("LDTR", 16) +_def_x86_reg("GDTR", 16) +_def_x86_reg("SSP", 16) +_def_x86_reg("TSC", 32) +_def_x86_reg("TSCAUX", 16) +_def_x86_reg("X87STATUS", 16) +_def_x86_reg("X87TAG", 16) +_def_x86_reg("X87CONTROL", 16) +_def_x86_reg("MXCSR", 32) +_def_x86_reg("MSRS", 64) # CPU control registers # TODO: fix bit widths for i in range(16): - _def_x86_reg(f'CR{i}', 16) -_def_x86_reg('XCR0', 32) + _def_x86_reg(f"CR{i}", 16) +_def_x86_reg("XCR0", 32) for i in range(8): - ## FIXME: this is probably wrong - _def_x86_reg(f'MM{i}', 64) - _def_x86_reg(f'ST{i}', 80, f'MM{i}', att_asm_name=f'%st({i})') + ## FIXME: this is probably wrong + _def_x86_reg(f"MM{i}", 64) + _def_x86_reg(f"ST{i}", 80, f"MM{i}", att_asm_name=f"%st({i})") for i in range(32): - _def_x86_reg(f'XMM{i}', 128) + _def_x86_reg(f"XMM{i}", 128) for i in range(32): - _def_x86_reg(f'YMM{i}', 256, f'XMM{i}') + _def_x86_reg(f"YMM{i}", 256, f"XMM{i}") for i in range(32): - _def_x86_reg(f'ZMM{i}', 512, f'YMM{i}') + _def_x86_reg(f"ZMM{i}", 512, f"YMM{i}") ## fill in sub/super registers for reg in _ALL_REGISTERS_: - subs = set() + subs = set() - for sub in reg._subs: - if type(sub) is str: - sub = self[sub] + for sub in reg._subs: + if type(sub) is str: + sub = self[sub] - assert isinstance(sub, X86_Register) - subs.add(sub) - sub._supers.add(reg) + assert isinstance(sub, X86_Register) + subs.add(sub) + sub._supers.add(reg) - reg._subs = subs + reg._subs = subs for reg in _ALL_REGISTERS_: - reg._supers = tuple(sorted(reg._supers)) - reg._subs = tuple(sorted(reg._subs)) - reg._aliases = tuple(sorted(set(reg.all_sub_registers) | set(reg.all_super_registers))) + reg._supers = tuple(sorted(reg._supers)) + reg._subs = tuple(sorted(reg._subs)) + reg._aliases = tuple( + sorted(set(reg.all_sub_registers) | set(reg.all_super_registers)) + ) - as_width = { - reg.width: reg - } - widest = reg + as_width = {reg.width: reg} + widest = reg - for alias in reg._aliases: - if alias in [AH, BH, CH, DH]: - continue + for alias in reg._aliases: + if alias in [AH, BH, CH, DH]: + continue - alias.width not in as_width, (reg, alias, as_width) + assert alias.width not in as_width, (reg, alias, as_width) - as_width[alias.width] = alias + as_width[alias.width] = alias - if alias.width > widest.width: - widest = alias + if alias.width > widest.width: + widest = alias - reg._as_width = as_width - reg._widest = widest + reg._as_width = as_width + reg._widest = widest del self @@ -315,424 +314,434 @@ _ALL_REGISTER_CLASSES_ = [] def _def_x86_register_class(name, *registers, **aliases): - ## put register aliases in module scope - for k, v in aliases.items(): - assert k not in globals() or globals()[k] is v + ## put register aliases in module scope + for k, v in aliases.items(): + assert k not in globals() or globals()[k] is v - globals()[k] = v - __all__.append(k) + globals()[k] = v + __all__.append(k) - clss = Register_Class(name, registers, aliases) - _ALL_REGISTER_CLASSES_.append(clss) - globals()[name] = clss - __all__.append(name) + clss = Register_Class(name, registers, aliases) + _ALL_REGISTER_CLASSES_.append(clss) + globals()[name] = clss + __all__.append(name) -_def_x86_register_class( - 'ANY_REGISTER', - *_ALL_REGISTERS_ -) +_def_x86_register_class("ANY_REGISTER", *_ALL_REGISTERS_) _def_x86_register_class( - 'ARGUMENT_REGISTER', - RDI, RSI, RDX, RCX, R8, R9, - ARG_1=RDI, - ARG_2=RSI, - ARG_3=RDX, - ARG_4=RCX, - ARG_5=R8, - ARG_6=R9, + "ARGUMENT_REGISTER", + RDI, + RSI, + RDX, + RCX, + R8, + R9, + ARG_1=RDI, + ARG_2=RSI, + ARG_3=RDX, + ARG_4=RCX, + ARG_5=R8, + ARG_6=R9, ) -_def_x86_register_class( - 'CALLER_SAVED', - RAX, RDI, RSI, RDX, RCX, R8, R9, R10, R11 -) +_def_x86_register_class("CALLER_SAVED", RAX, RDI, RSI, RDX, RCX, R8, R9, R10, R11) _def_x86_register_class( - 'CALLEE_SAVED', - RBX, RBP, R12, R13, R14, R15, + "CALLEE_SAVED", + RBX, + RBP, + R12, + R13, + R14, + R15, ) _def_x86_register_class( - 'GPR', - RAX, - EAX, - AX, - AL, - # 1st argument, caller-saved - RDI, - EDI, - DI, - DIL, - # 2nd argument, caller-saved - RSI, - ESI, - SI, - SIL, - # 3rd argument, caller-saved - RDX, - EDX, - DX, - DL, - # 4th argument, caller-saved - RCX, - ECX, - CX, - CL, - # 5th argument, caller-saved - R8, - R8D, - R8W, - R8B, - # 6th argument, caller-saved - R9, - R9D, - R9W, - R9B, - # Scratch/temporary, caller-saved - R10, - R10D, - R10W, - R10B, - # Scratch/temporary, caller-saved - R11, - R11D, - R11W, - R11B, - # Stack pointer, callee-saved - RSP, - ESP, - SP, - SPL, - # Local variable, callee-saved - RBX, - EBX, - BX, - BL, - # Local variable, callee-saved - RBP, - EBP, - BP, - BPL, - # Local variable, callee-saved - R12, - R12D, - R12W, - R12B, - # Local variable, callee-saved - R13, - R13D, - R13W, - R13B, - # Local variable, callee-saved - R14, - R14D, - R14W, - R14B, - # Local variable, callee-saved - R15, - R15D, - R15W, - R15B, + "GPR", + RAX, + EAX, + AX, + AL, + # 1st argument, caller-saved + RDI, + EDI, + DI, + DIL, + # 2nd argument, caller-saved + RSI, + ESI, + SI, + SIL, + # 3rd argument, caller-saved + RDX, + EDX, + DX, + DL, + # 4th argument, caller-saved + RCX, + ECX, + CX, + CL, + # 5th argument, caller-saved + R8, + R8D, + R8W, + R8B, + # 6th argument, caller-saved + R9, + R9D, + R9W, + R9B, + # Scratch/temporary, caller-saved + R10, + R10D, + R10W, + R10B, + # Scratch/temporary, caller-saved + R11, + R11D, + R11W, + R11B, + # Stack pointer, callee-saved + RSP, + ESP, + SP, + SPL, + # Local variable, callee-saved + RBX, + EBX, + BX, + BL, + # Local variable, callee-saved + RBP, + EBP, + BP, + BPL, + # Local variable, callee-saved + R12, + R12D, + R12W, + R12B, + # Local variable, callee-saved + R13, + R13D, + R13W, + R13B, + # Local variable, callee-saved + R14, + R14D, + R14W, + R14B, + # Local variable, callee-saved + R15, + R15D, + R15W, + R15B, ) -_def_x86_register_class( - 'GPR64', - *[r for r in GPR if r.width == 64] -) +_def_x86_register_class("GPR64", *[r for r in GPR if r.width == 64]) -_def_x86_register_class( - 'GPR32', - *[r for r in GPR if r.width == 32] -) +_def_x86_register_class("GPR32", *[r for r in GPR if r.width == 32]) -_def_x86_register_class( - 'GPR16', - *[r for r in GPR if r.width == 16] -) +_def_x86_register_class("GPR16", *[r for r in GPR if r.width == 16]) _def_x86_register_class( - 'GPR8', - *[r for r in GPR if r.width == 8 if r not in (AH, BH, CH, DH)] + "GPR8", *[r for r in GPR if r.width == 8 if r not in (AH, BH, CH, DH)] ) _def_x86_register_class( - 'GPR8NOREX', - # AL, BL, CL, DL, - AH, BH, CH, DH, + "GPR8NOREX", + # AL, BL, CL, DL, + AH, + BH, + CH, + DH, ) -_def_x86_register_class( - 'FPST', - ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7 -) +_def_x86_register_class("FPST", ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7) _def_x86_register_class( - 'VR', - MM0, - MM1, - MM2, - MM3, - MM4, - MM5, - MM6, - MM7, - - XMM0, - XMM1, - XMM2, - XMM3, - XMM4, - XMM5, - XMM6, - XMM7, - XMM8, - XMM9, - XMM10, - XMM11, - XMM12, - XMM13, - XMM14, - XMM15, - XMM16, - XMM17, - XMM18, - XMM19, - XMM20, - XMM21, - XMM22, - XMM23, - XMM24, - XMM25, - XMM26, - XMM27, - XMM28, - XMM29, - XMM30, - XMM31, - - YMM0, - YMM1, - YMM2, - YMM3, - YMM4, - YMM5, - YMM6, - YMM7, - YMM8, - YMM9, - YMM10, - YMM11, - YMM12, - YMM13, - YMM14, - YMM15, - YMM16, - YMM17, - YMM18, - YMM19, - YMM20, - YMM21, - YMM22, - YMM23, - YMM24, - YMM25, - YMM26, - YMM27, - YMM28, - YMM29, - YMM30, - YMM31, - - ZMM0, - ZMM1, - ZMM2, - ZMM3, - ZMM4, - ZMM5, - ZMM6, - ZMM7, - ZMM8, - ZMM9, - ZMM10, - ZMM11, - ZMM12, - ZMM13, - ZMM14, - ZMM15, - ZMM16, - ZMM17, - ZMM18, - ZMM19, - ZMM20, - ZMM21, - ZMM22, - ZMM23, - ZMM24, - ZMM25, - ZMM26, - ZMM27, - ZMM28, - ZMM29, - ZMM30, - ZMM31, + "VR", + MM0, + MM1, + MM2, + MM3, + MM4, + MM5, + MM6, + MM7, + XMM0, + XMM1, + XMM2, + XMM3, + XMM4, + XMM5, + XMM6, + XMM7, + XMM8, + XMM9, + XMM10, + XMM11, + XMM12, + XMM13, + XMM14, + XMM15, + XMM16, + XMM17, + XMM18, + XMM19, + XMM20, + XMM21, + XMM22, + XMM23, + XMM24, + XMM25, + XMM26, + XMM27, + XMM28, + XMM29, + XMM30, + XMM31, + YMM0, + YMM1, + YMM2, + YMM3, + YMM4, + YMM5, + YMM6, + YMM7, + YMM8, + YMM9, + YMM10, + YMM11, + YMM12, + YMM13, + YMM14, + YMM15, + YMM16, + YMM17, + YMM18, + YMM19, + YMM20, + YMM21, + YMM22, + YMM23, + YMM24, + YMM25, + YMM26, + YMM27, + YMM28, + YMM29, + YMM30, + YMM31, + ZMM0, + ZMM1, + ZMM2, + ZMM3, + ZMM4, + ZMM5, + ZMM6, + ZMM7, + ZMM8, + ZMM9, + ZMM10, + ZMM11, + ZMM12, + ZMM13, + ZMM14, + ZMM15, + ZMM16, + ZMM17, + ZMM18, + ZMM19, + ZMM20, + ZMM21, + ZMM22, + ZMM23, + ZMM24, + ZMM25, + ZMM26, + ZMM27, + ZMM28, + ZMM29, + ZMM30, + ZMM31, ) -_def_x86_register_class( - 'VR64', - *[r for r in VR if r.width == 64] -) +_def_x86_register_class("VR64", *[r for r in VR if r.width == 64]) -_def_x86_register_class( - 'VRX128', - *[r for r in VR if r.width == 128] -) +_def_x86_register_class("VRX128", *[r for r in VR if r.width == 128]) -_def_x86_register_class( - 'VRX256', - *[r for r in VR if r.width == 256] -) +_def_x86_register_class("VRX256", *[r for r in VR if r.width == 256]) -_def_x86_register_class( - 'VRX512', - *[r for r in VR if r.width == 512] -) +_def_x86_register_class("VRX512", *[r for r in VR if r.width == 512]) -_def_x86_register_class( - 'VR128', - *VRX128[0:16] -) -_def_x86_register_class( - 'VR256', - *VRX256[0:16] -) +_def_x86_register_class("VR128", *VRX128[0:16]) +_def_x86_register_class("VR256", *VRX256[0:16]) -_def_x86_register_class( - 'CL_REGISTER', - CL -) +_def_x86_register_class("CL_REGISTER", CL) _def_x86_register_class( - 'STACK_POINTER_REGISTER', - SP, ESP, RSP, + "STACK_POINTER_REGISTER", + SP, + ESP, + RSP, ) _def_x86_register_class( - 'FLAGS_REGISTER', - FLAGS, EFLAGS, RFLAGS, + "FLAGS_REGISTER", + FLAGS, + EFLAGS, + RFLAGS, ) _def_x86_register_class( - 'PC_REGISTER', - IP, EIP, RIP, + "PC_REGISTER", + IP, + EIP, + RIP, ) ## singleton register classes for reg in [ - FLAGS, EFLAGS, RFLAGS, - IP, EIP, RIP, - SP, ESP, RSP, - - CS, DS, ES, FS, GS, SS, - FSBASE, GSBASE, - TR, LDTR, GDTR, - - SSP, - - TSCAUX, - - CR0, CR1, CR2, CR3, CR4, CR5, CR6, CR7, CR8, CR9, CR10, CR11, CR12, CR13, CR14, CR15, - XCR0, - - AL, - CL, - - AH, - - AX, - BX, - CX, - DX, - DI, - SI, - BP, - SP, - - EAX, - EBX, - ECX, - EDX, - EDI, - ESI, - EBP, - ESP, - - RAX, - RBX, - RCX, - RDX, - RDI, - RSI, - RBP, - RSP, - R8, - R9, - R10, - R11, - R12, - R13, - R14, - R15, - - XMM0, - - ST0, - ST1, - ST2, - ST3, - ST4, - ST5, - ST6, - ST7, - - X87STATUS, - X87CONTROL, - X87TAG, - MXCSR, - TSC, - MSRS, + FLAGS, + EFLAGS, + RFLAGS, + IP, + EIP, + RIP, + SP, + ESP, + RSP, + CS, + DS, + ES, + FS, + GS, + SS, + FSBASE, + GSBASE, + TR, + LDTR, + GDTR, + SSP, + TSCAUX, + CR0, + CR1, + CR2, + CR3, + CR4, + CR5, + CR6, + CR7, + CR8, + CR9, + CR10, + CR11, + CR12, + CR13, + CR14, + CR15, + XCR0, + AL, + CL, + AH, + AX, + BX, + CX, + DX, + DI, + SI, + BP, + SP, + EAX, + EBX, + ECX, + EDX, + EDI, + ESI, + EBP, + ESP, + RAX, + RBX, + RCX, + RDX, + RDI, + RSI, + RBP, + RSP, + R8, + R9, + R10, + R11, + R12, + R13, + R14, + R15, + XMM0, + ST0, + ST1, + ST2, + ST3, + ST4, + ST5, + ST6, + ST7, + X87STATUS, + X87CONTROL, + X87TAG, + MXCSR, + TSC, + MSRS, ]: - _def_x86_register_class('RC_' + reg.name, reg) + _def_x86_register_class("RC_" + reg.name, reg) ### memory addressing ## https://en.wikipedia.org/wiki/X86#Addressing_modes _def_x86_register_class( - 'BASE_REGISTER_16', - BX, BP, + "BASE_REGISTER_16", + BX, + BP, ) _def_x86_register_class( - 'BASE_REGISTER_32', - EAX, EBX, ECX, EDX, ESP, EBP, ESI, EDI, + "BASE_REGISTER_32", + EAX, + EBX, + ECX, + EDX, + ESP, + EBP, + ESI, + EDI, ) _def_x86_register_class( - 'BASE_REGISTER_64', - *GPR64, RSP, RIP, + "BASE_REGISTER_64", + *GPR64, + RSP, + RIP, ) _def_x86_register_class( - 'INDEX_REGISTER_16', - SI, DI, + "INDEX_REGISTER_16", + SI, + DI, ) _def_x86_register_class( - 'INDEX_REGISTER_32', - EAX, EBX, ECX, EDX, EBP, ESI, EDI, + "INDEX_REGISTER_32", + EAX, + EBX, + ECX, + EDX, + EBP, + ESI, + EDI, ) _def_x86_register_class( - 'INDEX_REGISTER_64', - *GPR64, + "INDEX_REGISTER_64", + *GPR64, ) @@ -762,22 +771,22 @@ _def_x86_register_class( class X86_Register_Set(Register_Set): - @abc.override - def stack_pointer_register(self) -> 'Register': - return RSP + @abc.override + def stack_pointer_register(self) -> "Register": + return RSP - @abc.override - def register_classes(self) -> ['Register_Class']: - return _ALL_REGISTER_CLASSES_ + @abc.override + def register_classes(self) -> List["Register_Class"]: + return _ALL_REGISTER_CLASSES_ - @abc.override - def all_registers(self) -> 'Register_Class': - return _ALL_REGISTERS_ + @abc.override + def all_registers(self) -> "Register_Class": + return ANY_REGISTER - @abc.override - def argument_registers(self) -> 'Register_Class': - return ARGUMENT_REGISTER + @abc.override + def argument_registers(self) -> "Register_Class": + return ARGUMENT_REGISTER - @abc.override - def callee_save_registers(self) -> 'Register_Class': - return CALLEE_SAVED + @abc.override + def callee_save_registers(self) -> "Register_Class": + return CALLEE_SAVED -- GitLab From 1f2230a6f5184701489d9c5f6d7dcc7579513cd8 Mon Sep 17 00:00:00 2001 From: Nicolas Derumigny <nderumigny@gmail.com> Date: Tue, 6 Apr 2021 16:45:49 +0200 Subject: [PATCH 04/12] armv8a/flags: first implementation --- src/pipedream/asm/armv8a/__init__.py | 8 +- src/pipedream/asm/armv8a/flags.py | 156 +++++++++++++++++++++++++++ src/pipedream/asm/x86/flags.py | 66 ++++++------ 3 files changed, 194 insertions(+), 36 deletions(-) create mode 100644 src/pipedream/asm/armv8a/flags.py diff --git a/src/pipedream/asm/armv8a/__init__.py b/src/pipedream/asm/armv8a/__init__.py index a7622d5..70e5ba8 100644 --- a/src/pipedream/asm/armv8a/__init__.py +++ b/src/pipedream/asm/armv8a/__init__.py @@ -7,12 +7,14 @@ from pipedream.benchmark.types import Loop_Overhead from . import registers -from . import operands + +# from . import operands from .asmwriter import * -from .operands import * from .registers import * -from .instructions import * + +# from .operands import * +# from .instructions import * __all__ = [ *registers.__all__, diff --git a/src/pipedream/asm/armv8a/flags.py b/src/pipedream/asm/armv8a/flags.py new file mode 100644 index 0000000..5e44c0e --- /dev/null +++ b/src/pipedream/asm/armv8a/flags.py @@ -0,0 +1,156 @@ +from enum import IntFlag +from typing import Iterable + +__all__ = ["ARMv8_NZCV", "ARMv8_FPSCR", "ARMv8_CPSR"] + + +class ARMv8_NZCV(IntFlag): + """ + Condition Flags (64-bit, but upper 32 are reserved) + """ + + # RES0 + V = 0x1000_0000 # Overflow condition flag + C = 0x2000_0000 # Carry condition flag + Z = 0x4000_0000 # Zero condition flag + N = 0x8000_0000 # Negative condition flag + # RES0 + + def __iter__(self) -> Iterable["ARMv8_NZCV"]: + for flag in ARMv8_NZCV: + if self & flag: + yield flag + + +# Not useful as of now +# class ARMv8_FPCR(IntFlag): +# """ +# Floating-Point Control Register +# """ +# +# TODO + + +class ARMv8_FPSR(IntFlag): + """ + Floating-Point Status Register (64-bit, but upper 32 are reserved) + """ + + IOC = 0x0000_0001 # Invalid Operation cumulative floating-point exception bit + DZC = 0x0000_0002 # Overflow cumulative floating-point exception bit + OFC = 0x0000_0004 # Overflow cumulative floating-point exception bit + UFC = 0x0000_0008 # Underflow cumulative floating-point exception bit + IXC = 0x0000_0010 # Inexact cumulative floating-point exception bit + # RES0 + IDC = 0x0000_0080 # Input Denormal cumulative floating-point exception bit + # RES0 + QC = 0x0800_0000 # Cumulative sturation bit + + # "Inherited" from `ARMv8_NZCV` + # RES0 + V = 0x1000_0000 # Overflow condition flag + C = 0x2000_0000 # Carry condition flag + Z = 0x4000_0000 # Zero condition flag + N = 0x8000_0000 # Negative condition flag + # RES0 + + def __iter__(self) -> Iterable["ARMv8_FPSR"]: + for flag in ARMv8_FPSR: + if self & flag: + yield flag + + +class ARMv8_FPSCR(IntFlag): + """ + Floating-Point Status and Control Register) (32-bit) + """ + + # Inherit from `ARMv8_FPSR` + IOC = 0x0000_0001 # Invalid Operation cumulative floating-point exception bit + DZC = 0x0000_0002 # Overflow cumulative floating-point exception bit + OFC = 0x0000_0004 # Overflow cumulative floating-point exception bit + UFC = 0x0000_0008 # Underflow cumulative floating-point exception bit + IXC = 0x0000_0010 # Inexact cumulative floating-point exception bit + # RES0 + IDC = 0x0000_0080 # Input Denormal cumulative floating-point exception bit + + IOE = 0x0000_0100 # Invalid Operation floating_point exception trap enable + DZE = 0x0000_0200 # Divide by Zero floating-point exception trap enable + OFE = 0x0000_0400 # Overflow floating-point exception trap enable + UFE = 0x0000_0800 # Underflow floating-point exception trap enable + IXE = 0x0000_1000 # Inexact floating-point exception trap enable + # RES0 + IDE = 0x0000_8000 # Input Denormal floating-point exeption trap enable + LEN = 0x0007_0000 # Implementation Defined + FZ16 = 0x0008_0000 + # Flush-to-zero mode control bit on half-precision data-processing instructions + Stride = 0x0030_0000 # Implementation defined + RMode = 0x00C0_0000 # Rounding Mode control bit + FZ = 0x0100_0000 # Flush-to-zerp mode control bit + DN = 0x0200_0000 # Default NaN mode control bit + AHP = 0x0400_0000 # Alternate half-precision control bit + + # Inherit from `ARMv8_FPSR` + QC = 0x0800_0000 # Cumulative sturation bit + # RES0 + + # "Inherited" from `ARMv8_NZCV` + V = 0x1000_0000 # Overflow condition flag + C = 0x2000_0000 # Carry condition flag + Z = 0x4000_0000 # Zero condition flag + N = 0x8000_0000 # Negative condition flag + # RES0 + + +class ARMv8_APSR(IntFlag): + """ + Application Program Status Register (32-bit) + """ + + # RES0: 0-3 + # RES1: 4 + # RES0 + GE = 0x000F_0000 # PSTATE greater than or equal flags + + # "Inherited" from `ARMv8_NZCV` + V = 0x1000_0000 # Overflow condition flag + C = 0x2000_0000 # Carry condition flag + Z = 0x4000_0000 # Zero condition flag + N = 0x8000_0000 # Negative condition flag + # RES0 + + def __iter__(self) -> Iterable["ARMv8_APSR"]: + for flag in ARMv8_APSR: + if self & flag: + yield flag + + +class ARMv8_CPSR(IntFlag): + """ + Current Program Status Register (32-bit) + """ + + M = 0x0000_001F # PSTATE mode bits + # RES0 + F = 0x0000_0040 # PSTATE FIQ interrupt mask bit + I = 0x0000_0080 # PSTATE IRQ interrupt mask bit + A = 0x0000_0100 # PSTATE SError interrupt mask bit + E = 0x0000_0200 # PSTATE endianness bit + + # "Inherited" from `ARMv8_APSR` + # RES0: 0-3 + # RES1: 4 + # RES0 + GE = 0x000F_0000 # PSTATE greater than or equal flags + + # "Inherited" from `ARMv8_NZCV` + V = 0x1000_0000 # Overflow condition flag + C = 0x2000_0000 # Carry condition flag + Z = 0x4000_0000 # Zero condition flag + N = 0x8000_0000 # Negative condition flag + # RES0 + + def __iter__(self) -> Iterable["ARMv8_CPSR"]: + for flag in ARMv8_CPSR: + if self & flag: + yield flag diff --git a/src/pipedream/asm/x86/flags.py b/src/pipedream/asm/x86/flags.py index 0eeb460..76d0098 100644 --- a/src/pipedream/asm/x86/flags.py +++ b/src/pipedream/asm/x86/flags.py @@ -1,44 +1,44 @@ - import enum import typing as ty __all__ = [ - 'X86_Flags', + "X86_Flags", ] class X86_Flags(enum.IntFlag): - """ + """ flags for CPU flags register - """ - ## FLAGS - CF = 0x0000_0000_0001 # Carry flag - PF = 0x0000_0000_0004 # Parity flag - AF = 0x0000_0000_0010 # Adjust flag - ZF = 0x0000_0000_0040 # Zero flag - SF = 0x0000_0000_0080 # Sign flag - TF = 0x0000_0000_0100 # Trap flag - IF = 0x0000_0000_0200 # Interrupt enable flag - DF = 0x0000_0000_0400 # Direction flag - OF = 0x0000_0000_0800 # Overflow flag - IOPL = 0x0000_0000_3000 # I/O privilege level (286+ only) - NT = 0x0000_0000_4000 # Nested task flag (286+ only) - ## EFLAGS - RF = 0x0000_0001_0000 # Resume flag (386+ only) - VM = 0x0000_0002_0000 # Virtual 8086 mode flag (386+ only) - AC = 0x0000_0004_0000 # Alignment check (486SX+ only) - VIF = 0x0000_0008_0000 # Virtual interrupt flag (Pentium+) - VIP = 0x0000_0010_0000 # Virtual interrupt pending (Pentium+) - ID = 0x0000_0020_0000 # Able to use CPUID instruction (Pentium+) - ## RFLAGS + """ + + ## FLAGS + CF = 0x0000_0000_0001 # Carry flag + PF = 0x0000_0000_0004 # Parity flag + AF = 0x0000_0000_0010 # Adjust flag + ZF = 0x0000_0000_0040 # Zero flag + SF = 0x0000_0000_0080 # Sign flag + TF = 0x0000_0000_0100 # Trap flag + IF = 0x0000_0000_0200 # Interrupt enable flag + DF = 0x0000_0000_0400 # Direction flag + OF = 0x0000_0000_0800 # Overflow flag + IOPL = 0x0000_0000_3000 # I/O privilege level (286+ only) + NT = 0x0000_0000_4000 # Nested task flag (286+ only) + ## EFLAGS + RF = 0x0000_0001_0000 # Resume flag (386+ only) + VM = 0x0000_0002_0000 # Virtual 8086 mode flag (386+ only) + AC = 0x0000_0004_0000 # Alignment check (486SX+ only) + VIF = 0x0000_0008_0000 # Virtual interrupt flag (Pentium+) + VIP = 0x0000_0010_0000 # Virtual interrupt pending (Pentium+) + ID = 0x0000_0020_0000 # Able to use CPUID instruction (Pentium+) + ## RFLAGS - ## X87 FLAGS - FC0 = 0x0001_0000_0000 - FC1 = 0x0002_0000_0000 - FC2 = 0x0004_0000_0000 - FC3 = 0x0008_0000_0000 + ## X87 FLAGS + FC0 = 0x0001_0000_0000 + FC1 = 0x0002_0000_0000 + FC2 = 0x0004_0000_0000 + FC3 = 0x0008_0000_0000 - def __iter__(self) -> ty.Iterable['X86_Flags']: - for flag in X86_Flags: - if self & flag: - yield flag + def __iter__(self) -> ty.Iterable["X86_Flags"]: + for flag in X86_Flags: + if self & flag: + yield flag -- GitLab From 6d697a3b6412d844391294816d925307a0102c05 Mon Sep 17 00:00:00 2001 From: Nicolas Derumigny <nderumigny@gmail.com> Date: Tue, 6 Apr 2021 19:13:46 +0200 Subject: [PATCH 05/12] armv8a/operands: first implementation --- src/pipedream/asm/armv8a/__init__.py | 8 +- src/pipedream/asm/armv8a/operands.py | 491 ++++++++++++ src/pipedream/asm/x86/operands.py | 1098 +++++++++++++------------- 3 files changed, 1065 insertions(+), 532 deletions(-) create mode 100644 src/pipedream/asm/armv8a/operands.py diff --git a/src/pipedream/asm/armv8a/__init__.py b/src/pipedream/asm/armv8a/__init__.py index 70e5ba8..2a10e74 100644 --- a/src/pipedream/asm/armv8a/__init__.py +++ b/src/pipedream/asm/armv8a/__init__.py @@ -7,17 +7,17 @@ from pipedream.benchmark.types import Loop_Overhead from . import registers - -# from . import operands +from . import operands from .asmwriter import * from .registers import * -# from .operands import * +from .operands import * + # from .instructions import * __all__ = [ *registers.__all__, - # *operands.__all__, + *operands.__all__, # *instructions.__all__, ] diff --git a/src/pipedream/asm/armv8a/operands.py b/src/pipedream/asm/armv8a/operands.py new file mode 100644 index 0000000..376405f --- /dev/null +++ b/src/pipedream/asm/armv8a/operands.py @@ -0,0 +1,491 @@ +from pipedream.asm import ir +from pipedream.utils import abc + +from .flags import ARMv8_NZCV, ARMv8_CPSR +from .registers import * +from typing import Optional, Union + +__all__ = [ + "ARMv8_Operand", + "ARMv8_Register_Operand", + "ARMv8_Immediate_Operand", + "ImmLo", + "ImmHi", + "Imm12", + "UImm6", + "UImm4", + "UImm5", + "Imm16", + "Imm26", + "Imm14", + "Imm9", + "Imm7", + "Imm3", + "ARMv8_Flags_Operand", +] + + +class ARMv8_Operand(ir.Operand): + def __init__(self, name: str, visibility: ir.Operand_Visibility): + self._name = name + self._visibility = visibility + + @property # type: ignore + @abc.override + def name(self) -> str: + return self._name + + @property # type: ignore + @abc.override + def visibility(self) -> ir.Operand_Visibility: + return self._visibility + + +class ARMv8_Register_Operand(ARMv8_Operand, ir.Register_Operand): + def __init__( + self, + name: str, + visibility: ir.Operand_Visibility, + use_def: ir.Use_Def, + reg_class: ir.Register_Class, + reg: ARMv8_Register, + ): + super().__init__(name, visibility) + self._use_def = use_def + self._reg_class = reg_class + self._reg = reg + + @property # type: ignore + @abc.override + def short_name(self) -> str: + if self._reg: + return self._reg.name + else: + return self._reg_class.name + + @property # type: ignore + @abc.override + def register_class(self) -> ir.Register_Class: + return self._reg_class + + @property # type: ignore + @abc.override + def register(self) -> Optional[ARMv8_Register]: + return self._reg + + @abc.override + def with_register(self, reg: ir.Register) -> "ARMv8_Register_Operand": + assert isinstance(reg, ARMv8_Register) + + if reg not in self.register_class: + raise TypeError( + f"Register {reg.name} is not a member of {self.register_class.name}" + ) + + return ARMv8_Register_Operand( + self.name, self.visibility, self.use_def, self.register_class, reg + ) + + @property # type: ignore + @abc.override + def use_def(self) -> ir.Use_Def: + return self._use_def + + +class ARMv8_Flags_Operand(ARMv8_Operand, ir.Flags_Operand): + def __init__( + self, + name: str, + visibility: ir.Operand_Visibility, + reg: ARMv8_Register, + flags_read: ARMv8_NZCV, + flags_written: ARMv8_NZCV, + ): + assert isinstance(flags_read, ARMv8_NZCV), flags_read + assert isinstance(flags_written, ARMv8_NZCV), flags_written + + super().__init__(name, visibility) + self._reg = reg + self._flags_read = flags_read + self._flags_written = flags_written + + @property # type: ignore + @abc.override + def is_virtual(self): + return False + + @property # type: ignore + @abc.override + def flags_read(self) -> ARMv8_NZCV: + return self._flags_read + + @property # type: ignore + @abc.override + def flags_written(self) -> ARMv8_NZCV: + return self._flags_written + + @property # type: ignore + @abc.override + def register(self) -> ARMv8_Register: + return self._reg + + @property # type: ignore + @abc.override + def short_name(self) -> str: + return self._reg.name + + @property # type: ignore + @abc.override + def use_def(self) -> ir.Use_Def: + if self._flags_read and self._flags_written: + return ir.Use_Def.USE_DEF + if self._flags_read: + return ir.Use_Def.USE + if self._flags_written: + return ir.Use_Def.DEF + + # The register in an operand is either used or defined + assert False, [self.name, self._reg, self._flags_read, self._flags_written] + + def __repr__(self): + txt = self.name + ":" + + rw = self.flags_read | self.flags_written + # Taking the biggest PSTATE register: APSR + flags = [f for f in ARMv8_CPSR if f & rw] + + if flags: + for f in flags: + txt += f.name + + if f & self.flags_read: + txt += "?" + if f & self.flags_written: + txt += "!" + else: + txt += self._reg.name + + return txt + + +class ARMv8_Immediate_Operand(ARMv8_Operand, ir.Immediate_Operand): + def __init__(self, name: str, visibility: ir.Operand_Visibility, value: int = None): + super().__init__(name, visibility) + + assert value is None or isinstance( + value, (int, ir.Label) + ), f"want int or Label, have {value!r}" + + self._value = value + + @property # type: ignore + @abc.override + def short_name(self) -> str: + return type(self).__name__.upper() + + @property # type: ignore + @abc.override + def value(self) -> Union[int, float, None]: + return self._value + + @abc.override + def with_value(self, value): + clss = type(self) + + if not self._is_valid_value(value): + raise TypeError( + f"{value} is not a valid value for immediates of type {clss.__name__}" + ) + + return clss(self.name, self.visibility, value) + + @classmethod + @abc.abstractmethod + def _is_valid_value(clss, value) -> bool: + """ + Check if value is a valid value for immediates of this type. + """ + + @classmethod + @abc.abstractmethod + def _arbitrary(clss, random) -> Union[int, float]: + """ + Generate a random int/float that can fit in an immediate of this type. + """ + + +def _signed_min_max(num_bits: int): + min = -(2 ** (num_bits - 1)) + max = +(2 ** (num_bits - 1)) - 1 + return min, max + + +def _unsigned_min_max(num_bits: int): + min = 0 + max = +(2 ** (num_bits - 1)) + return min, max + + +# Immediate operands of the A64 Instruction Set, in the order of the "Architecture +# Reference Manual" + +# Jump operands are probably useless, as Pipedream's codegen rely on GNU assembly +# syntax (with labels instead of immediate offsets) + + +class ImmLo(ARMv8_Immediate_Operand): + """ + Class of PC-relative operands: immlo + """ + + _NUM_BITS = 2 + MIN, MAX = _unsigned_min_max(_NUM_BITS) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + +class ImmHi(ARMv8_Immediate_Operand): + """ + Class of PC-relative operands: immhi + """ + + _NUM_BITS = 19 + MIN, MAX = _unsigned_min_max(_NUM_BITS) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + +class Imm12(ARMv8_Immediate_Operand): + """ + Class of 12-bit signed immediates + """ + + _NUM_BITS = 12 + MIN, MAX = _signed_min_max(_NUM_BITS) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + @classmethod + @abc.override + def _arbitrary(clss, random): + return clss.MAX + + +class UImm6(ARMv8_Immediate_Operand): + """ + Class of 6-bit unsigned immediates + """ + + _NUM_BITS = 6 + MIN, MAX = _unsigned_min_max(_NUM_BITS) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + @classmethod + @abc.override + def _arbitrary(clss, random): + return clss.MAX + + +class UImm4(ARMv8_Immediate_Operand): + """ + Class of 4-bit unsigned immediates + """ + + _NUM_BITS = 4 + MIN, MAX = _unsigned_min_max(_NUM_BITS) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + @classmethod + @abc.override + def _arbitrary(clss, random): + return clss.MAX + + +class UImm5(ARMv8_Immediate_Operand): + """ + Class of 5-bit unsigned immediates (bitfields) + """ + + _NUM_BITS = 5 + MIN, MAX = _unsigned_min_max(_NUM_BITS) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + @classmethod + @abc.override + def _arbitrary(clss, random): + return clss.MAX + + +class Imm16(ARMv8_Immediate_Operand): + """ + Class of 16-bit signed immediates + """ + + _NUM_BITS = 16 + MIN, MAX = _signed_min_max(_NUM_BITS) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + @classmethod + @abc.override + def _arbitrary(clss, random): + return clss.MAX + + +class Imm26(ARMv8_Immediate_Operand): + """ + Class of 26-bit signed immediates (used for unconditional branches) + """ + + _NUM_BITS = 26 + MIN, MAX = _signed_min_max(_NUM_BITS) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + +class Imm14(ARMv8_Immediate_Operand): + """ + Class of 14-bit signed immediates (used for test and branch) + """ + + _NUM_BITS = 14 + MIN, MAX = _signed_min_max(_NUM_BITS) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + +class Imm9(ARMv8_Immediate_Operand): + """ + Class of 9-bit signed immediates (used for offsets on unscaled immediate) + """ + + _NUM_BITS = 9 + MIN, MAX = _signed_min_max(_NUM_BITS) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + +class Imm7(ARMv8_Immediate_Operand): + """ + Class of 7-bit signed immediates (used for offsets on load/store register pair) + """ + + _NUM_BITS = 7 + MIN, MAX = _signed_min_max(_NUM_BITS) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + +class Imm3(ARMv8_Immediate_Operand): + """ + Class of 3-bit signed immediates + """ + + _NUM_BITS = 3 + MIN, MAX = _signed_min_max(_NUM_BITS) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + @classmethod + @abc.override + def _arbitrary(clss, random): + return clss.MAX diff --git a/src/pipedream/asm/x86/operands.py b/src/pipedream/asm/x86/operands.py index f248947..9344914 100644 --- a/src/pipedream/asm/x86/operands.py +++ b/src/pipedream/asm/x86/operands.py @@ -1,6 +1,5 @@ - from pipedream.asm.x86.flags import X86_Flags -from pipedream.asm import ir +from pipedream.asm import ir from pipedream.utils import abc from .registers import * @@ -8,679 +7,722 @@ import copy import typing as ty __all__ = [ - 'X86_Operand', - 'X86_Register_Operand', - 'X86_Immediate_Operand', - 'X86_Flags_Operand', - - 'X86_Memory_Operand', - 'X86_Address_Operand', - - 'X86_Base_Displacement_Operand', - 'X86_Base_Displacement_Memory_Operand', - 'X86_Base_Displacement_Address_Operand', - - 'Imm8', - 'Imm16', - 'Imm32', - 'Imm64', - 'ImmU8', - 'ImmU16', - 'ImmU32', - 'ImmU64', - 'Scale_Imm', - 'Shift_Imm', + "X86_Operand", + "X86_Register_Operand", + "X86_Immediate_Operand", + "X86_Flags_Operand", + "X86_Memory_Operand", + "X86_Address_Operand", + "X86_Base_Displacement_Operand", + "X86_Base_Displacement_Memory_Operand", + "X86_Base_Displacement_Address_Operand", + "Imm8", + "Imm16", + "Imm32", + "Imm64", + "ImmU8", + "ImmU16", + "ImmU32", + "ImmU64", + "Scale_Imm", + "Shift_Imm", ] class X86_Operand(ir.Operand): - def __init__(self, name: str, visibility: ir.Operand_Visibility): - self._name = name - self._visibility = visibility + def __init__(self, name: str, visibility: ir.Operand_Visibility): + self._name = name + self._visibility = visibility - @property - @abc.override - def name(self) -> str: - return self._name + @property + @abc.override + def name(self) -> str: + return self._name - @property - @abc.override - def visibility(self) -> str: - return self._visibility + @property + @abc.override + def visibility(self) -> ir.Operand_Visibility: + return self._visibility ################################################################################ ##### REGISTERS + class X86_Register_Operand(X86_Operand, ir.Register_Operand): - def __init__(self, name: str, visibility: ir.Operand_Visibility, use_def: ir.Use_Def, - reg_class: ir.Register_Class, reg: X86_Register): - super().__init__(name, visibility) - self._use_def = use_def - self._reg_class = reg_class - self._reg = reg - - @property - @abc.override - def short_name(self) -> str: - if self._reg: - return self._reg.name - else: - return self._reg_class.name - - @property - @abc.override - def register_class(self) -> ir.Register_Class: - return self._reg_class - - @property - @abc.override - def register(self) -> ty.Optional[X86_Register]: - return self._reg - - @abc.override - def with_register(self, reg: ir.Register) -> 'X86_Register_Operand': - assert reg is not None - - if reg not in self.register_class: - raise TypeError(f'Register {reg.name} is not a member of {self.register_class.name}') - - return X86_Register_Operand(self.name, self.visibility, self.use_def, self.register_class, reg) - - @property - @abc.override - def use_def(self) -> ir.Use_Def: - return self._use_def + def __init__( + self, + name: str, + visibility: ir.Operand_Visibility, + use_def: ir.Use_Def, + reg_class: ir.Register_Class, + reg: X86_Register, + ): + super().__init__(name, visibility) + self._use_def = use_def + self._reg_class = reg_class + self._reg = reg + + @property + @abc.override + def short_name(self) -> str: + if self._reg: + return self._reg.name + else: + return self._reg_class.name + + @property + @abc.override + def register_class(self) -> ir.Register_Class: + return self._reg_class + + @property + @abc.override + def register(self) -> ty.Optional[X86_Register]: + return self._reg + + @abc.override + def with_register(self, reg: ir.Register) -> "X86_Register_Operand": + assert isinstance(reg, X86_Register) + + if reg not in self.register_class: + raise TypeError( + f"Register {reg.name} is not a member of {self.register_class.name}" + ) + + return X86_Register_Operand( + self.name, self.visibility, self.use_def, self.register_class, reg + ) + + @property + @abc.override + def use_def(self) -> ir.Use_Def: + return self._use_def class X86_Flags_Operand(X86_Operand, ir.Flags_Operand): - def __init__(self, name: str, visibility: ir.Operand_Visibility, - reg: X86_Register, flags_read: X86_Flags, flags_written: X86_Flags): - assert isinstance(flags_read, X86_Flags), flags_read - assert isinstance(flags_written, X86_Flags), flags_written - - super().__init__(name, visibility) - self._reg = reg - self._flags_read = flags_read - self._flags_written = flags_written - - @property - @abc.override - def is_virtual(self): - return False - - @property - @abc.override - def flags_read(self) -> X86_Flags: - return self._flags_read - - @property - @abc.override - def flags_written(self) -> X86_Flags: - return self._flags_written - - @property - @abc.override - def register(self) -> X86_Register: - return self._reg - - @property - @abc.override - def short_name(self) -> str: - return self._reg.name - - @property - @abc.override - def use_def(self) -> ir.Use_Def: - if self._flags_read and self._flags_written: - return ir.Use_Def.USE_DEF - if self._flags_read: - return ir.Use_Def.USE - if self._flags_written: - return ir.Use_Def.DEF - - ## FIXME: should never happen - # assert False, [self.name, self._reg, self._flags_read, self._flags_written] - return ir.Use_Def.USE - - def __repr__(self): - txt = self.name + ":" - - rw = self.flags_read | self.flags_written - flags = [f for f in X86_Flags if f & rw] - - if flags: - for f in flags: - txt += f.name - - if f & self.flags_read: - txt += '?' - if f & self.flags_written: - txt += '!' - else: - txt += self._reg.name - - return txt + def __init__( + self, + name: str, + visibility: ir.Operand_Visibility, + reg: X86_Register, + flags_read: X86_Flags, + flags_written: X86_Flags, + ): + assert isinstance(flags_read, X86_Flags), flags_read + assert isinstance(flags_written, X86_Flags), flags_written + + super().__init__(name, visibility) + self._reg = reg + self._flags_read = flags_read + self._flags_written = flags_written + + @property + @abc.override + def is_virtual(self): + return False + + @property + @abc.override + def flags_read(self) -> X86_Flags: + return self._flags_read + + @property + @abc.override + def flags_written(self) -> X86_Flags: + return self._flags_written + + @property + @abc.override + def register(self) -> X86_Register: + return self._reg + + @property + @abc.override + def short_name(self) -> str: + return self._reg.name + + @property + @abc.override + def use_def(self) -> ir.Use_Def: + if self._flags_read and self._flags_written: + return ir.Use_Def.USE_DEF + if self._flags_read: + return ir.Use_Def.USE + if self._flags_written: + return ir.Use_Def.DEF + + ## FIXME: should never happen + # assert False, [self.name, self._reg, self._flags_read, self._flags_written] + return ir.Use_Def.USE + + def __repr__(self): + txt = self.name + ":" + + rw = self.flags_read | self.flags_written + flags = [f for f in X86_Flags if f & rw] + + if flags: + for f in flags: + txt += f.name + + if f & self.flags_read: + txt += "?" + if f & self.flags_written: + txt += "!" + else: + txt += self._reg.name + + return txt ################################################################################ ##### IMMEDIATES + class X86_Immediate_Operand(X86_Operand, ir.Immediate_Operand): - def __init__(self, name: str, visibility: ir.Operand_Visibility, value: int = None): - super().__init__(name, visibility) + def __init__(self, name: str, visibility: ir.Operand_Visibility, value: int = None): + super().__init__(name, visibility) - assert value is None or isinstance(value, (int, ir.Label)), f'want int or Label, have {value!r}' + assert value is None or isinstance( + value, (int, ir.Label) + ), f"want int or Label, have {value!r}" - self._value = value + self._value = value - @property - @abc.override - def short_name(self) -> str: - return type(self).__name__.upper() + @property + @abc.override + def short_name(self) -> str: + return type(self).__name__.upper() - @property - @abc.override - def value(self) -> ty.Union[int, float, None]: - return self._value + @property + @abc.override + def value(self) -> ty.Union[int, float, None]: + return self._value - @abc.override - def with_value(self, value): - clss = type(self) + @abc.override + def with_value(self, value): + clss = type(self) - if not self._is_valid_value(value): - raise TypeError(f'{value} is not a valid value for immediates of type {clss.__name__}') + if not self._is_valid_value(value): + raise TypeError( + f"{value} is not a valid value for immediates of type {clss.__name__}" + ) - return clss(self.name, self.visibility, value) + return clss(self.name, self.visibility, value) - @classmethod - @abc.abstractmethod - def _is_valid_value(clss, value) -> bool: - """ - Check if value is a valid value for immediates of this type. - """ + @classmethod + @abc.abstractmethod + def _is_valid_value(clss, value) -> bool: + """ + Check if value is a valid value for immediates of this type. + """ - @classmethod - @abc.abstractmethod - def _arbitrary(clss, random) -> ty.Union[int, float]: - """ - Generate a random int/float that can fit in an immediate of this type. - """ + @classmethod + @abc.abstractmethod + def _arbitrary(clss, random) -> ty.Union[int, float]: + """ + Generate a random int/float that can fit in an immediate of this type. + """ def _signed_min_max(num_bits: int): - min = - 2 ** (num_bits - 1) - max = + 2 ** (num_bits - 1) - 1 - return min, max + min = -(2 ** (num_bits - 1)) + max = +(2 ** (num_bits - 1)) - 1 + return min, max def _unsigned_min_max(num_bits: int): - min = 0 - max = + 2 ** (num_bits - 1) - return min, max + min = 0 + max = +(2 ** (num_bits - 1)) + return min, max class Imm8(X86_Immediate_Operand): - """ + """ class of 8 bit signed immediates - """ + """ - _NUM_BITS = 8 - MIN, MAX = _signed_min_max(_NUM_BITS) + _NUM_BITS = 8 + MIN, MAX = _signed_min_max(_NUM_BITS) - @property - @abc.override - def num_bits(self): - return self._NUM_BITS + @property + @abc.override + def num_bits(self): + return self._NUM_BITS - @classmethod - @abc.override - def _is_valid_value(clss, value): - return type(value) is int and clss.MIN <= value <= clss.MAX + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX - @classmethod - @abc.override - def _arbitrary(clss, random): - return clss.MIN - # return random.randint(clss.MIN, clss.MAX) + @classmethod + @abc.override + def _arbitrary(clss, random): + return clss.MIN + # return random.randint(clss.MIN, clss.MAX) class ImmU8(X86_Immediate_Operand): - """ + """ class of 8 bit unsigned immediates - """ + """ - _NUM_BITS = 8 - MIN, MAX = _unsigned_min_max(_NUM_BITS) + _NUM_BITS = 8 + MIN, MAX = _unsigned_min_max(_NUM_BITS) - @property - @abc.override - def num_bits(self): - return self._NUM_BITS + @property + @abc.override + def num_bits(self): + return self._NUM_BITS - @classmethod - @abc.override - def _is_valid_value(clss, value): - return type(value) is int and clss.MIN <= value <= clss.MAX + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX - @classmethod - @abc.override - def _arbitrary(clss, random): - return clss.MAX - # return random.randint(clss.MIN, clss.MAX) + @classmethod + @abc.override + def _arbitrary(clss, random): + return clss.MAX + # return random.randint(clss.MIN, clss.MAX) class Imm16(X86_Immediate_Operand): - """ + """ class of 16 bit signed immediates - """ + """ - _NUM_BITS = 16 - MIN, MAX = _signed_min_max(_NUM_BITS) + _NUM_BITS = 16 + MIN, MAX = _signed_min_max(_NUM_BITS) - @property - @abc.override - def num_bits(self): - return self._NUM_BITS + @property + @abc.override + def num_bits(self): + return self._NUM_BITS - @classmethod - @abc.override - def _is_valid_value(clss, value): - if type(value) is ir.Label: - return True + @classmethod + @abc.override + def _is_valid_value(clss, value): + if type(value) is ir.Label: + return True - if type(value) is int: - return clss.MIN <= value <= clss.MAX + if type(value) is int: + return clss.MIN <= value <= clss.MAX - return False + return False - @classmethod - @abc.override - def _arbitrary(clss, random): - return clss.MIN - # return random.randint(clss.MIN, clss.MAX) + @classmethod + @abc.override + def _arbitrary(clss, random): + return clss.MIN + # return random.randint(clss.MIN, clss.MAX) class ImmU16(X86_Immediate_Operand): - """ + """ class of 16 bit unsigned immediates - """ + """ - _NUM_BITS = 16 - MIN, MAX = _unsigned_min_max(_NUM_BITS) + _NUM_BITS = 16 + MIN, MAX = _unsigned_min_max(_NUM_BITS) - @property - @abc.override - def num_bits(self): - return self._NUM_BITS + @property + @abc.override + def num_bits(self): + return self._NUM_BITS - @classmethod - @abc.override - def _is_valid_value(clss, value): - if type(value) is ir.Label: - return True + @classmethod + @abc.override + def _is_valid_value(clss, value): + if type(value) is ir.Label: + return True - if type(value) is int: - return clss.MIN <= value <= clss.MAX + if type(value) is int: + return clss.MIN <= value <= clss.MAX - return False + return False - @classmethod - @abc.override - def _arbitrary(clss, random): - return clss.MAX - # return random.randint(clss.MIN, clss.MAX) + @classmethod + @abc.override + def _arbitrary(clss, random): + return clss.MAX + # return random.randint(clss.MIN, clss.MAX) class Imm32(X86_Immediate_Operand): - """ + """ class of 32 bit signed immediates - """ + """ - _NUM_BITS = 32 - MIN, MAX = _signed_min_max(_NUM_BITS) + _NUM_BITS = 32 + MIN, MAX = _signed_min_max(_NUM_BITS) - @property - @abc.override - def num_bits(self): - return self._NUM_BITS + @property + @abc.override + def num_bits(self): + return self._NUM_BITS - @classmethod - @abc.override - def _is_valid_value(clss, value): - if type(value) is ir.Label: - return True + @classmethod + @abc.override + def _is_valid_value(clss, value): + if type(value) is ir.Label: + return True - if type(value) is int: - return clss.MIN <= value <= clss.MAX + if type(value) is int: + return clss.MIN <= value <= clss.MAX - return False + return False - @classmethod - @abc.override - def _arbitrary(clss, random): - return clss.MIN - # return random.randint(clss.MIN, clss.MAX) + @classmethod + @abc.override + def _arbitrary(clss, random): + return clss.MIN + # return random.randint(clss.MIN, clss.MAX) class ImmU32(X86_Immediate_Operand): - """ + """ class of 32 bit unsigned immediates - """ + """ - _NUM_BITS = 32 - MIN, MAX = _unsigned_min_max(_NUM_BITS) + _NUM_BITS = 32 + MIN, MAX = _unsigned_min_max(_NUM_BITS) - @property - @abc.override - def num_bits(self): - return self._NUM_BITS + @property + @abc.override + def num_bits(self): + return self._NUM_BITS - @classmethod - @abc.override - def _is_valid_value(clss, value): - if type(value) is ir.Label: - return True + @classmethod + @abc.override + def _is_valid_value(clss, value): + if type(value) is ir.Label: + return True - if type(value) is int: - return clss.MIN <= value <= clss.MAX + if type(value) is int: + return clss.MIN <= value <= clss.MAX - return False + return False - @classmethod - @abc.override - def _arbitrary(clss, random): - return clss.MAX - # return random.randint(clss.MIN, clss.MAX) + @classmethod + @abc.override + def _arbitrary(clss, random): + return clss.MAX + # return random.randint(clss.MIN, clss.MAX) class Imm64(X86_Immediate_Operand): - """ + """ class of 64 bit signed immediates - """ + """ - _NUM_BITS = 64 - MIN, MAX = _signed_min_max(_NUM_BITS) + _NUM_BITS = 64 + MIN, MAX = _signed_min_max(_NUM_BITS) - @property - @abc.override - def num_bits(self): - return self._NUM_BITS + @property + @abc.override + def num_bits(self): + return self._NUM_BITS - @classmethod - @abc.override - def _is_valid_value(clss, value): - if type(value) is ir.Label: - return True + @classmethod + @abc.override + def _is_valid_value(clss, value): + if type(value) is ir.Label: + return True - if type(value) is int: - return clss.MIN <= value <= clss.MAX + if type(value) is int: + return clss.MIN <= value <= clss.MAX - return False + return False - @classmethod - @abc.override - def _arbitrary(clss, random): - return clss.MIN - # return random.randint(clss.MIN, clss.MAX) + @classmethod + @abc.override + def _arbitrary(clss, random): + return clss.MIN + # return random.randint(clss.MIN, clss.MAX) class ImmU64(X86_Immediate_Operand): - """ + """ class of 64 bit unsigned immediates - """ + """ - _NUM_BITS = 64 - MIN, MAX = _unsigned_min_max(_NUM_BITS) + _NUM_BITS = 64 + MIN, MAX = _unsigned_min_max(_NUM_BITS) - @property - @abc.override - def num_bits(self): - return self._NUM_BITS + @property + @abc.override + def num_bits(self): + return self._NUM_BITS - @classmethod - @abc.override - def _is_valid_value(clss, value): - if type(value) is ir.Label: - return True + @classmethod + @abc.override + def _is_valid_value(clss, value): + if type(value) is ir.Label: + return True - if type(value) is int: - return clss.MIN <= value <= clss.MAX + if type(value) is int: + return clss.MIN <= value <= clss.MAX - return False + return False - @classmethod - @abc.override - def _arbitrary(clss, random): - return clss.MAX - # return random.randint(clss.MIN, clss.MAX) + @classmethod + @abc.override + def _arbitrary(clss, random): + return clss.MAX + # return random.randint(clss.MIN, clss.MAX) class Scale_Imm(X86_Immediate_Operand): - """ + """ Immediate for `scale` operand of a memory access in X86. - """ + """ - @property - @abc.override - def short_name(self) -> str: - return 'SCALE' + @property + @abc.override + def short_name(self) -> str: + return "SCALE" - @classmethod - @abc.override - def _is_valid_value(clss, value): - return type(value) is int and value in (1, 2, 4, 8) + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and value in (1, 2, 4, 8) - @classmethod - @abc.override - def _arbitrary(clss, random): - return 1 - # return 2 ** random.randint(0, 3) + @classmethod + @abc.override + def _arbitrary(clss, random): + return 1 + # return 2 ** random.randint(0, 3) class Shift_Imm(X86_Immediate_Operand): - """ + """ Immediate operand for a shit instruction. - """ + """ - MIN = 0 - MAX = 31 + MIN = 0 + MAX = 31 - @property - @abc.override - def short_name(self) -> str: - return 'SHIFT' + @property + @abc.override + def short_name(self) -> str: + return "SHIFT" - @classmethod - @abc.override - def _is_valid_value(clss, value): - return type(value) is int and clss.MIN <= value <= clss.MAX + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX - @classmethod - @abc.override - def _arbitrary(clss, random): - return 3 - # return random.randint(clss.MIN, clss.MAX) + @classmethod + @abc.override + def _arbitrary(clss, random): + return 3 + # return random.randint(clss.MIN, clss.MAX) ################################################################################ ##### MEMORY + class X86_Memory_Or_Address_Operand(X86_Operand, ir.Composite_Operand): - """ + """ Mixin for defining memory & address operands - """ - - def __init__(self, - name: str, - visibility: ir.Operand_Visibility, - address_width: int, - ): - super().__init__(name, visibility) - - assert type(address_width) is int - assert address_width in (16, 32, 64) - - # how many bits are in address (16/32/64) - self._address_width = address_width - - @property - def address_width(self) -> int: - return self._address_width - - @property - @abc.override - def is_virtual(self) -> bool: - return any(op.is_virtual for op in self.sub_operands) - - @abc.override - def update_sub_operand(self, idx_or_name: ty.Union[int, str], - fn: ty.Callable[['Operand'], 'Operand']): - if type(idx_or_name) is int: - name = self.get_operand_name(idx_or_name) - else: - name = idx_or_name + """ - new = copy.copy(self) - sub = fn(getattr(new, name)) - assert isinstance(sub, ir.Operand) - setattr(new, name, sub) - return new + def __init__( + self, + name: str, + visibility: ir.Operand_Visibility, + address_width: int, + ): + super().__init__(name, visibility) + + assert type(address_width) is int + assert address_width in (16, 32, 64) + + # how many bits are in address (16/32/64) + self._address_width = address_width + + @property + def address_width(self) -> int: + return self._address_width + + @property + @abc.override + def is_virtual(self) -> bool: + return any(op.is_virtual for op in self.sub_operands) + + @abc.override + def update_sub_operand( + self, idx_or_name: ty.Union[int, str], fn: ty.Callable[["Operand"], "Operand"] + ): + if type(idx_or_name) is int: + name = self.get_operand_name(idx_or_name) + else: + name = idx_or_name + + new = copy.copy(self) + sub = fn(getattr(new, name)) + assert isinstance(sub, ir.Operand) + setattr(new, name, sub) + return new class X86_Memory_Operand(X86_Memory_Or_Address_Operand, ir.Memory_Operand): - def __init__(self, - name: str, - visibility: ir.Operand_Visibility, - use_def: ir.Use_Def, - address_width: int, memory_width: int, - ): - super().__init__(name, visibility, address_width) + def __init__( + self, + name: str, + visibility: ir.Operand_Visibility, + use_def: ir.Use_Def, + address_width: int, + memory_width: int, + ): + super().__init__(name, visibility, address_width) - assert type(use_def) is ir.Use_Def - assert type(memory_width) is int + assert type(use_def) is ir.Use_Def + assert type(memory_width) is int - # how many bits are loaded/stored - self._memory_width = memory_width + # how many bits are loaded/stored + self._memory_width = memory_width - self._use_def = use_def + self._use_def = use_def - @property - @abc.override - def memory_width(self) -> int: - return self._memory_width + @property + @abc.override + def memory_width(self) -> int: + return self._memory_width - @property - @abc.override - def use_def(self) -> ir.Use_Def: - return self._use_def + @property + @abc.override + def use_def(self) -> ir.Use_Def: + return self._use_def class X86_Address_Operand(X86_Memory_Or_Address_Operand, ir.Address_Operand): - @property - @abc.override - def use_def(self) -> ir.Use_Def: - return ir.Use_Def.USE + @property + @abc.override + def use_def(self) -> ir.Use_Def: + return ir.Use_Def.USE class X86_Base_Displacement_Operand(ir.Base_Displacement_Operand): - """ + """ Mixin for defining base/displacement operands - """ - - def __init__(self, - base: ty.Optional[X86_Register_Operand], - displacement: ty.Optional[Imm32], - ): - assert isinstance(displacement, X86_Immediate_Operand) - assert isinstance(base, X86_Register_Operand), base - - self._displacement = displacement - self._base = base - - @property - def base(self) -> ir.Register_Operand: - return self._base - - @property - def displacement(self) -> ir.Immediate_Operand: - return self._displacement - - def with_base(self, base_reg: ir.Register) -> 'X86_Base_Displacement_Operand': - new = copy.copy(self) - new._base = new._base.with_register(base_reg) - return new - - def with_displacement(self, disp: int) -> 'X86_Base_Displacement_Operand': - new = copy.copy(self) - new._displacement = new._displacement.with_value(disp) - return new - - @abc.abstractproperty - def _short_short_name(self) -> str: - pass - - @property - @abc.override - def short_name(self) -> str: - return f'{self._short_short_name}BD{self.address_width}_{self.memory_width}' - - @property - @abc.override - def sub_operands(self): - yield self.base - yield self.displacement - - def get_operand_name(self, idx: int) -> str: - return { - 0: 'base', - 1: 'displacement', - }[idx] - - def __repr__(self): - tmp = [ - self.name, ':', - self._short_short_name, 'BD', - str(self.address_width), '/', - ] - if hasattr(self, 'memory_width'): - tmp += [ - str(self.memory_width), '/', - ] - tmp += [ - '(', repr(self.base), ', ', repr(self.displacement), ')' - ] - - return ''.join(tmp) - - -class X86_Base_Displacement_Memory_Operand(X86_Base_Displacement_Operand, X86_Memory_Operand, - ir.Base_Displacement_Memory_Operand): - def __init__(self, name: str, visibility: ir.Operand_Visibility, - use_def: ir.Use_Def, address_width: int, memory_width: int, - base: ty.Optional[X86_Register_Operand], - displacement: ty.Optional[Imm32], - ): - X86_Memory_Operand.__init__(self, name, visibility, use_def, address_width, memory_width) - X86_Base_Displacement_Operand.__init__(self, base, displacement) - - @property - @abc.override - def _short_short_name(self) -> str: - return 'mem' - - -class X86_Base_Displacement_Address_Operand(X86_Base_Displacement_Operand, X86_Address_Operand, - ir.Base_Displacement_Address_Operand): - def __init__(self, name: str, visibility: ir.Operand_Visibility, - address_width: int, - base: ty.Optional[X86_Register_Operand], - displacement: ty.Optional[Imm32], - ): - X86_Address_Operand.__init__(self, name, visibility, address_width) - X86_Base_Displacement_Operand.__init__(self, base, displacement) - - @property - @abc.override - def _short_short_name(self) -> str: - return 'addr' + """ + + def __init__( + self, + base: ty.Optional[X86_Register_Operand], + displacement: ty.Optional[Imm32], + ): + assert isinstance(displacement, X86_Immediate_Operand) + assert isinstance(base, X86_Register_Operand), base + + self._displacement = displacement + self._base = base + + @property + def base(self) -> ir.Register_Operand: + return self._base + + @property + def displacement(self) -> ir.Immediate_Operand: + return self._displacement + + def with_base(self, base_reg: ir.Register) -> "X86_Base_Displacement_Operand": + new = copy.copy(self) + new._base = new._base.with_register(base_reg) + return new + + def with_displacement(self, disp: int) -> "X86_Base_Displacement_Operand": + new = copy.copy(self) + new._displacement = new._displacement.with_value(disp) + return new + + @abc.abstractproperty + def _short_short_name(self) -> str: + pass + + @property + @abc.override + def short_name(self) -> str: + return f"{self._short_short_name}BD{self.address_width}_{self.memory_width}" + + @property + @abc.override + def sub_operands(self): + yield self.base + yield self.displacement + + def get_operand_name(self, idx: int) -> str: + return { + 0: "base", + 1: "displacement", + }[idx] + + def __repr__(self): + tmp = [ + self.name, + ":", + self._short_short_name, + "BD", + str(self.address_width), + "/", + ] + if hasattr(self, "memory_width"): + tmp += [ + str(self.memory_width), + "/", + ] + tmp += ["(", repr(self.base), ", ", repr(self.displacement), ")"] + + return "".join(tmp) + + +class X86_Base_Displacement_Memory_Operand( + X86_Base_Displacement_Operand, + X86_Memory_Operand, + ir.Base_Displacement_Memory_Operand, +): + def __init__( + self, + name: str, + visibility: ir.Operand_Visibility, + use_def: ir.Use_Def, + address_width: int, + memory_width: int, + base: ty.Optional[X86_Register_Operand], + displacement: ty.Optional[Imm32], + ): + X86_Memory_Operand.__init__( + self, name, visibility, use_def, address_width, memory_width + ) + X86_Base_Displacement_Operand.__init__(self, base, displacement) + + @property + @abc.override + def _short_short_name(self) -> str: + return "mem" + + +class X86_Base_Displacement_Address_Operand( + X86_Base_Displacement_Operand, + X86_Address_Operand, + ir.Base_Displacement_Address_Operand, +): + def __init__( + self, + name: str, + visibility: ir.Operand_Visibility, + address_width: int, + base: ty.Optional[X86_Register_Operand], + displacement: ty.Optional[Imm32], + ): + X86_Address_Operand.__init__(self, name, visibility, address_width) + X86_Base_Displacement_Operand.__init__(self, base, displacement) + + @property + @abc.override + def _short_short_name(self) -> str: + return "addr" -- GitLab From 26c772e75eb34ee6eab79805415040704419cbed Mon Sep 17 00:00:00 2001 From: Nicolas Derumigny <nderumigny@gmail.com> Date: Thu, 8 Apr 2021 10:35:53 +0200 Subject: [PATCH 06/12] armv8/instructions: first implementation + add extract-instructions-binutils --- src/pipedream/asm/armv8a/__init__.py | 22 +- src/pipedream/asm/armv8a/instructions.py | 137 ++ src/pipedream/asm/x86/instructions.py | 1433 +++++++++-------- .../README.md | 30 + .../extract_arm_db/extract_binutils.py | 376 +++++ .../requirements.txt | 3 + .../setup.py | 35 + 7 files changed, 1392 insertions(+), 644 deletions(-) create mode 100644 src/pipedream/asm/armv8a/instructions.py create mode 100644 tools/extract-binutils-instruction-database/README.md create mode 100644 tools/extract-binutils-instruction-database/extract_arm_db/extract_binutils.py create mode 100644 tools/extract-binutils-instruction-database/requirements.txt create mode 100755 tools/extract-binutils-instruction-database/setup.py diff --git a/src/pipedream/asm/armv8a/__init__.py b/src/pipedream/asm/armv8a/__init__.py index 2a10e74..3c1dc26 100644 --- a/src/pipedream/asm/armv8a/__init__.py +++ b/src/pipedream/asm/armv8a/__init__.py @@ -8,16 +8,30 @@ from pipedream.benchmark.types import Loop_Overhead from . import registers from . import operands +from . import instructions from .asmwriter import * from .registers import * - from .operands import * - -# from .instructions import * +from .instructions import * __all__ = [ *registers.__all__, *operands.__all__, - # *instructions.__all__, + *instructions.__all__, ] + + +# TODO +class ARMv8_Architecture(Architecture): + pass + + +# TODO +class ARMv8_ASM_Dialect(ASM_Dialect): + pass + + +# TODO +class ARMv8_IR_Builder(IR_Builder): + pass diff --git a/src/pipedream/asm/armv8a/instructions.py b/src/pipedream/asm/armv8a/instructions.py new file mode 100644 index 0000000..05955e6 --- /dev/null +++ b/src/pipedream/asm/armv8a/instructions.py @@ -0,0 +1,137 @@ +import pipedream.utils.abc as abc + +from pipedream.asm import ir + +from pipedream.asm.armv8a.operands import * +from pipedream.asm.armv8a.registers import * +from pipedream.asm.armv8a.flags import * + +from typing import Dict, Sequence, Union + +# from pipedream.asm.armv8a import instructions_binutils + + +__all__ = [ + "ARMv8_Instruction", + "ARMv8_Instruction_Set", + "MNEMONICS", + "INSTRUCTIONS", + # "Harness", +] + +Instruction_Name = str +Operand_Name = str + +ALL_TAGS = frozenset() +INSTRUCTIONS: Dict[Instruction_Name, ir.Machine_Instruction] = {} +MNEMONICS: Dict[Instruction_Name, str] = {} + + +class ARMv8_Instruction(ir.Machine_Instruction): + def __init__(self, name, mnemonic, isa_set, operands, tags, can_benchmark): + self._name = name + self._mnemonic = mnemonic + self._isa_set = isa_set + self._operands = operands + self._tags = tags + self._can_benchmark = can_benchmark + + assert len(operands) == len( + set(o.name for o in operands) + ), "duplicate operand name in " + str(self) + + @property + @abc.override + def name(self) -> str: + return self._name + + @property + def att_mnemonic(self) -> str: + return self._att_mnemonic + + @property + def intel_mnemonic(self) -> str: + return self._intel_mnemonic + + @property + @abc.override + def isa_set(self) -> str: + return self._isa_set + + @property + @abc.override + def tags(self) -> Sequence[str]: + return self._tags + + @property + @abc.override + def operands(self) -> ["Operand"]: + return self._operands + + @abc.override + def update_operand( + self, idx_or_name: Union[int, str], fn + ) -> ir.Machine_Instruction: + if type(idx_or_name) is str: + idx = self.get_operand_idx(idx_or_name) + else: + idx = idx_or_name + + ops = list(self.operands) + ops[idx] = fn(ops[idx]) + new = copy.copy(self) + new._operands = ops + return new + + @abc.override + def encodings(self) -> ["Instruction_Encoding"]: + raise NotImplementedError() + + @property + @abc.override + def can_benchmark(self) -> bool: + return self._can_benchmark + + +class ARMv8_Instruction_Set(ir.Instruction_Set): + @abc.override + def instruction_groups(self) -> ["Instruction_Group"]: + return [] + + @abc.override + def instructions(self) -> [ir.Machine_Instruction]: + return list(INSTRUCTIONS.values()) + + @abc.override + def benchmark_instructions(self) -> [ir.Machine_Instruction]: + return [I for I in INSTRUCTIONS.values() if I._can_benchmark] + + @property + @abc.override + def all_tags(self): + tags = collections.OrderedDict() + + for inst in INSTRUCTIONS.values(): + for tag in inst.tags: + tags[tag] = True + + yield from tags + + @abc.override + def instruction_for_name(self, name: str) -> ir.Machine_Instruction: + return INSTRUCTIONS[name] + + @abc.override + def instructions_for_tags(self, *tags) -> Sequence[ir.Machine_Instruction]: + tags = set(tags) + + for inst in INSTRUCTIONS.values(): + if tags <= inst.tags: + yield inst + + def __getitem__(self, name): + return INSTRUCTIONS[name] + + @abc.override + def __iter__(self): + return iter(INSTRUCTIONS.values()) diff --git a/src/pipedream/asm/x86/instructions.py b/src/pipedream/asm/x86/instructions.py index 4f5ddeb..fca9440 100644 --- a/src/pipedream/asm/x86/instructions.py +++ b/src/pipedream/asm/x86/instructions.py @@ -1,4 +1,3 @@ - import collections import copy import types @@ -8,705 +7,859 @@ import pipedream.utils.abc as abc from pipedream.asm import ir -from pipedream.asm.x86.operands import * +from pipedream.asm.x86.operands import * from pipedream.asm.x86.registers import * -from pipedream.asm.x86.flags import * +from pipedream.asm.x86.flags import * from pipedream.asm.x86 import instructions_xed __all__ = [ - 'X86_Instruction', - 'X86_Instruction_Set', - - 'ATT_MNEMONICS', - 'INSTRUCTIONS', - - 'Harness', + "X86_Instruction", + "X86_Instruction_Set", + "ATT_MNEMONICS", + "INSTRUCTIONS", + "Harness", ] Instruction_Name = str -Operand_Name = str +Operand_Name = str -ALL_TAGS = frozenset() +ALL_TAGS = frozenset() INSTRUCTIONS: ty.Dict[Instruction_Name, ir.Machine_Instruction] = {} ATT_MNEMONICS: ty.Dict[Instruction_Name, str] = {} class X86_Instruction(ir.Machine_Instruction): - def __init__(self, name, att_mnemonic, intel_mnemonic, isa_set, operands, tags, can_benchmark): - self._name = name - self._att_mnemonic = att_mnemonic - self._intel_mnemonic = intel_mnemonic - self._isa_set = isa_set - self._operands = operands - self._tags = tags - self._can_benchmark = can_benchmark - - assert len(operands) == len(set(o.name for o in operands)), 'duplicate operand name in ' + str(self) - - @property - @abc.override - def name(self) -> str: - return self._name - - @property - @abc.override - def att_mnemonic(self) -> str: - return self._att_mnemonic - - @property - @abc.override - def intel_mnemonic(self) -> str: - return self._intel_mnemonic - - @property - @abc.override - def isa_set(self) -> str: - return self._isa_set - - @property - @abc.override - def tags(self) -> ty.Sequence[str]: - return self._tags - - @property - @abc.override - def operands(self) -> ['Operand']: - return self._operands - - @abc.override - def update_operand(self, idx_or_name: ty.Union[int, str], fn) -> ir.Machine_Instruction: - if type(idx_or_name) is str: - idx = self.get_operand_idx(idx_or_name) - else: - idx = idx_or_name - - ops = list(self.operands) - ops[idx] = fn(ops[idx]) - new = copy.copy(self) - new._operands = ops - return new - - @abc.override - def encodings(self) -> ['Instruction_Encoding']: - raise NotImplementedError() - - @property - @abc.override - def can_benchmark(self) -> bool: - return self._can_benchmark + def __init__( + self, name, att_mnemonic, intel_mnemonic, isa_set, operands, tags, can_benchmark + ): + self._name = name + self._att_mnemonic = att_mnemonic + self._intel_mnemonic = intel_mnemonic + self._isa_set = isa_set + self._operands = operands + self._tags = tags + self._can_benchmark = can_benchmark + + assert len(operands) == len( + set(o.name for o in operands) + ), "duplicate operand name in " + str(self) + + @property + @abc.override + def name(self) -> str: + return self._name + + @property + def att_mnemonic(self) -> str: + return self._att_mnemonic + + @property + def intel_mnemonic(self) -> str: + return self._intel_mnemonic + + @property + @abc.override + def isa_set(self) -> str: + return self._isa_set + + @property + @abc.override + def tags(self) -> ty.Sequence[str]: + return self._tags + + @property + @abc.override + def operands(self) -> ["Operand"]: + return self._operands + + @abc.override + def update_operand( + self, idx_or_name: ty.Union[int, str], fn + ) -> ir.Machine_Instruction: + if type(idx_or_name) is str: + idx = self.get_operand_idx(idx_or_name) + else: + idx = idx_or_name + + ops = list(self.operands) + ops[idx] = fn(ops[idx]) + new = copy.copy(self) + new._operands = ops + return new + + @abc.override + def encodings(self) -> ["Instruction_Encoding"]: + raise NotImplementedError() + + @property + @abc.override + def can_benchmark(self) -> bool: + return self._can_benchmark class X86_Instruction_Set(ir.Instruction_Set): - @abc.override - def instruction_groups(self) -> ['Instruction_Group']: - return [] + @abc.override + def instruction_groups(self) -> ["Instruction_Group"]: + return [] - @abc.override - def instructions(self) -> [ir.Machine_Instruction]: - return list(INSTRUCTIONS.values()) + @abc.override + def instructions(self) -> [ir.Machine_Instruction]: + return list(INSTRUCTIONS.values()) - @abc.override - def benchmark_instructions(self) -> [ir.Machine_Instruction]: - return [I for I in INSTRUCTIONS.values() if I._can_benchmark] + @abc.override + def benchmark_instructions(self) -> [ir.Machine_Instruction]: + return [I for I in INSTRUCTIONS.values() if I._can_benchmark] - @property - @abc.override - def all_tags(self): - tags = collections.OrderedDict() + @property + @abc.override + def all_tags(self): + tags = collections.OrderedDict() - for inst in INSTRUCTIONS.values(): - for tag in inst.tags: - tags[tag] = True + for inst in INSTRUCTIONS.values(): + for tag in inst.tags: + tags[tag] = True - yield from tags + yield from tags - @abc.override - def instruction_for_name(self, name: str) -> ir.Machine_Instruction: - return INSTRUCTIONS[name] + @abc.override + def instruction_for_name(self, name: str) -> ir.Machine_Instruction: + return INSTRUCTIONS[name] - @abc.override - def instructions_for_tags(self, *tags) -> ty.Sequence[ir.Machine_Instruction]: - tags = set(tags) + @abc.override + def instructions_for_tags(self, *tags) -> ty.Sequence[ir.Machine_Instruction]: + tags = set(tags) - for inst in INSTRUCTIONS.values(): - if tags <= inst.tags: - yield inst + for inst in INSTRUCTIONS.values(): + if tags <= inst.tags: + yield inst - def __getitem__(self, name): - return INSTRUCTIONS[name] + def __getitem__(self, name): + return INSTRUCTIONS[name] - @abc.override - def __iter__(self): - return iter(INSTRUCTIONS.values()) + @abc.override + def __iter__(self): + return iter(INSTRUCTIONS.values()) ### FIXME: many thousands of missing instructions -i64 = 'i64' -i32 = 'i32' -i8 = 'i8' +i64 = "i64" +i32 = "i32" +i8 = "i8" USE, DEF, USE_DEF = ir.Use_Def R, W, RW = USE, DEF, USE_DEF -EXPLICIT = ir.Operand_Visibility.EXPLICIT -IMPLICIT = ir.Operand_Visibility.IMPLICIT +EXPLICIT = ir.Operand_Visibility.EXPLICIT +IMPLICIT = ir.Operand_Visibility.IMPLICIT SUPPRESSED = ir.Operand_Visibility.SUPPRESSED -TAGS_CANNOT_BENCHMARK = frozenset([ - ## TODO: add support (really big stack space for benchmark??) - 'stack', - ## :/ branches - 'branch', - ## write kernel benchmark harness :P - 'ring0', - ## add support for old school i386 :P - 'segmentation', - ## get an AMD machine? (check CPU flags?) - 'amdonly', - ## get a machine that supports it? - 'waitpkg', - ## add protected-mode benchmark harness - 'protected-mode', -]) -INTEL_MNEMONIC_CANNOT_BENCHMARK = frozenset([ - ## always raises SIGILL (that's the whole point) - 'ud0', 'ud1', 'ud2', - ## interrupt handling - 'int', 'int1', 'int3', 'cli', - ## priviledge: control register access - 'xgetbv', - ## fix harness: stack access - 'leave', - ## fix harness: weird fixed memory operand - 'insb', 'insd', 'insw', - 'outsb', 'outsd', 'outsw', - 'lods', 'lodsb', 'lodsd', 'lodsw', 'lodsq', - 'stos', 'stosb', 'stosd', 'stosw', 'stosq', - 'xlat', 'xlatb' - ## fix harness: string copy/move instructions - 'movs', 'movsb', 'movsd', 'movsw', 'movsq', - 'scas', 'scasb', 'scasd', 'scasw', 'scasq', - ## direct I/O port access - 'in', 'out', - ## MXCSR access (updated harness) - 'ldmxcsr', 'stmxcsr', 'vldmxcsr', 'vstmxcsr', - ## read/write protection keys (privileged) - 'rdpkru', 'wrpkru', - ## read PMU registers (privileged) - 'rdpmc', - ## VM instructions (privileged) - 'vmfunc', - ## prefetch nop (get newer CPU for testing) - 'prefetch_exclusive', - ## fix harness: direct write to flags register: causes segfault - 'std', 'sti', -]) -INST_NAME_CANNOT_BENCHMARK = frozenset([ - # legacy string instructions with weird address operand - # we can benchmark the vector versions same mnemonic and normal memory operand - 'CMPSB', 'CMPSD', 'CMPSW', 'CMPSQ', - # fix harness: divide by zero? - 'DIV_GPR8NOREXi8', - 'DIV_MEM64i16', - 'DIV_MEM64i32', - 'DIV_MEM64i64', - 'DIV_MEM64u8', - 'IDIV_GPR8NOREXi8', - 'IDIV_MEM64i16', - 'IDIV_MEM64i32', - 'IDIV_MEM64i64', - 'IDIV_MEM64u8', - # fix harness: FP error? - 'FLDCW_MEM64OTHER', - 'FLDCW_MEM32OTHER', -]) - -def mk_inst(*, name: str, att_mnemonic: str, intel_mnemonic: str, - operands: ty.List, tags: ty.Set[str], - isa_set: str, isa_extension: str = None, - can_benchmark = None): - # print('mk_inst', name, att_mnemonic, intel_mnemonic, - # list(op() for op in operands), - # tags, isa_set, isa_extension, - # can_benchmark) - - global INSTRUCTIONS, ALL_TAGS - - tags = frozenset(tags) | frozenset([name, att_mnemonic, intel_mnemonic]) - - operands = tuple(mk_op() for mk_op in operands) - - ## filter out instructions we currently do not support - if can_benchmark is None: - can_benchmark = pipedream_asm_backend_can_handle(name, isa_set, intel_mnemonic, att_mnemonic, tags, operands) - - inst = X86_Instruction(name, att_mnemonic, intel_mnemonic, isa_set, operands, tags, can_benchmark) - ATT_MNEMONICS[name] = att_mnemonic - - if name in INSTRUCTIONS: - raise ValueError('\n'.join([ - f'Duplicate instruction:', - f' ({INSTRUCTIONS[name]}', - f' vs', - f' {inst})', - ])) - - INSTRUCTIONS[name] = inst - ALL_TAGS = ALL_TAGS | tags - - return inst - - -def pipedream_asm_backend_can_handle(name: str, isa_set, intel_mnemonic: str, att_mnemonic: str, tags: ty.Set[str], operands: ty.List[ir.Operand]): - if tags & TAGS_CANNOT_BENCHMARK: - return False - - if intel_mnemonic in INTEL_MNEMONIC_CANNOT_BENCHMARK: - return False - - ##### need to update benchmark harness (currently crash) - - if name in INST_NAME_CANNOT_BENCHMARK: - return False - - assert name != 'DIV_MEM64i16', [name, INST_NAME_CANNOT_BENCHMARK] - - ##### modern extensions not available on my machine - - ISA = instructions_xed.ISA - - ## VIA CPU instructions - if isa_set in (ISA.VIA_PADLOCK_RNG, ): - return False - - ## Total Memory Encryption (TME) - if isa_set in (ISA.PCONFIG, ): - return False - - ## AVX Galois Field instructions - if isa_set in (ISA.GFNI, ISA.AVX_GFNI, ): - return False - - ## Software Guard Extensions (SGX) - if isa_set in (ISA.SGX, ISA.SGX_ENCLV): - return False - - ## Cache Write Back - if isa_set is ISA.CLWB: - return False - - ## Supervisor Mode Access Prevention - if isa_set is ISA.SMAP: - return False - - ## Control-Flow Enforcement Technology (CET) - if isa_set is ISA.CET: - return False - - ## Intel Processor Trace (PT) - if isa_set is ISA.PT: - return False +TAGS_CANNOT_BENCHMARK = frozenset( + [ + ## TODO: add support (really big stack space for benchmark??) + "stack", + ## :/ branches + "branch", + ## write kernel benchmark harness :P + "ring0", + ## add support for old school i386 :P + "segmentation", + ## get an AMD machine? (check CPU flags?) + "amdonly", + ## get a machine that supports it? + "waitpkg", + ## add protected-mode benchmark harness + "protected-mode", + ] +) +INTEL_MNEMONIC_CANNOT_BENCHMARK = frozenset( + [ + ## always raises SIGILL (that's the whole point) + "ud0", + "ud1", + "ud2", + ## interrupt handling + "int", + "int1", + "int3", + "cli", + ## priviledge: control register access + "xgetbv", + ## fix harness: stack access + "leave", + ## fix harness: weird fixed memory operand + "insb", + "insd", + "insw", + "outsb", + "outsd", + "outsw", + "lods", + "lodsb", + "lodsd", + "lodsw", + "lodsq", + "stos", + "stosb", + "stosd", + "stosw", + "stosq", + "xlat", + "xlatb" + ## fix harness: string copy/move instructions + "movs", + "movsb", + "movsd", + "movsw", + "movsq", + "scas", + "scasb", + "scasd", + "scasw", + "scasq", + ## direct I/O port access + "in", + "out", + ## MXCSR access (updated harness) + "ldmxcsr", + "stmxcsr", + "vldmxcsr", + "vstmxcsr", + ## read/write protection keys (privileged) + "rdpkru", + "wrpkru", + ## read PMU registers (privileged) + "rdpmc", + ## VM instructions (privileged) + "vmfunc", + ## prefetch nop (get newer CPU for testing) + "prefetch_exclusive", + ## fix harness: direct write to flags register: causes segfault + "std", + "sti", + ] +) +INST_NAME_CANNOT_BENCHMARK = frozenset( + [ + # legacy string instructions with weird address operand + # we can benchmark the vector versions same mnemonic and normal memory operand + "CMPSB", + "CMPSD", + "CMPSW", + "CMPSQ", + # fix harness: divide by zero? + "DIV_GPR8NOREXi8", + "DIV_MEM64i16", + "DIV_MEM64i32", + "DIV_MEM64i64", + "DIV_MEM64u8", + "IDIV_GPR8NOREXi8", + "IDIV_MEM64i16", + "IDIV_MEM64i32", + "IDIV_MEM64i64", + "IDIV_MEM64u8", + # fix harness: FP error? + "FLDCW_MEM64OTHER", + "FLDCW_MEM32OTHER", + ] +) - ## Read Processor ID - if isa_set is ISA.RDPID: - return False - ## Intel SHA extensions - if isa_set is ISA.SHA: - return False +def mk_inst( + *, + name: str, + att_mnemonic: str, + intel_mnemonic: str, + operands: ty.List, + tags: ty.Set[str], + isa_set: str, + isa_extension: str = None, + can_benchmark=None, +): + # print('mk_inst', name, att_mnemonic, intel_mnemonic, + # list(op() for op in operands), + # tags, isa_set, isa_extension, + # can_benchmark) + + global INSTRUCTIONS, ALL_TAGS + + tags = frozenset(tags) | frozenset([name, att_mnemonic, intel_mnemonic]) + + operands = tuple(mk_op() for mk_op in operands) + + ## filter out instructions we currently do not support + if can_benchmark is None: + can_benchmark = pipedream_asm_backend_can_handle( + name, isa_set, intel_mnemonic, att_mnemonic, tags, operands + ) + + inst = X86_Instruction( + name, att_mnemonic, intel_mnemonic, isa_set, operands, tags, can_benchmark + ) + ATT_MNEMONICS[name] = att_mnemonic + + if name in INSTRUCTIONS: + raise ValueError( + "\n".join( + [ + f"Duplicate instruction:", + f" ({INSTRUCTIONS[name]}", + f" vs", + f" {inst})", + ] + ) + ) + + INSTRUCTIONS[name] = inst + ALL_TAGS = ALL_TAGS | tags + + return inst + + +def pipedream_asm_backend_can_handle( + name: str, + isa_set, + intel_mnemonic: str, + att_mnemonic: str, + tags: ty.Set[str], + operands: ty.List[ir.Operand], +): + if tags & TAGS_CANNOT_BENCHMARK: + return False - ## AVX & AVX-512 AES instructions - if isa_set in (ISA.VAES, ISA.AVXAES, ISA.AVX512_VAES_128, ISA.AVX512_VAES_256, ISA.AVX512_VAES_512): - return False + if intel_mnemonic in INTEL_MNEMONIC_CANNOT_BENCHMARK: + return False - ## AVX 512 in general - if isa_set.name.startswith('AVX512'): - return False + ##### need to update benchmark harness (currently crash) - ## Intel Virtualization Extensions (VT-x) - if isa_set is ISA.VTX: - return False + if name in INST_NAME_CANNOT_BENCHMARK: + return False - ## Direct store (MOVDIRI extension) - if isa_set is ISA.MOVDIR: - return False + assert name != "DIV_MEM64i16", [name, INST_NAME_CANNOT_BENCHMARK] - ## TSX Load-tracking (Spectre/Meltdown mitigation instructions) - if isa_set in (ISA.TSX_LDTRK, ISA.SERIALIZE): - return False + ##### modern extensions not available on my machine - ## ??? - if isa_set is ISA.VPCLMULQDQ: - return False + ISA = instructions_xed.ISA - ##### some funky instruction operand-size problems + ## VIA CPU instructions + if isa_set in (ISA.VIA_PADLOCK_RNG,): + return False - def match(name, noperands): - if att_mnemonic != name: - return False - if len(operands) != noperands: - return False - return True + ## Total Memory Encryption (TME) + if isa_set in (ISA.PCONFIG,): + return False - def is_reg(idx: int, reg_class): - op = operands[idx] - if not isinstance(operands[idx], ir.Register_Operand): - return False - if op.register_class is not reg_class: - return False - return True + ## AVX Galois Field instructions + if isa_set in ( + ISA.GFNI, + ISA.AVX_GFNI, + ): + return False - def is_imm(idx: int): - op = operands[idx] - if not isinstance(operands[idx], ir.Immediate_Operand): - return False - return True + ## Software Guard Extensions (SGX) + if isa_set in (ISA.SGX, ISA.SGX_ENCLV): + return False + + ## Cache Write Back + if isa_set is ISA.CLWB: + return False + + ## Supervisor Mode Access Prevention + if isa_set is ISA.SMAP: + return False + + ## Control-Flow Enforcement Technology (CET) + if isa_set is ISA.CET: + return False + + ## Intel Processor Trace (PT) + if isa_set is ISA.PT: + return False + + ## Read Processor ID + if isa_set is ISA.RDPID: + return False + + ## Intel SHA extensions + if isa_set is ISA.SHA: + return False + + ## AVX & AVX-512 AES instructions + if isa_set in ( + ISA.VAES, + ISA.AVXAES, + ISA.AVX512_VAES_128, + ISA.AVX512_VAES_256, + ISA.AVX512_VAES_512, + ): + return False + + ## AVX 512 in general + if isa_set.name.startswith("AVX512"): + return False + + ## Intel Virtualization Extensions (VT-x) + if isa_set is ISA.VTX: + return False + + ## Direct store (MOVDIRI extension) + if isa_set is ISA.MOVDIR: + return False + + ## TSX Load-tracking (Spectre/Meltdown mitigation instructions) + if isa_set in (ISA.TSX_LDTRK, ISA.SERIALIZE): + return False + + ## ??? + if isa_set is ISA.VPCLMULQDQ: + return False + + ##### some funky instruction operand-size problems + + def match(name, noperands): + if att_mnemonic != name: + return False + if len(operands) != noperands: + return False + return True + + def is_reg(idx: int, reg_class): + op = operands[idx] + if not isinstance(operands[idx], ir.Register_Operand): + return False + if op.register_class is not reg_class: + return False + return True + + def is_imm(idx: int): + op = operands[idx] + if not isinstance(operands[idx], ir.Immediate_Operand): + return False + return True + + def is_mem(idx: int): + op = operands[idx] + if not isinstance(operands[idx], ir.Memory_Operand): + return False + return True + + for op in operands: + if isinstance(op, ir.Register_Operand): + if op.register_class is VRX512: + return False + + if isinstance(op, ir.Memory_Operand): + if op.address_width != 64: + return False + + ## produces: "BSWAP_GPR16i16" (with a 66 prefix) + ## which is a valid instruction (and objdump can disassemble it), but gas says it is invalid. + ## + ## >> echo "bswap %ax # 66 0f c8" | as + ## Error: operand size mismatch for `bswap' + if match("bswap", 1) and is_reg(0, GPR16): + return False + + ## produces: "MOVSXD_GPR32i32_GPR32i32" + ## which is not a valid instruction according to GAS. + if match("movsll", 2) and is_reg(0, GPR32) and is_reg(1, GPR32): + return False + if match("movsll", 2) and is_reg(0, GPR32) and is_mem(1): + return False + + ## produces: "MOVSXD_GPR16i16_GPR32i32" + ## which is not a valid instruction according to GAS. + if match("movslw", 2) and is_reg(0, GPR16) and is_reg(1, GPR32): + return False + if match("movslw", 2) and is_reg(0, GPR16) and is_mem(1): + return False + if match("movslw", 2) and is_reg(0, GPR16) and is_reg(1, GPR16): + return False + + ## produces: "MOVSX_GPR16i16_GPR16i16" + ## which is not a valid instruction according to GAS. + if match("movsww", 2) and is_reg(0, GPR16) and is_reg(1, GPR16): + return False + if match("movsww", 2) and is_reg(0, GPR16) and is_mem(1): + return False + + ## produces: "MOVZX_GPR16i16_GPR16i16" (with a 66 prefix) + ## which is a valid instruction (and objdump can disassemble it), but gas says it is invalid. + ## + ## >> echo "movzww %ax, %ax # 66 0f b7 c0" | as + ## Error: invalid instruction suffix for `movzw' + if match("movzww", 2) and is_reg(0, GPR16) and is_reg(1, GPR16): + return False + if match("movzww", 2) and is_reg(0, GPR16) and is_mem(1): + return False + + ## FIXME: there are some XMM variants GAS does not accept + if match("vcvtpd2dq", 2) or match("vcvtpd2ps", 2) or match("vcvttpd2dq", 2): + return False + + if match("maskmovq", 3) or match("maskmovdqu", 3) or match("vmaskmovdqu", 3): + if all(o.visibility is ir.Operand_Visibility.EXPLICIT for o in operands): + return True + else: + ## FIXME: fixed memory base register DS:DI/EDI/RDI (currently can't handle it) + return False + + ## FIXME: register/imm operand is actually second displacement/index operand + ## need to tell this to the register allocator + if (match("btw", 3) or match("btl", 3) or match("btq", 3)) and is_mem(0): + return False + if (match("btcw", 3) or match("btcl", 3) or match("btcq", 3)) and is_mem(0): + return False + if (match("btsw", 3) or match("btsl", 3) or match("btsq", 3)) and is_mem(0): + return False + if (match("btrw", 3) or match("btrl", 3) or match("btrq", 3)) and is_mem(0): + return False - def is_mem(idx: int): - op = operands[idx] - if not isinstance(operands[idx], ir.Memory_Operand): - return False return True - for op in operands: - if isinstance(op, ir.Register_Operand): - if op.register_class is VRX512: - return False - - if isinstance(op, ir.Memory_Operand): - if op.address_width != 64: - return False - - ## produces: "BSWAP_GPR16i16" (with a 66 prefix) - ## which is a valid instruction (and objdump can disassemble it), but gas says it is invalid. - ## - ## >> echo "bswap %ax # 66 0f c8" | as - ## Error: operand size mismatch for `bswap' - if match('bswap', 1) and is_reg(0, GPR16): - return False - - ## produces: "MOVSXD_GPR32i32_GPR32i32" - ## which is not a valid instruction according to GAS. - if match('movsll', 2) and is_reg(0, GPR32) and is_reg(1, GPR32): - return False - if match('movsll', 2) and is_reg(0, GPR32) and is_mem(1): - return False - - ## produces: "MOVSXD_GPR16i16_GPR32i32" - ## which is not a valid instruction according to GAS. - if match('movslw', 2) and is_reg(0, GPR16) and is_reg(1, GPR32): - return False - if match('movslw', 2) and is_reg(0, GPR16) and is_mem(1): - return False - if match('movslw', 2) and is_reg(0, GPR16) and is_reg(1, GPR16): - return False - - ## produces: "MOVSX_GPR16i16_GPR16i16" - ## which is not a valid instruction according to GAS. - if match('movsww', 2) and is_reg(0, GPR16) and is_reg(1, GPR16): - return False - if match('movsww', 2) and is_reg(0, GPR16) and is_mem(1): - return False - - ## produces: "MOVZX_GPR16i16_GPR16i16" (with a 66 prefix) - ## which is a valid instruction (and objdump can disassemble it), but gas says it is invalid. - ## - ## >> echo "movzww %ax, %ax # 66 0f b7 c0" | as - ## Error: invalid instruction suffix for `movzw' - if match('movzww', 2) and is_reg(0, GPR16) and is_reg(1, GPR16): - return False - if match('movzww', 2) and is_reg(0, GPR16) and is_mem(1): - return False - - ## FIXME: there are some XMM variants GAS does not accept - if match('vcvtpd2dq', 2) or match('vcvtpd2ps', 2) or match('vcvttpd2dq', 2): - return False - - if match('maskmovq', 3) or match('maskmovdqu', 3) or match('vmaskmovdqu', 3): - if all(o.visibility is ir.Operand_Visibility.EXPLICIT for o in operands): - return True + +def make_reg_op( + *, + name: str, + reg_class: ir.Register_Class, + reg: X86_Register = None, + action: ir.Use_Def, + type, + elems: int, + visibility: ir.Operand_Visibility, +): + return lambda: X86_Register_Operand(name, visibility, action, reg_class, reg) + + +def make_imm_op( + *, + name: str, + imm_bits: int, + type, + elems: int, + value: int = None, + visibility: ir.Operand_Visibility, +): + if type[0] == "i": + clss = { + 8: Imm8, + 16: Imm16, + 32: Imm32, + 64: Imm64, + }[imm_bits] + elif type[0] == "u": + clss = { + 8: ImmU8, + 16: ImmU16, + 32: ImmU32, + 64: ImmU64, + }[imm_bits] + else: + raise ValueError("invalid type for immediate operand: " + repr(type)) + + return lambda: clss(name, visibility, value) + + +def make_brdisp_op( + *, name: str, disp_bits: int, type, elems: int, visibility: ir.Operand_Visibility +): + ## TODO: add a proper branch displacement op type + + return make_imm_op( + name=name, + imm_bits=disp_bits, + type=type, + elems=elems, + visibility=visibility, + value=0, + ) + + +def make_flags_op( + *, + name: str, + reg: X86_Register, + read: X86_Flags, + write: X86_Flags, + visibility: ir.Operand_Visibility, +): + if not read: + read = X86_Flags(0) + if not write: + write = X86_Flags(0) + + return lambda: X86_Flags_Operand(name, visibility, reg, read, write) + + +def make_addr_op( + name: str, + addr_bits: int, + type, + elems: int, + visibility: ir.Operand_Visibility, + base: ty.Optional[X86_Register] = None, +): + # TODO: more addressing modes + + if addr_bits == 64: + base_reg_class = BASE_REGISTER_64 + index_reg_class = INDEX_REGISTER_64 + disp_imm_class = Imm32 + scale_imm_class = Scale_Imm + elif addr_bits == 32: + base_reg_class = BASE_REGISTER_32 + index_reg_class = INDEX_REGISTER_32 + disp_imm_class = Imm32 + scale_imm_class = Scale_Imm + elif addr_bits == 16: + base_reg_class = BASE_REGISTER_16 + index_reg_class = INDEX_REGISTER_16 + disp_imm_class = Imm32 + scale_imm_class = lambda name: Scale_Imm(1) + else: + raise ValueError("invalid width for LEA src operand", src_width) + + EXPLICIT = ir.Operand_Visibility.EXPLICIT + USE = ir.Use_Def.USE + + assert base is None or (isinstance(base, X86_Register) and base in base_reg_class) + + base = X86_Register_Operand("base", EXPLICIT, USE, base_reg_class, base) + displacement = disp_imm_class("displacement", EXPLICIT) + + # base = X86_Register_Operand('base', EXPLICIT, USE, base_reg_class, None) if 'B' in mode else None + # index = X86_Register_Operand('index', EXPLICIT, USE, index_reg_class, None) if 'I' in mode else None + # scale = scale_imm_class('scale', EXPLICIT,) if 'S' in mode else None + # displacement = disp_imm_class('displacement', EXPLICIT,) if 'D' in mode else None + + return lambda: X86_Base_Displacement_Address_Operand( + name, visibility, addr_bits, base, displacement + ) + + +def make_mem_op( + name: str, + addr_bits: str, + mem_bits: int, + action: ir.Use_Def, + type, + elems: int, + visibility: ir.Operand_Visibility, + base: ty.Optional[X86_Register] = None, +): + # TODO: more addressing modes + + if addr_bits == 64: + base_reg_class = BASE_REGISTER_64 + index_reg_class = INDEX_REGISTER_64 + disp_imm_class = Imm32 + scale_imm_class = Scale_Imm + elif addr_bits == 32: + base_reg_class = BASE_REGISTER_32 + index_reg_class = INDEX_REGISTER_32 + disp_imm_class = Imm32 + scale_imm_class = Scale_Imm + elif addr_bits == 16: + base_reg_class = BASE_REGISTER_16 + index_reg_class = INDEX_REGISTER_16 + disp_imm_class = Imm32 + scale_imm_class = lambda name: Scale_Imm(1) else: - ## FIXME: fixed memory base register DS:DI/EDI/RDI (currently can't handle it) - return False - - ## FIXME: register/imm operand is actually second displacement/index operand - ## need to tell this to the register allocator - if (match('btw', 3) or match('btl', 3) or match('btq', 3)) and is_mem(0): - return False - if (match('btcw', 3) or match('btcl', 3) or match('btcq', 3)) and is_mem(0): - return False - if (match('btsw', 3) or match('btsl', 3) or match('btsq', 3)) and is_mem(0): - return False - if (match('btrw', 3) or match('btrl', 3) or match('btrq', 3)) and is_mem(0): - return False - - return True - - -def make_reg_op(*, name: str, reg_class: ir.Register_Class, reg: X86_Register = None, - action: ir.Use_Def, type, elems: int, visibility: ir.Operand_Visibility): - return lambda: X86_Register_Operand(name, visibility, action, reg_class, reg) - - -def make_imm_op(*, name: str, imm_bits: int, - type, elems: int, value: int = None, visibility: ir.Operand_Visibility): - if type[0] == 'i': - clss = { - 8: Imm8, - 16: Imm16, - 32: Imm32, - 64: Imm64, - }[imm_bits] - elif type[0] == 'u': - clss = { - 8: ImmU8, - 16: ImmU16, - 32: ImmU32, - 64: ImmU64, - }[imm_bits] - else: - raise ValueError('invalid type for immediate operand: ' + repr(type)) - - return lambda: clss(name, visibility, value) - - -def make_brdisp_op(*, name: str, disp_bits: int, - type, elems: int, visibility: ir.Operand_Visibility): - ## TODO: add a proper branch displacement op type - - return make_imm_op(name=name, imm_bits=disp_bits, type=type, elems=elems, - visibility=visibility, value=0) - - -def make_flags_op(*, name: str, reg: X86_Register, read: X86_Flags, write: X86_Flags, - visibility: ir.Operand_Visibility): - if not read: - read = X86_Flags(0) - if not write: - write = X86_Flags(0) - - return lambda: X86_Flags_Operand(name, visibility, reg, read, write) - - -def make_addr_op(name: str, addr_bits: int, type, elems: int, visibility: ir.Operand_Visibility, - base: ty.Optional[X86_Register] = None,): - # TODO: more addressing modes - - if addr_bits == 64: - base_reg_class = BASE_REGISTER_64 - index_reg_class = INDEX_REGISTER_64 - disp_imm_class = Imm32 - scale_imm_class = Scale_Imm - elif addr_bits == 32: - base_reg_class = BASE_REGISTER_32 - index_reg_class = INDEX_REGISTER_32 - disp_imm_class = Imm32 - scale_imm_class = Scale_Imm - elif addr_bits == 16: - base_reg_class = BASE_REGISTER_16 - index_reg_class = INDEX_REGISTER_16 - disp_imm_class = Imm32 - scale_imm_class = lambda name: Scale_Imm(1) - else: - raise ValueError('invalid width for LEA src operand', src_width) - - EXPLICIT = ir.Operand_Visibility.EXPLICIT - USE = ir.Use_Def.USE - - assert base is None or (isinstance(base, X86_Register) and base in base_reg_class) - - base = X86_Register_Operand('base', EXPLICIT, USE, base_reg_class, base) - displacement = disp_imm_class('displacement', EXPLICIT) - - # base = X86_Register_Operand('base', EXPLICIT, USE, base_reg_class, None) if 'B' in mode else None - # index = X86_Register_Operand('index', EXPLICIT, USE, index_reg_class, None) if 'I' in mode else None - # scale = scale_imm_class('scale', EXPLICIT,) if 'S' in mode else None - # displacement = disp_imm_class('displacement', EXPLICIT,) if 'D' in mode else None - - return lambda: X86_Base_Displacement_Address_Operand( - name, - visibility, - addr_bits, - base, - displacement - ) - - -def make_mem_op(name: str, addr_bits: str, mem_bits: int, - action: ir.Use_Def, type, elems: int, visibility: ir.Operand_Visibility, - base: ty.Optional[X86_Register] = None,): - # TODO: more addressing modes - - if addr_bits == 64: - base_reg_class = BASE_REGISTER_64 - index_reg_class = INDEX_REGISTER_64 - disp_imm_class = Imm32 - scale_imm_class = Scale_Imm - elif addr_bits == 32: - base_reg_class = BASE_REGISTER_32 - index_reg_class = INDEX_REGISTER_32 - disp_imm_class = Imm32 - scale_imm_class = Scale_Imm - elif addr_bits == 16: - base_reg_class = BASE_REGISTER_16 - index_reg_class = INDEX_REGISTER_16 - disp_imm_class = Imm32 - scale_imm_class = lambda name: Scale_Imm(1) - else: - raise ValueError('invalid width for LEA src operand', src_width) - - if base is not None: - base = base.as_width(addr_bits) - - EXPLICIT = ir.Operand_Visibility.EXPLICIT - USE = ir.Use_Def.USE - - assert base is None or (isinstance(base, X86_Register) and base in base_reg_class), [base, *base_reg_class] - - displacement = None - - if base.widest is RSP: - displacement = 0 - - if base.widest is RAX: - # FIXME: extract-xed-database puts AX/EAX/RAX as default base register everywhere - base = None - - base = X86_Register_Operand('base', EXPLICIT, USE, base_reg_class, base) - displacement = disp_imm_class('displacement', EXPLICIT, displacement) - - # base = X86_Register_Operand('base', EXPLICIT, USE, base_reg_class, None) if 'B' in mode else None - # index = X86_Register_Operand('index', EXPLICIT, USE, index_reg_class, None) if 'I' in mode else None - # scale = scale_imm_class('scale', EXPLICIT,) if 'S' in mode else None - # displacement = disp_imm_class('displacement', EXPLICIT,) if 'D' in mode else None - - return lambda: X86_Base_Displacement_Memory_Operand( - name, - visibility, - action, - addr_bits, - mem_bits, - base, - displacement - ) + raise ValueError("invalid width for LEA src operand", src_width) + + if base is not None: + base = base.as_width(addr_bits) + + EXPLICIT = ir.Operand_Visibility.EXPLICIT + USE = ir.Use_Def.USE + + assert base is None or ( + isinstance(base, X86_Register) and base in base_reg_class + ), [base, *base_reg_class] + + displacement = None + + if base.widest is RSP: + displacement = 0 + + if base.widest is RAX: + # FIXME: extract-xed-database puts AX/EAX/RAX as default base register everywhere + base = None + + base = X86_Register_Operand("base", EXPLICIT, USE, base_reg_class, base) + displacement = disp_imm_class("displacement", EXPLICIT, displacement) + + # base = X86_Register_Operand('base', EXPLICIT, USE, base_reg_class, None) if 'B' in mode else None + # index = X86_Register_Operand('index', EXPLICIT, USE, index_reg_class, None) if 'I' in mode else None + # scale = scale_imm_class('scale', EXPLICIT,) if 'S' in mode else None + # displacement = disp_imm_class('displacement', EXPLICIT,) if 'D' in mode else None + + return lambda: X86_Base_Displacement_Memory_Operand( + name, visibility, action, addr_bits, mem_bits, base, displacement + ) instructions_xed.make_instruction_database( - make_instruction = mk_inst, - reg_op = make_reg_op, - mem_op = make_mem_op, - addr_op = make_addr_op, - imm_op = make_imm_op, - brdisp_op = make_brdisp_op, - flags_op = make_flags_op, + make_instruction=mk_inst, + reg_op=make_reg_op, + mem_op=make_mem_op, + addr_op=make_addr_op, + imm_op=make_imm_op, + brdisp_op=make_brdisp_op, + flags_op=make_flags_op, ) class Harness: - """ + """ Well known instructions used in benchmark harness - """ - - ADD_GPR64 = INSTRUCTIONS['ADD_GPR64i64_GPR64i64'] - IMUL_IMM_GPR64 = INSTRUCTIONS['IMUL_GPR64i64_GPR64i64_IMMi32'] - MOV_GPR64 = INSTRUCTIONS['MOV_GPR64i64_GPR64i64'] - MOV_IMM32_GPR32 = INSTRUCTIONS['MOV_GPR32i32_IMMi32'] - MOV_IMM64_GPR64 = INSTRUCTIONS['MOV_GPR64i64_IMMi64'] - MOV_IMM64_GPR64 = INSTRUCTIONS['MOV_GPR64i64_IMMi64'] - SUB_IMM8_GPR64 = INSTRUCTIONS['SUB_GPR64i64_IMMi8'] - TEST_GPR64 = INSTRUCTIONS['TEST_GPR64i64_GPR64i64'] - XOR_GPR32 = INSTRUCTIONS['XOR_GPR32i32_GPR32i32'] - - LOAD_MEM64_TO_REG64 = INSTRUCTIONS['MOV_GPR64i64_MEM64i64'] - STORE_REG64_TO_MEM64 = INSTRUCTIONS['MOV_MEM64i64_GPR64i64'] - - POP_GPR64 = INSTRUCTIONS['POP_GPR64i64'] - PUSH_GPR64 = INSTRUCTIONS['PUSH_GPR64i64'] - - XOR_GPR8i8_GPR8i8 = INSTRUCTIONS['XOR_GPR8i8_GPR8i8'] - XOR_GPR16i16_GPR16i16 = INSTRUCTIONS['XOR_GPR16i16_GPR16i16'] - XOR_GPR16i16_GPR16i16 = INSTRUCTIONS['XOR_GPR16i16_GPR16i16'] - XOR_GPR32i32_GPR32i32 = INSTRUCTIONS['XOR_GPR32i32_GPR32i32'] - XOR_GPR64i64_GPR64i64 = INSTRUCTIONS['XOR_GPR64i64_GPR64i64'] - - VMOVDQA_VR256_MEM64 = INSTRUCTIONS['VMOVDQA_VR256i32x8_MEM64i32x8'] - - LEAVE = INSTRUCTIONS['LEAVE_MEM64i64'] - CPUID = INSTRUCTIONS['CPUID'] - VZEROALL = INSTRUCTIONS['VZEROALL'] # TODO: supressed write to all XMM/YMM/ZMM vector registers - JNE_32 = INSTRUCTIONS['JNZ_BRDISP32'] - JEQ_32 = INSTRUCTIONS['JZ_BRDISP32'] - - RET = mk_inst( - name = 'RET', - att_mnemonic = 'ret', - intel_mnemonic = 'ret', - isa_set = instructions_xed.ISA.I86, - operands = [ - ## TODO: add memory operand support - ## TODO: add stackrel memory operand support - # make_mem_op(name='stack', addr_bits=64, mem_bits=64, type=i64, elems=1, action=R, visibility=SUPPRESSED), - make_reg_op(name='sp', reg=RSP, reg_class=RC_RSP, type=i64, elems=1, action=RW, visibility=SUPPRESSED), - make_reg_op(name='ip', reg=RIP, reg_class=RC_RIP, type=i64, elems=1, action=W, visibility=SUPPRESSED), - ], - tags = ['stack', 'scalar'], - can_benchmark = False, - ) - CALL = mk_inst( - name = 'CALL', - att_mnemonic = 'call', - intel_mnemonic = 'call', - isa_set = instructions_xed.ISA.I86, - operands = [ - make_imm_op(name='dst', imm_bits=32, type=i32, elems=1, visibility=EXPLICIT), - make_reg_op(name='ip', reg=RIP, reg_class=RC_RIP, action=RW, type=i64, elems=1, visibility=SUPPRESSED), - make_mem_op(name='stack', addr_bits=64, mem_bits=64, base=RSP, action=W, type=i64, elems=1, visibility=SUPPRESSED) - ], - tags = ['branch', 'conditional-branch', 'relative-branch'], - can_benchmark = False, - ) - # zero byte jump - JMP_E9_0 = mk_inst( - name = 'JMP_0', - # att_mnemonic = '.byte 0xeb, 0', - att_mnemonic = '.byte 0xe9, 0, 0, 0, 0', - intel_mnemonic = 'XXX', - isa_set = instructions_xed.ISA.I86, - operands = [ - # TODO: relbr operand type - # make_imm_op(name='dst', imm_bits=8, type=i8, elems=1, value=ir.Label('0'), visibility=EXPLICIT), - make_reg_op(name='ip', reg=RIP, reg_class=RC_RIP, action=RW, type=i64, elems=1, visibility=SUPPRESSED), - ], - tags = ['branch', 'conditional-branch', 'relative-branch'], - can_benchmark = True, - ) - - ## leave function (essentially `RSP = RBP`) - # LEAVE = mk_inst( - # name = 'LEAVE', - # att_mnemonic = 'leave', - # intel_mnemonic = 'leave', - # isa_set = instructions_xed.ISA.I86, - # operands = [ - # make_reg_op(name='src', reg=RBP, reg_class=RC_RBP, action=R, type=i64, elems=1, visibility=SUPPRESSED), - # make_reg_op(name='dst', reg=RSP, reg_class=RC_RSP, action=W, type=i64, elems=1, visibility=SUPPRESSED), - # ], - # tags = ['branch', 'conditional-branch', 'relative-branch'], - # can_benchmark = False, - # ) - - # special nop instruction used for IACA start/stop markers - IACA_START_STOP_NOP = mk_inst( - name = 'IACA_START_STOP_NOP', - att_mnemonic = 'fs addr32 nop', - intel_mnemonic = 'fs addr32 nop', - isa_set = instructions_xed.ISA.I86, - operands = [], - tags = ['nop'], - can_benchmark = False, - ) + """ + + ADD_GPR64 = INSTRUCTIONS["ADD_GPR64i64_GPR64i64"] + IMUL_IMM_GPR64 = INSTRUCTIONS["IMUL_GPR64i64_GPR64i64_IMMi32"] + MOV_GPR64 = INSTRUCTIONS["MOV_GPR64i64_GPR64i64"] + MOV_IMM32_GPR32 = INSTRUCTIONS["MOV_GPR32i32_IMMi32"] + MOV_IMM64_GPR64 = INSTRUCTIONS["MOV_GPR64i64_IMMi64"] + MOV_IMM64_GPR64 = INSTRUCTIONS["MOV_GPR64i64_IMMi64"] + SUB_IMM8_GPR64 = INSTRUCTIONS["SUB_GPR64i64_IMMi8"] + TEST_GPR64 = INSTRUCTIONS["TEST_GPR64i64_GPR64i64"] + XOR_GPR32 = INSTRUCTIONS["XOR_GPR32i32_GPR32i32"] + + LOAD_MEM64_TO_REG64 = INSTRUCTIONS["MOV_GPR64i64_MEM64i64"] + STORE_REG64_TO_MEM64 = INSTRUCTIONS["MOV_MEM64i64_GPR64i64"] + + POP_GPR64 = INSTRUCTIONS["POP_GPR64i64"] + PUSH_GPR64 = INSTRUCTIONS["PUSH_GPR64i64"] + + XOR_GPR8i8_GPR8i8 = INSTRUCTIONS["XOR_GPR8i8_GPR8i8"] + XOR_GPR16i16_GPR16i16 = INSTRUCTIONS["XOR_GPR16i16_GPR16i16"] + XOR_GPR16i16_GPR16i16 = INSTRUCTIONS["XOR_GPR16i16_GPR16i16"] + XOR_GPR32i32_GPR32i32 = INSTRUCTIONS["XOR_GPR32i32_GPR32i32"] + XOR_GPR64i64_GPR64i64 = INSTRUCTIONS["XOR_GPR64i64_GPR64i64"] + + VMOVDQA_VR256_MEM64 = INSTRUCTIONS["VMOVDQA_VR256i32x8_MEM64i32x8"] + + LEAVE = INSTRUCTIONS["LEAVE_MEM64i64"] + CPUID = INSTRUCTIONS["CPUID"] + VZEROALL = INSTRUCTIONS[ + "VZEROALL" + ] # TODO: supressed write to all XMM/YMM/ZMM vector registers + JNE_32 = INSTRUCTIONS["JNZ_BRDISP32"] + JEQ_32 = INSTRUCTIONS["JZ_BRDISP32"] + + RET = mk_inst( + name="RET", + att_mnemonic="ret", + intel_mnemonic="ret", + isa_set=instructions_xed.ISA.I86, + operands=[ + ## TODO: add memory operand support + ## TODO: add stackrel memory operand support + # make_mem_op(name='stack', addr_bits=64, mem_bits=64, type=i64, elems=1, action=R, visibility=SUPPRESSED), + make_reg_op( + name="sp", + reg=RSP, + reg_class=RC_RSP, + type=i64, + elems=1, + action=RW, + visibility=SUPPRESSED, + ), + make_reg_op( + name="ip", + reg=RIP, + reg_class=RC_RIP, + type=i64, + elems=1, + action=W, + visibility=SUPPRESSED, + ), + ], + tags=["stack", "scalar"], + can_benchmark=False, + ) + CALL = mk_inst( + name="CALL", + att_mnemonic="call", + intel_mnemonic="call", + isa_set=instructions_xed.ISA.I86, + operands=[ + make_imm_op( + name="dst", imm_bits=32, type=i32, elems=1, visibility=EXPLICIT + ), + make_reg_op( + name="ip", + reg=RIP, + reg_class=RC_RIP, + action=RW, + type=i64, + elems=1, + visibility=SUPPRESSED, + ), + make_mem_op( + name="stack", + addr_bits=64, + mem_bits=64, + base=RSP, + action=W, + type=i64, + elems=1, + visibility=SUPPRESSED, + ), + ], + tags=["branch", "conditional-branch", "relative-branch"], + can_benchmark=False, + ) + # zero byte jump + JMP_E9_0 = mk_inst( + name="JMP_0", + # att_mnemonic = '.byte 0xeb, 0', + att_mnemonic=".byte 0xe9, 0, 0, 0, 0", + intel_mnemonic="XXX", + isa_set=instructions_xed.ISA.I86, + operands=[ + # TODO: relbr operand type + # make_imm_op(name='dst', imm_bits=8, type=i8, elems=1, value=ir.Label('0'), visibility=EXPLICIT), + make_reg_op( + name="ip", + reg=RIP, + reg_class=RC_RIP, + action=RW, + type=i64, + elems=1, + visibility=SUPPRESSED, + ), + ], + tags=["branch", "conditional-branch", "relative-branch"], + can_benchmark=True, + ) + + ## leave function (essentially `RSP = RBP`) + # LEAVE = mk_inst( + # name = 'LEAVE', + # att_mnemonic = 'leave', + # intel_mnemonic = 'leave', + # isa_set = instructions_xed.ISA.I86, + # operands = [ + # make_reg_op(name='src', reg=RBP, reg_class=RC_RBP, action=R, type=i64, elems=1, visibility=SUPPRESSED), + # make_reg_op(name='dst', reg=RSP, reg_class=RC_RSP, action=W, type=i64, elems=1, visibility=SUPPRESSED), + # ], + # tags = ['branch', 'conditional-branch', 'relative-branch'], + # can_benchmark = False, + # ) + + # special nop instruction used for IACA start/stop markers + IACA_START_STOP_NOP = mk_inst( + name="IACA_START_STOP_NOP", + att_mnemonic="fs addr32 nop", + intel_mnemonic="fs addr32 nop", + isa_set=instructions_xed.ISA.I86, + operands=[], + tags=["nop"], + can_benchmark=False, + ) + tmp = collections.OrderedDict() for inst in sorted(INSTRUCTIONS.values(), key=lambda i: i.name): - tmp[inst.name] = inst + tmp[inst.name] = inst INSTRUCTIONS = types.MappingProxyType(tmp) diff --git a/tools/extract-binutils-instruction-database/README.md b/tools/extract-binutils-instruction-database/README.md new file mode 100644 index 0000000..24b712d --- /dev/null +++ b/tools/extract-binutils-instruction-database/README.md @@ -0,0 +1,30 @@ +# Extract-binutils-instruction-database + +Simple tool to create a Pipedream representation from the ARMv8a description written in `binutils` +source files. + +## Authors +* Nicolas Derumigny +* Théophile Bastian + +## Dependencies +* `cpp` the C Preprocessor +* `Python3` with all packages listed in `requirements.txt` + +## Installing +It is recommended to install the python dependencies and modules inside a +virtualenv: + +```bash +virtualenv -p python3 venv +source venv/bin/activate +pip install -r requirements.txt +pip install -e . +``` + +## Usage +Execute +```bash +extract-arm-db +``` +This will create the `instructions-binutils.py` file required for ARMv8a compatibility of Pipedream. diff --git a/tools/extract-binutils-instruction-database/extract_arm_db/extract_binutils.py b/tools/extract-binutils-instruction-database/extract_arm_db/extract_binutils.py new file mode 100644 index 0000000..f2b7843 --- /dev/null +++ b/tools/extract-binutils-instruction-database/extract_arm_db/extract_binutils.py @@ -0,0 +1,376 @@ +import subprocess +import wget # type:ignore +import yaml +from enum import Enum +from pathlib import Path +from typing import List, Dict, Any, Tuple + +URL = "https://sourceware.org/git/?p=binutils-gdb.git;a=blob_plain;f=opcodes/aarch64-tbl.h" +# https://sourceware.org/git/?p=binutils-gdb.git;a=blob_plain;f=opcodes/aarch64-opc-2.c +# may be usefull too (signification of `op_types` field) +INFILE = "aarch64-tbl.h" +OUTFILE = "instructions-binutils.py" + + +def brace_split(text, split_on=","): + """ A variant of `str.split` that respects braces """ + + def events(text): + pos = 0 + + while True: + next_obrace = text.find("{", pos) + next_cbrace = text.find("}", pos) + next_split = text.find(split_on, pos) + res_set = set([next_obrace, next_cbrace, next_split]) + res_set.discard(-1) + if not res_set: # no next event + break + next_event = min(res_set) + + if next_event == next_obrace: + yield (next_event, "{") + elif next_event == next_cbrace: + yield (next_event, "}") + else: + yield (next_event, "") + pos = next_event + 1 + + last_split_pos = 0 + open_brace = 0 + + for (pos, ev_type) in events(text): + if ev_type == "{": + open_brace += 1 + elif ev_type == "}": + open_brace -= 1 + if open_brace < 0: + raise Exception("Unmatched brace (col {}) in <{}>".format(pos, text)) + elif open_brace == 0: + yield text[last_split_pos:pos] + last_split_pos = pos + len(split_on) + last_chunk = text[last_split_pos:] + if last_chunk: + yield last_chunk + + +def convert_register_name(binop_name: str) -> str: + # `S_x` correspond (apparently) to an SIMD vector element of type `x` + if binop_name.startswith("S_"): + return f"V{binop_name[len('S_'):]}" + return binop_name + + +def parse_operands_str(operands_str, is_type: bool) -> List[List[str]]: + """ Parse an instruction's operands """ + + operands_str = operands_str.strip() + assert operands_str[0] == "{" and operands_str[-1] == "}" + op_body = operands_str[1:-1].strip() + if len(op_body) > 0 and op_body[-1] == ",": + op_body = op_body[:-1] + if not op_body: + return [] + ops = op_body.split(", ") + new_ops: List[List[str]] = [] + for op in ops: + real_op = op + if op.startswith("{"): + new_ops.append([]) + real_op = real_op[1:] + + if op.endswith("}"): + real_op = real_op[:-1] + if is_type: + real_op = real_op[len("AARCH64_OPND_") :] + else: + real_op = real_op[len("AARCH64_OPND_QLF_") :] + real_op = "registers." + convert_register_name(real_op) + if is_type: + new_ops.append(real_op) + else: + new_ops[-1].append(real_op) + return new_ops + + +class OP_TYPE(Enum): + REG = 0 + IMM = 1 + ADDR = 2 + FLAG = 3 + NO_OP = 10 + OTHER = 11 + + +def parse_operands_role( + list_op_roles: List[str], +) -> List[Tuple[OP_TYPE, List[str]]]: + # returns (create_func, name, op_role, action, visibility) + list_carac: List[Tuple[OP_TYPE, List[str]]] = [] + for role in list_op_roles: + if role in { + "Rd", + "Ed", + "Fd", + "Vd", + "Sd", + "VdD1", + "SVE_Pd", + "SVE_Zd", + "Rd_SP", + }: + list_carac.append((OP_TYPE.REG, [role, "W", "EXPLICIT"])) + elif "IMM" in role or role in { + "IDX", + "WIDTH", + "NZCV", + "COND", + "COND1", + "FBITS", + "HALF", + "BIT_NUM", + "MASK", + "LIMM", + # FIXME: this would actually need separate instruction entries + # depending on the prefetch operation + "PRFOP", + }: + list_carac.append((OP_TYPE.IMM, [role, "EXPLICIT"])) + elif ( + (role[0] in {"R", "E", "V", "F", "S", "L"} and len(role) < 5) + or role.startswith("SYSREG") + or role.startswith("SVE") + or role + in { + "LVt_AL", + "Rn_SP", + "Rm_SP", + "Rm_SFT", + "Rm_EXT", + "Rt_SP", + "Rt_LS64", + "Rt_SYS", + # Do not care, unsupported + "CRn", + "CRm", + } + ): + list_carac.append((OP_TYPE.REG, [role, "R", "EXPLICIT"])) + elif role in {"PAIRREG"}: + # FIXME: this will not work because RegAlloc will never take this register + # into account and R/W are broken + list_carac.append((OP_TYPE.REG, [role, "R", "IMPLICIT"])) + elif "ADDR" in role: + list_carac.append((OP_TYPE.ADDR, [role, "EXPLICIT"])) + elif role in {}: + list_carac.append((OP_TYPE.FLAG, [role, "IMPLICIT"])) + else: + list_carac.append((OP_TYPE.OTHER, [role])) + return list_carac + + +def parse_operands(list_op_classes, op_roles) -> List[List[str]]: + op_carac = parse_operands_role(op_roles) + if not all( + [len(list_op_classes[i]) == len(op_carac) for i in range(len(list_op_classes))] + ): + assert all( + [ + len(list_op_classes[i]) == len(list_op_classes[0]) + for i in range(len(list_op_classes)) + ] + ) + for idx, list_ops in enumerate(list_op_classes): + if (len(op_carac) - len(list_op_classes[idx])) > 0: + list_ops += ["Unknown (not in reg)"] * ( + len(op_carac) - len(list_op_classes[idx]) + ) + else: + op_carac += [(OP_TYPE.NO_OP, ["SUPPRESSED"])] * ( + len(list_op_classes[idx]) - len(op_carac) + ) + break + + assert all( + [len(list_op_classes[i]) == len(op_carac) for i in range(len(list_op_classes))] + ) + + ret: List[List[str]] = [] + for op_classes in list_op_classes: + new_operands: List[str] = [] + for op_class, (op_type, list_carac) in zip(op_classes, op_carac): + if op_type == OP_TYPE.REG: + if op_class[-len("NIL") :] == "NIL" and list_carac[0].startswith( + "SYSREG" + ): + op_class = "registers.SYSREG" + new_operands.append( + f'reg_op(name="{list_carac[0]}", reg_class={op_class}, ' + f"action={list_carac[1]}, visibility={list_carac[2]})" + ) + elif op_type == OP_TYPE.ADDR: + addr_class = op_class[len("registers.") :] + if addr_class == "NIL": + addr_class = list_carac[0] + addr_class = f"operands.{addr_class}" + new_operands.append( + f'addr_op(name="{list_carac[0]}", addr_class={addr_class}, ' + f"visibility={list_carac[1]})" + ) + elif op_type == OP_TYPE.IMM: + imm_class = op_class[len("registers.") :] + if imm_class == "NIL": + imm_class = f"operands.{list_carac[0]}" + new_operands.append( + f'imm_op(name="{list_carac[0]}", imm_class={imm_class}, ' + f"visibility={list_carac[1]})" + ) + elif ( + imm_class.startswith("imm_") and imm_class[len("imm_")].isnumeric() + ): + # Under the form imm_min_max + min_val = imm_class[len("imm_")] + max_val = imm_class[len("imm_X_") :] + new_operands.append( + f'imm_op(name="{list_carac[0]}",' + f"visibility={list_carac[1]}, min_val={min_val}, " + f"max_val={max_val})" + ) + + elif op_type == OP_TYPE.FLAG: + new_operands.append( + f'flag_op(name="{list_carac[0]}", visibility={list_carac[1]})' + ) + elif op_type == OP_TYPE.OTHER: + raise Exception(f"Unsupported operand type: {list_carac[0]}") + elif op_type == OP_TYPE.NO_OP: + pass + else: + raise Exception("Wrong OP_TYPE") + ret.append(new_operands) + return ret + + +def clean_instruction_dict(instruction) -> Dict[str, Any]: + instruction["extension"] = instruction["extension"][len("&aarch64_feature_") :] + if instruction["insn_class"][0].isnumeric(): + instruction["insn_class"] = "_" + instruction["insn_class"] + + operands = instruction["operands"] + op_roles = instruction.pop("op_types") + + instruction["operands"] = parse_operands(operands, op_roles) + + return instruction + + +def parse_line(line): + line = line.strip() + if not line: + return None + if not line.endswith("},"): + raise Exception("Invalid end of line: <{}>".format(line)) + if not line.startswith("{"): + raise Exception("Invalid start of line: <{}>".format(line)) + + line_split = list(map(lambda x: x.strip(), brace_split(line[1:-2]))) + assert len(line_split) == 12 + mnemonic = line_split[0][1:-1] # strip quotes + return line_split, clean_instruction_dict( + { + "mnemonic": mnemonic, + "insn_class": line_split[3], + "extension": line_split[5], + "op_types": parse_operands_str(line_split[6], True), + "operands": parse_operands_str(line_split[7], False), + } + ) + + +def generate_header(file, instruction_list: List[Dict]) -> None: + file.write( + "## This file is derived from the original ARM database of GNU binutils,\n" + "## available at https://www.gnu.org/software/binutils/\n\n" + ) + file.write( + "from pipedream.asm.armv8a import registers, flags, operands\n" + "from pipedream.asm.ir import Use_Def, Operand_Visibility\n\n" + "import enum\n" + "from typing import *\n\n" + ) + file.write( + "## read/write\n" + "R = Use_Def.USE\n" + "W = Use_Def.DEF\n" + "RW = Use_Def.USE_DEF\n\n" + "## operand visibility\n" + "EXPLICIT = Operand_Visibility.EXPLICIT\n" + "IMPLICIT = Operand_Visibility.IMPLICIT\n" + "SUPPRESSED = Operand_Visibility.SUPPRESSED\n\n" + ) + file.write("class ISA(enum.Enum):\n") + isa = set([inst["insn_class"] for inst in instruction_list]) + for ext in isa: + file.write(f"\t{ext} = '{ext if ext[0] != '_' else ext[1:]}'\n") + + file.write("\nclass ISA_Extension(enum.Enum):\n") + isa = set([inst["extension"] for inst in instruction_list]) + for ext in isa: + file.write(f"\t{ext} = '{ext}'\n") + file.write("\n\n") + file.write( + "def make_instruction_database(*, make_instruction, reg_op, " + "addr_op, imm_op, flag_op):\n" + ) + + +def generate_instruction_file() -> None: + if not Path(INFILE).is_file(): + print("Downloading binutil's description") + assert wget.download(URL) == INFILE + print("\nDone") + assert Path(INFILE).is_file + + with open(INFILE, "r") as f: + lines = f.readlines() + + # Removing includes and defines of "OPx" as it will output a preprocessor + with open(INFILE, "w") as f: + for l in lines: + if not l.startswith("#include") and not l.startswith("#error"): + f.write(l) + + # Running the preprocessor to handle #defines + infile_processed = INFILE + ".processed" + subprocess.run(["cpp", INFILE, "-o", infile_processed]) + instruction_list: List = [] + with open(infile_processed, "r") as f: + # 0 = no, 1 = first line ("{"), 2 = yes + inside_aarch64_decl = 0 + for line in f: + if inside_aarch64_decl == 2: + # Last line is `{0, 0, 0, 0, 0, 0, {}, {}, 0},\n` + if line.startswith(" {0, 0, 0"): + inside_aarch64_decl = 0 + continue + ret = parse_line(line) + if ret is not None: + _, instruction = ret + instruction_list.append(instruction) + + elif inside_aarch64_decl == 1 or line.startswith("struct aarch64_opcode"): + inside_aarch64_decl += 1 + + with open(OUTFILE, "w") as f: + generate_header(f, instruction_list) + + for arguments in instruction_list: + for ops in arguments["operands"]: + operands = "".join([f"\t\t\t{op},\n" for op in ops]) + f.write( + "\tmake_instruction(\n" + f'\t\tmnemonic = "{arguments["mnemonic"]}",\n' + f'\t\tisa_set = ISA.{arguments["insn_class"]},\n' + f'\t\tisa_extension = ISA_Extention.{arguments["extension"]},\n' + f"\t\toperands = [\n{operands}\t\t],\n\t)\n" + ) diff --git a/tools/extract-binutils-instruction-database/requirements.txt b/tools/extract-binutils-instruction-database/requirements.txt new file mode 100644 index 0000000..d654c71 --- /dev/null +++ b/tools/extract-binutils-instruction-database/requirements.txt @@ -0,0 +1,3 @@ +PyYAML==5.4.1 +wget==3.2 +mypy==0.812 diff --git a/tools/extract-binutils-instruction-database/setup.py b/tools/extract-binutils-instruction-database/setup.py new file mode 100755 index 0000000..4c54264 --- /dev/null +++ b/tools/extract-binutils-instruction-database/setup.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 + +from setuptools import setup, find_packages +import sys + + +def parse_requirements(): + reqs = [] + with open("requirements.txt", "r") as handle: + for line in handle: + if line.startswith("-"): + continue + reqs.append(line) + return reqs + + +setup( + name="extract-binutils-db", + version="0.0.1", + description="", + author="CORSE", + license="LICENSE", + url="https://gitlab.inria.fr/fgruber/pipedream", + packages=find_packages(), + include_package_data=True, + long_description=open("README.md").read(), + install_requires=parse_requirements(), + entry_points={ + "console_scripts": [ + ( + "extract-arm-instructions=extract_arm_db.extract_binutils:generate_instruction_file" + ), + ] + }, +) -- GitLab From 9c0b57a5509c6f65f0431be6c58b520e71a5636a Mon Sep 17 00:00:00 2001 From: Nicolas Derumigny <nderumigny@gmail.com> Date: Fri, 7 May 2021 18:47:50 +0200 Subject: [PATCH 07/12] armv8a: adapting code to extract-binutils syntax --- src/pipedream/asm/armv8a/__init__.py | 16 +- src/pipedream/asm/armv8a/instructions.py | 227 ++++- src/pipedream/asm/armv8a/operands.py | 955 ++++++++++++++++-- src/pipedream/asm/armv8a/registers.py | 26 +- src/pipedream/asm/ir.py | 8 +- src/pipedream/asm/x86/instructions.py | 2 +- src/pipedream/utils/statistics.py | 793 ++++++++------- src/pipedream/utils/yaml.py | 818 +++++++-------- .../extract_arm_db/extract_binutils.py | 106 +- 9 files changed, 1987 insertions(+), 964 deletions(-) diff --git a/src/pipedream/asm/armv8a/__init__.py b/src/pipedream/asm/armv8a/__init__.py index 3c1dc26..55d2f79 100644 --- a/src/pipedream/asm/armv8a/__init__.py +++ b/src/pipedream/asm/armv8a/__init__.py @@ -1,3 +1,9 @@ +from typing import List +import random + + +import pipedream.utils.abc as abc + from pipedream.utils import * from pipedream.asm.ir import * from pipedream.asm.allocator import * @@ -24,7 +30,15 @@ __all__ = [ # TODO class ARMv8_Architecture(Architecture): - pass + @property # type: ignore + @abc.override + def name(self) -> str: + return "ARMv8a" + + @property # type: ignore + @abc.override + def nb_vector_reg(self) -> int: + return len(V_8B) # type: ignore # TODO diff --git a/src/pipedream/asm/armv8a/instructions.py b/src/pipedream/asm/armv8a/instructions.py index 05955e6..9bfba30 100644 --- a/src/pipedream/asm/armv8a/instructions.py +++ b/src/pipedream/asm/armv8a/instructions.py @@ -1,3 +1,5 @@ +import collections +import types import pipedream.utils.abc as abc from pipedream.asm import ir @@ -5,10 +7,10 @@ from pipedream.asm import ir from pipedream.asm.armv8a.operands import * from pipedream.asm.armv8a.registers import * from pipedream.asm.armv8a.flags import * +from pipedream.asm.armv8a import instructions_binutils # type: ignore -from typing import Dict, Sequence, Union - -# from pipedream.asm.armv8a import instructions_binutils +from copy import copy +from typing import Dict, Sequence, Union, List, Tuple, FrozenSet, Optional, Any __all__ = [ @@ -22,116 +24,233 @@ __all__ = [ Instruction_Name = str Operand_Name = str -ALL_TAGS = frozenset() INSTRUCTIONS: Dict[Instruction_Name, ir.Machine_Instruction] = {} MNEMONICS: Dict[Instruction_Name, str] = {} class ARMv8_Instruction(ir.Machine_Instruction): - def __init__(self, name, mnemonic, isa_set, operands, tags, can_benchmark): + def __init__(self, name, mnemonic, isa_set, operands, can_benchmark): self._name = name self._mnemonic = mnemonic self._isa_set = isa_set self._operands = operands - self._tags = tags self._can_benchmark = can_benchmark assert len(operands) == len( set(o.name for o in operands) ), "duplicate operand name in " + str(self) - @property + @property # type: ignore @abc.override def name(self) -> str: return self._name - @property - def att_mnemonic(self) -> str: - return self._att_mnemonic + @property # type: ignore + @abc.override + def tags(self) -> Sequence[str]: + return [] - @property - def intel_mnemonic(self) -> str: - return self._intel_mnemonic + @property # type: ignore + def mnemonic(self) -> str: + return self._mnemonic - @property + @property # type: ignore @abc.override def isa_set(self) -> str: return self._isa_set - @property - @abc.override - def tags(self) -> Sequence[str]: - return self._tags - - @property + @property # type: ignore @abc.override - def operands(self) -> ["Operand"]: + def operands(self) -> List[ir.Operand]: return self._operands @abc.override def update_operand( self, idx_or_name: Union[int, str], fn ) -> ir.Machine_Instruction: - if type(idx_or_name) is str: + if isinstance(idx_or_name, str): idx = self.get_operand_idx(idx_or_name) else: idx = idx_or_name ops = list(self.operands) ops[idx] = fn(ops[idx]) - new = copy.copy(self) + new = copy(self) new._operands = ops return new - @abc.override - def encodings(self) -> ["Instruction_Encoding"]: - raise NotImplementedError() - - @property + @property # type: ignore @abc.override def can_benchmark(self) -> bool: return self._can_benchmark - -class ARMv8_Instruction_Set(ir.Instruction_Set): @abc.override - def instruction_groups(self) -> ["Instruction_Group"]: - return [] + def encodings(self) -> ["Instruction_Encoding"]: + raise NotImplementedError() + +class ARMv8_Instruction_Set(ir.Instruction_Set): @abc.override - def instructions(self) -> [ir.Machine_Instruction]: + def instructions(self) -> List[ir.Machine_Instruction]: return list(INSTRUCTIONS.values()) @abc.override - def benchmark_instructions(self) -> [ir.Machine_Instruction]: - return [I for I in INSTRUCTIONS.values() if I._can_benchmark] - - @property - @abc.override - def all_tags(self): - tags = collections.OrderedDict() - - for inst in INSTRUCTIONS.values(): - for tag in inst.tags: - tags[tag] = True - - yield from tags + def benchmark_instructions(self) -> List[ir.Machine_Instruction]: + return [I for I in INSTRUCTIONS.values() if I.can_benchmark] @abc.override def instruction_for_name(self, name: str) -> ir.Machine_Instruction: return INSTRUCTIONS[name] - @abc.override - def instructions_for_tags(self, *tags) -> Sequence[ir.Machine_Instruction]: - tags = set(tags) - - for inst in INSTRUCTIONS.values(): - if tags <= inst.tags: - yield inst - def __getitem__(self, name): return INSTRUCTIONS[name] @abc.override def __iter__(self): return iter(INSTRUCTIONS.values()) + + +MNEMONIC_CANNOT_BENCHMARK: FrozenSet[str] = frozenset([]) +INST_NAME_CANNOT_BENCHMARK: FrozenSet[str] = frozenset([]) + + +def pipedream_asm_backend_can_handle( + name: str, + isa_set, + mnemonic: str, + operands: List[ir.Operand], +): + if mnemonic in MNEMONIC_CANNOT_BENCHMARK: + return False + + if name in INST_NAME_CANNOT_BENCHMARK: + return False + + # TODO: check it! + return True + + +def mk_inst( + *, + mnemonic: str, + isa_set: str, + operands: Union[Tuple, List], + isa_extension: str = None, + can_benchmark=None, +): + global INSTRUCTIONS + + operands = tuple(mk_op() for mk_op in operands) + name = mnemonic.upper() + ( + "".join(map(lambda op: "_" + op.name + "_" + op.classname, operands)).upper() + ) + + ## filter out instructions we currently do not support + if can_benchmark is None: + can_benchmark = pipedream_asm_backend_can_handle( + name, isa_set, mnemonic, list(operands) + ) + + inst = ARMv8_Instruction(name, mnemonic, isa_set, operands, can_benchmark) + MNEMONICS[name] = mnemonic + + if name in INSTRUCTIONS: + raise ValueError( + "\n".join( + [ + f"Duplicate instruction:", + f" ({INSTRUCTIONS[name]}", + f" vs", + f" {inst})", + ] + ) + ) + + INSTRUCTIONS[name] = inst + + return inst + + +def make_reg_op( + *, + name: str, + reg_class: ir.Register_Class, + action: ir.Use_Def, + visibility: ir.Operand_Visibility, +): + return lambda: ARMv8_Register_Operand(name, visibility, action, reg_class, None) + + +def make_addr_op( + name: str, + visibility: ir.Operand_Visibility, + reg_class: Optional[Any] = None, +): + if reg_class is None: + # FIXME: add SP register! + return lambda: ARMv8_Base_Operand(name, base=X, visibility=visibility) + else: + # FIXME: check real IMM width + return lambda: ARMv8_Base_Offset_Operand( + name, base=reg_class, offset=IMM0, visibility=visibility + ) + + +def make_imm_op( + *, + name: str, + visibility: ir.Operand_Visibility, + imm_class: Optional[Any] = None, + reg_class: Optional[Any] = None, + min_val: Optional[int] = None, + max_val: Optional[int] = None, +): + if imm_class is None: + assert isinstance(min_val, int) and isinstance(max_val, int) + return lambda: ImmMinMax(name, visibility, min_val, max_val) + elif reg_class is not None: + return lambda: imm_class(name, visibility=visibility, reg_class=reg_class) + else: + return lambda: imm_class(name, visibility) + + +# FIXME: this is currently never used in `instruction-binutils` +def make_flags_op( + *, + name: str, + read: ARMv8_NZCV, + write: ARMv8_NZCV, + visibility: ir.Operand_Visibility, +): + if not read: + read = ARMv8_NZCV(0) + if not write: + write = ARMv8_NZCV(0) + + return lambda: ARMv8_Flags_Operand( + name, visibility, FLAGS, read, write # type:ignore + ) + + +instructions_binutils.make_instruction_database( + make_instruction=mk_inst, + reg_op=make_reg_op, + addr_op=make_addr_op, + imm_op=make_imm_op, + flags_op=make_flags_op, +) + + +class Harness: + """ + Well known instructions used in benchmark harness + """ + + # TODO at codegen + pass + + +tmp = collections.OrderedDict() +for inst in sorted(INSTRUCTIONS.values(), key=lambda inst: inst.name): # type: ignore + tmp[inst.name] = inst +INSTRUCTIONS = types.MappingProxyType(tmp) # type: ignore diff --git a/src/pipedream/asm/armv8a/operands.py b/src/pipedream/asm/armv8a/operands.py index 376405f..8f20f32 100644 --- a/src/pipedream/asm/armv8a/operands.py +++ b/src/pipedream/asm/armv8a/operands.py @@ -3,31 +3,26 @@ from pipedream.utils import abc from .flags import ARMv8_NZCV, ARMv8_CPSR from .registers import * -from typing import Optional, Union +from typing import Optional, Union, Any + +import math __all__ = [ "ARMv8_Operand", + "ARMv8_Base_Operand", + "ARMv8_Base_Offset_Operand", "ARMv8_Register_Operand", "ARMv8_Immediate_Operand", - "ImmLo", - "ImmHi", - "Imm12", - "UImm6", - "UImm4", - "UImm5", - "Imm16", - "Imm26", - "Imm14", - "Imm9", - "Imm7", - "Imm3", "ARMv8_Flags_Operand", + "ImmMinMax", + "IMM0", ] class ARMv8_Operand(ir.Operand): - def __init__(self, name: str, visibility: ir.Operand_Visibility): + def __init__(self, name: str, classname: str, visibility: ir.Operand_Visibility): self._name = name + self._classname = classname self._visibility = visibility @property # type: ignore @@ -40,6 +35,11 @@ class ARMv8_Operand(ir.Operand): def visibility(self) -> ir.Operand_Visibility: return self._visibility + @property # type: ignore + @abc.override + def classname(self) -> str: + return self._classname + class ARMv8_Register_Operand(ARMv8_Operand, ir.Register_Operand): def __init__( @@ -48,21 +48,14 @@ class ARMv8_Register_Operand(ARMv8_Operand, ir.Register_Operand): visibility: ir.Operand_Visibility, use_def: ir.Use_Def, reg_class: ir.Register_Class, - reg: ARMv8_Register, + reg: Optional[ARMv8_Register], ): - super().__init__(name, visibility) + classname = reg_class.name + super().__init__(name, classname, visibility) self._use_def = use_def self._reg_class = reg_class self._reg = reg - @property # type: ignore - @abc.override - def short_name(self) -> str: - if self._reg: - return self._reg.name - else: - return self._reg_class.name - @property # type: ignore @abc.override def register_class(self) -> ir.Register_Class: @@ -104,7 +97,7 @@ class ARMv8_Flags_Operand(ARMv8_Operand, ir.Flags_Operand): assert isinstance(flags_read, ARMv8_NZCV), flags_read assert isinstance(flags_written, ARMv8_NZCV), flags_written - super().__init__(name, visibility) + super().__init__(name, "NZCV", visibility) self._reg = reg self._flags_read = flags_read self._flags_written = flags_written @@ -168,14 +161,162 @@ class ARMv8_Flags_Operand(ARMv8_Operand, ir.Flags_Operand): return txt -class ARMv8_Immediate_Operand(ARMv8_Operand, ir.Immediate_Operand): - def __init__(self, name: str, visibility: ir.Operand_Visibility, value: int = None): - super().__init__(name, visibility) +class ARMv8_Base_Operand(ARMv8_Operand): + """ + base addressing mode + """ + + def __init__( + self, + name: str, + base: ir.Register_Class, + visibility: ir.Operand_Visibility, + ): + assert isinstance(base, ir.Register_Class), base + super().__init__(name, name, visibility) + self._base = base + + @property + def name(self) -> str: + return self._name + + @property + def base(self) -> ir.Register_Operand: + return self._base + + def with_base(self, base_reg: ir.Register) -> "ARMv8_Base_Operand": + new = copy.copy(self) + new._base = new._base.with_register(base_reg) + return new + + @property # type: ignore + @abc.override + def is_virtual(self): + return False + + @property # type: ignore + @abc.override + def use_def(self) -> ir.Use_Def: + return ir.Use_Def.USE + + @property + def _short_short_name(self) -> str: + return "mem" + + @property + @abc.override + def short_name(self) -> str: + return f"{self._name}" + + def __repr__(self): + tmp = [ + self.name, + ":", + self._short_short_name, + "/", + "(", + repr(self.base), + ")", + ] + + return "".join(tmp) + + +# TODO: add other addressing mode (pre-indexed / post-indexed) +class ARMv8_Base_Offset_Operand(ir.Base_Displacement_Operand): + """ + Mixing for defining base/offset operands + """ + + def __init__( + self, + name, + base: Optional[Any], + offset: Optional[Any], + visibility: ir.Operand_Visibility, + ): + super().__init__() + + self._name = name + self._offset = offset + self._base_class = base + + @property + def name(self): + return self._name + + @property + def classname(self): + return self._name + + # FIXME: Broken: mismatch between addr type and instances + @property + def base(self) -> ir.Register_Operand: + return self._base + + @property + def offset(self) -> ir.Immediate_Operand: + return self._offset + + def with_base(self, base_reg: ir.Register) -> "ARMv8_Base_Offset_Operand": + new = copy.copy(self) + new._base = new._base.with_register(base_reg) + return new + + def with_offset(self, disp: int) -> "ARMv8_Base_Offset_Operand": + new = copy.copy(self) + new._offset = new._offset.with_value(disp) + return new + + @property # type: ignore + @abc.override + def is_virtual(self): + return False + + @property + def _short_short_name(self) -> str: + return "mem" + + @property + @abc.override + def short_name(self) -> str: + return f"{self._short_short_name}BO{self.address_width}_{self.memory_width}" - assert value is None or isinstance( - value, (int, ir.Label) - ), f"want int or Label, have {value!r}" + @property + @abc.override + def sub_operands(self): + yield self.base + yield self.offset + + def get_operand_name(self, idx: int) -> str: + return { + 0: "base", + 1: "offset", + }[idx] + + def __repr__(self): + tmp = [ + self.name, + ":", + self._short_short_name, + "BO", + str(self.address_width), + "/", + ] + if hasattr(self, "memory_width"): + tmp += [ + str(self.memory_width), + "/", + ] + tmp += ["(", repr(self.base), ", ", repr(self.offset), ")"] + + return "".join(tmp) + +class ARMv8_Immediate_Operand(ARMv8_Operand, ir.Immediate_Operand): + def __init__(self, name: str, visibility: ir.Operand_Visibility, value: int = None): + super().__init__(name, str(type(self)), visibility) + assert value is None or isinstance(value, int) self._value = value @property # type: ignore @@ -226,19 +367,44 @@ def _unsigned_min_max(num_bits: int): return min, max -# Immediate operands of the A64 Instruction Set, in the order of the "Architecture -# Reference Manual" +class ImmMinMax(ARMv8_Immediate_Operand): + """ + Every operands with specific min/max values + """ + + def __init__( + self, name: str, visibility: ir.Operand_Visibility, min_val: int, max_val: int + ): + self._min = min_val + self._max = max_val + super().__init__(name, visibility) + + @property # type: ignore + @abc.override + def num_bits(self) -> int: + return int(math.log(self._max - self._min + 1, 2)) + + @property # type: ignore + @abc.override + def _is_valid_value(self, value): # type: ignore + return type(value) is int and self._min <= value <= self._max + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return self.max_val -# Jump operands are probably useless, as Pipedream's codegen rely on GNU assembly -# syntax (with labels instead of immediate offsets) +### Operands as described by binutils ### -class ImmLo(ARMv8_Immediate_Operand): +# TODO: +# Add support for shifts +class AIMM(ARMv8_Immediate_Operand): """ - Class of PC-relative operands: immlo + 12-bit unsigned immediate with optional left shift of 12 bits """ - _NUM_BITS = 2 + _NUM_BITS = 12 MIN, MAX = _unsigned_min_max(_NUM_BITS) @property # type: ignore @@ -251,14 +417,19 @@ class ImmLo(ARMv8_Immediate_Operand): def _is_valid_value(clss, value): return type(value) is int and clss.MIN <= value <= clss.MAX + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return self.max_val -class ImmHi(ARMv8_Immediate_Operand): + +class ADDR_SIMM9(ARMv8_Immediate_Operand): """ - Class of PC-relative operands: immhi + 9-bit signed immediate """ - _NUM_BITS = 19 - MIN, MAX = _unsigned_min_max(_NUM_BITS) + _NUM_BITS = 9 + MIN, MAX = _signed_min_max(_NUM_BITS) @property # type: ignore @abc.override @@ -270,13 +441,18 @@ class ImmHi(ARMv8_Immediate_Operand): def _is_valid_value(clss, value): return type(value) is int and clss.MIN <= value <= clss.MAX + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return self.max_val + -class Imm12(ARMv8_Immediate_Operand): +class ADDR_SIMM13(ARMv8_Immediate_Operand): """ - Class of 12-bit signed immediates + address with 13-bit signed immediate (multiple of 16) offset """ - _NUM_BITS = 12 + _NUM_BITS = 13 MIN, MAX = _signed_min_max(_NUM_BITS) @property # type: ignore @@ -289,18 +465,18 @@ class Imm12(ARMv8_Immediate_Operand): def _is_valid_value(clss, value): return type(value) is int and clss.MIN <= value <= clss.MAX - @classmethod + @property # type: ignore @abc.override - def _arbitrary(clss, random): - return clss.MAX + def _arbitrary(self, random): # type: ignore + return self.max_val -class UImm6(ARMv8_Immediate_Operand): +class ADDR_UIMM12(ARMv8_Immediate_Operand): """ - Class of 6-bit unsigned immediates + address with 12-bit unsigned immediate """ - _NUM_BITS = 6 + _NUM_BITS = 12 MIN, MAX = _unsigned_min_max(_NUM_BITS) @property # type: ignore @@ -313,18 +489,18 @@ class UImm6(ARMv8_Immediate_Operand): def _is_valid_value(clss, value): return type(value) is int and clss.MIN <= value <= clss.MAX - @classmethod + @property # type: ignore @abc.override - def _arbitrary(clss, random): - return clss.MAX + def _arbitrary(self, random): # type: ignore + return self.max_val -class UImm4(ARMv8_Immediate_Operand): +class UIMM3_OP1(ARMv8_Immediate_Operand): """ - Class of 4-bit unsigned immediates + a 3-bit unsigned immediate """ - _NUM_BITS = 4 + _NUM_BITS = 3 MIN, MAX = _unsigned_min_max(_NUM_BITS) @property # type: ignore @@ -337,18 +513,21 @@ class UImm4(ARMv8_Immediate_Operand): def _is_valid_value(clss, value): return type(value) is int and clss.MIN <= value <= clss.MAX - @classmethod + @property # type: ignore @abc.override - def _arbitrary(clss, random): - return clss.MAX + def _arbitrary(self, random): # type: ignore + return self.max_val + + +UIMM3_OP2 = UIMM3_OP1 -class UImm5(ARMv8_Immediate_Operand): +class SVE_FPIMM8(ARMv8_Immediate_Operand): """ - Class of 5-bit unsigned immediates (bitfields) + 8-bit floating-point immediate """ - _NUM_BITS = 5 + _NUM_BITS = 8 MIN, MAX = _unsigned_min_max(_NUM_BITS) @property # type: ignore @@ -361,19 +540,19 @@ class UImm5(ARMv8_Immediate_Operand): def _is_valid_value(clss, value): return type(value) is int and clss.MIN <= value <= clss.MAX - @classmethod + @property # type: ignore @abc.override - def _arbitrary(clss, random): - return clss.MAX + def _arbitrary(self, random): # type: ignore + return self.max_val -class Imm16(ARMv8_Immediate_Operand): +class TME_UIMM16(ARMv8_Immediate_Operand): """ - Class of 16-bit signed immediates + a 16-bit unsigned immediate for TME tcancel """ _NUM_BITS = 16 - MIN, MAX = _signed_min_max(_NUM_BITS) + MIN, MAX = _unsigned_min_max(_NUM_BITS) @property # type: ignore @abc.override @@ -385,18 +564,42 @@ class Imm16(ARMv8_Immediate_Operand): def _is_valid_value(clss, value): return type(value) is int and clss.MIN <= value <= clss.MAX + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return self.max_val + + +class ADDR_SIMM10(ARMv8_Immediate_Operand): + """ + address with 10-bit signed immediate + """ + + _NUM_BITS = 10 + MIN, MAX = _signed_min_max(_NUM_BITS) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + @classmethod @abc.override - def _arbitrary(clss, random): - return clss.MAX + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return self.max_val -class Imm26(ARMv8_Immediate_Operand): +class ADDR_SIMM11(ARMv8_Immediate_Operand): """ - Class of 26-bit signed immediates (used for unconditional branches) + address with 10-bit signed immediate """ - _NUM_BITS = 26 + _NUM_BITS = 11 MIN, MAX = _signed_min_max(_NUM_BITS) @property # type: ignore @@ -409,13 +612,18 @@ class Imm26(ARMv8_Immediate_Operand): def _is_valid_value(clss, value): return type(value) is int and clss.MIN <= value <= clss.MAX + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return self.max_val + -class Imm14(ARMv8_Immediate_Operand): +class ADDR_SIMM7(ARMv8_Immediate_Operand): """ - Class of 14-bit signed immediates (used for test and branch) + address with 7-bit signed immediate """ - _NUM_BITS = 14 + _NUM_BITS = 7 MIN, MAX = _signed_min_max(_NUM_BITS) @property # type: ignore @@ -428,13 +636,18 @@ class Imm14(ARMv8_Immediate_Operand): def _is_valid_value(clss, value): return type(value) is int and clss.MIN <= value <= clss.MAX + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return self.max_val + -class Imm9(ARMv8_Immediate_Operand): +class LIMM(ARMv8_Immediate_Operand): """ - Class of 9-bit signed immediates (used for offsets on unscaled immediate) + Logical Immediate, 10-bit """ - _NUM_BITS = 9 + _NUM_BITS = 10 MIN, MAX = _signed_min_max(_NUM_BITS) @property # type: ignore @@ -447,13 +660,18 @@ class Imm9(ARMv8_Immediate_Operand): def _is_valid_value(clss, value): return type(value) is int and clss.MIN <= value <= clss.MAX + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return self.max_val + -class Imm7(ARMv8_Immediate_Operand): +class IMM_MOV(ARMv8_Immediate_Operand): """ - Class of 7-bit signed immediates (used for offsets on load/store register pair) + Immediate (?) """ - _NUM_BITS = 7 + _NUM_BITS = 16 MIN, MAX = _signed_min_max(_NUM_BITS) @property # type: ignore @@ -466,13 +684,19 @@ class Imm7(ARMv8_Immediate_Operand): def _is_valid_value(clss, value): return type(value) is int and clss.MIN <= value <= clss.MAX + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return self.max_val + -class Imm3(ARMv8_Immediate_Operand): +# TODO left shift +class HALF(ARMv8_Immediate_Operand): """ - Class of 3-bit signed immediates + 16-bit immediate with optional left shift """ - _NUM_BITS = 3 + _NUM_BITS = 16 MIN, MAX = _signed_min_max(_NUM_BITS) @property # type: ignore @@ -485,7 +709,564 @@ class Imm3(ARMv8_Immediate_Operand): def _is_valid_value(clss, value): return type(value) is int and clss.MIN <= value <= clss.MAX + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return self.max_val + + +class UIMM10(ARMv8_Immediate_Operand): + """ + 10-bit unsigned multiple of 16 + """ + + _NUM_BITS = 10 + MIN = 0 + MAX = 16 * _unsigned_min_max(_NUM_BITS)[1] + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return self.max_val + + +# TODO: this needs further support: +# - link values to other operand width +# - other arbitrary values (duplicate instructions) +class IMM_ROT1(ARMv8_Immediate_Operand): + """ + 2-bit rotation specifier for complex arithmetic operations + """ + + _NUM_BITS = 2 + MIN = 0 + MAX = 270 + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return value in [0, 90, 180, 270] + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return 270 + + +IMM_ROT2 = IMM_ROT1 + + +# TODO: this needs further support: +# - other arbitrary values (duplicate instructions) +class IMM_ROT3(ARMv8_Immediate_Operand): + """ + 1-bit rotation specifier for complex arithmetic operations + """ + + _NUM_BITS = 1 + MIN = 90 + MAX = 270 + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return value in [90, 270] + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return 270 + + +# TODO: this needs further support: +# - link values to other operand width +# - other arbitrary values (duplicate instructions) +class IMM_LSL(ARMv8_Immediate_Operand): + """ + left shift amount for an AdvSIMD register + """ + + _NUM_BITS = 2 + MIN = 0 + MAX = 24 + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return value in {0, 8, 16, 24} + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return 0 + + +# TODO: this needs further support: +# - look at what this is for? +class IMM_MSL(ARMv8_Immediate_Operand): + """ + left shift ones for an AdvSIMD register + """ + + # ? + _NUM_BITS = 1 + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + # ??? + return value in {0} + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return 0 + + +# TODO: further support of all possible values +class SIMD_FPIMM(ARMv8_Immediate_Operand): + """ + an 8-bit floating-point constant + """ + + _NUM_BITS = 8 + MIN, MAX = 0.125, 31.0 + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) in {0.125, 10, 31} + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return 10.0 + + +FPIMM = SIMD_FPIMM + +# TODO: further support of all possible values +class SIMD_IMM(ARMv8_Immediate_Operand): + """ + An immediate + """ + + _NUM_BITS = 8 + MIN, MAX = _unsigned_min_max(64) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + # TODO: change this to be more accurate @classmethod @abc.override - def _arbitrary(clss, random): + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore return clss.MAX + + +class IMM0(ARMv8_Immediate_Operand): + """ + 0 + """ + + _NUM_BITS = 1 + MIN, MAX = 0, 0 + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + # TODO: change this to be more accurate + @classmethod + @abc.override + def _is_valid_value(clss, value): + return value == 0 + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return 0 + + +class FPIMM0(ARMv8_Immediate_Operand): + """ + 0.0 + """ + + _NUM_BITS = 1 + MIN, MAX = 0, 0 + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return value == 0 + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return 0.0 + + +class SHLL_IMM(ARMv8_Immediate_Operand): + """ + An immediate shift amount of 8, 16 or 32 + """ + + _NUM_BITS = 2 + MIN, MAX = 0, 32 + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) in {8, 16, 32} + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return 8 + + +class IMM_VLSR(ARMv8_Immediate_Operand): + """ + a right shift amount for an AdvSIMD register + """ + + _NUM_BITS = 7 + MIN = 1 + MAX = 64 + + def __init__( + self, + name: str, + visibility: ir.Operand_Visibility, + value: int = None, + reg_class: Optional[Any] = None, + ): + super().__init__(name, visibility, value) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + # Why not? + return 5 + + +# TODO: change operand width to match other operands? +class IMM_VLSL(ARMv8_Immediate_Operand): + """ + a left shift amount for an AdvSIMD register + """ + + _NUM_BITS = 7 + MIN = 1 + MAX = 64 + + def __init__( + self, + name: str, + visibility: ir.Operand_Visibility, + value: int = None, + reg_class: Optional[Any] = None, + ): + super().__init__(name, visibility, value) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + # Why not? + return 5 + + +class CCMP_IMM(ARMv8_Immediate_Operand): + """ + 5-bit unsigned immediate + """ + + _NUM_BITS = 5 + MIN, MAX = _unsigned_min_max(_NUM_BITS) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return self.max_val + + +# TODO: might be buggy +class COND(ARMv8_Operand): + """ + One of the following condition flag: + EQ Equal + NE Not equal + CS Carry set (identical to HS) + HS Unsigned higher or same (identical to CS) + CC Carry clear (identical to LO) + LO Unsigned lower (identical to CC) + MI Minus or negative result + PL Positive or zero result + VS Overflow + VC No overflow + HI Unsigned higher + LS Unsigned lower or same + GE Signed greater than or equal + LT Signed less than + GT Signed greater than + LE Signed less than or equal + AL Always (this is the default) + """ + + def __init__(self, name: str, visibility: ir.Operand_Visibility, value: int = None): + super().__init__(name, str(type(self)), visibility) + assert value is None or isinstance(value, str) + self._value = value + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return value in { + "EQ", + "NE", + "CS", + "HS", + "CC", + "LO", + "MI", + "PL", + "VS", + "VC", + "HI", + "LS", + "GE", + "LT", + "GT", + "LE", + "AL", + } + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return "EQ" + + @property # type: ignore + @abc.override + def short_name(self) -> str: + return type(self).__name__.upper() + + @property # type: ignore + @abc.override + def value(self) -> Union[str, None]: + return self._value + + @abc.override + def with_value(self, value): + clss = type(self) + + if not self._is_valid_value(value): + raise TypeError( + f"{value} is not a valid value for immediates of type {clss.__name__}" + ) + + return clss(self.name, self.visibility, value) + + @property # type: ignore + @abc.override + def is_virtual(self, random): # type: ignore + return value is None + + @property # type: ignore + @abc.override + def use_def(self, random): # type: ignore + return ir.Use_Def.USE + + +# TODO: might be buggy +class COND1(ARMv8_Operand): + """ + One of the following condition flag: + EQ Equal + NE Not equal + CS Carry set (identical to HS) + HS Unsigned higher or same (identical to CS) + CC Carry clear (identical to LO) + LO Unsigned lower (identical to CC) + MI Minus or negative result + PL Positive or zero result + VS Overflow + VC No overflow + HI Unsigned higher + LS Unsigned lower or same + GE Signed greater than or equal + LT Signed less than + GT Signed greater than + LE Signed less than or equal + """ + + def __init__(self, name: str, visibility: ir.Operand_Visibility, value: int = None): + super().__init__(name, str(type(self)), visibility) + assert value is None or isinstance(value, str) + self._value = value + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return value in { + "EQ", + "NE", + "CS", + "HS", + "CC", + "LO", + "MI", + "PL", + "VS", + "VC", + "HI", + "LS", + "GE", + "LT", + "GT", + "LE", + } + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return "EQ" + + @property # type: ignore + @abc.override + def short_name(self) -> str: + return type(self).__name__.upper() + + @property # type: ignore + @abc.override + def value(self) -> Union[str, None]: + return self._value + + @abc.override + def with_value(self, value): + clss = type(self) + + if not self._is_valid_value(value): + raise TypeError( + f"{value} is not a valid value for immediates of type {clss.__name__}" + ) + + return clss(self.name, self.visibility, value) + + @property # type: ignore + @abc.override + def is_virtual(self, random): # type: ignore + return value is None + + @property # type: ignore + @abc.override + def use_def(self, random): # type: ignore + return ir.Use_Def.USE + + +class NZCV(ARMv8_Immediate_Operand): + """ + flag bit specifier giving an alternative value for each flag (0 to 15) + """ + + _NUM_BITS = 4 + MIN, MAX = _unsigned_min_max(_NUM_BITS) + + @property # type: ignore + @abc.override + def num_bits(self): + return self._NUM_BITS + + @classmethod + @abc.override + def _is_valid_value(clss, value): + return type(value) is int and clss.MIN <= value <= clss.MAX + + @property # type: ignore + @abc.override + def _arbitrary(self, random): # type: ignore + return self.max_val diff --git a/src/pipedream/asm/armv8a/registers.py b/src/pipedream/asm/armv8a/registers.py index ff9f2b9..9af9033 100644 --- a/src/pipedream/asm/armv8a/registers.py +++ b/src/pipedream/asm/armv8a/registers.py @@ -35,22 +35,30 @@ class _ARMv8_REGISTER_TYPE(Enum): VH = Unique(16) VS = Unique(32) VD = Unique(64) - # ARMv8.2 only (dot product) + VQ = Unique(128) + # ARMv8.2 only V4B = Unique(32) + V2H = Unique(32) + V_2H = Unique(32) # Vects V_8B = Unique(64) V_4H = Unique(64) V_2S = Unique(64) + V_1D = Unique(64) V_16B = Unique(128) V_8H = Unique(128) V_4S = Unique(128) V_2D = Unique(128) + V_1Q = Unique(128) # Other WSP = Unique(32) WZR = Unique(32) + FLAGS = Unique(32) SP = Unique(64) ZXR = Unique(64) LR = Unique(64) + CR = Unique(64) + SYSREG = Unique(64) @property def value(self): @@ -80,27 +88,34 @@ _ARMv8_V_ELMT_TYPE_TO_WIDTH = { _ARMv8_REGISTER_TYPE.VH: _ARMv8_REGISTER_TYPE.H, _ARMv8_REGISTER_TYPE.VS: _ARMv8_REGISTER_TYPE.S, _ARMv8_REGISTER_TYPE.VD: _ARMv8_REGISTER_TYPE.D, + _ARMv8_REGISTER_TYPE.VQ: _ARMv8_REGISTER_TYPE.Q, } _ARMv8_WIDTH_TO_VECT = { (_ARMv8_REGISTER_TYPE.D, _ARMv8_REGISTER_TYPE.B): _ARMv8_REGISTER_TYPE.V_8B, (_ARMv8_REGISTER_TYPE.D, _ARMv8_REGISTER_TYPE.H): _ARMv8_REGISTER_TYPE.V_4H, (_ARMv8_REGISTER_TYPE.D, _ARMv8_REGISTER_TYPE.S): _ARMv8_REGISTER_TYPE.V_2S, + (_ARMv8_REGISTER_TYPE.D, _ARMv8_REGISTER_TYPE.D): _ARMv8_REGISTER_TYPE.V_1D, (_ARMv8_REGISTER_TYPE.Q, _ARMv8_REGISTER_TYPE.B): _ARMv8_REGISTER_TYPE.V_16B, (_ARMv8_REGISTER_TYPE.Q, _ARMv8_REGISTER_TYPE.H): _ARMv8_REGISTER_TYPE.V_8H, (_ARMv8_REGISTER_TYPE.Q, _ARMv8_REGISTER_TYPE.S): _ARMv8_REGISTER_TYPE.V_4S, (_ARMv8_REGISTER_TYPE.Q, _ARMv8_REGISTER_TYPE.D): _ARMv8_REGISTER_TYPE.V_2D, + (_ARMv8_REGISTER_TYPE.Q, _ARMv8_REGISTER_TYPE.Q): _ARMv8_REGISTER_TYPE.V_1Q, } _ARMv8_VECT_TO_SUBTYPE = { _ARMv8_REGISTER_TYPE.V4B: _ARMv8_REGISTER_TYPE.B, + _ARMv8_REGISTER_TYPE.V2H: _ARMv8_REGISTER_TYPE.H, + _ARMv8_REGISTER_TYPE.V_2H: _ARMv8_REGISTER_TYPE.H, _ARMv8_REGISTER_TYPE.V_8B: _ARMv8_REGISTER_TYPE.B, _ARMv8_REGISTER_TYPE.V_4H: _ARMv8_REGISTER_TYPE.H, _ARMv8_REGISTER_TYPE.V_2S: _ARMv8_REGISTER_TYPE.S, _ARMv8_REGISTER_TYPE.V_16B: _ARMv8_REGISTER_TYPE.B, _ARMv8_REGISTER_TYPE.V_8H: _ARMv8_REGISTER_TYPE.H, _ARMv8_REGISTER_TYPE.V_4S: _ARMv8_REGISTER_TYPE.S, + _ARMv8_REGISTER_TYPE.V_1D: _ARMv8_REGISTER_TYPE.D, _ARMv8_REGISTER_TYPE.V_2D: _ARMv8_REGISTER_TYPE.D, + _ARMv8_REGISTER_TYPE.V_1Q: _ARMv8_REGISTER_TYPE.Q, } @@ -314,7 +329,7 @@ def recurse_vect_fp_reg_add( # integer automatically computed) if vect_idx == 0: if size_idx in [3, 4]: - for idx, subvect_type_idx in enumerate(range(size_idx)): + for idx, subvect_type_idx in enumerate(range(size_idx + 1)): vect_type = _ARMv8_WIDTH_TO_VECT[ fp_type, _ARMv8_BASE_SIZES[subvect_type_idx] ] @@ -361,6 +376,13 @@ def recurse_vect_fp_reg_add( ) +# CR, unused in generation +_def_ARMv8_reg("CR", _ARMv8_REGISTER_TYPE.CR) +_def_ARMv8_reg("SYSREG", _ARMv8_REGISTER_TYPE.SYSREG) + +# flags handeling +_def_ARMv8_reg("_NZCV", _ARMv8_REGISTER_TYPE.FLAGS) + # stack pointer WSP = _def_ARMv8_reg("WSP", _ARMv8_REGISTER_TYPE.WSP) SP = _def_ARMv8_reg("SP", _ARMv8_REGISTER_TYPE.SP, [WSP], widest=True) diff --git a/src/pipedream/asm/ir.py b/src/pipedream/asm/ir.py index 9d2f76f..82016e5 100644 --- a/src/pipedream/asm/ir.py +++ b/src/pipedream/asm/ir.py @@ -883,14 +883,14 @@ class Instruction(abc.ABC): ... @property - def defs(self) -> ["Operand"]: + def defs(self) -> ty.List["Operand"]: """ Operands written to """ for op in self.operands: if op.is_def: yield op @property - def uses(self) -> ["Operand"]: + def uses(self) -> ty.List["Operand"]: """ Operands read from """ for op in self.operands: if op.is_use: @@ -903,7 +903,7 @@ class Instruction(abc.ABC): raise KeyError(name) - def get_operand_idx(self, name: str) -> "Operand": + def get_operand_idx(self, name: str) -> int: for idx, op in enumerate(self.operands): if op.name == name: return idx @@ -926,7 +926,7 @@ class Instruction(abc.ABC): return False @abc.abstractmethod - def encodings(self) -> ["Instruction_Encoding"]: + def encodings(self) -> ty.List["Instruction_Encoding"]: """ List of ways that an instruction can be encoded. """ diff --git a/src/pipedream/asm/x86/instructions.py b/src/pipedream/asm/x86/instructions.py index fca9440..095e3a0 100644 --- a/src/pipedream/asm/x86/instructions.py +++ b/src/pipedream/asm/x86/instructions.py @@ -10,7 +10,7 @@ from pipedream.asm import ir from pipedream.asm.x86.operands import * from pipedream.asm.x86.registers import * from pipedream.asm.x86.flags import * -from pipedream.asm.x86 import instructions_xed +from pipedream.asm.x86 import instructions_xed # type: ignore __all__ = [ "X86_Instruction", diff --git a/src/pipedream/utils/statistics.py b/src/pipedream/utils/statistics.py index b68dd60..6fb6e7c 100644 --- a/src/pipedream/utils/statistics.py +++ b/src/pipedream/utils/statistics.py @@ -1,4 +1,3 @@ - """ Helpers for dealing with statistics. """ @@ -12,394 +11,426 @@ import math import numpy import typing as ty -import scipy.stats +import scipy.stats # type: ignore __all__ = [ - 'Statistics', + "Statistics", ] class Statistics(yaml.YAML_Struct): - """ + """ Statistics for a set of measurements. Mean, standard deviation, some percentiles, ... - """ - raw_data = yaml.Slot(ty.List[float], list) - num_samples = yaml.Slot(float, math.nan) - ## arithmetic mean - mean = yaml.Slot(float, math.nan) - ## sample standard deviation - stddev = yaml.Slot(float, math.nan) - ## sample variance - variance = yaml.Slot(float, math.nan) - ## median absolute deviation - MAD = yaml.Slot(float, math.nan) - min = yaml.Slot(float, math.nan) - max = yaml.Slot(float, math.nan) - percentiles = yaml.Slot(ty.Dict[int, float], dict) - histogram = yaml.Slot(ty.List[int], list) - - def __init__(self, *, - mean: float, stddev: float, variance: float = math.nan, - num_samples: float = math.nan, - MAD: float = math.nan, - min: float = math.nan, max: float = math.nan, - percentiles: ty.Dict[int, float] = None, - histogram: ty.Sequence[float] = None, - raw_data: ty.List[float] = None, - ): - self.raw_data = raw_data if raw_data is not None else [] - self.num_samples = float(num_samples) - self.mean = float(mean) - self.stddev = float(stddev) - self.variance = self.stddev ** 2 if math.isnan(float(variance)) else float(variance) - self.MAD = float(MAD) - self.min = float(min) - self.max = float(max) - self.percentiles = percentiles or {} - self.histogram = histogram or [] - - assert not percentiles or 0 not in percentiles or percentiles[0] == min - assert not percentiles or 100 not in percentiles or percentiles[100] == max - - assert not percentiles or 10 in percentiles - assert not percentiles or 25 in percentiles - assert not percentiles or 50 in percentiles - assert not percentiles or 75 in percentiles - assert not percentiles or 90 in percentiles - - def drop_details(self): - """ - Throw away the histogram and all percentiles except 10,25,50,75,90. - This is intended to reduce memory usage. - """ - self.raw_data = [] - self.histogram = [] - self.percentiles = { - 10: self.percentiles[10], - 25: self.percentiles[25], - 50: self.percentiles[50], - 75: self.percentiles[75], - 90: self.percentiles[90], - } - - def prediction_interval(self, probability: float): - assert 0 <= probability <= 1 - - return scipy.stats.norm.interval(probability, loc=self.mean, scale=self.stddev) - - @staticmethod - def from_array(array: numpy.ndarray, keep_raw_data: bool = False): - """ - compute values from a numpy array. - """ - - if len(array) == 0: - raise ValueError('Cannot create Statistics from empty array') - - # https://en.wikipedia.org/wiki/Median_absolute_deviation - median = numpy.median(array) - MAD = numpy.median([abs(x - median) for x in array]) - - percentiles = numpy.percentile(array, range(1, 100)) - percentiles = {n + 1: percentiles[n] for n in range(len(percentiles))} - - if not math.isinf(median) and not math.isnan(median): - hist, hist_bins = numpy.histogram(array, 100) - - assert sum(hist) == len(array), f'{sum(hist)} != {len(array)}' - - hist = [int(n) for n in hist] - else: - hist = None - - if keep_raw_data: - return Statistics( - raw_data = array, - num_samples = len(array), - mean = array.mean(), - stddev = array.std(ddof=1), - variance = array.var(ddof=1), - MAD = MAD, - min = array.min(), - max = array.max(), - percentiles = percentiles, - histogram = hist, - ) - - else: - return Statistics( - num_samples = len(array), - mean = array.mean(), - stddev = array.std(ddof=1), - variance = array.var(ddof=1), - MAD = MAD, - min = array.min(), - max = array.max(), - percentiles = percentiles, - histogram = hist, - ) - - @property - def IQR(self): - """ - Interquartile range (https://en.wikipedia.org/wiki/Interquartile_range) - """ - return self.p75 - self.p25 - - @property - def p0(self): - """ - 0th percentile (minimum value) - """ - return self.min - - @property - def p100(self): - """ - 100th percentile (maximum value) - """ - return self.max - - @property - def p10(self): - return self.percentile(10) - - @property - def p25(self): - return self.percentile(25) - - @property - def p50(self): - return self.percentile(50) - - @property - def p75(self): - return self.percentile(75) - - @property - def p90(self): - return self.percentile(90) - - def percentile(self, q: int) -> float: - """ - Retrieve q-th percentile value. - Returns NaN if that percentile is not available - """ - - return self.percentiles.get(q, math.nan) - - def scale(self, scalar: float) -> 'Statistics': - """ - return new Statistics where mean and stddev are multiplied with scalar. - All other values are dropped. - - I guess this is basically only safe for results of __truediv__ - """ - - return Statistics(mean = self.mean * scalar, stddev = self.stddev * scalar) - - def __truediv__(self, that): - if type(that) is not Statistics: - return NotImplemented - - ## https://stats.stackexchange.com/questions/49399/standard-deviation-of-a-ratio-percentage-change - ## - ## comment by Eekhorn: - ## ... if you want to normalize your data y ± Δy to z ± Δz - ## x = y / z - ## you have to calculate the standard deviation Δx as follows: - ## Δx = x * sqrt( (Δy / y)^2 + (Δz / z)^2 ) - ## NOTE: I removed the scaling by 100 to get percentages - ## it appears in a linear factor in both equations so you can scale later - ## just by multiplying mean & stddev if desired. - - y_mean = self.mean - y_dev = self.stddev - - z_mean = that.mean - z_dev = that.stddev - - x_mean = y_mean / z_mean - x_dev = x_mean * math.sqrt( - (y_dev / y_mean) ** 2 + - (z_dev / z_mean) ** 2 - ) - - return Statistics(mean=x_mean, stddev=x_dev) - - def __sub__(self, that): - assert type(that) is Statistics - - assert False, "TODO: need covariance, and for that we need to store all values :(" - - @classmethod - def stddev_of_difference(clss, data1: numpy.ndarray, data2: numpy.ndarray): - return math.sqrt(clss.variance_of_difference(data1, data2)) - - @classmethod - def variance_of_difference(clss, data1: numpy.ndarray, data2: numpy.ndarray): - # https://en.wikipedia.org/wiki/Variance#Sum_of_correlated_variables - # https://stats.stackexchange.com/questions/142745/what-is-the-demonstration-of-the-variance-of-the-difference-of-two-dependent-var - # Var[X - Y] = Var[X] + Var[Y] - 2 * Cov[X, Y] - - cov = numpy.cov(numpy.vstack(data1, data2)) - - return numpy.var(data1) + numpy.var(data2) - 2 * cov[0, 1] - - def __str__(self): - out = io.StringIO() - yaml.dump(self, out) - return out.getvalue() - - def whisker_plot(self, width: int) -> str: - assert width > 2 - width -= 2 - - def frac(n, default = 0): - if math.isnan(n): - n = default - return fractions.Fraction(n) - - min = fractions.Fraction(self.min) - p10 = frac(self.p10, min) - p25 = frac(self.p25, p10) - p50 = fractions.Fraction(self.p50) - p75 = frac(self.p75, p50) - p90 = frac(self.p90, p75) - max = fractions.Fraction(self.max) - - span = (max - min) or 1 - - def w(hi, lo): - # assert lo <= hi - return round(((hi - lo) * width) / span) - - w10 = w(p10, min) - w25 = w(p25, p10) - w50 = w(p50, p25) - w75 = w(p75, p50) - w90 = w(p90, p75) - w100 = w(max, p90) - - # print('%10s %10s' % (p10, w10)) - # print('%10s %10s' % (p25, w25)) - # print('%10s %10s' % (p50, w50)) - # print('%10s %10s' % (p75, w75)) - # print('%10s %10s' % (p90, w90)) - # print('%10s %10s' % (max, w100)) - - txt = '' - txt += '[' - txt += ' ' * round(w10) - txt += ('|' + '-' * width)[:round(w25)] - txt += '=' * round(w50) - txt += '0' - txt += '=' * round(w75) - txt += ('|' + '-' * width)[:round(w90)][::-1] - txt += ' ' * round(w100) - txt += ']' - - # draw outliers (lower 10 and higher 10 percentlies) - if self.percentiles: - tmp = list(txt) - - def draw_outlier(pos: int): - if pos >= len(tmp): - return - if tmp[pos] != ' ': - return - - tmp[pos] = '.' - - for p in range(1, 10): - if p in self.percentiles: - draw_outlier(w(self.percentiles[p], self.min)) - - for p in range(91, 99): - if p in self.percentiles: - draw_outlier(w(self.percentiles[p], self.min)) - - txt = ''.join(tmp) - - return txt - - def histogram_plot(self, width: int) -> str: - """ - Visualizes distribution of values using terminal colors. - - Produces a 'histogram' of width :width: - blue -> very many values - red -> many values - yellow -> some values - green -> few values """ - assert width > 2 - width -= 2 - - assert self.histogram - - BLACK = (0, 0, 0) - GRAY = (55, 55, 55) - WHITE = (255, 255, 255) - SOFT_GREEN = (171, 252, 133) - GREEN = (122, 244, 66) - YELLOW = (247, 243, 32) - RED = (244, 19, 19) - CYAN = (133, 246, 252) - BLUE = (11, 19, 239) - PURPLE = (150, 0, 170) - - N = max(self.histogram) - - COLORS = (BLACK, GRAY, SOFT_GREEN, GREEN, YELLOW, RED, CYAN, BLUE, PURPLE, WHITE) - STEPS = (0, N * 0.025, N * 0.05, N * 0.10, N * 0.25, N * 0.50, N * 0.75, N * 0.9, N * 0.95, N) - GRADIENT = tuple(tuple(C[i] for C in COLORS) for i in range(3)) - - def color_gradient(n: float): - return tuple(int(round(numpy.interp(n, STEPS, GRADIENT[i]))) for i in range(3)) - - hist = '' - hist += '[' - - for n in self.histogram: - rgb = color_gradient(n) - hist += terminal.Colors.bg_rgb(*rgb)(' ') - - hist += ']' - - assert len(hist) == 102 - - return hist - - def to_jsonish(self) -> dict: - d = {} - for slot in self._yaml_slots_: - assert slot.yaml_name not in d - - val = getattr(self, slot.py_name) - d[slot.yaml_name] = val - return d - - @staticmethod - def from_jsonish(src: dict) -> 'Statistics': - out = {} - - for slot in Statistics._yaml_slots_: - try: - val = src[slot.yaml_name] - - if slot.type is ty.Dict[int, float]: - val = {int(k): v for k, v in val.items()} - except KeyError: - val = slot.default - - out[slot.py_name] = val - - try: - return Statistics(**out) - except TypeError: - print(src) - raise + raw_data = yaml.Slot(ty.List[float], list) + num_samples = yaml.Slot(float, math.nan) + ## arithmetic mean + mean = yaml.Slot(float, math.nan) + ## sample standard deviation + stddev = yaml.Slot(float, math.nan) + ## sample variance + variance = yaml.Slot(float, math.nan) + ## median absolute deviation + MAD = yaml.Slot(float, math.nan) + min = yaml.Slot(float, math.nan) + max = yaml.Slot(float, math.nan) + percentiles = yaml.Slot(ty.Dict[int, float], dict) + histogram = yaml.Slot(ty.List[int], list) + + def __init__( + self, + *, + mean: ty.Union[numpy.number, float], + stddev: ty.Union[numpy.number, float], + variance: ty.Union[numpy.number, float] = math.nan, + num_samples: ty.Union[numpy.number, float] = math.nan, + MAD: ty.Union[numpy.number, float] = math.nan, + min: ty.Union[numpy.number, float] = math.nan, + max: ty.Union[numpy.number, float] = math.nan, + percentiles: ty.Dict[int, float] = None, + histogram: ty.Sequence[float] = None, + raw_data: ty.Union[numpy.ndarray, ty.List[float]] = None, + ): + self.raw_data = raw_data if raw_data is not None else [] + self.num_samples = float(num_samples) + self.mean = float(mean) + self.stddev = float(stddev) + self.variance = ( + self.stddev ** 2 if math.isnan(float(variance)) else float(variance) + ) + self.MAD = float(MAD) + self.min = float(min) + self.max = float(max) + self.percentiles = percentiles or {} + self.histogram = histogram or [] + + assert not percentiles or 0 not in percentiles or percentiles[0] == min + assert not percentiles or 100 not in percentiles or percentiles[100] == max + + assert not percentiles or 10 in percentiles + assert not percentiles or 25 in percentiles + assert not percentiles or 50 in percentiles + assert not percentiles or 75 in percentiles + assert not percentiles or 90 in percentiles + + def drop_details(self): + """ + Throw away the histogram and all percentiles except 10,25,50,75,90. + This is intended to reduce memory usage. + """ + self.raw_data = [] + self.histogram = [] + self.percentiles = { + 10: self.percentiles[10], + 25: self.percentiles[25], + 50: self.percentiles[50], + 75: self.percentiles[75], + 90: self.percentiles[90], + } + + def prediction_interval(self, probability: float): + assert 0 <= probability <= 1 + + return scipy.stats.norm.interval(probability, loc=self.mean, scale=self.stddev) + + @staticmethod + def from_array(array: numpy.ndarray, keep_raw_data: bool = False): + """ + compute values from a numpy array. + """ + + if len(array) == 0: + raise ValueError("Cannot create Statistics from empty array") + + # https://en.wikipedia.org/wiki/Median_absolute_deviation + median = numpy.median(array) + MAD = numpy.median([abs(x - median) for x in array]) + + percentiles = numpy.percentile(array, range(1, 100)) + percentiles = {n + 1: percentiles[n] for n in range(len(percentiles))} + + if not math.isinf(median) and not math.isnan(median): + hist, hist_bins = numpy.histogram(array, 100) + + assert sum(hist) == len(array), f"{sum(hist)} != {len(array)}" + + hist = [int(n) for n in hist] + else: + hist = None + + if keep_raw_data: + return Statistics( + raw_data=array, + num_samples=len(array), + mean=array.mean(), + stddev=array.std(ddof=1), + variance=array.var(ddof=1), + MAD=MAD, + min=array.min(), + max=array.max(), + percentiles=percentiles, + histogram=hist, + ) + + else: + return Statistics( + num_samples=len(array), + mean=array.mean(), + stddev=array.std(ddof=1), + variance=array.var(ddof=1), + MAD=MAD, + min=array.min(), + max=array.max(), + percentiles=percentiles, + histogram=hist, + ) + + @property + def IQR(self): + """ + Interquartile range (https://en.wikipedia.org/wiki/Interquartile_range) + """ + return self.p75 - self.p25 + + @property + def p0(self): + """ + 0th percentile (minimum value) + """ + return self.min + + @property + def p100(self): + """ + 100th percentile (maximum value) + """ + return self.max + + @property + def p10(self): + return self.percentile(10) + + @property + def p25(self): + return self.percentile(25) + + @property + def p50(self): + return self.percentile(50) + + @property + def p75(self): + return self.percentile(75) + + @property + def p90(self): + return self.percentile(90) + + def percentile(self, q: int) -> float: + """ + Retrieve q-th percentile value. + Returns NaN if that percentile is not available + """ + + return self.percentiles.get(q, math.nan) + + def scale(self, scalar: float) -> "Statistics": + """ + return new Statistics where mean and stddev are multiplied with scalar. + All other values are dropped. + + I guess this is basically only safe for results of __truediv__ + """ + + return Statistics(mean=self.mean * scalar, stddev=self.stddev * scalar) + + def __truediv__(self, that): + if type(that) is not Statistics: + return NotImplemented + + ## https://stats.stackexchange.com/questions/49399/standard-deviation-of-a-ratio-percentage-change + ## + ## comment by Eekhorn: + ## ... if you want to normalize your data y ± Δy to z ± Δz + ## x = y / z + ## you have to calculate the standard deviation Δx as follows: + ## Δx = x * sqrt( (Δy / y)^2 + (Δz / z)^2 ) + ## NOTE: I removed the scaling by 100 to get percentages + ## it appears in a linear factor in both equations so you can scale later + ## just by multiplying mean & stddev if desired. + + y_mean = self.mean + y_dev = self.stddev + + z_mean = that.mean + z_dev = that.stddev + + x_mean = y_mean / z_mean + x_dev = x_mean * math.sqrt((y_dev / y_mean) ** 2 + (z_dev / z_mean) ** 2) + + return Statistics(mean=x_mean, stddev=x_dev) + + def __sub__(self, that): + assert type(that) is Statistics + + assert ( + False + ), "TODO: need covariance, and for that we need to store all values :(" + + @classmethod + def stddev_of_difference(clss, data1: numpy.ndarray, data2: numpy.ndarray): + return math.sqrt(clss.variance_of_difference(data1, data2)) + + @classmethod + def variance_of_difference(clss, data1: numpy.ndarray, data2: numpy.ndarray): + # https://en.wikipedia.org/wiki/Variance#Sum_of_correlated_variables + # https://stats.stackexchange.com/questions/142745/what-is-the-demonstration-of-the-variance-of-the-difference-of-two-dependent-var + # Var[X - Y] = Var[X] + Var[Y] - 2 * Cov[X, Y] + + # ND 07/05/2021: Mypy is showinf an error, not sure this still works + cov = numpy.cov(numpy.vstack(data1, data2)) # type: ignore + + return numpy.var(data1) + numpy.var(data2) - 2 * cov[0, 1] + + def __str__(self): + out = io.StringIO() + yaml.dump(self, out) + return out.getvalue() + + def whisker_plot(self, width: int) -> str: + assert width > 2 + width -= 2 + + def frac(n, default=0): + if math.isnan(n): + n = default + return fractions.Fraction(n) + + min = fractions.Fraction(self.min) + p10 = frac(self.p10, min) + p25 = frac(self.p25, p10) + p50 = fractions.Fraction(self.p50) + p75 = frac(self.p75, p50) + p90 = frac(self.p90, p75) + max = fractions.Fraction(self.max) + + span = (max - min) or 1 + + def w(hi, lo): + # assert lo <= hi + return round(((hi - lo) * width) / span) + + w10 = w(p10, min) + w25 = w(p25, p10) + w50 = w(p50, p25) + w75 = w(p75, p50) + w90 = w(p90, p75) + w100 = w(max, p90) + + # print('%10s %10s' % (p10, w10)) + # print('%10s %10s' % (p25, w25)) + # print('%10s %10s' % (p50, w50)) + # print('%10s %10s' % (p75, w75)) + # print('%10s %10s' % (p90, w90)) + # print('%10s %10s' % (max, w100)) + + txt = "" + txt += "[" + txt += " " * round(w10) + txt += ("|" + "-" * width)[: round(w25)] + txt += "=" * round(w50) + txt += "0" + txt += "=" * round(w75) + txt += ("|" + "-" * width)[: round(w90)][::-1] + txt += " " * round(w100) + txt += "]" + + # draw outliers (lower 10 and higher 10 percentlies) + if self.percentiles: + tmp = list(txt) + + def draw_outlier(pos: int): + if pos >= len(tmp): + return + if tmp[pos] != " ": + return + + tmp[pos] = "." + + for p in range(1, 10): + if p in self.percentiles: + draw_outlier(w(self.percentiles[p], self.min)) + + for p in range(91, 99): + if p in self.percentiles: + draw_outlier(w(self.percentiles[p], self.min)) + + txt = "".join(tmp) + + return txt + + def histogram_plot(self, width: int) -> str: + """ + Visualizes distribution of values using terminal colors. + + Produces a 'histogram' of width :width: + blue -> very many values + red -> many values + yellow -> some values + green -> few values + """ + + assert width > 2 + width -= 2 + + assert self.histogram + + BLACK = (0, 0, 0) + GRAY = (55, 55, 55) + WHITE = (255, 255, 255) + SOFT_GREEN = (171, 252, 133) + GREEN = (122, 244, 66) + YELLOW = (247, 243, 32) + RED = (244, 19, 19) + CYAN = (133, 246, 252) + BLUE = (11, 19, 239) + PURPLE = (150, 0, 170) + + N = max(self.histogram) + + COLORS = ( + BLACK, + GRAY, + SOFT_GREEN, + GREEN, + YELLOW, + RED, + CYAN, + BLUE, + PURPLE, + WHITE, + ) + STEPS = ( + 0, + N * 0.025, + N * 0.05, + N * 0.10, + N * 0.25, + N * 0.50, + N * 0.75, + N * 0.9, + N * 0.95, + N, + ) + GRADIENT = tuple(tuple(C[i] for C in COLORS) for i in range(3)) + + def color_gradient(n: float): + return tuple( + int(round(numpy.interp(n, STEPS, GRADIENT[i]))) for i in range(3) + ) + + hist = "" + hist += "[" + + for n in self.histogram: + rgb = color_gradient(n) + hist += terminal.Colors.bg_rgb(*rgb)(" ") + + hist += "]" + + assert len(hist) == 102 + + return hist + + def to_jsonish(self) -> dict: + d = {} + for slot in self._yaml_slots_: + assert slot.yaml_name not in d + + val = getattr(self, slot.py_name) + d[slot.yaml_name] = val + return d + + @staticmethod + def from_jsonish(src: dict) -> "Statistics": + out = {} + + for slot in Statistics._yaml_slots_: + try: + val = src[slot.yaml_name] + + if slot.type is ty.Dict[int, float]: + val = {int(k): v for k, v in val.items()} + except KeyError: + val = slot.default + + out[slot.py_name] = val + + try: + return Statistics(**out) + except TypeError: + print(src) + raise diff --git a/src/pipedream/utils/yaml.py b/src/pipedream/utils/yaml.py index 70a124d..856d0dd 100644 --- a/src/pipedream/utils/yaml.py +++ b/src/pipedream/utils/yaml.py @@ -13,469 +13,472 @@ import types import typing as ty __all__ = [ - 'YAML_Serializable', - 'YAML_Struct', - 'Node', + "YAML_Serializable", + "YAML_Struct", + "Node", ] Representer = yaml.representer.SafeRepresenter Constructor = yaml.constructor.SafeConstructor -YAMLError = yaml.error.YAMLError +YAMLError = yaml.error.YAMLError ConstructorError = yaml.constructor.ConstructorError -ScannerError = yaml.scanner.ScannerError +ScannerError = yaml.scanner.ScannerError -Node = yaml.Node +Node = yaml.Node SequenceNode = yaml.SequenceNode -MappingNode = yaml.MappingNode -ScalarNode = yaml.ScalarNode +MappingNode = yaml.MappingNode +ScalarNode = yaml.ScalarNode -SEQUENCE_TAG = 'tag:yaml.org,2002:seq' -MAPPING_TAG = 'tag:yaml.org,2002:map' +SEQUENCE_TAG = "tag:yaml.org,2002:seq" +MAPPING_TAG = "tag:yaml.org,2002:map" -T = ty.TypeVar('T') -K = ty.TypeVar('K') -V = ty.TypeVar('V') -E = ty.TypeVar('E', bound=enum.Enum) +T = ty.TypeVar("T") +K = ty.TypeVar("K") +V = ty.TypeVar("V") +E = ty.TypeVar("E", bound=enum.Enum) class YAML_Serializer(abc.ABC, ty.Generic[T]): - """ + """ turn an object into a YAML Node - """ + """ - @classmethod - @ty.no_type_check - def for_type(clss, want: type) -> 'YAML_Serializer[T]': - ## FIXME: as of Python 3.8 there are public typing.get_origin and typing.get_args functions. + @classmethod + @ty.no_type_check + def for_type(clss, want: type) -> "YAML_Serializer[T]": + ## FIXME: as of Python 3.8 there are public typing.get_origin and typing.get_args functions. - origin = getattr(want, '__origin__', None) + origin = getattr(want, "__origin__", None) - # python typing types - if origin is not None: + # python typing types + if origin is not None: - # python 3.6: __origin__ is typing.List - # python 3.7: __origin__ is list - if origin in (ty.List, list): - args = want.__args__ + # python 3.6: __origin__ is typing.List + # python 3.7: __origin__ is list + if origin in (ty.List, list): + args = want.__args__ - assert len(args) == 1 + assert len(args) == 1 - return List_Serializer( - clss.for_type(args[0]) - ) + return List_Serializer(clss.for_type(args[0])) - if origin in (ty.Dict, dict): - args = want.__args__ + if origin in (ty.Dict, dict): + args = want.__args__ - assert len(args) == 2 + assert len(args) == 2 - return Dict_Serializer( - clss.for_type(args[0]), - clss.for_type(args[1]), - ) + return Dict_Serializer( + clss.for_type(args[0]), + clss.for_type(args[1]), + ) - if origin is ty.Union and len(want.__args__) == 2: - args = want.__args__ + if origin is ty.Union and len(want.__args__) == 2: + args = want.__args__ - assert args[0] is not type(None) - assert args[1] is type(None) + assert args[0] is not type(None) + assert args[1] is type(None) - return Optional_Serializer( - clss.for_type(args[0]) - ) + return Optional_Serializer(clss.for_type(args[0])) + + # base types + if want is str: + return Str_Serializer() + if want is bool: + return Bool_Serializer() + if want is int: + return Int_Serializer() + if want is float: + return Float_Serializer() + if want is datetime.datetime: + return Date_Time_Serializer() + if want is datetime.timedelta: + return Time_Delta_Serializer() + + # normal classes + if issubclass(want, YAML_Serializable): + out = want.yaml_serializer() + assert isinstance(out, YAML_Serializer), repr(out) + return out - # base types - if want is str: - return Str_Serializer() - if want is bool: - return Bool_Serializer() - if want is int: - return Int_Serializer() - if want is float: - return Float_Serializer() - if want is datetime.datetime: - return Date_Time_Serializer() - if want is datetime.timedelta: - return Time_Delta_Serializer() - - # normal classes - if issubclass(want, YAML_Serializable): - out = want.yaml_serializer() - assert isinstance(out, YAML_Serializer), repr(out) - return out - - # not supported - raise TypeError('No serializer for type ' + ty._type_repr(want)) - - @abc.abstractmethod - def to_yaml(self, obj: T) -> yaml.Node: - pass - - @abc.abstractmethod - def from_yaml(self, node: yaml.Node) -> T: - pass + # not supported + raise TypeError("No serializer for type " + ty._type_repr(want)) + + @abc.abstractmethod + def to_yaml(self, obj: T) -> yaml.Node: + pass + + @abc.abstractmethod + def from_yaml(self, node: yaml.Node) -> T: + pass class Str_Serializer(YAML_Serializer[str]): - def to_yaml(self, obj: str) -> yaml.Node: - return represent_str(obj) + def to_yaml(self, obj: str) -> yaml.Node: + return represent_str(obj) - def from_yaml(self, node: yaml.Node) -> str: - return construct_str(node) + def from_yaml(self, node: yaml.Node) -> str: + return construct_str(node) class Bool_Serializer(YAML_Serializer[int]): - def to_yaml(self, obj): - return represent_bool(obj) + def to_yaml(self, obj): + return represent_bool(obj) - def from_yaml(self, node): - return construct_bool(node) + def from_yaml(self, node): + return construct_bool(node) class Int_Serializer(YAML_Serializer[int]): - def to_yaml(self, obj): - return represent_int(obj) + def to_yaml(self, obj): + return represent_int(obj) - def from_yaml(self, node): - return construct_int(node) + def from_yaml(self, node): + return construct_int(node) class Float_Serializer(YAML_Serializer[float]): - def to_yaml(self, obj): - return represent_float(obj) + def to_yaml(self, obj): + return represent_float(obj) - def from_yaml(self, node): - return construct_float(node) + def from_yaml(self, node): + return construct_float(node) class Date_Time_Serializer(YAML_Serializer[datetime.datetime]): - def to_yaml(self, obj): - return represent_datetime(obj) + def to_yaml(self, obj): + return represent_datetime(obj) - def from_yaml(self, node): - return construct_datetime(node) + def from_yaml(self, node): + return construct_datetime(node) class Time_Delta_Serializer(YAML_Serializer[datetime.timedelta]): - def to_yaml(self, obj): - return represent_timedelta(obj) + def to_yaml(self, obj): + return represent_timedelta(obj) - def from_yaml(self, node): - return construct_timedelta(node) + def from_yaml(self, node): + return construct_timedelta(node) class List_Serializer(ty.Generic[T], YAML_Serializer[ty.List[T]]): - def __init__(self, item: YAML_Serializer[T]): - self.item = item + def __init__(self, item: YAML_Serializer[T]): + self.item = item - def to_yaml(self, obj): - return represent_list(obj, self.item) + def to_yaml(self, obj): + return represent_list(obj, self.item) - def from_yaml(self, node): - return construct_list(node, self.item) + def from_yaml(self, node): + return construct_list(node, self.item) class Dict_Serializer(ty.Generic[K, V], YAML_Serializer[ty.Dict[K, V]]): - def __init__(self, key: YAML_Serializer[K], val: YAML_Serializer[V]): - self.key = key - self.val = val + def __init__(self, key: YAML_Serializer[K], val: YAML_Serializer[V]): + self.key = key + self.val = val - def to_yaml(self, obj): - return represent_dict(obj, self.key, self.val) + def to_yaml(self, obj): + return represent_dict(obj, self.key, self.val) - def from_yaml(self, node): - return construct_dict(node, self.key, self.val) + def from_yaml(self, node): + return construct_dict(node, self.key, self.val) class Optional_Serializer(ty.Generic[E], YAML_Serializer[ty.Optional[E]]): - """ + """ Serializer for a typing.Optional[E] value (i.e. either None or an E). - """ + """ - def __init__(self, value: YAML_Serializer[E]): - self.value = value + def __init__(self, value: YAML_Serializer[E]): + self.value = value - def to_yaml(self, obj): - if obj is None: - return Representer().represent_none(None) - else: - return self.value.to_yaml(obj) + def to_yaml(self, obj): + if obj is None: + return Representer().represent_none(None) + else: + return self.value.to_yaml(obj) - def from_yaml(self, node): - if type(node) is yaml.ScalarNode and node.tag == 'tag:yaml.org,2002:null': - assert node.value == 'null' - return None - else: - return self.value.from_yaml(node) + def from_yaml(self, node): + if type(node) is yaml.ScalarNode and node.tag == "tag:yaml.org,2002:null": + assert node.value == "null" + return None + else: + return self.value.from_yaml(node) class Enum_Serializer(ty.Generic[E], YAML_Serializer[E]): - """ + """ Serializer for a enum.Enum enumeration class. Enum values are simply represented by their name as a str - """ + """ - # TODO: use a YAML tag? + # TODO: use a YAML tag? - def __init__(self, enum_class: ty.Type[E]): - assert issubclass(enum_class, enum.Enum) - self.enum_class = enum_class + def __init__(self, enum_class: ty.Type[E]): + assert issubclass(enum_class, enum.Enum) + self.enum_class = enum_class - def to_yaml(self, obj: E): - return represent_enum(self.enum_class, obj) + def to_yaml(self, obj: E): + return represent_enum(self.enum_class, obj) - def from_yaml(self, node): - return construct_enum(node, self.enum_class) + def from_yaml(self, node): + return construct_enum(node, self.enum_class) class YAML_Serializable(abc.ABC): - """ + """ Helper class for serializing to/from YAML. A bit less heavy-weight than yaml.YAMLObject, does not do any metaclass magic. - """ + """ - @abc.abstractclassmethod - def yaml_serializer(clss) -> YAML_Serializer: - raise NotImplementedError('abstract') + @abc.abstractclassmethod + def yaml_serializer(clss) -> YAML_Serializer: + raise NotImplementedError("abstract") - def from_yaml(clss, node: Node) -> ty.Type['YAML_Serializable']: - return clss.yaml_serializer().from_yaml(node) + def from_yaml(clss, node: Node) -> ty.Type["YAML_Serializable"]: + return clss.yaml_serializer().from_yaml(node) - def to_yaml(self) -> Node: - """ - serialize to yaml.Node - """ - return self.yaml_serializer().to_yaml(self) + def to_yaml(self) -> Node: + """ + serialize to yaml.Node + """ + return self.yaml_serializer().to_yaml(self) class YAML_Struct(YAML_Serializable): - """ + """ Object that is represented by a mapping in YAML. Default implementations of methods from YAML_Serializable tries to create a mapping from attributes in __slots__ if it is present ('_' in attribute names will be replaced by '-'). - """ + """ - _yaml_slots_: ty.ClassVar[ty.Tuple['Slot']] + _yaml_slots_: ty.ClassVar[ty.Tuple["Slot"]] - @classmethod - def yaml_flow_style(clss) -> bool: - """ - return true iff you want to use flow style when serializing this type - """ - return False + @classmethod + def yaml_flow_style(clss) -> bool: + """ + return true iff you want to use flow style when serializing this type + """ + return False - @classmethod - def yaml_serializer(clss): - return YAML_Struct_Serializer(clss) + @classmethod + def yaml_serializer(clss): + return YAML_Struct_Serializer(clss) - def __new__(clss, *args, **kwargs): - obj = super().__new__(clss) + def __new__(clss, *args, **kwargs): + obj = super().__new__(clss) - for slot in clss._yaml_default_slots_: - slot.set_default(obj) + for slot in clss._yaml_default_slots_: + slot.set_default(obj) - return obj + return obj - def __init_subclass__(clss): - super().__init_subclass__() + def __init_subclass__(clss): + super().__init_subclass__() - ## FIXME: forbid inheritance - assert clss.mro()[1] is YAML_Struct + ## FIXME: forbid inheritance + assert clss.mro()[1] is YAML_Struct - yaml_slots = [] + yaml_slots = [] - for v in clss.__dict__.values(): - if not isinstance(v, Slot): - continue + for v in clss.__dict__.values(): + if not isinstance(v, Slot): + continue - yaml_slots.append(v) + yaml_slots.append(v) - clss._yaml_slots_ = tuple(yaml_slots) - clss._yaml_default_slots_ = tuple(s for s in yaml_slots if s.has_default) - clss._yaml_slot_dict_ = types.MappingProxyType({s.yaml_name: s for s in yaml_slots}) + clss._yaml_slots_ = tuple(yaml_slots) + clss._yaml_default_slots_ = tuple(s for s in yaml_slots if s.has_default) + clss._yaml_slot_dict_ = types.MappingProxyType( + {s.yaml_name: s for s in yaml_slots} + ) class YAML_Struct_Serializer(YAML_Serializer): - """ + """ Default serializer for YAML_Struct objects - """ + """ - def __init__(self, struct): - self.struct = struct + def __init__(self, struct): + self.struct = struct - def to_yaml(self, obj): - def yaml_items(): - for slot in self.struct._yaml_slots_: - key = represent_str(slot.yaml_name) - val = slot.serializer.to_yaml(slot.__get__(obj)) + def to_yaml(self, obj): + def yaml_items(): + for slot in self.struct._yaml_slots_: + key = represent_str(slot.yaml_name) + val = slot.serializer.to_yaml(slot.__get__(obj)) - assert val is not None, repr(slot.type) + assert val is not None, repr(slot.type) - yield key, val + yield key, val - return make_mapping_node( - yaml_items(), - flow_style=self.struct.yaml_flow_style(), - ) + return make_mapping_node( + yaml_items(), + flow_style=self.struct.yaml_flow_style(), + ) - def from_yaml(self, node: yaml.Node) -> list: - check_type(MappingNode, node) + def from_yaml(self, node: yaml.Node) -> list: + check_type(MappingNode, node) - kwargs = {} - slots = self.struct._yaml_slot_dict_ + kwargs = {} + slots = self.struct._yaml_slot_dict_ - for k, v in node.value: - k = construct_str(k) + for k, v in node.value: + k = construct_str(k) - slot = slots.get(k) + slot = slots.get(k) - if slot is None: - raise yaml.constructor.ConstructorError('invalid field ' + repr(k) + ' in ' + self.struct.__name__) + if slot is None: + raise yaml.constructor.ConstructorError( + "invalid field " + repr(k) + " in " + self.struct.__name__ + ) - v = slot.serializer.from_yaml(v) + v = slot.serializer.from_yaml(v) - kwargs[slot.py_name] = v + kwargs[slot.py_name] = v - return self.struct(**kwargs) + return self.struct(**kwargs) class Slot: - """ - Descriptor for declaring fields in a YAML_Struct. - """ - - NO_DEFAULT = object() - - def __init__(self, type_, default = NO_DEFAULT): - self._type = type_ - self._serializer = YAML_Serializer.for_type(type_) - - if default is self.NO_DEFAULT: - self._default = self._fail_no_default - self._has_default = False - else: - if type(default) is type: - self._default = default - else: - self._default = lambda: default - self._has_default = True - - def __set_name__(self, owner, name): - self._py_name = name - self._yaml_name = name.replace('_', '-') - - def __get__(self, instance, owner=None): - if instance is None: - return self - else: - return instance.__dict__[self._py_name] - - def __set__(self, instance, value): - instance.__dict__[self._py_name] = value - - @property - def has_default(self) -> bool: """ - Check if this descriptor has a default value. - """ - - return self._has_default - - def set_default(self, instance): - """ - Set property in *instance* to default value. + Descriptor for declaring fields in a YAML_Struct. """ - assert self.has_default, (repr(self.py_name) + ' ' + - ' (' + repr(self.yaml_name) + ') has no default') + NO_DEFAULT = object() + + def __init__(self, type_, default=NO_DEFAULT): + self._type = type_ + self._serializer = YAML_Serializer.for_type(type_) + + if default is self.NO_DEFAULT: + self._default = self._fail_no_default + self._has_default = False + else: + if type(default) is type: + self._default = default + else: + self._default = lambda: default + self._has_default = True + + def __set_name__(self, owner, name): + self._py_name = name + self._yaml_name = name.replace("_", "-") + + def __get__(self, instance, owner=None): + if instance is None: + return self + else: + return instance.__dict__[self._py_name] + + def __set__(self, instance, value): + instance.__dict__[self._py_name] = value + + @property + def has_default(self) -> bool: + """ + Check if this descriptor has a default value. + """ + + return self._has_default + + def set_default(self, instance): + """ + Set property in *instance* to default value. + """ + + assert self.has_default, ( + repr(self.py_name) + " " + " (" + repr(self.yaml_name) + ") has no default" + ) - setattr(instance, self._py_name, self.default) + setattr(instance, self._py_name, self.default) - @property - def py_name(self): - return self._py_name + @property + def py_name(self): + return self._py_name - @property - def yaml_name(self): - return self._yaml_name + @property + def yaml_name(self): + return self._yaml_name - @property - def type(self): - return self._type + @property + def type(self): + return self._type - @property - def serializer(self): - return self._serializer + @property + def serializer(self): + return self._serializer - @property - def default(self): - return self._default() + @property + def default(self): + return self._default() - def _fail_no_default(self): - raise ValueError(f'Slot {self.yaml_name!r} has no default') + def _fail_no_default(self): + raise ValueError(f"Slot {self.yaml_name!r} has no default") def load(serializer: YAML_Serializer, stream: ty.IO[str]) -> YAML_Serializable: - """ + """ Load and deserialize one YAML document into an object. - """ + """ - loader = yaml.SafeLoader(stream) + loader = yaml.SafeLoader(stream) - try: - node = loader.get_single_node() + try: + node = loader.get_single_node() - if node is not None: - return serializer.from_yaml(node) + if node is not None: + return serializer.from_yaml(node) - raise yaml.YAMLError('Empty document') - finally: - loader.dispose() + raise yaml.YAMLError("Empty document") + finally: + loader.dispose() -def load_all(serializer: YAML_Serializer, stream: ty.IO[str]) -> ty.Iterable[YAML_Serializable]: - """ +def load_all( + serializer: YAML_Serializer, stream: ty.IO[str] +) -> ty.Iterable[YAML_Serializable]: + """ Load and deserialize one YAML document into an object. - """ + """ - loader = yaml.SafeLoader(stream) + loader = yaml.SafeLoader(stream) - try: - while loader.check_data(): - if loader.check_node(): - node = loader.get_node() + try: + while loader.check_data(): + if loader.check_node(): + node = loader.get_node() - yield serializer.from_yaml(node) - finally: - loader.dispose() + yield serializer.from_yaml(node) + finally: + loader.dispose() def dump(obj: YAML_Serializable, stream: ty.IO[str] = sys.stdout): - dumper = yaml.SafeDumper(stream) + dumper = yaml.SafeDumper(stream) - try: - dumper.open() - stream.write('\n') + try: + dumper.open() + stream.write("\n") - dumper.represent_data('') + dumper.represent_data("") - node = represent_object(obj) + node = represent_object(obj) - dumper.serialize(node) + dumper.serialize(node) - dumper.close() - finally: - dumper.dispose() + dumper.close() + finally: + dumper.dispose() def dump_all(seq: ty.List[YAML_Serializable], stream: ty.IO[str] = sys.stdout): - for obj in seq: - dump(obj, stream) + for obj in seq: + dump(obj, stream) ################################################################################ @@ -483,117 +486,121 @@ def dump_all(seq: ty.List[YAML_Serializable], stream: ty.IO[str] = sys.stdout): def construct_str(node: yaml.Node) -> str: - check_type(ScalarNode, node) + check_type(ScalarNode, node) - return node.value + return node.value def construct_none(node: yaml.Node) -> None: - return Constructor().construct_yaml_null(node) + return Constructor().construct_yaml_null(node) def construct_bool(node: yaml.Node) -> bool: - return Constructor().construct_yaml_bool(node) + return Constructor().construct_yaml_bool(node) def construct_int(node: yaml.Node) -> int: - return Constructor().construct_yaml_int(node) + return Constructor().construct_yaml_int(node) def construct_float(node: yaml.Node) -> float: - return Constructor().construct_yaml_float(node) + return Constructor().construct_yaml_float(node) def construct_datetime(node: yaml.Node) -> datetime.datetime: - return Constructor().construct_yaml_timestamp(node) + return Constructor().construct_yaml_timestamp(node) def construct_timedelta(node: yaml.Node) -> datetime.timedelta: - txt = construct_str(node) + txt = construct_str(node) - if not txt.endswith('s'): - raise yaml.constructor.ConstructorError(f'expected a string like "\d+([.]\d+)?s", got {txt!r}') + if not txt.endswith("s"): + raise yaml.constructor.ConstructorError( + f'expected a string like "\d+([.]\d+)?s", got {txt!r}' + ) - txt = txt[:-1] + txt = txt[:-1] - return datetime.timedelta(seconds=float(txt)) + return datetime.timedelta(seconds=float(txt)) def construct_list(node: yaml.Node, item: YAML_Serializer[T]) -> ty.List[T]: - check_type(SequenceNode, node) - check_tag(SEQUENCE_TAG, node) + check_type(SequenceNode, node) + check_tag(SEQUENCE_TAG, node) - return [ - item.from_yaml(n) for n in node.value - ] + return [item.from_yaml(n) for n in node.value] -def construct_dict(node: yaml.Node, key: YAML_Serializer[K], val: YAML_Serializer[V]) -> ty.Dict[K, V]: - check_type(MappingNode, node) - check_tag(MAPPING_TAG, node) +def construct_dict( + node: yaml.Node, key: YAML_Serializer[K], val: YAML_Serializer[V] +) -> ty.Dict[K, V]: + check_type(MappingNode, node) + check_tag(MAPPING_TAG, node) - return { - key.from_yaml(kv[0]): val.from_yaml(kv[1]) for kv in node.value - } + return {key.from_yaml(kv[0]): val.from_yaml(kv[1]) for kv in node.value} def construct_enum(node: yaml.Node, enum_class: ty.Type[E]): - assert issubclass(enum_class, enum.Enum) - check_type(ScalarNode, node) + assert issubclass(enum_class, enum.Enum) + check_type(ScalarNode, node) - txt = construct_str(node) + txt = construct_str(node) - return enum_class[txt] + return enum_class[txt] def construct_field(node: yaml.MappingNode, field: str, serializer: YAML_Serializer): - """ + """ Helper for deserialzing structs. Takes a MappingNode, all keys must be strings. Find and pop key from *mapping.value*, then de-serialize and return value. - """ + """ - assert isinstance(node, MappingNode) + assert isinstance(node, MappingNode) - val = None + val = None - for i, kv in enumerate(node.value): - k, v = kv + for i, kv in enumerate(node.value): + k, v = kv - k = construct_str(k) + k = construct_str(k) - if k == field: - node.value.pop(i) - val = v + if k == field: + node.value.pop(i) + val = v - if val is None: - raise ValueError('Key ' + repr(field) + ' not in mapping') + if val is None: + raise ValueError("Key " + repr(field) + " not in mapping") - return serializer.from_yaml(val) + return serializer.from_yaml(val) def check_tag(expected_tag: str, node: Node): - """ + """ Check if *node* has expected tag, raise exception otherwise. - """ + """ - assert node.tag + assert node.tag - if node.tag != expected_tag: - raise yaml.constructor.ConstructorError('expected a ' + expected_tag + ', got a ' + node.tag) + if node.tag != expected_tag: + raise yaml.constructor.ConstructorError( + "expected a " + expected_tag + ", got a " + node.tag + ) def check_type(expected_clss: ty.Type[Node], node: Node): - """ + """ Check if *node* has expected type (scalar, mapping, ...), raise exception otherwise. - """ + """ - ## according to typeshed's pyyaml type annotations all subtypes of 'Node' have an 'id' field. - ## But the 'Node' type does not :/ - assert node.id + ## according to typeshed's pyyaml type annotations all subtypes of 'Node' have an 'id' field. + ## But the 'Node' type does not :/ + assert node.id # type: ignore - if type(node) is not expected_clss: - raise yaml.constructor.ConstructorError('expected a ' + expected_clss.id + ', got a ' + node.id) + if type(node) is not expected_clss: + raise yaml.constructor.ConstructorError( + "expected a " + expected_clss.id + ", got a " + node.id # type:ignore + ) ################################################################################ @@ -601,105 +608,110 @@ def check_type(expected_clss: ty.Type[Node], node: Node): def represent_object(obj): - serializer = YAML_Serializer.for_type(type(obj)) + serializer = YAML_Serializer.for_type(type(obj)) - return serializer.to_yaml(obj) + return serializer.to_yaml(obj) def represent_none(): - return Representer().represent_none(None) + return Representer().represent_none(None) def represent_str(obj: str): - return Representer().represent_str(obj) + return Representer().represent_str(obj) def represent_bool(obj: str): - return Representer().represent_bool(obj) + return Representer().represent_bool(obj) def represent_int(obj: int): - return Representer().represent_int(obj) + return Representer().represent_int(obj) def represent_float(obj: float): - return Representer().represent_float(obj) + return Representer().represent_float(obj) def represent_datetime(obj: datetime.datetime): - return Representer().represent_datetime(obj) + return Representer().represent_datetime(obj) def represent_timedelta(obj: datetime.timedelta): - assert type(obj) is datetime.timedelta, obj - return represent_str('%fs' % obj.total_seconds()) + assert type(obj) is datetime.timedelta, obj + return represent_str("%fs" % obj.total_seconds()) def represent_list(sequence: ty.Iterable, itemser: YAML_Serializer, flow_style=None): - value = [] + value = [] - best_style = True + best_style = True - for item in sequence: - node_item = itemser.to_yaml(item) + for item in sequence: + node_item = itemser.to_yaml(item) - if not (isinstance(node_item, ScalarNode) and not node_item.style): - best_style = False + if not (isinstance(node_item, ScalarNode) and not node_item.style): + best_style = False - value.append(node_item) + value.append(node_item) - if flow_style is None: - flow_style = best_style + if flow_style is None: + flow_style = best_style - return SequenceNode(SEQUENCE_TAG, value, flow_style=flow_style) + return SequenceNode(SEQUENCE_TAG, value, flow_style=flow_style) -def represent_dict(mapping: ty.Dict[ty.Any, ty.Any], key: YAML_Serializer, val: YAML_Serializer, flow_style=None): - if not isinstance(mapping, dict): - raise TypeError(type(mapping)) +def represent_dict( + mapping: ty.Dict[ty.Any, ty.Any], + key: YAML_Serializer, + val: YAML_Serializer, + flow_style=None, +): + if not isinstance(mapping, dict): + raise TypeError(type(mapping)) - def items(): - for item_key, item_value in mapping.items(): - node_key = key.to_yaml(item_key) - node_value = val.to_yaml(item_value) + def items(): + for item_key, item_value in mapping.items(): + node_key = key.to_yaml(item_key) + node_value = val.to_yaml(item_value) - yield node_key, node_value + yield node_key, node_value - return make_mapping_node(items(), flow_style) + return make_mapping_node(items(), flow_style) def represent_enum(enum_class: ty.Type[E], enum_val: E): - assert issubclass(enum_class, enum.Enum) - assert enum_val in enum_class + assert issubclass(enum_class, enum.Enum) + assert enum_val in enum_class - return represent_str(enum_val.name) + return represent_str(enum_val.name) def make_mapping_node(mapping: ty.Iterable[ty.Tuple[Node, Node]], flow_style=None): - """ + """ low level helper for creating MappingNode objects - """ + """ - tag = MAPPING_TAG - best_style = True + tag = MAPPING_TAG + best_style = True - value: ty.List[ty.Tuple[Node, Node]] = [] - key: Node - val: Node + value: ty.List[ty.Tuple[Node, Node]] = [] + key: Node + val: Node - for key, val in mapping: - assert isinstance(key, Node), repr(key) - assert isinstance(val, Node), repr(val) + for key, val in mapping: + assert isinstance(key, Node), repr(key) + assert isinstance(val, Node), repr(val) - if not (isinstance(key, ScalarNode) and not key.style): - best_style = False + if not (isinstance(key, ScalarNode) and not key.style): + best_style = False - if not (isinstance(val, ScalarNode) and not val.style): - best_style = False + if not (isinstance(val, ScalarNode) and not val.style): + best_style = False - value.append((key, val)) + value.append((key, val)) - if flow_style is None: - flow_style = best_style + if flow_style is None: + flow_style = best_style - return MappingNode(tag, value, flow_style=flow_style) + return MappingNode(tag, value, flow_style=flow_style) diff --git a/tools/extract-binutils-instruction-database/extract_arm_db/extract_binutils.py b/tools/extract-binutils-instruction-database/extract_arm_db/extract_binutils.py index f2b7843..168a0d6 100644 --- a/tools/extract-binutils-instruction-database/extract_arm_db/extract_binutils.py +++ b/tools/extract-binutils-instruction-database/extract_arm_db/extract_binutils.py @@ -9,7 +9,10 @@ URL = "https://sourceware.org/git/?p=binutils-gdb.git;a=blob_plain;f=opcodes/aar # https://sourceware.org/git/?p=binutils-gdb.git;a=blob_plain;f=opcodes/aarch64-opc-2.c # may be usefull too (signification of `op_types` field) INFILE = "aarch64-tbl.h" -OUTFILE = "instructions-binutils.py" +OUTFILE = "instructions_binutils.py" + +UNSUPPORTED_SET = {"ldst_unscaled", "ldstpair_indexed"} +UNSUPPORTED_INS = {"prfm", "casp", "caspa", "caspl", "caspal"} def brace_split(text, split_on=","): @@ -96,8 +99,9 @@ def parse_operands_str(operands_str, is_type: bool) -> List[List[str]]: class OP_TYPE(Enum): REG = 0 IMM = 1 - ADDR = 2 - FLAG = 3 + IMM_VLS = 2 + ADDR = 3 + FLAG = 4 NO_OP = 10 OTHER = 11 @@ -105,7 +109,6 @@ class OP_TYPE(Enum): def parse_operands_role( list_op_roles: List[str], ) -> List[Tuple[OP_TYPE, List[str]]]: - # returns (create_func, name, op_role, action, visibility) list_carac: List[Tuple[OP_TYPE, List[str]]] = [] for role in list_op_roles: if role in { @@ -120,6 +123,11 @@ def parse_operands_role( "Rd_SP", }: list_carac.append((OP_TYPE.REG, [role, "W", "EXPLICIT"])) + elif role in { + "IMM_VLSR", + "IMM_VLSL", + }: + list_carac.append((OP_TYPE.IMM_VLS, [role, "EXPLICIT"])) elif "IMM" in role or role in { "IDX", "WIDTH", @@ -182,9 +190,7 @@ def parse_operands(list_op_classes, op_roles) -> List[List[str]]: ) for idx, list_ops in enumerate(list_op_classes): if (len(op_carac) - len(list_op_classes[idx])) > 0: - list_ops += ["Unknown (not in reg)"] * ( - len(op_carac) - len(list_op_classes[idx]) - ) + list_ops += [None] * (len(op_carac) - len(list_op_classes[idx])) else: op_carac += [(OP_TYPE.NO_OP, ["SUPPRESSED"])] * ( len(list_op_classes[idx]) - len(op_carac) @@ -200,46 +206,74 @@ def parse_operands(list_op_classes, op_roles) -> List[List[str]]: new_operands: List[str] = [] for op_class, (op_type, list_carac) in zip(op_classes, op_carac): if op_type == OP_TYPE.REG: - if op_class[-len("NIL") :] == "NIL" and list_carac[0].startswith( - "SYSREG" - ): + if op_class.endswith("NIL") and list_carac[0].startswith("SYSREG"): op_class = "registers.SYSREG" new_operands.append( f'reg_op(name="{list_carac[0]}", reg_class={op_class}, ' f"action={list_carac[1]}, visibility={list_carac[2]})" ) elif op_type == OP_TYPE.ADDR: - addr_class = op_class[len("registers.") :] - if addr_class == "NIL": - addr_class = list_carac[0] - addr_class = f"operands.{addr_class}" - new_operands.append( - f'addr_op(name="{list_carac[0]}", addr_class={addr_class}, ' - f"visibility={list_carac[1]})" - ) - elif op_type == OP_TYPE.IMM: - imm_class = op_class[len("registers.") :] - if imm_class == "NIL": - imm_class = f"operands.{list_carac[0]}" + reg_class = op_class[len("registers.") :] + if reg_class == "NIL": new_operands.append( - f'imm_op(name="{list_carac[0]}", imm_class={imm_class}, ' + f'addr_op(name="{list_carac[0]}", visibility={list_carac[1]})' + ) + else: + new_operands.append( + f'addr_op(name="{list_carac[0]}", reg_class={op_class}, ' f"visibility={list_carac[1]})" ) - elif ( - imm_class.startswith("imm_") and imm_class[len("imm_")].isnumeric() + elif op_type == OP_TYPE.IMM: + imm_class = ( + op_class[len("registers.") :] if op_class is not None else None + ) + # Binop using another encoding here + if ( + list_carac[0] + in { + "COND1", + "FPIMM0", + "ADDR_SIMM9", + "ADDR_SIMM13", + "ADDR_SIMM19", + "ADDR_UIMM12", + "ADDR_SIMM7", + } + or imm_class == "imm_tag" + ): + imm_class = list_carac[0] + if ( + imm_class is not None + and imm_class.startswith("imm_") + and imm_class[len("imm_")].isnumeric() ): # Under the form imm_min_max min_val = imm_class[len("imm_")] max_val = imm_class[len("imm_X_") :] new_operands.append( - f'imm_op(name="{list_carac[0]}",' + f'imm_op(name="{list_carac[0]}", ' f"visibility={list_carac[1]}, min_val={min_val}, " f"max_val={max_val})" ) - + else: + if imm_class is None or imm_class == "NIL": + imm_class = f"{list_carac[0]}" + if imm_class in {"LSL", "LSR", "MSL"}: + imm_class = f"IMM_{imm_class}" + imm_class = f"operands.{imm_class}" + new_operands.append( + f'imm_op(name="{list_carac[0]}", imm_class={imm_class}, ' + f"visibility={list_carac[1]})" + ) + elif op_type == OP_TYPE.IMM_VLS: + new_operands.append( + f'imm_op(name="{list_carac[0]}", ' + f"imm_class=operands.{list_carac[0]}, reg_class={op_class}, " + f"visibility={list_carac[1]})" + ) elif op_type == OP_TYPE.FLAG: new_operands.append( - f'flag_op(name="{list_carac[0]}", visibility={list_carac[1]})' + f'flags_op(name="{list_carac[0]}", visibility={list_carac[1]})' ) elif op_type == OP_TYPE.OTHER: raise Exception(f"Unsupported operand type: {list_carac[0]}") @@ -264,7 +298,13 @@ def clean_instruction_dict(instruction) -> Dict[str, Any]: return instruction -def parse_line(line): +def parse_line(line: str) -> Dict: + # move is present in two variants, discarding one + if line.startswith(' { "mov", 0x52800000, 0x7f800000'): + return None + # FIXME: SVE support + if "SVE" in line: + return None line = line.strip() if not line: return None @@ -274,8 +314,12 @@ def parse_line(line): raise Exception("Invalid start of line: <{}>".format(line)) line_split = list(map(lambda x: x.strip(), brace_split(line[1:-2]))) + if line_split[3] in UNSUPPORTED_SET: + return None assert len(line_split) == 12 mnemonic = line_split[0][1:-1] # strip quotes + if mnemonic in UNSUPPORTED_INS: + return None return line_split, clean_instruction_dict( { "mnemonic": mnemonic, @@ -320,7 +364,7 @@ def generate_header(file, instruction_list: List[Dict]) -> None: file.write("\n\n") file.write( "def make_instruction_database(*, make_instruction, reg_op, " - "addr_op, imm_op, flag_op):\n" + "addr_op, imm_op, flags_op):\n" ) @@ -371,6 +415,6 @@ def generate_instruction_file() -> None: "\tmake_instruction(\n" f'\t\tmnemonic = "{arguments["mnemonic"]}",\n' f'\t\tisa_set = ISA.{arguments["insn_class"]},\n' - f'\t\tisa_extension = ISA_Extention.{arguments["extension"]},\n' + f'\t\tisa_extension = ISA_Extension.{arguments["extension"]},\n' f"\t\toperands = [\n{operands}\t\t],\n\t)\n" ) -- GitLab From a1217b55fdeab76d8aec21189ce24dc312c30826 Mon Sep 17 00:00:00 2001 From: Nicolas Derumigny <nderumigny@gmail.com> Date: Tue, 18 May 2021 16:45:01 +0200 Subject: [PATCH 08/12] asm: adding preliminary asmwriter support --- src/pipedream/asm/armv8a/__init__.py | 51 ++++++- src/pipedream/asm/armv8a/asmwriter.py | 183 ++++++++++++++++++++++++++ src/pipedream/asm/armv8a/registers.py | 3 +- src/pipedream/benchmark/common.py | 5 +- 4 files changed, 231 insertions(+), 11 deletions(-) create mode 100644 src/pipedream/asm/armv8a/asmwriter.py diff --git a/src/pipedream/asm/armv8a/__init__.py b/src/pipedream/asm/armv8a/__init__.py index 55d2f79..83cbddf 100644 --- a/src/pipedream/asm/armv8a/__init__.py +++ b/src/pipedream/asm/armv8a/__init__.py @@ -1,7 +1,6 @@ -from typing import List -import random - +from typing import List, Sequence, IO +import random import pipedream.utils.abc as abc from pipedream.utils import * @@ -14,7 +13,6 @@ from pipedream.benchmark.types import Loop_Overhead from . import registers from . import operands -from . import instructions from .asmwriter import * from .registers import * @@ -28,7 +26,6 @@ __all__ = [ ] -# TODO class ARMv8_Architecture(Architecture): @property # type: ignore @abc.override @@ -40,10 +37,50 @@ class ARMv8_Architecture(Architecture): def nb_vector_reg(self) -> int: return len(V_8B) # type: ignore + @property + @abc.override + def max_vector_size(self) -> int: + return 128 // 8 + + @abc.override + def instruction_set(self) -> Instruction_Set: + return self._instruction_set + + @abc.override + def register_set(self) -> Register_Set: + return self._register_set + + @abc.override + def asm_dialects(self) -> List[ASM_Dialect]: + return self._asm_dialects + + @abc.override + def make_asm_writer(self, dialect: ASM_Dialect, file: IO[str]) -> ASM_Writer: + return ARMv8_ASM_Writer(dialect, file) + + @abc.override + def make_ir_builder(self) -> "ARMv8_IR_Builder": + return ARMv8_IR_Builder(self) + + @abc.override + def make_register_allocator(self) -> Register_Liveness_Tracker: + regs = self.register_set() + + return Register_Liveness_Tracker( + all_registers=regs.all_registers(), + callee_save_registers=regs.callee_save_registers(), + ) + + # TODO + @abc.override + def loop_overhead(self, num_iterations: int) -> Loop_Overhead: + pass + -# TODO class ARMv8_ASM_Dialect(ASM_Dialect): - pass + @abc.override + def _mnemonic(self, inst: "X86_Instruction") -> str: + return MNEMONICS[instr.name] # TODO diff --git a/src/pipedream/asm/armv8a/asmwriter.py b/src/pipedream/asm/armv8a/asmwriter.py new file mode 100644 index 0000000..8452691 --- /dev/null +++ b/src/pipedream/asm/armv8a/asmwriter.py @@ -0,0 +1,183 @@ +from pipedream.utils import abc +from pipedream.asm.ir import * +from pipedream.asm.asmwriter import * +from pipedream.asm.armv8a.registers import * +from pipedream.asm.armv8a.operands import * + +import functools +import math + +from typing import IO, Sequence + +__all__ = [ + "ARMv8_ASM_Writer", +] + + +class ARMv8_ASM_Writer(ASM_Writer): + def __init__(self, dialect: "ARMv8_ASM_Dialect", file: IO[str]): + super().__init__(file) + self.dialect = dialect + + @abc.override + def begin_file(self, name): + self.print(1, '.file 1 "' + name + '"') + self.print(1, ".text") + + @abc.override + def end_file(self, name): + self.print(1, '.ident "pipe-dream"') + + @abc.override + def begin_function(self, function_name: str): + self.print(1, ".globl " + function_name) + self.print(1, ".type " + function_name + ", @function") + self.align() + self.print(0, function_name + ":") + self.print(1, ".cfi_startproc") + + @abc.override + def end_function(self, function_name: str): + self.print(1, ".cfi_endproc") + self.print(1, ".size " + function_name + ", .-" + function_name) + + @abc.override + def emit_label(self, label): + self.print(0, label.name + ":") + + @abc.override + def align(self): + self.print(1, ".align 16, 0x90") + + @abc.override + def insts(self, insts: Sequence[Instruction]): + for inst in insts: + assert isinstance(inst, Instruction) + assert not any(o.is_virtual for o in inst.operands), [ + inst, + [o for o in inst.operands if o.is_virtual], + ] + + try: + mnem = self.dialect._mnemonic(inst) + ops = reversed( + [ + op + for op in inst.operands + if op.visibility is not Operand_Visibility.SUPPRESSED + ] + ) + + if ops: + self.print( + 1, mnem.ljust(12), ", ".join(self.emit_operand(o) for o in ops) + ) + else: + self.print(1, mnem) + + except AssertionError: + print("#", inst) + raise + + # FIXME: this is x86 code, but may be compatible to ARM as well + @abc.override + def global_byte_array(self, name: str, size: int, alignment: int): + ## this is just copy/adapted from ASM emitted by GCC + + assert type(name) is str + assert type(size) is int + assert type(alignment) is int + + assert name and name.isprintable() + assert size >= 0 + assert alignment >= 0 + assert math.log2(alignment).is_integer() + + ## assume we are in text section + # self.print(1, '.text') + ## switch to BSS & emit array + self.print(1, ".bss") + self.print(1, ".global", name) + self.print(1, ".align", alignment) + self.print(1, ".type", name + ",", "@object") + self.print(1, ".size", name + ",", size) + self.print(0, name + ":") + self.print(1, ".zero", size) + ## switch back to text section + self.print(1, ".text") + self.print(0, "") + + @abc.override + def comment(self, *args): + self.print(1, "#", *args) + + @abc.override + def newline(self): + self.print(0) + + @abc.override + def emit_operand(self, op) -> str: + return renderARMv8_arg(op) + + +@functools.singledispatch +def renderARMv8_arg(arg): + raise ValueError( + "Cannot render argument " + repr(arg) + " of type " + type(arg).__name__ + ) + + +@renderARMv8_arg.register(Label) +def _(arg): + return arg.name + + +@renderARMv8_arg.register(ARMv8_Register_Operand) +def _(arg): + return arg.register.att_asm_name + + +@renderARMv8_arg.register(ARMv8_Immediate_Operand) +def _(arg): + if type(arg.value) is Label: + return arg.value.name + else: + return "$" + str(arg.value) + + +@renderARMv8_arg.register(str) +def _(arg): + # TODO: introduce symbol type + return arg + + +@renderARMv8_arg.register(ARMv8_Base_Offset_Operand) +def _(arg): + return renderARMv8_mem_arg(arg.offset) + "(" + renderARMv8_mem_arg(arg.base) + ")" + + +@functools.singledispatch +def renderARMv8_mem_arg(arg): + raise ValueError( + "Cannot render memory argument " + repr(arg) + " of type " + type(arg).__name__ + ) + + +@renderARMv8_mem_arg.register(ARMv8_Immediate_Operand) +def _(arg) -> str: + assert arg.value is not None + + if arg.value == 0: + return "" + + if type(arg.value) is Label: + return arg.value.name + + assert type(arg.value) is int + return str(arg.value) + + +@renderARMv8_mem_arg.register(ARMv8_Register_Operand) +def _(arg) -> str: + assert arg.register is not None + return arg.register.asm_name diff --git a/src/pipedream/asm/armv8a/registers.py b/src/pipedream/asm/armv8a/registers.py index 9af9033..3743e36 100644 --- a/src/pipedream/asm/armv8a/registers.py +++ b/src/pipedream/asm/armv8a/registers.py @@ -156,8 +156,7 @@ class ARMv8_Register(Register): return self._name @property # type:ignore - @abc.override - def att_asm_name(self) -> str: + def asm_name(self) -> str: return self._name @property # type:ignore diff --git a/src/pipedream/benchmark/common.py b/src/pipedream/benchmark/common.py index dab5ecb..baeaf8d 100644 --- a/src/pipedream/benchmark/common.py +++ b/src/pipedream/benchmark/common.py @@ -1205,9 +1205,10 @@ class _Benchmark_Runner: # of the kernel (=> used only during initialisation) unused_registers_kernel = set() - from pipedream.asm.x86 import RDX + if arch.name == "x86": + from pipedream.asm.x86 import RDX - assert results is RDX + assert results is RDX if gen_papi_calls: out.comment("papi_event_set -> ", SCRATCH_REG_1) -- GitLab From 1589ebd6782a0d177890355f44e40b83acedaa21 Mon Sep 17 00:00:00 2001 From: Nicolas Derumigny <nderumigny@gmail.com> Date: Tue, 1 Jun 2021 16:37:23 +0200 Subject: [PATCH 09/12] arm/operands: correcting classnames --- src/pipedream/asm/armv8a/operands.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/pipedream/asm/armv8a/operands.py b/src/pipedream/asm/armv8a/operands.py index 8f20f32..4c7b755 100644 --- a/src/pipedream/asm/armv8a/operands.py +++ b/src/pipedream/asm/armv8a/operands.py @@ -30,6 +30,9 @@ class ARMv8_Operand(ir.Operand): def name(self) -> str: return self._name + def __repr__(self) -> str: + return self.name + @property # type: ignore @abc.override def visibility(self) -> ir.Operand_Visibility: @@ -215,7 +218,7 @@ class ARMv8_Base_Operand(ARMv8_Operand): self._short_short_name, "/", "(", - repr(self.base), + repr(self.base.name), ")", ] @@ -239,7 +242,7 @@ class ARMv8_Base_Offset_Operand(ir.Base_Displacement_Operand): self._name = name self._offset = offset - self._base_class = base + self._base = base @property def name(self): @@ -280,7 +283,7 @@ class ARMv8_Base_Offset_Operand(ir.Base_Displacement_Operand): @property @abc.override def short_name(self) -> str: - return f"{self._short_short_name}BO{self.address_width}_{self.memory_width}" + return f"{self._short_short_name}BO_{self.memory_width}" @property @abc.override @@ -300,7 +303,6 @@ class ARMv8_Base_Offset_Operand(ir.Base_Displacement_Operand): ":", self._short_short_name, "BO", - str(self.address_width), "/", ] if hasattr(self, "memory_width"): @@ -308,21 +310,21 @@ class ARMv8_Base_Offset_Operand(ir.Base_Displacement_Operand): str(self.memory_width), "/", ] - tmp += ["(", repr(self.base), ", ", repr(self.offset), ")"] + tmp += ["(", repr(self.base.name), ", ", repr(self.offset), ")"] return "".join(tmp) class ARMv8_Immediate_Operand(ARMv8_Operand, ir.Immediate_Operand): def __init__(self, name: str, visibility: ir.Operand_Visibility, value: int = None): - super().__init__(name, str(type(self)), visibility) + super().__init__(name, self.__class__.__name__, visibility) assert value is None or isinstance(value, int) self._value = value @property # type: ignore @abc.override def short_name(self) -> str: - return type(self).__name__.upper() + return self.__class__.__name__.upper() @property # type: ignore @abc.override @@ -1098,7 +1100,7 @@ class COND(ARMv8_Operand): """ def __init__(self, name: str, visibility: ir.Operand_Visibility, value: int = None): - super().__init__(name, str(type(self)), visibility) + super().__init__(name, self.__class__.__name__, visibility) assert value is None or isinstance(value, str) self._value = value @@ -1185,7 +1187,7 @@ class COND1(ARMv8_Operand): """ def __init__(self, name: str, visibility: ir.Operand_Visibility, value: int = None): - super().__init__(name, str(type(self)), visibility) + super().__init__(name, self.__class__.__name__, visibility) assert value is None or isinstance(value, str) self._value = value -- GitLab From ea74b1ee1e9adea3a2403e541861f0e0d7015dcd Mon Sep 17 00:00:00 2001 From: Nicolas Derumigny <nderumigny@gmail.com> Date: Tue, 1 Jun 2021 16:38:54 +0200 Subject: [PATCH 10/12] asm/ir: also load ARM ISA --- src/pipedream/asm/armv8a/__init__.py | 187 ++++++++++++++++++++++- src/pipedream/asm/armv8a/instructions.py | 19 ++- src/pipedream/asm/armv8a/registers.py | 9 +- src/pipedream/asm/ir.py | 6 +- 4 files changed, 210 insertions(+), 11 deletions(-) diff --git a/src/pipedream/asm/armv8a/__init__.py b/src/pipedream/asm/armv8a/__init__.py index 83cbddf..0a189fd 100644 --- a/src/pipedream/asm/armv8a/__init__.py +++ b/src/pipedream/asm/armv8a/__init__.py @@ -1,4 +1,4 @@ -from typing import List, Sequence, IO +from typing import List, Sequence, IO, Set, Tuple import random import pipedream.utils.abc as abc @@ -27,6 +27,11 @@ __all__ = [ class ARMv8_Architecture(Architecture): + def __init__(self): + self._instruction_set = ARMv8_Instruction_Set() + self._register_set = ARMv8_Register_Set() + self._asm_dialects = [ARMv8_ASM_Dialect()] + @property # type: ignore @abc.override def name(self) -> str: @@ -85,4 +90,182 @@ class ARMv8_ASM_Dialect(ASM_Dialect): # TODO class ARMv8_IR_Builder(IR_Builder): - pass + def __init__(self, arch): + self._arch = arch + self._insts = arch.instruction_set() + + @abc.override + def get_return_register(self) -> Register: + return X0 + + @abc.override + def get_argument_register(self, idx: int) -> Register: + return {0: ARG_1, 1: ARG_2, 2: ARG_3, 3: ARG_4, 4: ARG_5, 5: ARG_6, 6: ARG_7}[ + idx + ] + + @abc.override + def get_scratch_register(self, idx: int) -> Register: + return {0: X9, 1: X10, 2: X11, 3: X12, 4: X13, 5: X14, 6: X15}[idx] + + @abc.override + def select_memory_base_register( + self, + insts: List[Instruction], + free_regs: Set[Register], + address_width: int, + ) -> Register: + """ + Select a register that can be used as a memory base register for all memory accesses + in the given list of instructions :insts:. + """ + pass + + @abc.override + def preallocate_benchmark( + self, alloc: Register_Liveness_Tracker, instructions: Sequence[Instruction] + ) -> Tuple[object, List[Instruction]]: + stolen_regs = [] + preallocated_insts = [] + + # TODO + + return stolen_regs, preallocated_insts + + @abc.override + def free_stolen_benchmark_registers( + self, alloc: Register_Liveness_Tracker, stolen_regs: object + ): + for reg in stolen_regs: + alloc.free(reg) + + @abc.override + def emit_loop_prologue( + self, + kernel: Sequence[Instruction], + free_regs: Sequence[Register], + reg_values: Register, + ) -> List[Instruction]: + """`reg_values` is a register containing a pointer to a memory region from + where vector registers can be initialized""" + + out: List[Instruction] = [] + # TODO: initialize vect reg values + + return out + + @abc.override + def emit_benchmark_prologue( + self, + kernel: Sequence[Instruction], + free_regs: Sequence[Register], + ) -> List[Instruction]: + out = [] + # TODO: initialize reg values + + return out + + @abc.override + def emit_benchmark_epilogue( + self, + kernel: Sequence[Instruction], + free_regs: Sequence[Register], + ) -> List[Instruction]: + return [] + + @abc.override + def emit_loop_epilogue( + self, + kernel: Sequence[Instruction], + free_regs: Sequence[Register], + reg_values: Register, + ) -> List[Instruction]: + out = [] + out += self.emit_pop_from_stack(reg_values) + return out + + def emit_dependency_breaker(self, reg: Register) -> List[Instruction]: + # TODO: instert a zero-idiom on the register + assert False + return [] + + @abc.override + def emit_sequentialize_cpu( + self, alloc: Register_Liveness_Tracker + ) -> List[Instruction]: + out: List[Instruction] = [] + # TODO: is there something to do this on ARM CPUs? a branch, maybe? + + return out + + @abc.override + def emit_copy(self, src: Register, dst: Register) -> List[Instruction]: + # TODO: mov + print("WARNING: copy not supported") + return [] + + @abc.override + def emit_push_to_stack(self, src: Register) -> List[Instruction]: + # TODO: PUSH? STR? STP? + print("WARNING: PUSH TO STACK UNSUPPORTED") + return [] + + @abc.override + def emit_pop_from_stack(self, dst: Register) -> List[Instruction]: + # TODO: POP? STR? + print("WARNING: POP FROM STACK UNSUPPORTED") + return [] + + @abc.override + def emit_branch_if_not_zero(self, reg: Register, dst: Label) -> List[Instruction]: + assert False + return [] + + @abc.override + def emit_branch_if_zero(self, reg: Register, dst: Label) -> List[Instruction]: + assert False + return [] + + @abc.override + def emit_call(self, dst: Label) -> List[Instruction]: + print("ERROR: emit_call not supported") + return [] + + @abc.override + def emit_return(self, reg: Register) -> List[Instruction]: + print("ERROR: emit_return not supported") + return [] + + @abc.override + def emit_put_const_in_register( + self, const: int, reg: Register + ) -> List[Instruction]: + print("WARNING: emit_put_const_in_register not supported") + return [] + + @abc.override + def emit_mul_reg_const(self, reg: Register, const: int) -> List[Instruction]: + print("WARNING: emit_mul_reg_const not supported") + return [] + + @abc.override + def emit_add_registers( + self, src_reg: Register, src_dst_reg: Register + ) -> List[Instruction]: + print("warning: emit_add_registers not supported") + return [] + + @abc.override + def emit_put_basedisplacement_in_register( + self, base: "Register", displacement: int, reg: "Register" + ) -> List[Instruction]: + print("Warning: emit_put_basedisplacement_in_register not supoprted") + assert False + return [] + + @abc.override + def emit_substract_one_from_reg_and_branch_if_not_zero( + self, loop_counter: Register, dst: "Label" + ) -> List[Instruction]: + print("ERROR: emit_substract_one_from_reg_and_branch_if_not_zero not supported") + return [] diff --git a/src/pipedream/asm/armv8a/instructions.py b/src/pipedream/asm/armv8a/instructions.py index 9bfba30..e126eed 100644 --- a/src/pipedream/asm/armv8a/instructions.py +++ b/src/pipedream/asm/armv8a/instructions.py @@ -18,7 +18,7 @@ __all__ = [ "ARMv8_Instruction_Set", "MNEMONICS", "INSTRUCTIONS", - # "Harness", + "Harness", ] Instruction_Name = str @@ -48,7 +48,7 @@ class ARMv8_Instruction(ir.Machine_Instruction): @property # type: ignore @abc.override def tags(self) -> Sequence[str]: - return [] + return set([self._isa_set]) @property # type: ignore def mnemonic(self) -> str: @@ -90,6 +90,10 @@ class ARMv8_Instruction(ir.Machine_Instruction): class ARMv8_Instruction_Set(ir.Instruction_Set): + @abc.override + def instruction_groups(self) -> List["Instruction_Group"]: + return [] + @abc.override def instructions(self) -> List[ir.Machine_Instruction]: return list(INSTRUCTIONS.values()) @@ -102,6 +106,14 @@ class ARMv8_Instruction_Set(ir.Instruction_Set): def instruction_for_name(self, name: str) -> ir.Machine_Instruction: return INSTRUCTIONS[name] + @abc.override + def instructions_for_tags(self, *tags) -> Sequence[ir.Machine_Instruction]: + tags = set(tags) + + for inst in INSTRUCTIONS.values(): + if tags <= inst.tags: + yield inst + def __getitem__(self, name): return INSTRUCTIONS[name] @@ -136,7 +148,7 @@ def mk_inst( isa_set: str, operands: Union[Tuple, List], isa_extension: str = None, - can_benchmark=None, + can_benchmark: Optional[bool] = None, ): global INSTRUCTIONS @@ -246,7 +258,6 @@ class Harness: Well known instructions used in benchmark harness """ - # TODO at codegen pass diff --git a/src/pipedream/asm/armv8a/registers.py b/src/pipedream/asm/armv8a/registers.py index 3743e36..e65b0a9 100644 --- a/src/pipedream/asm/armv8a/registers.py +++ b/src/pipedream/asm/armv8a/registers.py @@ -386,6 +386,8 @@ _def_ARMv8_reg("_NZCV", _ARMv8_REGISTER_TYPE.FLAGS) WSP = _def_ARMv8_reg("WSP", _ARMv8_REGISTER_TYPE.WSP) SP = _def_ARMv8_reg("SP", _ARMv8_REGISTER_TYPE.SP, [WSP], widest=True) +r_SP = SP # hack to avoid elision with RSP register class + # zero registers WZR = _def_ARMv8_reg("WZR", _ARMv8_REGISTER_TYPE.WZR) ZXR = _def_ARMv8_reg( @@ -415,15 +417,14 @@ _def_ARMv8_register_class("ANY_REGISTER", *_ALL_REGISTERS_) # See "Procedure Call Standard for the Arm 64-bit Architecture" _def_ARMv8_register_class( "ARGUMENT_REGISTER", - *[self[f"X{i}"] for i in range(0, 7)], - **{f"ARG_{i}": self[f"X{i}"] for i in range(0, 7)}, + *[self[f"X{i}"] for i in range(0, 8)], + **{f"ARG_{i}": self[f"X{i}"] for i in range(0, 8)}, ) _def_ARMv8_register_class( "CALLER_SAVED", *[self[f"X{i}"] for i in range(0, 19)], - *[self[f"X{i}"] for i in range(28, 31)], ) _def_ARMv8_register_class("CALLEE_SAVED", *[self[f"X{i}"] for i in range(19, 28)]) @@ -440,7 +441,7 @@ del self class ARMv8_Register_Set(Register_Set): @abc.override def stack_pointer_register(self) -> Register: - return SP + return r_SP @abc.override def register_classes(self) -> List[Register_Class]: diff --git a/src/pipedream/asm/ir.py b/src/pipedream/asm/ir.py index 82016e5..c683da9 100644 --- a/src/pipedream/asm/ir.py +++ b/src/pipedream/asm/ir.py @@ -42,11 +42,15 @@ class Architecture(abc.ABC): @staticmethod def for_name(name: str) -> "Architecture": import pipedream.asm.x86 + import pipedream.asm.armv8a if not hasattr(Architecture, "_REGISTRY_"): X86 = pipedream.asm.x86.X86_Architecture() + ARMv8a = pipedream.asm.armv8a.ARMv8_Architecture() - Architecture._REGISTRY_ = types.MappingProxyType({X86.name: X86}) + Architecture._REGISTRY_ = types.MappingProxyType( + {X86.name: X86, ARMv8a.name: ARMv8a} + ) try: return Architecture._REGISTRY_[name] -- GitLab From 7616577d4010387d2a0384ec82e2fb6df420bbf6 Mon Sep 17 00:00:00 2001 From: Nicolas Derumigny <nderumigny@gmail.com> Date: Fri, 9 Jul 2021 15:00:46 +0200 Subject: [PATCH 11/12] arm/*: various bugfixes, mainly codegen + benchmark/common --- src/pipedream/asm/armv8a/__init__.py | 13 +++++----- src/pipedream/asm/armv8a/asmwriter.py | 2 +- src/pipedream/asm/armv8a/registers.py | 4 +-- src/pipedream/benchmark/common.py | 26 ++++++++++++------- .../extract_arm_db/extract_binutils.py | 8 ++++-- 5 files changed, 32 insertions(+), 21 deletions(-) diff --git a/src/pipedream/asm/armv8a/__init__.py b/src/pipedream/asm/armv8a/__init__.py index 0a189fd..08dc2fb 100644 --- a/src/pipedream/asm/armv8a/__init__.py +++ b/src/pipedream/asm/armv8a/__init__.py @@ -84,8 +84,8 @@ class ARMv8_Architecture(Architecture): class ARMv8_ASM_Dialect(ASM_Dialect): @abc.override - def _mnemonic(self, inst: "X86_Instruction") -> str: - return MNEMONICS[instr.name] + def _mnemonic(self, inst: "ARMv8_Instruction") -> str: + return MNEMONICS[inst.name] # TODO @@ -126,10 +126,8 @@ class ARMv8_IR_Builder(IR_Builder): self, alloc: Register_Liveness_Tracker, instructions: Sequence[Instruction] ) -> Tuple[object, List[Instruction]]: stolen_regs = [] - preallocated_insts = [] - - # TODO - + # Sweet copy here, just in case + preallocated_insts = list(instructions) return stolen_regs, preallocated_insts @abc.override @@ -152,6 +150,8 @@ class ARMv8_IR_Builder(IR_Builder): out: List[Instruction] = [] # TODO: initialize vect reg values + out += self.emit_push_to_stack(reg_values) + return out @abc.override @@ -161,6 +161,7 @@ class ARMv8_IR_Builder(IR_Builder): free_regs: Sequence[Register], ) -> List[Instruction]: out = [] + # TODO: initialize reg values return out diff --git a/src/pipedream/asm/armv8a/asmwriter.py b/src/pipedream/asm/armv8a/asmwriter.py index 8452691..6159641 100644 --- a/src/pipedream/asm/armv8a/asmwriter.py +++ b/src/pipedream/asm/armv8a/asmwriter.py @@ -134,7 +134,7 @@ def _(arg): @renderARMv8_arg.register(ARMv8_Register_Operand) def _(arg): - return arg.register.att_asm_name + return arg.register.asm_name @renderARMv8_arg.register(ARMv8_Immediate_Operand) diff --git a/src/pipedream/asm/armv8a/registers.py b/src/pipedream/asm/armv8a/registers.py index e65b0a9..1df411e 100644 --- a/src/pipedream/asm/armv8a/registers.py +++ b/src/pipedream/asm/armv8a/registers.py @@ -199,7 +199,7 @@ class ARMv8_Register(Register): return self def freeze(self): - self._aliases = tuple(self._aliases | self._subs | self._supers) + self._aliases = tuple((self._aliases | self._subs | self._supers) - set([self])) self._subs = tuple(self._subs) self._supers = tuple(self._supers) @@ -324,7 +324,7 @@ def recurse_vect_fp_reg_add( reg._widest = new_vect same_size_vect.append(new_vect) - # Adding Vn.xY: combined vector composed of vect of size Y = `fp_type2` (x + # Adding Vn.xY: combined vector composed of vect of size Y = `vect_type` (x # integer automatically computed) if vect_idx == 0: if size_idx in [3, 4]: diff --git a/src/pipedream/benchmark/common.py b/src/pipedream/benchmark/common.py index baeaf8d..e0a3d0e 100644 --- a/src/pipedream/benchmark/common.py +++ b/src/pipedream/benchmark/common.py @@ -1306,7 +1306,6 @@ class _Benchmark_Runner: ## first generate code benchmark (but don't actually put it in ASM yet) kernel_code: ty.List[_ASM_Builder.ASM_Statement] - kernel_insts: ty.List[ir.Instruction] kernel_code, fully_allocated_kernel = self._gen_inner_benchmark_loop( out, benchmark, @@ -1459,7 +1458,6 @@ class _Benchmark_Runner: out.comment("*" * 3, "END FUNCTION", fn_name) out.comment("*" * 70) out.newline() - return out.take_code() def _gen_inner_benchmark_loop( @@ -1482,7 +1480,7 @@ class _Benchmark_Runner: goes invertedly for used_register. WARNING: `used_register` only cancels `unused_registers` (i.e. a register being in both sets is supposed to be currently used and will not be used - when generating + when generating) """ for reg in unused_registers: out.free_reg(reg) @@ -1628,31 +1626,40 @@ class _Benchmark_Runner: gen_papi_calls=gen_papi_calls, debug=debug, ) - + i = 0 for stmt in asm: stmt.emit(asm_writer) + i += 1 benchmark_functions[benchmark] = fn_name asm_writer.end_file(asm_file) + # subprocess.check_call(["cat", asm_file]) + with open(asm_file, "r") as file: + for line in file: + print(line, end="") + self.info("assemble benchmark library") + # Temporary setup for aarch64 testing subprocess.check_call( - ["as", asm_file, "-o", obj_file], + ["aarch64-linux-gnu-as", asm_file, "-o", obj_file], ) + # Temporary setup for aarch64 testing ld_args = [ - "ld", + "aarch64-linux-gnu-ld", "-shared", obj_file, ] - if gen_papi_calls: - ld_args += ["-lpapi"] + # Temporary setup for aarch64 testing + # if gen_papi_calls: + # ld_args += ["-lpapi"] ld_args += ["-o", lib_file] subprocess.check_call(ld_args) - + exit(0) shared_lib = ctypes.CDLL(lib_file) os.unlink(lib_file) if is_lib_tmp: @@ -2162,7 +2169,6 @@ class _ASM_Builder: """ Splice in code generated by another ASM builder. """ - self._code.extend(code) def emit_asm(self, writer: ASM_Writer): diff --git a/tools/extract-binutils-instruction-database/extract_arm_db/extract_binutils.py b/tools/extract-binutils-instruction-database/extract_arm_db/extract_binutils.py index 168a0d6..ed12638 100644 --- a/tools/extract-binutils-instruction-database/extract_arm_db/extract_binutils.py +++ b/tools/extract-binutils-instruction-database/extract_arm_db/extract_binutils.py @@ -58,9 +58,13 @@ def brace_split(text, split_on=","): def convert_register_name(binop_name: str) -> str: - # `S_x` correspond (apparently) to an SIMD vector element of type `x` + # `S_x` corresponds (apparently) to an FP element of type `x` if binop_name.startswith("S_"): - return f"V{binop_name[len('S_'):]}" + # ... Probably... + if binop_name[2].isdigit(): + return f"V{binop_name[len('S_'):]}" + else: + return f"{binop_name[len('S_'):]}" return binop_name -- GitLab From 978bee75ac0f5eab70d4f3ace70d5b4a5cc1a24f Mon Sep 17 00:00:00 2001 From: Nicolas Derumigny <nderumigny@gmail.com> Date: Mon, 2 Aug 2021 11:28:06 +0200 Subject: [PATCH 12/12] asmwriter, registers: updated according to reviews --- src/pipedream/asm/armv8a/asmwriter.py | 2 +- src/pipedream/asm/armv8a/registers.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/pipedream/asm/armv8a/asmwriter.py b/src/pipedream/asm/armv8a/asmwriter.py index 6159641..594242b 100644 --- a/src/pipedream/asm/armv8a/asmwriter.py +++ b/src/pipedream/asm/armv8a/asmwriter.py @@ -47,7 +47,7 @@ class ARMv8_ASM_Writer(ASM_Writer): @abc.override def align(self): - self.print(1, ".align 16, 0x90") + self.print(1, ".p2align 4") @abc.override def insts(self, insts: Sequence[Instruction]): diff --git a/src/pipedream/asm/armv8a/registers.py b/src/pipedream/asm/armv8a/registers.py index 1df411e..84a0872 100644 --- a/src/pipedream/asm/armv8a/registers.py +++ b/src/pipedream/asm/armv8a/registers.py @@ -1,4 +1,5 @@ from enum import Enum, auto +from dataclass import dataclass from typing import Optional, Set, Union, Tuple, List, FrozenSet, Dict, cast from pipedream.utils import abc from pipedream.asm.ir import * @@ -15,9 +16,9 @@ __all__ = [ # Wrapper to allow multiple values in `Enum` (i.e. `Enum.A == Enum.B` but # `Enum.A is not Enum.B`) +@dataclass(frozen=True) class Unique: - def __init__(self, value: int): - self.value = value + value: int class _ARMv8_REGISTER_TYPE(Enum): -- GitLab