Implementing the Closest Pair Algorithm in OpenCL and WebCL

I was looking for a fun geometry problem to solve and came across the closest pair of points problem.

The problem statement is:
Given a set of points, in 2D, compute the closest pair of points.

The brute force algorithm takes O(n^2) time. There’s a better solution described at the Wikipedia page Closest pair of points problem, which takes O(n \cdot \textrm{log}(n)) time.
As far as I know, there aren’t any posted solutions using OpenCL to compute the closest pair. So I’ve implemented one, it’s posted at closest-pair.

There are a few challenges in adapting the algorithm for OpenCL. Namely, we can’t use recursion so we must convert the recursive algorithm to a procedural one. In this case it’s not to complicated because the structure of the solution is easily converted from the top-down recursive algorithm to a bottom-up parallel algorithm. There are some tricky issues when the number of points are not a power of 2. I’ve commented the code for those cases.

The second challenge is using WebCL. WebCL has an additional restriction that you can’t pass structures between Javascript and OpenCL kernels. Because of this I had to use dumb arrays of simple uint’s instead of using arrays of uint2, uint3, and uint4. This made the code more verbose. To help reduce the verbosity I added macros in the OpenCL ckStrips kernel. The WebCL version is posted at
closest_pair_ocl.html.

I hope the solution is useful to someone. Enjoy reading the code, it requires careful thought.
Check it out at closest-pair or view the WebCL one online at closest_pair_ocl.html.

Computing the Smallest Enclosing Disk

I recently read in chapter 4 of Computational Geometry by de Berg et al. the problem of computing the smallest enclosing disk for a set of points.

I’ve shamelessly stolen the algorithm from there and done a simple conversion to Javascript.

The code is under the canvas-geolib GitHub repository in the geometry.js file, there’s also an example @ enclosingdisk.html. The example initially consists of three points and the smallest disk enclosing them. Click anywhere on the canvas and a new point will be added and the disk redrawn.

I don’t want to go over the general algorithm but I do want explain computing the unique disk with three points given for its boundary. In geometry.js the function “enclosingDisk3Points” takes three points and returns the unique disk that has those points on its boundary.

The below figure shows the two defining characteristics of the disk, which are (1) the disk center is centered at (x,y) and (2) the distance from the center to all three points (p0, p1, & p3) is equal i.e. the distance from the center to all three points is d.

center point

From this computing (x,y) and d is simple. For simplicity we assume p_0 = (0, 0) and p_1 = (0, {p_y}), also we let p_2 = ({q}_x, {q}_y).

The set of equations to solve is:

d^2 = x^2 + y^2, d^2 = {(x - 0)}^2 + {(y - p_y)}^2 = x^2 + y^2 - {p_y}^2 - 2\cdot p_y y, and d^2 = {(x - q_x)}^2 + {(y- q_y)}^2 = x^2 + y^2 + {q_x}^2 + {q_y}^2 - 2 q_x x - 2 q_y y.

Solving for y, we have y = p_y/2, using this we can solve for x, which yields x = \frac{{\|q\|}^2 - q_y p_y}{2 q_x}. Finally we also have d = \sqrt{x^2 + y^2}, which finishes our computation.

The below Javascript code implements the above computation and also adds the preprocessing steps of (1) translating the origin to p_0 and (2) rotating the coordinate system so that p_1 is of the form p_1 = (0, {p_1}_y).

// return the unique disk with p1, p2, and p3 as boundary points.
function enclosingDisk3Points(_p1, _p2, _p3){

    var p1 = [_p1[0], _p1[1]];
    var p2 = [_p2[0], _p2[1]];
    var p3 = [_p3[0], _p3[1]];
    if (dist(p1, p3) > dist(p1, p2)){
var p = p2;
p2 = p3;
p3 = p;
    }

    var p = p1;
    // make p1 the origin
    p2[0] = p2[0] - p1[0];
    p2[1] = p2[1] - p1[1];
    p3[0] = p3[0] - p1[0];
    p3[1] = p3[1] - p1[1];
    
    // apply rotation matrix to make p2.x = 0
    // the rotation matrix is
    // | p2[1]/dist(p2), -1 * p2[0]/dist(p2) |
    // | p2[0]/dist(p2), p2[1]/dist(p2) |
    //
    var original_p2 = [p2[0], p2[1]];
    p2[0] = 0;
    p2[1] = d(original_p2);

    // apply rotation matrix to p3
    var original_p3 = [p3[0], p3[1]]
    p3[0] = original_p2[1]/d(original_p2) * original_p3[0] - original_p2[0]/d(original_p2) * original_p3[1]
    p3[1] = original_p2[0]/d(original_p2) * original_p3[0] + original_p2[1]/d(original_p2) * original_p3[1]

    // the unique disk with the points p1, p2, and p3 as boundary points is
    // defined by the equation y = p2.y/2 & x = (d(p3)^2 + p3.y * p2.y)/(2 * p3.x)
    var y = p2[1]/2.0;
    var x = (d(p3) * d(p3) - p3[1] * p2[1])/(2 * p3[0]);

    // apply inverse of rotation matrix
    var rotated_x = original_p2[1]/d(original_p2) * x + original_p2[0]/d(original_p2) * y
    var rotated_y = -1 * original_p2[0]/d(original_p2) * x + original_p2[1]/d(original_p2) * y;

    // translate back
    rotated_x = rotated_x + p1[0];
    rotated_y = rotated_y + p1[1];
    
    var radius = d([rotated_x - p1[0], rotated_y - p1[1]]);
    return [[rotated_x, rotated_y], radius];
    
}

