lilatomic

Writing a Vars Plugin in Ansible

Writing a Vars Plugin in Ansible #

Often there are facts you wish were just ambiently available, especially if you are working with cloud infrastructure. For example, you might want to have your subscription id, tenant id, user object id, contents of a resource group, and many other things. You could have tasks which set these as facts, but that severely limits your ability to resume at a task.

Basic outline #

The most basic version of a vars plugin is just a reflectable classname and method name.

from ansible.plugins.vars import BaseVarsPlugin

class VarsModule(BaseVarsPlugin):

	def get_vars(self, loader, path, entities, cache=None):
		return {}

Let's get our documentation in order. We want to include this fragment to let people know how to set when the plugin runs (See Execution for why)

vars: gitroot
version_added: "0.2" # for collections, use the collection version, not the Ansible version
short_description: Finds the git root
description: Finds the git root
options:
  stage:
    ini:
      - key: stage
        section: lilatomic.alpacloud.gitroot
    env:
      - name: LILATOMIC_ALPACLOUD_GITROOT
extends_documentation_fragment:
  - vars_plugin_staging

Combining we get a sample like:

DOCUMENTATION = r"""
vars: gitroot
version_added: "0.2"
short_description: Finds the git root
description: Finds the git root
options:
  stage:
    ini:
      - key: stage
        section: lilatomic.alpacloud.gitroot
    env:
      - name: LILATOMIC_ALPACLOUD_GITROOT
extends_documentation_fragment:
  - vars_plugin_staging
"""

import subprocess

from ansible.plugins.vars import BaseVarsPlugin


class VarsModule(BaseVarsPlugin):  # Must be named VarsModule
	def __init__(self):  # Not necessary if you're just going to call up
		super().__init__()

	def get_vars(
		self, loader, path, entities, cache=None
	):  # the function which actualy does the lookup

		# call to super. `self._basedir = basedir(path)`
		super(VarsModule, self).get_vars(loader, path, entities)

		return {"src": _get_git_root()}


def _get_git_root():
	res = subprocess.run(
		"git rev-parse --show-toplevel".split(" "),
		stdout=subprocess.PIPE,
		universal_newlines=True,
	)
	return res.stdout

This sample is really bad, because it gets executed a lot (See Execution). So we want to build some caching into it. This is patterned off of how host_group_vars does it (the only vars plugin included in base), but other vars plugins do a similar thing:

DOCUMENTATION = r"""
vars: gitroot
version_added: "0.2"
short_description: Finds the git root
description: Finds the git root
options:
  stage:
    ini:
      - key: stage
        section: lilatomic.alpacloud.gitroot
    env:
      - name: LILATOMIC_ALPACLOUD_GITROOT
extends_documentation_fragment:
  - vars_plugin_staging
"""

import subprocess

from ansible.plugins.vars import BaseVarsPlugin


FOUND = {}


class VarsModule(BaseVarsPlugin):
	def get_vars(self, loader, path, entities, cache=None):

		if not isinstance(entities, list):
			entities = [entities]

		super(VarsModule, self).get_vars(loader, path, entities)

		if "src" not in FOUND:
			FOUND["src"] = _get_git_root()

		return {"src": FOUND["src"]}


def _get_git_root():
	res = subprocess.run(
		"git rev-parse --show-toplevel".split(" "),
		stdout=subprocess.PIPE,
		universal_newlines=True,
	)
	return res.stdout

Using hosts and groups #

In the above example, we added a general var (one which isn't attached to a particular host). We may want to add other data associated with each host or group. There are a couple moving pieces detailed in the following:

DOCUMENTATION = r"""
vars: knownhostentry
version_added: "0.2"
short_description: Adds a known-hosts entry for each host, if you have one
description: 
  - Adds a known-hosts entry for each host, if you have one in you local known-hosts
  - Useful for not having to construct it every time
options:
  stage:
    ini:
      - key: stage
        section: lilatomic.alpacloud.knownhostentry
    env:
      - name: LILATOMIC_ALPACLOUD_KNOWNHOSTENTRY
extends_documentation_fragment:
  - vars_plugin_staging
"""
import json
import os
import subprocess

from ansible.errors import AnsibleParserError
from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.plugins.vars import BaseVarsPlugin
from ansible.inventory.host import Host
from ansible.inventory.group import Group
from ansible.utils.vars import combine_vars


from ansible.utils.display import Display

display = Display()

FOUND = {}


class VarsModule(BaseVarsPlugin):
	def get_vars(self, loader, path, entities, cache=True):

		# not sure what this is about, it's not invoked when I run it
		# there's probably a compatibility thing
		if not isinstance(entities, list):
			entities = [entities]

		super(VarsModule, self).get_vars(loader, path, entities)

		data = {}
		# loop through entities
		for entity in entities:
			# entities come as either Hosts or Groups. This example ignores groups.
			if isinstance(entity, Host):
				...
			elif isinstance(entity, Group):
				continue
			else:
				raise AnsibleParserError(
					"Supplied entity must be Host or Group, got %s instead" % (type(entity))
				)

			if entity.name.startswith(os.path.sep):
				# avoid 'chroot' type inventory hostnames /path/to/chroot
				# this is from the sample plugin
				continue

			try:
				key = str(entity.name)
				if cache and key in FOUND:
					known_hosts = FOUND[key]
				else:
					FOUND[key] = _get_known_host(entity.name)
					known_hosts = FOUND[key]

				if known_hosts:
					data["known_hosts"] = known_hosts

			except Exception as e:
				display.warning("HALP")
				raise AnsibleParserError(to_native(e))

		display.v(json.dumps(data))
		return data


def _get_known_host(host, file="~/.ssh/known_hosts"):
	file = os.path.expanduser(file)

	cmd = ["/usr/bin/ssh-keygen", "-f", file, "-F", host]

	display.v(f"cmd : {cmd}")
	res = subprocess.run(
		cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True,
	)
	display.v(f"stdout {res.stdout}")
	display.v(f"stderr {res.stderr}")
	display.warning(res.stderr)
	lines = res.stdout.split("\n")
	hostlines = [x for x in lines if x and not x.startswith("#")]
	return hostlines

Loading files #

You can probably figure out how to load files by cribbing from the host_group_vars plugin. I've included one here for completeness. The main difference is that loader.load_from_file only supports YAML/JSON, so we write our own little functionlet.

DOCUMENTATION = r"""
vars: dhall_vars
version_added: "0.2"
short_description: Loads vars from a Dhall file
description: 
  - Loads vars from a Dhall file
options:
  stage:
    ini:
      - key: stage
        section: lilatomic.alpacloud.dhall_vars
    env:
      - name: LILATOMIC_ALPACLOUD_DHALL_VARS
extends_documentation_fragment:
  - vars_plugin_staging
"""

import os
import subprocess

from ansible.errors import AnsibleParserError
from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.plugins.vars import BaseVarsPlugin
from ansible.inventory.host import Host
from ansible.inventory.group import Group
from ansible.utils.vars import combine_vars

import dhall


FOUND = {}


class VarsModule(BaseVarsPlugin):
	def get_vars(self, loader, path, entities, cache=None):

		if not isinstance(entities, list):
			entities = [entities]

		super(VarsModule, self).get_vars(loader, path, entities)

		data = {}

		for entity in entities:
			# we set the subdir here
			if isinstance(entity, Host):
				subdir = "host_vars"
			elif isinstance(entity, Group):
				subdir = "group_vars"
			else:
				raise AnsibleParserError(
					"Supplied entity must be Host or Group, got %s instead" % (type(entity))
				)

			# avoid 'chroot' type inventory hostnames /path/to/chroot
			if entity.name.startswith(os.path.sep):
				continue

			try:
				# mostly copied from host_group_vars
				found_files = []
				# load vars
				b_opath = os.path.realpath(to_bytes(os.path.join(self._basedir, subdir)))
				opath = to_text(b_opath)
				self._display.v("\tprocessing dir %s" % opath)

				# We set the cache key to be specific to both entity and file
				key = "%s.%s" % (entity.name, opath)
				if cache and key in FOUND:
					# cache hit
					found_files = FOUND[key]
				else:
					if os.path.exists(b_opath):
						if os.path.isdir(b_opath):
							self._display.debug("\tprocessing dir %s" % opath)
							# use the file loader to load
							found_files = loader.find_vars_files(
								path=opath,
								name=entity.name,
								extensions=["", ".dhall"],
								# allow_dir=True
							)
						else:
							self._display.warning(
								"Found %s that is not a directory, skipping: %s" % (subdir, opath)
							)

				for found in found_files:
					new_data = _read_dhall_file(found)
					if new_data:  # ignore empty files
						data = combine_vars(data, new_data)

			except Exception as e:
				raise AnsibleParserError(to_native(e))

		return data


def _read_dhall_file(filename):
	with open(filename, "r") as f:
		ctn = f.read()
		vars = dhall.loads(ctn)
	return vars

Execution #

Vars plugins are invoked at 2 stages:

  1. Inventory : upon initial inventory parsing
  • once for every group (including "all" and "ungrouped")
  • once for every host
  1. Task : every task
  • once per cartesian product of:
    • entity: hostname or group involved in the play
    • path:
      • each inventory source path
      • the basedir for the play (or the nested play)

This can be specified with the option specified in the documentation.

This is why you want your plugin to implement a load phase and some form of caching!