class Session: 一段落
セッションをあれやこれやしてくれるクラス、Sessionがとりあえず一段落しました。エラー処理とか、あとドキュメンテーション(docstring)とか、そのへん放ったらかしだけど、それはまた追い追い。
ドキュメンテーション文字列に関しては、PEP: 257を参考にしたらいいのかな。まあ、ある程度落ち着いたら、追い追い書いていきます。
Sessionのやることは、SQLiteを使ってのセッション管理です。セッションIDを受け取って、それが有効なら同じIDを、無効(存在しなかったり期限切れだったり)なら新規にIDを作成して返してくれます。
Sessionのインスタンスを作る際には、SQLiteの作成するファイルのパスを渡してやる必要があります。あとは、まあ適当に。セッションIDとセッション有効期間、リモートアドレス(IP)のマッチを行なうかどうか、それを指定してやります。
get_id()メソッドを使うと、セッションIDが返されます。そのIDを、クライアントから受けとったIDと比べることで、セッションが有効かどうかを判定します。
とりあえず、こんな感じです。以下にソースをのせておきます。
# -*- coding: utf-8 -*- ''' Class for provide session with sqlite3 ''' import os import datetime import sys import random import hashlib import sqlite3 import pickle import base64 class Session: '''Class to provide session.''' def __init__(self, dbpath, sid=None, validity=u'3 hours', ipmatch=False): self.sid = sid self.dbpath = dbpath self.dbtablename = u'sessiontable' self.validity = validity self.ipmatch = ipmatch self.data = None connection = self._open_db() cursor = connection.cursor() cursor.execute('select * from sqlite_master \ where type=\'table\' and name=?;', \ (self.dbtablename, )) tablecount = cursor.fetchall() if len(tablecount) == 0: cursor.execute('create table %s (id primary key, data, \ created_time, accessed_time, expire_time, remote_addr);' \ % self.dbtablename) cursor.execute('delete from %s where expire_time<datetime(\'now\');' \ % self.dbtablename) if isinstance(self.sid, basestring): cursor.execute('select id from %s where id=\'%s\';' \ % (self.dbtablename, self.sid)) idcount = cursor.fetchall() if len(idcount) == 0: self._create_session_id() self._insert_session_record(cursor) else: if self.ipmatch: current_addr = os.environ.get('REMOTE_ADDR', u'') past_addr = self.get_remote_addr() if current_addr == past_addr: self._update_session_record(cursor) else: self._create_session_id() self._insert_session_record(cursor) else: self._update_session_record(cursor) else: self._create_session_id() self._insert_session_record(cursor) cursor.close() connection.commit() connection.close() def get_id(self): return self.sid def get_created_time(self): connection = self._open_db() cursor = connection.cursor() cursor.execute('select created_time from %s where id=\'%s\';' \ % (self.dbtablename, self.sid)) created_time = cursor.fetchone() cursor.close() connection.close() return created_time[0] def get_accessed_time(self): connection = self._open_db() cursor = connection.cursor() cursor.execute('select accessed_time from %s where id=\'%s\';' \ % (self.dbtablename, self.sid)) accessed_time = cursor.fetchone() cursor.close() connection.close() return accessed_time[0] def get_expire_time(self): connection = self._open_db() cursor = connection.cursor() cursor.execute('select expire_time from %s where id=\'%s\';' \ % (self.dbtablename, self.sid)) expire_time = cursor.fetchone() cursor.close() connection.close() return expire_time[0] def get_remote_addr(self): connection = self._open_db() cursor = connection.cursor() cursor.execute('select remote_addr from %s where id=\'%s\';' \ % (self.dbtablename, self.sid)) remote_addr = cursor.fetchone() cursor.close() connection.close() return remote_addr[0] def get_data(self): if self.data == None: connection = self._open_db() cursor = connection.cursor() cursor.execute('select data from %s where id=\'%s\';' \ % (self.dbtablename, self.sid)) data = cursor.fetchone() data = data[0] if data != None: self.data = pickle.loads(base64.decodestring(data)) return self.data def set_data(self, data): self.data = data def reset_data(self): self.data = None def save_data(self): data = self.data if data != None: data = base64.encodestring(pickle.dumps(data)) connection = self._open_db() cursor = connection.cursor() cursor.execute('update %s set data=\'%s\' where id=\'%s\';' \ % (self.dbtablename, data, self.sid)) cursor.close() connection.commit() connection.close() def delete(self): connection = self._open_db() cursor = connection.cursor() cursor.execute('delete from %s where id=\'%s\';' \ % (self.dbtablename, self.sid)) cursor.close() connection.commit() connection.close() # internal methods def _open_db(self): return sqlite3.connect(self.dbpath) def _create_session_id(self): connection = self._open_db() cursor = connection.cursor() while True: now = datetime.datetime.today() seed = str(os.getpid()) + \ str(now.isoformat()) + \ str(random.randint(0, sys.maxint - 1)) message = hashlib.new('sha256') message.update(seed) sid = message.hexdigest() cursor.execute('select id from %s where id=\'%s\';' \ % (self.dbtablename, sid)) idcount = cursor.fetchall() if len(idcount) == 0: self.sid = sid break cursor.close() connection.close() def _insert_session_record(self, cursor): cursor.execute('insert into %s (id, created_time, accessed_time, \ expire_time, remote_addr) values(\'%s\', datetime(\'now\'), \ datetime(\'now\'), datetime(\'now\', \'%s\'), \'%s\');' \ % (self.dbtablename, self.sid, self.validity, \ os.environ.get('REMOTE_HOST', u''))) def _update_session_record(self, cursor): cursor.execute('update %s set accessed_time=datetime(\'now\'), \ expire_time=datetime(\'now\', \'%s\'), remote_addr=\'%s\' \ where id=\'%s\';' \ % (self.dbtablename, self.validity, \ os.environ.get('REMOTE_HOST', u''), self.sid))