Solving Tangrams Using JTS

The project, 2dfit, solves Tangram puzzles using the Java Topology Suite (JTS). The algorithm implementation is based on what I outlined in item 3 of the post “http://bloggingmath.wordpress.com/2007/05/28/tangram-puzzle/“.

The implementation difficulties are from using floating point arithmetic, which is not robust for geometric operations. The JTS library attempts to minimize this by a coordinate snapping technique. But for the operations used in solving Tangrams the provided snapping was not sufficient.

There’s an option in JTS to specify the snapping tolerance (it has a fairly small default). I added small wrapper functions for the two operations of Boolean intersection and Boolean difference. The wrapper functions apply successively larger snapping tolerances up to a factor of epsilon, where epsilon = 1e-5. The below code shows the wrapper functions (in the code, g1 is the Tangram and g2 is a puzzle piece).

    public static Geometry SemiRobustGeoOp(Geometry g1, Geometry g2, int op) throws Exception {
        double e1 = EPSILON/10;
        double snapTolerance = GeometrySnapper.computeOverlaySnapTolerance(g1, g2);
        while (snapTolerance < e1) {
            try {
                Geometry[] geos = GeometrySnapper.snap(g1, g2, snapTolerance);
                switch (op) {                    
                case DIF_OP:   // difference
                    return geos[0].difference(geos[1]);
                case UNION_OP: // union
                    return geos[0].union(geos[1]);
                default:
                    throw new Exception("unhandled semirobustgeoop: " + op);
                }
            } catch (TopologyException e){
                snapTolerance *= 2;
            }
        }
        return null;
    }

    public static boolean SemiRobustGeoPred(Geometry g1, Geometry g2, int pred) throws Exception {
        double e1 = EPSILON/10;
        double snapTolerance = GeometrySnapper.computeOverlaySnapTolerance(g1, g2);
        while (snapTolerance < e1) {
            try {
                Geometry[] geos = GeometrySnapper.snap(g1, g2, snapTolerance);
                switch (pred) {                    
                case COVER_PRED: // 
                    return geos[0].covers(geos[1]);
                default:
                    throw new Exception("unhandled semirobustgeopred: " + pred);
                }
            } catch (TopologyException e){
                snapTolerance *= 2;
            }
        }
        return false;
    }

Using the wrapper functions was key to a more robust implementation. The below figure shows a solved Tangram puzzle (from Test.java:FitTest_ToSingleLargeTriangle()), in the figure the puzzle pieces are labeled l1.dat, l2.dat,…, l7.dat (I was lazy in naming the files). It’s the result of running FitTest_ToSingleLargeTriange() in Test.java and plotting the result using gnuplot.

I used the symmetry of each puzzle piece and a heuristic for choosing which puzzle piece to fit in reducing the number of permutations used for solving a Tangram.

For the puzzle piece symmetry, I used that the square is completely symmetric so only its first line segment needs to be used when fitting it. The triangle pieces are only partially symmetric so two of their three line segments need to be tried.

The below figure illustrates the symmetry of the square:

I used two heuristics for which pieces to try fitting, first try larger pieces before smaller ones and two skip a piece if another identical piece has already failed to be fitted (there are two identical small triangles and similarly two identical large triangles).

With the above two optimizations it takes ~1min to solve a Tangram. Without the optimizations the algorithm did not complete for the Tangrams I tried.

River Flow Forecasting Using Support Vector Machines

Over the past few months I and a colleague (Brian Wallace) have been working on a river flow forecasting paper. A draft version is available @ River Flow Paper.

The goal of our work was to beat the current forecast methods used by the Department of Water Resources for the April through July American River flow. The Department of Water Resources uses an aggregation of human judgement and linear regression equations for generating their forecasts. Given their methods they are surprisingly hard to beat!

We spent a few months trying different Machine Learning methods with little success. Many of the methods we tried resulted in forecasts that were significantly worse than the current forecasts, a few methods such as a properly trained neural network gave forecasts that were comparable to the current forecasts. Finally, I decided to use a Support Vector Machine (SVM) for producing forecasts, after testing a large combination of parameters the forecasts started being significantly better than the current ones.

The data we used for generating forecasts is available online @ https://github.com/bjwbell/California-Water-Runoff-Forecasting. The takeaway message is that we improved the forecast relative error from ~65% to ~48%. The below table shows the forecasts for the last 10 years.

SVM Forecasts 2001-2010
Year Actual (AcreFt)   Predicted (AcreFt)   |Error| (AcreFt)  
2001    552,626 689,472 136,846
2002    973,817 1,028,681 54,864
2003    1,354,434 459,476 894,957
2004    632,159 713,440 81,281
2005    2,003,878 1,844,360 159,517
2006    2,622,387 2,315,193 307,193
2007    522,651 293,256 229,394
2008    674,287 800,080 125,793
2009    1,068,327 1,253,523 185,196
2010    1,486,780 1,023,649 463,130
Mean 1,189,135 1,042,113 263,817
Root mean squared error 355,856
Relative absolute error 48.65%
Root relative squared error 54.14%

The forecasts currently used by the Department of Water Resources produced relative errors of 63.82% and root relative squared errors of 69.15%. Using modern methods for SVM’s gave us an increase in relative accuracy of over 15%! This was a fantastic result and shows the large payoff in keeping up with the state of art for something as ordinary as river flow forecasting.

Haskell & LLVM Talk

