support ad-hoc models

This commit is contained in:
Itai Shirav 2016-06-26 15:11:16 +03:00
parent 9262f0eae6
commit 92ea9d413e
3 changed files with 51 additions and 8 deletions

View File

@ -1,4 +1,6 @@
import requests import requests
from models import ModelBase
from utils import escape, parse_tsv
class DatabaseException(Exception): class DatabaseException(Exception):
@ -47,10 +49,14 @@ class Database(object):
return int(r.text) if r.text else 0 return int(r.text) if r.text else 0
def select(self, query, model_class=None, settings=None): def select(self, query, model_class=None, settings=None):
query += ' FORMAT TabSeparated' query += ' FORMAT TabSeparatedWithNamesAndTypes'
r = self._send(query, settings) r = self._send(query, settings)
for line in r.iter_lines(): lines = r.iter_lines()
yield model_class.from_tsv(line) field_names = parse_tsv(next(lines))
field_types = parse_tsv(next(lines))
model_class = model_class or ModelBase.create_ad_hoc_model(zip(field_names, field_types))
for line in lines:
yield model_class.from_tsv(line, field_names)
def _send(self, data, settings=None): def _send(self, data, settings=None):
params = self._build_params(settings) params = self._build_params(settings)

View File

@ -1,6 +1,6 @@
from fields import *
from utils import escape, parse_tsv from utils import escape, parse_tsv
from engines import * from engines import *
from fields import Field
class ModelBase(type): class ModelBase(type):
@ -16,6 +16,18 @@ class ModelBase(type):
setattr(new_cls, '_fields', fields) setattr(new_cls, '_fields', fields)
return new_cls return new_cls
@classmethod
def create_ad_hoc_model(cls, fields):
# fields is a list of tuples (name, db_type)
import fields as orm_fields
attrs = {}
for name, db_type in fields:
field_class = db_type + 'Field'
if not hasattr(orm_fields, field_class):
raise NotImplementedError('No field class for %s' % db_type)
attrs[name] = getattr(orm_fields, field_class)()
return cls.__new__(cls, 'AdHocModel', (Model,), attrs)
class Model(object): class Model(object):
''' '''
@ -77,16 +89,18 @@ class Model(object):
return 'DROP TABLE IF EXISTS %s.%s' % (db_name, cls.table_name()) return 'DROP TABLE IF EXISTS %s.%s' % (db_name, cls.table_name())
@classmethod @classmethod
def from_tsv(cls, line): def from_tsv(cls, line, field_names=None):
''' '''
Create a model instance from a tab-separated line. The line may or may not include a newline. Create a model instance from a tab-separated line. The line may or may not include a newline.
The field_names list must match the fields defined in the model, but does not have to include all of them.
If omitted, it is assumed to be the names of all fields in the model, in order of definition.
''' '''
field_names = field_names or [name for name, field in cls._fields]
values = iter(parse_tsv(line)) values = iter(parse_tsv(line))
kwargs = {} kwargs = {}
for name, field in cls._fields: for name in field_names:
kwargs[name] = values.next() kwargs[name] = values.next()
return cls(**kwargs) return cls(**kwargs)
# TODO verify that the number of values matches the number of fields
def to_tsv(self): def to_tsv(self):
''' '''

View File

@ -37,7 +37,7 @@ class DatabaseTestCase(unittest.TestCase):
self.assertEquals(self.database.count(Person), 100) self.assertEquals(self.database.count(Person), 100)
self.assertEquals(self.database.count(Person, "first_name = 'Courtney'"), 2) self.assertEquals(self.database.count(Person, "first_name = 'Courtney'"), 2)
self.assertEquals(self.database.count(Person, "birthday > '2000-01-01'"), 22) self.assertEquals(self.database.count(Person, "birthday > '2000-01-01'"), 22)
self.assertEquals(self.database.count(Person, "birthday < '1900-01-01'"), 0) self.assertEquals(self.database.count(Person, "birthday < '1970-03-01'"), 0)
def test_select(self): def test_select(self):
self._insert_and_check(self._sample_data(), len(data)) self._insert_and_check(self._sample_data(), len(data))
@ -45,7 +45,30 @@ class DatabaseTestCase(unittest.TestCase):
results = list(self.database.select(query, Person)) results = list(self.database.select(query, Person))
self.assertEquals(len(results), 2) self.assertEquals(len(results), 2)
self.assertEquals(results[0].last_name, 'Durham') self.assertEquals(results[0].last_name, 'Durham')
self.assertEquals(results[0].height, 1.72)
self.assertEquals(results[1].last_name, 'Scott') self.assertEquals(results[1].last_name, 'Scott')
self.assertEquals(results[1].height, 1.70)
def test_select_partial_fields(self):
self._insert_and_check(self._sample_data(), len(data))
query = "SELECT first_name, last_name FROM test_db.person WHERE first_name = 'Whitney' ORDER BY last_name"
results = list(self.database.select(query, Person))
self.assertEquals(len(results), 2)
self.assertEquals(results[0].last_name, 'Durham')
self.assertEquals(results[0].height, 0) # default value
self.assertEquals(results[1].last_name, 'Scott')
self.assertEquals(results[1].height, 0) # default value
def test_select_ad_hoc_model(self):
self._insert_and_check(self._sample_data(), len(data))
query = "SELECT * FROM test_db.person WHERE first_name = 'Whitney' ORDER BY last_name"
results = list(self.database.select(query))
self.assertEquals(len(results), 2)
self.assertEquals(results[0].__class__.__name__, 'AdHocModel')
self.assertEquals(results[0].last_name, 'Durham')
self.assertEquals(results[0].height, 1.72)
self.assertEquals(results[1].last_name, 'Scott')
self.assertEquals(results[1].height, 1.70)
def _sample_data(self): def _sample_data(self):
for entry in data: for entry in data: