2021-07-30 15:52:01 +02:00
package sqldb
import (
"database/sql"
2021-12-03 11:19:53 +01:00
"encoding/json"
2021-07-30 15:52:01 +02:00
"fmt"
2021-12-03 11:19:53 +01:00
"html/template"
"io/ioutil"
2021-07-30 15:52:01 +02:00
"log"
2021-12-03 11:19:53 +01:00
"os"
2021-07-30 15:52:01 +02:00
"strconv"
"strings"
"github.com/lib/pq"
)
2021-10-28 10:34:10 +02:00
type Db struct {
Driver string
Url string
conn * sql . DB
}
2021-07-30 15:52:01 +02:00
// AssRow : associative row type
type AssRow map [ string ] interface { }
2021-10-28 10:34:10 +02:00
// Select Result
type Rows [ ] AssRow
2021-07-30 15:52:01 +02:00
// Table is a table structure description
2021-10-28 10:34:10 +02:00
type TableInfo struct {
2021-07-30 15:52:01 +02:00
Name string ` json:"name" `
Columns map [ string ] string ` json:"columns" `
2021-10-28 10:34:10 +02:00
db * Db
2021-07-30 15:52:01 +02:00
}
2021-12-03 11:19:53 +01:00
type Link struct {
Source string
Destination string
}
2021-07-30 15:52:01 +02:00
// Open the database
2021-10-28 10:34:10 +02:00
func Open ( driver string , url string ) * Db {
var database Db
2021-07-30 15:52:01 +02:00
var err error
2021-10-28 10:34:10 +02:00
database . Driver = driver
database . Url = url
database . conn , err = sql . Open ( driver , url )
2021-07-30 15:52:01 +02:00
if err != nil {
log . Println ( err )
}
2021-10-28 10:34:10 +02:00
return & database
2021-07-30 15:52:01 +02:00
}
// Close the database connection
2021-10-28 10:34:10 +02:00
func ( db * Db ) Close ( ) {
db . conn . Close ( )
}
func ( db * Db ) Table ( name string ) * TableInfo {
var ti TableInfo
ti . Name = name
ti . db = db
return & ti
2021-07-30 15:52:01 +02:00
}
// GetAssociativeArray : Provide results as an associative array
2021-10-28 10:34:10 +02:00
func ( t * TableInfo ) GetAssociativeArray ( columns [ ] string , restriction string , sortkeys [ ] string , dir string ) ( [ ] AssRow , error ) {
return t . db . QueryAssociativeArray ( t . buildSelect ( "" , columns , restriction , sortkeys , dir ) )
2021-07-30 15:52:01 +02:00
}
// QueryAssociativeArray : Provide results as an associative array
2021-10-28 10:34:10 +02:00
func ( db * Db ) QueryAssociativeArray ( query string ) ( Rows , error ) {
rows , err := db . conn . Query ( query )
2021-07-30 15:52:01 +02:00
if err != nil {
log . Println ( err )
log . Println ( query )
2021-10-28 10:34:10 +02:00
return nil , err
2021-07-30 15:52:01 +02:00
}
defer rows . Close ( )
2023-01-23 11:29:46 +01:00
results := Rows { }
2021-10-28 10:34:10 +02:00
cols , err := rows . Columns ( )
if err != nil {
log . Println ( err )
log . Println ( query )
return nil , err
}
2021-07-30 15:52:01 +02:00
for rows . Next ( ) {
// Create a slice of interface{}'s to represent each column,
// and a second slice to contain pointers to each item in the columns slice.
columns := make ( [ ] interface { } , len ( cols ) )
columnPointers := make ( [ ] interface { } , len ( cols ) )
for i := range columns {
columnPointers [ i ] = & columns [ i ]
}
// Scan the result into the column pointers...
if err := rows . Scan ( columnPointers ... ) ; err != nil {
}
// Create our map, and retrieve the value for each column from the pointers slice,
// storing it in the map with the name of the column as the key.
m := make ( AssRow )
for i , colName := range cols {
val := columnPointers [ i ] . ( * interface { } )
2021-10-28 10:34:10 +02:00
m [ colName ] = fmt . Sprintf ( "%v" , * val )
2021-07-30 15:52:01 +02:00
}
results = append ( results , m )
}
2021-10-28 10:34:10 +02:00
return results , nil
2021-07-30 15:52:01 +02:00
}
// GetSchema : Provide results as an associative array
2021-10-28 10:34:10 +02:00
func ( t * TableInfo ) GetSchema ( ) ( * TableInfo , error ) {
var ti TableInfo
ti . Name = t . Name
ti . db = t . db
2021-10-30 22:21:54 +02:00
cols , err := t . db . QueryAssociativeArray ( "SELECT column_name :: varchar as name, REPLACE(REPLACE(data_type,'character varying','varchar'),'character','char') || COALESCE('(' || character_maximum_length || ')', '') as type, col_description('public." + t . Name + "'::regclass, ordinal_position) as comment from INFORMATION_SCHEMA.COLUMNS where table_name ='" + t . Name + "';" )
2021-10-28 10:34:10 +02:00
if err != nil {
log . Println ( err )
return nil , err
}
ti . Columns = make ( map [ string ] string )
2021-07-30 15:52:01 +02:00
for _ , row := range cols {
2021-10-30 22:21:54 +02:00
var name , rowtype , comment string
2021-07-30 15:52:01 +02:00
for key , element := range row {
if key == "name" {
name = fmt . Sprintf ( "%v" , element )
}
if key == "type" {
rowtype = fmt . Sprintf ( "%v" , element )
}
2021-10-30 22:21:54 +02:00
if key == "comment" {
comment = fmt . Sprintf ( "%v" , element )
}
2021-07-30 15:52:01 +02:00
}
2021-10-28 10:34:10 +02:00
ti . Columns [ name ] = rowtype
2021-10-30 22:21:54 +02:00
if comment != "<nil>" && strings . TrimSpace ( comment ) != "" {
ti . Columns [ name ] = ti . Columns [ name ] + "|" + comment
}
2021-07-30 15:52:01 +02:00
}
2021-10-28 10:34:10 +02:00
return & ti , nil
2021-07-30 15:52:01 +02:00
}
2021-12-03 11:19:53 +01:00
func ( db * Db ) GetSchema ( ) ( [ ] TableInfo , error ) {
var res [ ] TableInfo
tables , err := db . ListTables ( )
if err != nil {
log . Println ( err . Error ( ) )
return nil , err
}
for _ , row := range tables {
for _ , element := range row {
var ti TableInfo
var fullti * TableInfo
ti . Name = fmt . Sprintf ( "%v" , element )
ti . db = db
fullti , err = ti . GetSchema ( )
if err != nil {
log . Println ( err . Error ( ) )
return nil , err
}
res = append ( res , * fullti )
}
}
return res , nil
}
2021-10-28 10:34:10 +02:00
func ( db * Db ) ListTables ( ) ( Rows , error ) {
2021-11-09 13:49:52 +01:00
return db . QueryAssociativeArray ( "SELECT table_name :: varchar as name FROM information_schema.tables WHERE table_schema = 'public' ORDER BY table_name;" )
2021-07-30 15:52:01 +02:00
}
2021-10-28 10:34:10 +02:00
func ( db * Db ) CreateTable ( t TableInfo ) error {
t . db = db
2021-07-30 15:52:01 +02:00
query := "create table " + t . Name + " ( "
columns := ""
for name , rowtype := range t . Columns {
if fmt . Sprintf ( "%v" , name ) == "id" {
columns += fmt . Sprintf ( "%v" , name ) + " " + "SERIAL PRIMARY KEY,"
} else {
2021-10-30 22:21:54 +02:00
desc := strings . Split ( fmt . Sprintf ( "%v" , rowtype ) , "|" )
columns += fmt . Sprintf ( "%v" , name ) + " " + desc [ 0 ]
2021-07-30 15:52:01 +02:00
columns += ","
}
}
query += columns
query = query [ : len ( query ) - 1 ] + " )"
2021-10-28 10:34:10 +02:00
_ , err := t . db . conn . Query ( query )
2021-07-30 15:52:01 +02:00
if err != nil {
2021-10-28 10:34:10 +02:00
log . Println ( err . Error ( ) )
return err
2021-07-30 15:52:01 +02:00
}
2021-10-30 22:21:54 +02:00
for name , rowtype := range t . Columns {
desc := strings . Split ( fmt . Sprintf ( "%v" , rowtype ) , "|" )
if len ( desc ) > 1 {
query = "COMMENT ON COLUMN " + t . Name + "." + fmt . Sprintf ( "%v" , name ) + " IS '" + desc [ 1 ] + "'"
_ , err := t . db . conn . Query ( query )
if err != nil {
log . Println ( err . Error ( ) )
return err
}
}
}
2021-10-28 10:34:10 +02:00
return nil
2021-07-30 15:52:01 +02:00
}
2021-10-28 10:34:10 +02:00
func ( t * TableInfo ) DeleteTable ( ) error {
query := "drop table " + t . Name
_ , err := t . db . conn . Query ( query )
2021-07-30 15:52:01 +02:00
if err != nil {
2021-10-28 10:34:10 +02:00
log . Println ( err . Error ( ) )
return err
2021-07-30 15:52:01 +02:00
}
2021-10-28 10:34:10 +02:00
query = "drop sequence if exists sq_" + t . Name
_ , err = t . db . conn . Query ( query )
2021-07-30 15:52:01 +02:00
if err != nil {
2021-10-28 10:34:10 +02:00
log . Println ( err . Error ( ) )
return err
2021-07-30 15:52:01 +02:00
}
2021-10-28 10:34:10 +02:00
return nil
2021-07-30 15:52:01 +02:00
}
2021-10-30 22:21:54 +02:00
func ( t * TableInfo ) AddColumn ( name string , sqltype string , comment string ) error {
2021-10-28 10:34:10 +02:00
query := "alter table " + t . Name + " add " + name + " " + sqltype
rows , err := t . db . conn . Query ( query )
2021-07-30 15:52:01 +02:00
if err != nil {
log . Println ( err )
2021-10-28 10:34:10 +02:00
return err
2021-07-30 15:52:01 +02:00
}
2021-10-30 22:21:54 +02:00
if strings . TrimSpace ( comment ) != "" {
query = "COMMENT ON COLUMN " + t . Name + "." + name + " IS '" + comment + "'"
_ , err = t . db . conn . Query ( query )
if err != nil {
log . Println ( err . Error ( ) )
return err
}
}
2021-07-30 15:52:01 +02:00
defer rows . Close ( )
2021-10-28 10:34:10 +02:00
return nil
2021-07-30 15:52:01 +02:00
}
2021-10-28 10:34:10 +02:00
func ( t * TableInfo ) DeleteColumn ( name string ) error {
query := "alter table " + t . Name + " drop " + name
rows , err := t . db . conn . Query ( query )
2021-07-30 15:52:01 +02:00
if err != nil {
log . Println ( err )
2021-10-28 10:34:10 +02:00
return err
2021-07-30 15:52:01 +02:00
}
defer rows . Close ( )
2021-10-28 10:34:10 +02:00
return nil
2021-07-30 15:52:01 +02:00
}
2021-12-03 11:19:53 +01:00
func ( db * Db ) ImportSchema ( filename string ) {
jsonFile , err := os . Open ( filename )
if err != nil {
log . Println ( err )
}
defer jsonFile . Close ( )
byteValue , _ := ioutil . ReadAll ( jsonFile )
var jsonSource [ ] TableInfo
json . Unmarshal ( [ ] byte ( byteValue ) , & jsonSource )
for _ , ti := range jsonSource {
ti . db = db
err = db . CreateTable ( ti )
if err != nil {
log . Println ( err . Error ( ) )
}
}
}
func ( db * Db ) ClearImportSchema ( filename string ) {
jsonFile , err := os . Open ( filename )
if err != nil {
log . Println ( err )
}
defer jsonFile . Close ( )
byteValue , _ := ioutil . ReadAll ( jsonFile )
var jsonSource [ ] TableInfo
json . Unmarshal ( [ ] byte ( byteValue ) , & jsonSource )
for _ , ti := range jsonSource {
ti . db = db
err = ti . DeleteTable ( )
if err != nil {
log . Println ( err . Error ( ) )
}
}
}
2021-10-28 10:34:10 +02:00
func ( db * Db ) ListSequences ( ) ( Rows , error ) {
return db . QueryAssociativeArray ( "SELECT sequence_name :: varchar FROM information_schema.sequences WHERE sequence_schema = 'public' ORDER BY sequence_name;" )
2021-07-30 15:52:01 +02:00
}
2021-10-28 10:34:10 +02:00
func ( t * TableInfo ) buildSelect ( key string , columns [ ] string , restriction string , sortkeys [ ] string , dir ... string ) string {
2021-07-30 15:52:01 +02:00
if key != "" {
columns = append ( columns , key )
}
2021-10-28 10:34:10 +02:00
query := "select " + strings . Join ( columns , "," ) + " from " + t . Name
2021-07-30 15:52:01 +02:00
if restriction != "" {
query += " where " + restriction
}
2021-11-03 15:11:48 +01:00
if len ( sortkeys ) > 0 && len ( sortkeys [ 0 ] ) > 0 {
2021-07-30 15:52:01 +02:00
query += " order by " + strings . Join ( sortkeys , "," )
}
if len ( dir ) > 0 {
query += " " + dir [ 0 ]
}
return query
}
2021-10-28 10:34:10 +02:00
func ( t * TableInfo ) Insert ( record AssRow ) ( int , error ) {
2021-07-30 15:52:01 +02:00
columns := ""
values := ""
2021-10-28 10:34:10 +02:00
t , err := t . GetSchema ( )
if err != nil {
log . Println ( err )
return - 1 , err
}
2021-07-30 15:52:01 +02:00
var id int
for key , element := range record {
2022-05-17 17:03:38 +02:00
columns += key + ","
values += FormatForSQL ( t . Columns [ key ] , element ) + ","
2021-07-30 15:52:01 +02:00
}
2022-05-19 10:49:19 +02:00
err = t . db . conn . QueryRow ( "INSERT INTO " + t . Name + "(" + removeLastChar ( columns ) + ") VALUES (" + removeLastChar ( values ) + ") RETURNING id" ) . Scan ( & id )
return id , err
2021-07-30 15:52:01 +02:00
}
2021-10-28 10:34:10 +02:00
func ( t * TableInfo ) Update ( record AssRow ) error {
2021-07-30 15:52:01 +02:00
2021-10-28 10:34:10 +02:00
t , err := t . GetSchema ( )
if err != nil {
log . Println ( err )
return err
}
2021-07-30 15:52:01 +02:00
id := ""
stack := ""
for key , element := range record {
2022-05-17 17:03:38 +02:00
if key == "id" {
id = fmt . Sprintf ( "%v" , element )
2021-07-30 15:52:01 +02:00
} else {
2022-05-17 17:03:38 +02:00
stack = stack + " " + key + " = " + FormatForSQL ( t . Columns [ key ] , element ) + ","
2021-07-30 15:52:01 +02:00
}
2022-05-17 17:03:38 +02:00
2021-07-30 15:52:01 +02:00
}
stack = removeLastChar ( stack )
2021-10-28 10:34:10 +02:00
query := ( "UPDATE " + t . Name + " SET " + stack + " WHERE id = " + id )
rows , err := t . db . conn . Query ( query )
2021-07-30 15:52:01 +02:00
if err != nil {
log . Println ( query )
log . Println ( err )
2021-10-28 10:34:10 +02:00
return err
2021-07-30 15:52:01 +02:00
}
defer rows . Close ( )
2021-10-28 10:34:10 +02:00
return nil
2021-07-30 15:52:01 +02:00
}
2021-10-28 10:34:10 +02:00
func ( t * TableInfo ) Delete ( record AssRow ) error {
2021-07-30 15:52:01 +02:00
id := ""
values := ""
for key , element := range record {
if key == "id" {
values += fmt . Sprintf ( "%v" , element ) + ","
id = removeLastChar ( values )
2023-01-30 12:16:57 +01:00
break
2021-07-30 15:52:01 +02:00
}
}
2021-10-28 10:34:10 +02:00
query := ( "DELETE FROM " + t . Name + " WHERE id = " + id )
rows , err := t . db . conn . Query ( query )
2021-07-30 15:52:01 +02:00
if err != nil {
log . Println ( query )
log . Println ( err )
2021-10-28 10:34:10 +02:00
return err
2021-07-30 15:52:01 +02:00
}
defer rows . Close ( )
2021-10-28 10:34:10 +02:00
return nil
2021-07-30 15:52:01 +02:00
}
2021-10-28 10:34:10 +02:00
func ( t * TableInfo ) UpdateOrInsert ( record AssRow ) ( int , error ) {
id := - 1
2021-07-30 15:52:01 +02:00
for key , element := range record {
if key == "id" {
sid := fmt . Sprintf ( "%v" , element )
id , _ = strconv . Atoi ( sid )
2023-01-30 12:16:57 +01:00
break
2021-07-30 15:52:01 +02:00
}
}
2021-10-28 10:34:10 +02:00
if id == - 1 {
return t . Insert ( record )
2021-07-30 15:52:01 +02:00
} else {
2021-10-28 10:34:10 +02:00
t . Update ( record )
return id , nil
2021-07-30 15:52:01 +02:00
}
}
func removeLastChar ( s string ) string {
r := [ ] rune ( s )
return string ( r [ : len ( r ) - 1 ] )
}
2021-10-28 15:12:53 +02:00
func ( ar * AssRow ) GetString ( column string ) string {
str := fmt . Sprintf ( "%v" , ( * ar ) [ column ] )
return str
}
func ( ar * AssRow ) GetInt ( column string ) int {
str := fmt . Sprintf ( "%v" , ( * ar ) [ column ] )
val , _ := strconv . Atoi ( str )
return val
}
func ( ar * AssRow ) GetFloat ( column string ) float64 {
str := fmt . Sprintf ( "%v" , ( * ar ) [ column ] )
val , _ := strconv . ParseFloat ( str , 64 )
return val
}
2021-10-28 15:31:36 +02:00
func Quote ( str string ) string {
return pq . QuoteLiteral ( str )
}
2021-12-03 11:19:53 +01:00
func ( db * Db ) SaveSchema ( generatedFilename string ) error {
schema , err := db . GetSchema ( )
if err != nil {
log . Println ( err )
return err
}
2022-04-20 15:22:31 +02:00
// file, _ := json.Marshal(schema)
2021-12-03 11:19:53 +01:00
file , _ := json . MarshalIndent ( schema , "" , " " )
_ = ioutil . WriteFile ( generatedFilename , file , 0644 )
return nil
}
func buildLinks ( schema [ ] TableInfo ) [ ] Link {
var links [ ] Link
for _ , ti := range schema {
fmt . Println ( ti . Name )
for column , _ := range ti . Columns {
if strings . HasSuffix ( column , "_id" ) {
tokens := strings . Split ( column , "_" )
linkedtable := tokens [ len ( tokens ) - 2 ]
var link Link
link . Source = ti . Name
link . Destination = linkedtable
links = append ( links , link )
}
}
}
return links
}
func ( db * Db ) GenerateTemplate ( templateFilename string , generatedFilename string ) error {
schema , err := db . GetSchema ( )
if err != nil {
log . Println ( err )
return err
}
links := buildLinks ( schema )
data := struct {
Tbl [ ] TableInfo
Lnk [ ] Link
} {
schema ,
links ,
}
t , err := template . ParseFiles ( templateFilename )
if err != nil {
log . Println ( err )
return err
}
f , err := os . Create ( generatedFilename )
if err != nil {
log . Println ( "create file: " , err )
return err
}
err = t . Execute ( f , data )
if err != nil {
log . Println ( err )
return err
}
return nil
}
func ( db * Db ) GenerateTableTemplates ( templateFilename string , outputFolder string , extension string ) error {
schema , err := db . GetSchema ( )
if err != nil {
log . Println ( err )
return err
}
for _ , ti := range schema {
t , err := template . ParseFiles ( templateFilename )
if err != nil {
log . Println ( err )
return err
}
f , err := os . Create ( outputFolder + string ( os . PathSeparator ) + ti . Name + "." + extension )
if err != nil {
log . Println ( "create file: " , err )
return err
}
err = t . Execute ( f , ti )
if err != nil {
log . Println ( err )
return err
}
}
return nil
}
2022-05-17 17:03:38 +02:00
func FormatForSQL ( datatype string , value interface { } ) string {
2022-05-20 15:12:04 +02:00
if value == nil {
return "NULL"
}
2022-05-17 17:03:38 +02:00
strval := fmt . Sprintf ( "%v" , value )
if ! strings . Contains ( datatype , "char" ) && len ( strval ) == 0 {
return "NULL"
}
2022-05-19 10:49:19 +02:00
if strings . Contains ( datatype , "char" ) || strings . Contains ( datatype , "date" ) || strings . Contains ( datatype , "timestamp" ) {
2022-05-17 17:03:38 +02:00
return fmt . Sprint ( pq . QuoteLiteral ( strval ) )
}
return fmt . Sprint ( strval )
}