Jump to content

Module:CineMol/geometry

From Wikipedia, the free encyclopedia
-- This is a port of CineMol to lua
-- CineMol https://github.com/moltools/CineMol was written by David Meijer, Marnix H. Medema & Justin J. J. van der Hooft and is MIT licensed
-- Please consider any edits I make to this page also dual licensed MIT & CC-BY-SA 4.0

-- Note, because "end" is a reserved keyword in lua, we rename "end" to "endp"
local libraryUtil = require( 'libraryUtil' )
local p = {}

-- Only supports numbers, table and custom types.
local function checkType( name, argIndex, arg, argType )
	if argType == 'number' or argType == 'table' or argType == 'string' then
		libraryUtil.checkType( name, argIndex, arg, argType, false )
	else
		libraryUtil.checkType( name, argIndex, arg, "table", false )
		-- perf opt. Do not compute the string.format unless the assert would fail.
		if arg._TYPE ~= argType then
			assert(
				arg._TYPE == argType,
				string.format( "Argument #%d of %s expected to be of type %s but got type %s.", argIndex, name, argType, tostring(arg._TYPE) )
			)
		end
	end
end

-- Doesn't quite belong here, but is useful and have to put it somewhere.
p.checkType = checkType

local vector3D_methods = {}

function vector3D_methods:length()
	checkType( "Vector3D:length", 1, self, "Vector3D" )
	return math.sqrt(self.x^2 + self.y^2 + self.z^2)
end

function vector3D_methods:normalize()
	checkType( "Vector3D:normalize", 1, self, "Vector3D")
	if self:length() == 0 then
		return p.Vector3D(0, 0, 0)
	end
	return self:multiply(1 / self:length())
end

function vector3D_methods:dot(other)
	-- do not check types in hot func.
	--checkType( "Vector3D:dot", 1, self, "Vector3D")
	--checkType( "Vector3D:dot", 2, other, "Vector3D")
	return self.x * other.x + self.y * other.y + self.z * other.z
end

function vector3D_methods:cross(other)
	checkType( "Vector3D:cross", 1, self, "Vector3D")
	checkType( "Vector3D:cross", 2, other, "Vector3D")
    return p.Vector3D(
        self.y * other.z - self.z * other.y,
        self.z * other.x - self.x * other.z,
        self.x * other.y - self.y * other.x
    )
end

function vector3D_methods:subtract(other)
	checkType( "Vector3D:subtract", 1, self, "Vector3D")
	checkType( "Vector3D:subtract", 2, other, "Vector3D")
	return p.Vector3D(self.x - other.x, self.y - other.y, self.z - other.z)
end

function vector3D_methods:multiply(scalar)
	checkType( "Vector3D:multiply", 1, self, "Vector3D")
	libraryUtil.checkType( "Vector3D:multiply", 2, scalar, "number", false )
	return p.Vector3D(self.x * scalar, self.y * scalar, self.z * scalar)
end
local Vector3D_meta = {__index = vector3D_methods}
function p.Vector3D(x, y, z)
	-- Type checking was taking too long on profile for some heavily used functions
	--checkType( "Vector3D", 1, x, "number", false )
	--checkType( "Vector3D", 2, y, "number", false )
	--checkType( "Vector3D", 3, z, "number", false )
	local obj = {
		x = x,
		y = y,
		z = z,
		_TYPE = 'Vector3D'
	}
	setmetatable( obj, Vector3D_meta )
	return obj
end

-- Guess we could use a metatable was __call to make static method, but this just seems easier
function p.Vector3D_create_random()
	-- Note: always has same seed so not really random
	return p.Vector3D( math.random(), math.random(), math.random() )
    -- Sometimes its useful for debugging to make this a fixed value in order to diff with python implementation.
	--return p.Vector3D( 0.1, 0.2, 0.3 )
end

local Point2D_methods = {}
function Point2D_methods:subtract_point(other)
    -- Disable type checking on hot funcs
	--checkType( "Point2D:subtract_point", 1, self, "Point2D" )
	--checkType( "Point2D:subtract_point", 2, other, "Point2D" )
	return p.Point2D(self.x - other.x, self.y - other.y)
end

function Point2D_methods:cross(other)
    -- Disable type checking on hot funcs
	--checkType( "Point2D:cross", 1, self, "Point2D" )
	--checkType( "Point2D:cross", 2, other, "Point2D" )
	return self.x * other.y - self.y * other.x
end
local Point2D_metatable = { __index = Point2D_methods }
function p.Point2D(x,y)
	-- Disable type checking on heavily used funcs for perfromance
	--libraryUtil.checkType( "Point2D", 1, x, "number", false )
	--libraryUtil.checkType( "Point2D", 1, y, "number", false )
	local obj = {
		x = x,
		y = y,
		_TYPE = 'Point2D'
	}
	setmetatable(obj, Point2D_metatable)
	return obj
end

local Point3D_methods = {}
-- Create a vector from this point to another point.
function Point3D_methods:create_vector(other)
    -- Disable type checking on hot funcs
	--checkType( "Point3D:create_vector", 1, self, "Point3D" )
	--checkType( "Point3D:create_vector", 2, other, "Point3D" )
	return p.Vector3D(other.x - self.x, other.y - self.y, other.z - self.z)
end

-- Calculate the distance between this point and another point.
function Point3D_methods:calculate_distance(other)
	checkType( "Point3D:calculate_distance", 1, self, "Point3D" )
	checkType( "Point3D:calculate_distance", 2, other, "Point3D" )
	return self:create_vector(other):length()
end

-- Calculate the midpoint between this point and another point.
function Point3D_methods:midpoint(other)
	checkType( "Point3D:midpoint", 1, self, "Point3D" )
	checkType( "Point3D:midpoint", 2, other, "Point3D" )
	return Point3D((self.x + other.x) / 2, (self.y + other.y) / 2, (self.z + other.z) / 2)
end

-- Add a vector to this point.
function Point3D_methods:translate(vector)
	checkType( "Point3D:translate", 1, self, "Point3D" )
	checkType( "Point3D:translate", 2, vector, "Vector3D" )
	return p.Point3D(self.x + vector.x, self.y + vector.y, self.z + vector.z)
end

-- Rotate this point around the origin.
-- x,y,z in radians
function Point3D_methods:rotate(x, y, z)
	x = x == nil and 0 or x
	y = y == nil and 0 or y
	z = z == nil and 0 or z
	checkType( "Point3D:rotate", 1, self, "Point3D" )
	libraryUtil.checkType( "Point3D:rotate", 2, x, "number", false )
	libraryUtil.checkType( "Point3D:rotate", 2, y, "number", false )
	libraryUtil.checkType( "Point3D:rotate", 2, z, "number", false )

	-- Rotate around x-axis.
    local y1 = self.y * math.cos(x) - self.z * math.sin(x)
    local z1 = self.y * math.sin(x) + self.z * math.cos(x)

    -- Rotate around y-axis.
    local x2 = self.x * math.cos(y) + z1 * math.sin(y)
    local z2 = -self.x * math.sin(y) + z1 * math.cos(y)

    -- Rotate around z-axis.
    local x3 = x2 * math.cos(z) - y1 * math.sin(z)
    local y3 = x2 * math.sin(z) + y1 * math.cos(z)

    return p.Point3D(x3, y3, z2)

end
local Point3D_metatable = { __index = Point3D_methods }

function p.Point3D(x,y,z)
	-- For performance, disable type checking on highly used funcs
	--libraryUtil.checkType( "Point3D", 1, x, "number", false )
	--libraryUtil.checkType( "Point3D", 2, y, "number", false )
	--libraryUtil.checkType( "Point3D", 3, z, "number", false )
	local obj = {
		x = x,
		y = y,
		z = z,
		_TYPE = 'Point3D'
	}
	setmetatable(obj, Point3D_metatable)
	return obj
