Jump to content

Module:Lua class

Permanently protected module
From Wikipedia, the free encyclopedia
This is an old revision of this page, as edited by Alexiscoutinho (talk | contribs) at 18:48, 9 June 2021 (Created beta version of module to handle object-oriented Lua classes (based on Python and C++)). The present address (URL) is a permanent link to this revision, which may differ significantly from the current revision.
(diff) ← Previous revision | Latest revision (diff) | Newer revision → (diff)

local classes, instances = {}, {} -- registry of all complete/internal class and instance objects (with some exceptions)
local inst_private_mts, inst_public_mts = {}, {} -- for each class since they are immutable

local una_metamethods = {__ipairs=1, __pairs=1, __tostring=1, __unm=1}
local bin_metamethods = {__add=1, __concat=1, __div=1, __eq=1, __le=1, __lt=1, __mod=1, __mul=1, __pow=1, __sub=1}
local oth_metamethods = {__call=1, __index=1, __newindex=1, __init=1}
local not_metamethods = {__name=1, __bases=1, __methods=1, __protected=1, __class=1}


local function private_read(self_private, key)
	if not not_metamethods[key] then
		return instances[self_private][key]
	end
	error('unauthorized read attempt of internal "' .. key .. '"')
end

local function private_read_custom(self_private, key)
	if not not_metamethods[key] then
		local self = instances[self_private]
		local value = self.__class.__index(self_private, key)
		if value == nil then
			value = self[key]
		end
		return value
	end
	error('unauthorized read attempt of internal "' .. key .. '"')
end

local function private_write(self_private, key, value)
	local self = instances[self_private]
	if not self.__class.__methods[key] and key.sub(1,2) ~= '__' then
		self[key] = value
	end
	error('forbidden write attempt {' .. key .. ': ' .. value .. '} to immutable instance method')
end

local function private_write_custom(self_private, key, value)
	local self = instances[self_private]
	if not self.__class.__methods[key] and key.sub(1,2) ~= '__' then
		if not self.__class.__newindex(self_private, key, value) then
			self[key] = value
		end
	end
	error('forbidden write attempt {' .. key .. ': ' .. value .. '} to immutable instance method')
end

local inst_mt = {
	__index = function (self, key)
		return self.__class[key]
	end
}

local function public_read(self_public, key)
	if key.sub(1,1) ~= '_' then
		return instances[instances[self_public]][key]
	end
	error('unauthorized read attempt of nonpublic "' .. key .. '"')
end

local function public_read_custom(self_public, key)
	if key.sub(1,1) ~= '_' then
		local self = instances[instances[self_public]]
		local value = self.__class.__index(instances[self_public], key)
		if value == nil then
			value = self[key]
		end
		return value
	end
	error('unauthorized read attempt of nonpublic "' .. key .. '"')
end

local function public_write(self_public, key, value)
	if key.sub(1,1) ~= '_' then
		local self = instances[instances[self_public]]
		if not self.__class.__methods[key] then
			if self[key] ~= nil then
				self[key] = value
			else
				error('public item creation attempt {' .. key .. ': ' .. value .. '} (currently not allowed)')
			end
		else
			error('forbidden write attempt {' .. key .. ': ' .. value .. '} to immutable instance method')
		end
	else
		error('unauthorized write attempt of nonpublic "' .. key .. '"')
	end
end

local function public_write_custom(self_public, key, value)
	if key.sub(1,1) ~= '_' then
		local self = instances[instances[self_public]]
		if not self.__class.__methods[key] then
			if not self.__class.__newindex(instances[self_public], key, value) then
				if self[key] ~= nil then
					self[key] = value
				else
					error('public item creation attempt {' .. key .. ': ' .. value .. '} (currently not allowed)')
				end
			end
		else
			error('forbidden write attempt {' .. key .. ': ' .. value .. '} to immutable instance method')
		end
	else
		error('unauthorized write attempt of nonpublic "' .. key .. '"')
	end
end

local function constructor(wrapper, ...)
	if select('#', ...) > 1 then
		error('incorrect instance constructor syntax, should be: Class{arg1, arg2..., kw1=kwarg1, kw2=kwarg2...}')
	end
	local kwargs = ({...})[1]
	local self = {} -- __new
	local cls = classes[wrapper]
	self.__class = cls

	local self_private = {} -- wrapper

	local mt = inst_private_mts[cls]
	if not mt then
		mt = {}
		mt.__index = cls.__index and private_read_custom or private_read
		mt.__newindex = cls.__newindex and private_write_custom or private_write
		for key in pairs(una_metamethods) do
			mt[key] = cls[key]
		end
		mt.__call = cls.__call
		mt.__metatable = 'unauthorized access attempt of wrapper object metatable'

		inst_private_mts[cls] = mt
	end

	setmetatable(self_private, mt)
	instances[self_private] = self

	local __init = cls.__init
	if __init and __init(self_private, kwargs) then
		error('__init must not return a var-list')
	end

	for key in pairs(cls.__methods) do
		self[key] = function (...) return cls[key](self_private, ...) end
	end

	setmetatable(self, inst_mt)

	local self_public = {}

	mt = inst_public_mts[cls]
	if not mt then
		mt = {}
		mt.__index = cls.__index and public_read_custom or public_read
		mt.__newindex = cls.__newindex and public_write_custom or public_write
		for key in pairs(una_metamethods) do
			if cls[key] then
				mt[key] = function (a) return cls[key](instances[a]) end
			end
		end
		for key in pairs(bin_metamethods) do
			if cls[key] then
				mt[key] = function (a, b) return cls[key](instances[a], instances[b]) end
			end
		end
		mt.__call = function (self_public, ...) return cls.__call(instances[self_public], ...) end
		mt.__metatable = 'unauthorized access attempt of wrapper object metatable'

		inst_public_mts[cls] = mt
	end

	setmetatable(self_public, mt)
	instances[self_public] = self_private -- because metamethod wrappers require it
	return self_public
