Jump to content

Module:CineMol/api

From Wikipedia, the free encyclopedia
-- This is a port of CineMol to lua
-- CineMol https://github.com/moltools/CineMol was written by David Meijer, Marnix H. Medema & Justin J. J. van der Hooft and is MIT licensed
-- Please consider any edits made to this page as dual licensed MIT & CC-BY-SA 4.0

-- This module provides functions for drawing molecules using atoms and bonds.

local p = {}

local geometry = require( 'Module:CineMol/geometry' )
local style = require( 'Module:CineMol/style' )
local cinemolsvg = require( 'Module:CineMol/svg' )
local model = require( 'Module:CineMol/model' )

local Cylinder = geometry.Cylinder
local CylinderCapType = geometry.CylinderCapType
local Line3D = geometry.Line3D 
local Point3D = geometry.Point3D 
local Sphere = geometry.Sphere
local get_perpendicular_lines = geometry.get_perpendicular_lines
local checkType = geometry.checkType


local Cartoon = style.Cartoon
local Color = style.Color
local Glossy = style.Glossy
local CoreyPaulingKoltungAtomColor = style.CoreyPaulingKoltungAtomColor
local PubChemAtomRadius = style.PubChemAtomRadius

local Svg = cinemolsvg.Svg
local ViewBox = cinemolsvg.ViewBox

local ModelCylinder = model.ModelCylinder
local ModelSphere = model.ModelSphere
local ModelWire = model.ModelWire
local Scene = model.Scene

p.Style = {
	SPACEFILLING = 1,
	BALL_AND_STICK = 2,
	TUBE = 3,
	WIREFRAME = 4
}
local Style = p.Style

p.Look = {
	CARTOON = 1,
	GLOSSY = 2
}
local Look = p.Look

