Be warned, this code is really ugly.  I need some help to make it better.

I was in this situation this week where there was a sample of some numbers.  I knew the valid possible range that numbers could lie in, and I knew about 30% of the actual values from the sample.  I also knew the original average and the original standard deviation.  I wanted more information about the other values in the sample, so I sat down and derived two formulas: One that, given an average, the number of items in the average, and 1 number that you know was used somewhere in the calculation of the average, calculates what the average would be if that number had not been there.  Another that does the same for standard deviation.  The second formula is fairly insane, if someone can figure out how to simplify it let me know, it took forever to derive it without making any mistakes.  Actually I didn't even know it was possible to do a running standard deviation like that until I sat down and derived it.

So anyway the program asks the user how large the original sample was, and the average and standard deviation.  Then it runs in a loop asking it for a value to remove from the sample set and calculate the updated values.  For certain values however, even if they are actually in the range of valid values for the range of the sample space, they might be impossible to achieve given a certain sample size, average, and stdev (for example, if you have two numbers in the range [0..100] and the average is 100, and you remove 1 number, none of the remaining numbers can be 0 obviously).  So I wanted to detect this if the user inputs such a value, that way for one thing I can calculate exact upper and lower bounds for the actual range of the numbers just by trying all possible values. 

I know that the nature of the program mandates that there's going to be a lot of IO Monadic code, but I still feel like my code sucks. 

For one thing, I really tried to make the removePoints method be pure, and just accept a list as the argument and do the I/O in another function.  I imagined some kind of lazy list comprehension, where each time the removePoints method asked for the next item, it would ask for the input and do all that stuff.  I couldn't figure out how to make this work.  Perhaps it's not even possible, since by definition a list comprehension that asked for input when forced isn't pure.  But I still think there's a better way to do it. 

I also really dislike having to do manual recursion for getting input, there's got to be a way to use lists and built-in functions to do it for you, using some kind of takeWhile() perhaps.

Anyway if anyone has any suggestions for making this code look better, more elegant, or more Haskelly, let me know :)


import Control.Arrow

promptUntilValid :: (Read a, Show a, Eq a) => String -> IO a
promptUntilValid prompt = do
    putStr prompt
    putStr " "
    result <- getLine
    case (reads result) of
        [(parsed,_)] -> return parsed
        invalid -> do
            putStrLn $ "Invalid input!  You entered " ++ result ++ ", and results length is " ++ (show (length invalid))
            promptUntilValid prompt
           
newStdev :: Int -> (Int,Int,Double,Double,Double) -> Double
newStdev removeValue (old_count, new_count, old_stdev, old_avg, new_avg) =
    sqrt.(/ (n-1.0)) $
        n * ó_x^2 + 2.0*old_avg*new_avg*(n-1.0) - (n-1.0)*old_avg^2 - (a_n - old_avg)^2 - (n-1.0)*new_avg^2
    where
        ó_x = old_stdev
        n = fromIntegral old_count
        a_n = fromIntegral removeValue

removeSinglePoint :: Int -> (Int,Double,Double) -> (Int,Double,Double)
removeSinglePoint value (count,stdev,avg) =
    (new_count,new_stdev,new_avg)
    where new_count = count-1
          new_avg = (avg*(fromIntegral count) - (fromIntegral value)) / (fromIntegral new_count)
          new_stdev = newStdev value (count, new_count, stdev, avg, new_avg)

removePoints :: (Int,Double,Double) -> IO ()
removePoints (count,stdev,avg) = do
    putStr "Enter a number to remove from the sample set (Enter to stop): "
    putStr " "
    result <- getLine
    case (result,reads result) of
        ("",_) -> return ()
        (_, [(parsed,_)]) -> do
            if (any (uncurry (||) . (isInfinite &&& isNaN)) [new_stdev, new_avg])
             then do
                putStrLn "That number could not have been there!  The new values are NaN!"
                removePoints (count,stdev,avg)
             else do
                putStrLn $ "New count: " ++ (show new_count) ++ ", new stdev: " ++ (show new_stdev) ++ ", new avg: " ++ (show new_avg)
                removePoints (new_count,new_stdev,new_avg)
            where (new_count,new_stdev,new_avg) = removeSinglePoint parsed (count,stdev,avg)
        (_,invalid) -> do
            putStrLn $ "Invalid input!  You entered " ++ result ++ ", and results length is " ++ (show (length invalid))
            removePoints (count,stdev,avg)
   
main = do
    num_values::Int <- promptUntilValid "Enter the initial number of values: "
    stdev::Double <- promptUntilValid "Enter the initial standard deviation: "
    average::Double <- promptUntilValid "Enter the initial average: "

    putStrLn ("num_values = " ++ (show num_values))
    putStrLn ("stdev = " ++ (show stdev))
    putStrLn ("average = " ++ (show average))
   
    removePoints (num_values, stdev, average)
   
    return ()