Skip to content
Snippets Groups Projects
bitset.lus 2.98 KiB
include "utils.lus"

-- The empty bitset: `false^N`

-- A bitset with all elements: `true^N`

-- Returns the one-hot bitset containing only x.
function one_hot<<const N:int>>(x : int) returns (s : bool^N);
let
	assert(x >= 0 and x < N);
	s = map<<=,N>>(x^N, range<<N>>(true));
	assert(boolred<<1,1,N>>(s));
tel;

-- Returns the complement of a given bitset.
function complement<<const N:int>>(s : bool^N) returns (not_s : bool^N);
let
	not_s = map<<not,N>>(s);
tel;

-- Returns the number of elements of a bitset, aka its popcount.
function pop_count<<const N:int>>(s : bool^N) returns (c : int);
let
	c = red<<+,N>>(0, map<<int_of_bool,N>>(s));
tel;

-- Tests whether a bitset is empty or not.
function is_empty<<const N:int>>(s : bool^N) returns (y : bool);
let
	y = boolred<<0,0,N>>(s);
tel;

-- Bitset union.
function union<<const N:int>>(a, b : bool^N) returns (y : bool^N);
let
	y = map<<or,N>>(a, b);
tel;

-- Bitset intersection.
function inter<<const N:int>>(a, b : bool^N) returns (y : bool^N);
let
	y = map<<and,N>>(a, b);
tel;

-- Bitset difference: contains the elements of s1 that are not in s2.
function diff<<const N:int>>(s1, s2 : bool^N) returns (y : bool^N);
let
	y = inter<<N>>(s1, complement<<N>>(s2));
tel;

-- Returns a bitset containing all elements of s, plus x.
function set<<const N:int>>(s : bool^N; x : int) returns (y : bool^N);
let
	assert(x >= 0 and x < N);
	y = union<<N>>(s, one_hot<<N>>(x));
tel;

-- Returns a bitset containing all elements of s, except x.
function unset<<const N:int>>(s : bool^N; x : int) returns (y : bool^N);
let
	assert(x >= 0 and x < N);
	y = diff<<N>>(s, one_hot<<N>>(x));
tel;

-- no need to implement `equal` as the operator `=` works for bool arrays

-- Test if two bitsets are disjoint.
function disjoint<<const N:int>>(a, b : bool^N) returns (y : bool);
let
	y = is_empty<<N>>(inter<<N>>(a, b));
tel;

-- Tests whether the bitset s1 is a subset of the bitset s2.
function subset<<const N:int>>(s1, s2 : bool^N) returns (y : bool);
let
	y = boolred<<N,N,N>>(map<<=>,N>>(s1, s2));
tel;

/*
	Builds bitset of size N from a list of M elements, where each element in the
	list is treated as a bit position to be set to true in the resulting bitset.
*/
function bitset_of_list<<const N:int; const M:int>>(list : int^M)
returns (set : bool^N);
let
	set = red<<set<<N>>,M>>(false^N, list);
tel;

/*
	Returns the index of the first enumerated element present in the set.
	The result will be negative if the set is empty.
*/
function first_set<<const N:int>>(s : bool^N) returns (x : int);
var
	found : int;
let
	found = with (N = 1) then (if s[0] then 0 else -1)
	        else first_set<<N-1>>(s[1 .. N-1]);
	x = if s[0] then 0
	    else if found < 0 then -1
	    else found + 1;
tel;

/*
	Returns the index of the last enumerated element present in the set.
	The result will be negative if the set is empty.
*/
function last_set<<const N:int>>(s : bool^N) returns (x : int);
let
	x = with (N = 1) then (if s[0] then 0 else -1)
	    else (if s[N-1] then N-1 else last_set<<N-1>>(s[0 .. N-2]));
tel;