end

-- ==================
-- Helper functions
-- ==================

-- Return the sign of a number.
local function sign(x)
	checkType( 'sign', 1, x, 'number' )
	if x < 0 then
		return -1
	elseif x > 0 then
		return 1
	else
		return 0
	end
end

-- Generate two orthogonal vectors for a given vector using the Gram-Schmidt process.
function p.gram_schmidt(n)
	checkType( 'gram_schmidt', 1, n, 'Vector3D' )
	local v = p.Vector3D_create_random()
	v = v:subtract(n:multiply(v:dot(n))):normalize()
	local w = n:cross(v)
	return v,w
end

-- =======
-- Shape definitions
-- =======

function p.Line3D(start, endp)
	checkType("Line3D", 1, start, "Point3D")
	checkType("Line3D", 2, endp, "Point3D")
	return {
		start = start,
		endp = endp,
		_TYPE = 'Line3D'
	}
end

function p.Plane3D(center, normal)
	checkType("Plane3D", 1, center, "Point3D")
	checkType("Plane3D", 2, normal, "Vector3D")
	return {
		center = center,
		normal = normal,
		_TYPE = 'Plane3D'
	}
end

function p.Circle3D(center, radius, normal)
	checkType("Circle3D", 1, center, "Point3D")
	checkType("Circle3D", 2, radius, "number")
	checkType("Circle3D", 3, normal, "Vector3D")
	return {
		center = center,
		normal = normal,
		radius = radius,
		_TYPE = 'Circle3D'
	}
end


function p.Sphere(center, radius)
	checkType("Sphere", 1, center, "Point3D")
	checkType("Sphere", 2, radius, "number")
	return {
		center = center,
		radius = radius,
		_TYPE = 'Sphere'
	}
end

p.CylinderCapType = {
	NO_CAP = 1,
	FLAT = 2,
	ROUND = 3
}

function p.Cylinder(start, endp, radius, cap_type)
	checkType( "Cylinder", 1, start, "Point3D" )
	checkType( "Cylinder", 2, endp, "Point3D" )
	checkType( "Cylinder", 3, radius, "number" )
	checkType( "Cylinder", 4, cap_type, "number" )
	assert( cap_type <= 3, "Argument 4 of Cylinder must be a cap type got " .. cap_type )
	assert( cap_type >= 1, "Argument 4 of Cylinder must be a cap type got " .. cap_type )
	return {
		start = start,
		endp = endp,
		radius = radius,
		cap_type = cap_type,
		_TYPE = "Cylinder"
	}
end

-- Check if two points are on the same side of a plane.
function p.same_side_of_plane(plane, p1, p2)
	checkType( "same_side_of_plane", 1, plane, "Plane3D" )
	checkType( "same_side_of_plane", 2, p1, "Point3D" )
	checkType( "same_side_of_plane", 3, p2, "Point3D" )
	local left = sign(p1:create_vector(plane.center):dot(plane.normal))
	local right = sign(p2:create_vector(plane.center):dot(plane.normal))
	return left == right
end

-- Compute the distance from a point to the line.
function p.distance_to_line( line, point )
	checkType( "distance_to_line", 1, line, "Line3D" )
	checkType( "distance_to_line", 2, point, "Point3D" )

	local d = line.endp:create_vector(line.start):normalize()
    local s = line.start:create_vector(point):dot(d)
    local t = point:create_vector(line.endp):dot(d)
    local h = math.max(s, t, 0.0)
    local c = point:create_vector(line.start):cross(d):length()
    return math.sqrt(h * h + c * c)
end