I did a talk at the Linux Users’ Group of Davis yesterday evening that went very well. The QA and discussion afterwards was much better than previous talks I’ve given to students, the attendees from the Bay Area Haskell Users’s group were very sharp and had good comments.

Photos and slides are posted here and the slides are also viewable below.

The Basics of LLVM (Compiler Series Part X)

For the first pass at using LLVM for code generation we’re only generating code for simple expression of the form “1 + 2″ etc. This frees us from having to consider many many issues such as how to deal with the memory layout of objects. In future passes we relax the conditions on code generation.

I advise you to read the LLVM Haskell documentation at http://hackage.haskell.org/package/llvm (it’s under the modules section) to familiarize yourself with the API. There are two parts to the LLVM Haskell bindings, a low level interface to the C API, and a high level wrapper. I tried using the high level wrapper but the type declarations confused me. I’m not that fluent in Haskell and the high level wrapper makes extensive use of type classes, monads, and a few other constructs. I decided to use the low level wrapper (Core/Util.hs) around the C API, which uses less of Haskell’s advanced type features.

With the above said, I want to dive in and create some LLVM IR code. A simple NewL program is:

class Main {
    public static void main(string[] args) {
        System.out.println("");
    }
}
class Foo 
{
   public int foo()
   { 
      return 1 + 2;
   }
}

For now I’m ignoring the “Main” class and concentrating only on:

   public int foo()
   { 
      return 1 + 2;
   }

Our plan is to traverse the parse tree and generate code at each node. The parse tree for “foo” is shown below:

To use LLVM to generate code we need to know some LLVM concepts. To use LLVM we need to create a module by calling “Util.createModule.” Util is a wrapper that provides some syntactic sugar for calling the C API of LLVM. One note is that almost any call to the LLVM returns “IO returnType” that is all the return types are in the IO monad. The reason is that most LLVM calls are not pure functions.

To create function a function LLVM two things are needed the module, m, that the function will reside in and the function type definition. The below code creates the function type defintion.

getLLVMType :: TypeNames.Type -> FFI.TypeRef
getLLVMType TypeInt = FFI.int64Type
getLLVMType TypeBool = FFI.int1Type
getLLVMType _ = error("Unimplemented type")

getArgTypes :: FormalList -> [FFI.TypeRef]
getArgTypes FEmpty = []
getArgTypes (FormalList theType ident FEmpty) = [(getLLVMType theType)]
getArgTypes (FormalList theType ident rest) =
    [(getLLVMType theType)] ++ (getArgTypes rest) 

methodType = Util.functionType False (getLLVMType theType) (getArgTypes formalList)

With “methodType” defined we add the function (which is so far empty) to the module, m, by calling

  method <- Util.addFunction m FFI.ExternalLinkage methodName methodType

We now have an empty function called “method.” We create a builder to emit instructions and add a basic block to the method using the below code:

  bld <- U.createBuilder
  entry <- U.appendBasicBlock method "entry"
  U.positionAtEnd bld entry

With the basic block added it is time to emit the code to compute the expression “return 1+2;” We do this with the following code.

  withBuilder bld $ \bldRef ->
      do
        expVal <- codeGenExp exp bld context
        FFI.buildRet bldRef expVal 

The call to “withBuilder” takes some explaining. It takes a “Builder” object and a function of type “BuilderRef -> IO a” and returns “IO a” where “a” is any type. We need “withBuilder” because the function “FFI.buildRet” has type “BuilderRef -> ValueRef -> IO ValueRef.” That is we can’t pass a “Builder” object to “FFI.buildRet” directly we have to pass a “BuilderRef” to it.

If you’re paying attention you notice that only the line “FFI.buildRet bldRef expVal ” actually emits any code. Emitting the code to generate the value of the expression “exp” is left the to the “codeGenExp” function. That code is not interesting and you can view it in the code listing at the end of this post. You are probably interested in what kind of LLVM IR code we generate from the function. See the below listing for what it looks like.

define i64 @foo()
{
entry:
  ret i64 3
}

The “i64″ symbol means that the function returns a 64 bit integer. And “ret i64 3″ means return “3″ as a 64 bit integer.

In my next post I’ll continue explore using LLVM to generate code. Specifically I want to show how to represent strings and other composite objects in LLVM.

The code listing for the entire code generation module is shown below. If you want to compile the code check it out of the github repository at http://github.com/bjwbell/NewL-Compiler and read the README file in the llvm_code_generation1 directory.

module CodeGen where
import TypeNames
import Data.Word
import Data.Int(Int32)
import Data.Typeable as T
import Foreign.C.String as CS
import System.IO.Unsafe
import LLVM.FFI.Core as FFI
import LLVM.Core.Util as U

codeGen :: Program -> IO String
codeGen (Program mainClass classDeclList) = codeGenClassDeclList classDeclList

codeGenClassDeclList :: ClassDeclList -> IO String
codeGenClassDeclList CEmpty = return "Ok"
codeGenClassDeclList (ClassDeclList classDecl classDeclList)  = 
    do
      result <- codeGenClassDecl classDecl
      if result /= "Ok"
        then return "fail: codeGenClassDeclList"
        else codeGenClassDeclList classDeclList
      return "Ok"

codeGenClassDecl :: ClassDecl -> IO String
codeGenClassDecl (ClassDecl className varDeclList methodDeclList) = 
    do
      m <- U.createModule className
      codeGenMethodDeclList methodDeclList m [("this", className)]

