class Session: 一段落

セッションをあれやこれやしてくれるクラス、Sessionがとりあえず一段落しました。エラー処理とか、あとドキュメンテーション(docstring)とか、そのへん放ったらかしだけど、それはまた追い追い。

ドキュメンテーション文字列に関しては、PEP: 257を参考にしたらいいのかな。まあ、ある程度落ち着いたら、追い追い書いていきます。

Sessionのやることは、SQLiteを使ってのセッション管理です。セッションIDを受け取って、それが有効なら同じIDを、無効(存在しなかったり期限切れだったり)なら新規にIDを作成して返してくれます。

Sessionのインスタンスを作る際には、SQLiteの作成するファイルのパスを渡してやる必要があります。あとは、まあ適当に。セッションIDとセッション有効期間、リモートアドレス(IP)のマッチを行なうかどうか、それを指定してやります。

get_id()メソッドを使うと、セッションIDが返されます。そのIDを、クライアントから受けとったIDと比べることで、セッションが有効かどうかを判定します。

とりあえず、こんな感じです。以下にソースをのせておきます。

# -*- coding: utf-8 -*-
'''
Class for provide session with sqlite3
'''

import os
import datetime
import sys
import random
import hashlib
import sqlite3
import pickle
import base64

class Session:
    '''Class to provide session.'''

    def __init__(self, dbpath, sid=None, validity=u'3 hours', ipmatch=False):
        self.sid = sid
        self.dbpath = dbpath
        self.dbtablename = u'sessiontable'
        self.validity = validity
        self.ipmatch = ipmatch
        self.data = None
        
        connection = self._open_db()
        cursor = connection.cursor()
        cursor.execute('select * from sqlite_master \
        where type=\'table\' and name=?;', \
                       (self.dbtablename, ))
        tablecount = cursor.fetchall()
        if len(tablecount) == 0:
            cursor.execute('create table %s (id primary key, data, \
            created_time, accessed_time, expire_time, remote_addr);' \
                           % self.dbtablename)

        cursor.execute('delete from %s where expire_time<datetime(\'now\');' \
                       % self.dbtablename)

        if isinstance(self.sid, basestring):
            cursor.execute('select id from %s where id=\'%s\';' \
                           % (self.dbtablename, self.sid))
            idcount = cursor.fetchall()
            if len(idcount) == 0:
                self._create_session_id()
                self._insert_session_record(cursor)
            else:
                if self.ipmatch:
                    current_addr = os.environ.get('REMOTE_ADDR', u'')
                    past_addr = self.get_remote_addr()
                    if current_addr == past_addr:
                        self._update_session_record(cursor)
                    else:
                        self._create_session_id()
                        self._insert_session_record(cursor)
                else:
                    self._update_session_record(cursor)
        else:
            self._create_session_id()
            self._insert_session_record(cursor)

        cursor.close()
        connection.commit()
        connection.close()


    def get_id(self):
        return self.sid

    def get_created_time(self):
        connection = self._open_db()
        cursor = connection.cursor()
        cursor.execute('select created_time from %s where id=\'%s\';' \
                       % (self.dbtablename, self.sid))
        created_time = cursor.fetchone()
        cursor.close()
        connection.close()
        return created_time[0]

    def get_accessed_time(self):
        connection = self._open_db()
        cursor = connection.cursor()
        cursor.execute('select accessed_time from %s where id=\'%s\';' \
                       % (self.dbtablename, self.sid))
        accessed_time = cursor.fetchone()
        cursor.close()
        connection.close()
        return accessed_time[0]
        
    def get_expire_time(self):
        connection = self._open_db()
        cursor = connection.cursor()
        cursor.execute('select expire_time from %s where id=\'%s\';' \
                       % (self.dbtablename, self.sid))
        expire_time = cursor.fetchone()
        cursor.close()
        connection.close()
        return expire_time[0]

    def get_remote_addr(self):
        connection = self._open_db()
        cursor = connection.cursor()
        cursor.execute('select remote_addr from %s where id=\'%s\';' \
                       % (self.dbtablename, self.sid))
        remote_addr = cursor.fetchone()
        cursor.close()
        connection.close()
        return remote_addr[0]

    def get_data(self):
        if self.data == None:
            connection = self._open_db()
            cursor = connection.cursor()
            cursor.execute('select data from %s where id=\'%s\';' \
                           % (self.dbtablename, self.sid))
            data = cursor.fetchone()
            data = data[0]
            if data != None:
                self.data = pickle.loads(base64.decodestring(data))
        return self.data


    def set_data(self, data):
        self.data = data

    def reset_data(self):
        self.data = None

    def save_data(self):
        data = self.data
        if data != None:
            data = base64.encodestring(pickle.dumps(data))
            connection = self._open_db()
            cursor = connection.cursor()
            cursor.execute('update %s set data=\'%s\' where id=\'%s\';' \
                           % (self.dbtablename, data, self.sid))
            cursor.close()
            connection.commit()
            connection.close()

    def delete(self):
        connection = self._open_db()
        cursor = connection.cursor()
        cursor.execute('delete from %s where id=\'%s\';' \
                       % (self.dbtablename, self.sid))
        cursor.close()
        connection.commit()
        connection.close()


    # internal methods

    def _open_db(self):
        return sqlite3.connect(self.dbpath)

    def _create_session_id(self):
        connection = self._open_db()
        cursor = connection.cursor()
        while True:
            now = datetime.datetime.today()
            seed = str(os.getpid()) + \
                   str(now.isoformat()) + \
                   str(random.randint(0, sys.maxint - 1))
            message = hashlib.new('sha256')
            message.update(seed)
            sid = message.hexdigest()
            cursor.execute('select id from %s where id=\'%s\';' \
                           % (self.dbtablename, sid))
            idcount = cursor.fetchall()
            if len(idcount) == 0:
                self.sid = sid
                break
        cursor.close()
        connection.close()

    def _insert_session_record(self, cursor):
        cursor.execute('insert into %s (id, created_time, accessed_time, \
        expire_time, remote_addr) values(\'%s\', datetime(\'now\'), \
        datetime(\'now\'), datetime(\'now\', \'%s\'), \'%s\');' \
                       % (self.dbtablename, self.sid, self.validity, \
                          os.environ.get('REMOTE_HOST', u'')))

    def _update_session_record(self, cursor):
        cursor.execute('update %s set accessed_time=datetime(\'now\'), \
        expire_time=datetime(\'now\', \'%s\'), remote_addr=\'%s\' \
        where id=\'%s\';' \
                       % (self.dbtablename, self.validity, \
                          os.environ.get('REMOTE_HOST', u''), self.sid))