helpers.py 7.14 KB
Newer Older
BERJON Matthieu's avatar
BERJON Matthieu committed
1 2
import base64
import hashlib
3
import itertools
4
import logging
BERJON Matthieu's avatar
BERJON Matthieu committed
5
import os
6
import re
BERJON Matthieu's avatar
BERJON Matthieu committed
7

8
import redis
9
import IPy
10
from django.conf import settings
11
import django.db
12
from django.db.models import Q
13

14
import config
15 16 17 18 19 20
from .models import (
        AllgoUser,
        Job,
        Webapp,
        WebappVersion,
        )
21

BERJON Matthieu's avatar
BERJON Matthieu committed
22

23
log = logging.getLogger('allgo')
BERJON Matthieu's avatar
BERJON Matthieu committed
24 25
DEFAULT_ENTROPY = 32 # number of bytes to return by default

26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50

##################################################
# redis keys


# job log
REDIS_KEY_JOB_LOG       = "log:job:%d"
REDIS_KEY_JOB_STATE     = "state:job:%d"

# pubsub channels for waking up allgo.aio (frontend) and the controller
# (ourselves)
REDIS_CHANNEL_AIO        = "notify:aio"
REDIS_CHANNEL_CONTROLLER = "notify:controller"

# pubsub messages
REDIS_MESSAGE_JOB_UPDATED    = "job:%d"
REDIS_MESSAGE_WEBAPP_UPDATED = "webapp:%d"

##################################################


# global redis connection pool
_redis_connection_pool = None


BERJON Matthieu's avatar
BERJON Matthieu committed
51 52 53 54 55 56
def get_ssh_data(key):
    """
    Return the fingerprint and comment of a given SSH key.

    It has been tested only on RSA keys
    """
BAIRE Anthony's avatar
BAIRE Anthony committed
57 58 59 60 61
    # FIXME: this implementation computes a MD5 hash, which was superseded a
    #        long time ago. The current openssh fingerprinats are based on
    #        use SHA256, the output looks like:
    #           2048 SHA256:sjsPbfDzfuylskauytlylfpaltjufjhqnphYvVYnhbI
    #        We should use this format too
BERJON Matthieu's avatar
BERJON Matthieu committed
62 63 64 65 66 67 68 69 70 71 72
    key_parts = key.strip().split(None, 2)
    if len(key_parts) == 3:
        comment = key_parts[2]
    else:
        comment = None

    key = base64.b64decode(key.strip().split()[1].encode('ascii'))
    fp_plain = hashlib.md5(key).hexdigest()
    fp_encoded = ':'.join(a+b for a, b in zip(fp_plain[::2], fp_plain[1::2]))

    return fp_encoded, comment
73

74
def upload_data(uploaded_files, job):
75 76 77 78
    """
    Upload any data according to a specific job id

    Args:
79 80
        uploaded_files:   iterable that yields
                          django.core.files.uploadedfile.UploadedFile objects
81
        job (Job):        job
82 83 84

    Examples:

85
        >>> upload_data(self.request.FILES.getlist('files'), job)
86 87 88 89

    Returns:
        Nothing
    """
90

91
    job_dir = job.data_dir
92 93
    os.makedirs(job_dir)

94
    for file_data in uploaded_files:
95 96 97 98 99 100 101 102 103 104 105 106 107 108
        filename = file_data.name

        # sanitise the filename to prevent directory escape
        #
        # The filename is provided by the user submitting the job, it cannot be
        # tructed. Dangerous characters are replaced with "_" so as to
        # guarantee that we do not write anything outside the job dir.
        #
        # This is a security feature, do not remove it.
        #
        if filename in (".", ".."):
            filename = filename.replace(".", "_")
        filename = filename.replace("/", "_")

109 110
        filepath = os.path.join(job_dir, filename)
        with open(filepath, 'wb+') as destination:
111 112
            for chunk in file_data.chunks():
                destination.write(chunk)
113

114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
def lookup_job_file(job_id, filename):
    """Look up a job data file and return its real path

    This function also performs additional security checks to prevent escaping
    from the job data directory:
    - prevent accessing subdirectories or other job directories
    - exclude non-regular files
    - exclude symbolic links

    returns None if lookup fails
    """

    path = os.path.join(settings.DATASTORE, str(job_id), filename)
    if (        "/" not in filename
        and     os.path.isfile(path)
        and not os.path.islink(path)
        ):
        return path