codeGenMethodDeclList :: MethodDeclList -> U.Module -> [(String, String)] -> IO String 
codeGenMethodDeclList MEmpty m context = return "Ok" -- no methods so automatically successful
codeGenMethodDeclList (MethodDeclList methodDecl methodDeclList) m context =
    do
      result <- codeGenMethodDecl methodDecl m context
      if result /= "Ok"
        then return "fail: codeGenMethodDeclList"
        else codeGenMethodDeclList methodDeclList m context


getLLVMType :: TypeNames.Type -> FFI.TypeRef
getLLVMType TypeInt = FFI.int64Type
getLLVMType TypeBool = int1Type
getLLVMType _ = error("Unimplemented type")

getArgTypes :: FormalList -> [FFI.TypeRef]
getArgTypes FEmpty = []
getArgTypes (FormalList theType ident FEmpty) = [(getLLVMType theType)]
getArgTypes (FormalList theType ident rest) =
    [(getLLVMType theType)] ++ (getArgTypes rest) 

codeGenMethodDecl :: MethodDecl -> U.Module -> [(String, String)] -> IO String
codeGenMethodDecl (MethodDecl theType methodName formalList varDeclList statementList exp) m context = do
  let methodType = U.functionType False (getLLVMType theType) (getArgTypes formalList)
  method <- U.addFunction m FFI.ExternalLinkage methodName methodType
  bld <- U.createBuilder
  entry <- U.appendBasicBlock method "entry"
  U.positionAtEnd bld entry
  withBuilder bld $ \bldRef ->
      do
        expVal <- codeGenExp exp bld context
        FFI.buildRet bldRef expVal 
  FFI.dumpValue method
  return "Ok"

codeGenExp :: Exp -> U.Builder -> [(String, String)] -> IO FFI.ValueRef
codeGenExp (ExpOp exp1 char exp2) bld context = 
    do
      val1 <- codeGenExp exp1 bld context
      val2 <- codeGenExp exp2 bld context
      case char of 
        '+' -> U.withBuilder bld $ \ bldRef ->
               withCString "" $ \ cString -> buildAdd bldRef val1 val2 cString
        '-' -> U.withBuilder bld $ \ bldRef ->
               withCString "" $ \ cString -> buildAdd bldRef val1 val2 cString
        '*' -> U.withBuilder bld $ \ bldRef ->
               withCString "" $ \ cString -> buildAdd bldRef val1 val2 cString
        '/' -> U.withBuilder bld $ \ bldRef ->
               withCString "" $ \ cString -> buildAdd bldRef val1 val2 cString
        _ -> error ("Unrecognized operator, " ++ (show char) ++ " in ExpOp")

codeGenExp (ExpInt value) bld context = 
    do
      return (constInt int64Type (fromIntegral value) (fromIntegral 64)) 
codeGenExp (ExpBool True) bld context = 
    do
      return (constInt int1Type (fromIntegral 1) (fromIntegral 1)) 
codeGenExp (ExpBool False) bld context = 
    do
      return (constInt int1Type (fromIntegral 0) (fromIntegral 1)) 

Code Generation With LLVM (Compiler Series Part IX)

With type checking done, it’s time to generate code from the parse tree. Code generation typically consists of the following steps:
1. Generate intermediate code.
2. Generate the machine or byte code from the intermediate code.

The following are options for doing the above two tasks:

  • Generate intermediate and machine code for x86 or RISC using our own methods (tedious and error prone).
  • Generate the intermediate code and then byte code for either the JVM or CLI.
  • Outsource the problem to LLVM.

I’m a lazy fellow so I pick choice #3 of outsourcing the code generation to LLVM. This is a particularly apt choice given the exciting intersection of Haskell with LLVM recently via David Terei’s honors thesis on using LLVM as the backend for the GHC compiler. This work is ongoing at http://hackage.haskell.org/trac/ghc/wiki/Commentary/Compiler/Backends/LLVM and for an explanation of why LLVM is used as the backend see http://blog.llvm.org/2010/05/glasgow-haskell-compiler-and-llvm.html.

Bryan O’Sullivan put together bindings to LLVM at http://code.haskell.org/llvm/. Bryan and the other author, Lennart Augustsson, also posted examples of using the bindings at http://www.serpentine.com/blog/2008/01/03/llvm-bindings-for-haskell/, http://augustss.blogspot.com/2009/01/llvm-llvm-low-level-virtual-machine-is.html, and http://augustss.blogspot.com/2009/01/llvm-arithmetic-so-we-want-to-compute-x.html.

I’m reading up on LLVM and will post more as I learn more about LLVM in general and the Haskell LLVM bindings in particular.

Type Checking (Compiler Series Part VIII)

In the last post we finished adding the functions to create the symbol tables needed to start the type checking. With the symbol table functions completed it’s time to use them.

To type check a program we traverse the parse tree and check the type of each variable reference. Sounds simple and it is for our language, once we have the right framework setup.

Our framework consists of (1) creating the class symbol table to easily look up class definitions and (2) keeping track of the current context (to track the current context we use an association list called context). We also need helper functions to outsource some tasks such as checking that a class doesn’t declare duplicate variables.

On a high level we (1) create the parse tree, (2) create the class symbol table, and (3) call the type check function on the parse tree. I’ve listed the source code below.

main = do 
  inStr <- getContents
  let parseTree = newl (alexScanTokens2 inStr)  
  let defaultClasses = [("int", ClassSymbol "int" [] []), ("string", ClassSymbol "string" [] []), ("bool", ClassSymbol "bool" [] [])]
  let classes = defaultClasses ++ classSymbols parseTree
  let typeCheckingResult = typeCheck parseTree classes
  if typeCheckingResult == "Ok"
     then putStrLn "Semantic Analysis Results: Passed"
     else putStrLn ("Semantic Analysis Results: Failed, " ++ typeCheckingResult)
  print "done"
}

