#### Source code for train.py ####


#!/usr/bin/python
# -*- coding: iso-8859-2 -*-

from rgbimg import*
from snns import krui, util
from sys import argv

ImageBoundsError = "Image width or size not divisible by 8"
InternalError = "Internal error"

class image: 

	"""Nahrává obrázek ve formátu RGBA (SGI .rgb). Čísla jsou uložena v
	řetězci ARGB"""
	def load(self, file): 
		self._data = []
		(self._width, self._height) = sizeofimage(file)
		print "Obrázek načten, rozměry:", self._width, self._height
		if self._width%8 != 0 or self._height%8 != 0: 
			raise ImageBoundsError

		rawdata = longimagedata(file)
		i = 0
		for x in rawdata: 
			x = ord(x)
			if (i+1)%4 == 0: 
				self._data.append(x)
			i += 1
	
	def store(self, file): 
		data = ""
		for x in self._data: 
			data += chr(255)+chr(x)*3
		if len(data) != self._width*self._height*4: 
			raise InternalError
		longstoimage(data, self._width, self._height, 1, file)

	def blank(self): 
		for i in xrange(len(self._data)): 
			self._data[i] = 0

	def getBlock(self, index): 
		result = []
		wb = self._width/8
		v = (index/wb)*8*self._width # nad blokem
		h = (index%wb)*8 # zleva
		for y in xrange(8): 
			for x in xrange(8): 
				result.append(self._data[v+h+x+self._width*y])
		return result

	def getBlockCount(self): 
		return len(self._data)/64

	def setBlock(self, index, block): 
		wb = self._width/8
		v = (index/wb)*8*self._width # nad blokem
		h = (index%wb)*8 # zleva
		b = 0
		for y in xrange(8): 
			for x in xrange(8): 
				self._data[v+h+x+self._width*y] = block[b]
				b += 1

	def __iter__(self): 
		self._ix = 0
		self._blocks = self.getBlockCount()
		return self

	def next(self): 
		if self._ix>=self._blocks: 
			raise StopIteration
		block = self.getBlock(self._ix)
		self._ix = self._ix+1
		return block

# pomocne funkce
def pix2real(p): 
	return p*(2.0/255.0)-1.0

def real2pix(r): 
	return int(r*128.0+127.5)

krui.setLearnFunc('BackpropBatch')
krui.setUpdateFunc('Topological_Order')
krui.setUnitDefaults(1, 0, krui.INPUT, 0, 1, 'Act_TanH', 'Out_Identity')

print "Nahrávám obrázek"

im = image()
im.load("beerfox2.rgb")
print "Velikost:", len(im._data)

print "Konstruuji síť"

vnejsi_vrstvy = 8*8
vnitrni_vrstva = 4*4

# vstupni vrstva 8x8 (64 neuronu)
pos = [0, 0, 0]
inputs = []
for i in range(1, vnejsi_vrstvy+1): 
	pos[0] = i
	num = krui.createDefaultUnit()
	inputs.append(num)
	krui.setUnitName(num, 'Input_%i'%i)
	krui.setUnitPosition(num, pos)

# skryta vrstva 4x4 (16 neuronu)
pos[1] = 2
hidden = []
for i in range(1, vnitrni_vrstva+1): 
	pos[0] = i+3
	num = krui.createDefaultUnit()
	hidden.append(num)
	krui.setUnitName(num, 'Hidden_%i'%i)
	krui.setUnitTType(num, krui.HIDDEN)
	krui.setUnitPosition(num, pos)
	krui.setCurrentUnit(num)
	for src in inputs: 
		krui.createLink(src, 0)

# vystupni vrstva 8x8 (64)
pos[1] = 4
outputs = []
for i in range(1, vnejsi_vrstvy+1): 
	pos[0] = i
	num = krui.createDefaultUnit()
	outputs.append(num)
	krui.setUnitName(num, 'Output_%i'%i)
	krui.setUnitTType(num, krui.OUTPUT)
	krui.setUnitPosition(num, pos)
	krui.setCurrentUnit(num)
	for src in hidden: 
		krui.createLink(src, 0)

print "Vytvářím vzorky pro SNNS"

krui.deleteAllPatterns()
patset = krui.allocNewPatternSet()
for block in im: 
	for i in xrange(vnejsi_vrstvy): 
		krui.setUnitActivation(inputs[i], pix2real(block[i]))
		krui.setUnitActivation(outputs[i], pix2real(block[i]))
	krui.newPattern()

krui.initializeNet(-1, 1)
krui.shufflePatterns(1)
krui.DefTrainSubPat()

pruch_1 = int(argv[1])
pruch_2 = int(argv[2])

print "Fáze učení (%d+%d průchodů)"%(pruch_1, pruch_2)

i = 0
# první fáze učení
while i<pruch_1: 
	res = krui.learnAllPatterns(0.3, 0.1)
	if not i%100:  print "Fáze 1, chyba v cyklu %d:"%i, res[0]
	i = i+1

# druhá fáze (jemnější)
i = 0
while i<pruch_2: 
	res = krui.learnAllPatterns(0.03, 0.1)
	if not i%100:  print "Fáze 2, chyba v cyklu %d:"%i, res[0]
	i = i+1

print "Rekonstruuji původní obrázek"

im.blank()

for p in xrange(krui.getNoOfPatterns()): 
	krui.setPatternNo(p+1)
	krui.showPattern(1); 
	krui.updateNet()
	block = []
	for u in xrange(64+16+1, 64*2+16+1): 
		block.append(real2pix(krui.getUnitActivation(u)))
	im.setBlock(p, block)

print "Zapisuji výsledný obrázek na disk"

im.store("compressed.rgb")

print "Vytvářím soubor vzorků pro SNNS"
krui.saveNewPatterns('image.pat', patset)
print "Vytvářím soubor sítě pro SNNS"
krui.saveNet('image.net', 'image')

# konec

[Created with py2html Ver:0.62]

Valid HTML 4.01!