132 133 134 135 136 137 138 139 140 141 142 143 144 145

def get_redis_connection():
    "Get a redis connection from the global pool"
    global _redis_connection_pool

    if _redis_connection_pool is None:
        _redis_connection_pool = redis.ConnectionPool(
                host=config.env.ALLGO_REDIS_HOST)

    return redis.Redis(connection_pool=_redis_connection_pool)


def notify_controller(obj):
    """Notify the controller that an entry was updated in the db
146

147 148 149 150 151
    The notification is sent through the redis pubsub channel
    REDIS_CHANNEL_CONTROLLER.
    """
    conn = get_redis_connection()

152
    if isinstance(obj, Job):
153
        conn.publish(REDIS_CHANNEL_CONTROLLER, REDIS_MESSAGE_JOB_UPDATED % obj.id)
154
    elif isinstance(obj, Webapp):
155
        conn.publish(REDIS_CHANNEL_CONTROLLER, REDIS_MESSAGE_WEBAPP_UPDATED % obj.id)
156 157 158
    else:
        raise TypeError(obj)

159 160

_ALLOWED_IP_NETWORKS = list(map(IPy.IP, config.env.ALLGO_ALLOWED_IP_ADMIN.split(",")))
161
def is_allowed_ip_admin(ip_address):
162
    """Return true if admin actions are allowed from this IP address
163

164 165 166
    The function return true if the provided ip address is included in at least
    one network listed in ALLGO_ALLOWED_IP_ADMIN.
    """
167
    return any(ip_address in net for net in _ALLOWED_IP_NETWORKS)
168 169 170



171 172 173 174 175 176 177 178 179 180 181 182 183 184
def get_base_url(request):
    """Extract the base url from this django request object

        typically this will be "https://allgo.inria.fr"
    """
    scheme = request.META.get("HTTP_X_FORWARDED_PROTO", request.scheme)
    # NOTE: django's request.get_host()/.get_port() are kind of broken because
    # they do not expect the port to be provided in the Host/X-Forwarded-Host
    # headers (which is quite common)
    host = request.META.get("HTTP_X_FORWARDED_HOST")
    if host is None:
        host = request.get_host()
    return "%s://%s" % (scheme, host)

185

186 187 188
def get_request_user(request):
    """Return the authenticated user from the provided request

189 190 191 192 193 194
    Depending on the request path, the authentication is attempted on:
    - the token provided in the HTTP Authorization header for /api/ urls
    - the session cookie for other urls

    In case of /auth requests we assume that 'X-Original-URI' is the path of
    the current request.
195 196 197 198 199

    Args:
        request

    Returns:
200
        a User or None
201
    """
202 203 204 205 206 207 208 209
    path = request.path
    if path == "/auth":
        path = request.META['HTTP_X_ORIGINAL_URI']
    if path.startswith("/api/"):
        # authenticated by token for API requests
        #
        # NOTE: we must NOT authenticate by cookie because the CORS
        #       configuration in the nginx.conf allows all origins
210
        mo = re.match(r"Token token=(\S+)",
211 212 213 214 215 216 217 218 219 220
                request.META.get('HTTP_AUTHORIZATION', ''))
        if mo:
            return getattr(
                    # FIXME: user token should have a unicity constraint
                    AllgoUser.objects.filter(token=mo.group(1)).first(),
                    "user", None)
    else:
        # authenticated by cookie for other requests
        if request.user.is_authenticated:
            return request.user
221 222


223 224 225 226 227
def query_webapps_for_user(user):
    """Return a queryset of all webapps visible by a given user"""

    if user.is_superuser:
        return Webapp.objects.all()
228 229 230 231 232 233 234 235
    
    # a webapp is visible in the public index if it is not private (obviously) and if it has at
    # least one version published and ready.
    with django.db.connection.cursor() as cur:
        cur.execute("""SELECT webapp_id FROM dj_webapp_versions WHERE webapp_id IN (
            SELECT id FROM dj_webapps WHERE private != 1
        ) AND published=1 AND state=%s GROUP BY webapp_id""", (WebappVersion.READY,))
        public_ids = list(itertools.chain(*cur.fetchall()))
236

237
    return Webapp.objects.filter(Q(user_id=user.id) | Q(id__in=public_ids))
238