For type checking the parse tree we do the following:

  • Type check the main class which consists of type checking the single statement in
    public static void main(string[] params) { statement}

    Please note that the only context in the statement is the name of the argument parameters.

  • Type check each class in the program which consists of type checking the class variables and the class methods. For checking the methods we track which class the method is declared in via creating the context association list with the entry (“this”, className) where className is the name of the current class. I’ve listed the source code below.
    typeCheckClassDecl (ClassDecl className varDeclList methodDeclList) classes =
        if typeCheckVarDeclList varDeclList classes [("this", className)] == "Ok" 
        then typeCheckMethodDeclList methodDeclList classes [("this", className)]
        else "Fail2"
    

Doing the above type checking involves traversing down the parse tree and type checking more basic productions such as expressions. I’ll let you read the code for that since the code itself is reasonably self explanatory. The code listing is shown below. You can view the complete type checking program at github.

typeCheck (Program mainClass classDeclList) classes =  
 if typeCheckMainClass mainClass classes == "Ok" 
 then typeCheckClassDeclList classDeclList classes
 else "Fail1" ++ " " ++ (typeCheckMainClass mainClass classes)

typeCheckClassDeclList CEmpty classes = "Ok"
typeCheckClassDeclList (ClassDeclList classDecl classDeclList) classes = 
  if typeCheckClassDecl classDecl classes  == "Ok"
  then "Ok"
  else typeCheckClassDeclList classDeclList classes 

typeCheckClassDecl (ClassDecl className varDeclList methodDeclList) classes =
    if typeCheckVarDeclList varDeclList classes [("this", className)] == "Ok" 
    then typeCheckMethodDeclList methodDeclList classes [("this", className)]
    else "Fail2"

typeCheckVarDeclList VEmpty _ _ = "Ok" -- empty variable declaration list so automatically successful

typeCheckVarDeclList (VarDeclList theType ident varDeclList) classes context =
    if checkForDuplicateVarDeclarations (VarDeclList theType ident varDeclList) [] context == "Ok" && typeCheckVarDecl theType ident classes context == "Ok" 
    then typeCheckVarDeclList varDeclList classes context
    else "Fail3"

checkForDuplicateVarDeclarations VEmpty vars _ = "Ok"
checkForDuplicateVarDeclarations (VarDeclList theType ident varDeclList) [] context =
  checkForDuplicateVarDeclarations varDeclList [ident] context
checkForDuplicateVarDeclarartions (VarDeclList theType ident varDeclList) vars context =
  case elem ident vars of
    True -> error ("Double declaration of " ++ ident ++ " in " ++ (show (lookup "this" context)))
    False -> checkForDuplicateVarDeclarations varDeclList (ident : vars) context

getTypeName TypeBool = "bool"
getTypeName TypeInt = "int"
getTypeName TypeString = "string"
getTypeName (TypeIdent ident) = ident
getTypeName (TypeIdentArray ident) = ident


typeCheckVarDecl theType ident classes context = 
  case theType of
    TypeBool -> "Ok"
    TypeInt -> "Ok"
    TypeString -> "Ok"
    (TypeIdent typeName) -> case lookup typeName classes of
      Just classSym -> "Ok"
      Nothing -> error ("Unknown type " ++ typeName ++ " in the var declaration list of class " ++ (show (lookup "this" context)))
    (TypeIdentArray typeNameArray) -> let tName = reverse (drop 2 (reverse typeNameArray)) in -- drops the [] at the end of the typeNameArray.
      case lookup tName classes of
        Just classSym -> "Ok"
        Nothing -> error ("Unknown type " ++ typeNameArray ++ " in the var declaration list of class " ++ (show (lookup "this" context)))

typeCheckMethodDeclList MEmpty classes context = "Ok" -- no methods so automatically successful
typeCheckMethodDeclList (MethodDeclList methodDecl methodDeclList) classes context =
    if checkForDuplicateMethodDeclarations (MethodDeclList methodDecl methodDeclList) [] context == "Ok" && typeCheckMethodDecl methodDecl classes context == "Ok" 
    then typeCheckMethodDeclList methodDeclList classes context
    else "Fail4"

checkForDuplicateMethodDeclarations MEmpty _ _ = "Ok"
checkForDuplicateMethodDeclarations (MethodDeclList methodDecl methodDeclList) methods context =
  case elem (methodName methodDecl) methods of
    False -> checkForDuplicateMethodDeclarations methodDeclList ((methodName methodDecl) : methods) context
    True -> error("Duplicate method " ++ (methodName methodDecl))
  

typeCheckMethodDecl (MethodDecl theType methodName formalList varDeclList statementList exp) classes context =    
    let context2 = (getThisClassVariables classes context) ++ (getVarDeclListVariables varDeclList classes) ++ (getFormalListVariables formalList classes) ++ context
        typeName = getTypeName theType
    in
      if typeCheckStatementList statementList classes context2 == "Ok" 
      then 
        if typeName /=  (typeCheckExp exp classes context2)
        then error ("the type of method " ++ methodName ++ " does not match the return type of " ++ (show exp))
        else "Ok"
      else 
        "Fail5"          
                

getThisClassVariables classes context =
    case lookup "this" context of
      Nothing -> error("undeclared this in context: " ++ show context)
      Just thisTypeName -> cVarTypes thisTypeName classes