-- Split current line into multiple perpendicular lines.
function p.get_perpendicular_lines( line, width, num_lines )
	checkType( "get_perpendicular_lines", 1, line, "Line3D" )
	checkType( "get_perpendicular_lines", 2, width, "number" )
	checkType( "get_perpendicular_lines", 3, num_lines, "number" )

	assert( num_lines >= 1, "Number of lines must be greater than 0." )
	assert( width > 0, "Width must be greater than 0." )

    if num_lines == 1 then
        return {line}
	end

    -- Ignore z-axis and get a vector perpendicular to the line to translate start
    -- and end points on.
    local v = p.Vector3D(line.endp.y - line.start.y, line.start.x - line.endp.x, 0.0):normalize()

    -- Get new start and end points, but centroid of starts and ends should always
    -- be original start and end.
    local start = line.start:translate(v:multiply(-width * (num_lines - 1) / 2))
    local endp = line.endp:translate(v:multiply(-width * (num_lines - 1) / 2))

    -- Get new lines.
    local lines = {}
    for _ = 1, num_lines do
        table.insert(lines, p.Line3D(start, endp))
        start = start:translate(v:multiply(width))
        endp = endp:translate(v:multiply(width))
	end

    return lines

end

-- Generate `num_points` + 1 points along a line.
function p.get_points_on_line_3d(line, num_points)
	checkType( "get_points_on_line3d", 1, line, 'Line3D' )
	checkType( "get_points_on_line3d", 2, num_points, 'number' )
	local s_cx = line.start.x
	local s_cy = line.start.y
	local s_cz = line.start.z
	local e_cx = line.endp.x
	local e_cy = line.endp.y
	local e_cz = line.endp.z

	local points = {}
	for i = 0, num_points do
        local interpolation_factor = i / num_points
        local point = p.Point3D(
            s_cx + (e_cx - s_cx) * interpolation_factor,
            s_cy + (e_cy - s_cy) * interpolation_factor,
            s_cz + (e_cz - s_cz) * interpolation_factor
        )
        table.insert(points, point)
	end
	return points
end