end


local function multi_inheritance(cls, key)
	for i, base in ipairs(cls.__bases) do
		if key.sub(1,1) ~= '_' or base.__protected[key] or key.sub(1,2) == '__' then
			local value = base[key]
			if value ~= nil then
				return value
			end
		end
	end
end

local cls_mt = {
	__index = multi_inheritance
}

local function forbidden_write(wrapper, key, value)
	error('forbidden write attempt {' .. key .. ': ' .. value .. '} to immutable class')
end

local cls_private_mt = {
	__call = constructor,
	__index = function (cls_private, key)
		if not not_metamethods[key] then
			local value = classes[cls_private][key]
			if type(value) == 'table' then
				return mw.clone(value) -- because classes are immutable
			end
			return value
		end
		error('unauthorized read attempt of internal "' .. key .. '"')
	end,
	__newindex = forbidden_write,
	__metatable = 'unauthorized access attempt of wrapper object metatable'
}

local cls_public_mt = {
	__call = constructor,
	__index = function (cls_public, key)
		if key.sub(1,1) ~= '_' then
			local value = classes[cls_public][key]
			if type(value) == 'table' then
				return mw.clone(value)
			end
			return value
		end
		error('unauthorized read attempt of nonpublic "' .. key .. '"')
	end,
	__newindex = forbidden_write,
	__metatable = 'unauthorized access attempt of wrapper object metatable'
}

function class(...)
	local args = {...}
	local cls = {} -- internal

	local idx
	if type(args[1]) == 'string' then
		cls.__name = args[1]
		idx = 2
	else
		idx = 1
	end

	cls.__bases = {}
	for i = idx, #args-1 do
		cls.__bases[#cls.__bases+1] = classes[args[i]]
	end

	local kwargs = args[#args]
	assert(kwargs, 'a (sub)class must have at least one method')
	if kwargs.__name or kwargs.__bases then
		error('__name and unpacked __bases must be passed as optional first args to "class"')
	end

	cls.__protected = {}
	if kwargs.__protected then
		for i, key in ipairs(kwargs.__protected) do
			cls.__protected[key] = 1
		end
		kwargs.__protected = nil
	end
	local mt = {
		__index = function (__protected, key) -- multi_inheritance
			for i, base in ipairs(cls.__bases) do
				if base.__protected[key] then
					return 1
				end
			end
		end
	}
	setmetatable(cls.__protected, mt)

	if kwargs.__methods then
		error('__classmethods and __staticmethods should be passed as optional items instead of __methods')
	end

	if kwargs.__classmethods then
		local cls_private = {} -- wrapper
		setmetatable(cls_private, cls_private_mt)
		classes[cls_private] = cls

		for i, key in ipairs(kwargs.__classmethods) do
			cls[key] = function (...) return kwargs[key](cls_private, ...) end
			kwargs[key] = nil
		end
		kwargs.__classmethods = nil
	end

	local staticmethods = {}
	if kwargs.__staticmethods then
		for i, key in ipairs(kwargs.__staticmethods) do
			staticmethods[key] = 1
		end
		kwargs.__staticmethods = nil
	end

	cls.__methods = {}
	for i, base in ipairs(cls.__bases) do
		for key in pairs(base.__methods) do
			cls.__methods[key] = 1
		end
	end

	local valid = false
	for key, val in pairs(kwargs) do
		if key.sub(1,2) == '__' and not una_metamethods[key] and not bin_metamethods[key] and not oth_metamethods[key] then
			error('unauthorized or unrecognized metamethod "' .. key .. '"')
		end
		cls[key] = val
		if type(val) == 'function' then
			if not staticmethods[key] and key.sub(1,2) ~= '__' then
				cls.__methods[key] = 1
			end
			if key ~= '__init' then -- __init does not qualify to a functional/proper class
				valid = true
			end
		end
	end
	assert(valid, 'a (sub)class must have at least one method')

	setmetatable(cls, cls_mt)

	local cls_public = {} -- wrapper
	setmetatable(cls_public, cls_public_mt)
	classes[cls_public] = cls
	return cls_public
end


local function rissubclass1(class, classinfo)
	if class == classinfo then
		return true
	end
	for i, base in ipairs(class.__bases) do
		if rissubclass1(base, classinfo) then
			return true
		end
	end
	return false
end

local function rissubclass2(class, classinfo)
	if type(classinfo) == 'table' then
		if classes[classinfo] then
			return rissubclass1(class, classes[classinfo])
		end
		for i = 1, #classinfo do
			if rissubclass2(class, classinfo[i]) then
				return true
			end
		end
		return false
	end
	error('arg2 is neither a class nor recursive sequence of classes')
end

function issubclass(class, classinfo)
	class = classes[class]
	if class then
		return rissubclass2(class, classinfo)
	end
	error('arg1 is not a class')
end

function isinstance(instance, classinfo)
	instance = instances[instance]
	if instance then
		return rissubclass2((instances[instance] or instance).__class, classinfo)
	end
	error('arg1 is not an instance')
end

local type = type
function type(value)
	local t = type(value)
	if t == 'table' then
		if classes[value] then
			return 'class'
		elseif instances[value] then
			local instance = instances[instances[value]] or instances[value]
			return instance.__class.__name or 'instance'
		end
	end
	return t
end