cVarTypes typeName classes = 
    case lookup typeName classes of
      Nothing -> error("undeclared type: " ++ typeName)
      Just classSym -> publicVariables classSym

getVarDeclListVariables VEmpty _ = [] -- empty variable declaration list

getVarDeclListVariables (VarDeclList theType ident varDeclList) classes = 
    let typeName = getTypeName theType 
    in 
      if lookup typeName classes == Nothing || checkForDuplicateVariableDecls varDeclList [(theType, ident)] /= "Ok"
      then error("unknown type " ++ typeName ++ " for variable " ++ ident)
      else (ident, getTypeName theType) : getVarDeclListVariables varDeclList classes

checkForDuplicateVariableDecls :: VarDeclList -> [(Type, Ident)] -> String
checkForDuplicateVariableDecls VEmpty _ = "Ok"
checkForDuplicateVariableDecls (VarDeclList theType ident varDeclList) varList =
  case elem (theType, ident) varList of
    False -> checkForDuplicateVariableDecls varDeclList ((theType, ident) : varList)
    True -> error("duplicate declaration of " ++ ident ++ " in " ++ show(varDeclList)) 



getFormalListVariables FEmpty classes = []
getFormalListVariables (FormalList theType ident formalList) classes = 
    let typeName = getTypeName theType
    in
      if lookup typeName classes == Nothing || checkForDuplicateFormalListVariables formalList [(theType, ident)] /= "Ok"
      then error("unknown type " ++ typeName ++ " for variable " ++ ident)
      else (ident, getTypeName theType) : getFormalListVariables formalList classes
    
checkForDuplicateFormalListVariables :: FormalList -> [(Type, Ident)] -> String
checkForDuplicateFormalListVariables FEmpty _ = "Ok"
checkForDuplicateFormalListVariables (FormalList theType ident formalList) fList = case elem (theType, ident) fList of
  False -> checkForDuplicateFormalListVariables formalList ([(theType, ident)] ++ fList)
  True -> error("duplicate declaration of " ++ ident ++ " in " ++ show(fList))
      

typeCheckMainClass (MClass className paramName statement) classes = if (lookup className classes == Nothing)
                                                                    then error("Error " ++ className ++ " is not a declared class")
                                                                    else typeCheckStatement statement classes [("this", className), (paramName, "string[]")]



typeCheckStatementList Empty classes context = "Ok"
typeCheckStatementList (StatementList statementList statement) classes context =
    if typeCheckStatement statement classes context == "Ok" && typeCheckStatementList statementList classes context == "Ok"
    then "Ok"
    else "Fail6"

typeCheckStatement (SList statementList) classes context = 
    typeCheckStatementList statementList classes context
typeCheckStatement (SIfElse exp1 statement1 statement2) classes context =
    if (typeCheckExp exp1 classes context) == "bool" && typeCheckStatement statement1 classes context == "Ok" && typeCheckStatement statement2 classes context == "Ok"
    then "Ok"
    else error ("Error in if else statement")

typeCheckStatement (SWhile exp statement) classes context = 
      if typeCheckExp exp classes context == "bool" 
      then typeCheckStatement statement classes context
      else error("Error type of " ++ show(exp) ++ " is not bool in while statement")


typeCheckStatement (SEqual ident exp1) classes context = 
    let identType = lookup ident context in
    case identType of 
         Nothing -> error("Error undeclared identifier " ++ ident ++ " in equal statement")
         Just iType -> if iType == typeCheckExp exp1 classes context
                     then "Ok"
                     else error("Error types do not match in equals statements, type1 " ++ iType ++ " type2 " ++ (typeCheckExp exp1 classes context))


typeCheckStatement (SPrint exp) classes context = if typeCheckExp exp classes context /= "" then "Ok" else "Fail"
typeCheckStatement (SArrayEqual ident exp1 exp2) classes context = 
  case lookup ident context of
    Nothing -> error("Error undeclared identifier " ++ ident ++ " in equal statement")
    Just iType -> if take 2 (reverse iType) /= "[]" then
                    error(ident ++ " is not an array")
                  else 
                    let baseTypeName = reverse (take 2 (reverse iType)) in
                    if baseTypeName /= typeCheckExp exp2 classes context || typeCheckExp exp1 classes context /= "bool" then
                      error("Error, can't assign to array")
                    else
                      "Ok"

-- some helper functions for typeCheckExp

expTypes ExpListEmpty classes context = []
expTypes (ExpList exp expRest) classes context = (typeCheckExp exp classes context) : expRestTypes expRest classes context
expTypes (ExpListExp exp) classes context = [typeCheckExp exp classes context]
    
expRestTypes (ExpRest exp) classes context = [typeCheckExp exp classes context]
        

checkFunctionCall (ClassSymbol cName vars []) methodName [] = 
    error (methodName ++ " is not a method of class " ++ cName)

checkFunctionCall (ClassSymbol cName var methods) methodName methodTypes = 
    let method = lookup methodName methods in
    case method of 
         Just theMethod -> checkMethodTypes theMethod methodTypes
         Nothing -> error (methodName ++ " is not a method of " ++ cName)

checkMethodTypes (MethodSymbol returnType name []) [] = returnType

checkMethodTypes (MethodSymbol returnType name []) methodTypes = error("method " ++ name ++ " doesn't take any arguments but arguments of type " ++ show(methodTypes) ++ " provided")

checkMethodTypes (MethodSymbol returnType name args) [] = error("method " ++ name ++ " takes arguments but no arguments provided")
 
