セッションを管理するクラス

昨日はHTMLの生成に役立つクラスを作っているといっていましたが、今日は少しHTMLから離れて、セッションを管理するクラスを作っていました。Perlのモジュール、CGI::Session日本語訳)を参考にしつつ、またmysql & python でセッションを管理も参照しながら、とりあえず最低限動くものができました。

まだできてないところも多い、本当に基本的なものでしかないのだけど、セッションIDを発行するくらいなら可能です。データの管理にはSQLiteを使っていて、これ、一歩間違えると、データベースを丸ごと持ってかれたりしそうで怖いんだけど(公開ディレクトリにファイルを置かないとか、用心しないといけませんね)、まあそのへんは割り切るということで。

セッションIDの生成は、プロセスIDと日時とランダムな数値を連結したものをMD5ハッシュ値にすることで得ています。これ、まんまCGI::Sessionを参考にしています。

あとは、まあ、読んでみてください。

# -*- coding: utf-8 -*-
'''
Class for provide session with sqlite3
'''

import os
import datetime
import sys
import random
import md5
import sqlite3

class Session:
    '''Class to provide session.'''

    def __init__(self, dbpath, sid=None, validity=u'3 hours'):
        self.sid = sid
        self.dbpath = dbpath
        self.dbtablename = u'sessiontable'
        self.validity = validity
        
        connection = self._open_db()
        cursor = connection.cursor()
        cursor.execute('select * from sqlite_master \
        where type=\'table\' and name=?;', \
                       (self.dbtablename, ))
        tablecount = cursor.fetchall()
        if len(tablecount) == 0:
            cursor.execute('create table %s (id, data, created_time, \
            accessed_time, expire_time, remote_addr);' % self.dbtablename)

        cursor.execute('delete from %s where expire_time<datetime(\'now\');' \
                       % self.dbtablename)

        if isinstance(self.sid, basestring):
            cursor.execute('select id from %s where id=\'%s\';' \
                           % (self.dbtablename, self.sid))
            idcount = cursor.fetchall()
            if len(idcount) == 0:
                self._create_session_id()
                self._insert_session_record(cursor)
            else:
                self._update_session_record(cursor)
        else:
            self._create_session_id()
            self._insert_session_record(cursor)

        cursor.close()
        connection.commit()
        connection.close()



    def get_id(self):
        return self.sid


    # internal methods

    def _open_db(self):
        return sqlite3.connect(self.dbpath)

    def _create_session_id(self):
        connection = self._open_db()
        cursor = connection.cursor()
        while True:
            now = datetime.datetime.today()
            seed = str(os.getpid()) + \
                   str(now.isoformat()) + \
                   str(random.randint(0, sys.maxint - 1))
            sid = md5.new(seed).hexdigest()
            cursor.execute('select id from %s where id=\'%s\';' \
                           % (self.dbtablename, sid))
            idcount = cursor.fetchall()
            if len(idcount) == 0:
                self.sid = sid
                break
        cursor.close()
        connection.close()

    def _insert_session_record(self, cursor):
        cursor.execute('insert into %s (id, created_time, accessed_time, \
        expire_time, remote_addr) values(\'%s\', datetime(\'now\'), \
        datetime(\'now\'), datetime(\'now\', \'%s\'), \'%s\');' \
                       % (self.dbtablename, self.sid, self.validity, \
                          os.environ['REMOTE_HOST']))

    def _update_session_record(self, cursor):
        cursor.execute('update %s set accessed_time=datetime(\'now\'), \
        expire_time=datetime(\'now\', \'%s\'), remote_addr=\'%s\' \
        where id=\'%s\';' \
                       % (self.dbtablename, self.validity, \
                          os.environ['REMOTE_HOST'], self.sid))