-- Generate points on the circumference of the circle.
function p.get_points_on_circumference_circle_3d( circle, num_points )
	checkType( "get_points_on_circle3d", 1, circle, 'Circle3D' )
	checkType( "get_points_on_circle3d", 2, num_points, 'number' )

	local cos_angles = {}
	local sin_angles = {}
	for i = 0, num_points - 1 do
		cos_angles[#cos_angles+1] = math.cos( 2 * math.pi * i / num_points )
		sin_angles[#sin_angles+1] = math.sin( 2 * math.pi * i / num_points )
	end

    local normal = circle.normal:normalize()
    local v, w = p.gram_schmidt(normal)

    local cx, cy, cz = circle.center.x, circle.center.y, circle.center.z
    local r = circle.radius

    local points = {}
    for i = 1, #cos_angles do
		local cos_a = cos_angles[i]
		local sin_a = sin_angles[i]
        local point = p.Point3D(
            cx + r * cos_a * v.x + r * sin_a * w.x,
            cy + r * cos_a * v.y + r * sin_a * w.y,
            cz + r * cos_a * v.z + r * sin_a * w.z
        )
        points[#points+1] = point
	end
    return points

end

function p.get_points_on_surface_circle_3d( circle, num_radii, num_points )
	checkType( "get_points_on_surface_circle_3d", 1, circle, 'Circle3D' )
	checkType( "get_points_on_surface_circle_3d", 2, num_radii, 'number' )
	checkType( "get_points_on_surface_circle_3d", 3, num_points, 'number' )


    local points = {}
	for i = 0, num_radii - 1 do
		local radius = circle.radius * i / num_radii
        local temp_circle = p.Circle3D(circle.center, radius, circle.normal)
		local newPoints = p.get_points_on_circumference_circle_3d(temp_circle, num_points)
		for j = 1, #newPoints do
			points[#points+1] = newPoints[j]
		end
	end
    return points
end

function p.get_points_on_surface_sphere( sphere, num_phi, num_theta, filter_for_pov )
	filter_for_pov = filter_for_pov == nil and true or filter_for_pov
	checkType( "get_points_on_surface_sphere", 1, sphere, 'Sphere' )
	checkType( "get_points_on_surface_sphere", 2, num_phi, 'number' )
	checkType( "get_points_on_surface_sphere", 3, num_theta, 'number' )
	libraryUtil.checkType( "get_points_on_surface_sphere", 4, filter_for_pov, 'boolean' )

    local phis = {}
	for i = 0, num_phi do
		phis[#phis+1] = 2 * math.pi * i / num_phi
	end
    local thetas = {}
	for i = 0, num_theta do
		thetas[#thetas+1] = math.pi * i / num_theta
	end

    local cx, cy, cz = sphere.center.x, sphere.center.y, sphere.center.z
    local r = sphere.radius

    local points = {}
    for _, theta in ipairs(thetas) do
        for __, phi in ipairs(phis) do
            local x = cx + r * math.sin(theta) * math.cos(phi)
            local y = cy + r * math.sin(theta) * math.sin(phi)
            local z = cz + r * math.cos(theta)

            -- Only add points that are on the surface of the sphere we can see.
            if not filter_for_pov then
                table.insert( points, p.Point3D(x, y, z))
			elseif z >= sphere.center.z then
              -- Check if point is on the surface of the sphere we can see.
              table.insert( points, p.Point3D(x, y, z))
			end
		end
	end

    return points

end

-- Generate points on the surface of the cap.
function p.get_points_on_surface_cap( cap_type, center_cap, radius_cap, normal_cap, center_cylinder, resolution, filter_for_pov)
	filter_for_pov = filter_for_pov == nil and true or filter_for_pov
	checkType( "get_points_on_surface_cap", 1, cap_type, 'number' )
	checkType( "get_points_on_surface_cap", 2, center_cap, 'Point3D' )
	checkType( "get_points_on_surface_cap", 3, radius_cap, 'number' )
	checkType( "get_points_on_surface_cap", 4, normal_cap, 'Vector3D' )
	checkType( "get_points_on_surface_cap", 5, center_cylinder, 'Point3D' )
	checkType( "get_points_on_surface_cap", 6, resolution, 'number' )
	libraryUtil.checkType( "get_points_on_surface_cap", 7, filter_for_pov, 'boolean' )

	if cap_type == p.CylinderCapType.NO_CAP then
		return {}
	elseif cap_type == p.CylinderCapType.FLAT then
        local circle = p.Circle3D(center_cap, radius_cap, normal_cap)
        return p.get_points_on_circumference_circle_3d(circle, resolution)
	elseif cap_type == p.CylinderCapType.ROUND then
        local sphere = p.Sphere(center_cap, radius_cap)
        local plane = p.Plane3D(center_cap, normal_cap)

        local points_tmp = p.get_points_on_surface_sphere(
            sphere, resolution, resolution, filter_for_pov
        )

		local points = {}
		for i,point in ipairs( points_tmp ) do
			if not p.same_side_of_plane(plane, center_cylinder, point) then
				points[#points+1] = point
			end
		end

        return points

	else
		error( "Invalid CAP type: " .. cap_type )
	end
end

-- Generate points on the surface of the cylinder.
function p.get_points_on_surface_cylinder( cylinder, resolution )
	checkType( "get_points_on_surface_cylinder", 1, cylinder, 'Cylinder' )
	checkType( "get_points_on_surface_cylinder", 2, resolution, 'number' )

	local normal = cylinder.endp:create_vector(cylinder.start):normalize()
    local centers = p.get_points_on_line_3d(p.Line3D(cylinder.start, cylinder.endp), resolution)

    local points = {}
    for _,center in ipairs(centers) do
        local circle = p.Circle3D(center, cylinder.radius, normal)
        local tempPoints = p.get_points_on_circumference_circle_3d(circle, resolution)
		for i = 1, #tempPoints do
			points[#points+1] = tempPoints[i]
		end
	end

    -- Get points on the caps.
    local cap_type = cylinder.cap_type

    local cap_points = p.get_points_on_surface_cap(
        cap_type, cylinder.start, cylinder.radius, normal, cylinder.endp, resolution, false
    )
	for i = 1, #cap_points do
		points[#points+1] = cap_points[i]
	end

    cap_points = p.get_points_on_surface_cap(
        cap_type, cylinder.endp, cylinder.radius, normal, cylinder.start, resolution, false
    )
	for i = 1, #cap_points do
		points[#points+1] = cap_points[i]
	end

    return points
end

-- ===================
-- Check if points are inside a shape
-- ===================

function p.point_is_inside_sphere( sphere, point )
	checkType( 'point_is_inside_sphere', 1, sphere, 'Sphere' )
	checkType( 'point_is_inside_sphere', 2, point, 'Point3D' )
	return sphere.center:calculate_distance(point) <= sphere.radius
end

function p.point_is_inside_cylinder( cylinder, point )
	checkType( 'point_is_inside_cylinder', 1, cylinder, 'Cylinder' )
	checkType( 'point_is_inside_cylinder', 2, point, 'Point3D' )

    local line = p.Line3D(cylinder.start, cylinder.endp)
    local dist = p.distance_to_line(line, point)
    local cap_type = cylinder.cap_type

    if cap_type == p.CylinderCapType.ROUND then
        return dist <= cylinder.radius
    elseif cap_type == p.CylinderCapType.FLAT or cap_type == p.CylinderCapType.NO_CAP then
        local normal = cylinder.endp:create_vector(cylinder.start):normalize()
        local plane1 = p.Plane3D(cylinder.start, normal)
        local plane2 = p.Plane3D(cylinder.endp, normal)
        local is_between_planes = p.same_side_of_plane(plane1, point, cylinder.endp) and p.same_side_of_plane(
            plane2, point, cylinder.start
        )
        return dist <= cylinder.radius and is_between_planes
    else
        error("Unknown cap type: " .. cap_type)
	end

end

-- ===============
-- Check if shapes intersect
-- ===============

-- Check if two spheres intersect.
function p.sphere_intersects_with_sphere( sphere1, sphere2 )
	checkType( 'sphere_intersects_with_sphere', 1, sphere1, "Sphere" )
	checkType( 'sphere_intersects_with_sphere', 2, sphere2, "Sphere" )

	local c1, r1 = sphere1.center, sphere1.radius
    local c2, r2 = sphere2.center, sphere2.radius
    return c1:calculate_distance(c2) <= r1 + r2
end

-- Check if a sphere intersects with a cylinder.
function p.sphere_intersects_with_cylinder( sphere, cylinder )
	checkType( 'sphere_intersects_with_cylinder', 1, sphere, "Sphere" )
	checkType( 'sphere_intersects_with_cylinder', 2, cylinder, "Cylinder" )
    local d = sphere.radius + cylinder.radius
    return (
        sphere.center:calculate_distance(cylinder.start) <= d
        or sphere.center:calculate_distance(cylinder.endp) <= d
    )
end

-- Check if two cylinders intersect.
function p.cylinder_intersects_with_cylinder( cylinder1, cylinder2 )
	checkType( 'cylinder_intersects_with_cylinder', 1, cylinder1, "Cylinder" )
	checkType( 'cylinder_intersects_with_cylinder', 2, cylinder2, "Cylinder" )

	local d = cylinder1.radius + cylinder2.radius
    return (
        cylinder1.start:calculate_distance(cylinder2.start) <= d
        or cylinder1.start:calculate_distance(cylinder2.endp) <= d
        or cylinder1.endp:calculate_distance(cylinder2.start) <= d
        or cylinder1.endp:calculate_distance(cylinder2.endp) <= d
    )
end

return p