checkMethodTypes (MethodSymbol returnType name ((argName, argType) : args)) (type1 : types)  = 
    if argType == type1 
    then checkMethodTypes (MethodSymbol returnType name args) types 
    else error ("method " ++ name ++ " argument type mismatch " ++ " expected type " ++ argType ++ " but got type " ++ type1)




-- the typeCheckExp function returns the type name of the expression
typeCheckExp (ExpOp exp1 char exp2) classes context = 
    if typeCheckExp exp1 classes context == "int" && typeCheckExp exp2 classes context == "int"
       then "int"
       else error ("one of the expressions exp1:" ++ show(exp1) ++ " exp2:" ++ show(exp2) ++ " is not an integer \n exp1 type: " ++ (typeCheckExp exp1 classes context) ++ "\n exp2 type: " ++ (typeCheckExp exp2 classes context))


typeCheckExp (ExpComOp exp1 char exp2) classes context = 
    if typeCheckExp exp1 classes context == "int" && typeCheckExp exp2 classes context == "int"
       then "bool"
       else error ("one of the expressions exp1:" ++ show(exp1) ++ " exp2:" ++ show(exp2) ++ " is not an integer \n exp1 type: " ++ (typeCheckExp exp1 classes context) ++ "\n exp2 type: " ++ (typeCheckExp exp2 classes context))


typeCheckExp (ExpArray exp1 exp2) classes context  =  -- "Exp [ Exp ]"
  if typeCheckExp exp2 classes context /= "int" then
    error("Error in ExpArray")
  else
    if take 2 (reverse (typeCheckExp exp1 classes context)) /= "[]" then
      error("Error in ExpArray")
    else
      reverse(drop 2 (reverse (typeCheckExp exp1 classes context)))
    
  
typeCheckExp (ExpFCall exp ident expList) classes context =   
      let className = typeCheckExp exp classes context -- Exp . Ident ( ExpList )
          classSym = lookup className classes
          expListTypes = expTypes expList classes context 
      in case classSym of
           Just x -> (checkFunctionCall x ident expListTypes)
           Nothing -> error ("Undeclared class " ++ className ++ " in function call")

typeCheckExp (ExpInt int) classes context = "int"

typeCheckExp (ExpNewIntArray exp) classes context = 
    if typeCheckExp exp classes context == "int"
       then "int[]"
       else error ("Error new int[exp] the expression type is not an integer")


typeCheckExp (ExpNewBoolArray exp) classes context = 
    if typeCheckExp exp classes context == "int"
       then "bool[]"
       else error ("Error new bool[exp] the expression type is not an integer")


typeCheckExp (ExpNewStringArray exp) classes context = 
    if typeCheckExp exp classes context == "int"
       then "string[]"
       else error ("Error new string[exp] the expression type is not an integer")

typeCheckExp (ExpBool bool) classes context  = "bool" -- True or False


typeCheckExp (ExpIdent ident) classes context = 
    case lookup ident context of
      Just x -> x
      Nothing -> error ("Error " ++ ident ++ " is not a declared variable, context " ++ show context)
                                                       
typeCheckExp (ExpNewIdent ident) classes context = 
    if lookup ident classes == Nothing
    then error ("Error " ++ ident ++ " is not a declared class" ++ ", context " ++ show context)
    else ident

typeCheckExp (ExpNewIdentArray ident exp) classes context = 
    if lookup ident classes == Nothing || typeCheckExp exp classes context /= "int"
    then error ("Error " ++ ident ++ " is not a declared class or " ++ show(exp) ++ " is not an int" ++ ", context " ++ show context)
    else ident ++ "[]"

typeCheckExp (ExpExp exp) classes context  = typeCheckExp exp classes context -- Exp ( Exp )

typeCheckExp (ExpThis) classes context =
    let thisSym = lookup "this" context in
    case thisSym of 
      Just sym -> sym
      Nothing -> error ("this symbol undeclared")

typeCheckExp (ExpNot exp) classes context = 
    if typeCheckExp exp classes context == "bool"
       then "bool"
       else error "wrong type for !exp, expecting bool type"

typeCheckExp (ExpLength exp) classes context =
    if typeCheckExp exp classes context == "int[]"
           then "int"
           else error "Error in " ++ show(exp) ++ ".length the expression is not of type int[] "


main = do 
  inStr <- getContents
  let parseTree = newl (alexScanTokens2 inStr)  
  putStrLn ("parseTree: " ++ show(parseTree))
  let defaultClasses = [("int", ClassSymbol "int" [] []), ("string", ClassSymbol "string" [] []), ("bool", ClassSymbol "bool" [] [])]
  let classes = defaultClasses ++ classSymbols parseTree
  putStrLn "classes " 
  print classes
  let typeCheckingResult = typeCheck parseTree classes
  if typeCheckingResult == "Ok"
     then putStrLn "Semantic Analysis Results: Passed"
     else putStrLn ("Semantic Analysis Results: Failed, " ++ typeCheckingResult)
  
  putStrLn ("parseTree: " ++ show(parseTree))
  print "done"
}

Cleaning up NewL and Adding a Symbol Table Part II (Compiler Series Part VII)

Type checking has several stages, I’m breaking it up into a few posts to cover it more fully. To type check a program we need to:

  • 1. Create the symbol table for the classes. This symbol table is an association list where the key is the class name and the value is a “ClassSymbol”.
  • 2. For each class create the symbol table for the class methods which is again an association list where the key is the method name and the value is a “MethodSymbol”
  • 3. For each class create the symbol table of the class instance variables, this symbol table is an association list where the key is the variable name and the value is the variable type name. For example a class with one instance variable named “count” of type “int” has [("count", "int")] as the symbol table.
  • 4. Type check the statements and expressions within the methods. (future post)