-- Represents an atom in a molecule.
-- Note: color is not a color object but a plain [r,g,b] table
function p.Atom( index, symbol, coordinates, radius, color, opacity )
	checkType( 'Atom', 1, index, 'number' )
	checkType( 'Atom', 2, symbol, 'string' )
	checkType( 'Atom', 3, coordinates, 'table' )
	assert( #coordinates == 3, "Expected 3 coordinates" )
	opacity = opacity == nil and 1.0 or opacity
	checkType( 'Atom', 6, opacity, 'number' )

	return {
		_TYPE = 'Atom',
		index = index,
		symbol = symbol,
		coordinates = coordinates,
		color = color,
		opacity = opacity
	}
end

-- A bond between two atoms in a molecule.
-- color is not a Color but a plain table
function p.Bond( start_index, end_index, order, radius, color, opacity )
	opacity = opacity == nil and 1.0 or opacity
	checkType( 'Bond', 1, start_index, 'number' )
	checkType( 'Bond', 2, end_index, 'number' )
	checkType( 'Bond', 3, order, 'number' )
	checkType( 'Bond', 6, opacity, 'number' )

	return {
		_TYPE = 'Bond',
		start_index = start_index,
		end_index = end_index,
		order = order,
		radius = radius,
		color = color,
		opacity = opacity
	}

end

local function findInTable( table, needle )
	for i,v in ipairs( table ) do
		if needle == v then
			return v
		end
	end
	return false
end

-- Filter atoms and bonds based on the given exclude atoms.
function p.filter_atoms_and_bonds( atoms, bonds, exclude_atoms )
	checkType( 'filter_atoms_and_bonds', 1, atoms, 'table' )
	checkType( 'filter_atoms_and_bonds', 2, bonds, 'table' )

    if exclude_atoms ~= nil then
        local filtered_atoms, filtered_bonds, exclude_inds = {}, {}, {}

        for _, atom in ipairs( atoms ) do
            if findInTable( exclude_atoms, atom.symbol ) == false then
                filtered_atoms[#filtered_atoms+1] = atom
            else
                exclude_inds[#exclude_inds+1] = atom.index
            end
		end

        for _, bond in ipairs( bonds ) do
            if findInTable( exclude_inds, bond.start_index ) == false  and findInTable( exclude_inds, bond.end_index ) == false then
                filtered_bonds[#filtered_bonds+1] = bond
			end
		end

        return filtered_atoms, filtered_bonds
	end
    return atoms, bonds
end

-- Draw a molecule using the given atoms and bonds in wireframe style.
function p.draw_bonds_in_wireframe_style( scene, atoms, bonds, wire_width )
	wire_width = wire_width == nil and 0.05 or wire_width
	checkType( 'draw_bonds_in_wireframe_style', 1, scene, 'Scene' )
	checkType( 'draw_bonds_in_wireframe_style', 2, atoms, 'table' )
	checkType( 'draw_bonds_in_wireframe_style', 3, bonds, 'table' )
	checkType( 'draw_bonds_in_wireframe_style', 4, wire_width, 'number' )

	local atom_map = {}
	for i,v in ipairs( atoms ) do
		atom_map[v.index] = v	
	end

	for _, bond in ipairs(bonds) do
        -- Get start atom specifications.
        local start_atom = atom_map[bond.start_index]
        local start_symbol = start_atom.symbol
        local start_color = CoreyPaulingKoltungAtomColor:get_color(start_symbol)
        local start_pos = Point3D(start_atom.coordinates[1], start_atom.coordinates[2], start_atom.coordinates[3])

        -- Get end atom specifications.
        local end_atom = atom_map[bond.end_index]
        local end_symbol = end_atom.symbol
        local end_color = CoreyPaulingKoltungAtomColor:get_color(end_symbol)
        local end_pos = Point3D(end_atom.coordinates[1], end_atom.coordinates[2], end_atom.coordinates[3])

        -- Determine color of bond.
        if bond.color ~= nil then
            start_color = Color(bond.color[1], bond.color[2], bond.color[3])
            end_color = Color(bond.color[1], bond.color[2], bond.color[3])
        else
            if start_atom.color ~= nil then
                start_color = Color(start_atom.color[1], start_atom.color[2], start_atom.color[3])
			end

            if end_atom.color ~= nil then
                end_color = Color(end_atom.color[1], end_atom.color[2], end_atom.color[3])
			end
		end

        -- If colors are not the same we split the bond down the middle, and
        -- draw two separate wires to represent the bond.
        if start_color ~= end_color then
            local middle_pos = Point3D(
                (start_pos.x + end_pos.x) / 2,
                (start_pos.y + end_pos.y) / 2,
                (start_pos.z + end_pos.z) / 2
            )
            scene:add_node(
                ModelWire(Line3D(start_pos, middle_pos), start_color, wire_width, bond.opacity)
            )
            scene:add_node(
                ModelWire(Line3D(middle_pos, end_pos), end_color, wire_width, bond.opacity)
            )

        else
            scene:add_node(
                ModelWire(Line3D(start_pos, end_pos), start_color, wire_width, bond.opacity)
            )
		end
	end
end

function p.draw_atoms_in_spacefilling_style( scene, atoms, look, stroke_color, stroke_width, radius_scale )
	radius_scale = radius_scale == nil and 1.0 or radius_scale
	checkType( 'draw_atoms_in_spacefilling_style', 1, scene, 'Scene' )
	checkType( 'draw_atoms_in_spacefilling_style', 2, atoms, 'table' )
	checkType( 'draw_atoms_in_spacefilling_style', 3, look, 'number' )
	checkType( 'draw_atoms_in_spacefilling_style', 4, stroke_color, 'Color' )
	checkType( 'draw_atoms_in_spacefilling_style', 5, stroke_width, 'number' )
	checkType( 'draw_atoms_in_spacefilling_style', 6, radius_scale, 'number' )

    for _, atom in ipairs( atoms) do
        -- Get atom specifications.
        local atom_symbol = atom.symbol
        local atom_color = (
            atom.color == nil
				and CoreyPaulingKoltungAtomColor:get_color(atom_symbol)
				or Color(atom.color[1], atom.color[2], atom.color[3])
        )
        local atom_radius = PubChemAtomRadius:to_angstrom(atom_symbol) * radius_scale
        local atom_pos = Point3D(atom.coordinates[1], atom.coordinates[2], atom.coordinates[3])

        -- Determine atom look.
		local depiction
        if look == Look.CARTOON then
            depiction = Cartoon(atom_color, stroke_color, stroke_width, atom.opacity)
        elseif look == Look.GLOSSY then
            depiction = Glossy(atom_color, atom.opacity)
        else
            error("Unknown look: " .. look)
		end

        -- Add atom to scene.
        scene:add_node(ModelSphere(Sphere(atom_pos, atom_radius), depiction))
    end
end

-- Draw a molecule using the given atoms and bonds in tube style.
function p.draw_bonds_in_tube_style( scene, atoms, bonds, tube_bond_style, look, cap_type, stroke_color, stroke_width )
	checkType( 'draw_bonds_in_tube_style', 1, scene, 'Scene' )
	checkType( 'draw_bonds_in_tube_style', 2, atoms, 'table' )
	checkType( 'draw_bonds_in_tube_style', 3, bonds, 'table' )
	checkType( 'draw_bonds_in_tube_style', 4, tube_bond_style, 'number' )
	checkType( 'draw_bonds_in_tube_style', 5, look, 'number' )
	checkType( 'draw_bonds_in_tube_style', 6, cap_type, 'number' )
	checkType( 'draw_bonds_in_tube_style', 7, stroke_color, 'Color' )
	checkType( 'draw_bonds_in_tube_style', 8, stroke_width, 'number' )


	local atom_map = {}
	for i,v in ipairs( atoms ) do
		atom_map[v.index] = v	
	end

    for _, bond in ipairs(bonds) do
        -- Get start atom specifications.
        local start_atom = atom_map[bond.start_index]
        local start_symbol = start_atom.symbol
        local start_color = CoreyPaulingKoltungAtomColor:get_color(start_symbol)
		local start_pos = Point3D(start_atom.coordinates[1], start_atom.coordinates[2], start_atom.coordinates[3])

        -- Get end atom specifications.
        local end_atom = atom_map[bond.end_index]
        local end_symbol = end_atom.symbol
        local end_color = CoreyPaulingKoltungAtomColor:get_color(end_symbol)
        local end_pos = Point3D(end_atom.coordinates[1], end_atom.coordinates[2], end_atom.coordinates[3])

        -- Determine color of bond.
        if bond.color ~= nil then
            start_color = Color(bond.color[1], bond.color[2], bond.color[3])
            end_color = Color(bond.color[1], bond.color[2], bond.color[3])
        else
            if start_atom.color ~= nil then
                start_color = Color(start_atom.color[1], start_atom.color[2], start_atom.color[3])
			end

            if end_atom.color ~= nil then
                end_color = Color(end_atom.color[1], end_atom.color[2], end_atom.color[3])
			end
		end

        -- Determine number of cylinders to draw for each bond.
        -- Bond order is only used for ball-and-stick style.
        local bond_order = tube_bond_style == Style.BALL_AND_STICK and bond.order or 1
        local bond_radius = bond.radius ~= nil and bond.radius or 0.2
        local temp_bond_radius = bond_radius / bond_order
        local line_tmp = Line3D(start_pos, end_pos)
        local lines = get_perpendicular_lines(line_tmp, temp_bond_radius * (bond_order + 1), bond_order)

        -- Add cylinders to scene.
        for __, line in ipairs(lines) do
            if start_color ~= end_color then
                local middel_pos = Point3D(
                    (line.start.x + line.endp.x) / 2,
                    (line.start.y + line.endp.y) / 2,
                    (line.start.z + line.endp.z) / 2
                )

                -- First part of bond.
				local depiction
                if look == Look.CARTOON then
                    depiction = Cartoon(
                        start_color, stroke_color, stroke_width, bond.opacity
                    )

                elseif look == Look.GLOSSY then
                    depiction = Glossy(start_color, bond.opacity)

                else
                    error("Unknown look: " .. look)
				end

                scene:add_node(
                    ModelCylinder(
                        Cylinder(line.start, middel_pos, temp_bond_radius, cap_type), depiction
                    )
                )

                -- Second part of bond.
                if look == Look.CARTOON then
                    depiction = Cartoon(end_color, stroke_color, stroke_width, bond.opacity)

                elseif look == Look.GLOSSY then
                    depiction = Glossy(end_color, bond.opacity)
                else
                    error("Unknown look: " .. look)
				end

                scene:add_node(
                    ModelCylinder(
                        Cylinder(middel_pos, line.endp, temp_bond_radius, cap_type), depiction
                    )
                )

            else
            	local depiction
                -- Bond as a whole.
                if look == Look.CARTOON then
                    depiction = Cartoon(start_color, stroke_color, stroke_width, bond.opacity)

                elseif look == Look.GLOSSY then
                    depiction = Glossy(start_color, bond.opacity)

                else
                    error("Unknown look: " .. look)
                end

                scene:add_node(
                    ModelCylinder(
                        Cylinder(line.start, line.endp, temp_bond_radius, cap_type), depiction
                    )
                )
			end
		end
	end
end

function p.draw_molecule(
	atoms,
	bonds,
	style,
	look,
	resolution,
	window, -- Possibly doesn't work
	view_box,
	rotation_over_x_axis,
	rotation_over_y_axis,
	rotation_over_z_axis,
	scale,
	focal_length,
	exclude_atoms
)
	rotation_over_x_axis = rotation_over_x_axis == nil and 0 or rotation_over_x_axis
	rotation_over_y_axis = rotation_over_y_axis == nil and 0 or rotation_over_y_axis
	rotation_over_z_axis = rotation_over_z_axis == nil and 0 or rotation_over_z_axis
	scale = scale == nil and 1 or scale
	checkType( 'draw_molecule', 1, atoms, 'table' )
	checkType( 'draw_molecule', 2, bonds, 'table' )
	checkType( 'draw_molecule', 3, style, 'number' )
	checkType( 'draw_molecule', 4, look, 'number' )
	checkType( 'draw_molecule', 5, resolution, 'number' )
	checkType( 'draw_molecule', 8, rotation_over_x_axis, 'number' )
	checkType( 'draw_molecule', 9, rotation_over_y_axis, 'number' )
	checkType( 'draw_molecule', 10, rotation_over_z_axis, 'number' )
	checkType( 'draw_molecule', 11, scale, 'number' )

    local atoms, bonds = p.filter_atoms_and_bonds(atoms, bonds, exclude_atoms)

    local scene = Scene({})

    -- Default settings for drawing.
    local include_spheres = false
    local include_cylinders = false
    local include_wires = false
    local calculate_sphere_sphere_intersections = false
    local calculate_sphere_cylinder_intersections = false
    local calculate_cylinder_sphere_intersections = false
    local calculate_cylinder_cylinder_intersections = false

    -- Default settings for cartoon look.
    local stroke_color = Color(0, 0, 0)
    local stroke_width = 0.05

    -- Wire style has a separate implementation that is faster than for geometric shapes.
    if style == Style.WIREFRAME then
        include_wires = true
        p.draw_bonds_in_wireframe_style(scene, atoms, bonds)

    elseif style == Style.SPACEFILLING then
        include_spheres = true
        calculate_sphere_sphere_intersections = true
        p.draw_atoms_in_spacefilling_style(scene, atoms, look, stroke_color, stroke_width)

    elseif style == Style.BALL_AND_STICK or style == Style.TUBE then
    	local cap_type
        if style == Style.BALL_AND_STICK then
            -- We need to draw atoms a spheres and bonds as cylinders for ball-and-stick
            -- style.
            include_spheres = true
            include_cylinders = true
            calculate_cylinder_sphere_intersections = true
            cap_type = CylinderCapType.NO_CAP
            p.draw_atoms_in_spacefilling_style(
                scene,
                atoms,
                look,
                stroke_color,
                stroke_width,
                1.0 / 3.0  -- Lower the radius of the atoms for ball-and-stick style.
            )

        elseif style == Style.TUBE then
            -- For tube we only draw bonds as cylinders and therefore the intersections
            -- between spheres and cylinders are not needed. However, we do need to
            -- calculate the intersections between cylinders.
            include_cylinders = true
            calculate_cylinder_cylinder_intersections = true
            cap_type = CylinderCapType.ROUND
		end

        p.draw_bonds_in_tube_style(
            scene, atoms, bonds, style, look, cap_type, stroke_color, stroke_width
        )

    else
        error("Unknown style: " .. style)
	end
	local viewBox = nil
	if view_box ~= nil then
		viewBox = ViewBox( view_box[1], view_box[2], view_box[3], view_box[4] )
	end
    -- Draw scene.
    local svg = scene:draw(
        resolution,
        window,
        viewBox,
        rotation_over_x_axis,
        rotation_over_y_axis,
        rotation_over_z_axis,
        include_spheres,
        include_cylinders,
        include_wires,
        calculate_sphere_sphere_intersections,
        calculate_sphere_cylinder_intersections,
        calculate_cylinder_sphere_intersections,
        calculate_cylinder_cylinder_intersections,
        true,
        scale,
        focal_length
    )

    return svg
end

return p