In looking over our language NewL I realized that arrays and strings are not handled consistently. In particular a string was declared with a capital “S” in the “String” versus using all lowercase for “int” and “boolean”. I didn’t like this since it make the declaration syntax inconsistent for inbuilt types. The second issue is that it is only possible to declare arrays of integers. I want to make arrays an integral feature of the language. Arrays should be of any type. With those two changes in mind I’ve updated the grammar of NewL. The updated grammar is shown below. I’ve modified the lexer and parser where needed to reflect the update. You can get both here.

Terminal symbols are bolded and single character ones are in double quotes.

NewL Grammar

1. Program ::= MainClass { ClassDecl }
2. MainClass :== class Ident “{public static void main “(” string[] Ident “)” “{” Statement “}” “}
3. ClassDecl ::= class Ident “{” {VarDecl} {MethodDecl} “}
4. VarDecl ::= Type Ident “;
5. MethodDecl ::= public Type Ident “(” FormalList “)” “{” {VarDecl} {Statement} return Exp “;” “}
6. FormalList :== Type Ident { FormalRest } | ε
7. FormalRest :== “,” Type Ident
8. Type :== int | bool | int | string | Ident | IdentArray
9. Statement :== “{” { Statement } “}” | if “(“ Exp “)” Statement else Statement | while(” Exp “)” Statement | System.out.println(” Exp “)” “;” | Ident = Exp “;” | Ident “[" Exp "]” “=” Exp “;
10. Exp ::= Exp Op Exp | Exp “[" Exp "]” | Exp “.length | Exp “.” Ident “(” ExpList “)” |
Integer_Literal | String_Literal | true | false | Ident | this | new Ident “[" Exp "]” | new Ident “(” “)” | “!” Exp | “(” Exp “)
11. ExpList :== Exp { ExpRest } | ε
12. ExpRest ::= “,” Exp
13. Ident ::= Letter { Letter | Digit | “_” }
13. IdentArray ::= Letter { Letter | Digit | “_” }”[""]
14. Integer_Literal ::= Digit { Digit }
15. String_Literal ::= “” { Character } “
16. Op ::= && | < | +| - | *

Revised Hello World for NewL.

class HelloWorld {
    public static void main(string[] a) {
        System.out.println(new Hello().World());
    }
}
class Hello {
    public String World() {
        return "Hello World";
    }
}

With the revised NewL language. We can create the symbol table for the classes. Please note that each ClassSymbol structure itself contains a symbol table for its instance variables and its methods. The below code contains the functions for creating the symbol table of classes. There are several helper functions for creating the symbol table for the methods of each class and the symbol table of the instance variables of each class. Please refer back to the post on “Adding a Symbol Table” for the definition of ClassSymbol and MethodSymbol. With the symbol tables created we will be able to start the type checking.

classSymbols (Program mainClass classDeclList) = classSymbolMainClass mainClass : classSymbolsClassDeclList classDeclList

classSymbolMainClass (MClass className paramName statement) =
                      (className, (ClassSymbol className [] [
                                                            ("main", 
                                                                    (MethodSymbol {returnType = "void", name = "main", args = [("string[]", paramName)]})
                                                            )]
                                  )
                      )
classSymbolClassDecl (ClassDecl className parentClassName varDecls methodDecls) = (className, (ClassSymbol className (varSymbols varDecls) (methodSymbols methodDecls)))
classSymbolsClassDeclList (ClassDeclList classDecl classDeclList) = classSymbolClassDecl classDecl : classSymbolsClassDeclList classDeclList
classSymbolsClassDeclList (CEmpty) = []

varSymbols VEmpty = []
varSymbols (VarDeclList theType ident varDeclList) = varSymbol theType ident : varSymbols varDeclList

varSymbol (TypeString) ident = (ident, "string")
varSymbol (TypeBool) ident = (ident, "bool")
varSymbol (TypeInt) ident = (ident, "int")
varSymbol (TypeIdent identType) ident = (ident, identType)
varSymbol (TypeIdentArray identType) ident = (ident, identType)

methodSymbols MEmpty = []
methodSymbols (MethodDeclList methodDecl methodDeclList) = methodSymbol methodDecl : methodSymbols methodDeclList

methodSymbol (MethodDecl theType ident formalList varDeclList statementList exp)
             = case theType of
                    TypeInt -> (ident, MethodSymbol {returnType = "int", name = ident, args = (argSymbols formalList)})
                    TypeBool -> (ident, MethodSymbol {returnType = "bool", name = ident, args = (argSymbols formalList)})
                    TypeString -> (ident, MethodSymbol {returnType = "string", name = ident, args = (argSymbols formalList)})
                    TypeIdent classType -> (ident, MethodSymbol {returnType = classType, name = ident, args = (argSymbols formalList)})
                    TypeIdentArray classType -> (ident, MethodSymbol {returnType = classType, name = ident, args = (argSymbols formalList)})

argSymbols FEmpty = []
argSymbols (FormalList theType ident formalList) =
           case theType of
                TypeInt -> (ident, "int") : argSymbols formalList
                TypeBool -> (ident, "bool") : argSymbols formalList
                TypeString -> (ident, "string") : argSymbols formalList
                TypeIdent classType -> (ident, classType) : argSymbols formalList
                TypeIdentArray classType -> (ident, classType) : argSymbols formalList

You can download the source code here. Look in the Newl.y file for the parser with the symbol table added.